mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-06 06:30:30 -04:00
- Add "show stats" after each LLM response to see token related stats more easily.
- Back to multi-turn (no exception handling yet). - Show app's version in App Info screen. - Show API doc link and source code link in the model list screen. - Jump to model info page in HF when clicking "learn more".
This commit is contained in:
parent
ea31fd0544
commit
42d442389d
16 changed files with 343 additions and 183 deletions
|
@ -30,7 +30,7 @@ android {
|
||||||
minSdk = 24
|
minSdk = 24
|
||||||
targetSdk = 35
|
targetSdk = 35
|
||||||
versionCode = 1
|
versionCode = 1
|
||||||
versionName = "1.0"
|
versionName = "20250416"
|
||||||
|
|
||||||
// Needed for HuggingFace auth workflows.
|
// Needed for HuggingFace auth workflows.
|
||||||
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
|
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
|
||||||
|
@ -58,6 +58,7 @@ android {
|
||||||
}
|
}
|
||||||
buildFeatures {
|
buildFeatures {
|
||||||
compose = true
|
compose = true
|
||||||
|
buildConfig = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2025 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package com.google.aiedge.gallery
|
|
||||||
|
|
||||||
const val VERSION = "20250413"
|
|
|
@ -218,9 +218,6 @@ const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/lite
|
||||||
const val LLM_CHAT_INFO =
|
const val LLM_CHAT_INFO =
|
||||||
"Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms"
|
"Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms"
|
||||||
|
|
||||||
const val LLM_CHAT_LEARN_MORE_URL =
|
|
||||||
"https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android"
|
|
||||||
|
|
||||||
const val IMAGE_GENERATION_INFO =
|
const val IMAGE_GENERATION_INFO =
|
||||||
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
|
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
|
||||||
|
|
||||||
|
@ -234,7 +231,7 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
|
||||||
sizeInBytes = 1354301440L,
|
sizeInBytes = 1354301440L,
|
||||||
configs = createLlmChatConfigs(),
|
configs = createLlmChatConfigs(),
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
|
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||||
)
|
)
|
||||||
|
|
||||||
val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
||||||
|
@ -244,7 +241,7 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
||||||
sizeInBytes = 2627141632L,
|
sizeInBytes = 2627141632L,
|
||||||
configs = createLlmChatConfigs(),
|
configs = createLlmChatConfigs(),
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
|
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||||
)
|
)
|
||||||
|
|
||||||
val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
||||||
|
@ -254,7 +251,7 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
||||||
sizeInBytes = 554661243L,
|
sizeInBytes = 554661243L,
|
||||||
configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f),
|
configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f),
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
|
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
||||||
llmPromptTemplates = listOf(
|
llmPromptTemplates = listOf(
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
title = "Emoji Fun",
|
title = "Emoji Fun",
|
||||||
|
@ -277,7 +274,7 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
|
||||||
llmBackend = LlmBackend.CPU,
|
llmBackend = LlmBackend.CPU,
|
||||||
configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f),
|
configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f),
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
|
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||||
)
|
)
|
||||||
|
|
||||||
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
|
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
|
||||||
|
@ -343,6 +340,7 @@ val MODEL_IMAGE_GENERATION_STABLE_DIFFUSION: Model = Model(
|
||||||
showBenchmarkButton = false,
|
showBenchmarkButton = false,
|
||||||
info = IMAGE_GENERATION_INFO,
|
info = IMAGE_GENERATION_INFO,
|
||||||
configs = IMAGE_GENERATION_CONFIGS,
|
configs = IMAGE_GENERATION_CONFIGS,
|
||||||
|
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||||
)
|
)
|
||||||
|
|
||||||
val EMPTY_MODEL: Model = Model(
|
val EMPTY_MODEL: Model = Model(
|
||||||
|
|
|
@ -50,6 +50,12 @@ data class Task(
|
||||||
/** Description of the task. */
|
/** Description of the task. */
|
||||||
val description: String,
|
val description: String,
|
||||||
|
|
||||||
|
/** Documentation url for the task. */
|
||||||
|
val docUrl: String = "",
|
||||||
|
|
||||||
|
/** Source code url for the model-related functions. */
|
||||||
|
val sourceCodeUrl: String = "",
|
||||||
|
|
||||||
/** Placeholder text for the name of the agent shown above chat messages. */
|
/** Placeholder text for the name of the agent shown above chat messages. */
|
||||||
@StringRes val agentNameRes: Int = R.string.chat_generic_agent_name,
|
@StringRes val agentNameRes: Int = R.string.chat_generic_agent_name,
|
||||||
|
|
||||||
|
@ -80,6 +86,8 @@ val TASK_LLM_CHAT = Task(
|
||||||
iconVectorResourceId = R.drawable.chat_spark,
|
iconVectorResourceId = R.drawable.chat_spark,
|
||||||
models = MODELS_LLM_CHAT,
|
models = MODELS_LLM_CHAT,
|
||||||
description = "Chat? with a on-device large language model",
|
description = "Chat? with a on-device large language model",
|
||||||
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -88,13 +96,15 @@ val TASK_IMAGE_GENERATION = Task(
|
||||||
iconVectorResourceId = R.drawable.image_spark,
|
iconVectorResourceId = R.drawable.image_spark,
|
||||||
models = MODELS_IMAGE_GENERATION,
|
models = MODELS_IMAGE_GENERATION,
|
||||||
description = "Generate images from text",
|
description = "Generate images from text",
|
||||||
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android",
|
||||||
|
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt",
|
||||||
textInputPlaceHolderRes = R.string.text_image_generation_text_field_placeholder
|
textInputPlaceHolderRes = R.string.text_image_generation_text_field_placeholder
|
||||||
)
|
)
|
||||||
|
|
||||||
/** All tasks. */
|
/** All tasks. */
|
||||||
val TASKS: List<Task> = listOf(
|
val TASKS: List<Task> = listOf(
|
||||||
TASK_TEXT_CLASSIFICATION,
|
// TASK_TEXT_CLASSIFICATION,
|
||||||
TASK_IMAGE_CLASSIFICATION,
|
// TASK_IMAGE_CLASSIFICATION,
|
||||||
TASK_IMAGE_GENERATION,
|
TASK_IMAGE_GENERATION,
|
||||||
TASK_LLM_CHAT,
|
TASK_LLM_CHAT,
|
||||||
)
|
)
|
||||||
|
|
|
@ -71,13 +71,17 @@ open class ChatMessageText(
|
||||||
// Negative numbers will hide the latency display.
|
// Negative numbers will hide the latency display.
|
||||||
override val latencyMs: Float = 0f,
|
override val latencyMs: Float = 0f,
|
||||||
val isMarkdown: Boolean = true,
|
val isMarkdown: Boolean = true,
|
||||||
|
|
||||||
|
// Benchmark result for LLM response.
|
||||||
|
var llmBenchmarkResult: ChatMessageBenchmarkLlmResult? = null,
|
||||||
) : ChatMessage(type = ChatMessageType.TEXT, side = side, latencyMs = latencyMs) {
|
) : ChatMessage(type = ChatMessageType.TEXT, side = side, latencyMs = latencyMs) {
|
||||||
override fun clone(): ChatMessageText {
|
override fun clone(): ChatMessageText {
|
||||||
return ChatMessageText(
|
return ChatMessageText(
|
||||||
content = content,
|
content = content,
|
||||||
side = side,
|
side = side,
|
||||||
latencyMs = latencyMs,
|
latencyMs = latencyMs,
|
||||||
isMarkdown = isMarkdown
|
isMarkdown = isMarkdown,
|
||||||
|
llmBenchmarkResult = llmBenchmarkResult,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -150,6 +150,8 @@ fun ChatPanel(
|
||||||
lastMessageContent.value = tmpLastMessage.content
|
lastMessageContent.value = tmpLastMessage.content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
val lastShowingStatsByModel: MutableState<Map<String, MutableSet<ChatMessage>>> =
|
||||||
|
remember { mutableStateOf(mapOf()) }
|
||||||
|
|
||||||
// Scroll the content to the bottom when any of these changes.
|
// Scroll the content to the bottom when any of these changes.
|
||||||
LaunchedEffect(
|
LaunchedEffect(
|
||||||
|
@ -158,9 +160,27 @@ fun ChatPanel(
|
||||||
lastMessageContent.value,
|
lastMessageContent.value,
|
||||||
WindowInsets.ime.getBottom(density),
|
WindowInsets.ime.getBottom(density),
|
||||||
) {
|
) {
|
||||||
|
// Only scroll if showingStatsByModel is not changed. In other words, when showingStatsByModel
|
||||||
|
// changes we want the display to not scroll.
|
||||||
if (messages.isNotEmpty()) {
|
if (messages.isNotEmpty()) {
|
||||||
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
|
if (uiState.showingStatsByModel === lastShowingStatsByModel.value) {
|
||||||
|
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
|
||||||
|
} else {
|
||||||
|
// Scroll to bottom if the message to show stats is the last message.
|
||||||
|
val curShowingStats =
|
||||||
|
uiState.showingStatsByModel[selectedModel.name]?.toMutableSet() ?: mutableSetOf()
|
||||||
|
val lastShowingStats = lastShowingStatsByModel.value[selectedModel.name] ?: mutableSetOf()
|
||||||
|
curShowingStats.removeAll(lastShowingStats)
|
||||||
|
if (curShowingStats.isNotEmpty()) {
|
||||||
|
val index =
|
||||||
|
viewModel.getMessageIndex(model = selectedModel, message = curShowingStats.first())
|
||||||
|
if (index == messages.size - 2) {
|
||||||
|
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
lastShowingStatsByModel.value = uiState.showingStatsByModel
|
||||||
}
|
}
|
||||||
|
|
||||||
val nestedScrollConnection = remember {
|
val nestedScrollConnection = remember {
|
||||||
|
@ -309,7 +329,51 @@ fun ChatPanel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (message.side == ChatSide.AGENT) {
|
if (message.side == ChatSide.AGENT) {
|
||||||
LatencyText(message = message)
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
) {
|
||||||
|
LatencyText(message = message)
|
||||||
|
// A button to show stats for the LLM message.
|
||||||
|
if (selectedModel.taskType == TaskType.LLM_CHAT && message is ChatMessageText
|
||||||
|
// This means we only want to show the action button when the message is done
|
||||||
|
// generating, at which point the latency will be set.
|
||||||
|
&& message.latencyMs >= 0
|
||||||
|
) {
|
||||||
|
val showingStats =
|
||||||
|
viewModel.isShowingStats(model = selectedModel, message = message)
|
||||||
|
MessageActionButton(
|
||||||
|
label = if (showingStats) "Hide stats" else "Show stats",
|
||||||
|
icon = Icons.Outlined.Timer,
|
||||||
|
onClick = {
|
||||||
|
// Toggle showing stats.
|
||||||
|
viewModel.toggleShowingStats(selectedModel, message)
|
||||||
|
|
||||||
|
// Add the stats message after the LLM message.
|
||||||
|
if (viewModel.isShowingStats(model = selectedModel, message = message)) {
|
||||||
|
val llmBenchmarkResult = message.llmBenchmarkResult
|
||||||
|
if (llmBenchmarkResult != null) {
|
||||||
|
viewModel.insertMessageAfter(
|
||||||
|
model = selectedModel,
|
||||||
|
anchorMessage = message,
|
||||||
|
messageToAdd = llmBenchmarkResult,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Remove the stats message.
|
||||||
|
else {
|
||||||
|
val curMessageIndex =
|
||||||
|
viewModel.getMessageIndex(model = selectedModel, message = message)
|
||||||
|
viewModel.removeMessageAt(
|
||||||
|
model = selectedModel,
|
||||||
|
index = curMessageIndex + 1
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
enabled = !uiState.inProgress
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (message.side == ChatSide.USER) {
|
} else if (message.side == ChatSide.USER) {
|
||||||
Row(
|
Row(
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
@ -328,21 +392,21 @@ fun ChatPanel(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark button
|
// Benchmark button
|
||||||
if (selectedModel.showBenchmarkButton) {
|
// if (selectedModel.showBenchmarkButton) {
|
||||||
MessageActionButton(
|
// MessageActionButton(
|
||||||
label = stringResource(R.string.benchmark),
|
// label = stringResource(R.string.benchmark),
|
||||||
icon = Icons.Outlined.Timer,
|
// icon = Icons.Outlined.Timer,
|
||||||
onClick = {
|
// onClick = {
|
||||||
if (selectedModel.taskType == TaskType.LLM_CHAT) {
|
// if (selectedModel.taskType == TaskType.LLM_CHAT) {
|
||||||
onBenchmarkClicked(selectedModel, message, 0, 0)
|
// onBenchmarkClicked(selectedModel, message, 0, 0)
|
||||||
} else {
|
// } else {
|
||||||
showBenchmarkConfigsDialog = true
|
// showBenchmarkConfigsDialog = true
|
||||||
benchmarkMessage.value = message
|
// benchmarkMessage.value = message
|
||||||
}
|
// }
|
||||||
},
|
// },
|
||||||
enabled = !uiState.inProgress
|
// enabled = !uiState.inProgress
|
||||||
)
|
// )
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,6 +43,12 @@ data class ChatUiState(
|
||||||
* A map of model names to the currently streaming chat message.
|
* A map of model names to the currently streaming chat message.
|
||||||
*/
|
*/
|
||||||
val streamingMessagesByModel: Map<String, ChatMessage> = mapOf(),
|
val streamingMessagesByModel: Map<String, ChatMessage> = mapOf(),
|
||||||
|
|
||||||
|
/*
|
||||||
|
* A map of model names to a map of chat messages to a boolean indicating whether the message is
|
||||||
|
* showing the stats below it.
|
||||||
|
*/
|
||||||
|
val showingStatsByModel: Map<String, MutableSet<ChatMessage>> = mapOf(),
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -66,6 +72,31 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
||||||
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun insertMessageAfter(model: Model, anchorMessage: ChatMessage, messageToAdd: ChatMessage) {
|
||||||
|
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||||
|
val newMessages = newMessagesByModel[model.name]?.toMutableList()
|
||||||
|
if (newMessages != null) {
|
||||||
|
newMessagesByModel[model.name] = newMessages
|
||||||
|
// Find the index of the anchor message
|
||||||
|
val anchorIndex = newMessages.indexOf(anchorMessage)
|
||||||
|
if (anchorIndex != -1) {
|
||||||
|
// Insert the new message after the anchor message
|
||||||
|
newMessages.add(anchorIndex + 1, messageToAdd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fun removeMessageAt(model: Model, index: Int) {
|
||||||
|
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||||
|
val newMessages = newMessagesByModel[model.name]?.toMutableList()
|
||||||
|
if (newMessages != null) {
|
||||||
|
newMessagesByModel[model.name] = newMessages
|
||||||
|
newMessages.removeAt(index)
|
||||||
|
}
|
||||||
|
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||||
|
}
|
||||||
|
|
||||||
fun removeLastMessage(model: Model) {
|
fun removeLastMessage(model: Model) {
|
||||||
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||||
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
|
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
|
||||||
|
@ -80,7 +111,7 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
||||||
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
|
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun updateLastMessageContentIncrementally(
|
fun updateLastTextMessageContentIncrementally(
|
||||||
model: Model,
|
model: Model,
|
||||||
partialContent: String,
|
partialContent: String,
|
||||||
latencyMs: Float,
|
latencyMs: Float,
|
||||||
|
@ -124,6 +155,25 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
||||||
_uiState.update { newUiState }
|
_uiState.update { newUiState }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun updateLastTextMessageLlmBenchmarkResult(
|
||||||
|
model: Model,
|
||||||
|
llmBenchmarkResult: ChatMessageBenchmarkLlmResult
|
||||||
|
) {
|
||||||
|
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||||
|
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
|
||||||
|
if (newMessages.size > 0) {
|
||||||
|
val lastMessage = newMessages.last()
|
||||||
|
if (lastMessage is ChatMessageText) {
|
||||||
|
lastMessage.llmBenchmarkResult = llmBenchmarkResult
|
||||||
|
newMessages.removeAt(newMessages.size - 1)
|
||||||
|
newMessages.add(lastMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newMessagesByModel[model.name] = newMessages
|
||||||
|
val newUiState = _uiState.value.copy(messagesByModel = newMessagesByModel)
|
||||||
|
_uiState.update { newUiState }
|
||||||
|
}
|
||||||
|
|
||||||
fun replaceLastMessage(model: Model, message: ChatMessage, type: ChatMessageType) {
|
fun replaceLastMessage(model: Model, message: ChatMessage, type: ChatMessageType) {
|
||||||
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||||
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
|
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
|
||||||
|
@ -159,10 +209,6 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
||||||
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
||||||
}
|
}
|
||||||
|
|
||||||
fun isInProgress(): Boolean {
|
|
||||||
return _uiState.value.inProgress
|
|
||||||
}
|
|
||||||
|
|
||||||
fun addConfigChangedMessage(
|
fun addConfigChangedMessage(
|
||||||
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
|
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
|
||||||
) {
|
) {
|
||||||
|
@ -173,6 +219,26 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
||||||
addMessage(message = message, model = model)
|
addMessage(message = message, model = model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun getMessageIndex(model: Model, message: ChatMessage): Int {
|
||||||
|
return (_uiState.value.messagesByModel[model.name] ?: listOf()).indexOf(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun isShowingStats(model: Model, message: ChatMessage): Boolean {
|
||||||
|
return _uiState.value.showingStatsByModel[model.name]?.contains(message) ?: false
|
||||||
|
}
|
||||||
|
|
||||||
|
fun toggleShowingStats(model: Model, message: ChatMessage) {
|
||||||
|
val newShowingStatsByModel = _uiState.value.showingStatsByModel.toMutableMap()
|
||||||
|
val newShowingStats = newShowingStatsByModel[model.name]?.toMutableSet() ?: mutableSetOf()
|
||||||
|
if (newShowingStats.contains(message)) {
|
||||||
|
newShowingStats.remove(message)
|
||||||
|
} else {
|
||||||
|
newShowingStats.add(message)
|
||||||
|
}
|
||||||
|
newShowingStatsByModel[model.name] = newShowingStats
|
||||||
|
_uiState.update { _uiState.value.copy(showingStatsByModel = newShowingStatsByModel) }
|
||||||
|
}
|
||||||
|
|
||||||
private fun createUiState(task: Task): ChatUiState {
|
private fun createUiState(task: Task): ChatUiState {
|
||||||
val messagesByModel: MutableMap<String, MutableList<ChatMessage>> = mutableMapOf()
|
val messagesByModel: MutableMap<String, MutableList<ChatMessage>> = mutableMapOf()
|
||||||
for (model in task.models) {
|
for (model in task.models) {
|
||||||
|
|
|
@ -179,7 +179,7 @@ private fun getMessageLayoutConfig(
|
||||||
is ChatMessageBenchmarkLlmResult -> {
|
is ChatMessageBenchmarkLlmResult -> {
|
||||||
horizontalArrangement = Arrangement.SpaceBetween
|
horizontalArrangement = Arrangement.SpaceBetween
|
||||||
modifier = modifier.fillMaxWidth()
|
modifier = modifier.fillMaxWidth()
|
||||||
userLabel = "Benchmark"
|
userLabel = "Stats"
|
||||||
}
|
}
|
||||||
|
|
||||||
is ChatMessageImageWithHistory -> {
|
is ChatMessageImageWithHistory -> {
|
||||||
|
|
|
@ -38,6 +38,7 @@ import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.rounded.Settings
|
import androidx.compose.material.icons.rounded.Settings
|
||||||
import androidx.compose.material.icons.rounded.UnfoldLess
|
import androidx.compose.material.icons.rounded.UnfoldLess
|
||||||
import androidx.compose.material.icons.rounded.UnfoldMore
|
import androidx.compose.material.icons.rounded.UnfoldMore
|
||||||
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.IconButton
|
import androidx.compose.material3.IconButton
|
||||||
import androidx.compose.material3.OutlinedButton
|
import androidx.compose.material3.OutlinedButton
|
||||||
|
@ -90,6 +91,7 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
|
||||||
* model description and buttons for learning more (opening a URL) and downloading/trying
|
* model description and buttons for learning more (opening a URL) and downloading/trying
|
||||||
* the model.
|
* the model.
|
||||||
*/
|
*/
|
||||||
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelItem(
|
fun ModelItem(
|
||||||
model: Model,
|
model: Model,
|
||||||
|
@ -217,7 +219,7 @@ fun ModelItem(
|
||||||
.then(m),
|
.then(m),
|
||||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||||
) {
|
) {
|
||||||
// The "learn more" button. Click to show url in default browser.
|
// The "learn more" button. Click to show related urls in a bottom sheet.
|
||||||
if (model.learnMoreUrl.isNotEmpty()) {
|
if (model.learnMoreUrl.isNotEmpty()) {
|
||||||
OutlinedButton(
|
OutlinedButton(
|
||||||
onClick = {
|
onClick = {
|
||||||
|
@ -241,35 +243,6 @@ fun ModelItem(
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
onClicked = { onModelClicked(model) }
|
onClicked = { onModelClicked(model) }
|
||||||
)
|
)
|
||||||
// Button(
|
|
||||||
// onClick = {
|
|
||||||
// if (isExpanded) {
|
|
||||||
// onModelClicked(model)
|
|
||||||
// if (needToDownloadFirst) {
|
|
||||||
// scope.launch {
|
|
||||||
// delay(80)
|
|
||||||
// checkNotificationPermissonAndStartDownload(
|
|
||||||
// context = context,
|
|
||||||
// launcher = launcher,
|
|
||||||
// modelManagerViewModel = modelManagerViewModel,
|
|
||||||
// model = model
|
|
||||||
// )
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// },
|
|
||||||
// ) {
|
|
||||||
// Icon(
|
|
||||||
// Icons.AutoMirrored.Rounded.ArrowForward,
|
|
||||||
// contentDescription = "",
|
|
||||||
// modifier = Modifier.padding(end = 4.dp)
|
|
||||||
// )
|
|
||||||
// if (needToDownloadFirst) {
|
|
||||||
// Text("Download & Try it", maxLines = 1)
|
|
||||||
// } else {
|
|
||||||
// Text("Try it", maxLines = 1)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -366,7 +339,6 @@ fun ModelItem(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Preview(showBackground = true)
|
@Preview(showBackground = true)
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package com.google.aiedge.gallery.ui.home
|
package com.google.aiedge.gallery.ui.home
|
||||||
|
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import com.google.aiedge.gallery.VERSION
|
import com.google.aiedge.gallery.BuildConfig
|
||||||
import com.google.aiedge.gallery.data.Config
|
import com.google.aiedge.gallery.data.Config
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||||
|
@ -45,7 +45,7 @@ fun SettingsDialog(
|
||||||
)
|
)
|
||||||
ConfigDialog(
|
ConfigDialog(
|
||||||
title = "Settings",
|
title = "Settings",
|
||||||
subtitle = "App version: $VERSION",
|
subtitle = "App version: ${BuildConfig.VERSION_NAME}",
|
||||||
okBtnLabel = "OK",
|
okBtnLabel = "OK",
|
||||||
configs = CONFIGS,
|
configs = CONFIGS,
|
||||||
initialValues = initialValues,
|
initialValues = initialValues,
|
||||||
|
|
|
@ -18,12 +18,11 @@ package com.google.aiedge.gallery.ui.llmchat
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import com.google.common.util.concurrent.ListenableFuture
|
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.LlmBackend
|
import com.google.aiedge.gallery.data.LlmBackend
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||||
|
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||||
|
|
||||||
private const val TAG = "AGLlmChatModelHelper"
|
private const val TAG = "AGLlmChatModelHelper"
|
||||||
private const val DEFAULT_MAX_TOKEN = 1024
|
private const val DEFAULT_MAX_TOKEN = 1024
|
||||||
|
@ -39,8 +38,6 @@ data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceS
|
||||||
object LlmChatModelHelper {
|
object LlmChatModelHelper {
|
||||||
// Indexed by model name.
|
// Indexed by model name.
|
||||||
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
|
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
|
||||||
private val generateResponseListenableFutures: MutableMap<String, ListenableFuture<String>> =
|
|
||||||
mutableMapOf()
|
|
||||||
|
|
||||||
fun initialize(
|
fun initialize(
|
||||||
context: Context, model: Model, onDone: () -> Unit
|
context: Context, model: Model, onDone: () -> Unit
|
||||||
|
@ -64,13 +61,12 @@ object LlmChatModelHelper {
|
||||||
try {
|
try {
|
||||||
val llmInference = LlmInference.createFromOptions(context, options)
|
val llmInference = LlmInference.createFromOptions(context, options)
|
||||||
|
|
||||||
// val session = LlmInferenceSession.createFromOptions(
|
val session = LlmInferenceSession.createFromOptions(
|
||||||
// llmInference,
|
llmInference,
|
||||||
// LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||||
// .setTemperature(temperature).build()
|
.setTemperature(temperature).build()
|
||||||
// )
|
)
|
||||||
model.instance = llmInference
|
model.instance = LlmModelInstance(engine = llmInference, session = session)
|
||||||
// LlmModelInstance(engine = llmInference, session = session)
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
e.printStackTrace()
|
e.printStackTrace()
|
||||||
}
|
}
|
||||||
|
@ -82,11 +78,10 @@ object LlmChatModelHelper {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
val instance = model.instance as LlmInference
|
val instance = model.instance as LlmModelInstance
|
||||||
try {
|
try {
|
||||||
instance.close()
|
instance.session.close()
|
||||||
// instance.session.close()
|
instance.engine.close()
|
||||||
// instance.engine.close()
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
// ignore
|
// ignore
|
||||||
}
|
}
|
||||||
|
@ -104,7 +99,7 @@ object LlmChatModelHelper {
|
||||||
resultListener: ResultListener,
|
resultListener: ResultListener,
|
||||||
cleanUpListener: CleanUpListener,
|
cleanUpListener: CleanUpListener,
|
||||||
) {
|
) {
|
||||||
val instance = model.instance as LlmInference
|
val instance = model.instance as LlmModelInstance
|
||||||
|
|
||||||
// Set listener.
|
// Set listener.
|
||||||
if (!cleanUpListeners.containsKey(model.name)) {
|
if (!cleanUpListeners.containsKey(model.name)) {
|
||||||
|
@ -112,24 +107,8 @@ object LlmChatModelHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start async inference.
|
// Start async inference.
|
||||||
val future = instance.generateResponseAsync(input, resultListener)
|
val session = instance.session
|
||||||
generateResponseListenableFutures[model.name] = future
|
session.addQueryChunk(input)
|
||||||
|
session.generateResponseAsync(resultListener)
|
||||||
// val session = instance.session
|
|
||||||
// TODO: need to count token and reset session.
|
|
||||||
// session.addQueryChunk(input)
|
|
||||||
// session.generateResponseAsync(resultListener)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun stopInference(model: Model) {
|
|
||||||
val instance = model.instance as LlmInference
|
|
||||||
if (instance != null) {
|
|
||||||
instance.close()
|
|
||||||
}
|
|
||||||
// val future = generateResponseListenableFutures[model.name]
|
|
||||||
// if (future != null) {
|
|
||||||
// future.cancel(true)
|
|
||||||
// generateResponseListenableFutures.remove(model.name)
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package com.google.aiedge.gallery.ui.llmchat
|
package com.google.aiedge.gallery.ui.llmchat
|
||||||
|
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||||
|
@ -56,11 +55,31 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run inference.
|
// Run inference.
|
||||||
|
val instance = model.instance as LlmModelInstance
|
||||||
|
val prefillTokens = instance.session.sizeInTokens(input)
|
||||||
|
|
||||||
|
var firstRun = true
|
||||||
|
var timeToFirstToken = 0f
|
||||||
|
var firstTokenTs = 0L
|
||||||
|
var decodeTokens = 0
|
||||||
|
var prefillSpeed = 0f
|
||||||
|
var decodeSpeed: Float
|
||||||
val start = System.currentTimeMillis()
|
val start = System.currentTimeMillis()
|
||||||
LlmChatModelHelper.runInference(
|
LlmChatModelHelper.runInference(
|
||||||
model = model,
|
model = model,
|
||||||
input = input,
|
input = input,
|
||||||
resultListener = { partialResult, done ->
|
resultListener = { partialResult, done ->
|
||||||
|
val curTs = System.currentTimeMillis()
|
||||||
|
|
||||||
|
if (firstRun) {
|
||||||
|
firstTokenTs = System.currentTimeMillis()
|
||||||
|
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||||
|
prefillSpeed = prefillTokens / timeToFirstToken
|
||||||
|
firstRun = false
|
||||||
|
} else {
|
||||||
|
decodeTokens++
|
||||||
|
}
|
||||||
|
|
||||||
// Remove the last message if it is a "loading" message.
|
// Remove the last message if it is a "loading" message.
|
||||||
// This will only be done once.
|
// This will only be done once.
|
||||||
val lastMessage = getLastMessage(model = model)
|
val lastMessage = getLastMessage(model = model)
|
||||||
|
@ -76,7 +95,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
|
|
||||||
// Incrementally update the streamed partial results.
|
// Incrementally update the streamed partial results.
|
||||||
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
|
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
|
||||||
updateLastMessageContentIncrementally(
|
updateLastTextMessageContentIncrementally(
|
||||||
model = model,
|
model = model,
|
||||||
partialContent = partialResult,
|
partialContent = partialResult,
|
||||||
latencyMs = latencyMs.toFloat()
|
latencyMs = latencyMs.toFloat()
|
||||||
|
@ -84,6 +103,29 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
|
|
||||||
if (done) {
|
if (done) {
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
|
|
||||||
|
decodeSpeed =
|
||||||
|
decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||||
|
if (decodeSpeed.isNaN()) {
|
||||||
|
decodeSpeed = 0f
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lastMessage is ChatMessageText) {
|
||||||
|
updateLastTextMessageLlmBenchmarkResult(
|
||||||
|
model = model, llmBenchmarkResult =
|
||||||
|
ChatMessageBenchmarkLlmResult(
|
||||||
|
orderedStats = STATS,
|
||||||
|
statValues = mutableMapOf(
|
||||||
|
"prefill_speed" to prefillSpeed,
|
||||||
|
"decode_speed" to decodeSpeed,
|
||||||
|
"time_to_first_token" to timeToFirstToken,
|
||||||
|
"latency" to (curTs - start).toFloat() / 1000f,
|
||||||
|
),
|
||||||
|
running = false,
|
||||||
|
latencyMs = -1f,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, cleanUpListener = {
|
}, cleanUpListener = {
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
|
@ -117,8 +159,8 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
while (model.instance == null) {
|
while (model.instance == null) {
|
||||||
delay(100)
|
delay(100)
|
||||||
}
|
}
|
||||||
val instance = model.instance as LlmInference
|
val instance = model.instance as LlmModelInstance
|
||||||
val prefillTokens = instance.sizeInTokens(message.content)
|
val prefillTokens = instance.session.sizeInTokens(message.content)
|
||||||
|
|
||||||
// Add the message to show benchmark results.
|
// Add the message to show benchmark results.
|
||||||
val benchmarkLlmResult = ChatMessageBenchmarkLlmResult(
|
val benchmarkLlmResult = ChatMessageBenchmarkLlmResult(
|
||||||
|
|
|
@ -16,21 +16,34 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.modelmanager
|
package com.google.aiedge.gallery.ui.modelmanager
|
||||||
|
|
||||||
import androidx.compose.foundation.ExperimentalFoundationApi
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.PaddingValues
|
import androidx.compose.foundation.layout.PaddingValues
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.foundation.lazy.LazyColumn
|
import androidx.compose.foundation.lazy.LazyColumn
|
||||||
import androidx.compose.foundation.lazy.items
|
import androidx.compose.foundation.lazy.items
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.outlined.Code
|
||||||
|
import androidx.compose.material.icons.outlined.Description
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
import androidx.compose.ui.platform.LocalUriHandler
|
||||||
|
import androidx.compose.ui.text.AnnotatedString
|
||||||
|
import androidx.compose.ui.text.SpanStyle
|
||||||
import androidx.compose.ui.text.font.FontWeight
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
import androidx.compose.ui.text.style.TextAlign
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
|
import androidx.compose.ui.text.style.TextDecoration
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
@ -39,9 +52,9 @@ import com.google.aiedge.gallery.ui.common.modelitem.ModelItem
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
|
|
||||||
/** The list of models in the model manager. */
|
/** The list of models in the model manager. */
|
||||||
@OptIn(ExperimentalFoundationApi::class)
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelList(
|
fun ModelList(
|
||||||
task: Task,
|
task: Task,
|
||||||
|
@ -62,11 +75,40 @@ fun ModelList(
|
||||||
textAlign = TextAlign.Center,
|
textAlign = TextAlign.Center,
|
||||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.padding(bottom = 20.dp)
|
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// URLs.
|
||||||
|
item(key = "urls") {
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.Center,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(top = 12.dp, bottom = 16.dp),
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
horizontalAlignment = Alignment.Start,
|
||||||
|
verticalArrangement = Arrangement.spacedBy(4.dp),
|
||||||
|
) {
|
||||||
|
if (task.docUrl.isNotEmpty()) {
|
||||||
|
ClickableLink(
|
||||||
|
url = task.docUrl,
|
||||||
|
linkText = "API Documentation",
|
||||||
|
icon = Icons.Outlined.Description
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (task.sourceCodeUrl.isNotEmpty()) {
|
||||||
|
ClickableLink(
|
||||||
|
url = task.sourceCodeUrl,
|
||||||
|
linkText = "Example code",
|
||||||
|
icon = Icons.Outlined.Code
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// List of models within a task.
|
// List of models within a task.
|
||||||
items(items = task.models) { model ->
|
items(items = task.models) { model ->
|
||||||
Box {
|
Box {
|
||||||
|
@ -82,6 +124,45 @@ fun ModelList(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ClickableLink(
|
||||||
|
url: String,
|
||||||
|
linkText: String,
|
||||||
|
icon: ImageVector,
|
||||||
|
) {
|
||||||
|
val uriHandler = LocalUriHandler.current
|
||||||
|
val annotatedText = AnnotatedString(
|
||||||
|
text = linkText,
|
||||||
|
spanStyles = listOf(
|
||||||
|
AnnotatedString.Range(
|
||||||
|
item = SpanStyle(
|
||||||
|
color = MaterialTheme.customColors.linkColor,
|
||||||
|
textDecoration = TextDecoration.Underline
|
||||||
|
),
|
||||||
|
start = 0,
|
||||||
|
end = linkText.length
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.Center,
|
||||||
|
) {
|
||||||
|
Icon(icon, contentDescription = "", modifier = Modifier.size(16.dp))
|
||||||
|
Text(
|
||||||
|
text = annotatedText,
|
||||||
|
textAlign = TextAlign.Center,
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(start = 6.dp)
|
||||||
|
.clickable {
|
||||||
|
uriHandler.openUri(url)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Preview(showBackground = true)
|
@Preview(showBackground = true)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelListPreview() {
|
fun ModelListPreview() {
|
||||||
|
|
|
@ -21,9 +21,6 @@ import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Scaffold
|
import androidx.compose.material3.Scaffold
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.collectAsState
|
|
||||||
import androidx.compose.runtime.getValue
|
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
@ -31,9 +28,7 @@ import com.google.aiedge.gallery.GalleryTopAppBar
|
||||||
import com.google.aiedge.gallery.data.AppBarAction
|
import com.google.aiedge.gallery.data.AppBarAction
|
||||||
import com.google.aiedge.gallery.data.AppBarActionType
|
import com.google.aiedge.gallery.data.AppBarActionType
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.getModelByName
|
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
@ -48,9 +43,6 @@ fun ModelManager(
|
||||||
onModelClicked: (Model) -> Unit,
|
onModelClicked: (Model) -> Unit,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
) {
|
) {
|
||||||
val uiState by viewModel.uiState.collectAsState()
|
|
||||||
val coroutineScope = rememberCoroutineScope()
|
|
||||||
|
|
||||||
// Set title based on the task.
|
// Set title based on the task.
|
||||||
var title = "${task.type.label} model"
|
var title = "${task.type.label} model"
|
||||||
if (task.models.size != 1) {
|
if (task.models.size != 1) {
|
||||||
|
@ -67,27 +59,7 @@ fun ModelManager(
|
||||||
topBar = {
|
topBar = {
|
||||||
GalleryTopAppBar(
|
GalleryTopAppBar(
|
||||||
title = title,
|
title = title,
|
||||||
// subtitle = String.format(
|
|
||||||
// stringResource(R.string.downloaded_size),
|
|
||||||
// totalSizeInBytes.humanReadableSize()
|
|
||||||
// ),
|
|
||||||
|
|
||||||
// Refresh model list button at the left side of the app bar.
|
|
||||||
// leftAction = AppBarAction(actionType = if (uiState.loadingHfModels) {
|
|
||||||
// AppBarActionType.REFRESHING_MODELS
|
|
||||||
// } else {
|
|
||||||
// AppBarActionType.REFRESH_MODELS
|
|
||||||
// }, actionFn = {
|
|
||||||
// coroutineScope.launch(Dispatchers.IO) {
|
|
||||||
// viewModel.loadHfModels()
|
|
||||||
// }
|
|
||||||
// }),
|
|
||||||
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp)
|
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp)
|
||||||
|
|
||||||
// "Done" button at the right side of the app bar to navigate up.
|
|
||||||
// rightAction = AppBarAction(
|
|
||||||
// actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp
|
|
||||||
// ),
|
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
) { innerPadding ->
|
) { innerPadding ->
|
||||||
|
@ -101,19 +73,6 @@ fun ModelManager(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun getTotalDownloadedFileSize(uiState: ModelManagerUiState): Long {
|
|
||||||
var totalSizeInBytes = 0L
|
|
||||||
for ((name, status) in uiState.modelDownloadStatus.entries) {
|
|
||||||
if (status.status == ModelDownloadStatusType.SUCCEEDED) {
|
|
||||||
totalSizeInBytes += getModelByName(name)?.totalBytes ?: 0L
|
|
||||||
} else if (status.status == ModelDownloadStatusType.IN_PROGRESS) {
|
|
||||||
totalSizeInBytes += status.receivedBytes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return totalSizeInBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Preview
|
@Preview
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelManagerPreview() {
|
fun ModelManagerPreview() {
|
||||||
|
|
|
@ -115,26 +115,27 @@ data class CustomColors(
|
||||||
val homeBottomGradient: List<Color> = listOf(),
|
val homeBottomGradient: List<Color> = listOf(),
|
||||||
val userBubbleBgColor: Color = Color.Transparent,
|
val userBubbleBgColor: Color = Color.Transparent,
|
||||||
val agentBubbleBgColor: Color = Color.Transparent,
|
val agentBubbleBgColor: Color = Color.Transparent,
|
||||||
|
val linkColor: Color = Color.Transparent,
|
||||||
)
|
)
|
||||||
|
|
||||||
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
|
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
|
||||||
|
|
||||||
val lightCustomColors = CustomColors(
|
val lightCustomColors = CustomColors(
|
||||||
taskBgColors = listOf(
|
taskBgColors = listOf(
|
||||||
|
// green
|
||||||
|
Color(0xFFE1F6DE),
|
||||||
|
// blue
|
||||||
|
Color(0xFFEDF0FF),
|
||||||
// yellow
|
// yellow
|
||||||
Color(0xFFFFEFC9),
|
Color(0xFFFFEFC9),
|
||||||
// red
|
// red
|
||||||
Color(0xFFFFEDE6),
|
Color(0xFFFFEDE6),
|
||||||
// green
|
|
||||||
Color(0xFFE1F6DE),
|
|
||||||
// blue
|
|
||||||
Color(0xFFEDF0FF)
|
|
||||||
),
|
),
|
||||||
taskIconColors = listOf(
|
taskIconColors = listOf(
|
||||||
Color(0xFFE37400),
|
|
||||||
Color(0xFFD93025),
|
|
||||||
Color(0xFF34A853),
|
Color(0xFF34A853),
|
||||||
Color(0xFF1967D2),
|
Color(0xFF1967D2),
|
||||||
|
Color(0xFFE37400),
|
||||||
|
Color(0xFFD93025),
|
||||||
),
|
),
|
||||||
taskIconShapeBgColor = Color.White,
|
taskIconShapeBgColor = Color.White,
|
||||||
homeBottomGradient = listOf(
|
homeBottomGradient = listOf(
|
||||||
|
@ -143,24 +144,25 @@ val lightCustomColors = CustomColors(
|
||||||
),
|
),
|
||||||
agentBubbleBgColor = Color(0xFFe9eef6),
|
agentBubbleBgColor = Color(0xFFe9eef6),
|
||||||
userBubbleBgColor = Color(0xFF32628D),
|
userBubbleBgColor = Color(0xFF32628D),
|
||||||
|
linkColor = Color(0xFF32628D),
|
||||||
)
|
)
|
||||||
|
|
||||||
val darkCustomColors = CustomColors(
|
val darkCustomColors = CustomColors(
|
||||||
taskBgColors = listOf(
|
taskBgColors = listOf(
|
||||||
|
// green
|
||||||
|
Color(0xFF2E312D),
|
||||||
|
// blue
|
||||||
|
Color(0xFF303033),
|
||||||
// yellow
|
// yellow
|
||||||
Color(0xFF33302A),
|
Color(0xFF33302A),
|
||||||
// red
|
// red
|
||||||
Color(0xFF362F2D),
|
Color(0xFF362F2D),
|
||||||
// green
|
|
||||||
Color(0xFF2E312D),
|
|
||||||
// blue
|
|
||||||
Color(0xFF303033)
|
|
||||||
),
|
),
|
||||||
taskIconColors = listOf(
|
taskIconColors = listOf(
|
||||||
Color(0xFFFFB955),
|
|
||||||
Color(0xFFFFB4AB),
|
|
||||||
Color(0xFF6DD58C),
|
Color(0xFF6DD58C),
|
||||||
Color(0xFFAAC7FF),
|
Color(0xFFAAC7FF),
|
||||||
|
Color(0xFFFFB955),
|
||||||
|
Color(0xFFFFB4AB),
|
||||||
),
|
),
|
||||||
taskIconShapeBgColor = Color(0xFF202124),
|
taskIconShapeBgColor = Color(0xFF202124),
|
||||||
homeBottomGradient = listOf(
|
homeBottomGradient = listOf(
|
||||||
|
@ -169,6 +171,7 @@ val darkCustomColors = CustomColors(
|
||||||
),
|
),
|
||||||
agentBubbleBgColor = Color(0xFF1b1c1d),
|
agentBubbleBgColor = Color(0xFF1b1c1d),
|
||||||
userBubbleBgColor = Color(0xFF1f3760),
|
userBubbleBgColor = Color(0xFF1f3760),
|
||||||
|
linkColor = Color(0xFF9DCAFC),
|
||||||
)
|
)
|
||||||
|
|
||||||
val MaterialTheme.customColors: CustomColors
|
val MaterialTheme.customColors: CustomColors
|
||||||
|
|
|
@ -35,7 +35,7 @@
|
||||||
<string name="model_is_initializing_msg">Initializing model…</string>
|
<string name="model_is_initializing_msg">Initializing model…</string>
|
||||||
<string name="text_input_placeholder_text_classification">Type movie review to classify…</string>
|
<string name="text_input_placeholder_text_classification">Type movie review to classify…</string>
|
||||||
<string name="text_image_generation_text_field_placeholder">Type prompt…</string>
|
<string name="text_image_generation_text_field_placeholder">Type prompt…</string>
|
||||||
<string name="text_input_placeholder_llm_chat">Type prompt (one shot)…</string>
|
<string name="text_input_placeholder_llm_chat">Type prompt…</string>
|
||||||
<string name="run_again">Run again</string>
|
<string name="run_again">Run again</string>
|
||||||
<string name="benchmark">Run benchmark</string>
|
<string name="benchmark">Run benchmark</string>
|
||||||
<string name="warming_up">warming up…</string>
|
<string name="warming_up">warming up…</string>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue