Add support for LLM single-turn experience

This commit is contained in:
Jing Jin 2025-04-26 21:35:03 -07:00
parent 46aaee2654
commit d94fec0674
39 changed files with 2376 additions and 295 deletions

View file

@ -30,7 +30,7 @@ android {
minSdk = 24 minSdk = 24
targetSdk = 35 targetSdk = 35
versionCode = 1 versionCode = 1
versionName = "20250421" versionName = "20250428"
// Needed for HuggingFace auth workflows. // Needed for HuggingFace auth workflows.
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth" manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"

View file

@ -187,6 +187,5 @@ fun GalleryTopAppBar(
else -> {} else -> {}
} }
} }
) )
} }

View file

@ -23,7 +23,7 @@ import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.preferencesDataStore import androidx.datastore.preferences.preferencesDataStore
import com.google.aiedge.gallery.data.AppContainer import com.google.aiedge.gallery.data.AppContainer
import com.google.aiedge.gallery.data.DefaultAppContainer import com.google.aiedge.gallery.data.DefaultAppContainer
import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.ui.common.processTasks
import com.google.aiedge.gallery.ui.theme.ThemeSettings import com.google.aiedge.gallery.ui.theme.ThemeSettings
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences") private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
@ -36,12 +36,7 @@ class GalleryApplication : Application() {
super.onCreate() super.onCreate()
// Process tasks. // Process tasks.
for ((index, task) in TASKS.withIndex()) { processTasks()
task.index = index
for (model in task.models) {
model.preProcess(task = task)
}
}
container = DefaultAppContainer(this, dataStore) container = DefaultAppContainer(this, dataStore)

View file

@ -92,15 +92,13 @@ data class Model(
val imported: Boolean = false, val imported: Boolean = false,
// The following fields are managed by the app. Don't need to set manually. // The following fields are managed by the app. Don't need to set manually.
var taskType: TaskType? = null,
var instance: Any? = null, var instance: Any? = null,
var initializing: Boolean = false, var initializing: Boolean = false,
var configValues: Map<String, Any> = mapOf(), var configValues: Map<String, Any> = mapOf(),
var totalBytes: Long = 0L, var totalBytes: Long = 0L,
var accessToken: String? = null, var accessToken: String? = null,
) { ) {
fun preProcess(task: Task) { fun preProcess() {
this.taskType = task.type
val configValues: MutableMap<String, Any> = mutableMapOf() val configValues: MutableMap<String, Any> = mutableMapOf()
for (config in this.configs) { for (config in this.configs) {
configValues[config.key.label] = config.defaultValue configValues[config.key.label] = config.defaultValue
@ -246,6 +244,7 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin", url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
sizeInBytes = 1354301440L, sizeInBytes = 1354301440L,
configs = createLlmChatConfigs(), configs = createLlmChatConfigs(),
showBenchmarkButton = false,
info = LLM_CHAT_INFO, info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community", learnMoreUrl = "https://huggingface.co/litert-community",
) )
@ -256,6 +255,7 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin", url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
sizeInBytes = 2627141632L, sizeInBytes = 2627141632L,
configs = createLlmChatConfigs(), configs = createLlmChatConfigs(),
showBenchmarkButton = false,
info = LLM_CHAT_INFO, info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community", learnMoreUrl = "https://huggingface.co/litert-community",
) )
@ -271,6 +271,7 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
defaultTopP = 0.95f, defaultTopP = 0.95f,
accelerators = listOf(Accelerator.CPU, Accelerator.GPU) accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
), ),
showBenchmarkButton = false,
info = LLM_CHAT_INFO, info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT", learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
llmPromptTemplates = listOf( llmPromptTemplates = listOf(
@ -299,6 +300,7 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
defaultTopP = 0.7f, defaultTopP = 0.7f,
accelerators = listOf(Accelerator.CPU) accelerators = listOf(Accelerator.CPU)
), ),
showBenchmarkButton = false,
info = LLM_CHAT_INFO, info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B", learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
) )
@ -389,7 +391,7 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2, MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
) )
val MODELS_LLM_CHAT: MutableList<Model> = mutableListOf( val MODELS_LLM: MutableList<Model> = mutableListOf(
MODEL_LLM_GEMMA_2B_GPU_INT4, MODEL_LLM_GEMMA_2B_GPU_INT4,
MODEL_LLM_GEMMA_2_2B_GPU_INT8, MODEL_LLM_GEMMA_2_2B_GPU_INT8,
MODEL_LLM_GEMMA_3_1B_INT4, MODEL_LLM_GEMMA_3_1B_INT4,

View file

@ -18,9 +18,11 @@ package com.google.aiedge.gallery.data
import androidx.annotation.StringRes import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Forum
import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.material.icons.rounded.ImageSearch import androidx.compose.material.icons.rounded.ImageSearch
import androidx.compose.runtime.MutableState import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
@ -30,6 +32,7 @@ enum class TaskType(val label: String) {
IMAGE_CLASSIFICATION("Image Classification"), IMAGE_CLASSIFICATION("Image Classification"),
IMAGE_GENERATION("Image Generation"), IMAGE_GENERATION("Image Generation"),
LLM_CHAT("LLM Chat"), LLM_CHAT("LLM Chat"),
LLM_SINGLE_TURN("LLM Use Cases"),
TEST_TASK_1("Test task 1"), TEST_TASK_1("Test task 1"),
TEST_TASK_2("Test task 2") TEST_TASK_2("Test task 2")
@ -67,7 +70,7 @@ data class Task(
// The following fields are managed by the app. Don't need to set manually. // The following fields are managed by the app. Don't need to set manually.
var index: Int = -1, var index: Int = -1,
val updateTrigger: MutableState<Long> = mutableStateOf(0) val updateTrigger: MutableState<Long> = mutableLongStateOf(0)
) )
val TASK_TEXT_CLASSIFICATION = Task( val TASK_TEXT_CLASSIFICATION = Task(
@ -87,9 +90,19 @@ val TASK_IMAGE_CLASSIFICATION = Task(
val TASK_LLM_CHAT = Task( val TASK_LLM_CHAT = Task(
type = TaskType.LLM_CHAT, type = TaskType.LLM_CHAT,
iconVectorResourceId = R.drawable.chat_spark, icon = Icons.Outlined.Forum,
models = MODELS_LLM_CHAT, models = MODELS_LLM,
description = "Chat? with a on-device large language model", description = "Chat with a on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_LLM_SINGLE_TURN = Task(
type = TaskType.LLM_SINGLE_TURN,
icon = Icons.Outlined.Widgets,
models = MODELS_LLM,
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
@ -108,9 +121,10 @@ val TASK_IMAGE_GENERATION = Task(
/** All tasks. */ /** All tasks. */
val TASKS: List<Task> = listOf( val TASKS: List<Task> = listOf(
// TASK_TEXT_CLASSIFICATION, // TASK_TEXT_CLASSIFICATION,
// TASK_IMAGE_CLASSIFICATION, TASK_IMAGE_CLASSIFICATION,
TASK_IMAGE_GENERATION, TASK_IMAGE_GENERATION,
TASK_LLM_CHAT, TASK_LLM_CHAT,
TASK_LLM_SINGLE_TURN,
) )
fun getModelByName(name: String): Model? { fun getModelByName(name: String): Model? {

View file

@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
@ -56,6 +57,11 @@ object ViewModelProvider {
LlmChatViewModel() LlmChatViewModel()
} }
// Initializer for LlmSingleTurnViewModel..
initializer {
LlmSingleTurnViewModel()
}
initializer { initializer {
ImageGenerationViewModel() ImageGenerationViewModel()
} }

View file

@ -0,0 +1,213 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.ArrowDropDown
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelPageAppBar(
task: Task,
model: Model,
modelManagerViewModel: ModelManagerViewModel,
onBackClicked: () -> Unit,
onModelSelected: (Model) -> Unit,
modifier: Modifier = Modifier,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
) {
var showConfigDialog by remember { mutableStateOf(false) }
var showModelPicker by remember { mutableStateOf(false) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val context = LocalContext.current
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
CenterAlignedTopAppBar(title = {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
// Task type.
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
tint = getTaskIconColor(task = task),
modifier = Modifier.size(16.dp),
contentDescription = "",
)
Text(
task.type.label,
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
color = getTaskIconColor(task = task)
)
}
// Model name.
Row(verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
.clickable {
showModelPicker = true
}
.padding(start = 8.dp, end = 2.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text(
model.name,
style = MaterialTheme.typography.labelSmall,
modifier = Modifier.padding(start = 4.dp),
)
Icon(
Icons.Rounded.ArrowDropDown,
modifier = Modifier.size(20.dp),
contentDescription = "",
)
}
}
}, modifier = modifier,
// The back button.
navigationIcon = {
IconButton(onClick = onBackClicked) {
Icon(
imageVector = Icons.AutoMirrored.Rounded.ArrowBack,
contentDescription = "",
)
}
},
// The config button for the model (if existed).
actions = {
if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
IconButton(onClick = { showConfigDialog = true }) {
Icon(
imageVector = Icons.Rounded.Settings,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
})
// Config dialog.
if (showConfigDialog) {
ConfigDialog(
title = "Model configs",
configs = model.configs,
initialValues = model.configValues,
onDismissed = { showConfigDialog = false },
onOk = { curConfigValues ->
// Hide config dialog.
showConfigDialog = false
// Check if the configs are changed or not. Also check if the model needs to be
// re-initialized.
var same = true
var needReinitialization = false
for (config in model.configs) {
val key = config.key.label
val oldValue = convertValueToTargetType(
value = model.configValues.getValue(key), valueType = config.valueType
)
val newValue = convertValueToTargetType(
value = curConfigValues.getValue(key), valueType = config.valueType
)
if (oldValue != newValue) {
same = false
if (config.needReinitialization) {
needReinitialization = true
}
break
}
}
if (same) {
return@ConfigDialog
}
// Save the config values to Model.
val oldConfigValues = model.configValues
model.configValues = curConfigValues
// Force to re-initialize the model with the new configs.
if (needReinitialization) {
modelManagerViewModel.initializeModel(
context = context, task = task, model = model, force = true
)
}
// Notify.
onConfigChanged(oldConfigValues, model.configValues)
},
)
}
// Model picker.
if (showModelPicker) {
ModalBottomSheet(
onDismissRequest = { showModelPicker = false },
sheetState = sheetState,
) {
ModelPicker(
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = { model ->
showModelPicker = false
onModelSelected(model)
}
)
}
}
}

View file

@ -0,0 +1,132 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.CheckCircle
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
@Composable
fun ModelPicker(
task: Task,
modelManagerViewModel: ModelManagerViewModel,
onModelSelected: (Model) -> Unit
) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
Column(modifier = Modifier.padding(bottom = 8.dp)) {
// Title
Row(
modifier = Modifier
.padding(horizontal = 16.dp)
.padding(top = 4.dp, bottom = 4.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
) {
Icon(
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
tint = getTaskIconColor(task = task),
modifier = Modifier.size(16.dp),
contentDescription = "",
)
Text(
"${task.type.label} models",
modifier = Modifier.fillMaxWidth(),
style = MaterialTheme.typography.titleMedium,
color = getTaskIconColor(task = task),
)
}
// Model list.
for (model in task.models) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier
.fillMaxWidth()
.clickable {
onModelSelected(model)
}
.padding(horizontal = 16.dp, vertical = 8.dp),
) {
Spacer(modifier = Modifier.width(24.dp))
Column(modifier = Modifier.weight(1f)) {
Text(model.name, style = MaterialTheme.typography.bodyMedium)
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text(
model.sizeInBytes.humanReadableSize(),
color = MaterialTheme.colorScheme.secondary,
style = labelSmallNarrow.copy(lineHeight = 10.sp)
)
}
}
if (model.name == modelManagerUiState.selectedModel.name) {
Icon(
Icons.Outlined.CheckCircle,
modifier = Modifier.size(16.dp),
contentDescription = ""
)
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ModelPickerPreview() {
val context = LocalContext.current
GalleryTheme {
ModelPicker(
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
onModelSelected = {},
)
}
}

View file

@ -29,6 +29,7 @@ import androidx.core.content.ContextCompat
import androidx.core.content.FileProvider import androidx.core.content.FileProvider
import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkResult import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkResult
@ -57,6 +58,9 @@ interface LatencyProvider {
val latencyMs: Float val latencyMs: Float
} }
private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***"
/** Format the bytes into a human-readable format. */ /** Format the bytes into a human-readable format. */
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String { fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
val bytes = this val bytes = this
@ -452,3 +456,33 @@ fun cleanUpMediapipeTaskErrorMessage(message: String): String {
} }
return message return message
} }
fun processTasks() {
for ((index, task) in TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess()
}
}
}
fun processLlmResponse(response: String): String {
// Add "thinking" and "done thinking" around the thinking content.
var newContent = response
.replace("<think>", "$START_THINKING\n")
.replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) {
val thinkingContent =
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
.replace(START_THINKING, "")
.replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
}
}
return newContent
}

View file

@ -38,6 +38,7 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.width
import androidx.compose.foundation.layout.wrapContentHeight import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.layout.wrapContentWidth
import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState import androidx.compose.foundation.lazy.rememberLazyListState
@ -334,7 +335,10 @@ fun ChatPanel(
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message) is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
// Benchmark LLM result. // Benchmark LLM result.
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(message = message) is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
message = message,
modifier = Modifier.wrapContentWidth()
)
else -> {} else -> {}
} }
@ -346,7 +350,7 @@ fun ChatPanel(
) { ) {
LatencyText(message = message) LatencyText(message = message)
// A button to show stats for the LLM message. // A button to show stats for the LLM message.
if (selectedModel.taskType == TaskType.LLM_CHAT && message is ChatMessageText if (task.type == TaskType.LLM_CHAT && message is ChatMessageText
// This means we only want to show the action button when the message is done // This means we only want to show the action button when the message is done
// generating, at which point the latency will be set. // generating, at which point the latency will be set.
&& message.latencyMs >= 0 && message.latencyMs >= 0
@ -403,21 +407,17 @@ fun ChatPanel(
} }
// Benchmark button // Benchmark button
// if (selectedModel.showBenchmarkButton) { if (selectedModel.showBenchmarkButton) {
// MessageActionButton( MessageActionButton(
// label = stringResource(R.string.benchmark), label = stringResource(R.string.benchmark),
// icon = Icons.Outlined.Timer, icon = Icons.Outlined.Timer,
// onClick = { onClick = {
// if (selectedModel.taskType == TaskType.LLM_CHAT) { showBenchmarkConfigsDialog = true
// onBenchmarkClicked(selectedModel, message, 0, 0) benchmarkMessage.value = message
// } else { },
// showBenchmarkConfigsDialog = true enabled = !uiState.inProgress
// benchmarkMessage.value = message )
// } }
// },
// enabled = !uiState.inProgress
// )
// }
} }
} }
} }
@ -443,7 +443,7 @@ fun ChatPanel(
// Chat input // Chat input
when (chatInputType) { when (chatInputType) {
ChatInputType.TEXT -> { ChatInputType.TEXT -> {
val isLlmTask = selectedModel.taskType == TaskType.LLM_CHAT val isLlmTask = task.type == TaskType.LLM_CHAT
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates) val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
MessageInputText( MessageInputText(
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,

View file

@ -18,18 +18,10 @@ package com.google.aiedge.gallery.ui.common.chat
import android.util.Log import android.util.Log
import androidx.activity.compose.BackHandler import androidx.activity.compose.BackHandler
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.pager.HorizontalPager import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState import androidx.compose.foundation.pager.rememberPagerState
@ -40,29 +32,21 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import com.google.aiedge.gallery.GalleryTopAppBar
import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload import com.google.aiedge.gallery.ui.common.ModelPageAppBar
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewChatModel import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlin.math.absoluteValue import kotlin.math.absoluteValue
@ -77,7 +61,6 @@ private const val TAG = "AGChatView"
* manages model initialization, cleanup, and download status, and handles navigation and system * manages model initialization, cleanup, and download status, and handles navigation and system
* back gestures. * back gestures.
*/ */
@OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun ChatView( fun ChatView(
task: Task, task: Task,
@ -96,34 +79,29 @@ fun ChatView(
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel val selectedModel = modelManagerUiState.selectedModel
val pagerState = rememberPagerState(initialPage = task.models.indexOf(selectedModel), val pagerState = rememberPagerState(
initialPage = task.models.indexOf(selectedModel),
pageCount = { task.models.size }) pageCount = { task.models.size })
val context = LocalContext.current val context = LocalContext.current
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(task = task, model = selectedModel)
}
val handleNavigateUp = { val handleNavigateUp = {
navigateUp() navigateUp()
// clean up all models. // clean up all models.
scope.launch(Dispatchers.Default) { scope.launch(Dispatchers.Default) {
for (model in task.models) { for (model in task.models) {
modelManagerViewModel.cleanupModel(model = model) modelManagerViewModel.cleanupModel(task = task, model = model)
} }
} }
} }
// Initialize model when model/download state changes. // Initialize model when model/download state changes.
val status = modelManagerUiState.modelDownloadStatus[selectedModel.name] val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
LaunchedEffect(status, selectedModel.name) { LaunchedEffect(curDownloadStatus, selectedModel.name) {
if (status?.status == ModelDownloadStatusType.SUCCEEDED) { if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect") Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
modelManagerViewModel.initializeModel(context, model = selectedModel) modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
} }
} }
@ -135,7 +113,7 @@ fun ChatView(
"Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model." "Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model."
) )
if (curSelectedModel.name != selectedModel.name) { if (curSelectedModel.name != selectedModel.name) {
modelManagerViewModel.cleanupModel(model = selectedModel) modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
} }
modelManagerViewModel.selectModel(curSelectedModel) modelManagerViewModel.selectModel(curSelectedModel)
} }
@ -146,24 +124,36 @@ fun ChatView(
} }
Scaffold(modifier = modifier, topBar = { Scaffold(modifier = modifier, topBar = {
GalleryTopAppBar( ModelPageAppBar(
title = task.type.label, task = task,
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = { model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage(
oldConfigValues = old,
newConfigValues = new,
model = selectedModel
)
},
onBackClicked = {
handleNavigateUp() handleNavigateUp()
}), },
rightAction = AppBarAction(actionType = AppBarActionType.NO_ACTION, actionFn = {}), onModelSelected = { model ->
scope.launch {
pagerState.animateScrollToPage(task.models.indexOf(model))
}
},
) )
}) { innerPadding -> }) { innerPadding ->
Box { Box {
// A horizontal scrollable pager to switch between models. // A horizontal scrollable pager to switch between models.
HorizontalPager(state = pagerState) { pageIndex -> HorizontalPager(state = pagerState) { pageIndex ->
val curSelectedModel = task.models[pageIndex] val curSelectedModel = task.models[pageIndex]
val curModelDownloadStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name]
// Calculate the alpha of the current page based on how far they are from the center. // Calculate the alpha of the current page based on how far they are from the center.
val pageOffset = ( val pageOffset =
(pagerState.currentPage - pageIndex) + pagerState ((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
.currentPageOffsetFraction
).absoluteValue
val curAlpha = 1f - pageOffset.coerceIn(0f, 1f) val curAlpha = 1f - pageOffset.coerceIn(0f, 1f)
Column( Column(
@ -172,91 +162,14 @@ fun ChatView(
.fillMaxSize() .fillMaxSize()
.background(MaterialTheme.colorScheme.surface) .background(MaterialTheme.colorScheme.surface)
) { ) {
// Model selector at the top. ModelDownloadStatusInfoPanel(
ModelSelector(
model = curSelectedModel, model = curSelectedModel,
task = task, task = task,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel
onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage(
oldConfigValues = old,
newConfigValues = new,
model = curSelectedModel
)
},
modifier = Modifier.fillMaxWidth(),
contentAlpha = curAlpha,
) )
// Manages the conditional display of UI elements (download model button and downloading
// animation) based on the corresponding download status.
//
// It uses delayed visibility ensuring they are shown only after a short delay if their
// respective conditions remain true. This prevents UI flickering and provides a smoother
// user experience.
val curStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name]
var shouldShowDownloadingAnimation by remember { mutableStateOf(false) }
var downloadingAnimationConditionMet by remember { mutableStateOf(false) }
var shouldShowDownloadModelButton by remember { mutableStateOf(false) }
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
downloadingAnimationConditionMet =
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS ||
curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED ||
curStatus?.status == ModelDownloadStatusType.UNZIPPING
downloadModelButtonConditionMet =
curStatus?.status == ModelDownloadStatusType.FAILED ||
curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
LaunchedEffect(downloadingAnimationConditionMet) {
if (downloadingAnimationConditionMet) {
delay(100)
shouldShowDownloadingAnimation = true
} else {
shouldShowDownloadingAnimation = false
}
}
LaunchedEffect(downloadModelButtonConditionMet) {
if (downloadModelButtonConditionMet) {
delay(700)
shouldShowDownloadModelButton = true
} else {
shouldShowDownloadModelButton = false
}
}
AnimatedVisibility(
visible = shouldShowDownloadingAnimation,
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
exit = scaleOut(targetScale = 0.9f) + fadeOut()
) {
Box(
modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.Center
) {
ModelDownloadingAnimation()
}
}
AnimatedVisibility(
visible = shouldShowDownloadModelButton,
enter = fadeIn(),
exit = fadeOut()
) {
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
checkNotificationPermissionAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = curSelectedModel
)
})
}
// The main messages panel. // The main messages panel.
if (curStatus?.status == ModelDownloadStatusType.SUCCEEDED) { if (curModelDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
ChatPanel( ChatPanel(
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
task = task, task = task,

View file

@ -20,13 +20,12 @@ import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.processLlmResponse
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
private const val TAG = "AGChatViewModel" private const val TAG = "AGChatViewModel"
private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***"
data class ChatUiState( data class ChatUiState(
/** /**
@ -121,26 +120,7 @@ open class ChatViewModel(val task: Task) : ViewModel() {
if (newMessages.size > 0) { if (newMessages.size > 0) {
val lastMessage = newMessages.last() val lastMessage = newMessages.last()
if (lastMessage is ChatMessageText) { if (lastMessage is ChatMessageText) {
var newContent = "${lastMessage.content}${partialContent}" val newContent = processLlmResponse(response = "${lastMessage.content}${partialContent}")
// TODO: special handling for deepseek to remove the <think> tag.
// Add "thinking" and "done thinking" around the thinking content.
newContent = newContent
.replace("<think>", "$START_THINKING\n")
.replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) {
val thinkingContent =
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
.replace(START_THINKING, "")
.replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
}
}
val newLastMessage = ChatMessageText( val newLastMessage = ChatMessageText(
content = newContent, content = newContent,
side = lastMessage.side, side = lastMessage.side,

View file

@ -45,7 +45,7 @@ fun MarkdownText(
ProvideTextStyle( ProvideTextStyle(
value = TextStyle( value = TextStyle(
fontSize = fontSize, fontSize = fontSize,
lineHeight = fontSize * 1.2, lineHeight = fontSize * 1.4,
) )
) { ) {
RichText( RichText(

View file

@ -48,15 +48,16 @@ fun MessageActionButton(
label: String, label: String,
icon: ImageVector, icon: ImageVector,
onClick: () -> Unit, onClick: () -> Unit,
modifier: Modifier = Modifier,
enabled: Boolean = true enabled: Boolean = true
) { ) {
val modifier = Modifier val curModifier = modifier
.padding(top = 4.dp) .padding(top = 4.dp)
.clip(CircleShape) .clip(CircleShape)
.background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh) .background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh)
val alpha: Float = if (enabled) 1.0f else 0.3f val alpha: Float = if (enabled) 1.0f else 0.3f
Row( Row(
modifier = if (enabled) modifier.clickable { onClick() } else modifier, modifier = if (enabled) curModifier.clickable { onClick() } else modifier,
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
) { ) {
Icon( Icon(

View file

@ -19,8 +19,8 @@ package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentWidth
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
@ -33,16 +33,14 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
* This function renders benchmark statistics (e.g., various token speed) in data cards * This function renders benchmark statistics (e.g., various token speed) in data cards
*/ */
@Composable @Composable
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult) { fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult, modifier: Modifier = Modifier) {
Column( Column(
modifier = Modifier modifier = modifier.padding(12.dp),
.padding(12.dp)
.wrapContentWidth(),
verticalArrangement = Arrangement.spacedBy(8.dp) verticalArrangement = Arrangement.spacedBy(8.dp)
) { ) {
// Data cards. // Data cards.
Row( Row(
modifier = Modifier.wrapContentWidth(), horizontalArrangement = Arrangement.spacedBy(16.dp) modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) { ) {
for (stat in message.orderedStats) { for (stat in message.orderedStats) {
DataCard( DataCard(

View file

@ -82,7 +82,7 @@ fun MessageBodyPromptTemplates(
style = MaterialTheme.typography.titleSmall, style = MaterialTheme.typography.titleSmall,
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
.offset(y = -4.dp), .offset(y = (-4).dp),
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
) )
} }
@ -140,7 +140,7 @@ fun MessageBodyPromptTemplatesPreview() {
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) { for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
task.index = index task.index = index
for (model in task.models) { for (model in task.models) {
model.preProcess(task = task) model.preProcess()
} }
} }

View file

@ -70,7 +70,6 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
* This function renders a row containing a text field for message input and a send button. * This function renders a row containing a text field for message input and a send button.
* It handles message composition, input validation, and sending messages. * It handles message composition, input validation, and sending messages.
*/ */
@OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun MessageInputText( fun MessageInputText(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
@ -190,7 +189,7 @@ fun MessageInputText(
Icons.AutoMirrored.Rounded.Send, Icons.AutoMirrored.Rounded.Send,
contentDescription = "", contentDescription = "",
modifier = Modifier.offset(x = 2.dp), modifier = Modifier.offset(x = 2.dp),
tint = if (inProgress) MaterialTheme.colorScheme.surfaceContainerHigh else MaterialTheme.colorScheme.primary tint = MaterialTheme.colorScheme.primary
) )
} }
} }

View file

@ -0,0 +1,119 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.coroutines.delay
@Composable
fun ModelDownloadStatusInfoPanel(
model: Model,
task: Task,
modelManagerViewModel: ModelManagerViewModel
) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
// Manages the conditional display of UI elements (download model button and downloading
// animation) based on the corresponding download status.
//
// It uses delayed visibility ensuring they are shown only after a short delay if their
// respective conditions remain true. This prevents UI flickering and provides a smoother
// user experience.
val curStatus = modelManagerUiState.modelDownloadStatus[model.name]
var shouldShowDownloadingAnimation by remember { mutableStateOf(false) }
var downloadingAnimationConditionMet by remember { mutableStateOf(false) }
var shouldShowDownloadModelButton by remember { mutableStateOf(false) }
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
downloadingAnimationConditionMet =
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS || curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED || curStatus?.status == ModelDownloadStatusType.UNZIPPING
downloadModelButtonConditionMet =
curStatus?.status == ModelDownloadStatusType.FAILED || curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
LaunchedEffect(downloadingAnimationConditionMet) {
if (downloadingAnimationConditionMet) {
delay(100)
shouldShowDownloadingAnimation = true
} else {
shouldShowDownloadingAnimation = false
}
}
LaunchedEffect(downloadModelButtonConditionMet) {
if (downloadModelButtonConditionMet) {
delay(700)
shouldShowDownloadModelButton = true
} else {
shouldShowDownloadModelButton = false
}
}
AnimatedVisibility(
visible = shouldShowDownloadingAnimation,
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
exit = scaleOut(targetScale = 0.9f) + fadeOut()
) {
Box(
modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center
) {
ModelDownloadingAnimation(
model = model, task = task, modelManagerViewModel = modelManagerViewModel
)
}
}
AnimatedVisibility(
visible = shouldShowDownloadModelButton, enter = fadeIn(), exit = fadeOut()
) {
Column(
modifier = Modifier.fillMaxSize(),
verticalArrangement = Arrangement.Center,
horizontalAlignment = Alignment.CenterHorizontally
) {
DownloadAndTryButton(
task = task,
model = model,
enabled = true,
needToDownloadFirst = true,
modelManagerViewModel = modelManagerViewModel,
onClicked = {}
)
}
}
}

View file

@ -24,6 +24,7 @@ import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
@ -32,23 +33,40 @@ import androidx.compose.foundation.layout.width
import androidx.compose.foundation.lazy.grid.GridCells import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.itemsIndexed import androidx.compose.foundation.lazy.grid.itemsIndexed
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.ColorFilter import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.formatToHourMinSecond
import com.google.aiedge.gallery.ui.common.getTaskIconColor import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.common.humanReadableSize
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.MODEL_TEST1
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlin.math.cos import kotlin.math.cos
import kotlin.math.pow import kotlin.math.pow
@ -66,8 +84,19 @@ private const val END_SCALE = 0.6f
* scaling and rotation effect. * scaling and rotation effect.
*/ */
@Composable @Composable
fun ModelDownloadingAnimation() { fun ModelDownloadingAnimation(
model: Model,
task: Task,
modelManagerViewModel: ModelManagerViewModel
) {
val scale = remember { Animatable(END_SCALE) } val scale = remember { Animatable(END_SCALE) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val downloadStatus by remember {
derivedStateOf { modelManagerUiState.modelDownloadStatus[model.name] }
}
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
var curDownloadProgress = 0f
LaunchedEffect(Unit) { // Run this once LaunchedEffect(Unit) { // Run this once
while (true) { while (true) {
@ -93,67 +122,156 @@ fun ModelDownloadingAnimation() {
} }
} }
Column(
horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier.offset(y = -GRID_SIZE / 8)
) {
LazyVerticalGrid(
columns = GridCells.Fixed(2),
horizontalArrangement = Arrangement.spacedBy(GRID_SPACING),
verticalArrangement = Arrangement.spacedBy(GRID_SPACING),
modifier = Modifier
.width(GRID_SIZE)
.height(GRID_SIZE)
) {
itemsIndexed(
listOf(
R.drawable.pantegon,
R.drawable.double_circle,
R.drawable.circle,
R.drawable.four_circle
)
) { index, imageResource ->
val currentScale =
if (index == 0 || index == 3) scale.value else START_SCALE + END_SCALE - scale.value
Box( // Failure message.
modifier = Modifier val curDownloadStatus = downloadStatus
.width((GRID_SIZE - GRID_SPACING) / 2) if (curDownloadStatus != null && curDownloadStatus.status == ModelDownloadStatusType.FAILED) {
.height((GRID_SIZE - GRID_SPACING) / 2), Row(verticalAlignment = Alignment.CenterVertically) {
contentAlignment = when (index) { Text(
0 -> Alignment.BottomEnd curDownloadStatus.errorMessage,
1 -> Alignment.BottomStart color = MaterialTheme.colorScheme.error,
2 -> Alignment.TopEnd style = labelSmallNarrow,
3 -> Alignment.TopStart overflow = TextOverflow.Ellipsis,
else -> Alignment.Center )
} }
) { }
Image( // No failure
painter = painterResource(id = imageResource), else {
contentDescription = "", Column(
contentScale = ContentScale.Fit, horizontalAlignment = Alignment.CenterHorizontally,
colorFilter = ColorFilter.tint(getTaskIconColor(index = index)), modifier = Modifier.offset(y = -GRID_SIZE / 8)
modifier = Modifier ) {
.graphicsLayer { LazyVerticalGrid(
scaleX = currentScale columns = GridCells.Fixed(2),
scaleY = currentScale horizontalArrangement = Arrangement.spacedBy(GRID_SPACING),
rotationZ = currentScale * 120 verticalArrangement = Arrangement.spacedBy(GRID_SPACING),
alpha = 0.8f modifier = Modifier
} .width(GRID_SIZE)
.size(70.dp) .height(GRID_SIZE)
) {
itemsIndexed(
listOf(
R.drawable.pantegon,
R.drawable.double_circle,
R.drawable.circle,
R.drawable.four_circle
) )
) { index, imageResource ->
val currentScale =
if (index == 0 || index == 3) scale.value else START_SCALE + END_SCALE - scale.value
Box(
modifier = Modifier
.width((GRID_SIZE - GRID_SPACING) / 2)
.height((GRID_SIZE - GRID_SPACING) / 2),
contentAlignment = when (index) {
0 -> Alignment.BottomEnd
1 -> Alignment.BottomStart
2 -> Alignment.TopEnd
3 -> Alignment.TopStart
else -> Alignment.Center
}
) {
Image(
painter = painterResource(id = imageResource),
contentDescription = "",
contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint(getTaskIconColor(index = index)),
modifier = Modifier
.graphicsLayer {
scaleX = currentScale
scaleY = currentScale
rotationZ = currentScale * 120
alpha = 0.8f
}
.size(70.dp)
)
}
} }
} }
}
Text(
"Feel free to switch apps or lock your device.\n" // Download stats
+ "The download will continue in the background.\n" var sizeLabel = model.totalBytes.humanReadableSize()
+ "We'll send a notification when it's done.", if (curDownloadStatus != null) {
style = MaterialTheme.typography.bodyMedium, // For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime}
textAlign = TextAlign.Center if (inProgress || isPartiallyDownloaded) {
) var totalSize = curDownloadStatus.totalBytes
if (totalSize == 0L) {
totalSize = model.totalBytes
}
sizeLabel =
"${curDownloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
if (curDownloadStatus.bytesPerSecond > 0) {
sizeLabel =
"$sizeLabel · ${curDownloadStatus.bytesPerSecond.humanReadableSize()} / s"
if (curDownloadStatus.remainingMs >= 0) {
sizeLabel =
"$sizeLabel · ${curDownloadStatus.remainingMs.formatToHourMinSecond()} left"
}
}
if (isPartiallyDownloaded) {
sizeLabel = "$sizeLabel (resuming...)"
}
curDownloadProgress =
curDownloadStatus.receivedBytes.toFloat() / curDownloadStatus.totalBytes.toFloat()
if (curDownloadProgress.isNaN()) {
curDownloadProgress = 0f
}
}
// Status for unzipping.
else if (curDownloadStatus.status == ModelDownloadStatusType.UNZIPPING) {
sizeLabel = "Unzipping..."
}
Text(
sizeLabel,
color = MaterialTheme.colorScheme.secondary,
style = labelSmallNarrow.copy(fontSize = 9.sp, lineHeight = 10.sp),
textAlign = TextAlign.Center,
overflow = TextOverflow.Visible,
modifier = Modifier
.padding(bottom = 4.dp)
)
}
// Download progress.
if (inProgress || isPartiallyDownloaded) {
val animatedProgress = remember { Animatable(0f) }
LinearProgressIndicator(
progress = { animatedProgress.value },
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 36.dp)
.padding(horizontal = 36.dp)
)
LaunchedEffect(curDownloadProgress) {
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
}
}
// Unzipping progress.
else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) {
LinearProgressIndicator(
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 36.dp)
.padding(horizontal = 36.dp)
)
}
Text(
"Feel free to switch apps or lock your device.\n"
+ "The download will continue in the background.\n"
+ "We'll send a notification when it's done.",
style = MaterialTheme.typography.bodyMedium,
textAlign = TextAlign.Center
)
}
} }
} }
// Custom Easing function for a multi-bounce effect // Custom Easing function for a multi-bounce effect
@ -168,9 +286,15 @@ fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
@Preview(showBackground = true) @Preview(showBackground = true)
@Composable @Composable
fun ModelDownloadingAnimationPreview() { fun ModelDownloadingAnimationPreview() {
val context = LocalContext.current
GalleryTheme { GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) { Row(modifier = Modifier.padding(16.dp)) {
ModelDownloadingAnimation() ModelDownloadingAnimation(
model = MODEL_TEST1,
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = context)
)
} }
} }
} }

View file

@ -63,7 +63,8 @@ fun ModelSelector(
) { ) {
Box( Box(
modifier = Modifier modifier = Modifier
.fillMaxWidth().padding(bottom = 8.dp), .fillMaxWidth()
.padding(bottom = 8.dp),
contentAlignment = Alignment.Center contentAlignment = Alignment.Center
) { ) {
// Model row. // Model row.
@ -134,7 +135,12 @@ fun ModelSelector(
// Force to re-initialize the model with the new configs. // Force to re-initialize the model with the new configs.
if (needReinitialization) { if (needReinitialization) {
modelManagerViewModel.initializeModel(context = context, model = model, force = true) modelManagerViewModel.initializeModel(
context = context,
task = task,
model = model,
force = true
)
} }
// Notify. // Notify.

View file

@ -181,7 +181,5 @@ fun ModelNameAndStatus(
.padding(top = 2.dp), .padding(top = 2.dp),
) )
} }
} }
} }

View file

@ -23,8 +23,10 @@ import androidx.compose.foundation.layout.size
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.HelpOutline import androidx.compose.material.icons.automirrored.outlined.HelpOutline
import androidx.compose.material.icons.filled.DownloadForOffline import androidx.compose.material.icons.filled.DownloadForOffline
import androidx.compose.material.icons.rounded.Downloading
import androidx.compose.material.icons.rounded.Error import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
@ -34,6 +36,7 @@ import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
/** /**
* Composable function to display an icon representing the download status of a model. * Composable function to display an icon representing the download status of a model.
@ -56,7 +59,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
ModelDownloadStatusType.SUCCEEDED -> { ModelDownloadStatusType.SUCCEEDED -> {
Icon( Icon(
Icons.Filled.DownloadForOffline, Icons.Filled.DownloadForOffline,
tint = Color(0xff3d860b), tint = MaterialTheme.customColors.successColor,
contentDescription = "", contentDescription = "",
modifier = Modifier.size(14.dp) modifier = Modifier.size(14.dp)
) )
@ -69,6 +72,12 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
modifier = Modifier.size(14.dp) modifier = Modifier.size(14.dp)
) )
ModelDownloadStatusType.IN_PROGRESS -> Icon(
Icons.Rounded.Downloading,
contentDescription = "",
modifier = Modifier.size(14.dp)
)
else -> {} else -> {}
} }
} }

View file

@ -73,7 +73,6 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.scale import androidx.compose.ui.draw.scale
import androidx.compose.ui.focus.focusModifier
import androidx.compose.ui.graphics.Brush import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.input.nestedscroll.nestedScroll import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.layout.layout import androidx.compose.ui.layout.layout
@ -91,7 +90,6 @@ import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.ImportedModelInfo import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.TaskIcon import com.google.aiedge.gallery.ui.common.TaskIcon
import com.google.aiedge.gallery.ui.common.getTaskBgColor import com.google.aiedge.gallery.ui.common.getTaskBgColor
@ -275,7 +273,6 @@ fun HomeScreen(
onDismiss = { showImportingDialog = false }, onDismiss = { showImportingDialog = false },
onDone = { onDone = {
modelManagerViewModel.addImportedLlmModel( modelManagerViewModel.addImportedLlmModel(
task = TASK_LLM_CHAT,
info = it, info = it,
) )
showImportingDialog = false showImportingDialog = false

View file

@ -30,7 +30,7 @@ private const val TAG = "AGLlmChatModelHelper"
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
typealias CleanUpListener = () -> Unit typealias CleanUpListener = () -> Unit
data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceSession) data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession)
object LlmChatModelHelper { object LlmChatModelHelper {
// Indexed by model name. // Indexed by model name.
@ -74,6 +74,24 @@ object LlmChatModelHelper {
onDone("") onDone("")
} }
fun resetSession(model: Model) {
val instance = model.instance as LlmModelInstance? ?: return
val session = instance.session
session.close()
val inference = instance.engine
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
val newSession = LlmInferenceSession.createFromOptions(
inference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature).build()
)
instance.session = newSession
}
fun cleanUp(model: Model) { fun cleanUp(model: Model) {
if (model.instance == null) { if (model.instance == null) {
return return
@ -99,7 +117,11 @@ object LlmChatModelHelper {
input: String, input: String,
resultListener: ResultListener, resultListener: ResultListener,
cleanUpListener: CleanUpListener, cleanUpListener: CleanUpListener,
singleTurn: Boolean = false,
) { ) {
if (singleTurn) {
resetSession(model = model)
}
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
// Set listener. // Set listener.

View file

@ -32,7 +32,7 @@ import kotlinx.coroutines.launch
private const val TAG = "AGLlmChatViewModel" private const val TAG = "AGLlmChatViewModel"
private val STATS = listOf( private val STATS = listOf(
Stat(id = "time_to_first_token", label = "Time to 1st token", unit = "sec"), Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"), Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"), Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec") Stat(id = "latency", label = "Latency", unit = "sec")

View file

@ -0,0 +1,206 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import android.util.Log
import androidx.activity.compose.BackHandler
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.calculateStartPadding
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection
import androidx.compose.ui.tooling.preview.Preview
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable
/** Navigation destination data */
object LlmSingleTurnDestination {
@Serializable
val route = "LlmSingleTurnRoute"
}
private const val TAG = "AGLlmSingleTurnScreen"
@Composable
fun LlmSingleTurnScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmSingleTurnViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
val task = viewModel.task
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel
val scope = rememberCoroutineScope()
val context = LocalContext.current
val handleNavigateUp = {
navigateUp()
// clean up all models.
scope.launch(Dispatchers.Default) {
for (model in task.models) {
modelManagerViewModel.cleanupModel(task = task, model = model)
}
}
}
// Handle system's edge swipe.
BackHandler {
handleNavigateUp()
}
// Initialize model when model/download state changes.
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
LaunchedEffect(curDownloadStatus, selectedModel.name) {
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
Log.d(
TAG,
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
)
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
}
}
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[selectedModel.name]
Scaffold(modifier = modifier, topBar = {
ModelPageAppBar(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
onConfigChanged = { _, _ -> },
onBackClicked = { handleNavigateUp() },
onModelSelected = { newSelectedModel ->
scope.launch(Dispatchers.Default) {
// Clean up current model.
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
// Update selected model.
modelManagerViewModel.selectModel(model = newSelectedModel)
}
}
)
}) { innerPadding ->
Column(
modifier = Modifier.padding(
top = innerPadding.calculateTopPadding(),
start = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
end = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
)
) {
ModelDownloadStatusInfoPanel(
model = selectedModel,
task = task,
modelManagerViewModel = modelManagerViewModel
)
// Main UI after model is downloaded.
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
Box(
contentAlignment = Alignment.BottomCenter,
modifier = Modifier.weight(1f)
) {
VerticalSplitView(modifier = Modifier.fillMaxSize(),
topView = {
PromptTemplatesPanel(
model = selectedModel,
viewModel = viewModel,
onSend = { fullPrompt ->
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
}, modifier = Modifier.fillMaxSize()
)
},
bottomView = {
Box(
contentAlignment = Alignment.BottomCenter,
modifier = Modifier
.fillMaxSize()
.background(MaterialTheme.customColors.agentBubbleBgColor)
) {
ResponsePanel(
model = selectedModel,
viewModel = viewModel,
modifier = Modifier
.fillMaxSize()
.padding(bottom = innerPadding.calculateBottomPadding())
)
}
})
// Model initialization in-progress message.
this@Column.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding())
) {
ModelInitializationStatusChip()
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun LlmSingleTurnScreenPreview() {
val context = LocalContext.current
GalleryTheme {
LlmSingleTurnScreen(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
viewModel = PreviewLlmSingleTurnViewModel(),
navigateUp = {},
)
}
}

View file

@ -0,0 +1,210 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.Stat
import com.google.aiedge.gallery.ui.common.processLlmResponse
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.aiedge.gallery.ui.llmchat.LlmModelInstance
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
private const val TAG = "AGLlmSingleTurnViewModel"
data class LlmSingleTurnUiState(
/**
* Indicates whether the runtime is currently processing a message.
*/
val inProgress: Boolean = false,
/**
* Indicates whether the model is currently being initialized.
*/
val initializing: Boolean = false,
// model -> <template label -> response>
val responsesByModel: Map<String, Map<String, String>>,
// model -> <template label -> benchmark result>
val benchmarkByModel: Map<String, Map<String, ChatMessageBenchmarkLlmResult>>,
/** Selected prompt template type. */
val selectedPromptTemplateType: PromptTemplateType = PromptTemplateType.entries[0],
)
private val STATS = listOf(
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec")
)
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_SINGLE_TURN) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow()
fun generateResponse(model: Model, input: String) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
setInitializing(true)
// Wait for instance to be initialized.
while (model.instance == null) {
delay(100)
}
// Run inference.
val instance = model.instance as LlmModelInstance
val prefillTokens = instance.session.sizeInTokens(input)
var firstRun = true
var timeToFirstToken = 0f
var firstTokenTs = 0L
var decodeTokens = 0
var prefillSpeed = 0f
var decodeSpeed: Float
val start = System.currentTimeMillis()
var response = ""
var lastBenchmarkUpdateTs = 0L
LlmChatModelHelper.runInference(model = model,
input = input,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
if (firstRun) {
setInitializing(false)
firstTokenTs = System.currentTimeMillis()
timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken
firstRun = false
} else {
decodeTokens++
}
// Incrementally update the streamed partial results.
response = processLlmResponse(response = "$response$partialResult")
// Update response.
updateResponse(
model = model,
promptTemplateType = uiState.value.selectedPromptTemplateType,
response = response
)
// Update benchmark (with throttling).
if (curTs - lastBenchmarkUpdateTs > 200) {
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) {
decodeSpeed = 0f
}
val benchmark = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = !done,
latencyMs = -1f,
)
updateBenchmark(
model = model,
promptTemplateType = uiState.value.selectedPromptTemplateType,
benchmark = benchmark
)
lastBenchmarkUpdateTs = curTs
}
if (done) {
setInProgress(false)
}
},
singleTurn = true,
cleanUpListener = {
setInitializing(false)
setInProgress(false)
})
}
}
fun selectPromptTemplate(model: Model, promptTemplateType: PromptTemplateType) {
Log.d(TAG, "selecting prompt template: ${promptTemplateType.label}")
// Clear response.
updateResponse(model = model, promptTemplateType = promptTemplateType, response = "")
this._uiState.update { this.uiState.value.copy(selectedPromptTemplateType = promptTemplateType) }
}
fun setInProgress(inProgress: Boolean) {
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
}
fun setInitializing(initializing: Boolean) {
_uiState.update { _uiState.value.copy(initializing = initializing) }
}
fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) {
_uiState.update { currentState ->
val currentResponses = currentState.responsesByModel
val modelResponses = currentResponses[model.name]?.toMutableMap() ?: mutableMapOf()
modelResponses[promptTemplateType.label] = response
val newResponses = currentResponses.toMutableMap()
newResponses[model.name] = modelResponses
currentState.copy(responsesByModel = newResponses)
}
}
fun updateBenchmark(
model: Model, promptTemplateType: PromptTemplateType, benchmark: ChatMessageBenchmarkLlmResult
) {
_uiState.update { currentState ->
val currentBenchmark = currentState.benchmarkByModel
val modelBenchmarks = currentBenchmark[model.name]?.toMutableMap() ?: mutableMapOf()
modelBenchmarks[promptTemplateType.label] = benchmark
val newBenchmarks = currentBenchmark.toMutableMap()
newBenchmarks[model.name] = modelBenchmarks
currentState.copy(benchmarkByModel = newBenchmarks)
}
}
private fun createUiState(task: Task): LlmSingleTurnUiState {
val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf()
val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> =
mutableMapOf()
for (model in task.models) {
responsesByModel[model.name] = mutableMapOf()
benchmarkByModel[model.name] = mutableMapOf()
}
return LlmSingleTurnUiState(
responsesByModel = responsesByModel,
benchmarkByModel = benchmarkByModel,
)
}
}

View file

@ -0,0 +1,185 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.buildAnnotatedString
import androidx.compose.ui.text.withStyle
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.graphics.Brush.Companion.linearGradient
enum class PromptTemplateInputEditorType {
SINGLE_SELECT
}
enum class RewriteToneType(val label: String) {
FORMAL(label = "Formal"), CASUAL(label = "Casual"), FRIENDLY(label = "Friendly"), POLITE(label = "Polite"), ENTHUSIASTIC(
label = "Enthusiastic"
),
CONCISE(label = "Concise"),
}
enum class SummarizationType(val label: String) {
KEY_BULLET_POINT(label = "Key bullet points (3-5)"),
SHORT_PARAGRAPH(label = "Short paragraph (1-2 sentences)"),
CONCISE_SUMMARY(label = "Concise summary (~50 words)"),
HEADLINE_TITLE(label = "Headline / title"),
ONE_SENTENCE_SUMMARY(label = "One-sentence summary"),
}
enum class LanguageType(val label: String) {
CPP(label = "C++"),
JAVA(label = "Java"),
JAVASCRIPT(label = "JavaScript"),
KOTLIN(label = "Kotlin"),
PYTHON(label = "Python"),
SWIFT(label = "Swift"),
TYPESCRIPT(label = "TypeScript"),
}
enum class InputEditorLabel(val label: String) {
TONE(label = "Tone"),
STYLE(label = "Style"),
LANGUAGE(label = "Language"),
}
open class PromptTemplateInputEditor(
open val label: String,
open val type: PromptTemplateInputEditorType,
open val defaultOption: String = "",
)
/** Single select that shows options in bottom sheet. */
class PromptTemplateSingleSelectInputEditor(
override val label: String,
val options: List<String> = listOf(),
override val defaultOption: String = "",
) : PromptTemplateInputEditor(
label = label, type = PromptTemplateInputEditorType.SINGLE_SELECT, defaultOption = defaultOption
)
data class PromptTemplateConfig(val inputEditors: List<PromptTemplateInputEditor> = listOf())
private val GEMINI_GRADIENT_STYLE = SpanStyle(
brush = linearGradient(
colors = listOf(Color(0xFF4285f4), Color(0xFF9b72cb), Color(0xFFd96570))
)
)
enum class PromptTemplateType(
val label: String,
val config: PromptTemplateConfig,
val genFullPrompt: (userInput: String, inputEditorValues: Map<String, Any>) -> AnnotatedString = { _, _ ->
AnnotatedString("")
},
val examplePrompts: List<String> = listOf(),
) {
FREE_FORM(
label = "Free form",
config = PromptTemplateConfig(),
genFullPrompt = { userInput, _ -> AnnotatedString(userInput) },
examplePrompts = listOf(
"Suggest 3 topics for a podcast about \"Friendships in your 20s\".",
"Outline the key sections needed in a basic logo design brief.",
"List 3 pros and 3 cons to consider before buying a smart watch.",
"Write a short, optimistic quote about the future of technology.",
"Generate 3 potential names for a mobile app that helps users identify plants.",
"Explain the difference between AI and machine learning in 2 sentences.",
"Create a simple haiku about a cat sleeping in the sun.",
"List 3 ways to make instant noodles taste better using common kitchen ingredients."
)
),
REWRITE_TONE(
label = "Rewrite tone", config = PromptTemplateConfig(
inputEditors = listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.TONE.label,
options = RewriteToneType.entries.map { it.label },
defaultOption = RewriteToneType.FORMAL.label
)
)
), genFullPrompt = { userInput, inputEditorValues ->
val tone = inputEditorValues[InputEditorLabel.TONE.label] as String
buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) {
append("Rewrite the following text using a ${tone.lowercase()} tone: ")
}
append(userInput)
}
}, examplePrompts = listOf(
"Hey team, just wanted to remind everyone about the meeting tomorrow @ 10. Be there!",
"Our new software update includes several bug fixes and performance improvements.",
"Due to the fact that the weather was bad, we decided to postpone the event.",
"Please find attached the requested documentation for your perusal.",
"Welcome to the team. Review the onboarding materials.",
)
),
SUMMARIZE_TEXT(
label = "Summarize text",
config = PromptTemplateConfig(
inputEditors = listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.STYLE.label,
options = SummarizationType.entries.map { it.label },
defaultOption = SummarizationType.KEY_BULLET_POINT.label
)
)
),
genFullPrompt = { userInput, inputEditorValues ->
val style = inputEditorValues[InputEditorLabel.STYLE.label] as String
buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) {
append("Please summarize the following in ${style.lowercase()}: ")
}
append(userInput)
}
},
examplePrompts = listOf(
"The new Pixel phone features an advanced camera system with improved low-light performance and AI-powered editing tools. The display is brighter and more energy-efficient. It runs on the latest Tensor chip, offering faster processing and enhanced security features. Battery life has also been extended, providing all-day power for most users.",
"Beginning this Friday, January 24, giant pandas Bao Li and Qing Bao are officially on view to the public at the Smithsonians National Zoo and Conservation Biology Institute (NZCBI). The 3-year-old bears arrived in Washington this past October, undergoing a quarantine period before making their debut. Under NZCBIs new agreement with the CWCA, Qing Bao and Bao Li will remain in the United States for ten years, until April 2034, in exchange for an annual fee of \$1 million. The pair are still too young to breed, as pandas only reach sexual maturity between ages 4 and 7. “Kind of picture them as like awkward teenagers right now,” Lally told WUSA9. “We still have about two years before we would probably even see signs that theyre ready to start mating.”",
),
),
CODE_SNIPPET(
label = "Code snippet",
config = PromptTemplateConfig(
inputEditors = listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.LANGUAGE.label,
options = LanguageType.entries.map { it.label },
defaultOption = LanguageType.JAVASCRIPT.label
)
)
),
genFullPrompt = { userInput, inputEditorValues ->
val language = inputEditorValues[InputEditorLabel.LANGUAGE.label] as String
buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) {
append("Write a $language code snippet to ")
}
append(userInput)
}
},
examplePrompts = listOf(
"Create an alert box that says \"Hello, World!\"",
"Declare an immutable variable named 'appName' with the value \"AI Gallery\"",
"Print the numbers from 1 to 5 using a for loop.",
"Write a function that returns the square of an integer input.",
),
),
}

View file

@ -0,0 +1,426 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.Send
import androidx.compose.material.icons.outlined.ContentCopy
import androidx.compose.material.icons.outlined.Description
import androidx.compose.material.icons.outlined.ExpandLess
import androidx.compose.material.icons.outlined.ExpandMore
import androidx.compose.material.icons.rounded.Add
import androidx.compose.material.icons.rounded.Visibility
import androidx.compose.material.icons.rounded.VisibilityOff
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.FilterChipDefaults
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.OutlinedIconButton
import androidx.compose.material3.PrimaryScrollableTabRow
import androidx.compose.material3.Tab
import androidx.compose.material3.Text
import androidx.compose.material3.TextField
import androidx.compose.material3.TextFieldDefaults
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateMapOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshots.SnapshotStateMap
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.text.TextLayoutResult
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape
import com.google.aiedge.gallery.ui.theme.customColors
private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries
private val TAB_TITLES = PromptTemplateType.entries.map { it.label }
private val ICON_BUTTON_SIZE = 42.dp
const val FULL_PROMPT_SWITCH_KEY = "full_prompt"
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun PromptTemplatesPanel(
model: Model,
viewModel: LlmSingleTurnViewModel,
onSend: (fullPrompt: String) -> Unit,
modifier: Modifier = Modifier
) {
val uiState by viewModel.uiState.collectAsState()
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val inProgress = uiState.inProgress
var selectedTabIndex by remember { mutableIntStateOf(0) }
var curTextInputContent by remember { mutableStateOf("") }
val inputEditorValues: SnapshotStateMap<String, Any> = remember {
mutableStateMapOf(FULL_PROMPT_SWITCH_KEY to false)
}
val fullPrompt by remember {
derivedStateOf {
uiState.selectedPromptTemplateType.genFullPrompt(curTextInputContent, inputEditorValues)
}
}
val clipboardManager = LocalClipboardManager.current
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
// Update input editor values when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) {
for (config in selectedPromptTemplateType.config.inputEditors) {
inputEditorValues[config.label] = config.defaultOption
}
expandedStates.clear()
}
var showExamplePromptBottomSheet by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius)
Column(modifier = modifier) {
// Scrollable tab row for all prompt templates.
PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) {
TAB_TITLES.forEachIndexed { index, title ->
Tab(selected = selectedTabIndex == index, onClick = {
// Clear input when tab changes.
curTextInputContent = ""
// Reset full prompt switch.
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false
selectedTabIndex = index
viewModel.selectPromptTemplate(
model = model,
promptTemplateType = promptTemplateTypes[index]
)
}, text = { Text(text = title) })
}
}
// Content.
Column(
modifier = Modifier
.weight(1f)
.fillMaxWidth()
) {
// Input editor row.
if (selectedPromptTemplateType.config.inputEditors.isNotEmpty()) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
modifier = Modifier
.fillMaxWidth()
.background(MaterialTheme.colorScheme.surfaceContainerLow)
.padding(horizontal = 16.dp, vertical = 10.dp)
) {
// Input editors.
for (inputEditor in selectedPromptTemplateType.config.inputEditors) {
when (inputEditor.type) {
PromptTemplateInputEditorType.SINGLE_SELECT -> SingleSelectButton(config = inputEditor as PromptTemplateSingleSelectInputEditor,
onSelected = { option ->
inputEditorValues[inputEditor.label] = option
})
}
}
}
}
// Text input box.
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(rememberScrollState())
) {
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
Text(
fullPrompt,
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
.padding(bottom = 32.dp)
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
.background(MaterialTheme.customColors.agentBubbleBgColor)
.padding(16.dp)
)
} else {
TextField(
value = curTextInputContent,
onValueChange = { curTextInputContent = it },
colors = TextFieldDefaults.colors(
unfocusedContainerColor = Color.Transparent,
focusedContainerColor = Color.Transparent,
focusedIndicatorColor = Color.Transparent,
unfocusedIndicatorColor = Color.Transparent,
disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyMedium,
placeholder = { Text("Enter content") },
modifier = Modifier.padding(bottom = 32.dp)
)
}
}
// Text action row.
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp),
modifier = Modifier
.fillMaxWidth()
.padding(vertical = 4.dp, horizontal = 16.dp)
) {
// Full prompt switch.
if (selectedPromptTemplateType != PromptTemplateType.FREE_FORM && curTextInputContent.isNotEmpty()) {
Row(verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp),
modifier = Modifier
.clip(CircleShape)
.background(if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.customColors.agentBubbleBgColor)
.clickable {
inputEditorValues[FULL_PROMPT_SWITCH_KEY] =
!(inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean)
}
.height(40.dp)
.border(
width = 1.dp, color = MaterialTheme.colorScheme.surface, shape = CircleShape
)
.padding(horizontal = 12.dp)) {
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
Icon(
imageVector = Icons.Rounded.Visibility,
contentDescription = "",
modifier = Modifier.size(FilterChipDefaults.IconSize),
)
} else {
Icon(
imageVector = Icons.Rounded.VisibilityOff,
contentDescription = "",
modifier = Modifier
.size(FilterChipDefaults.IconSize)
.alpha(0.3f),
)
}
Text("Preview prompt", style = MaterialTheme.typography.labelMedium)
}
}
Spacer(modifier = Modifier.weight(1f))
// Button to copy full prompt.
if (curTextInputContent.isNotEmpty()) {
OutlinedIconButton(
onClick = {
val clipData = fullPrompt
clipboardManager.setText(clipData)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE)
) {
Icon(
Icons.Outlined.ContentCopy, contentDescription = "", modifier = Modifier.size(20.dp)
)
}
}
// Add example prompt button.
OutlinedIconButton(
enabled = !inProgress,
onClick = { showExamplePromptBottomSheet = true },
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE)
) {
Icon(
Icons.Rounded.Add,
contentDescription = "",
modifier = Modifier.size(20.dp),
)
}
// Send button
OutlinedIconButton(
enabled = !inProgress && curTextInputContent.isNotEmpty(),
onClick = {
onSend(fullPrompt.text)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE)
) {
Icon(
Icons.AutoMirrored.Rounded.Send,
contentDescription = "",
modifier = Modifier
.size(20.dp)
.offset(x = 2.dp),
)
}
}
}
}
}
if (showExamplePromptBottomSheet) {
ModalBottomSheet(
onDismissRequest = { showExamplePromptBottomSheet = false },
sheetState = sheetState,
modifier = Modifier.wrapContentHeight(),
) {
Column(modifier = Modifier.padding(bottom = 16.dp)) {
// Title
Text(
"Select an example",
modifier = Modifier
.fillMaxWidth()
.padding(16.dp),
style = MaterialTheme.typography.titleLarge
)
// Examples
for (prompt in selectedPromptTemplateType.examplePrompts) {
var textLayoutResultState by remember { mutableStateOf<TextLayoutResult?>(null) }
val hasOverflow = remember(textLayoutResultState) {
textLayoutResultState?.hasVisualOverflow ?: false
}
val isExpanded = expandedStates[prompt] ?: false
Column(
modifier = Modifier
.fillMaxWidth()
.clickable {
curTextInputContent = prompt
showExamplePromptBottomSheet = false
}
.padding(horizontal = 16.dp, vertical = 8.dp),
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
) {
Icon(Icons.Outlined.Description, contentDescription = "")
Text(prompt,
maxLines = if (isExpanded) Int.MAX_VALUE else 3,
overflow = TextOverflow.Ellipsis,
style = MaterialTheme.typography.bodySmall,
modifier = Modifier.weight(1f),
onTextLayout = { textLayoutResultState = it }
)
}
if (hasOverflow && !isExpanded) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 2.dp),
horizontalArrangement = Arrangement.End
) {
Box(modifier = Modifier
.padding(end = 16.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
.clickable {
expandedStates[prompt] = true
}
.padding(vertical = 1.dp, horizontal = 6.dp)) {
Icon(
Icons.Outlined.ExpandMore,
contentDescription = "",
modifier = Modifier.size(12.dp)
)
}
}
} else if (isExpanded) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 2.dp),
horizontalArrangement = Arrangement.End
) {
Box(modifier = Modifier
.padding(end = 16.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
.clickable {
expandedStates[prompt] = false
}
.padding(vertical = 1.dp, horizontal = 6.dp)) {
Icon(
Icons.Outlined.ExpandLess,
contentDescription = "",
modifier = Modifier.size(12.dp)
)
}
}
}
}
}
}
}
}
}

View file

@ -0,0 +1,206 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.AutoAwesome
import androidx.compose.material.icons.outlined.ContentCopy
import androidx.compose.material.icons.outlined.Timer
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.PrimaryTabRow
import androidx.compose.material3.Tab
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
import com.google.aiedge.gallery.ui.theme.GalleryTheme
private val OPTIONS = listOf("Response", "Benchmark")
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ResponsePanel(
model: Model,
viewModel: LlmSingleTurnViewModel,
modifier: Modifier = Modifier,
) {
val uiState by viewModel.uiState.collectAsState()
val inProgress = uiState.inProgress
val initializing = uiState.initializing
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val response = uiState.responsesByModel[model.name]?.get(selectedPromptTemplateType.label) ?: ""
val benchmark = uiState.benchmarkByModel[model.name]?.get(selectedPromptTemplateType.label)
val responseScrollState = rememberScrollState()
var selectedOptionIndex by remember { mutableIntStateOf(0) }
val clipboardManager = LocalClipboardManager.current
// Scroll to bottom when response changes.
LaunchedEffect(response) {
if (inProgress) {
responseScrollState.animateScrollTo(responseScrollState.maxValue)
}
}
// Select the "response" tab when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) {
selectedOptionIndex = 0
}
if (initializing) {
Box(
contentAlignment = Alignment.TopStart,
modifier = modifier
.fillMaxSize()
.padding(horizontal = 16.dp)
) {
MessageBodyLoading()
}
} else {
// Message when response is empty.
if (response.isEmpty()) {
Row(
modifier = Modifier.fillMaxSize(),
horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically
) {
Text(
"Response will appear here",
modifier = Modifier.alpha(0.5f),
style = MaterialTheme.typography.labelMedium,
)
}
}
// Response markdown.
else {
Column(
modifier = modifier
.padding(horizontal = 16.dp)
.padding(bottom = 4.dp)
) {
// Response/benchmark switch.
Row(modifier = Modifier.fillMaxWidth()) {
PrimaryTabRow(
selectedTabIndex = selectedOptionIndex,
containerColor = Color.Transparent,
) {
OPTIONS.forEachIndexed { index, title ->
Tab(selected = selectedOptionIndex == index, onClick = {
selectedOptionIndex = index
}, text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp)
) {
Icon(
ICONS[index],
contentDescription = "",
modifier = Modifier
.size(16.dp)
.alpha(0.7f)
)
Text(text = title)
}
})
}
}
}
if (selectedOptionIndex == 0) {
Box(
contentAlignment = Alignment.BottomEnd,
modifier = Modifier.weight(1f)
) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(responseScrollState)
) {
MarkdownText(
text = response,
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp)
)
}
// Copy button.
IconButton(
onClick = {
val clipData = AnnotatedString(response)
clipboardManager.setText(clipData)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
contentColor = MaterialTheme.colorScheme.primary,
),
) {
Icon(
Icons.Outlined.ContentCopy,
contentDescription = "",
modifier = Modifier.size(20.dp),
)
}
}
} else if (selectedOptionIndex == 1) {
if (benchmark != null) {
MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth())
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ResponsePanelPreview() {
GalleryTheme {
ResponsePanel(
model = TASK_LLM_SINGLE_TURN.models[0],
viewModel = LlmSingleTurnViewModel(),
modifier = Modifier.fillMaxSize()
)
}
}

View file

@ -0,0 +1,90 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ArrowDropDown
import androidx.compose.material3.DropdownMenu
import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.unit.dp
@Composable
fun SingleSelectButton(
config: PromptTemplateSingleSelectInputEditor,
onSelected: (String) -> Unit
) {
var showMenu by remember { mutableStateOf(false) }
var selectedOption by remember { mutableStateOf(config.defaultOption) }
LaunchedEffect(config) {
selectedOption = config.defaultOption
}
Box {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier
.clip(RoundedCornerShape(8.dp))
.background(MaterialTheme.colorScheme.secondaryContainer)
.clickable {
showMenu = true
}
.padding(vertical = 4.dp, horizontal = 6.dp)
.padding(start = 8.dp)
) {
Text("${config.label}: $selectedOption", style = MaterialTheme.typography.labelLarge)
Icon(Icons.Rounded.ArrowDropDown, contentDescription = "")
}
DropdownMenu(
expanded = showMenu,
onDismissRequest = { showMenu = false }
) {
// Options
for (option in config.options) {
DropdownMenuItem(
text = { Text(option) },
onClick = {
selectedOption = option
showMenu = false
onSelected(option)
}
)
}
}
}
}

View file

@ -0,0 +1,133 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.foundation.background
import androidx.compose.foundation.gestures.detectDragGestures
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.layout.onGloballyPositioned
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
@Composable
fun VerticalSplitView(
topView: @Composable () -> Unit,
bottomView: @Composable () -> Unit,
modifier: Modifier = Modifier,
initialRatio: Float = 0.5f,
minTopHeight: Dp = 250.dp,
minBottomHeight: Dp = 200.dp,
handleThickness: Dp = 20.dp
) {
var splitRatio by remember { mutableFloatStateOf(initialRatio) }
var columnHeightPx by remember {
mutableFloatStateOf(0f)
}
var columnHeightDp by remember {
mutableStateOf(0.dp)
}
val localDensity = LocalDensity.current
Column(modifier = modifier
.fillMaxSize()
.onGloballyPositioned { coordinates ->
// Set column height using the LayoutCoordinates
columnHeightPx = coordinates.size.height.toFloat()
columnHeightDp = with(localDensity) { coordinates.size.height.toDp() }
}
) {
Box(
modifier = Modifier
.fillMaxWidth()
.weight(splitRatio)
) {
topView()
}
Box(
modifier = Modifier
.fillMaxWidth()
.height(handleThickness)
.background(MaterialTheme.customColors.agentBubbleBgColor)
.pointerInput(Unit) {
detectDragGestures { change, dragAmount ->
val newTopHeightPx = columnHeightPx * splitRatio + dragAmount.y
var newTopHeightDp = with(localDensity) { newTopHeightPx.toDp() }
if (newTopHeightDp < minTopHeight) {
newTopHeightDp = minTopHeight
}
if (columnHeightDp - newTopHeightDp < minBottomHeight) {
newTopHeightDp = columnHeightDp - minBottomHeight
}
splitRatio = newTopHeightDp / columnHeightDp
change.consume()
}
},
contentAlignment = Alignment.Center
) {
Box(
modifier = Modifier
.width(32.dp)
.height(4.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f))
)
}
Box(
modifier = Modifier
.fillMaxWidth()
.weight(1f - splitRatio)
) {
bottomView()
}
}
}
@Preview(showBackground = true)
@Composable
fun VerticalSplitViewPreview() {
GalleryTheme {
VerticalSplitView(topView = {
Text("top")
}, bottomView = {
Text("bottom")
})
}
}

