mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-06 06:30:30 -04:00
- Add "show stats" after each LLM response to see token related stats more easily.
- Back to multi-turn (no exception handling yet). - Show app's version in App Info screen. - Show API doc link and source code link in the model list screen. - Jump to model info page in HF when clicking "learn more".
This commit is contained in:
parent
ea31fd0544
commit
42d442389d
16 changed files with 343 additions and 183 deletions
|
@ -30,7 +30,7 @@ android {
|
|||
minSdk = 24
|
||||
targetSdk = 35
|
||||
versionCode = 1
|
||||
versionName = "1.0"
|
||||
versionName = "20250416"
|
||||
|
||||
// Needed for HuggingFace auth workflows.
|
||||
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
|
||||
|
@ -58,6 +58,7 @@ android {
|
|||
}
|
||||
buildFeatures {
|
||||
compose = true
|
||||
buildConfig = true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery
|
||||
|
||||
const val VERSION = "20250413"
|
|
@ -218,9 +218,6 @@ const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/lite
|
|||
const val LLM_CHAT_INFO =
|
||||
"Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms"
|
||||
|
||||
const val LLM_CHAT_LEARN_MORE_URL =
|
||||
"https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android"
|
||||
|
||||
const val IMAGE_GENERATION_INFO =
|
||||
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
|
||||
|
||||
|
@ -234,7 +231,7 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
|
|||
sizeInBytes = 1354301440L,
|
||||
configs = createLlmChatConfigs(),
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||
)
|
||||
|
||||
val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
||||
|
@ -244,7 +241,7 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
|||
sizeInBytes = 2627141632L,
|
||||
configs = createLlmChatConfigs(),
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||
)
|
||||
|
||||
val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
||||
|
@ -254,7 +251,7 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
|||
sizeInBytes = 554661243L,
|
||||
configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f),
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
||||
llmPromptTemplates = listOf(
|
||||
PromptTemplate(
|
||||
title = "Emoji Fun",
|
||||
|
@ -277,7 +274,7 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
|
|||
llmBackend = LlmBackend.CPU,
|
||||
configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f),
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
)
|
||||
|
||||
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
|
||||
|
@ -343,6 +340,7 @@ val MODEL_IMAGE_GENERATION_STABLE_DIFFUSION: Model = Model(
|
|||
showBenchmarkButton = false,
|
||||
info = IMAGE_GENERATION_INFO,
|
||||
configs = IMAGE_GENERATION_CONFIGS,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||
)
|
||||
|
||||
val EMPTY_MODEL: Model = Model(
|
||||
|
|
|
@ -50,6 +50,12 @@ data class Task(
|
|||
/** Description of the task. */
|
||||
val description: String,
|
||||
|
||||
/** Documentation url for the task. */
|
||||
val docUrl: String = "",
|
||||
|
||||
/** Source code url for the model-related functions. */
|
||||
val sourceCodeUrl: String = "",
|
||||
|
||||
/** Placeholder text for the name of the agent shown above chat messages. */
|
||||
@StringRes val agentNameRes: Int = R.string.chat_generic_agent_name,
|
||||
|
||||
|
@ -80,6 +86,8 @@ val TASK_LLM_CHAT = Task(
|
|||
iconVectorResourceId = R.drawable.chat_spark,
|
||||
models = MODELS_LLM_CHAT,
|
||||
description = "Chat? with a on-device large language model",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||
)
|
||||
|
||||
|
@ -88,13 +96,15 @@ val TASK_IMAGE_GENERATION = Task(
|
|||
iconVectorResourceId = R.drawable.image_spark,
|
||||
models = MODELS_IMAGE_GENERATION,
|
||||
description = "Generate images from text",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android",
|
||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt",
|
||||
textInputPlaceHolderRes = R.string.text_image_generation_text_field_placeholder
|
||||
)
|
||||
|
||||
/** All tasks. */
|
||||
val TASKS: List<Task> = listOf(
|
||||
TASK_TEXT_CLASSIFICATION,
|
||||
TASK_IMAGE_CLASSIFICATION,
|
||||
// TASK_TEXT_CLASSIFICATION,
|
||||
// TASK_IMAGE_CLASSIFICATION,
|
||||
TASK_IMAGE_GENERATION,
|
||||
TASK_LLM_CHAT,
|
||||
)
|
||||
|
|
|
@ -71,13 +71,17 @@ open class ChatMessageText(
|
|||
// Negative numbers will hide the latency display.
|
||||
override val latencyMs: Float = 0f,
|
||||
val isMarkdown: Boolean = true,
|
||||
|
||||
// Benchmark result for LLM response.
|
||||
var llmBenchmarkResult: ChatMessageBenchmarkLlmResult? = null,
|
||||
) : ChatMessage(type = ChatMessageType.TEXT, side = side, latencyMs = latencyMs) {
|
||||
override fun clone(): ChatMessageText {
|
||||
return ChatMessageText(
|
||||
content = content,
|
||||
side = side,
|
||||
latencyMs = latencyMs,
|
||||
isMarkdown = isMarkdown
|
||||
isMarkdown = isMarkdown,
|
||||
llmBenchmarkResult = llmBenchmarkResult,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -150,6 +150,8 @@ fun ChatPanel(
|
|||
lastMessageContent.value = tmpLastMessage.content
|
||||
}
|
||||
}
|
||||
val lastShowingStatsByModel: MutableState<Map<String, MutableSet<ChatMessage>>> =
|
||||
remember { mutableStateOf(mapOf()) }
|
||||
|
||||
// Scroll the content to the bottom when any of these changes.
|
||||
LaunchedEffect(
|
||||
|
@ -158,9 +160,27 @@ fun ChatPanel(
|
|||
lastMessageContent.value,
|
||||
WindowInsets.ime.getBottom(density),
|
||||
) {
|
||||
// Only scroll if showingStatsByModel is not changed. In other words, when showingStatsByModel
|
||||
// changes we want the display to not scroll.
|
||||
if (messages.isNotEmpty()) {
|
||||
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
|
||||
if (uiState.showingStatsByModel === lastShowingStatsByModel.value) {
|
||||
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
|
||||
} else {
|
||||
// Scroll to bottom if the message to show stats is the last message.
|
||||
val curShowingStats =
|
||||
uiState.showingStatsByModel[selectedModel.name]?.toMutableSet() ?: mutableSetOf()
|
||||
val lastShowingStats = lastShowingStatsByModel.value[selectedModel.name] ?: mutableSetOf()
|
||||
curShowingStats.removeAll(lastShowingStats)
|
||||
if (curShowingStats.isNotEmpty()) {
|
||||
val index =
|
||||
viewModel.getMessageIndex(model = selectedModel, message = curShowingStats.first())
|
||||
if (index == messages.size - 2) {
|
||||
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
lastShowingStatsByModel.value = uiState.showingStatsByModel
|
||||
}
|
||||
|
||||
val nestedScrollConnection = remember {
|
||||
|
@ -309,7 +329,51 @@ fun ChatPanel(
|
|||
}
|
||||
}
|
||||
if (message.side == ChatSide.AGENT) {
|
||||
LatencyText(message = message)
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||
) {
|
||||
LatencyText(message = message)
|
||||
// A button to show stats for the LLM message.
|
||||
if (selectedModel.taskType == TaskType.LLM_CHAT && message is ChatMessageText
|
||||
// This means we only want to show the action button when the message is done
|
||||
// generating, at which point the latency will be set.
|
||||
&& message.latencyMs >= 0
|
||||
) {
|
||||
val showingStats =
|
||||
viewModel.isShowingStats(model = selectedModel, message = message)
|
||||
MessageActionButton(
|
||||
label = if (showingStats) "Hide stats" else "Show stats",
|
||||
icon = Icons.Outlined.Timer,
|
||||
onClick = {
|
||||
// Toggle showing stats.
|
||||
viewModel.toggleShowingStats(selectedModel, message)
|
||||
|
||||
// Add the stats message after the LLM message.
|
||||
if (viewModel.isShowingStats(model = selectedModel, message = message)) {
|
||||
val llmBenchmarkResult = message.llmBenchmarkResult
|
||||
if (llmBenchmarkResult != null) {
|
||||
viewModel.insertMessageAfter(
|
||||
model = selectedModel,
|
||||
anchorMessage = message,
|
||||
messageToAdd = llmBenchmarkResult,
|
||||
)
|
||||
}
|
||||
}
|
||||
// Remove the stats message.
|
||||
else {
|
||||
val curMessageIndex =
|
||||
viewModel.getMessageIndex(model = selectedModel, message = message)
|
||||
viewModel.removeMessageAt(
|
||||
model = selectedModel,
|
||||
index = curMessageIndex + 1
|
||||
)
|
||||
}
|
||||
},
|
||||
enabled = !uiState.inProgress
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if (message.side == ChatSide.USER) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
|
@ -328,21 +392,21 @@ fun ChatPanel(
|
|||
}
|
||||
|
||||
// Benchmark button
|
||||
if (selectedModel.showBenchmarkButton) {
|
||||
MessageActionButton(
|
||||
label = stringResource(R.string.benchmark),
|
||||
icon = Icons.Outlined.Timer,
|
||||
onClick = {
|
||||
if (selectedModel.taskType == TaskType.LLM_CHAT) {
|
||||
onBenchmarkClicked(selectedModel, message, 0, 0)
|
||||
} else {
|
||||
showBenchmarkConfigsDialog = true
|
||||
benchmarkMessage.value = message
|
||||
}
|
||||
},
|
||||
enabled = !uiState.inProgress
|
||||
)
|
||||
}
|
||||
// if (selectedModel.showBenchmarkButton) {
|
||||
// MessageActionButton(
|
||||
// label = stringResource(R.string.benchmark),
|
||||
// icon = Icons.Outlined.Timer,
|
||||
// onClick = {
|
||||
// if (selectedModel.taskType == TaskType.LLM_CHAT) {
|
||||
// onBenchmarkClicked(selectedModel, message, 0, 0)
|
||||
// } else {
|
||||
// showBenchmarkConfigsDialog = true
|
||||
// benchmarkMessage.value = message
|
||||
// }
|
||||
// },
|
||||
// enabled = !uiState.inProgress
|
||||
// )
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,6 +43,12 @@ data class ChatUiState(
|
|||
* A map of model names to the currently streaming chat message.
|
||||
*/
|
||||
val streamingMessagesByModel: Map<String, ChatMessage> = mapOf(),
|
||||
|
||||
/*
|
||||
* A map of model names to a map of chat messages to a boolean indicating whether the message is
|
||||
* showing the stats below it.
|
||||
*/
|
||||
val showingStatsByModel: Map<String, MutableSet<ChatMessage>> = mapOf(),
|
||||
)
|
||||
|
||||
/**
|
||||
|
@ -66,6 +72,31 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
|||
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||
}
|
||||
|
||||
fun insertMessageAfter(model: Model, anchorMessage: ChatMessage, messageToAdd: ChatMessage) {
|
||||
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||
val newMessages = newMessagesByModel[model.name]?.toMutableList()
|
||||
if (newMessages != null) {
|
||||
newMessagesByModel[model.name] = newMessages
|
||||
// Find the index of the anchor message
|
||||
val anchorIndex = newMessages.indexOf(anchorMessage)
|
||||
if (anchorIndex != -1) {
|
||||
// Insert the new message after the anchor message
|
||||
newMessages.add(anchorIndex + 1, messageToAdd)
|
||||
}
|
||||
}
|
||||
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||
}
|
||||
|
||||
fun removeMessageAt(model: Model, index: Int) {
|
||||
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||
val newMessages = newMessagesByModel[model.name]?.toMutableList()
|
||||
if (newMessages != null) {
|
||||
newMessagesByModel[model.name] = newMessages
|
||||
newMessages.removeAt(index)
|
||||
}
|
||||
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||
}
|
||||
|
||||
fun removeLastMessage(model: Model) {
|
||||
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
|
||||
|
@ -80,7 +111,7 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
|||
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
|
||||
}
|
||||
|
||||
fun updateLastMessageContentIncrementally(
|
||||
fun updateLastTextMessageContentIncrementally(
|
||||
model: Model,
|
||||
partialContent: String,
|
||||
latencyMs: Float,
|
||||
|
@ -124,6 +155,25 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
|||
_uiState.update { newUiState }
|
||||
}
|
||||
|
||||
fun updateLastTextMessageLlmBenchmarkResult(
|
||||
model: Model,
|
||||
llmBenchmarkResult: ChatMessageBenchmarkLlmResult
|
||||
) {
|
||||
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
|
||||
if (newMessages.size > 0) {
|
||||
val lastMessage = newMessages.last()
|
||||
if (lastMessage is ChatMessageText) {
|
||||
lastMessage.llmBenchmarkResult = llmBenchmarkResult
|
||||
newMessages.removeAt(newMessages.size - 1)
|
||||
newMessages.add(lastMessage)
|
||||
}
|
||||
}
|
||||
newMessagesByModel[model.name] = newMessages
|
||||
val newUiState = _uiState.value.copy(messagesByModel = newMessagesByModel)
|
||||
_uiState.update { newUiState }
|
||||
}
|
||||
|
||||
fun replaceLastMessage(model: Model, message: ChatMessage, type: ChatMessageType) {
|
||||
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
|
||||
|
@ -159,10 +209,6 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
|||
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
||||
}
|
||||
|
||||
fun isInProgress(): Boolean {
|
||||
return _uiState.value.inProgress
|
||||
}
|
||||
|
||||
fun addConfigChangedMessage(
|
||||
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
|
||||
) {
|
||||
|
@ -173,6 +219,26 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
|||
addMessage(message = message, model = model)
|
||||
}
|
||||
|
||||
fun getMessageIndex(model: Model, message: ChatMessage): Int {
|
||||
return (_uiState.value.messagesByModel[model.name] ?: listOf()).indexOf(message)
|
||||
}
|
||||
|
||||
fun isShowingStats(model: Model, message: ChatMessage): Boolean {
|
||||
return _uiState.value.showingStatsByModel[model.name]?.contains(message) ?: false
|
||||
}
|
||||
|
||||
fun toggleShowingStats(model: Model, message: ChatMessage) {
|
||||
val newShowingStatsByModel = _uiState.value.showingStatsByModel.toMutableMap()
|
||||
val newShowingStats = newShowingStatsByModel[model.name]?.toMutableSet() ?: mutableSetOf()
|
||||
if (newShowingStats.contains(message)) {
|
||||
newShowingStats.remove(message)
|
||||
} else {
|
||||
newShowingStats.add(message)
|
||||
}
|
||||
newShowingStatsByModel[model.name] = newShowingStats
|
||||
_uiState.update { _uiState.value.copy(showingStatsByModel = newShowingStatsByModel) }
|
||||
}
|
||||
|
||||
private fun createUiState(task: Task): ChatUiState {
|
||||
val messagesByModel: MutableMap<String, MutableList<ChatMessage>> = mutableMapOf()
|
||||
for (model in task.models) {
|
||||
|
|
|
@ -179,7 +179,7 @@ private fun getMessageLayoutConfig(
|
|||
is ChatMessageBenchmarkLlmResult -> {
|
||||
horizontalArrangement = Arrangement.SpaceBetween
|
||||
modifier = modifier.fillMaxWidth()
|
||||
userLabel = "Benchmark"
|
||||
userLabel = "Stats"
|
||||
}
|
||||
|
||||
is ChatMessageImageWithHistory -> {
|
||||
|
|
|
@ -38,6 +38,7 @@ import androidx.compose.material.icons.Icons
|
|||
import androidx.compose.material.icons.rounded.Settings
|
||||
import androidx.compose.material.icons.rounded.UnfoldLess
|
||||
import androidx.compose.material.icons.rounded.UnfoldMore
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.OutlinedButton
|
||||
|
@ -90,6 +91,7 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
|
|||
* model description and buttons for learning more (opening a URL) and downloading/trying
|
||||
* the model.
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ModelItem(
|
||||
model: Model,
|
||||
|
@ -217,7 +219,7 @@ fun ModelItem(
|
|||
.then(m),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||
) {
|
||||
// The "learn more" button. Click to show url in default browser.
|
||||
// The "learn more" button. Click to show related urls in a bottom sheet.
|
||||
if (model.learnMoreUrl.isNotEmpty()) {
|
||||
OutlinedButton(
|
||||
onClick = {
|
||||
|
@ -241,35 +243,6 @@ fun ModelItem(
|
|||
modelManagerViewModel = modelManagerViewModel,
|
||||
onClicked = { onModelClicked(model) }
|
||||
)
|
||||
// Button(
|
||||
// onClick = {
|
||||
// if (isExpanded) {
|
||||
// onModelClicked(model)
|
||||
// if (needToDownloadFirst) {
|
||||
// scope.launch {
|
||||
// delay(80)
|
||||
// checkNotificationPermissonAndStartDownload(
|
||||
// context = context,
|
||||
// launcher = launcher,
|
||||
// modelManagerViewModel = modelManagerViewModel,
|
||||
// model = model
|
||||
// )
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// ) {
|
||||
// Icon(
|
||||
// Icons.AutoMirrored.Rounded.ArrowForward,
|
||||
// contentDescription = "",
|
||||
// modifier = Modifier.padding(end = 4.dp)
|
||||
// )
|
||||
// if (needToDownloadFirst) {
|
||||
// Text("Download & Try it", maxLines = 1)
|
||||
// } else {
|
||||
// Text("Try it", maxLines = 1)
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -366,7 +339,6 @@ fun ModelItem(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package com.google.aiedge.gallery.ui.home
|
||||
|
||||
import androidx.compose.runtime.Composable
|
||||
import com.google.aiedge.gallery.VERSION
|
||||
import com.google.aiedge.gallery.BuildConfig
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||
|
@ -45,7 +45,7 @@ fun SettingsDialog(
|
|||
)
|
||||
ConfigDialog(
|
||||
title = "Settings",
|
||||
subtitle = "App version: $VERSION",
|
||||
subtitle = "App version: ${BuildConfig.VERSION_NAME}",
|
||||
okBtnLabel = "OK",
|
||||
configs = CONFIGS,
|
||||
initialValues = initialValues,
|
||||
|
|
|
@ -18,12 +18,11 @@ package com.google.aiedge.gallery.ui.llmchat
|
|||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.LlmBackend
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||
|
||||
private const val TAG = "AGLlmChatModelHelper"
|
||||
private const val DEFAULT_MAX_TOKEN = 1024
|
||||
|
@ -39,8 +38,6 @@ data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceS
|
|||
object LlmChatModelHelper {
|
||||
// Indexed by model name.
|
||||
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
|
||||
private val generateResponseListenableFutures: MutableMap<String, ListenableFuture<String>> =
|
||||
mutableMapOf()
|
||||
|
||||
fun initialize(
|
||||
context: Context, model: Model, onDone: () -> Unit
|
||||
|
@ -64,13 +61,12 @@ object LlmChatModelHelper {
|
|||
try {
|
||||
val llmInference = LlmInference.createFromOptions(context, options)
|
||||
|
||||
// val session = LlmInferenceSession.createFromOptions(
|
||||
// llmInference,
|
||||
// LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||
// .setTemperature(temperature).build()
|
||||
// )
|
||||
model.instance = llmInference
|
||||
// LlmModelInstance(engine = llmInference, session = session)
|
||||
val session = LlmInferenceSession.createFromOptions(
|
||||
llmInference,
|
||||
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||
.setTemperature(temperature).build()
|
||||
)
|
||||
model.instance = LlmModelInstance(engine = llmInference, session = session)
|
||||
} catch (e: Exception) {
|
||||
e.printStackTrace()
|
||||
}
|
||||
|
@ -82,11 +78,10 @@ object LlmChatModelHelper {
|
|||
return
|
||||
}
|
||||
|
||||
val instance = model.instance as LlmInference
|
||||
val instance = model.instance as LlmModelInstance
|
||||
try {
|
||||
instance.close()
|
||||
// instance.session.close()
|
||||
// instance.engine.close()
|
||||
instance.session.close()
|
||||
instance.engine.close()
|
||||
} catch (e: Exception) {
|
||||
// ignore
|
||||
}
|
||||
|
@ -104,7 +99,7 @@ object LlmChatModelHelper {
|
|||
resultListener: ResultListener,
|
||||
cleanUpListener: CleanUpListener,
|
||||
) {
|
||||
val instance = model.instance as LlmInference
|
||||
val instance = model.instance as LlmModelInstance
|
||||
|
||||
// Set listener.
|
||||
if (!cleanUpListeners.containsKey(model.name)) {
|
||||
|
@ -112,24 +107,8 @@ object LlmChatModelHelper {
|
|||
}
|
||||
|
||||
// Start async inference.
|
||||
val future = instance.generateResponseAsync(input, resultListener)
|
||||
generateResponseListenableFutures[model.name] = future
|
||||
|
||||
// val session = instance.session
|
||||
// TODO: need to count token and reset session.
|
||||
// session.addQueryChunk(input)
|
||||
// session.generateResponseAsync(resultListener)
|
||||
}
|
||||
|
||||
fun stopInference(model: Model) {
|
||||
val instance = model.instance as LlmInference
|
||||
if (instance != null) {
|
||||
instance.close()
|
||||
}
|
||||
// val future = generateResponseListenableFutures[model.name]
|
||||
// if (future != null) {
|
||||
// future.cancel(true)
|
||||
// generateResponseListenableFutures.remove(model.name)
|
||||
// }
|
||||
val session = instance.session
|
||||
session.addQueryChunk(input)
|
||||
session.generateResponseAsync(resultListener)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package com.google.aiedge.gallery.ui.llmchat
|
||||
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||
|
@ -56,11 +55,31 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
}
|
||||
|
||||
// Run inference.
|
||||
val instance = model.instance as LlmModelInstance
|
||||
val prefillTokens = instance.session.sizeInTokens(input)
|
||||
|
||||
var firstRun = true
|
||||
var timeToFirstToken = 0f
|
||||
var firstTokenTs = 0L
|
||||
var decodeTokens = 0
|
||||
var prefillSpeed = 0f
|
||||
var decodeSpeed: Float
|
||||
val start = System.currentTimeMillis()
|
||||
LlmChatModelHelper.runInference(
|
||||
model = model,
|
||||
input = input,
|
||||
resultListener = { partialResult, done ->
|
||||
val curTs = System.currentTimeMillis()
|
||||
|
||||
if (firstRun) {
|
||||
firstTokenTs = System.currentTimeMillis()
|
||||
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||
prefillSpeed = prefillTokens / timeToFirstToken
|
||||
firstRun = false
|
||||
} else {
|
||||
decodeTokens++
|
||||
}
|
||||
|
||||
// Remove the last message if it is a "loading" message.
|
||||
// This will only be done once.
|
||||
val lastMessage = getLastMessage(model = model)
|
||||
|
@ -76,7 +95,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
|
||||
// Incrementally update the streamed partial results.
|
||||
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
|
||||
updateLastMessageContentIncrementally(
|
||||
updateLastTextMessageContentIncrementally(
|
||||
model = model,
|
||||
partialContent = partialResult,
|
||||
latencyMs = latencyMs.toFloat()
|
||||
|
@ -84,6 +103,29 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
|
||||
if (done) {
|
||||
setInProgress(false)
|
||||
|
||||
decodeSpeed =
|
||||
decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||
if (decodeSpeed.isNaN()) {
|
||||
decodeSpeed = 0f
|
||||
}
|
||||
|
||||
if (lastMessage is ChatMessageText) {
|
||||
updateLastTextMessageLlmBenchmarkResult(
|
||||
model = model, llmBenchmarkResult =
|
||||
ChatMessageBenchmarkLlmResult(
|
||||
orderedStats = STATS,
|
||||
statValues = mutableMapOf(
|
||||
"prefill_speed" to prefillSpeed,
|
||||
"decode_speed" to decodeSpeed,
|
||||
"time_to_first_token" to timeToFirstToken,
|
||||
"latency" to (curTs - start).toFloat() / 1000f,
|
||||
),
|
||||
running = false,
|
||||
latencyMs = -1f,
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}, cleanUpListener = {
|
||||
setInProgress(false)
|
||||
|
@ -117,8 +159,8 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
while (model.instance == null) {
|
||||
delay(100)
|
||||
}
|
||||
val instance = model.instance as LlmInference
|
||||
val prefillTokens = instance.sizeInTokens(message.content)
|
||||
val instance = model.instance as LlmModelInstance
|
||||
val prefillTokens = instance.session.sizeInTokens(message.content)
|
||||
|
||||
// Add the message to show benchmark results.
|
||||
val benchmarkLlmResult = ChatMessageBenchmarkLlmResult(
|
||||
|
|
|
@ -16,21 +16,34 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.modelmanager
|
||||
|
||||
import androidx.compose.foundation.ExperimentalFoundationApi
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.PaddingValues
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.outlined.Code
|
||||
import androidx.compose.material.icons.outlined.Description
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalUriHandler
|
||||
import androidx.compose.ui.text.AnnotatedString
|
||||
import androidx.compose.ui.text.SpanStyle
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.text.style.TextDecoration
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
|
@ -39,9 +52,9 @@ import com.google.aiedge.gallery.ui.common.modelitem.ModelItem
|
|||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
|
||||
/** The list of models in the model manager. */
|
||||
@OptIn(ExperimentalFoundationApi::class)
|
||||
@Composable
|
||||
fun ModelList(
|
||||
task: Task,
|
||||
|
@ -62,11 +75,40 @@ fun ModelList(
|
|||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||
modifier = Modifier
|
||||
.padding(bottom = 20.dp)
|
||||
.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
|
||||
// URLs.
|
||||
item(key = "urls") {
|
||||
Row(
|
||||
horizontalArrangement = Arrangement.Center,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(top = 12.dp, bottom = 16.dp),
|
||||
) {
|
||||
Column(
|
||||
horizontalAlignment = Alignment.Start,
|
||||
verticalArrangement = Arrangement.spacedBy(4.dp),
|
||||
) {
|
||||
if (task.docUrl.isNotEmpty()) {
|
||||
ClickableLink(
|
||||
url = task.docUrl,
|
||||
linkText = "API Documentation",
|
||||
icon = Icons.Outlined.Description
|
||||
)
|
||||
}
|
||||
if (task.sourceCodeUrl.isNotEmpty()) {
|
||||
ClickableLink(
|
||||
url = task.sourceCodeUrl,
|
||||
linkText = "Example code",
|
||||
icon = Icons.Outlined.Code
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// List of models within a task.
|
||||
items(items = task.models) { model ->
|
||||
Box {
|
||||
|
@ -82,6 +124,45 @@ fun ModelList(
|
|||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ClickableLink(
|
||||
url: String,
|
||||
linkText: String,
|
||||
icon: ImageVector,
|
||||
) {
|
||||
val uriHandler = LocalUriHandler.current
|
||||
val annotatedText = AnnotatedString(
|
||||
text = linkText,
|
||||
spanStyles = listOf(
|
||||
AnnotatedString.Range(
|
||||
item = SpanStyle(
|
||||
color = MaterialTheme.customColors.linkColor,
|
||||
textDecoration = TextDecoration.Underline
|
||||
),
|
||||
start = 0,
|
||||
end = linkText.length
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.Center,
|
||||
) {
|
||||
Icon(icon, contentDescription = "", modifier = Modifier.size(16.dp))
|
||||
Text(
|
||||
text = annotatedText,
|
||||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
modifier = Modifier
|
||||
.padding(start = 6.dp)
|
||||
.clickable {
|
||||
uriHandler.openUri(url)
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ModelListPreview() {
|
||||
|
|
|
@ -21,9 +21,6 @@ import androidx.compose.foundation.layout.fillMaxSize
|
|||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
|
@ -31,9 +28,7 @@ import com.google.aiedge.gallery.GalleryTopAppBar
|
|||
import com.google.aiedge.gallery.data.AppBarAction
|
||||
import com.google.aiedge.gallery.data.AppBarActionType
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.getModelByName
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
|
@ -48,9 +43,6 @@ fun ModelManager(
|
|||
onModelClicked: (Model) -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
) {
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
// Set title based on the task.
|
||||
var title = "${task.type.label} model"
|
||||
if (task.models.size != 1) {
|
||||
|
@ -67,27 +59,7 @@ fun ModelManager(
|
|||
topBar = {
|
||||
GalleryTopAppBar(
|
||||
title = title,
|
||||
// subtitle = String.format(
|
||||
// stringResource(R.string.downloaded_size),
|
||||
// totalSizeInBytes.humanReadableSize()
|
||||
// ),
|
||||
|
||||
// Refresh model list button at the left side of the app bar.
|
||||
// leftAction = AppBarAction(actionType = if (uiState.loadingHfModels) {
|
||||
// AppBarActionType.REFRESHING_MODELS
|
||||
// } else {
|
||||
// AppBarActionType.REFRESH_MODELS
|
||||
// }, actionFn = {
|
||||
// coroutineScope.launch(Dispatchers.IO) {
|
||||
// viewModel.loadHfModels()
|
||||
// }
|
||||
// }),
|
||||
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp)
|
||||
|
||||
// "Done" button at the right side of the app bar to navigate up.
|
||||
// rightAction = AppBarAction(
|
||||
// actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp
|
||||
// ),
|
||||
)
|
||||
},
|
||||
) { innerPadding ->
|
||||
|
@ -101,19 +73,6 @@ fun ModelManager(
|
|||
}
|
||||
}
|
||||
|
||||
private fun getTotalDownloadedFileSize(uiState: ModelManagerUiState): Long {
|
||||
var totalSizeInBytes = 0L
|
||||
for ((name, status) in uiState.modelDownloadStatus.entries) {
|
||||
if (status.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
totalSizeInBytes += getModelByName(name)?.totalBytes ?: 0L
|
||||
} else if (status.status == ModelDownloadStatusType.IN_PROGRESS) {
|
||||
totalSizeInBytes += status.receivedBytes
|
||||
}
|
||||
}
|
||||
return totalSizeInBytes
|
||||
}
|
||||
|
||||
|
||||
@Preview
|
||||
@Composable
|
||||
fun ModelManagerPreview() {
|
||||
|
|
|
@ -115,26 +115,27 @@ data class CustomColors(
|
|||
val homeBottomGradient: List<Color> = listOf(),
|
||||
val userBubbleBgColor: Color = Color.Transparent,
|
||||
val agentBubbleBgColor: Color = Color.Transparent,
|
||||
val linkColor: Color = Color.Transparent,
|
||||
)
|
||||
|
||||
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
|
||||
|
||||
val lightCustomColors = CustomColors(
|
||||
taskBgColors = listOf(
|
||||
// green
|
||||
Color(0xFFE1F6DE),
|
||||
// blue
|
||||
Color(0xFFEDF0FF),
|
||||
// yellow
|
||||
Color(0xFFFFEFC9),
|
||||
// red
|
||||
Color(0xFFFFEDE6),
|
||||
// green
|
||||
Color(0xFFE1F6DE),
|
||||
// blue
|
||||
Color(0xFFEDF0FF)
|
||||
),
|
||||
taskIconColors = listOf(
|
||||
Color(0xFFE37400),
|
||||
Color(0xFFD93025),
|
||||
Color(0xFF34A853),
|
||||
Color(0xFF1967D2),
|
||||
Color(0xFFE37400),
|
||||
Color(0xFFD93025),
|
||||
),
|
||||
taskIconShapeBgColor = Color.White,
|
||||
homeBottomGradient = listOf(
|
||||
|
@ -143,24 +144,25 @@ val lightCustomColors = CustomColors(
|
|||
),
|
||||
agentBubbleBgColor = Color(0xFFe9eef6),
|
||||
userBubbleBgColor = Color(0xFF32628D),
|
||||
linkColor = Color(0xFF32628D),
|
||||
)
|
||||
|
||||
val darkCustomColors = CustomColors(
|
||||
taskBgColors = listOf(
|
||||
// green
|
||||
Color(0xFF2E312D),
|
||||
// blue
|
||||
Color(0xFF303033),
|
||||
// yellow
|
||||
Color(0xFF33302A),
|
||||
// red
|
||||
Color(0xFF362F2D),
|
||||
// green
|
||||
Color(0xFF2E312D),
|
||||
// blue
|
||||
Color(0xFF303033)
|
||||
),
|
||||
taskIconColors = listOf(
|
||||
Color(0xFFFFB955),
|
||||
Color(0xFFFFB4AB),
|
||||
Color(0xFF6DD58C),
|
||||
Color(0xFFAAC7FF),
|
||||
Color(0xFFFFB955),
|
||||
Color(0xFFFFB4AB),
|
||||
),
|
||||
taskIconShapeBgColor = Color(0xFF202124),
|
||||
homeBottomGradient = listOf(
|
||||
|
@ -169,6 +171,7 @@ val darkCustomColors = CustomColors(
|
|||
),
|
||||
agentBubbleBgColor = Color(0xFF1b1c1d),
|
||||
userBubbleBgColor = Color(0xFF1f3760),
|
||||
linkColor = Color(0xFF9DCAFC),
|
||||
)
|
||||
|
||||
val MaterialTheme.customColors: CustomColors
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
<string name="model_is_initializing_msg">Initializing model…</string>
|
||||
<string name="text_input_placeholder_text_classification">Type movie review to classify…</string>
|
||||
<string name="text_image_generation_text_field_placeholder">Type prompt…</string>
|
||||
<string name="text_input_placeholder_llm_chat">Type prompt (one shot)…</string>
|
||||
<string name="text_input_placeholder_llm_chat">Type prompt…</string>
|
||||
<string name="run_again">Run again</string>
|
||||
<string name="benchmark">Run benchmark</string>
|
||||
<string name="warming_up">warming up…</string>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue