From 42d442389dccfcacf1a6368ce311722e1bc41753 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Wed, 16 Apr 2025 14:51:45 -0700 Subject: [PATCH] - Add "show stats" after each LLM response to see token related stats more easily. - Back to multi-turn (no exception handling yet). - Show app's version in App Info screen. - Show API doc link and source code link in the model list screen. - Jump to model info page in HF when clicking "learn more". --- Android/src/app/build.gradle.kts | 3 +- .../java/com/google/aiedge/gallery/Version.kt | 19 ---- .../com/google/aiedge/gallery/data/Model.kt | 12 +-- .../com/google/aiedge/gallery/data/Tasks.kt | 14 ++- .../gallery/ui/common/chat/ChatMessage.kt | 6 +- .../gallery/ui/common/chat/ChatPanel.kt | 98 +++++++++++++++---- .../gallery/ui/common/chat/ChatViewModel.kt | 76 +++++++++++++- .../gallery/ui/common/chat/MessageSender.kt | 2 +- .../gallery/ui/common/modelitem/ModelItem.kt | 34 +------ .../aiedge/gallery/ui/home/SettingsDialog.kt | 4 +- .../gallery/ui/llmchat/LlmChatModelHelper.kt | 51 +++------- .../gallery/ui/llmchat/LlmChatViewModel.kt | 50 +++++++++- .../gallery/ui/modelmanager/ModelList.kt | 87 +++++++++++++++- .../gallery/ui/modelmanager/ModelManager.kt | 41 -------- .../google/aiedge/gallery/ui/theme/Theme.kt | 27 ++--- .../src/app/src/main/res/values/strings.xml | 2 +- 16 files changed, 343 insertions(+), 183 deletions(-) delete mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/Version.kt diff --git a/Android/src/app/build.gradle.kts b/Android/src/app/build.gradle.kts index d06b8d7..dfe361f 100644 --- a/Android/src/app/build.gradle.kts +++ b/Android/src/app/build.gradle.kts @@ -30,7 +30,7 @@ android { minSdk = 24 targetSdk = 35 versionCode = 1 - versionName = "1.0" + versionName = "20250416" // Needed for HuggingFace auth workflows. manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth" @@ -58,6 +58,7 @@ android { } buildFeatures { compose = true + buildConfig = true } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/Version.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/Version.kt deleted file mode 100644 index 7924e7a..0000000 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/Version.kt +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aiedge.gallery - -const val VERSION = "20250413" \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt index a99ccb8..54c8a4e 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt @@ -218,9 +218,6 @@ const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/lite const val LLM_CHAT_INFO = "Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms" -const val LLM_CHAT_LEARN_MORE_URL = - "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android" - const val IMAGE_GENERATION_INFO = "Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)" @@ -234,7 +231,7 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model( sizeInBytes = 1354301440L, configs = createLlmChatConfigs(), info = LLM_CHAT_INFO, - learnMoreUrl = LLM_CHAT_LEARN_MORE_URL, + learnMoreUrl = "https://huggingface.co/litert-community", ) val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model( @@ -244,7 +241,7 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model( sizeInBytes = 2627141632L, configs = createLlmChatConfigs(), info = LLM_CHAT_INFO, - learnMoreUrl = LLM_CHAT_LEARN_MORE_URL, + learnMoreUrl = "https://huggingface.co/litert-community", ) val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model( @@ -254,7 +251,7 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model( sizeInBytes = 554661243L, configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f), info = LLM_CHAT_INFO, - learnMoreUrl = LLM_CHAT_LEARN_MORE_URL, + learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT", llmPromptTemplates = listOf( PromptTemplate( title = "Emoji Fun", @@ -277,7 +274,7 @@ val MODEL_LLM_DEEPSEEK: Model = Model( llmBackend = LlmBackend.CPU, configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f), info = LLM_CHAT_INFO, - learnMoreUrl = LLM_CHAT_LEARN_MORE_URL, + learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B", ) val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model( @@ -343,6 +340,7 @@ val MODEL_IMAGE_GENERATION_STABLE_DIFFUSION: Model = Model( showBenchmarkButton = false, info = IMAGE_GENERATION_INFO, configs = IMAGE_GENERATION_CONFIGS, + learnMoreUrl = "https://huggingface.co/litert-community", ) val EMPTY_MODEL: Model = Model( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt index f96dd8a..5b7c9b3 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt @@ -50,6 +50,12 @@ data class Task( /** Description of the task. */ val description: String, + /** Documentation url for the task. */ + val docUrl: String = "", + + /** Source code url for the model-related functions. */ + val sourceCodeUrl: String = "", + /** Placeholder text for the name of the agent shown above chat messages. */ @StringRes val agentNameRes: Int = R.string.chat_generic_agent_name, @@ -80,6 +86,8 @@ val TASK_LLM_CHAT = Task( iconVectorResourceId = R.drawable.chat_spark, models = MODELS_LLM_CHAT, description = "Chat? with a on-device large language model", + docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", + sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat ) @@ -88,13 +96,15 @@ val TASK_IMAGE_GENERATION = Task( iconVectorResourceId = R.drawable.image_spark, models = MODELS_IMAGE_GENERATION, description = "Generate images from text", + docUrl = "https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android", + sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt", textInputPlaceHolderRes = R.string.text_image_generation_text_field_placeholder ) /** All tasks. */ val TASKS: List = listOf( - TASK_TEXT_CLASSIFICATION, - TASK_IMAGE_CLASSIFICATION, +// TASK_TEXT_CLASSIFICATION, +// TASK_IMAGE_CLASSIFICATION, TASK_IMAGE_GENERATION, TASK_LLM_CHAT, ) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatMessage.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatMessage.kt index c47b7ab..17938e1 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatMessage.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatMessage.kt @@ -71,13 +71,17 @@ open class ChatMessageText( // Negative numbers will hide the latency display. override val latencyMs: Float = 0f, val isMarkdown: Boolean = true, + + // Benchmark result for LLM response. + var llmBenchmarkResult: ChatMessageBenchmarkLlmResult? = null, ) : ChatMessage(type = ChatMessageType.TEXT, side = side, latencyMs = latencyMs) { override fun clone(): ChatMessageText { return ChatMessageText( content = content, side = side, latencyMs = latencyMs, - isMarkdown = isMarkdown + isMarkdown = isMarkdown, + llmBenchmarkResult = llmBenchmarkResult, ) } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt index 8211b7c..1ef6f68 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt @@ -150,6 +150,8 @@ fun ChatPanel( lastMessageContent.value = tmpLastMessage.content } } + val lastShowingStatsByModel: MutableState>> = + remember { mutableStateOf(mapOf()) } // Scroll the content to the bottom when any of these changes. LaunchedEffect( @@ -158,9 +160,27 @@ fun ChatPanel( lastMessageContent.value, WindowInsets.ime.getBottom(density), ) { + // Only scroll if showingStatsByModel is not changed. In other words, when showingStatsByModel + // changes we want the display to not scroll. if (messages.isNotEmpty()) { - listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000) + if (uiState.showingStatsByModel === lastShowingStatsByModel.value) { + listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000) + } else { + // Scroll to bottom if the message to show stats is the last message. + val curShowingStats = + uiState.showingStatsByModel[selectedModel.name]?.toMutableSet() ?: mutableSetOf() + val lastShowingStats = lastShowingStatsByModel.value[selectedModel.name] ?: mutableSetOf() + curShowingStats.removeAll(lastShowingStats) + if (curShowingStats.isNotEmpty()) { + val index = + viewModel.getMessageIndex(model = selectedModel, message = curShowingStats.first()) + if (index == messages.size - 2) { + listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000) + } + } + } } + lastShowingStatsByModel.value = uiState.showingStatsByModel } val nestedScrollConnection = remember { @@ -309,7 +329,51 @@ fun ChatPanel( } } if (message.side == ChatSide.AGENT) { - LatencyText(message = message) + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(8.dp), + ) { + LatencyText(message = message) + // A button to show stats for the LLM message. + if (selectedModel.taskType == TaskType.LLM_CHAT && message is ChatMessageText + // This means we only want to show the action button when the message is done + // generating, at which point the latency will be set. + && message.latencyMs >= 0 + ) { + val showingStats = + viewModel.isShowingStats(model = selectedModel, message = message) + MessageActionButton( + label = if (showingStats) "Hide stats" else "Show stats", + icon = Icons.Outlined.Timer, + onClick = { + // Toggle showing stats. + viewModel.toggleShowingStats(selectedModel, message) + + // Add the stats message after the LLM message. + if (viewModel.isShowingStats(model = selectedModel, message = message)) { + val llmBenchmarkResult = message.llmBenchmarkResult + if (llmBenchmarkResult != null) { + viewModel.insertMessageAfter( + model = selectedModel, + anchorMessage = message, + messageToAdd = llmBenchmarkResult, + ) + } + } + // Remove the stats message. + else { + val curMessageIndex = + viewModel.getMessageIndex(model = selectedModel, message = message) + viewModel.removeMessageAt( + model = selectedModel, + index = curMessageIndex + 1 + ) + } + }, + enabled = !uiState.inProgress + ) + } + } } else if (message.side == ChatSide.USER) { Row( verticalAlignment = Alignment.CenterVertically, @@ -328,21 +392,21 @@ fun ChatPanel( } // Benchmark button - if (selectedModel.showBenchmarkButton) { - MessageActionButton( - label = stringResource(R.string.benchmark), - icon = Icons.Outlined.Timer, - onClick = { - if (selectedModel.taskType == TaskType.LLM_CHAT) { - onBenchmarkClicked(selectedModel, message, 0, 0) - } else { - showBenchmarkConfigsDialog = true - benchmarkMessage.value = message - } - }, - enabled = !uiState.inProgress - ) - } +// if (selectedModel.showBenchmarkButton) { +// MessageActionButton( +// label = stringResource(R.string.benchmark), +// icon = Icons.Outlined.Timer, +// onClick = { +// if (selectedModel.taskType == TaskType.LLM_CHAT) { +// onBenchmarkClicked(selectedModel, message, 0, 0) +// } else { +// showBenchmarkConfigsDialog = true +// benchmarkMessage.value = message +// } +// }, +// enabled = !uiState.inProgress +// ) +// } } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt index 0d3fde7..c5fbe2b 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt @@ -43,6 +43,12 @@ data class ChatUiState( * A map of model names to the currently streaming chat message. */ val streamingMessagesByModel: Map = mapOf(), + + /* + * A map of model names to a map of chat messages to a boolean indicating whether the message is + * showing the stats below it. + */ + val showingStatsByModel: Map> = mapOf(), ) /** @@ -66,6 +72,31 @@ open class ChatViewModel(val task: Task) : ViewModel() { _uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) } } + fun insertMessageAfter(model: Model, anchorMessage: ChatMessage, messageToAdd: ChatMessage) { + val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap() + val newMessages = newMessagesByModel[model.name]?.toMutableList() + if (newMessages != null) { + newMessagesByModel[model.name] = newMessages + // Find the index of the anchor message + val anchorIndex = newMessages.indexOf(anchorMessage) + if (anchorIndex != -1) { + // Insert the new message after the anchor message + newMessages.add(anchorIndex + 1, messageToAdd) + } + } + _uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) } + } + + fun removeMessageAt(model: Model, index: Int) { + val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap() + val newMessages = newMessagesByModel[model.name]?.toMutableList() + if (newMessages != null) { + newMessagesByModel[model.name] = newMessages + newMessages.removeAt(index) + } + _uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) } + } + fun removeLastMessage(model: Model) { val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap() val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf() @@ -80,7 +111,7 @@ open class ChatViewModel(val task: Task) : ViewModel() { return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull() } - fun updateLastMessageContentIncrementally( + fun updateLastTextMessageContentIncrementally( model: Model, partialContent: String, latencyMs: Float, @@ -124,6 +155,25 @@ open class ChatViewModel(val task: Task) : ViewModel() { _uiState.update { newUiState } } + fun updateLastTextMessageLlmBenchmarkResult( + model: Model, + llmBenchmarkResult: ChatMessageBenchmarkLlmResult + ) { + val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap() + val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf() + if (newMessages.size > 0) { + val lastMessage = newMessages.last() + if (lastMessage is ChatMessageText) { + lastMessage.llmBenchmarkResult = llmBenchmarkResult + newMessages.removeAt(newMessages.size - 1) + newMessages.add(lastMessage) + } + } + newMessagesByModel[model.name] = newMessages + val newUiState = _uiState.value.copy(messagesByModel = newMessagesByModel) + _uiState.update { newUiState } + } + fun replaceLastMessage(model: Model, message: ChatMessage, type: ChatMessageType) { val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap() val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf() @@ -159,10 +209,6 @@ open class ChatViewModel(val task: Task) : ViewModel() { _uiState.update { _uiState.value.copy(inProgress = inProgress) } } - fun isInProgress(): Boolean { - return _uiState.value.inProgress - } - fun addConfigChangedMessage( oldConfigValues: Map, newConfigValues: Map, model: Model ) { @@ -173,6 +219,26 @@ open class ChatViewModel(val task: Task) : ViewModel() { addMessage(message = message, model = model) } + fun getMessageIndex(model: Model, message: ChatMessage): Int { + return (_uiState.value.messagesByModel[model.name] ?: listOf()).indexOf(message) + } + + fun isShowingStats(model: Model, message: ChatMessage): Boolean { + return _uiState.value.showingStatsByModel[model.name]?.contains(message) ?: false + } + + fun toggleShowingStats(model: Model, message: ChatMessage) { + val newShowingStatsByModel = _uiState.value.showingStatsByModel.toMutableMap() + val newShowingStats = newShowingStatsByModel[model.name]?.toMutableSet() ?: mutableSetOf() + if (newShowingStats.contains(message)) { + newShowingStats.remove(message) + } else { + newShowingStats.add(message) + } + newShowingStatsByModel[model.name] = newShowingStats + _uiState.update { _uiState.value.copy(showingStatsByModel = newShowingStatsByModel) } + } + private fun createUiState(task: Task): ChatUiState { val messagesByModel: MutableMap> = mutableMapOf() for (model in task.models) { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageSender.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageSender.kt index 56cb84f..3bf4106 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageSender.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageSender.kt @@ -179,7 +179,7 @@ private fun getMessageLayoutConfig( is ChatMessageBenchmarkLlmResult -> { horizontalArrangement = Arrangement.SpaceBetween modifier = modifier.fillMaxWidth() - userLabel = "Benchmark" + userLabel = "Stats" } is ChatMessageImageWithHistory -> { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt index 97491fd..96f59c5 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt @@ -38,6 +38,7 @@ import androidx.compose.material.icons.Icons import androidx.compose.material.icons.rounded.Settings import androidx.compose.material.icons.rounded.UnfoldLess import androidx.compose.material.icons.rounded.UnfoldMore +import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.IconButton import androidx.compose.material3.OutlinedButton @@ -90,6 +91,7 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp * model description and buttons for learning more (opening a URL) and downloading/trying * the model. */ +@OptIn(ExperimentalMaterial3Api::class) @Composable fun ModelItem( model: Model, @@ -217,7 +219,7 @@ fun ModelItem( .then(m), horizontalArrangement = Arrangement.spacedBy(12.dp), ) { - // The "learn more" button. Click to show url in default browser. + // The "learn more" button. Click to show related urls in a bottom sheet. if (model.learnMoreUrl.isNotEmpty()) { OutlinedButton( onClick = { @@ -241,35 +243,6 @@ fun ModelItem( modelManagerViewModel = modelManagerViewModel, onClicked = { onModelClicked(model) } ) -// Button( -// onClick = { -// if (isExpanded) { -// onModelClicked(model) -// if (needToDownloadFirst) { -// scope.launch { -// delay(80) -// checkNotificationPermissonAndStartDownload( -// context = context, -// launcher = launcher, -// modelManagerViewModel = modelManagerViewModel, -// model = model -// ) -// } -// } -// } -// }, -// ) { -// Icon( -// Icons.AutoMirrored.Rounded.ArrowForward, -// contentDescription = "", -// modifier = Modifier.padding(end = 4.dp) -// ) -// if (needToDownloadFirst) { -// Text("Download & Try it", maxLines = 1) -// } else { -// Text("Try it", maxLines = 1) -// } -// } } } } @@ -366,7 +339,6 @@ fun ModelItem( } } } - } @Preview(showBackground = true) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt index 1852d71..5d111fc 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt @@ -17,7 +17,7 @@ package com.google.aiedge.gallery.ui.home import androidx.compose.runtime.Composable -import com.google.aiedge.gallery.VERSION +import com.google.aiedge.gallery.BuildConfig import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.SegmentedButtonConfig @@ -45,7 +45,7 @@ fun SettingsDialog( ) ConfigDialog( title = "Settings", - subtitle = "App version: $VERSION", + subtitle = "App version: ${BuildConfig.VERSION_NAME}", okBtnLabel = "OK", configs = CONFIGS, initialValues = initialValues, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt index 9d3b4f7..15cfe30 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -18,12 +18,11 @@ package com.google.aiedge.gallery.ui.llmchat import android.content.Context import android.util.Log -import com.google.common.util.concurrent.ListenableFuture -import com.google.mediapipe.tasks.genai.llminference.LlmInference -import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.LlmBackend import com.google.aiedge.gallery.data.Model +import com.google.mediapipe.tasks.genai.llminference.LlmInference +import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession private const val TAG = "AGLlmChatModelHelper" private const val DEFAULT_MAX_TOKEN = 1024 @@ -39,8 +38,6 @@ data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceS object LlmChatModelHelper { // Indexed by model name. private val cleanUpListeners: MutableMap = mutableMapOf() - private val generateResponseListenableFutures: MutableMap> = - mutableMapOf() fun initialize( context: Context, model: Model, onDone: () -> Unit @@ -64,13 +61,12 @@ object LlmChatModelHelper { try { val llmInference = LlmInference.createFromOptions(context, options) -// val session = LlmInferenceSession.createFromOptions( -// llmInference, -// LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) -// .setTemperature(temperature).build() -// ) - model.instance = llmInference -// LlmModelInstance(engine = llmInference, session = session) + val session = LlmInferenceSession.createFromOptions( + llmInference, + LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) + .setTemperature(temperature).build() + ) + model.instance = LlmModelInstance(engine = llmInference, session = session) } catch (e: Exception) { e.printStackTrace() } @@ -82,11 +78,10 @@ object LlmChatModelHelper { return } - val instance = model.instance as LlmInference + val instance = model.instance as LlmModelInstance try { - instance.close() -// instance.session.close() -// instance.engine.close() + instance.session.close() + instance.engine.close() } catch (e: Exception) { // ignore } @@ -104,7 +99,7 @@ object LlmChatModelHelper { resultListener: ResultListener, cleanUpListener: CleanUpListener, ) { - val instance = model.instance as LlmInference + val instance = model.instance as LlmModelInstance // Set listener. if (!cleanUpListeners.containsKey(model.name)) { @@ -112,24 +107,8 @@ object LlmChatModelHelper { } // Start async inference. - val future = instance.generateResponseAsync(input, resultListener) - generateResponseListenableFutures[model.name] = future - -// val session = instance.session -// TODO: need to count token and reset session. -// session.addQueryChunk(input) -// session.generateResponseAsync(resultListener) - } - - fun stopInference(model: Model) { - val instance = model.instance as LlmInference - if (instance != null) { - instance.close() - } -// val future = generateResponseListenableFutures[model.name] -// if (future != null) { -// future.cancel(true) -// generateResponseListenableFutures.remove(model.name) -// } + val session = instance.session + session.addQueryChunk(input) + session.generateResponseAsync(resultListener) } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt index 324c254..59862be 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -17,7 +17,6 @@ package com.google.aiedge.gallery.ui.llmchat import androidx.lifecycle.viewModelScope -import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult @@ -56,11 +55,31 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { } // Run inference. + val instance = model.instance as LlmModelInstance + val prefillTokens = instance.session.sizeInTokens(input) + + var firstRun = true + var timeToFirstToken = 0f + var firstTokenTs = 0L + var decodeTokens = 0 + var prefillSpeed = 0f + var decodeSpeed: Float val start = System.currentTimeMillis() LlmChatModelHelper.runInference( model = model, input = input, resultListener = { partialResult, done -> + val curTs = System.currentTimeMillis() + + if (firstRun) { + firstTokenTs = System.currentTimeMillis() + timeToFirstToken = (firstTokenTs - start) / 1000f + prefillSpeed = prefillTokens / timeToFirstToken + firstRun = false + } else { + decodeTokens++ + } + // Remove the last message if it is a "loading" message. // This will only be done once. val lastMessage = getLastMessage(model = model) @@ -76,7 +95,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { // Incrementally update the streamed partial results. val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1 - updateLastMessageContentIncrementally( + updateLastTextMessageContentIncrementally( model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat() @@ -84,6 +103,29 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { if (done) { setInProgress(false) + + decodeSpeed = + decodeTokens / ((curTs - firstTokenTs) / 1000f) + if (decodeSpeed.isNaN()) { + decodeSpeed = 0f + } + + if (lastMessage is ChatMessageText) { + updateLastTextMessageLlmBenchmarkResult( + model = model, llmBenchmarkResult = + ChatMessageBenchmarkLlmResult( + orderedStats = STATS, + statValues = mutableMapOf( + "prefill_speed" to prefillSpeed, + "decode_speed" to decodeSpeed, + "time_to_first_token" to timeToFirstToken, + "latency" to (curTs - start).toFloat() / 1000f, + ), + running = false, + latencyMs = -1f, + ) + ) + } } }, cleanUpListener = { setInProgress(false) @@ -117,8 +159,8 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { while (model.instance == null) { delay(100) } - val instance = model.instance as LlmInference - val prefillTokens = instance.sizeInTokens(message.content) + val instance = model.instance as LlmModelInstance + val prefillTokens = instance.session.sizeInTokens(message.content) // Add the message to show benchmark results. val benchmarkLlmResult = ChatMessageBenchmarkLlmResult( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt index 59ed91c..abcb0ee 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt @@ -16,21 +16,34 @@ package com.google.aiedge.gallery.ui.modelmanager -import androidx.compose.foundation.ExperimentalFoundationApi +import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.PaddingValues +import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.outlined.Code +import androidx.compose.material.icons.outlined.Description +import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.platform.LocalUriHandler +import androidx.compose.ui.text.AnnotatedString +import androidx.compose.ui.text.SpanStyle import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.text.style.TextDecoration import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import com.google.aiedge.gallery.data.Model @@ -39,9 +52,9 @@ import com.google.aiedge.gallery.ui.common.modelitem.ModelItem import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.theme.GalleryTheme +import com.google.aiedge.gallery.ui.theme.customColors /** The list of models in the model manager. */ -@OptIn(ExperimentalFoundationApi::class) @Composable fun ModelList( task: Task, @@ -62,11 +75,40 @@ fun ModelList( textAlign = TextAlign.Center, style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), modifier = Modifier - .padding(bottom = 20.dp) .fillMaxWidth() ) } + // URLs. + item(key = "urls") { + Row( + horizontalArrangement = Arrangement.Center, + modifier = Modifier + .fillMaxWidth() + .padding(top = 12.dp, bottom = 16.dp), + ) { + Column( + horizontalAlignment = Alignment.Start, + verticalArrangement = Arrangement.spacedBy(4.dp), + ) { + if (task.docUrl.isNotEmpty()) { + ClickableLink( + url = task.docUrl, + linkText = "API Documentation", + icon = Icons.Outlined.Description + ) + } + if (task.sourceCodeUrl.isNotEmpty()) { + ClickableLink( + url = task.sourceCodeUrl, + linkText = "Example code", + icon = Icons.Outlined.Code + ) + } + } + } + } + // List of models within a task. items(items = task.models) { model -> Box { @@ -82,6 +124,45 @@ fun ModelList( } } +@Composable +fun ClickableLink( + url: String, + linkText: String, + icon: ImageVector, +) { + val uriHandler = LocalUriHandler.current + val annotatedText = AnnotatedString( + text = linkText, + spanStyles = listOf( + AnnotatedString.Range( + item = SpanStyle( + color = MaterialTheme.customColors.linkColor, + textDecoration = TextDecoration.Underline + ), + start = 0, + end = linkText.length + ) + ) + ) + + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.Center, + ) { + Icon(icon, contentDescription = "", modifier = Modifier.size(16.dp)) + Text( + text = annotatedText, + textAlign = TextAlign.Center, + style = MaterialTheme.typography.bodySmall, + modifier = Modifier + .padding(start = 6.dp) + .clickable { + uriHandler.openUri(url) + }, + ) + } +} + @Preview(showBackground = true) @Composable fun ModelListPreview() { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManager.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManager.kt index 88eaabe..3a563b3 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManager.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManager.kt @@ -21,9 +21,6 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Scaffold import androidx.compose.runtime.Composable -import androidx.compose.runtime.collectAsState -import androidx.compose.runtime.getValue -import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.tooling.preview.Preview @@ -31,9 +28,7 @@ import com.google.aiedge.gallery.GalleryTopAppBar import com.google.aiedge.gallery.data.AppBarAction import com.google.aiedge.gallery.data.AppBarActionType import com.google.aiedge.gallery.data.Model -import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.Task -import com.google.aiedge.gallery.data.getModelByName import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.theme.GalleryTheme @@ -48,9 +43,6 @@ fun ModelManager( onModelClicked: (Model) -> Unit, modifier: Modifier = Modifier, ) { - val uiState by viewModel.uiState.collectAsState() - val coroutineScope = rememberCoroutineScope() - // Set title based on the task. var title = "${task.type.label} model" if (task.models.size != 1) { @@ -67,27 +59,7 @@ fun ModelManager( topBar = { GalleryTopAppBar( title = title, -// subtitle = String.format( -// stringResource(R.string.downloaded_size), -// totalSizeInBytes.humanReadableSize() -// ), - - // Refresh model list button at the left side of the app bar. -// leftAction = AppBarAction(actionType = if (uiState.loadingHfModels) { -// AppBarActionType.REFRESHING_MODELS -// } else { -// AppBarActionType.REFRESH_MODELS -// }, actionFn = { -// coroutineScope.launch(Dispatchers.IO) { -// viewModel.loadHfModels() -// } -// }), leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp) - - // "Done" button at the right side of the app bar to navigate up. -// rightAction = AppBarAction( -// actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp -// ), ) }, ) { innerPadding -> @@ -101,19 +73,6 @@ fun ModelManager( } } -private fun getTotalDownloadedFileSize(uiState: ModelManagerUiState): Long { - var totalSizeInBytes = 0L - for ((name, status) in uiState.modelDownloadStatus.entries) { - if (status.status == ModelDownloadStatusType.SUCCEEDED) { - totalSizeInBytes += getModelByName(name)?.totalBytes ?: 0L - } else if (status.status == ModelDownloadStatusType.IN_PROGRESS) { - totalSizeInBytes += status.receivedBytes - } - } - return totalSizeInBytes -} - - @Preview @Composable fun ModelManagerPreview() { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/theme/Theme.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/theme/Theme.kt index d223a8f..a5f6081 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/theme/Theme.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/theme/Theme.kt @@ -115,26 +115,27 @@ data class CustomColors( val homeBottomGradient: List = listOf(), val userBubbleBgColor: Color = Color.Transparent, val agentBubbleBgColor: Color = Color.Transparent, + val linkColor: Color = Color.Transparent, ) val LocalCustomColors = staticCompositionLocalOf { CustomColors() } val lightCustomColors = CustomColors( taskBgColors = listOf( + // green + Color(0xFFE1F6DE), + // blue + Color(0xFFEDF0FF), // yellow Color(0xFFFFEFC9), // red Color(0xFFFFEDE6), - // green - Color(0xFFE1F6DE), - // blue - Color(0xFFEDF0FF) ), taskIconColors = listOf( - Color(0xFFE37400), - Color(0xFFD93025), Color(0xFF34A853), Color(0xFF1967D2), + Color(0xFFE37400), + Color(0xFFD93025), ), taskIconShapeBgColor = Color.White, homeBottomGradient = listOf( @@ -143,24 +144,25 @@ val lightCustomColors = CustomColors( ), agentBubbleBgColor = Color(0xFFe9eef6), userBubbleBgColor = Color(0xFF32628D), + linkColor = Color(0xFF32628D), ) val darkCustomColors = CustomColors( taskBgColors = listOf( + // green + Color(0xFF2E312D), + // blue + Color(0xFF303033), // yellow Color(0xFF33302A), // red Color(0xFF362F2D), - // green - Color(0xFF2E312D), - // blue - Color(0xFF303033) ), taskIconColors = listOf( - Color(0xFFFFB955), - Color(0xFFFFB4AB), Color(0xFF6DD58C), Color(0xFFAAC7FF), + Color(0xFFFFB955), + Color(0xFFFFB4AB), ), taskIconShapeBgColor = Color(0xFF202124), homeBottomGradient = listOf( @@ -169,6 +171,7 @@ val darkCustomColors = CustomColors( ), agentBubbleBgColor = Color(0xFF1b1c1d), userBubbleBgColor = Color(0xFF1f3760), + linkColor = Color(0xFF9DCAFC), ) val MaterialTheme.customColors: CustomColors diff --git a/Android/src/app/src/main/res/values/strings.xml b/Android/src/app/src/main/res/values/strings.xml index bba336b..7b0242d 100644 --- a/Android/src/app/src/main/res/values/strings.xml +++ b/Android/src/app/src/main/res/values/strings.xml @@ -35,7 +35,7 @@ Initializing model… Type movie review to classify… Type prompt… - Type prompt (one shot)… + Type prompt… Run again Run benchmark warming up…