View file

@ -40,6 +40,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.data.ValueType
@ -175,9 +176,13 @@ open class ModelManagerViewModel(
// Kick off downloads for these models . // Kick off downloads for these models .
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) {
val tokenStatusAndData = getTokenStatusAndData()
for (info in inProgressWorkInfos) { for (info in inProgressWorkInfos) {
val model: Model? = getModelByName(info.modelName) val model: Model? = getModelByName(info.modelName)
if (model != null) { if (model != null) {
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
model.accessToken = tokenStatusAndData.data.accessToken
}
Log.d(TAG, "Sending a new download request for '${model.name}'") Log.d(TAG, "Sending a new download request for '${model.name}'")
downloadRepository.downloadModel( downloadRepository.downloadModel(
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
@ -233,11 +238,13 @@ open class ModelManagerViewModel(
// Delete model from the list if model is imported as a local model. // Delete model from the list if model is imported as a local model.
if (model.imported) { if (model.imported) {
val index = task.models.indexOf(model) for (curTask in TASKS) {
if (index >= 0) { val index = curTask.models.indexOf(model)
task.models.removeAt(index) if (index >= 0) {
curTask.models.removeAt(index)
}
curTask.updateTrigger.value = System.currentTimeMillis()
} }
task.updateTrigger.value = System.currentTimeMillis()
curModelDownloadStatus.remove(model.name) curModelDownloadStatus.remove(model.name)
// Update preference. // Update preference.
@ -252,7 +259,7 @@ open class ModelManagerViewModel(
_uiState.update { newUiState } _uiState.update { newUiState }
} }
fun initializeModel(context: Context, model: Model, force: Boolean = false) { fun initializeModel(context: Context, task: Task, model: Model, force: Boolean = false) {
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
// Skip if initialized already. // Skip if initialized already.
if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) { if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) {
@ -267,7 +274,7 @@ open class ModelManagerViewModel(
} }
// Clean up. // Clean up.
cleanupModel(model = model) cleanupModel(task = task, model = model)
// Start initialization. // Start initialization.
Log.d(TAG, "Initializing model '${model.name}'...") Log.d(TAG, "Initializing model '${model.name}'...")
@ -301,7 +308,7 @@ open class ModelManagerViewModel(
) )
} }
} }
when (model.taskType) { when (task.type) {
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.initialize( TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.initialize(
context = context, context = context,
model = model, model = model,
@ -320,24 +327,33 @@ open class ModelManagerViewModel(
onDone = onDone, onDone = onDone,
) )
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize( TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
context = context, model = model, onDone = onDone context = context, model = model, onDone = onDone
) )
else -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
} }
} }
} }
fun cleanupModel(model: Model) { fun cleanupModel(task: Task, model: Model) {
if (model.instance != null) { if (model.instance != null) {
Log.d(TAG, "Cleaning up model '${model.name}'...") Log.d(TAG, "Cleaning up model '${model.name}'...")
when (model.taskType) { when (task.type) {
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model) TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model) TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.cleanUp(model = model)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model) TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
else -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
} }
model.instance = null model.instance = null
model.initializing = false model.initializing = false
@ -421,10 +437,11 @@ open class ModelManagerViewModel(
return connection.responseCode return connection.responseCode
} }
fun addImportedLlmModel(task: Task, info: ImportedModelInfo) { fun addImportedLlmModel(info: ImportedModelInfo) {
Log.d(TAG, "adding imported llm model: $info") Log.d(TAG, "adding imported llm model: $info")
// Remove duplicated imported model if existed. // Remove duplicated imported model if existed.
val task = TASK_LLM_CHAT
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) { if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first") Log.d(TAG, "duplicated imported model found in task. Removing it first")
@ -455,6 +472,9 @@ open class ModelManagerViewModel(
) )
} }
task.updateTrigger.value = System.currentTimeMillis() task.updateTrigger.value = System.currentTimeMillis()
// Also need to update single turn task.
TASK_LLM_SINGLE_TURN.updateTrigger.value = System.currentTimeMillis()
// Add to preference storage. // Add to preference storage.
val importedModels = dataStoreRepository.readImportedModels().toMutableList() val importedModels = dataStoreRepository.readImportedModels().toMutableList()
@ -630,33 +650,26 @@ open class ModelManagerViewModel(
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model { private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
val accelerators: List<Accelerator> = (convertValueToTargetType( val accelerators: List<Accelerator> = (convertValueToTargetType(
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING
ValueType.STRING ) as String).split(",").mapNotNull { acceleratorLabel ->
) as String) when (acceleratorLabel.trim()) {
.split(",") Accelerator.GPU.label -> Accelerator.GPU
.mapNotNull { acceleratorLabel -> Accelerator.CPU.label -> Accelerator.CPU
when (acceleratorLabel.trim()) { else -> null // Ignore unknown accelerator labels
Accelerator.GPU.label -> Accelerator.GPU
Accelerator.CPU.label -> Accelerator.CPU
else -> null // Ignore unknown accelerator labels
}
} }
}
val configs: List<Config> = createLlmChatConfigs( val configs: List<Config> = createLlmChatConfigs(
defaultMaxToken = convertValueToTargetType( defaultMaxToken = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!, info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!, ValueType.INT
ValueType.INT
) as Int, ) as Int,
defaultTopK = convertValueToTargetType( defaultTopK = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!, info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!, ValueType.INT
ValueType.INT
) as Int, ) as Int,
defaultTopP = convertValueToTargetType( defaultTopP = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!, info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!, ValueType.FLOAT
ValueType.FLOAT
) as Float, ) as Float,
defaultTemperature = convertValueToTargetType( defaultTemperature = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!, info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!, ValueType.FLOAT
ValueType.FLOAT
) as Float, ) as Float,
accelerators = accelerators, accelerators = accelerators,
) )
@ -666,9 +679,10 @@ open class ModelManagerViewModel(
configs = configs, configs = configs,
sizeInBytes = info.fileSize, sizeInBytes = info.fileSize,
downloadFileName = "$IMPORTS_DIR/${info.fileName}", downloadFileName = "$IMPORTS_DIR/${info.fileName}",
showBenchmarkButton = false,
imported = true, imported = true,
) )
model.preProcess(task = task) model.preProcess()
return model return model
} }
@ -741,7 +755,7 @@ open class ModelManagerViewModel(
val task = TASKS.find { it.type.label == hfModel.task } val task = TASKS.find { it.type.label == hfModel.task }
val model = hfModel.toModel() val model = hfModel.toModel()
if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) { if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) {
model.preProcess(task = task) model.preProcess()
Log.d(TAG, "AG model: $model") Log.d(TAG, "AG model: $model")
task.models.add(model) task.models.add(model)

View file

@ -46,6 +46,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
@ -58,6 +59,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.aiedge.gallery.ui.modelmanager.ModelManager import com.google.aiedge.gallery.ui.modelmanager.ModelManager
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationDestination import com.google.aiedge.gallery.ui.textclassification.TextClassificationDestination
@ -131,7 +134,7 @@ fun GalleryNavHost(
task = curPickedTask, task = curPickedTask,
onModelClicked = { model -> onModelClicked = { model ->
navigateToTaskScreen( navigateToTaskScreen(
navController = navController, taskType = model.taskType!!, model = model navController = navController, taskType = curPickedTask.type, model = model
) )
}, },
navigateUp = { showModelManager = false }) navigateUp = { showModelManager = false })
@ -220,6 +223,24 @@ fun GalleryNavHost(
) )
} }
} }
// LLMm single turn.
composable(
route = "${LlmSingleTurnDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_SINGLE_TURN)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
} }
// Handle incoming intents for deep links // Handle incoming intents for deep links
@ -231,9 +252,10 @@ fun GalleryNavHost(
if (data.toString().startsWith("com.google.aiedge.gallery://model/")) { if (data.toString().startsWith("com.google.aiedge.gallery://model/")) {
val modelName = data.pathSegments.last() val modelName = data.pathSegments.last()
getModelByName(modelName)?.let { model -> getModelByName(modelName)?.let { model ->
// TODO(jingjin): need to show a list of possible tasks for this model.
navigateToTaskScreen( navigateToTaskScreen(
navController = navController, navController = navController,
taskType = model.taskType!!, taskType = TaskType.LLM_CHAT,
model = model model = model
) )
} }
@ -249,6 +271,7 @@ fun navigateToTaskScreen(
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}") TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}") TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}") TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_SINGLE_TURN -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}") TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {} TaskType.TEST_TASK_2 -> {}

View file

@ -0,0 +1,21 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.preview
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
class PreviewLlmSingleTurnViewModel : LlmSingleTurnViewModel(task = TASK_TEST1)

View file

@ -34,7 +34,7 @@ class PreviewModelManagerViewModel(context: Context) :
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) { for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
task.index = index task.index = index
for (model in task.models) { for (model in task.models) {
model.preProcess(task = task) model.preProcess()
} }
} }

View file

@ -25,7 +25,8 @@ val primaryContainerLight = Color(0xFFD0E4FF)
val onPrimaryContainerLight = Color(0xFF144A74) val onPrimaryContainerLight = Color(0xFF144A74)
val secondaryLight = Color(0xFF526070) val secondaryLight = Color(0xFF526070)
val onSecondaryLight = Color(0xFFFFFFFF) val onSecondaryLight = Color(0xFFFFFFFF)
val secondaryContainerLight = Color(0xFFD6E4F7) //val secondaryContainerLight = Color(0xFFD6E4F7)
val secondaryContainerLight = Color(0xFFC2E7FF)
val onSecondaryContainerLight = Color(0xFF3B4857) val onSecondaryContainerLight = Color(0xFF3B4857)
val tertiaryLight = Color(0xFF775A0B) val tertiaryLight = Color(0xFF775A0B)
val onTertiaryLight = Color(0xFFFFFFFF) val onTertiaryLight = Color(0xFFFFFFFF)

