mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-08 07:30:35 -04:00
Minor ux improvements
This commit is contained in:
parent
d94fec0674
commit
705b16f062
3 changed files with 60 additions and 28 deletions
|
@ -19,7 +19,6 @@ package com.google.aiedge.gallery.data
|
|||
import android.content.Context
|
||||
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
|
||||
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||
|
||||
data class ModelDataFile(
|
||||
|
@ -82,9 +81,6 @@ data class Model(
|
|||
/** The name of the directory to unzip the model to (if it's a zip file). */
|
||||
val unzipDir: String = "",
|
||||
|
||||
/** The accelerators the the model can run with. */
|
||||
val accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
|
||||
|
||||
/** The prompt templates for the model (only for LLM). */
|
||||
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
||||
|
||||
|
@ -243,7 +239,9 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
|
|||
downloadFileName = "gemma-2b-it-gpu-int4.bin",
|
||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
|
||||
sizeInBytes = 1354301440L,
|
||||
configs = createLlmChatConfigs(),
|
||||
configs = createLlmChatConfigs(
|
||||
accelerators = listOf(Accelerator.GPU)
|
||||
),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||
|
@ -254,7 +252,9 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
|||
downloadFileName = "gemma2-2b-it-gpu-int8.bin",
|
||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
|
||||
sizeInBytes = 2627141632L,
|
||||
configs = createLlmChatConfigs(),
|
||||
configs = createLlmChatConfigs(
|
||||
accelerators = listOf(Accelerator.GPU)
|
||||
),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||
|
@ -265,7 +265,6 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
|||
downloadFileName = "gemma3-1b-it-int4.task",
|
||||
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
|
||||
sizeInBytes = 554661243L,
|
||||
accelerators = listOf(Accelerator.CPU, Accelerator.GPU),
|
||||
configs = createLlmChatConfigs(
|
||||
defaultTopK = 64,
|
||||
defaultTopP = 0.95f,
|
||||
|
@ -293,7 +292,6 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
|
|||
downloadFileName = "deepseek.task",
|
||||
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
|
||||
sizeInBytes = 1860686856L,
|
||||
accelerators = listOf(Accelerator.CPU),
|
||||
configs = createLlmChatConfigs(
|
||||
defaultTemperature = 0.6f,
|
||||
defaultTopK = 40,
|
||||
|
|
|
@ -36,7 +36,10 @@ import androidx.compose.runtime.Composable
|
|||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
|
@ -80,8 +83,10 @@ fun LlmSingleTurnScreen(
|
|||
val selectedModel = modelManagerUiState.selectedModel
|
||||
val scope = rememberCoroutineScope()
|
||||
val context = LocalContext.current
|
||||
var navigatingUp by remember { mutableStateOf(false) }
|
||||
|
||||
val handleNavigateUp = {
|
||||
navigatingUp = true
|
||||
navigateUp()
|
||||
|
||||
// clean up all models.
|
||||
|
@ -100,12 +105,14 @@ fun LlmSingleTurnScreen(
|
|||
// Initialize model when model/download state changes.
|
||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
||||
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
Log.d(
|
||||
TAG,
|
||||
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
|
||||
)
|
||||
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||
if (!navigatingUp) {
|
||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
Log.d(
|
||||
TAG,
|
||||
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
|
||||
)
|
||||
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import androidx.compose.foundation.BorderStroke
|
|||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
|
@ -72,8 +73,11 @@ import androidx.compose.ui.Alignment
|
|||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.focus.FocusRequester
|
||||
import androidx.compose.ui.focus.focusRequester
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalClipboardManager
|
||||
import androidx.compose.ui.platform.LocalFocusManager
|
||||
import androidx.compose.ui.res.dimensionResource
|
||||
import androidx.compose.ui.text.TextLayoutResult
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
|
@ -111,6 +115,9 @@ fun PromptTemplatesPanel(
|
|||
}
|
||||
}
|
||||
val clipboardManager = LocalClipboardManager.current
|
||||
val focusRequester = remember { FocusRequester() }
|
||||
val focusManager = LocalFocusManager.current
|
||||
val interactionSource = remember { MutableInteractionSource() }
|
||||
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
|
||||
|
||||
// Update input editor values when prompt template changes.
|
||||
|
@ -127,20 +134,30 @@ fun PromptTemplatesPanel(
|
|||
|
||||
Column(modifier = modifier) {
|
||||
// Scrollable tab row for all prompt templates.
|
||||
PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) {
|
||||
PrimaryScrollableTabRow(
|
||||
selectedTabIndex = selectedTabIndex
|
||||
) {
|
||||
TAB_TITLES.forEachIndexed { index, title ->
|
||||
Tab(selected = selectedTabIndex == index, onClick = {
|
||||
// Clear input when tab changes.
|
||||
curTextInputContent = ""
|
||||
// Reset full prompt switch.
|
||||
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false
|
||||
Tab(selected = selectedTabIndex == index,
|
||||
enabled = !inProgress,
|
||||
onClick = {
|
||||
// Clear input when tab changes.
|
||||
curTextInputContent = ""
|
||||
// Reset full prompt switch.
|
||||
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false
|
||||
|
||||
selectedTabIndex = index
|
||||
viewModel.selectPromptTemplate(
|
||||
model = model,
|
||||
promptTemplateType = promptTemplateTypes[index]
|
||||
)
|
||||
}, text = { Text(text = title) })
|
||||
selectedTabIndex = index
|
||||
viewModel.selectPromptTemplate(
|
||||
model = model,
|
||||
promptTemplateType = promptTemplateTypes[index]
|
||||
)
|
||||
},
|
||||
text = {
|
||||
Text(
|
||||
text = title,
|
||||
modifier = Modifier.alpha(if (inProgress) 0.5f else 1f)
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -178,6 +195,13 @@ fun PromptTemplatesPanel(
|
|||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.verticalScroll(rememberScrollState())
|
||||
.clickable(
|
||||
interactionSource = interactionSource,
|
||||
indication = null // Disable the ripple effect
|
||||
) {
|
||||
// Request focus on the TextField when the Column is clicked
|
||||
focusRequester.requestFocus()
|
||||
}
|
||||
) {
|
||||
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
|
||||
Text(
|
||||
|
@ -186,7 +210,7 @@ fun PromptTemplatesPanel(
|
|||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp)
|
||||
.padding(bottom = 32.dp)
|
||||
.padding(bottom = 40.dp)
|
||||
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
|
||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||
.padding(16.dp)
|
||||
|
@ -205,7 +229,9 @@ fun PromptTemplatesPanel(
|
|||
),
|
||||
textStyle = MaterialTheme.typography.bodyMedium,
|
||||
placeholder = { Text("Enter content") },
|
||||
modifier = Modifier.padding(bottom = 32.dp)
|
||||
modifier = Modifier
|
||||
.padding(bottom = 40.dp)
|
||||
.focusRequester(focusRequester)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -301,6 +327,7 @@ fun PromptTemplatesPanel(
|
|||
OutlinedIconButton(
|
||||
enabled = !inProgress && curTextInputContent.isNotEmpty(),
|
||||
onClick = {
|
||||
focusManager.clearFocus()
|
||||
onSend(fullPrompt.text)
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue