Minor ux improvements

This commit is contained in:
Jing Jin 2025-04-27 11:18:10 -07:00
parent d94fec0674
commit 705b16f062
3 changed files with 60 additions and 28 deletions

View file

@ -19,7 +19,6 @@ package com.google.aiedge.gallery.data
import android.content.Context import android.content.Context
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
import com.google.aiedge.gallery.ui.common.convertValueToTargetType import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
data class ModelDataFile( data class ModelDataFile(
@ -82,9 +81,6 @@ data class Model(
/** The name of the directory to unzip the model to (if it's a zip file). */ /** The name of the directory to unzip the model to (if it's a zip file). */
val unzipDir: String = "", val unzipDir: String = "",
/** The accelerators the the model can run with. */
val accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
/** The prompt templates for the model (only for LLM). */ /** The prompt templates for the model (only for LLM). */
val llmPromptTemplates: List<PromptTemplate> = listOf(), val llmPromptTemplates: List<PromptTemplate> = listOf(),
@ -243,7 +239,9 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
downloadFileName = "gemma-2b-it-gpu-int4.bin", downloadFileName = "gemma-2b-it-gpu-int4.bin",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin", url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
sizeInBytes = 1354301440L, sizeInBytes = 1354301440L,
configs = createLlmChatConfigs(), configs = createLlmChatConfigs(
accelerators = listOf(Accelerator.GPU)
),
showBenchmarkButton = false, showBenchmarkButton = false,
info = LLM_CHAT_INFO, info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community", learnMoreUrl = "https://huggingface.co/litert-community",
@ -254,7 +252,9 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
downloadFileName = "gemma2-2b-it-gpu-int8.bin", downloadFileName = "gemma2-2b-it-gpu-int8.bin",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin", url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
sizeInBytes = 2627141632L, sizeInBytes = 2627141632L,
configs = createLlmChatConfigs(), configs = createLlmChatConfigs(
accelerators = listOf(Accelerator.GPU)
),
showBenchmarkButton = false, showBenchmarkButton = false,
info = LLM_CHAT_INFO, info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community", learnMoreUrl = "https://huggingface.co/litert-community",
@ -265,7 +265,6 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
downloadFileName = "gemma3-1b-it-int4.task", downloadFileName = "gemma3-1b-it-int4.task",
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true", url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
sizeInBytes = 554661243L, sizeInBytes = 554661243L,
accelerators = listOf(Accelerator.CPU, Accelerator.GPU),
configs = createLlmChatConfigs( configs = createLlmChatConfigs(
defaultTopK = 64, defaultTopK = 64,
defaultTopP = 0.95f, defaultTopP = 0.95f,
@ -293,7 +292,6 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
downloadFileName = "deepseek.task", downloadFileName = "deepseek.task",
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true", url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
sizeInBytes = 1860686856L, sizeInBytes = 1860686856L,
accelerators = listOf(Accelerator.CPU),
configs = createLlmChatConfigs( configs = createLlmChatConfigs(
defaultTemperature = 0.6f, defaultTemperature = 0.6f,
defaultTopK = 40, defaultTopK = 40,

View file

@ -36,7 +36,10 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
@ -80,8 +83,10 @@ fun LlmSingleTurnScreen(
val selectedModel = modelManagerUiState.selectedModel val selectedModel = modelManagerUiState.selectedModel
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val context = LocalContext.current val context = LocalContext.current
var navigatingUp by remember { mutableStateOf(false) }
val handleNavigateUp = { val handleNavigateUp = {
navigatingUp = true
navigateUp() navigateUp()
// clean up all models. // clean up all models.
@ -100,12 +105,14 @@ fun LlmSingleTurnScreen(
// Initialize model when model/download state changes. // Initialize model when model/download state changes.
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name] val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
LaunchedEffect(curDownloadStatus, selectedModel.name) { LaunchedEffect(curDownloadStatus, selectedModel.name) {
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { if (!navigatingUp) {
Log.d( if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
TAG, Log.d(
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect" TAG,
) "Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) )
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
}
} }
} }

View file

@ -20,6 +20,7 @@ import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
@ -72,8 +73,11 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.res.dimensionResource import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.text.TextLayoutResult import androidx.compose.ui.text.TextLayoutResult
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
@ -111,6 +115,9 @@ fun PromptTemplatesPanel(
} }
} }
val clipboardManager = LocalClipboardManager.current val clipboardManager = LocalClipboardManager.current
val focusRequester = remember { FocusRequester() }
val focusManager = LocalFocusManager.current
val interactionSource = remember { MutableInteractionSource() }
val expandedStates = remember { mutableStateMapOf<String, Boolean>() } val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
// Update input editor values when prompt template changes. // Update input editor values when prompt template changes.
@ -127,20 +134,30 @@ fun PromptTemplatesPanel(
Column(modifier = modifier) { Column(modifier = modifier) {
// Scrollable tab row for all prompt templates. // Scrollable tab row for all prompt templates.
PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) { PrimaryScrollableTabRow(
selectedTabIndex = selectedTabIndex
) {
TAB_TITLES.forEachIndexed { index, title -> TAB_TITLES.forEachIndexed { index, title ->
Tab(selected = selectedTabIndex == index, onClick = { Tab(selected = selectedTabIndex == index,
// Clear input when tab changes. enabled = !inProgress,
curTextInputContent = "" onClick = {
// Reset full prompt switch. // Clear input when tab changes.
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false curTextInputContent = ""
// Reset full prompt switch.
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false
selectedTabIndex = index selectedTabIndex = index
viewModel.selectPromptTemplate( viewModel.selectPromptTemplate(
model = model, model = model,
promptTemplateType = promptTemplateTypes[index] promptTemplateType = promptTemplateTypes[index]
) )
}, text = { Text(text = title) }) },
text = {
Text(
text = title,
modifier = Modifier.alpha(if (inProgress) 0.5f else 1f)
)
})
} }
} }
@ -178,6 +195,13 @@ fun PromptTemplatesPanel(
modifier = Modifier modifier = Modifier
.fillMaxSize() .fillMaxSize()
.verticalScroll(rememberScrollState()) .verticalScroll(rememberScrollState())
.clickable(
interactionSource = interactionSource,
indication = null // Disable the ripple effect
) {
// Request focus on the TextField when the Column is clicked
focusRequester.requestFocus()
}
) { ) {
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) { if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
Text( Text(
@ -186,7 +210,7 @@ fun PromptTemplatesPanel(
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
.padding(16.dp) .padding(16.dp)
.padding(bottom = 32.dp) .padding(bottom = 40.dp)
.clip(MessageBubbleShape(radius = bubbleBorderRadius)) .clip(MessageBubbleShape(radius = bubbleBorderRadius))
.background(MaterialTheme.customColors.agentBubbleBgColor) .background(MaterialTheme.customColors.agentBubbleBgColor)
.padding(16.dp) .padding(16.dp)
@ -205,7 +229,9 @@ fun PromptTemplatesPanel(
), ),
textStyle = MaterialTheme.typography.bodyMedium, textStyle = MaterialTheme.typography.bodyMedium,
placeholder = { Text("Enter content") }, placeholder = { Text("Enter content") },
modifier = Modifier.padding(bottom = 32.dp) modifier = Modifier
.padding(bottom = 40.dp)
.focusRequester(focusRequester)
) )
} }
} }
@ -301,6 +327,7 @@ fun PromptTemplatesPanel(
OutlinedIconButton( OutlinedIconButton(
enabled = !inProgress && curTextInputContent.isNotEmpty(), enabled = !inProgress && curTextInputContent.isNotEmpty(),
onClick = { onClick = {
focusManager.clearFocus()
onSend(fullPrompt.text) onSend(fullPrompt.text)
}, },
colors = IconButtonDefaults.iconButtonColors( colors = IconButtonDefaults.iconButtonColors(