From d94fec06747aaf667cc31382ca2dbb068ee3f05d Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Sat, 26 Apr 2025 21:35:03 -0700 Subject: [PATCH] Add support for LLM single-turn experience --- Android/src/app/build.gradle.kts | 2 +- .../com/google/aiedge/gallery/GalleryApp.kt | 1 - .../aiedge/gallery/GalleryApplication.kt | 9 +- .../com/google/aiedge/gallery/data/Model.kt | 10 +- .../com/google/aiedge/gallery/data/Tasks.kt | 26 +- .../aiedge/gallery/ui/ViewModelProvider.kt | 6 + .../gallery/ui/common/ModelPageAppBar.kt | 213 +++++++++ .../aiedge/gallery/ui/common/ModelPicker.kt | 132 ++++++ .../google/aiedge/gallery/ui/common/Utils.kt | 34 ++ .../gallery/ui/common/chat/ChatPanel.kt | 36 +- .../aiedge/gallery/ui/common/chat/ChatView.kt | 153 ++----- .../gallery/ui/common/chat/ChatViewModel.kt | 24 +- .../gallery/ui/common/chat/MarkdownText.kt | 2 +- .../ui/common/chat/MessageActionButton.kt | 5 +- .../ui/common/chat/MessageBodyBenchmarkLlm.kt | 10 +- .../common/chat/MessageBodyPromptTemplates.kt | 4 +- .../ui/common/chat/MessageInputText.kt | 3 +- .../chat/ModelDownloadStatusInfoPanel.kt | 119 +++++ .../common/chat/ModelDownloadingAnimation.kt | 238 +++++++--- .../gallery/ui/common/chat/ModelSelector.kt | 10 +- .../ui/common/modelitem/ModelNameAndStatus.kt | 2 - .../gallery/ui/common/modelitem/StatusIcon.kt | 11 +- .../aiedge/gallery/ui/home/HomeScreen.kt | 3 - .../gallery/ui/llmchat/LlmChatModelHelper.kt | 24 +- .../gallery/ui/llmchat/LlmChatViewModel.kt | 2 +- .../ui/llmsingleturn/LlmSingleTurnScreen.kt | 206 +++++++++ .../llmsingleturn/LlmSingleTurnViewModel.kt | 210 +++++++++ .../ui/llmsingleturn/PromptTemplateConfigs.kt | 185 ++++++++ .../ui/llmsingleturn/PromptTemplatesPanel.kt | 426 ++++++++++++++++++ .../gallery/ui/llmsingleturn/ResponsePanel.kt | 206 +++++++++ .../ui/llmsingleturn/SingleSelectButton.kt | 90 ++++ .../ui/llmsingleturn/VerticalSplitView.kt | 133 ++++++ .../ui/modelmanager/ModelManagerViewModel.kt | 78 ++-- .../gallery/ui/navigation/GalleryNavGraph.kt | 27 +- .../preview/PreviewLlmSingleTurnViewModel.kt | 21 + .../preview/PreviewModelManagerViewModel.kt | 2 +- .../google/aiedge/gallery/ui/theme/Color.kt | 3 +- .../google/aiedge/gallery/ui/theme/Theme.kt | 3 + .../aiedge/gallery/worker/DownloadWorker.kt | 2 + 39 files changed, 2376 insertions(+), 295 deletions(-) create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPicker.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelDownloadStatusInfoPanel.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplateConfigs.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/ResponsePanel.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/SingleSelectButton.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/VerticalSplitView.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewLlmSingleTurnViewModel.kt diff --git a/Android/src/app/build.gradle.kts b/Android/src/app/build.gradle.kts index 675415a..d21b009 100644 --- a/Android/src/app/build.gradle.kts +++ b/Android/src/app/build.gradle.kts @@ -30,7 +30,7 @@ android { minSdk = 24 targetSdk = 35 versionCode = 1 - versionName = "20250421" + versionName = "20250428" // Needed for HuggingFace auth workflows. manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth" diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApp.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApp.kt index 9c01f75..4cd0fa8 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApp.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApp.kt @@ -187,6 +187,5 @@ fun GalleryTopAppBar( else -> {} } } - ) } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApplication.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApplication.kt index f460d80..b422851 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApplication.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApplication.kt @@ -23,7 +23,7 @@ import androidx.datastore.preferences.core.Preferences import androidx.datastore.preferences.preferencesDataStore import com.google.aiedge.gallery.data.AppContainer import com.google.aiedge.gallery.data.DefaultAppContainer -import com.google.aiedge.gallery.data.TASKS +import com.google.aiedge.gallery.ui.common.processTasks import com.google.aiedge.gallery.ui.theme.ThemeSettings private val Context.dataStore: DataStore by preferencesDataStore(name = "app_gallery_preferences") @@ -36,12 +36,7 @@ class GalleryApplication : Application() { super.onCreate() // Process tasks. - for ((index, task) in TASKS.withIndex()) { - task.index = index - for (model in task.models) { - model.preProcess(task = task) - } - } + processTasks() container = DefaultAppContainer(this, dataStore) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt index b617ea1..00fa1e3 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt @@ -92,15 +92,13 @@ data class Model( val imported: Boolean = false, // The following fields are managed by the app. Don't need to set manually. - var taskType: TaskType? = null, var instance: Any? = null, var initializing: Boolean = false, var configValues: Map = mapOf(), var totalBytes: Long = 0L, var accessToken: String? = null, ) { - fun preProcess(task: Task) { - this.taskType = task.type + fun preProcess() { val configValues: MutableMap = mutableMapOf() for (config in this.configs) { configValues[config.key.label] = config.defaultValue @@ -246,6 +244,7 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model( url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin", sizeInBytes = 1354301440L, configs = createLlmChatConfigs(), + showBenchmarkButton = false, info = LLM_CHAT_INFO, learnMoreUrl = "https://huggingface.co/litert-community", ) @@ -256,6 +255,7 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model( url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin", sizeInBytes = 2627141632L, configs = createLlmChatConfigs(), + showBenchmarkButton = false, info = LLM_CHAT_INFO, learnMoreUrl = "https://huggingface.co/litert-community", ) @@ -271,6 +271,7 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model( defaultTopP = 0.95f, accelerators = listOf(Accelerator.CPU, Accelerator.GPU) ), + showBenchmarkButton = false, info = LLM_CHAT_INFO, learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT", llmPromptTemplates = listOf( @@ -299,6 +300,7 @@ val MODEL_LLM_DEEPSEEK: Model = Model( defaultTopP = 0.7f, accelerators = listOf(Accelerator.CPU) ), + showBenchmarkButton = false, info = LLM_CHAT_INFO, learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B", ) @@ -389,7 +391,7 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList = mutableListOf( MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2, ) -val MODELS_LLM_CHAT: MutableList = mutableListOf( +val MODELS_LLM: MutableList = mutableListOf( MODEL_LLM_GEMMA_2B_GPU_INT4, MODEL_LLM_GEMMA_2_2B_GPU_INT8, MODEL_LLM_GEMMA_3_1B_INT4, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt index c1ca12e..4d853dd 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt @@ -18,9 +18,11 @@ package com.google.aiedge.gallery.data import androidx.annotation.StringRes import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.outlined.Forum +import androidx.compose.material.icons.outlined.Widgets import androidx.compose.material.icons.rounded.ImageSearch import androidx.compose.runtime.MutableState -import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.mutableLongStateOf import androidx.compose.ui.graphics.vector.ImageVector import com.google.aiedge.gallery.R @@ -30,6 +32,7 @@ enum class TaskType(val label: String) { IMAGE_CLASSIFICATION("Image Classification"), IMAGE_GENERATION("Image Generation"), LLM_CHAT("LLM Chat"), + LLM_SINGLE_TURN("LLM Use Cases"), TEST_TASK_1("Test task 1"), TEST_TASK_2("Test task 2") @@ -67,7 +70,7 @@ data class Task( // The following fields are managed by the app. Don't need to set manually. var index: Int = -1, - val updateTrigger: MutableState = mutableStateOf(0) + val updateTrigger: MutableState = mutableLongStateOf(0) ) val TASK_TEXT_CLASSIFICATION = Task( @@ -87,9 +90,19 @@ val TASK_IMAGE_CLASSIFICATION = Task( val TASK_LLM_CHAT = Task( type = TaskType.LLM_CHAT, - iconVectorResourceId = R.drawable.chat_spark, - models = MODELS_LLM_CHAT, - description = "Chat? with a on-device large language model", + icon = Icons.Outlined.Forum, + models = MODELS_LLM, + description = "Chat with a on-device large language model", + docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", + sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", + textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat +) + +val TASK_LLM_SINGLE_TURN = Task( + type = TaskType.LLM_SINGLE_TURN, + icon = Icons.Outlined.Widgets, + models = MODELS_LLM, + description = "Single turn use cases with on-device large language model", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat @@ -108,9 +121,10 @@ val TASK_IMAGE_GENERATION = Task( /** All tasks. */ val TASKS: List = listOf( // TASK_TEXT_CLASSIFICATION, -// TASK_IMAGE_CLASSIFICATION, + TASK_IMAGE_CLASSIFICATION, TASK_IMAGE_GENERATION, TASK_LLM_CHAT, + TASK_LLM_SINGLE_TURN, ) fun getModelByName(name: String): Model? { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt index 6f25663..a1a4739 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt @@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel +import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel @@ -56,6 +57,11 @@ object ViewModelProvider { LlmChatViewModel() } + // Initializer for LlmSingleTurnViewModel.. + initializer { + LlmSingleTurnViewModel() + } + initializer { ImageGenerationViewModel() } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt new file mode 100644 index 0000000..6832e59 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt @@ -0,0 +1,213 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.aiedge.gallery.ui.common + +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.rounded.ArrowBack +import androidx.compose.material.icons.rounded.ArrowDropDown +import androidx.compose.material.icons.rounded.Settings +import androidx.compose.material3.CenterAlignedTopAppBar +import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.ModalBottomSheet +import androidx.compose.material3.Text +import androidx.compose.material3.rememberModalBottomSheetState +import androidx.compose.runtime.Composable +import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.vector.ImageVector +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.res.vectorResource +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.dp +import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.data.ModelDownloadStatusType +import com.google.aiedge.gallery.data.Task +import com.google.aiedge.gallery.ui.common.chat.ConfigDialog +import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun ModelPageAppBar( + task: Task, + model: Model, + modelManagerViewModel: ModelManagerViewModel, + onBackClicked: () -> Unit, + onModelSelected: (Model) -> Unit, + modifier: Modifier = Modifier, + onConfigChanged: (oldConfigValues: Map, newConfigValues: Map) -> Unit = { _, _ -> }, +) { + var showConfigDialog by remember { mutableStateOf(false) } + var showModelPicker by remember { mutableStateOf(false) } + val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() + val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) + val context = LocalContext.current + val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name] + + CenterAlignedTopAppBar(title = { + Column(horizontalAlignment = Alignment.CenterHorizontally) { + // Task type. + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp) + ) { + Icon( + task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!), + tint = getTaskIconColor(task = task), + modifier = Modifier.size(16.dp), + contentDescription = "", + ) + Text( + task.type.label, + style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold), + color = getTaskIconColor(task = task) + ) + } + + // Model name. + Row(verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(2.dp), + modifier = Modifier + .clip(CircleShape) + .background(MaterialTheme.colorScheme.surfaceContainerHigh) + .clickable { + showModelPicker = true + } + .padding(start = 8.dp, end = 2.dp)) { + StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name]) + Text( + model.name, + style = MaterialTheme.typography.labelSmall, + modifier = Modifier.padding(start = 4.dp), + ) + Icon( + Icons.Rounded.ArrowDropDown, + modifier = Modifier.size(20.dp), + contentDescription = "", + ) + } + + } + }, modifier = modifier, + // The back button. + navigationIcon = { + IconButton(onClick = onBackClicked) { + Icon( + imageVector = Icons.AutoMirrored.Rounded.ArrowBack, + contentDescription = "", + ) + } + }, + // The config button for the model (if existed). + actions = { + if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { + IconButton(onClick = { showConfigDialog = true }) { + Icon( + imageVector = Icons.Rounded.Settings, + contentDescription = "", + tint = MaterialTheme.colorScheme.primary + ) + } + } + }) + + // Config dialog. + if (showConfigDialog) { + ConfigDialog( + title = "Model configs", + configs = model.configs, + initialValues = model.configValues, + onDismissed = { showConfigDialog = false }, + onOk = { curConfigValues -> + // Hide config dialog. + showConfigDialog = false + + // Check if the configs are changed or not. Also check if the model needs to be + // re-initialized. + var same = true + var needReinitialization = false + for (config in model.configs) { + val key = config.key.label + val oldValue = convertValueToTargetType( + value = model.configValues.getValue(key), valueType = config.valueType + ) + val newValue = convertValueToTargetType( + value = curConfigValues.getValue(key), valueType = config.valueType + ) + if (oldValue != newValue) { + same = false + if (config.needReinitialization) { + needReinitialization = true + } + break + } + } + if (same) { + return@ConfigDialog + } + + // Save the config values to Model. + val oldConfigValues = model.configValues + model.configValues = curConfigValues + + // Force to re-initialize the model with the new configs. + if (needReinitialization) { + modelManagerViewModel.initializeModel( + context = context, task = task, model = model, force = true + ) + } + + // Notify. + onConfigChanged(oldConfigValues, model.configValues) + }, + ) + } + + // Model picker. + if (showModelPicker) { + ModalBottomSheet( + onDismissRequest = { showModelPicker = false }, + sheetState = sheetState, + ) { + ModelPicker( + task = task, + modelManagerViewModel = modelManagerViewModel, + onModelSelected = { model -> + showModelPicker = false + onModelSelected(model) + } + ) + } + } +} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPicker.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPicker.kt new file mode 100644 index 0000000..c27f74d --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPicker.kt @@ -0,0 +1,132 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.aiedge.gallery.ui.common + +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.layout.width +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.outlined.CheckCircle +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.getValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.vector.ImageVector +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.res.vectorResource +import androidx.compose.ui.tooling.preview.Preview +import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.data.Task +import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel +import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel +import com.google.aiedge.gallery.ui.preview.TASK_TEST1 +import com.google.aiedge.gallery.ui.theme.GalleryTheme +import com.google.aiedge.gallery.ui.theme.labelSmallNarrow + +@Composable +fun ModelPicker( + task: Task, + modelManagerViewModel: ModelManagerViewModel, + onModelSelected: (Model) -> Unit +) { + val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() + + Column(modifier = Modifier.padding(bottom = 8.dp)) { + // Title + Row( + modifier = Modifier + .padding(horizontal = 16.dp) + .padding(top = 4.dp, bottom = 4.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(8.dp), + ) { + Icon( + task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!), + tint = getTaskIconColor(task = task), + modifier = Modifier.size(16.dp), + contentDescription = "", + ) + Text( + "${task.type.label} models", + modifier = Modifier.fillMaxWidth(), + style = MaterialTheme.typography.titleMedium, + color = getTaskIconColor(task = task), + ) + } + + // Model list. + for (model in task.models) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween, + modifier = Modifier + .fillMaxWidth() + .clickable { + onModelSelected(model) + } + .padding(horizontal = 16.dp, vertical = 8.dp), + ) { + Spacer(modifier = Modifier.width(24.dp)) + Column(modifier = Modifier.weight(1f)) { + Text(model.name, style = MaterialTheme.typography.bodyMedium) + Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) { + StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name]) + Text( + model.sizeInBytes.humanReadableSize(), + color = MaterialTheme.colorScheme.secondary, + style = labelSmallNarrow.copy(lineHeight = 10.sp) + ) + } + } + if (model.name == modelManagerUiState.selectedModel.name) { + Icon( + Icons.Outlined.CheckCircle, + modifier = Modifier.size(16.dp), + contentDescription = "" + ) + } + } + } + } +} + +@Preview(showBackground = true) +@Composable +fun ModelPickerPreview() { + val context = LocalContext.current + + GalleryTheme { + ModelPicker( + task = TASK_TEST1, + modelManagerViewModel = PreviewModelManagerViewModel(context = context), + onModelSelected = {}, + ) + } +} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt index d9709fb..d614b3c 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt @@ -29,6 +29,7 @@ import androidx.core.content.ContextCompat import androidx.core.content.FileProvider import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkResult @@ -57,6 +58,9 @@ interface LatencyProvider { val latencyMs: Float } +private const val START_THINKING = "***Thinking...***" +private const val DONE_THINKING = "***Done thinking***" + /** Format the bytes into a human-readable format. */ fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String { val bytes = this @@ -452,3 +456,33 @@ fun cleanUpMediapipeTaskErrorMessage(message: String): String { } return message } + +fun processTasks() { + for ((index, task) in TASKS.withIndex()) { + task.index = index + for (model in task.models) { + model.preProcess() + } + } +} + +fun processLlmResponse(response: String): String { + // Add "thinking" and "done thinking" around the thinking content. + var newContent = response + .replace("", "$START_THINKING\n") + .replace("", "\n$DONE_THINKING") + + // Remove empty thinking content. + val endThinkingIndex = newContent.indexOf(DONE_THINKING) + if (endThinkingIndex >= 0) { + val thinkingContent = + newContent.substring(0, endThinkingIndex + DONE_THINKING.length) + .replace(START_THINKING, "") + .replace(DONE_THINKING, "") + if (thinkingContent.isBlank()) { + newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length) + } + } + + return newContent +} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt index 15a4c2d..a573590 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt @@ -38,6 +38,7 @@ import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.wrapContentHeight +import androidx.compose.foundation.layout.wrapContentWidth import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.rememberLazyListState @@ -334,7 +335,10 @@ fun ChatPanel( is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message) // Benchmark LLM result. - is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(message = message) + is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm( + message = message, + modifier = Modifier.wrapContentWidth() + ) else -> {} } @@ -346,7 +350,7 @@ fun ChatPanel( ) { LatencyText(message = message) // A button to show stats for the LLM message. - if (selectedModel.taskType == TaskType.LLM_CHAT && message is ChatMessageText + if (task.type == TaskType.LLM_CHAT && message is ChatMessageText // This means we only want to show the action button when the message is done // generating, at which point the latency will be set. && message.latencyMs >= 0 @@ -403,21 +407,17 @@ fun ChatPanel( } // Benchmark button -// if (selectedModel.showBenchmarkButton) { -// MessageActionButton( -// label = stringResource(R.string.benchmark), -// icon = Icons.Outlined.Timer, -// onClick = { -// if (selectedModel.taskType == TaskType.LLM_CHAT) { -// onBenchmarkClicked(selectedModel, message, 0, 0) -// } else { -// showBenchmarkConfigsDialog = true -// benchmarkMessage.value = message -// } -// }, -// enabled = !uiState.inProgress -// ) -// } + if (selectedModel.showBenchmarkButton) { + MessageActionButton( + label = stringResource(R.string.benchmark), + icon = Icons.Outlined.Timer, + onClick = { + showBenchmarkConfigsDialog = true + benchmarkMessage.value = message + }, + enabled = !uiState.inProgress + ) + } } } } @@ -443,7 +443,7 @@ fun ChatPanel( // Chat input when (chatInputType) { ChatInputType.TEXT -> { - val isLlmTask = selectedModel.taskType == TaskType.LLM_CHAT + val isLlmTask = task.type == TaskType.LLM_CHAT val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates) MessageInputText( modelManagerViewModel = modelManagerViewModel, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt index 22d1178..44d37b7 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt @@ -18,18 +18,10 @@ package com.google.aiedge.gallery.ui.common.chat import android.util.Log import androidx.activity.compose.BackHandler -import androidx.activity.compose.rememberLauncherForActivityResult -import androidx.activity.result.contract.ActivityResultContracts -import androidx.compose.animation.AnimatedVisibility -import androidx.compose.animation.fadeIn -import androidx.compose.animation.fadeOut -import androidx.compose.animation.scaleIn -import androidx.compose.animation.scaleOut import androidx.compose.foundation.background import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding import androidx.compose.foundation.pager.HorizontalPager import androidx.compose.foundation.pager.rememberPagerState @@ -40,29 +32,21 @@ import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope -import androidx.compose.runtime.setValue -import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.tooling.preview.Preview -import com.google.aiedge.gallery.GalleryTopAppBar -import com.google.aiedge.gallery.data.AppBarAction -import com.google.aiedge.gallery.data.AppBarActionType import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.Task -import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload +import com.google.aiedge.gallery.ui.common.ModelPageAppBar import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewChatModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.theme.GalleryTheme import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlin.math.absoluteValue @@ -77,7 +61,6 @@ private const val TAG = "AGChatView" * manages model initialization, cleanup, and download status, and handles navigation and system * back gestures. */ -@OptIn(ExperimentalMaterial3Api::class) @Composable fun ChatView( task: Task, @@ -96,34 +79,29 @@ fun ChatView( val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val selectedModel = modelManagerUiState.selectedModel - val pagerState = rememberPagerState(initialPage = task.models.indexOf(selectedModel), + val pagerState = rememberPagerState( + initialPage = task.models.indexOf(selectedModel), pageCount = { task.models.size }) val context = LocalContext.current val scope = rememberCoroutineScope() - val launcher = rememberLauncherForActivityResult( - ActivityResultContracts.RequestPermission() - ) { - modelManagerViewModel.downloadModel(task = task, model = selectedModel) - } - val handleNavigateUp = { navigateUp() // clean up all models. scope.launch(Dispatchers.Default) { for (model in task.models) { - modelManagerViewModel.cleanupModel(model = model) + modelManagerViewModel.cleanupModel(task = task, model = model) } } } // Initialize model when model/download state changes. - val status = modelManagerUiState.modelDownloadStatus[selectedModel.name] - LaunchedEffect(status, selectedModel.name) { - if (status?.status == ModelDownloadStatusType.SUCCEEDED) { + val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name] + LaunchedEffect(curDownloadStatus, selectedModel.name) { + if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect") - modelManagerViewModel.initializeModel(context, model = selectedModel) + modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) } } @@ -135,7 +113,7 @@ fun ChatView( "Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model." ) if (curSelectedModel.name != selectedModel.name) { - modelManagerViewModel.cleanupModel(model = selectedModel) + modelManagerViewModel.cleanupModel(task = task, model = selectedModel) } modelManagerViewModel.selectModel(curSelectedModel) } @@ -146,24 +124,36 @@ fun ChatView( } Scaffold(modifier = modifier, topBar = { - GalleryTopAppBar( - title = task.type.label, - leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = { + ModelPageAppBar( + task = task, + model = selectedModel, + modelManagerViewModel = modelManagerViewModel, + onConfigChanged = { old, new -> + viewModel.addConfigChangedMessage( + oldConfigValues = old, + newConfigValues = new, + model = selectedModel + ) + }, + onBackClicked = { handleNavigateUp() - }), - rightAction = AppBarAction(actionType = AppBarActionType.NO_ACTION, actionFn = {}), + }, + onModelSelected = { model -> + scope.launch { + pagerState.animateScrollToPage(task.models.indexOf(model)) + } + }, ) }) { innerPadding -> Box { // A horizontal scrollable pager to switch between models. HorizontalPager(state = pagerState) { pageIndex -> val curSelectedModel = task.models[pageIndex] + val curModelDownloadStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name] // Calculate the alpha of the current page based on how far they are from the center. - val pageOffset = ( - (pagerState.currentPage - pageIndex) + pagerState - .currentPageOffsetFraction - ).absoluteValue + val pageOffset = + ((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue val curAlpha = 1f - pageOffset.coerceIn(0f, 1f) Column( @@ -172,91 +162,14 @@ fun ChatView( .fillMaxSize() .background(MaterialTheme.colorScheme.surface) ) { - // Model selector at the top. - ModelSelector( + ModelDownloadStatusInfoPanel( model = curSelectedModel, task = task, - modelManagerViewModel = modelManagerViewModel, - onConfigChanged = { old, new -> - viewModel.addConfigChangedMessage( - oldConfigValues = old, - newConfigValues = new, - model = curSelectedModel - ) - }, - modifier = Modifier.fillMaxWidth(), - contentAlpha = curAlpha, + modelManagerViewModel = modelManagerViewModel ) - // Manages the conditional display of UI elements (download model button and downloading - // animation) based on the corresponding download status. - // - // It uses delayed visibility ensuring they are shown only after a short delay if their - // respective conditions remain true. This prevents UI flickering and provides a smoother - // user experience. - val curStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name] - var shouldShowDownloadingAnimation by remember { mutableStateOf(false) } - var downloadingAnimationConditionMet by remember { mutableStateOf(false) } - var shouldShowDownloadModelButton by remember { mutableStateOf(false) } - var downloadModelButtonConditionMet by remember { mutableStateOf(false) } - - downloadingAnimationConditionMet = - curStatus?.status == ModelDownloadStatusType.IN_PROGRESS || - curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED || - curStatus?.status == ModelDownloadStatusType.UNZIPPING - downloadModelButtonConditionMet = - curStatus?.status == ModelDownloadStatusType.FAILED || - curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED - - LaunchedEffect(downloadingAnimationConditionMet) { - if (downloadingAnimationConditionMet) { - delay(100) - shouldShowDownloadingAnimation = true - } else { - shouldShowDownloadingAnimation = false - } - } - - LaunchedEffect(downloadModelButtonConditionMet) { - if (downloadModelButtonConditionMet) { - delay(700) - shouldShowDownloadModelButton = true - } else { - shouldShowDownloadModelButton = false - } - } - - AnimatedVisibility( - visible = shouldShowDownloadingAnimation, - enter = scaleIn(initialScale = 0.9f) + fadeIn(), - exit = scaleOut(targetScale = 0.9f) + fadeOut() - ) { - Box( - modifier = Modifier.fillMaxSize(), - contentAlignment = Alignment.Center - ) { - ModelDownloadingAnimation() - } - } - - AnimatedVisibility( - visible = shouldShowDownloadModelButton, - enter = fadeIn(), - exit = fadeOut() - ) { - ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = { - checkNotificationPermissionAndStartDownload( - context = context, - launcher = launcher, - modelManagerViewModel = modelManagerViewModel, - task = task, - model = curSelectedModel - ) - }) - } - // The main messages panel. - if (curStatus?.status == ModelDownloadStatusType.SUCCEEDED) { + if (curModelDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { ChatPanel( modelManagerViewModel = modelManagerViewModel, task = task, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt index c5fbe2b..34f475e 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt @@ -20,13 +20,12 @@ import android.util.Log import androidx.lifecycle.ViewModel import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Task +import com.google.aiedge.gallery.ui.common.processLlmResponse import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.update private const val TAG = "AGChatViewModel" -private const val START_THINKING = "***Thinking...***" -private const val DONE_THINKING = "***Done thinking***" data class ChatUiState( /** @@ -121,26 +120,7 @@ open class ChatViewModel(val task: Task) : ViewModel() { if (newMessages.size > 0) { val lastMessage = newMessages.last() if (lastMessage is ChatMessageText) { - var newContent = "${lastMessage.content}${partialContent}" - // TODO: special handling for deepseek to remove the tag. - - // Add "thinking" and "done thinking" around the thinking content. - newContent = newContent - .replace("", "$START_THINKING\n") - .replace("", "\n$DONE_THINKING") - - // Remove empty thinking content. - val endThinkingIndex = newContent.indexOf(DONE_THINKING) - if (endThinkingIndex >= 0) { - val thinkingContent = - newContent.substring(0, endThinkingIndex + DONE_THINKING.length) - .replace(START_THINKING, "") - .replace(DONE_THINKING, "") - if (thinkingContent.isBlank()) { - newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length) - } - } - + val newContent = processLlmResponse(response = "${lastMessage.content}${partialContent}") val newLastMessage = ChatMessageText( content = newContent, side = lastMessage.side, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MarkdownText.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MarkdownText.kt index b0aca49..fc0d315 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MarkdownText.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MarkdownText.kt @@ -45,7 +45,7 @@ fun MarkdownText( ProvideTextStyle( value = TextStyle( fontSize = fontSize, - lineHeight = fontSize * 1.2, + lineHeight = fontSize * 1.4, ) ) { RichText( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageActionButton.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageActionButton.kt index ddf3467..35ab9d7 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageActionButton.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageActionButton.kt @@ -48,15 +48,16 @@ fun MessageActionButton( label: String, icon: ImageVector, onClick: () -> Unit, + modifier: Modifier = Modifier, enabled: Boolean = true ) { - val modifier = Modifier + val curModifier = modifier .padding(top = 4.dp) .clip(CircleShape) .background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh) val alpha: Float = if (enabled) 1.0f else 0.3f Row( - modifier = if (enabled) modifier.clickable { onClick() } else modifier, + modifier = if (enabled) curModifier.clickable { onClick() } else modifier, verticalAlignment = Alignment.CenterVertically, ) { Icon( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyBenchmarkLlm.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyBenchmarkLlm.kt index 3b25de1..dc87652 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyBenchmarkLlm.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyBenchmarkLlm.kt @@ -19,8 +19,8 @@ package com.google.aiedge.gallery.ui.common.chat import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding -import androidx.compose.foundation.layout.wrapContentWidth import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier import androidx.compose.ui.tooling.preview.Preview @@ -33,16 +33,14 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme * This function renders benchmark statistics (e.g., various token speed) in data cards */ @Composable -fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult) { +fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult, modifier: Modifier = Modifier) { Column( - modifier = Modifier - .padding(12.dp) - .wrapContentWidth(), + modifier = modifier.padding(12.dp), verticalArrangement = Arrangement.spacedBy(8.dp) ) { // Data cards. Row( - modifier = Modifier.wrapContentWidth(), horizontalArrangement = Arrangement.spacedBy(16.dp) + modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween ) { for (stat in message.orderedStats) { DataCard( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyPromptTemplates.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyPromptTemplates.kt index 0547ea4..1a81f2b 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyPromptTemplates.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyPromptTemplates.kt @@ -82,7 +82,7 @@ fun MessageBodyPromptTemplates( style = MaterialTheme.typography.titleSmall, modifier = Modifier .fillMaxWidth() - .offset(y = -4.dp), + .offset(y = (-4).dp), textAlign = TextAlign.Center, ) } @@ -140,7 +140,7 @@ fun MessageBodyPromptTemplatesPreview() { for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) { task.index = index for (model in task.models) { - model.preProcess(task = task) + model.preProcess() } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt index 2be0ecc..c550645 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt @@ -70,7 +70,6 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme * This function renders a row containing a text field for message input and a send button. * It handles message composition, input validation, and sending messages. */ -@OptIn(ExperimentalMaterial3Api::class) @Composable fun MessageInputText( modelManagerViewModel: ModelManagerViewModel, @@ -190,7 +189,7 @@ fun MessageInputText( Icons.AutoMirrored.Rounded.Send, contentDescription = "", modifier = Modifier.offset(x = 2.dp), - tint = if (inProgress) MaterialTheme.colorScheme.surfaceContainerHigh else MaterialTheme.colorScheme.primary + tint = MaterialTheme.colorScheme.primary ) } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelDownloadStatusInfoPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelDownloadStatusInfoPanel.kt new file mode 100644 index 0000000..37777ce --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelDownloadStatusInfoPanel.kt @@ -0,0 +1,119 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.aiedge.gallery.ui.common.chat + +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.fadeIn +import androidx.compose.animation.fadeOut +import androidx.compose.animation.scaleIn +import androidx.compose.animation.scaleOut +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.data.ModelDownloadStatusType +import com.google.aiedge.gallery.data.Task +import com.google.aiedge.gallery.ui.common.DownloadAndTryButton +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel +import kotlinx.coroutines.delay + +@Composable +fun ModelDownloadStatusInfoPanel( + model: Model, + task: Task, + modelManagerViewModel: ModelManagerViewModel +) { + val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() + + // Manages the conditional display of UI elements (download model button and downloading + // animation) based on the corresponding download status. + // + // It uses delayed visibility ensuring they are shown only after a short delay if their + // respective conditions remain true. This prevents UI flickering and provides a smoother + // user experience. + val curStatus = modelManagerUiState.modelDownloadStatus[model.name] + var shouldShowDownloadingAnimation by remember { mutableStateOf(false) } + var downloadingAnimationConditionMet by remember { mutableStateOf(false) } + var shouldShowDownloadModelButton by remember { mutableStateOf(false) } + var downloadModelButtonConditionMet by remember { mutableStateOf(false) } + + downloadingAnimationConditionMet = + curStatus?.status == ModelDownloadStatusType.IN_PROGRESS || curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED || curStatus?.status == ModelDownloadStatusType.UNZIPPING + downloadModelButtonConditionMet = + curStatus?.status == ModelDownloadStatusType.FAILED || curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED + + LaunchedEffect(downloadingAnimationConditionMet) { + if (downloadingAnimationConditionMet) { + delay(100) + shouldShowDownloadingAnimation = true + } else { + shouldShowDownloadingAnimation = false + } + } + + LaunchedEffect(downloadModelButtonConditionMet) { + if (downloadModelButtonConditionMet) { + delay(700) + shouldShowDownloadModelButton = true + } else { + shouldShowDownloadModelButton = false + } + } + + AnimatedVisibility( + visible = shouldShowDownloadingAnimation, + enter = scaleIn(initialScale = 0.9f) + fadeIn(), + exit = scaleOut(targetScale = 0.9f) + fadeOut() + ) { + Box( + modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center + ) { + ModelDownloadingAnimation( + model = model, task = task, modelManagerViewModel = modelManagerViewModel + ) + } + } + + AnimatedVisibility( + visible = shouldShowDownloadModelButton, enter = fadeIn(), exit = fadeOut() + ) { + Column( + modifier = Modifier.fillMaxSize(), + verticalArrangement = Arrangement.Center, + horizontalAlignment = Alignment.CenterHorizontally + ) { + DownloadAndTryButton( + task = task, + model = model, + enabled = true, + needToDownloadFirst = true, + modelManagerViewModel = modelManagerViewModel, + onClicked = {} + ) + } + } +} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelDownloadingAnimation.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelDownloadingAnimation.kt index cdfbda6..155974c 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelDownloadingAnimation.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelDownloadingAnimation.kt @@ -24,6 +24,7 @@ import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.padding @@ -32,23 +33,40 @@ import androidx.compose.foundation.layout.width import androidx.compose.foundation.lazy.grid.GridCells import androidx.compose.foundation.lazy.grid.LazyVerticalGrid import androidx.compose.foundation.lazy.grid.itemsIndexed +import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.derivedStateOf +import androidx.compose.runtime.getValue import androidx.compose.runtime.remember import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.ColorFilter import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.layout.ContentScale +import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.res.painterResource import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp import com.google.aiedge.gallery.R +import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.data.ModelDownloadStatusType +import com.google.aiedge.gallery.data.Task +import com.google.aiedge.gallery.ui.common.formatToHourMinSecond import com.google.aiedge.gallery.ui.common.getTaskIconColor +import com.google.aiedge.gallery.ui.common.humanReadableSize +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel +import com.google.aiedge.gallery.ui.preview.MODEL_TEST1 +import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel +import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.theme.GalleryTheme +import com.google.aiedge.gallery.ui.theme.labelSmallNarrow import kotlinx.coroutines.delay import kotlin.math.cos import kotlin.math.pow @@ -66,8 +84,19 @@ private const val END_SCALE = 0.6f * scaling and rotation effect. */ @Composable -fun ModelDownloadingAnimation() { +fun ModelDownloadingAnimation( + model: Model, + task: Task, + modelManagerViewModel: ModelManagerViewModel +) { val scale = remember { Animatable(END_SCALE) } + val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() + val downloadStatus by remember { + derivedStateOf { modelManagerUiState.modelDownloadStatus[model.name] } + } + val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS + val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED + var curDownloadProgress = 0f LaunchedEffect(Unit) { // Run this once while (true) { @@ -93,67 +122,156 @@ fun ModelDownloadingAnimation() { } } - Column( - horizontalAlignment = Alignment.CenterHorizontally, - modifier = Modifier.offset(y = -GRID_SIZE / 8) - ) { - LazyVerticalGrid( - columns = GridCells.Fixed(2), - horizontalArrangement = Arrangement.spacedBy(GRID_SPACING), - verticalArrangement = Arrangement.spacedBy(GRID_SPACING), - modifier = Modifier - .width(GRID_SIZE) - .height(GRID_SIZE) - ) { - itemsIndexed( - listOf( - R.drawable.pantegon, - R.drawable.double_circle, - R.drawable.circle, - R.drawable.four_circle - ) - ) { index, imageResource -> - val currentScale = - if (index == 0 || index == 3) scale.value else START_SCALE + END_SCALE - scale.value - Box( - modifier = Modifier - .width((GRID_SIZE - GRID_SPACING) / 2) - .height((GRID_SIZE - GRID_SPACING) / 2), - contentAlignment = when (index) { - 0 -> Alignment.BottomEnd - 1 -> Alignment.BottomStart - 2 -> Alignment.TopEnd - 3 -> Alignment.TopStart - else -> Alignment.Center - } - ) { - Image( - painter = painterResource(id = imageResource), - contentDescription = "", - contentScale = ContentScale.Fit, - colorFilter = ColorFilter.tint(getTaskIconColor(index = index)), - modifier = Modifier - .graphicsLayer { - scaleX = currentScale - scaleY = currentScale - rotationZ = currentScale * 120 - alpha = 0.8f - } - .size(70.dp) + // Failure message. + val curDownloadStatus = downloadStatus + if (curDownloadStatus != null && curDownloadStatus.status == ModelDownloadStatusType.FAILED) { + Row(verticalAlignment = Alignment.CenterVertically) { + Text( + curDownloadStatus.errorMessage, + color = MaterialTheme.colorScheme.error, + style = labelSmallNarrow, + overflow = TextOverflow.Ellipsis, + ) + } + } + // No failure + else { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + modifier = Modifier.offset(y = -GRID_SIZE / 8) + ) { + LazyVerticalGrid( + columns = GridCells.Fixed(2), + horizontalArrangement = Arrangement.spacedBy(GRID_SPACING), + verticalArrangement = Arrangement.spacedBy(GRID_SPACING), + modifier = Modifier + .width(GRID_SIZE) + .height(GRID_SIZE) + ) { + itemsIndexed( + listOf( + R.drawable.pantegon, + R.drawable.double_circle, + R.drawable.circle, + R.drawable.four_circle ) + ) { index, imageResource -> + val currentScale = + if (index == 0 || index == 3) scale.value else START_SCALE + END_SCALE - scale.value + + Box( + modifier = Modifier + .width((GRID_SIZE - GRID_SPACING) / 2) + .height((GRID_SIZE - GRID_SPACING) / 2), + contentAlignment = when (index) { + 0 -> Alignment.BottomEnd + 1 -> Alignment.BottomStart + 2 -> Alignment.TopEnd + 3 -> Alignment.TopStart + else -> Alignment.Center + } + ) { + Image( + painter = painterResource(id = imageResource), + contentDescription = "", + contentScale = ContentScale.Fit, + colorFilter = ColorFilter.tint(getTaskIconColor(index = index)), + modifier = Modifier + .graphicsLayer { + scaleX = currentScale + scaleY = currentScale + rotationZ = currentScale * 120 + alpha = 0.8f + } + .size(70.dp) + ) + } } } - } - Text( - "Feel free to switch apps or lock your device.\n" - + "The download will continue in the background.\n" - + "We'll send a notification when it's done.", - style = MaterialTheme.typography.bodyMedium, - textAlign = TextAlign.Center - ) + + // Download stats + var sizeLabel = model.totalBytes.humanReadableSize() + if (curDownloadStatus != null) { + // For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime} + if (inProgress || isPartiallyDownloaded) { + var totalSize = curDownloadStatus.totalBytes + if (totalSize == 0L) { + totalSize = model.totalBytes + } + sizeLabel = + "${curDownloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}" + if (curDownloadStatus.bytesPerSecond > 0) { + sizeLabel = + "$sizeLabel · ${curDownloadStatus.bytesPerSecond.humanReadableSize()} / s" + if (curDownloadStatus.remainingMs >= 0) { + sizeLabel = + "$sizeLabel · ${curDownloadStatus.remainingMs.formatToHourMinSecond()} left" + } + } + if (isPartiallyDownloaded) { + sizeLabel = "$sizeLabel (resuming...)" + } + curDownloadProgress = + curDownloadStatus.receivedBytes.toFloat() / curDownloadStatus.totalBytes.toFloat() + if (curDownloadProgress.isNaN()) { + curDownloadProgress = 0f + } + } + // Status for unzipping. + else if (curDownloadStatus.status == ModelDownloadStatusType.UNZIPPING) { + sizeLabel = "Unzipping..." + } + Text( + sizeLabel, + color = MaterialTheme.colorScheme.secondary, + style = labelSmallNarrow.copy(fontSize = 9.sp, lineHeight = 10.sp), + textAlign = TextAlign.Center, + overflow = TextOverflow.Visible, + modifier = Modifier + .padding(bottom = 4.dp) + ) + } + + // Download progress. + if (inProgress || isPartiallyDownloaded) { + val animatedProgress = remember { Animatable(0f) } + LinearProgressIndicator( + progress = { animatedProgress.value }, + color = getTaskIconColor(task = task), + trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, + modifier = Modifier + .fillMaxWidth() + .padding(bottom = 36.dp) + .padding(horizontal = 36.dp) + ) + LaunchedEffect(curDownloadProgress) { + animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150)) + } + } + // Unzipping progress. + else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) { + LinearProgressIndicator( + color = getTaskIconColor(task = task), + trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, + modifier = Modifier + .fillMaxWidth() + .padding(bottom = 36.dp) + .padding(horizontal = 36.dp) + ) + } + + Text( + "Feel free to switch apps or lock your device.\n" + + "The download will continue in the background.\n" + + "We'll send a notification when it's done.", + style = MaterialTheme.typography.bodyMedium, + textAlign = TextAlign.Center + ) + } } + } // Custom Easing function for a multi-bounce effect @@ -168,9 +286,15 @@ fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x -> @Preview(showBackground = true) @Composable fun ModelDownloadingAnimationPreview() { + val context = LocalContext.current + GalleryTheme { Row(modifier = Modifier.padding(16.dp)) { - ModelDownloadingAnimation() + ModelDownloadingAnimation( + model = MODEL_TEST1, + task = TASK_TEST1, + modelManagerViewModel = PreviewModelManagerViewModel(context = context) + ) } } } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelSelector.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelSelector.kt index b8853ff..46b059c 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelSelector.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ModelSelector.kt @@ -63,7 +63,8 @@ fun ModelSelector( ) { Box( modifier = Modifier - .fillMaxWidth().padding(bottom = 8.dp), + .fillMaxWidth() + .padding(bottom = 8.dp), contentAlignment = Alignment.Center ) { // Model row. @@ -134,7 +135,12 @@ fun ModelSelector( // Force to re-initialize the model with the new configs. if (needReinitialization) { - modelManagerViewModel.initializeModel(context = context, model = model, force = true) + modelManagerViewModel.initializeModel( + context = context, + task = task, + model = model, + force = true + ) } // Notify. diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt index 07cc6a0..182d016 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt @@ -181,7 +181,5 @@ fun ModelNameAndStatus( .padding(top = 2.dp), ) } - - } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/StatusIcon.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/StatusIcon.kt index cb94e8b..9c775a8 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/StatusIcon.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/StatusIcon.kt @@ -23,8 +23,10 @@ import androidx.compose.foundation.layout.size import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.outlined.HelpOutline import androidx.compose.material.icons.filled.DownloadForOffline +import androidx.compose.material.icons.rounded.Downloading import androidx.compose.material.icons.rounded.Error import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme import androidx.compose.runtime.Composable import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier @@ -34,6 +36,7 @@ import androidx.compose.ui.unit.dp import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.ui.theme.GalleryTheme +import com.google.aiedge.gallery.ui.theme.customColors /** * Composable function to display an icon representing the download status of a model. @@ -56,7 +59,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi ModelDownloadStatusType.SUCCEEDED -> { Icon( Icons.Filled.DownloadForOffline, - tint = Color(0xff3d860b), + tint = MaterialTheme.customColors.successColor, contentDescription = "", modifier = Modifier.size(14.dp) ) @@ -69,6 +72,12 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi modifier = Modifier.size(14.dp) ) + ModelDownloadStatusType.IN_PROGRESS -> Icon( + Icons.Rounded.Downloading, + contentDescription = "", + modifier = Modifier.size(14.dp) + ) + else -> {} } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt index 5b476fe..1e21e1e 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt @@ -73,7 +73,6 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.scale -import androidx.compose.ui.focus.focusModifier import androidx.compose.ui.graphics.Brush import androidx.compose.ui.input.nestedscroll.nestedScroll import androidx.compose.ui.layout.layout @@ -91,7 +90,6 @@ import com.google.aiedge.gallery.data.AppBarAction import com.google.aiedge.gallery.data.AppBarActionType import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ImportedModelInfo -import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.TaskIcon import com.google.aiedge.gallery.ui.common.getTaskBgColor @@ -275,7 +273,6 @@ fun HomeScreen( onDismiss = { showImportingDialog = false }, onDone = { modelManagerViewModel.addImportedLlmModel( - task = TASK_LLM_CHAT, info = it, ) showImportingDialog = false diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt index fe2be1b..d33bdad 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -30,7 +30,7 @@ private const val TAG = "AGLlmChatModelHelper" typealias ResultListener = (partialResult: String, done: Boolean) -> Unit typealias CleanUpListener = () -> Unit -data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceSession) +data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession) object LlmChatModelHelper { // Indexed by model name. @@ -74,6 +74,24 @@ object LlmChatModelHelper { onDone("") } + fun resetSession(model: Model) { + val instance = model.instance as LlmModelInstance? ?: return + val session = instance.session + session.close() + + val inference = instance.engine + val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK) + val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP) + val temperature = + model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE) + val newSession = LlmInferenceSession.createFromOptions( + inference, + LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) + .setTemperature(temperature).build() + ) + instance.session = newSession + } + fun cleanUp(model: Model) { if (model.instance == null) { return @@ -99,7 +117,11 @@ object LlmChatModelHelper { input: String, resultListener: ResultListener, cleanUpListener: CleanUpListener, + singleTurn: Boolean = false, ) { + if (singleTurn) { + resetSession(model = model) + } val instance = model.instance as LlmModelInstance // Set listener. diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt index 59862be..d2936d2 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -32,7 +32,7 @@ import kotlinx.coroutines.launch private const val TAG = "AGLlmChatViewModel" private val STATS = listOf( - Stat(id = "time_to_first_token", label = "Time to 1st token", unit = "sec"), + Stat(id = "time_to_first_token", label = "1st token", unit = "sec"), Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"), Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"), Stat(id = "latency", label = "Latency", unit = "sec") diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt new file mode 100644 index 0000000..326e284 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt @@ -0,0 +1,206 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.aiedge.gallery.ui.llmsingleturn + +import android.util.Log +import androidx.activity.compose.BackHandler +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.fadeIn +import androidx.compose.animation.fadeOut +import androidx.compose.animation.scaleIn +import androidx.compose.animation.scaleOut +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.calculateStartPadding +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.offset +import androidx.compose.foundation.layout.padding +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Scaffold +import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.getValue +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.platform.LocalLayoutDirection +import androidx.compose.ui.tooling.preview.Preview +import androidx.lifecycle.viewmodel.compose.viewModel +import com.google.aiedge.gallery.data.ModelDownloadStatusType +import com.google.aiedge.gallery.ui.ViewModelProvider +import com.google.aiedge.gallery.ui.common.ModelPageAppBar +import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel +import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip +import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel +import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel +import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel +import com.google.aiedge.gallery.ui.theme.GalleryTheme +import com.google.aiedge.gallery.ui.theme.customColors +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.serialization.Serializable + +/** Navigation destination data */ +object LlmSingleTurnDestination { + @Serializable + val route = "LlmSingleTurnRoute" +} + +private const val TAG = "AGLlmSingleTurnScreen" + +@Composable +fun LlmSingleTurnScreen( + modelManagerViewModel: ModelManagerViewModel, + navigateUp: () -> Unit, + modifier: Modifier = Modifier, + viewModel: LlmSingleTurnViewModel = viewModel( + factory = ViewModelProvider.Factory + ), +) { + val task = viewModel.task + val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() + val selectedModel = modelManagerUiState.selectedModel + val scope = rememberCoroutineScope() + val context = LocalContext.current + + val handleNavigateUp = { + navigateUp() + + // clean up all models. + scope.launch(Dispatchers.Default) { + for (model in task.models) { + modelManagerViewModel.cleanupModel(task = task, model = model) + } + } + } + + // Handle system's edge swipe. + BackHandler { + handleNavigateUp() + } + + // Initialize model when model/download state changes. + val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name] + LaunchedEffect(curDownloadStatus, selectedModel.name) { + if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { + Log.d( + TAG, + "Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect" + ) + modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) + } + } + + val modelInitializationStatus = + modelManagerUiState.modelInitializationStatus[selectedModel.name] + + Scaffold(modifier = modifier, topBar = { + ModelPageAppBar( + task = task, + model = selectedModel, + modelManagerViewModel = modelManagerViewModel, + onConfigChanged = { _, _ -> }, + onBackClicked = { handleNavigateUp() }, + onModelSelected = { newSelectedModel -> + scope.launch(Dispatchers.Default) { + // Clean up current model. + modelManagerViewModel.cleanupModel(task = task, model = selectedModel) + + // Update selected model. + modelManagerViewModel.selectModel(model = newSelectedModel) + } + } + ) + }) { innerPadding -> + Column( + modifier = Modifier.padding( + top = innerPadding.calculateTopPadding(), + start = innerPadding.calculateStartPadding(LocalLayoutDirection.current), + end = innerPadding.calculateStartPadding(LocalLayoutDirection.current), + ) + ) { + ModelDownloadStatusInfoPanel( + model = selectedModel, + task = task, + modelManagerViewModel = modelManagerViewModel + ) + + // Main UI after model is downloaded. + if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { + Box( + contentAlignment = Alignment.BottomCenter, + modifier = Modifier.weight(1f) + ) { + VerticalSplitView(modifier = Modifier.fillMaxSize(), + topView = { + PromptTemplatesPanel( + model = selectedModel, + viewModel = viewModel, + onSend = { fullPrompt -> + viewModel.generateResponse(model = selectedModel, input = fullPrompt) + }, modifier = Modifier.fillMaxSize() + ) + }, + bottomView = { + Box( + contentAlignment = Alignment.BottomCenter, + modifier = Modifier + .fillMaxSize() + .background(MaterialTheme.customColors.agentBubbleBgColor) + ) { + ResponsePanel( + model = selectedModel, + viewModel = viewModel, + modifier = Modifier + .fillMaxSize() + .padding(bottom = innerPadding.calculateBottomPadding()) + ) + } + }) + + // Model initialization in-progress message. + this@Column.AnimatedVisibility( + visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, + enter = scaleIn() + fadeIn(), + exit = scaleOut() + fadeOut(), + modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding()) + ) { + ModelInitializationStatusChip() + } + + } + } + } + } +} + +@Preview(showBackground = true) +@Composable +fun LlmSingleTurnScreenPreview() { + val context = LocalContext.current + GalleryTheme { + LlmSingleTurnScreen( + modelManagerViewModel = PreviewModelManagerViewModel(context = context), + viewModel = PreviewLlmSingleTurnViewModel(), + navigateUp = {}, + ) + } +} diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt new file mode 100644 index 0000000..fe5ffed --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt @@ -0,0 +1,210 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.aiedge.gallery.ui.llmsingleturn + +import android.util.Log +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN +import com.google.aiedge.gallery.data.Task +import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult +import com.google.aiedge.gallery.ui.common.chat.Stat +import com.google.aiedge.gallery.ui.common.processLlmResponse +import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper +import com.google.aiedge.gallery.ui.llmchat.LlmModelInstance +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.update +import kotlinx.coroutines.launch + +private const val TAG = "AGLlmSingleTurnViewModel" + +data class LlmSingleTurnUiState( + /** + * Indicates whether the runtime is currently processing a message. + */ + val inProgress: Boolean = false, + + /** + * Indicates whether the model is currently being initialized. + */ + val initializing: Boolean = false, + + // model ->