View file

@ -116,6 +116,7 @@ data class CustomColors(
val userBubbleBgColor: Color = Color.Transparent, val userBubbleBgColor: Color = Color.Transparent,
val agentBubbleBgColor: Color = Color.Transparent, val agentBubbleBgColor: Color = Color.Transparent,
val linkColor: Color = Color.Transparent, val linkColor: Color = Color.Transparent,
val successColor: Color = Color.Transparent,
) )
val LocalCustomColors = staticCompositionLocalOf { CustomColors() } val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
@ -145,6 +146,7 @@ val lightCustomColors = CustomColors(
agentBubbleBgColor = Color(0xFFe9eef6), agentBubbleBgColor = Color(0xFFe9eef6),
userBubbleBgColor = Color(0xFF32628D), userBubbleBgColor = Color(0xFF32628D),
linkColor = Color(0xFF32628D), linkColor = Color(0xFF32628D),
successColor = Color(0xff3d860b),
) )
val darkCustomColors = CustomColors( val darkCustomColors = CustomColors(
@ -172,6 +174,7 @@ val darkCustomColors = CustomColors(
agentBubbleBgColor = Color(0xFF1b1c1d), agentBubbleBgColor = Color(0xFF1b1c1d),
userBubbleBgColor = Color(0xFF1f3760), userBubbleBgColor = Color(0xFF1f3760),
linkColor = Color(0xFF9DCAFC), linkColor = Color(0xFF9DCAFC),
successColor = Color(0xFFA1CE83),
) )
val MaterialTheme.customColors: CustomColors val MaterialTheme.customColors: CustomColors

View file

@ -92,6 +92,7 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
val connection = url.openConnection() as HttpURLConnection val connection = url.openConnection() as HttpURLConnection
if (accessToken != null) { if (accessToken != null) {
Log.d(TAG, "Using access token: ${accessToken.subSequence(0, 10)}...")
connection.setRequestProperty("Authorization", "Bearer $accessToken") connection.setRequestProperty("Authorization", "Bearer $accessToken")
} }
@ -176,6 +177,7 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
KEY_MODEL_DOWNLOAD_REMAINING_MS, remainingMs.toLong() KEY_MODEL_DOWNLOAD_REMAINING_MS, remainingMs.toLong()
).build() ).build()
) )
Log.d(TAG, "downloadedBytes: $downloadedBytes")
lastSetProgressTs = curTs lastSetProgressTs = curTs
} }
} }