- Add "show stats" after each LLM response to see token related stats more easily.

- Back to multi-turn (no exception handling yet).
- Show app's version in App Info screen.
- Show API doc link and source code link in the model list screen.
- Jump to model info page in HF when clicking "learn more".
This commit is contained in:
Jing Jin 2025-04-16 14:51:45 -07:00
parent ea31fd0544
commit 42d442389d
16 changed files with 343 additions and 183 deletions

View file

@ -30,7 +30,7 @@ android {
minSdk = 24
targetSdk = 35
versionCode = 1
versionName = "1.0"
versionName = "20250416"
// Needed for HuggingFace auth workflows.
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
@ -58,6 +58,7 @@ android {
}
buildFeatures {
compose = true
buildConfig = true
}
}

View file

@ -1,19 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery
const val VERSION = "20250413"

View file

@ -218,9 +218,6 @@ const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/lite
const val LLM_CHAT_INFO =
"Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms"
const val LLM_CHAT_LEARN_MORE_URL =
"https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android"
const val IMAGE_GENERATION_INFO =
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
@ -234,7 +231,7 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
sizeInBytes = 1354301440L,
configs = createLlmChatConfigs(),
info = LLM_CHAT_INFO,
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
@ -244,7 +241,7 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
sizeInBytes = 2627141632L,
configs = createLlmChatConfigs(),
info = LLM_CHAT_INFO,
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
@ -254,7 +251,7 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
sizeInBytes = 554661243L,
configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f),
info = LLM_CHAT_INFO,
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
llmPromptTemplates = listOf(
PromptTemplate(
title = "Emoji Fun",
@ -277,7 +274,7 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
llmBackend = LlmBackend.CPU,
configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f),
info = LLM_CHAT_INFO,
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
)
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
@ -343,6 +340,7 @@ val MODEL_IMAGE_GENERATION_STABLE_DIFFUSION: Model = Model(
showBenchmarkButton = false,
info = IMAGE_GENERATION_INFO,
configs = IMAGE_GENERATION_CONFIGS,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val EMPTY_MODEL: Model = Model(

View file

@ -50,6 +50,12 @@ data class Task(
/** Description of the task. */
val description: String,
/** Documentation url for the task. */
val docUrl: String = "",
/** Source code url for the model-related functions. */
val sourceCodeUrl: String = "",
/** Placeholder text for the name of the agent shown above chat messages. */
@StringRes val agentNameRes: Int = R.string.chat_generic_agent_name,
@ -80,6 +86,8 @@ val TASK_LLM_CHAT = Task(
iconVectorResourceId = R.drawable.chat_spark,
models = MODELS_LLM_CHAT,
description = "Chat? with a on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
@ -88,13 +96,15 @@ val TASK_IMAGE_GENERATION = Task(
iconVectorResourceId = R.drawable.image_spark,
models = MODELS_IMAGE_GENERATION,
description = "Generate images from text",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt",
textInputPlaceHolderRes = R.string.text_image_generation_text_field_placeholder
)
/** All tasks. */
val TASKS: List<Task> = listOf(
TASK_TEXT_CLASSIFICATION,
TASK_IMAGE_CLASSIFICATION,
// TASK_TEXT_CLASSIFICATION,
// TASK_IMAGE_CLASSIFICATION,
TASK_IMAGE_GENERATION,
TASK_LLM_CHAT,
)

View file

@ -71,13 +71,17 @@ open class ChatMessageText(
// Negative numbers will hide the latency display.
override val latencyMs: Float = 0f,
val isMarkdown: Boolean = true,
// Benchmark result for LLM response.
var llmBenchmarkResult: ChatMessageBenchmarkLlmResult? = null,
) : ChatMessage(type = ChatMessageType.TEXT, side = side, latencyMs = latencyMs) {
override fun clone(): ChatMessageText {
return ChatMessageText(
content = content,
side = side,
latencyMs = latencyMs,
isMarkdown = isMarkdown
isMarkdown = isMarkdown,
llmBenchmarkResult = llmBenchmarkResult,
)
}
}

View file

@ -150,6 +150,8 @@ fun ChatPanel(
lastMessageContent.value = tmpLastMessage.content
}
}
val lastShowingStatsByModel: MutableState<Map<String, MutableSet<ChatMessage>>> =
remember { mutableStateOf(mapOf()) }
// Scroll the content to the bottom when any of these changes.
LaunchedEffect(
@ -158,9 +160,27 @@ fun ChatPanel(
lastMessageContent.value,
WindowInsets.ime.getBottom(density),
) {
// Only scroll if showingStatsByModel is not changed. In other words, when showingStatsByModel
// changes we want the display to not scroll.
if (messages.isNotEmpty()) {
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
if (uiState.showingStatsByModel === lastShowingStatsByModel.value) {
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
} else {
// Scroll to bottom if the message to show stats is the last message.
val curShowingStats =
uiState.showingStatsByModel[selectedModel.name]?.toMutableSet() ?: mutableSetOf()
val lastShowingStats = lastShowingStatsByModel.value[selectedModel.name] ?: mutableSetOf()
curShowingStats.removeAll(lastShowingStats)
if (curShowingStats.isNotEmpty()) {
val index =
viewModel.getMessageIndex(model = selectedModel, message = curShowingStats.first())
if (index == messages.size - 2) {
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
}
}
}
}
lastShowingStatsByModel.value = uiState.showingStatsByModel
}
val nestedScrollConnection = remember {
@ -309,7 +329,51 @@ fun ChatPanel(
}
}
if (message.side == ChatSide.AGENT) {
LatencyText(message = message)
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
) {
LatencyText(message = message)
// A button to show stats for the LLM message.
if (selectedModel.taskType == TaskType.LLM_CHAT && message is ChatMessageText
// This means we only want to show the action button when the message is done
// generating, at which point the latency will be set.
&& message.latencyMs >= 0
) {
val showingStats =
viewModel.isShowingStats(model = selectedModel, message = message)
MessageActionButton(
label = if (showingStats) "Hide stats" else "Show stats",
icon = Icons.Outlined.Timer,
onClick = {
// Toggle showing stats.
viewModel.toggleShowingStats(selectedModel, message)
// Add the stats message after the LLM message.
if (viewModel.isShowingStats(model = selectedModel, message = message)) {
val llmBenchmarkResult = message.llmBenchmarkResult
if (llmBenchmarkResult != null) {
viewModel.insertMessageAfter(
model = selectedModel,
anchorMessage = message,
messageToAdd = llmBenchmarkResult,
)
}
}
// Remove the stats message.
else {
val curMessageIndex =
viewModel.getMessageIndex(model = selectedModel, message = message)
viewModel.removeMessageAt(
model = selectedModel,
index = curMessageIndex + 1
)
}
},
enabled = !uiState.inProgress
)
}
}
} else if (message.side == ChatSide.USER) {
Row(
verticalAlignment = Alignment.CenterVertically,
@ -328,21 +392,21 @@ fun ChatPanel(
}
// Benchmark button
if (selectedModel.showBenchmarkButton) {
MessageActionButton(
label = stringResource(R.string.benchmark),
icon = Icons.Outlined.Timer,
onClick = {
if (selectedModel.taskType == TaskType.LLM_CHAT) {
onBenchmarkClicked(selectedModel, message, 0, 0)
} else {
showBenchmarkConfigsDialog = true
benchmarkMessage.value = message
}
},
enabled = !uiState.inProgress
)
}
// if (selectedModel.showBenchmarkButton) {
// MessageActionButton(
// label = stringResource(R.string.benchmark),
// icon = Icons.Outlined.Timer,
// onClick = {
// if (selectedModel.taskType == TaskType.LLM_CHAT) {
// onBenchmarkClicked(selectedModel, message, 0, 0)
// } else {
// showBenchmarkConfigsDialog = true
// benchmarkMessage.value = message
// }
// },
// enabled = !uiState.inProgress
// )
// }
}
}
}

View file

@ -43,6 +43,12 @@ data class ChatUiState(
* A map of model names to the currently streaming chat message.
*/
val streamingMessagesByModel: Map<String, ChatMessage> = mapOf(),
/*
* A map of model names to a map of chat messages to a boolean indicating whether the message is
* showing the stats below it.
*/
val showingStatsByModel: Map<String, MutableSet<ChatMessage>> = mapOf(),
)
/**
@ -66,6 +72,31 @@ open class ChatViewModel(val task: Task) : ViewModel() {
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
}
fun insertMessageAfter(model: Model, anchorMessage: ChatMessage, messageToAdd: ChatMessage) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList()
if (newMessages != null) {
newMessagesByModel[model.name] = newMessages
// Find the index of the anchor message
val anchorIndex = newMessages.indexOf(anchorMessage)
if (anchorIndex != -1) {
// Insert the new message after the anchor message
newMessages.add(anchorIndex + 1, messageToAdd)
}
}
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
}
fun removeMessageAt(model: Model, index: Int) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList()
if (newMessages != null) {
newMessagesByModel[model.name] = newMessages
newMessages.removeAt(index)
}
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
}
fun removeLastMessage(model: Model) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
@ -80,7 +111,7 @@ open class ChatViewModel(val task: Task) : ViewModel() {
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
}
fun updateLastMessageContentIncrementally(
fun updateLastTextMessageContentIncrementally(
model: Model,
partialContent: String,
latencyMs: Float,
@ -124,6 +155,25 @@ open class ChatViewModel(val task: Task) : ViewModel() {
_uiState.update { newUiState }
}
fun updateLastTextMessageLlmBenchmarkResult(
model: Model,
llmBenchmarkResult: ChatMessageBenchmarkLlmResult
) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
if (newMessages.size > 0) {
val lastMessage = newMessages.last()
if (lastMessage is ChatMessageText) {
lastMessage.llmBenchmarkResult = llmBenchmarkResult
newMessages.removeAt(newMessages.size - 1)
newMessages.add(lastMessage)
}
}
newMessagesByModel[model.name] = newMessages
val newUiState = _uiState.value.copy(messagesByModel = newMessagesByModel)
_uiState.update { newUiState }
}
fun replaceLastMessage(model: Model, message: ChatMessage, type: ChatMessageType) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
@ -159,10 +209,6 @@ open class ChatViewModel(val task: Task) : ViewModel() {
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
}
fun isInProgress(): Boolean {
return _uiState.value.inProgress
}
fun addConfigChangedMessage(
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
) {
@ -173,6 +219,26 @@ open class ChatViewModel(val task: Task) : ViewModel() {
addMessage(message = message, model = model)
}
fun getMessageIndex(model: Model, message: ChatMessage): Int {
return (_uiState.value.messagesByModel[model.name] ?: listOf()).indexOf(message)
}
fun isShowingStats(model: Model, message: ChatMessage): Boolean {
return _uiState.value.showingStatsByModel[model.name]?.contains(message) ?: false
}
fun toggleShowingStats(model: Model, message: ChatMessage) {
val newShowingStatsByModel = _uiState.value.showingStatsByModel.toMutableMap()
val newShowingStats = newShowingStatsByModel[model.name]?.toMutableSet() ?: mutableSetOf()
if (newShowingStats.contains(message)) {
newShowingStats.remove(message)
} else {
newShowingStats.add(message)
}
newShowingStatsByModel[model.name] = newShowingStats
_uiState.update { _uiState.value.copy(showingStatsByModel = newShowingStatsByModel) }
}
private fun createUiState(task: Task): ChatUiState {
val messagesByModel: MutableMap<String, MutableList<ChatMessage>> = mutableMapOf()
for (model in task.models) {

View file

@ -179,7 +179,7 @@ private fun getMessageLayoutConfig(
is ChatMessageBenchmarkLlmResult -> {
horizontalArrangement = Arrangement.SpaceBetween
modifier = modifier.fillMaxWidth()
userLabel = "Benchmark"
userLabel = "Stats"
}
is ChatMessageImageWithHistory -> {

View file

@ -38,6 +38,7 @@ import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material.icons.rounded.UnfoldLess
import androidx.compose.material.icons.rounded.UnfoldMore
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.OutlinedButton
@ -90,6 +91,7 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
* model description and buttons for learning more (opening a URL) and downloading/trying
* the model.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelItem(
model: Model,
@ -217,7 +219,7 @@ fun ModelItem(
.then(m),
horizontalArrangement = Arrangement.spacedBy(12.dp),
) {
// The "learn more" button. Click to show url in default browser.
// The "learn more" button. Click to show related urls in a bottom sheet.
if (model.learnMoreUrl.isNotEmpty()) {
OutlinedButton(
onClick = {
@ -241,35 +243,6 @@ fun ModelItem(
modelManagerViewModel = modelManagerViewModel,
onClicked = { onModelClicked(model) }
)
// Button(
// onClick = {
// if (isExpanded) {
// onModelClicked(model)
// if (needToDownloadFirst) {
// scope.launch {
// delay(80)
// checkNotificationPermissonAndStartDownload(
// context = context,
// launcher = launcher,
// modelManagerViewModel = modelManagerViewModel,
// model = model
// )
// }
// }
// }
// },
// ) {
// Icon(
// Icons.AutoMirrored.Rounded.ArrowForward,
// contentDescription = "",
// modifier = Modifier.padding(end = 4.dp)
// )
// if (needToDownloadFirst) {
// Text("Download & Try it", maxLines = 1)
// } else {
// Text("Try it", maxLines = 1)
// }
// }
}
}
}
@ -366,7 +339,6 @@ fun ModelItem(
}
}
}
}
@Preview(showBackground = true)

View file

@ -17,7 +17,7 @@
package com.google.aiedge.gallery.ui.home
import androidx.compose.runtime.Composable
import com.google.aiedge.gallery.VERSION
import com.google.aiedge.gallery.BuildConfig
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.SegmentedButtonConfig
@ -45,7 +45,7 @@ fun SettingsDialog(
)
ConfigDialog(
title = "Settings",
subtitle = "App version: $VERSION",
subtitle = "App version: ${BuildConfig.VERSION_NAME}",
okBtnLabel = "OK",
configs = CONFIGS,
initialValues = initialValues,

View file

@ -18,12 +18,11 @@ package com.google.aiedge.gallery.ui.llmchat
import android.content.Context
import android.util.Log
import com.google.common.util.concurrent.ListenableFuture
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.LlmBackend
import com.google.aiedge.gallery.data.Model
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
private const val TAG = "AGLlmChatModelHelper"
private const val DEFAULT_MAX_TOKEN = 1024
@ -39,8 +38,6 @@ data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceS
object LlmChatModelHelper {
// Indexed by model name.
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
private val generateResponseListenableFutures: MutableMap<String, ListenableFuture<String>> =
mutableMapOf()
fun initialize(
context: Context, model: Model, onDone: () -> Unit
@ -64,13 +61,12 @@ object LlmChatModelHelper {
try {
val llmInference = LlmInference.createFromOptions(context, options)
// val session = LlmInferenceSession.createFromOptions(
// llmInference,
// LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
// .setTemperature(temperature).build()
// )
model.instance = llmInference
// LlmModelInstance(engine = llmInference, session = session)
val session = LlmInferenceSession.createFromOptions(
llmInference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature).build()
)
model.instance = LlmModelInstance(engine = llmInference, session = session)
} catch (e: Exception) {
e.printStackTrace()
}
@ -82,11 +78,10 @@ object LlmChatModelHelper {
return
}
val instance = model.instance as LlmInference
val instance = model.instance as LlmModelInstance
try {
instance.close()
// instance.session.close()
// instance.engine.close()
instance.session.close()
instance.engine.close()
} catch (e: Exception) {
// ignore
}
@ -104,7 +99,7 @@ object LlmChatModelHelper {
resultListener: ResultListener,
cleanUpListener: CleanUpListener,
) {
val instance = model.instance as LlmInference
val instance = model.instance as LlmModelInstance
// Set listener.
if (!cleanUpListeners.containsKey(model.name)) {
@ -112,24 +107,8 @@ object LlmChatModelHelper {
}
// Start async inference.
val future = instance.generateResponseAsync(input, resultListener)
generateResponseListenableFutures[model.name] = future
// val session = instance.session
// TODO: need to count token and reset session.
// session.addQueryChunk(input)
// session.generateResponseAsync(resultListener)
}
fun stopInference(model: Model) {
val instance = model.instance as LlmInference
if (instance != null) {
instance.close()
}
// val future = generateResponseListenableFutures[model.name]
// if (future != null) {
// future.cancel(true)
// generateResponseListenableFutures.remove(model.name)
// }
val session = instance.session
session.addQueryChunk(input)
session.generateResponseAsync(resultListener)
}
}

View file

@ -17,7 +17,6 @@
package com.google.aiedge.gallery.ui.llmchat
import androidx.lifecycle.viewModelScope
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
@ -56,11 +55,31 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
}
// Run inference.
val instance = model.instance as LlmModelInstance
val prefillTokens = instance.session.sizeInTokens(input)
var firstRun = true
var timeToFirstToken = 0f
var firstTokenTs = 0L
var decodeTokens = 0
var prefillSpeed = 0f
var decodeSpeed: Float
val start = System.currentTimeMillis()
LlmChatModelHelper.runInference(
model = model,
input = input,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
if (firstRun) {
firstTokenTs = System.currentTimeMillis()
timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken
firstRun = false
} else {
decodeTokens++
}
// Remove the last message if it is a "loading" message.
// This will only be done once.
val lastMessage = getLastMessage(model = model)
@ -76,7 +95,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
// Incrementally update the streamed partial results.
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
updateLastMessageContentIncrementally(
updateLastTextMessageContentIncrementally(
model = model,
partialContent = partialResult,
latencyMs = latencyMs.toFloat()
@ -84,6 +103,29 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
if (done) {
setInProgress(false)
decodeSpeed =
decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) {
decodeSpeed = 0f
}
if (lastMessage is ChatMessageText) {
updateLastTextMessageLlmBenchmarkResult(
model = model, llmBenchmarkResult =
ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = false,
latencyMs = -1f,
)
)
}
}
}, cleanUpListener = {
setInProgress(false)
@ -117,8 +159,8 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
while (model.instance == null) {
delay(100)
}
val instance = model.instance as LlmInference
val prefillTokens = instance.sizeInTokens(message.content)
val instance = model.instance as LlmModelInstance
val prefillTokens = instance.session.sizeInTokens(message.content)
// Add the message to show benchmark results.
val benchmarkLlmResult = ChatMessageBenchmarkLlmResult(

View file

@ -16,21 +16,34 @@
package com.google.aiedge.gallery.ui.modelmanager
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Code
import androidx.compose.material.icons.outlined.Description
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalUriHandler
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextDecoration
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
@ -39,9 +52,9 @@ import com.google.aiedge.gallery.ui.common.modelitem.ModelItem
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
/** The list of models in the model manager. */
@OptIn(ExperimentalFoundationApi::class)
@Composable
fun ModelList(
task: Task,
@ -62,11 +75,40 @@ fun ModelList(
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier
.padding(bottom = 20.dp)
.fillMaxWidth()
)
}
// URLs.
item(key = "urls") {
Row(
horizontalArrangement = Arrangement.Center,
modifier = Modifier
.fillMaxWidth()
.padding(top = 12.dp, bottom = 16.dp),
) {
Column(
horizontalAlignment = Alignment.Start,
verticalArrangement = Arrangement.spacedBy(4.dp),
) {
if (task.docUrl.isNotEmpty()) {
ClickableLink(
url = task.docUrl,
linkText = "API Documentation",
icon = Icons.Outlined.Description
)
}
if (task.sourceCodeUrl.isNotEmpty()) {
ClickableLink(
url = task.sourceCodeUrl,
linkText = "Example code",
icon = Icons.Outlined.Code
)
}
}
}
}
// List of models within a task.
items(items = task.models) { model ->
Box {
@ -82,6 +124,45 @@ fun ModelList(
}
}
@Composable
fun ClickableLink(
url: String,
linkText: String,
icon: ImageVector,
) {
val uriHandler = LocalUriHandler.current
val annotatedText = AnnotatedString(
text = linkText,
spanStyles = listOf(
AnnotatedString.Range(
item = SpanStyle(
color = MaterialTheme.customColors.linkColor,
textDecoration = TextDecoration.Underline
),
start = 0,
end = linkText.length
)
)
)
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.Center,
) {
Icon(icon, contentDescription = "", modifier = Modifier.size(16.dp))
Text(
text = annotatedText,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodySmall,
modifier = Modifier
.padding(start = 6.dp)
.clickable {
uriHandler.openUri(url)
},
)
}
}
@Preview(showBackground = true)
@Composable
fun ModelListPreview() {

View file

@ -21,9 +21,6 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Scaffold
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
@ -31,9 +28,7 @@ import com.google.aiedge.gallery.GalleryTopAppBar
import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
@ -48,9 +43,6 @@ fun ModelManager(
onModelClicked: (Model) -> Unit,
modifier: Modifier = Modifier,
) {
val uiState by viewModel.uiState.collectAsState()
val coroutineScope = rememberCoroutineScope()
// Set title based on the task.
var title = "${task.type.label} model"
if (task.models.size != 1) {
@ -67,27 +59,7 @@ fun ModelManager(
topBar = {
GalleryTopAppBar(
title = title,
// subtitle = String.format(
// stringResource(R.string.downloaded_size),
// totalSizeInBytes.humanReadableSize()
// ),
// Refresh model list button at the left side of the app bar.
// leftAction = AppBarAction(actionType = if (uiState.loadingHfModels) {
// AppBarActionType.REFRESHING_MODELS
// } else {
// AppBarActionType.REFRESH_MODELS
// }, actionFn = {
// coroutineScope.launch(Dispatchers.IO) {
// viewModel.loadHfModels()
// }
// }),
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp)
// "Done" button at the right side of the app bar to navigate up.
// rightAction = AppBarAction(
// actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp
// ),
)
},
) { innerPadding ->
@ -101,19 +73,6 @@ fun ModelManager(
}
}
private fun getTotalDownloadedFileSize(uiState: ModelManagerUiState): Long {
var totalSizeInBytes = 0L
for ((name, status) in uiState.modelDownloadStatus.entries) {
if (status.status == ModelDownloadStatusType.SUCCEEDED) {
totalSizeInBytes += getModelByName(name)?.totalBytes ?: 0L
} else if (status.status == ModelDownloadStatusType.IN_PROGRESS) {
totalSizeInBytes += status.receivedBytes
}
}
return totalSizeInBytes
}
@Preview
@Composable
fun ModelManagerPreview() {

View file

@ -115,26 +115,27 @@ data class CustomColors(
val homeBottomGradient: List<Color> = listOf(),
val userBubbleBgColor: Color = Color.Transparent,
val agentBubbleBgColor: Color = Color.Transparent,
val linkColor: Color = Color.Transparent,
)
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
val lightCustomColors = CustomColors(
taskBgColors = listOf(
// green
Color(0xFFE1F6DE),
// blue
Color(0xFFEDF0FF),
// yellow
Color(0xFFFFEFC9),
// red
Color(0xFFFFEDE6),
// green
Color(0xFFE1F6DE),
// blue
Color(0xFFEDF0FF)
),
taskIconColors = listOf(
Color(0xFFE37400),
Color(0xFFD93025),
Color(0xFF34A853),
Color(0xFF1967D2),
Color(0xFFE37400),
Color(0xFFD93025),
),
taskIconShapeBgColor = Color.White,
homeBottomGradient = listOf(
@ -143,24 +144,25 @@ val lightCustomColors = CustomColors(
),
agentBubbleBgColor = Color(0xFFe9eef6),
userBubbleBgColor = Color(0xFF32628D),
linkColor = Color(0xFF32628D),
)
val darkCustomColors = CustomColors(
taskBgColors = listOf(
// green
Color(0xFF2E312D),
// blue
Color(0xFF303033),
// yellow
Color(0xFF33302A),
// red
Color(0xFF362F2D),
// green
Color(0xFF2E312D),
// blue
Color(0xFF303033)
),
taskIconColors = listOf(
Color(0xFFFFB955),
Color(0xFFFFB4AB),
Color(0xFF6DD58C),
Color(0xFFAAC7FF),
Color(0xFFFFB955),
Color(0xFFFFB4AB),
),
taskIconShapeBgColor = Color(0xFF202124),
homeBottomGradient = listOf(
@ -169,6 +171,7 @@ val darkCustomColors = CustomColors(
),
agentBubbleBgColor = Color(0xFF1b1c1d),
userBubbleBgColor = Color(0xFF1f3760),
linkColor = Color(0xFF9DCAFC),
)
val MaterialTheme.customColors: CustomColors

View file

@ -35,7 +35,7 @@
<string name="model_is_initializing_msg">Initializing model…</string>
<string name="text_input_placeholder_text_classification">Type movie review to classify…</string>
<string name="text_image_generation_text_field_placeholder">Type prompt…</string>
<string name="text_input_placeholder_llm_chat">Type prompt (one shot)</string>
<string name="text_input_placeholder_llm_chat">Type prompt…</string>
<string name="run_again">Run again</string>
<string name="benchmark">Run benchmark</string>
<string name="warming_up">warming up…</string>