diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt index 00fa1e3..78a897c 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt @@ -19,7 +19,6 @@ package com.google.aiedge.gallery.data import android.content.Context import com.google.aiedge.gallery.ui.common.chat.PromptTemplate import com.google.aiedge.gallery.ui.common.convertValueToTargetType -import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs data class ModelDataFile( @@ -82,9 +81,6 @@ data class Model( /** The name of the directory to unzip the model to (if it's a zip file). */ val unzipDir: String = "", - /** The accelerators the the model can run with. */ - val accelerators: List = DEFAULT_ACCELERATORS, - /** The prompt templates for the model (only for LLM). */ val llmPromptTemplates: List = listOf(), @@ -243,7 +239,9 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model( downloadFileName = "gemma-2b-it-gpu-int4.bin", url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin", sizeInBytes = 1354301440L, - configs = createLlmChatConfigs(), + configs = createLlmChatConfigs( + accelerators = listOf(Accelerator.GPU) + ), showBenchmarkButton = false, info = LLM_CHAT_INFO, learnMoreUrl = "https://huggingface.co/litert-community", @@ -254,7 +252,9 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model( downloadFileName = "gemma2-2b-it-gpu-int8.bin", url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin", sizeInBytes = 2627141632L, - configs = createLlmChatConfigs(), + configs = createLlmChatConfigs( + accelerators = listOf(Accelerator.GPU) + ), showBenchmarkButton = false, info = LLM_CHAT_INFO, learnMoreUrl = "https://huggingface.co/litert-community", @@ -265,7 +265,6 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model( downloadFileName = "gemma3-1b-it-int4.task", url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true", sizeInBytes = 554661243L, - accelerators = listOf(Accelerator.CPU, Accelerator.GPU), configs = createLlmChatConfigs( defaultTopK = 64, defaultTopP = 0.95f, @@ -293,7 +292,6 @@ val MODEL_LLM_DEEPSEEK: Model = Model( downloadFileName = "deepseek.task", url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true", sizeInBytes = 1860686856L, - accelerators = listOf(Accelerator.CPU), configs = createLlmChatConfigs( defaultTemperature = 0.6f, defaultTopK = 40, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt index 326e284..0d8a2c7 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt @@ -36,7 +36,10 @@ import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext @@ -80,8 +83,10 @@ fun LlmSingleTurnScreen( val selectedModel = modelManagerUiState.selectedModel val scope = rememberCoroutineScope() val context = LocalContext.current + var navigatingUp by remember { mutableStateOf(false) } val handleNavigateUp = { + navigatingUp = true navigateUp() // clean up all models. @@ -100,12 +105,14 @@ fun LlmSingleTurnScreen( // Initialize model when model/download state changes. val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name] LaunchedEffect(curDownloadStatus, selectedModel.name) { - if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { - Log.d( - TAG, - "Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect" - ) - modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) + if (!navigatingUp) { + if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { + Log.d( + TAG, + "Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect" + ) + modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) + } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt index 5f9e53d..8c5aa08 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt @@ -20,6 +20,7 @@ import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.background import androidx.compose.foundation.border import androidx.compose.foundation.clickable +import androidx.compose.foundation.interaction.MutableInteractionSource import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -72,8 +73,11 @@ import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.clip +import androidx.compose.ui.focus.FocusRequester +import androidx.compose.ui.focus.focusRequester import androidx.compose.ui.graphics.Color import androidx.compose.ui.platform.LocalClipboardManager +import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.res.dimensionResource import androidx.compose.ui.text.TextLayoutResult import androidx.compose.ui.text.style.TextOverflow @@ -111,6 +115,9 @@ fun PromptTemplatesPanel( } } val clipboardManager = LocalClipboardManager.current + val focusRequester = remember { FocusRequester() } + val focusManager = LocalFocusManager.current + val interactionSource = remember { MutableInteractionSource() } val expandedStates = remember { mutableStateMapOf() } // Update input editor values when prompt template changes. @@ -127,20 +134,30 @@ fun PromptTemplatesPanel( Column(modifier = modifier) { // Scrollable tab row for all prompt templates. - PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) { + PrimaryScrollableTabRow( + selectedTabIndex = selectedTabIndex + ) { TAB_TITLES.forEachIndexed { index, title -> - Tab(selected = selectedTabIndex == index, onClick = { - // Clear input when tab changes. - curTextInputContent = "" - // Reset full prompt switch. - inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false + Tab(selected = selectedTabIndex == index, + enabled = !inProgress, + onClick = { + // Clear input when tab changes. + curTextInputContent = "" + // Reset full prompt switch. + inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false - selectedTabIndex = index - viewModel.selectPromptTemplate( - model = model, - promptTemplateType = promptTemplateTypes[index] - ) - }, text = { Text(text = title) }) + selectedTabIndex = index + viewModel.selectPromptTemplate( + model = model, + promptTemplateType = promptTemplateTypes[index] + ) + }, + text = { + Text( + text = title, + modifier = Modifier.alpha(if (inProgress) 0.5f else 1f) + ) + }) } } @@ -178,6 +195,13 @@ fun PromptTemplatesPanel( modifier = Modifier .fillMaxSize() .verticalScroll(rememberScrollState()) + .clickable( + interactionSource = interactionSource, + indication = null // Disable the ripple effect + ) { + // Request focus on the TextField when the Column is clicked + focusRequester.requestFocus() + } ) { if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) { Text( @@ -186,7 +210,7 @@ fun PromptTemplatesPanel( modifier = Modifier .fillMaxWidth() .padding(16.dp) - .padding(bottom = 32.dp) + .padding(bottom = 40.dp) .clip(MessageBubbleShape(radius = bubbleBorderRadius)) .background(MaterialTheme.customColors.agentBubbleBgColor) .padding(16.dp) @@ -205,7 +229,9 @@ fun PromptTemplatesPanel( ), textStyle = MaterialTheme.typography.bodyMedium, placeholder = { Text("Enter content") }, - modifier = Modifier.padding(bottom = 32.dp) + modifier = Modifier + .padding(bottom = 40.dp) + .focusRequester(focusRequester) ) } } @@ -301,6 +327,7 @@ fun PromptTemplatesPanel( OutlinedIconButton( enabled = !inProgress && curTextInputContent.isNotEmpty(), onClick = { + focusManager.clearFocus() onSend(fullPrompt.text) }, colors = IconButtonDefaults.iconButtonColors(