mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-08 15:40:31 -04:00
Minor ux improvements
This commit is contained in:
parent
d94fec0674
commit
705b16f062
3 changed files with 60 additions and 28 deletions
|
@ -19,7 +19,6 @@ package com.google.aiedge.gallery.data
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
||||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||||
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
|
|
||||||
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||||
|
|
||||||
data class ModelDataFile(
|
data class ModelDataFile(
|
||||||
|
@ -82,9 +81,6 @@ data class Model(
|
||||||
/** The name of the directory to unzip the model to (if it's a zip file). */
|
/** The name of the directory to unzip the model to (if it's a zip file). */
|
||||||
val unzipDir: String = "",
|
val unzipDir: String = "",
|
||||||
|
|
||||||
/** The accelerators the the model can run with. */
|
|
||||||
val accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
|
|
||||||
|
|
||||||
/** The prompt templates for the model (only for LLM). */
|
/** The prompt templates for the model (only for LLM). */
|
||||||
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
||||||
|
|
||||||
|
@ -243,7 +239,9 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
|
||||||
downloadFileName = "gemma-2b-it-gpu-int4.bin",
|
downloadFileName = "gemma-2b-it-gpu-int4.bin",
|
||||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
|
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
|
||||||
sizeInBytes = 1354301440L,
|
sizeInBytes = 1354301440L,
|
||||||
configs = createLlmChatConfigs(),
|
configs = createLlmChatConfigs(
|
||||||
|
accelerators = listOf(Accelerator.GPU)
|
||||||
|
),
|
||||||
showBenchmarkButton = false,
|
showBenchmarkButton = false,
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||||
|
@ -254,7 +252,9 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
||||||
downloadFileName = "gemma2-2b-it-gpu-int8.bin",
|
downloadFileName = "gemma2-2b-it-gpu-int8.bin",
|
||||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
|
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
|
||||||
sizeInBytes = 2627141632L,
|
sizeInBytes = 2627141632L,
|
||||||
configs = createLlmChatConfigs(),
|
configs = createLlmChatConfigs(
|
||||||
|
accelerators = listOf(Accelerator.GPU)
|
||||||
|
),
|
||||||
showBenchmarkButton = false,
|
showBenchmarkButton = false,
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||||
|
@ -265,7 +265,6 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
||||||
downloadFileName = "gemma3-1b-it-int4.task",
|
downloadFileName = "gemma3-1b-it-int4.task",
|
||||||
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
|
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
|
||||||
sizeInBytes = 554661243L,
|
sizeInBytes = 554661243L,
|
||||||
accelerators = listOf(Accelerator.CPU, Accelerator.GPU),
|
|
||||||
configs = createLlmChatConfigs(
|
configs = createLlmChatConfigs(
|
||||||
defaultTopK = 64,
|
defaultTopK = 64,
|
||||||
defaultTopP = 0.95f,
|
defaultTopP = 0.95f,
|
||||||
|
@ -293,7 +292,6 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
|
||||||
downloadFileName = "deepseek.task",
|
downloadFileName = "deepseek.task",
|
||||||
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
|
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
|
||||||
sizeInBytes = 1860686856L,
|
sizeInBytes = 1860686856L,
|
||||||
accelerators = listOf(Accelerator.CPU),
|
|
||||||
configs = createLlmChatConfigs(
|
configs = createLlmChatConfigs(
|
||||||
defaultTemperature = 0.6f,
|
defaultTemperature = 0.6f,
|
||||||
defaultTopK = 40,
|
defaultTopK = 40,
|
||||||
|
|
|
@ -36,7 +36,10 @@ import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.LaunchedEffect
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
import androidx.compose.runtime.collectAsState
|
import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
@ -80,8 +83,10 @@ fun LlmSingleTurnScreen(
|
||||||
val selectedModel = modelManagerUiState.selectedModel
|
val selectedModel = modelManagerUiState.selectedModel
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
|
var navigatingUp by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
val handleNavigateUp = {
|
val handleNavigateUp = {
|
||||||
|
navigatingUp = true
|
||||||
navigateUp()
|
navigateUp()
|
||||||
|
|
||||||
// clean up all models.
|
// clean up all models.
|
||||||
|
@ -100,12 +105,14 @@ fun LlmSingleTurnScreen(
|
||||||
// Initialize model when model/download state changes.
|
// Initialize model when model/download state changes.
|
||||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
||||||
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
||||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
if (!navigatingUp) {
|
||||||
Log.d(
|
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||||
TAG,
|
Log.d(
|
||||||
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
|
TAG,
|
||||||
)
|
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
|
||||||
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
)
|
||||||
|
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ import androidx.compose.foundation.BorderStroke
|
||||||
import androidx.compose.foundation.background
|
import androidx.compose.foundation.background
|
||||||
import androidx.compose.foundation.border
|
import androidx.compose.foundation.border
|
||||||
import androidx.compose.foundation.clickable
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
|
@ -72,8 +73,11 @@ import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.alpha
|
import androidx.compose.ui.draw.alpha
|
||||||
import androidx.compose.ui.draw.clip
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.focus.FocusRequester
|
||||||
|
import androidx.compose.ui.focus.focusRequester
|
||||||
import androidx.compose.ui.graphics.Color
|
import androidx.compose.ui.graphics.Color
|
||||||
import androidx.compose.ui.platform.LocalClipboardManager
|
import androidx.compose.ui.platform.LocalClipboardManager
|
||||||
|
import androidx.compose.ui.platform.LocalFocusManager
|
||||||
import androidx.compose.ui.res.dimensionResource
|
import androidx.compose.ui.res.dimensionResource
|
||||||
import androidx.compose.ui.text.TextLayoutResult
|
import androidx.compose.ui.text.TextLayoutResult
|
||||||
import androidx.compose.ui.text.style.TextOverflow
|
import androidx.compose.ui.text.style.TextOverflow
|
||||||
|
@ -111,6 +115,9 @@ fun PromptTemplatesPanel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val clipboardManager = LocalClipboardManager.current
|
val clipboardManager = LocalClipboardManager.current
|
||||||
|
val focusRequester = remember { FocusRequester() }
|
||||||
|
val focusManager = LocalFocusManager.current
|
||||||
|
val interactionSource = remember { MutableInteractionSource() }
|
||||||
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
|
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
|
||||||
|
|
||||||
// Update input editor values when prompt template changes.
|
// Update input editor values when prompt template changes.
|
||||||
|
@ -127,20 +134,30 @@ fun PromptTemplatesPanel(
|
||||||
|
|
||||||
Column(modifier = modifier) {
|
Column(modifier = modifier) {
|
||||||
// Scrollable tab row for all prompt templates.
|
// Scrollable tab row for all prompt templates.
|
||||||
PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) {
|
PrimaryScrollableTabRow(
|
||||||
|
selectedTabIndex = selectedTabIndex
|
||||||
|
) {
|
||||||
TAB_TITLES.forEachIndexed { index, title ->
|
TAB_TITLES.forEachIndexed { index, title ->
|
||||||
Tab(selected = selectedTabIndex == index, onClick = {
|
Tab(selected = selectedTabIndex == index,
|
||||||
// Clear input when tab changes.
|
enabled = !inProgress,
|
||||||
curTextInputContent = ""
|
onClick = {
|
||||||
// Reset full prompt switch.
|
// Clear input when tab changes.
|
||||||
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false
|
curTextInputContent = ""
|
||||||
|
// Reset full prompt switch.
|
||||||
|
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false
|
||||||
|
|
||||||
selectedTabIndex = index
|
selectedTabIndex = index
|
||||||
viewModel.selectPromptTemplate(
|
viewModel.selectPromptTemplate(
|
||||||
model = model,
|
model = model,
|
||||||
promptTemplateType = promptTemplateTypes[index]
|
promptTemplateType = promptTemplateTypes[index]
|
||||||
)
|
)
|
||||||
}, text = { Text(text = title) })
|
},
|
||||||
|
text = {
|
||||||
|
Text(
|
||||||
|
text = title,
|
||||||
|
modifier = Modifier.alpha(if (inProgress) 0.5f else 1f)
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,6 +195,13 @@ fun PromptTemplatesPanel(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxSize()
|
.fillMaxSize()
|
||||||
.verticalScroll(rememberScrollState())
|
.verticalScroll(rememberScrollState())
|
||||||
|
.clickable(
|
||||||
|
interactionSource = interactionSource,
|
||||||
|
indication = null // Disable the ripple effect
|
||||||
|
) {
|
||||||
|
// Request focus on the TextField when the Column is clicked
|
||||||
|
focusRequester.requestFocus()
|
||||||
|
}
|
||||||
) {
|
) {
|
||||||
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
|
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
|
||||||
Text(
|
Text(
|
||||||
|
@ -186,7 +210,7 @@ fun PromptTemplatesPanel(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
.padding(16.dp)
|
.padding(16.dp)
|
||||||
.padding(bottom = 32.dp)
|
.padding(bottom = 40.dp)
|
||||||
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
|
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
|
||||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||||
.padding(16.dp)
|
.padding(16.dp)
|
||||||
|
@ -205,7 +229,9 @@ fun PromptTemplatesPanel(
|
||||||
),
|
),
|
||||||
textStyle = MaterialTheme.typography.bodyMedium,
|
textStyle = MaterialTheme.typography.bodyMedium,
|
||||||
placeholder = { Text("Enter content") },
|
placeholder = { Text("Enter content") },
|
||||||
modifier = Modifier.padding(bottom = 32.dp)
|
modifier = Modifier
|
||||||
|
.padding(bottom = 40.dp)
|
||||||
|
.focusRequester(focusRequester)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -301,6 +327,7 @@ fun PromptTemplatesPanel(
|
||||||
OutlinedIconButton(
|
OutlinedIconButton(
|
||||||
enabled = !inProgress && curTextInputContent.isNotEmpty(),
|
enabled = !inProgress && curTextInputContent.isNotEmpty(),
|
||||||
onClick = {
|
onClick = {
|
||||||
|
focusManager.clearFocus()
|
||||||
onSend(fullPrompt.text)
|
onSend(fullPrompt.text)
|
||||||
},
|
},
|
||||||
colors = IconButtonDefaults.iconButtonColors(
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue