Add support for LLM single-turn experience

This commit is contained in:
Jing Jin 2025-04-26 21:35:03 -07:00
parent 46aaee2654
commit d94fec0674
39 changed files with 2376 additions and 295 deletions

View file

@ -30,7 +30,7 @@ android {
minSdk = 24
targetSdk = 35
versionCode = 1
versionName = "20250421"
versionName = "20250428"
// Needed for HuggingFace auth workflows.
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"

View file

@ -187,6 +187,5 @@ fun GalleryTopAppBar(
else -> {}
}
}
)
}

View file

@ -23,7 +23,7 @@ import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.preferencesDataStore
import com.google.aiedge.gallery.data.AppContainer
import com.google.aiedge.gallery.data.DefaultAppContainer
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.ui.common.processTasks
import com.google.aiedge.gallery.ui.theme.ThemeSettings
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
@ -36,12 +36,7 @@ class GalleryApplication : Application() {
super.onCreate()
// Process tasks.
for ((index, task) in TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess(task = task)
}
}
processTasks()
container = DefaultAppContainer(this, dataStore)

View file

@ -92,15 +92,13 @@ data class Model(
val imported: Boolean = false,
// The following fields are managed by the app. Don't need to set manually.
var taskType: TaskType? = null,
var instance: Any? = null,
var initializing: Boolean = false,
var configValues: Map<String, Any> = mapOf(),
var totalBytes: Long = 0L,
var accessToken: String? = null,
) {
fun preProcess(task: Task) {
this.taskType = task.type
fun preProcess() {
val configValues: MutableMap<String, Any> = mutableMapOf()
for (config in this.configs) {
configValues[config.key.label] = config.defaultValue
@ -246,6 +244,7 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
sizeInBytes = 1354301440L,
configs = createLlmChatConfigs(),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community",
)
@ -256,6 +255,7 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
sizeInBytes = 2627141632L,
configs = createLlmChatConfigs(),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community",
)
@ -271,6 +271,7 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
defaultTopP = 0.95f,
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
llmPromptTemplates = listOf(
@ -299,6 +300,7 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
defaultTopP = 0.7f,
accelerators = listOf(Accelerator.CPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
)
@ -389,7 +391,7 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
)
val MODELS_LLM_CHAT: MutableList<Model> = mutableListOf(
val MODELS_LLM: MutableList<Model> = mutableListOf(
MODEL_LLM_GEMMA_2B_GPU_INT4,
MODEL_LLM_GEMMA_2_2B_GPU_INT8,
MODEL_LLM_GEMMA_3_1B_INT4,

View file

@ -18,9 +18,11 @@ package com.google.aiedge.gallery.data
import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Forum
import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.material.icons.rounded.ImageSearch
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.ui.graphics.vector.ImageVector
import com.google.aiedge.gallery.R
@ -30,6 +32,7 @@ enum class TaskType(val label: String) {
IMAGE_CLASSIFICATION("Image Classification"),
IMAGE_GENERATION("Image Generation"),
LLM_CHAT("LLM Chat"),
LLM_SINGLE_TURN("LLM Use Cases"),
TEST_TASK_1("Test task 1"),
TEST_TASK_2("Test task 2")
@ -67,7 +70,7 @@ data class Task(
// The following fields are managed by the app. Don't need to set manually.
var index: Int = -1,
val updateTrigger: MutableState<Long> = mutableStateOf(0)
val updateTrigger: MutableState<Long> = mutableLongStateOf(0)
)
val TASK_TEXT_CLASSIFICATION = Task(
@ -87,9 +90,19 @@ val TASK_IMAGE_CLASSIFICATION = Task(
val TASK_LLM_CHAT = Task(
type = TaskType.LLM_CHAT,
iconVectorResourceId = R.drawable.chat_spark,
models = MODELS_LLM_CHAT,
description = "Chat? with a on-device large language model",
icon = Icons.Outlined.Forum,
models = MODELS_LLM,
description = "Chat with a on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_LLM_SINGLE_TURN = Task(
type = TaskType.LLM_SINGLE_TURN,
icon = Icons.Outlined.Widgets,
models = MODELS_LLM,
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
@ -108,9 +121,10 @@ val TASK_IMAGE_GENERATION = Task(
/** All tasks. */
val TASKS: List<Task> = listOf(
// TASK_TEXT_CLASSIFICATION,
// TASK_IMAGE_CLASSIFICATION,
TASK_IMAGE_CLASSIFICATION,
TASK_IMAGE_GENERATION,
TASK_LLM_CHAT,
TASK_LLM_SINGLE_TURN,
)
fun getModelByName(name: String): Model? {

View file

@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
@ -56,6 +57,11 @@ object ViewModelProvider {
LlmChatViewModel()
}
// Initializer for LlmSingleTurnViewModel..
initializer {
LlmSingleTurnViewModel()
}
initializer {
ImageGenerationViewModel()
}

View file

@ -0,0 +1,213 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.ArrowDropDown
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelPageAppBar(
task: Task,
model: Model,
modelManagerViewModel: ModelManagerViewModel,
onBackClicked: () -> Unit,
onModelSelected: (Model) -> Unit,
modifier: Modifier = Modifier,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
) {
var showConfigDialog by remember { mutableStateOf(false) }
var showModelPicker by remember { mutableStateOf(false) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val context = LocalContext.current
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
CenterAlignedTopAppBar(title = {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
// Task type.
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
tint = getTaskIconColor(task = task),
modifier = Modifier.size(16.dp),
contentDescription = "",
)
Text(
task.type.label,
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
color = getTaskIconColor(task = task)
)
}
// Model name.
Row(verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
.clickable {
showModelPicker = true
}
.padding(start = 8.dp, end = 2.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text(
model.name,
style = MaterialTheme.typography.labelSmall,
modifier = Modifier.padding(start = 4.dp),
)
Icon(
Icons.Rounded.ArrowDropDown,
modifier = Modifier.size(20.dp),
contentDescription = "",
)
}
}
}, modifier = modifier,
// The back button.
navigationIcon = {
IconButton(onClick = onBackClicked) {
Icon(
imageVector = Icons.AutoMirrored.Rounded.ArrowBack,
contentDescription = "",
)
}
},
// The config button for the model (if existed).
actions = {
if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
IconButton(onClick = { showConfigDialog = true }) {
Icon(
imageVector = Icons.Rounded.Settings,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
})
// Config dialog.
if (showConfigDialog) {
ConfigDialog(
title = "Model configs",
configs = model.configs,
initialValues = model.configValues,
onDismissed = { showConfigDialog = false },
onOk = { curConfigValues ->
// Hide config dialog.
showConfigDialog = false
// Check if the configs are changed or not. Also check if the model needs to be
// re-initialized.
var same = true
var needReinitialization = false
for (config in model.configs) {
val key = config.key.label
val oldValue = convertValueToTargetType(
value = model.configValues.getValue(key), valueType = config.valueType
)
val newValue = convertValueToTargetType(
value = curConfigValues.getValue(key), valueType = config.valueType
)
if (oldValue != newValue) {
same = false
if (config.needReinitialization) {
needReinitialization = true
}
break
}
}
if (same) {
return@ConfigDialog
}
// Save the config values to Model.
val oldConfigValues = model.configValues
model.configValues = curConfigValues
// Force to re-initialize the model with the new configs.
if (needReinitialization) {
modelManagerViewModel.initializeModel(
context = context, task = task, model = model, force = true
)
}
// Notify.
onConfigChanged(oldConfigValues, model.configValues)
},
)
}
// Model picker.
if (showModelPicker) {
ModalBottomSheet(
onDismissRequest = { showModelPicker = false },
sheetState = sheetState,
) {
ModelPicker(
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = { model ->
showModelPicker = false
onModelSelected(model)
}
)
}
}
}

View file

@ -0,0 +1,132 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.CheckCircle
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
@Composable
fun ModelPicker(
task: Task,
modelManagerViewModel: ModelManagerViewModel,
onModelSelected: (Model) -> Unit
) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
Column(modifier = Modifier.padding(bottom = 8.dp)) {
// Title
Row(
modifier = Modifier
.padding(horizontal = 16.dp)
.padding(top = 4.dp, bottom = 4.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
) {
Icon(
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
tint = getTaskIconColor(task = task),
modifier = Modifier.size(16.dp),
contentDescription = "",
)
Text(
"${task.type.label} models",
modifier = Modifier.fillMaxWidth(),
style = MaterialTheme.typography.titleMedium,
color = getTaskIconColor(task = task),
)
}
// Model list.
for (model in task.models) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier
.fillMaxWidth()
.clickable {
onModelSelected(model)
}
.padding(horizontal = 16.dp, vertical = 8.dp),
) {
Spacer(modifier = Modifier.width(24.dp))
Column(modifier = Modifier.weight(1f)) {
Text(model.name, style = MaterialTheme.typography.bodyMedium)
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text(
model.sizeInBytes.humanReadableSize(),
color = MaterialTheme.colorScheme.secondary,
style = labelSmallNarrow.copy(lineHeight = 10.sp)
)
}
}
if (model.name == modelManagerUiState.selectedModel.name) {
Icon(
Icons.Outlined.CheckCircle,
modifier = Modifier.size(16.dp),
contentDescription = ""
)
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ModelPickerPreview() {
val context = LocalContext.current
GalleryTheme {
ModelPicker(
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
onModelSelected = {},
)
}
}

View file

@ -29,6 +29,7 @@ import androidx.core.content.ContextCompat
import androidx.core.content.FileProvider
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkResult
@ -57,6 +58,9 @@ interface LatencyProvider {
val latencyMs: Float
}
private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***"
/** Format the bytes into a human-readable format. */
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
val bytes = this
@ -452,3 +456,33 @@ fun cleanUpMediapipeTaskErrorMessage(message: String): String {
}
return message
}
fun processTasks() {
for ((index, task) in TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess()
}
}
}
fun processLlmResponse(response: String): String {
// Add "thinking" and "done thinking" around the thinking content.
var newContent = response
.replace("<think>", "$START_THINKING\n")
.replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) {
val thinkingContent =
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
.replace(START_THINKING, "")
.replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
}
}
return newContent
}

View file

@ -38,6 +38,7 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.layout.wrapContentWidth
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState
@ -334,7 +335,10 @@ fun ChatPanel(
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
// Benchmark LLM result.
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(message = message)
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
message = message,
modifier = Modifier.wrapContentWidth()
)
else -> {}
}
@ -346,7 +350,7 @@ fun ChatPanel(
) {
LatencyText(message = message)
// A button to show stats for the LLM message.
if (selectedModel.taskType == TaskType.LLM_CHAT && message is ChatMessageText
if (task.type == TaskType.LLM_CHAT && message is ChatMessageText
// This means we only want to show the action button when the message is done
// generating, at which point the latency will be set.
&& message.latencyMs >= 0
@ -403,21 +407,17 @@ fun ChatPanel(
}
// Benchmark button
// if (selectedModel.showBenchmarkButton) {
// MessageActionButton(
// label = stringResource(R.string.benchmark),
// icon = Icons.Outlined.Timer,
// onClick = {
// if (selectedModel.taskType == TaskType.LLM_CHAT) {
// onBenchmarkClicked(selectedModel, message, 0, 0)
// } else {
// showBenchmarkConfigsDialog = true
// benchmarkMessage.value = message
// }
// },
// enabled = !uiState.inProgress
// )
// }
if (selectedModel.showBenchmarkButton) {
MessageActionButton(
label = stringResource(R.string.benchmark),
icon = Icons.Outlined.Timer,
onClick = {
showBenchmarkConfigsDialog = true
benchmarkMessage.value = message
},
enabled = !uiState.inProgress
)
}
}
}
}
@ -443,7 +443,7 @@ fun ChatPanel(
// Chat input
when (chatInputType) {
ChatInputType.TEXT -> {
val isLlmTask = selectedModel.taskType == TaskType.LLM_CHAT
val isLlmTask = task.type == TaskType.LLM_CHAT
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
MessageInputText(
modelManagerViewModel = modelManagerViewModel,

View file

@ -18,18 +18,10 @@ package com.google.aiedge.gallery.ui.common.chat
import android.util.Log
import androidx.activity.compose.BackHandler
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState
@ -40,29 +32,21 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import com.google.aiedge.gallery.GalleryTopAppBar
import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlin.math.absoluteValue
@ -77,7 +61,6 @@ private const val TAG = "AGChatView"
* manages model initialization, cleanup, and download status, and handles navigation and system
* back gestures.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ChatView(
task: Task,
@ -96,34 +79,29 @@ fun ChatView(
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel
val pagerState = rememberPagerState(initialPage = task.models.indexOf(selectedModel),
val pagerState = rememberPagerState(
initialPage = task.models.indexOf(selectedModel),
pageCount = { task.models.size })
val context = LocalContext.current
val scope = rememberCoroutineScope()
val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(task = task, model = selectedModel)
}
val handleNavigateUp = {
navigateUp()
// clean up all models.
scope.launch(Dispatchers.Default) {
for (model in task.models) {
modelManagerViewModel.cleanupModel(model = model)
modelManagerViewModel.cleanupModel(task = task, model = model)
}
}
}
// Initialize model when model/download state changes.
val status = modelManagerUiState.modelDownloadStatus[selectedModel.name]
LaunchedEffect(status, selectedModel.name) {
if (status?.status == ModelDownloadStatusType.SUCCEEDED) {
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
LaunchedEffect(curDownloadStatus, selectedModel.name) {
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
modelManagerViewModel.initializeModel(context, model = selectedModel)
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
}
}
@ -135,7 +113,7 @@ fun ChatView(
"Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model."
)
if (curSelectedModel.name != selectedModel.name) {
modelManagerViewModel.cleanupModel(model = selectedModel)
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
}
modelManagerViewModel.selectModel(curSelectedModel)
}
@ -146,24 +124,36 @@ fun ChatView(
}
Scaffold(modifier = modifier, topBar = {
GalleryTopAppBar(
title = task.type.label,
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = {
ModelPageAppBar(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage(
oldConfigValues = old,
newConfigValues = new,
model = selectedModel
)
},
onBackClicked = {
handleNavigateUp()
}),
rightAction = AppBarAction(actionType = AppBarActionType.NO_ACTION, actionFn = {}),
},
onModelSelected = { model ->
scope.launch {
pagerState.animateScrollToPage(task.models.indexOf(model))
}
},
)
}) { innerPadding ->
Box {
// A horizontal scrollable pager to switch between models.
HorizontalPager(state = pagerState) { pageIndex ->
val curSelectedModel = task.models[pageIndex]
val curModelDownloadStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name]
// Calculate the alpha of the current page based on how far they are from the center.
val pageOffset = (
(pagerState.currentPage - pageIndex) + pagerState
.currentPageOffsetFraction
).absoluteValue
val pageOffset =
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
val curAlpha = 1f - pageOffset.coerceIn(0f, 1f)
Column(
@ -172,91 +162,14 @@ fun ChatView(
.fillMaxSize()
.background(MaterialTheme.colorScheme.surface)
) {
// Model selector at the top.
ModelSelector(
ModelDownloadStatusInfoPanel(
model = curSelectedModel,
task = task,
modelManagerViewModel = modelManagerViewModel,
onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage(
oldConfigValues = old,
newConfigValues = new,
model = curSelectedModel
modelManagerViewModel = modelManagerViewModel
)
},
modifier = Modifier.fillMaxWidth(),
contentAlpha = curAlpha,
)
// Manages the conditional display of UI elements (download model button and downloading
// animation) based on the corresponding download status.
//
// It uses delayed visibility ensuring they are shown only after a short delay if their
// respective conditions remain true. This prevents UI flickering and provides a smoother
// user experience.
val curStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name]
var shouldShowDownloadingAnimation by remember { mutableStateOf(false) }
var downloadingAnimationConditionMet by remember { mutableStateOf(false) }
var shouldShowDownloadModelButton by remember { mutableStateOf(false) }
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
downloadingAnimationConditionMet =
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS ||
curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED ||
curStatus?.status == ModelDownloadStatusType.UNZIPPING
downloadModelButtonConditionMet =
curStatus?.status == ModelDownloadStatusType.FAILED ||
curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
LaunchedEffect(downloadingAnimationConditionMet) {
if (downloadingAnimationConditionMet) {
delay(100)
shouldShowDownloadingAnimation = true
} else {
shouldShowDownloadingAnimation = false
}
}
LaunchedEffect(downloadModelButtonConditionMet) {
if (downloadModelButtonConditionMet) {
delay(700)
shouldShowDownloadModelButton = true
} else {
shouldShowDownloadModelButton = false
}
}
AnimatedVisibility(
visible = shouldShowDownloadingAnimation,
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
exit = scaleOut(targetScale = 0.9f) + fadeOut()
) {
Box(
modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.Center
) {
ModelDownloadingAnimation()
}
}
AnimatedVisibility(
visible = shouldShowDownloadModelButton,
enter = fadeIn(),
exit = fadeOut()
) {
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
checkNotificationPermissionAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = curSelectedModel
)
})
}
// The main messages panel.
if (curStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
if (curModelDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
ChatPanel(
modelManagerViewModel = modelManagerViewModel,
task = task,

View file

@ -20,13 +20,12 @@ import android.util.Log
import androidx.lifecycle.ViewModel
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.processLlmResponse
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
private const val TAG = "AGChatViewModel"
private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***"
data class ChatUiState(
/**
@ -121,26 +120,7 @@ open class ChatViewModel(val task: Task) : ViewModel() {
if (newMessages.size > 0) {
val lastMessage = newMessages.last()
if (lastMessage is ChatMessageText) {
var newContent = "${lastMessage.content}${partialContent}"
// TODO: special handling for deepseek to remove the <think> tag.
// Add "thinking" and "done thinking" around the thinking content.
newContent = newContent
.replace("<think>", "$START_THINKING\n")
.replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) {
val thinkingContent =
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
.replace(START_THINKING, "")
.replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
}
}
val newContent = processLlmResponse(response = "${lastMessage.content}${partialContent}")
val newLastMessage = ChatMessageText(
content = newContent,
side = lastMessage.side,

View file

@ -45,7 +45,7 @@ fun MarkdownText(
ProvideTextStyle(
value = TextStyle(
fontSize = fontSize,
lineHeight = fontSize * 1.2,
lineHeight = fontSize * 1.4,
)
) {
RichText(

View file

@ -48,15 +48,16 @@ fun MessageActionButton(
label: String,
icon: ImageVector,
onClick: () -> Unit,
modifier: Modifier = Modifier,
enabled: Boolean = true
) {
val modifier = Modifier
val curModifier = modifier
.padding(top = 4.dp)
.clip(CircleShape)
.background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh)
val alpha: Float = if (enabled) 1.0f else 0.3f
Row(
modifier = if (enabled) modifier.clickable { onClick() } else modifier,
modifier = if (enabled) curModifier.clickable { onClick() } else modifier,
verticalAlignment = Alignment.CenterVertically,
) {
Icon(

View file

@ -19,8 +19,8 @@ package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentWidth
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview
@ -33,16 +33,14 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
* This function renders benchmark statistics (e.g., various token speed) in data cards
*/
@Composable
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult) {
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult, modifier: Modifier = Modifier) {
Column(
modifier = Modifier
.padding(12.dp)
.wrapContentWidth(),
modifier = modifier.padding(12.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
// Data cards.
Row(
modifier = Modifier.wrapContentWidth(), horizontalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) {
for (stat in message.orderedStats) {
DataCard(

View file

@ -82,7 +82,7 @@ fun MessageBodyPromptTemplates(
style = MaterialTheme.typography.titleSmall,
modifier = Modifier
.fillMaxWidth()
.offset(y = -4.dp),
.offset(y = (-4).dp),
textAlign = TextAlign.Center,
)
}
@ -140,7 +140,7 @@ fun MessageBodyPromptTemplatesPreview() {
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess(task = task)
model.preProcess()
}
}

View file

@ -70,7 +70,6 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
* This function renders a row containing a text field for message input and a send button.
* It handles message composition, input validation, and sending messages.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun MessageInputText(
modelManagerViewModel: ModelManagerViewModel,
@ -190,7 +189,7 @@ fun MessageInputText(
Icons.AutoMirrored.Rounded.Send,
contentDescription = "",
modifier = Modifier.offset(x = 2.dp),
tint = if (inProgress) MaterialTheme.colorScheme.surfaceContainerHigh else MaterialTheme.colorScheme.primary
tint = MaterialTheme.colorScheme.primary
)
}
}

View file

@ -0,0 +1,119 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.coroutines.delay
@Composable
fun ModelDownloadStatusInfoPanel(
model: Model,
task: Task,
modelManagerViewModel: ModelManagerViewModel
) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
// Manages the conditional display of UI elements (download model button and downloading
// animation) based on the corresponding download status.
//
// It uses delayed visibility ensuring they are shown only after a short delay if their
// respective conditions remain true. This prevents UI flickering and provides a smoother
// user experience.
val curStatus = modelManagerUiState.modelDownloadStatus[model.name]
var shouldShowDownloadingAnimation by remember { mutableStateOf(false) }
var downloadingAnimationConditionMet by remember { mutableStateOf(false) }
var shouldShowDownloadModelButton by remember { mutableStateOf(false) }
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
downloadingAnimationConditionMet =
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS || curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED || curStatus?.status == ModelDownloadStatusType.UNZIPPING
downloadModelButtonConditionMet =
curStatus?.status == ModelDownloadStatusType.FAILED || curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
LaunchedEffect(downloadingAnimationConditionMet) {
if (downloadingAnimationConditionMet) {
delay(100)
shouldShowDownloadingAnimation = true
} else {
shouldShowDownloadingAnimation = false
}
}
LaunchedEffect(downloadModelButtonConditionMet) {
if (downloadModelButtonConditionMet) {
delay(700)
shouldShowDownloadModelButton = true
} else {
shouldShowDownloadModelButton = false
}
}
AnimatedVisibility(
visible = shouldShowDownloadingAnimation,
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
exit = scaleOut(targetScale = 0.9f) + fadeOut()
) {
Box(
modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center
) {
ModelDownloadingAnimation(
model = model, task = task, modelManagerViewModel = modelManagerViewModel
)
}
}
AnimatedVisibility(
visible = shouldShowDownloadModelButton, enter = fadeIn(), exit = fadeOut()
) {
Column(
modifier = Modifier.fillMaxSize(),
verticalArrangement = Arrangement.Center,
horizontalAlignment = Alignment.CenterHorizontally
) {
DownloadAndTryButton(
task = task,
model = model,
enabled = true,
needToDownloadFirst = true,
modelManagerViewModel = modelManagerViewModel,
onClicked = {}
)
}
}
}

View file

@ -24,6 +24,7 @@ import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
@ -32,23 +33,40 @@ import androidx.compose.foundation.layout.width
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.itemsIndexed
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.formatToHourMinSecond
import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.common.humanReadableSize
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.MODEL_TEST1
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
import kotlinx.coroutines.delay
import kotlin.math.cos
import kotlin.math.pow
@ -66,8 +84,19 @@ private const val END_SCALE = 0.6f
* scaling and rotation effect.
*/
@Composable
fun ModelDownloadingAnimation() {
fun ModelDownloadingAnimation(
model: Model,
task: Task,
modelManagerViewModel: ModelManagerViewModel
) {
val scale = remember { Animatable(END_SCALE) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val downloadStatus by remember {
derivedStateOf { modelManagerUiState.modelDownloadStatus[model.name] }
}
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
var curDownloadProgress = 0f
LaunchedEffect(Unit) { // Run this once
while (true) {
@ -93,6 +122,21 @@ fun ModelDownloadingAnimation() {
}
}
// Failure message.
val curDownloadStatus = downloadStatus
if (curDownloadStatus != null && curDownloadStatus.status == ModelDownloadStatusType.FAILED) {
Row(verticalAlignment = Alignment.CenterVertically) {
Text(
curDownloadStatus.errorMessage,
color = MaterialTheme.colorScheme.error,
style = labelSmallNarrow,
overflow = TextOverflow.Ellipsis,
)
}
}
// No failure
else {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier.offset(y = -GRID_SIZE / 8)
@ -146,6 +190,78 @@ fun ModelDownloadingAnimation() {
}
}
// Download stats
var sizeLabel = model.totalBytes.humanReadableSize()
if (curDownloadStatus != null) {
// For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime}
if (inProgress || isPartiallyDownloaded) {
var totalSize = curDownloadStatus.totalBytes
if (totalSize == 0L) {
totalSize = model.totalBytes
}
sizeLabel =
"${curDownloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
if (curDownloadStatus.bytesPerSecond > 0) {
sizeLabel =
"$sizeLabel · ${curDownloadStatus.bytesPerSecond.humanReadableSize()} / s"
if (curDownloadStatus.remainingMs >= 0) {
sizeLabel =
"$sizeLabel · ${curDownloadStatus.remainingMs.formatToHourMinSecond()} left"
}
}
if (isPartiallyDownloaded) {
sizeLabel = "$sizeLabel (resuming...)"
}
curDownloadProgress =
curDownloadStatus.receivedBytes.toFloat() / curDownloadStatus.totalBytes.toFloat()
if (curDownloadProgress.isNaN()) {
curDownloadProgress = 0f
}
}
// Status for unzipping.
else if (curDownloadStatus.status == ModelDownloadStatusType.UNZIPPING) {
sizeLabel = "Unzipping..."
}
Text(
sizeLabel,
color = MaterialTheme.colorScheme.secondary,
style = labelSmallNarrow.copy(fontSize = 9.sp, lineHeight = 10.sp),
textAlign = TextAlign.Center,
overflow = TextOverflow.Visible,
modifier = Modifier
.padding(bottom = 4.dp)
)
}
// Download progress.
if (inProgress || isPartiallyDownloaded) {
val animatedProgress = remember { Animatable(0f) }
LinearProgressIndicator(
progress = { animatedProgress.value },
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 36.dp)
.padding(horizontal = 36.dp)
)
LaunchedEffect(curDownloadProgress) {
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
}
}
// Unzipping progress.
else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) {
LinearProgressIndicator(
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 36.dp)
.padding(horizontal = 36.dp)
)
}
Text(
"Feel free to switch apps or lock your device.\n"
+ "The download will continue in the background.\n"
@ -156,6 +272,8 @@ fun ModelDownloadingAnimation() {
}
}
}
// Custom Easing function for a multi-bounce effect
fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
if (x == 1f) {
@ -168,9 +286,15 @@ fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
@Preview(showBackground = true)
@Composable
fun ModelDownloadingAnimationPreview() {
val context = LocalContext.current
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
ModelDownloadingAnimation()
ModelDownloadingAnimation(
model = MODEL_TEST1,
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = context)
)
}
}
}

View file

@ -63,7 +63,8 @@ fun ModelSelector(
) {
Box(
modifier = Modifier
.fillMaxWidth().padding(bottom = 8.dp),
.fillMaxWidth()
.padding(bottom = 8.dp),
contentAlignment = Alignment.Center
) {
// Model row.
@ -134,7 +135,12 @@ fun ModelSelector(
// Force to re-initialize the model with the new configs.
if (needReinitialization) {
modelManagerViewModel.initializeModel(context = context, model = model, force = true)
modelManagerViewModel.initializeModel(
context = context,
task = task,
model = model,
force = true
)
}
// Notify.

View file

@ -181,7 +181,5 @@ fun ModelNameAndStatus(
.padding(top = 2.dp),
)
}
}
}

View file

@ -23,8 +23,10 @@ import androidx.compose.foundation.layout.size
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.HelpOutline
import androidx.compose.material.icons.filled.DownloadForOffline
import androidx.compose.material.icons.rounded.Downloading
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
@ -34,6 +36,7 @@ import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
/**
* Composable function to display an icon representing the download status of a model.
@ -56,7 +59,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
ModelDownloadStatusType.SUCCEEDED -> {
Icon(
Icons.Filled.DownloadForOffline,
tint = Color(0xff3d860b),
tint = MaterialTheme.customColors.successColor,
contentDescription = "",
modifier = Modifier.size(14.dp)
)
@ -69,6 +72,12 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
modifier = Modifier.size(14.dp)
)
ModelDownloadStatusType.IN_PROGRESS -> Icon(
Icons.Rounded.Downloading,
contentDescription = "",
modifier = Modifier.size(14.dp)
)
else -> {}
}
}

View file

@ -73,7 +73,6 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.scale
import androidx.compose.ui.focus.focusModifier
import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.layout.layout
@ -91,7 +90,6 @@ import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.TaskIcon
import com.google.aiedge.gallery.ui.common.getTaskBgColor
@ -275,7 +273,6 @@ fun HomeScreen(
onDismiss = { showImportingDialog = false },
onDone = {
modelManagerViewModel.addImportedLlmModel(
task = TASK_LLM_CHAT,
info = it,
)
showImportingDialog = false

View file

@ -30,7 +30,7 @@ private const val TAG = "AGLlmChatModelHelper"
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
typealias CleanUpListener = () -> Unit
data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceSession)
data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession)
object LlmChatModelHelper {
// Indexed by model name.
@ -74,6 +74,24 @@ object LlmChatModelHelper {
onDone("")
}
fun resetSession(model: Model) {
val instance = model.instance as LlmModelInstance? ?: return
val session = instance.session
session.close()
val inference = instance.engine
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
val newSession = LlmInferenceSession.createFromOptions(
inference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature).build()
)
instance.session = newSession
}
fun cleanUp(model: Model) {
if (model.instance == null) {
return
@ -99,7 +117,11 @@ object LlmChatModelHelper {
input: String,
resultListener: ResultListener,
cleanUpListener: CleanUpListener,
singleTurn: Boolean = false,
) {
if (singleTurn) {
resetSession(model = model)
}
val instance = model.instance as LlmModelInstance
// Set listener.

View file

@ -32,7 +32,7 @@ import kotlinx.coroutines.launch
private const val TAG = "AGLlmChatViewModel"
private val STATS = listOf(
Stat(id = "time_to_first_token", label = "Time to 1st token", unit = "sec"),
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec")

View file

@ -0,0 +1,206 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import android.util.Log
import androidx.activity.compose.BackHandler
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.calculateStartPadding
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection
import androidx.compose.ui.tooling.preview.Preview
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable
/** Navigation destination data */
object LlmSingleTurnDestination {
@Serializable
val route = "LlmSingleTurnRoute"
}
private const val TAG = "AGLlmSingleTurnScreen"
@Composable
fun LlmSingleTurnScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmSingleTurnViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
val task = viewModel.task
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel
val scope = rememberCoroutineScope()
val context = LocalContext.current
val handleNavigateUp = {
navigateUp()
// clean up all models.
scope.launch(Dispatchers.Default) {
for (model in task.models) {
modelManagerViewModel.cleanupModel(task = task, model = model)
}
}
}
// Handle system's edge swipe.
BackHandler {
handleNavigateUp()
}
// Initialize model when model/download state changes.
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
LaunchedEffect(curDownloadStatus, selectedModel.name) {
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
Log.d(
TAG,
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
)
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
}
}
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[selectedModel.name]
Scaffold(modifier = modifier, topBar = {
ModelPageAppBar(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
onConfigChanged = { _, _ -> },
onBackClicked = { handleNavigateUp() },
onModelSelected = { newSelectedModel ->
scope.launch(Dispatchers.Default) {
// Clean up current model.
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
// Update selected model.
modelManagerViewModel.selectModel(model = newSelectedModel)
}
}
)
}) { innerPadding ->
Column(
modifier = Modifier.padding(
top = innerPadding.calculateTopPadding(),
start = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
end = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
)
) {
ModelDownloadStatusInfoPanel(
model = selectedModel,
task = task,
modelManagerViewModel = modelManagerViewModel
)
// Main UI after model is downloaded.
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
Box(
contentAlignment = Alignment.BottomCenter,
modifier = Modifier.weight(1f)
) {
VerticalSplitView(modifier = Modifier.fillMaxSize(),
topView = {
PromptTemplatesPanel(
model = selectedModel,
viewModel = viewModel,
onSend = { fullPrompt ->
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
}, modifier = Modifier.fillMaxSize()
)
},
bottomView = {
Box(
contentAlignment = Alignment.BottomCenter,
modifier = Modifier
.fillMaxSize()
.background(MaterialTheme.customColors.agentBubbleBgColor)
) {
ResponsePanel(
model = selectedModel,
viewModel = viewModel,
modifier = Modifier
.fillMaxSize()
.padding(bottom = innerPadding.calculateBottomPadding())
)
}
})
// Model initialization in-progress message.
this@Column.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding())
) {
ModelInitializationStatusChip()
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun LlmSingleTurnScreenPreview() {
val context = LocalContext.current
GalleryTheme {
LlmSingleTurnScreen(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
viewModel = PreviewLlmSingleTurnViewModel(),
navigateUp = {},
)
}
}

View file

@ -0,0 +1,210 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.Stat
import com.google.aiedge.gallery.ui.common.processLlmResponse
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.aiedge.gallery.ui.llmchat.LlmModelInstance
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
private const val TAG = "AGLlmSingleTurnViewModel"
data class LlmSingleTurnUiState(
/**
* Indicates whether the runtime is currently processing a message.
*/
val inProgress: Boolean = false,
/**
* Indicates whether the model is currently being initialized.
*/
val initializing: Boolean = false,
// model -> <template label -> response>
val responsesByModel: Map<String, Map<String, String>>,
// model -> <template label -> benchmark result>
val benchmarkByModel: Map<String, Map<String, ChatMessageBenchmarkLlmResult>>,
/** Selected prompt template type. */
val selectedPromptTemplateType: PromptTemplateType = PromptTemplateType.entries[0],
)
private val STATS = listOf(
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec")
)
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_SINGLE_TURN) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow()
fun generateResponse(model: Model, input: String) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
setInitializing(true)
// Wait for instance to be initialized.
while (model.instance == null) {
delay(100)
}
// Run inference.
val instance = model.instance as LlmModelInstance
val prefillTokens = instance.session.sizeInTokens(input)
var firstRun = true
var timeToFirstToken = 0f
var firstTokenTs = 0L
var decodeTokens = 0
var prefillSpeed = 0f
var decodeSpeed: Float
val start = System.currentTimeMillis()
var response = ""
var lastBenchmarkUpdateTs = 0L
LlmChatModelHelper.runInference(model = model,
input = input,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
if (firstRun) {
setInitializing(false)
firstTokenTs = System.currentTimeMillis()
timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken
firstRun = false
} else {
decodeTokens++
}
// Incrementally update the streamed partial results.
response = processLlmResponse(response = "$response$partialResult")
// Update response.
updateResponse(
model = model,
promptTemplateType = uiState.value.selectedPromptTemplateType,
response = response
)
// Update benchmark (with throttling).
if (curTs - lastBenchmarkUpdateTs > 200) {
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) {
decodeSpeed = 0f
}
val benchmark = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = !done,
latencyMs = -1f,
)
updateBenchmark(
model = model,
promptTemplateType = uiState.value.selectedPromptTemplateType,
benchmark = benchmark
)
lastBenchmarkUpdateTs = curTs
}
if (done) {
setInProgress(false)
}
},
singleTurn = true,
cleanUpListener = {
setInitializing(false)
setInProgress(false)
})
}
}
fun selectPromptTemplate(model: Model, promptTemplateType: PromptTemplateType) {
Log.d(TAG, "selecting prompt template: ${promptTemplateType.label}")
// Clear response.
updateResponse(model = model, promptTemplateType = promptTemplateType, response = "")
this._uiState.update { this.uiState.value.copy(selectedPromptTemplateType = promptTemplateType) }
}
fun setInProgress(inProgress: Boolean) {
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
}
fun setInitializing(initializing: Boolean) {
_uiState.update { _uiState.value.copy(initializing = initializing) }
}
fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) {
_uiState.update { currentState ->
val currentResponses = currentState.responsesByModel
val modelResponses = currentResponses[model.name]?.toMutableMap() ?: mutableMapOf()
modelResponses[promptTemplateType.label] = response
val newResponses = currentResponses.toMutableMap()
newResponses[model.name] = modelResponses
currentState.copy(responsesByModel = newResponses)
}
}
fun updateBenchmark(
model: Model, promptTemplateType: PromptTemplateType, benchmark: ChatMessageBenchmarkLlmResult
) {
_uiState.update { currentState ->
val currentBenchmark = currentState.benchmarkByModel
val modelBenchmarks = currentBenchmark[model.name]?.toMutableMap() ?: mutableMapOf()
modelBenchmarks[promptTemplateType.label] = benchmark
val newBenchmarks = currentBenchmark.toMutableMap()
newBenchmarks[model.name] = modelBenchmarks
currentState.copy(benchmarkByModel = newBenchmarks)
}
}
private fun createUiState(task: Task): LlmSingleTurnUiState {
val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf()
val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> =
mutableMapOf()
for (model in task.models) {
responsesByModel[model.name] = mutableMapOf()
benchmarkByModel[model.name] = mutableMapOf()
}
return LlmSingleTurnUiState(
responsesByModel = responsesByModel,
benchmarkByModel = benchmarkByModel,
)
}
}

View file

@ -0,0 +1,185 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.buildAnnotatedString
import androidx.compose.ui.text.withStyle
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.graphics.Brush.Companion.linearGradient
enum class PromptTemplateInputEditorType {
SINGLE_SELECT
}
enum class RewriteToneType(val label: String) {
FORMAL(label = "Formal"), CASUAL(label = "Casual"), FRIENDLY(label = "Friendly"), POLITE(label = "Polite"), ENTHUSIASTIC(
label = "Enthusiastic"
),
CONCISE(label = "Concise"),
}
enum class SummarizationType(val label: String) {
KEY_BULLET_POINT(label = "Key bullet points (3-5)"),
SHORT_PARAGRAPH(label = "Short paragraph (1-2 sentences)"),
CONCISE_SUMMARY(label = "Concise summary (~50 words)"),
HEADLINE_TITLE(label = "Headline / title"),
ONE_SENTENCE_SUMMARY(label = "One-sentence summary"),
}
enum class LanguageType(val label: String) {
CPP(label = "C++"),
JAVA(label = "Java"),
JAVASCRIPT(label = "JavaScript"),
KOTLIN(label = "Kotlin"),
PYTHON(label = "Python"),
SWIFT(label = "Swift"),
TYPESCRIPT(label = "TypeScript"),
}
enum class InputEditorLabel(val label: String) {
TONE(label = "Tone"),
STYLE(label = "Style"),
LANGUAGE(label = "Language"),
}
open class PromptTemplateInputEditor(
open val label: String,
open val type: PromptTemplateInputEditorType,
open val defaultOption: String = "",
)
/** Single select that shows options in bottom sheet. */
class PromptTemplateSingleSelectInputEditor(
override val label: String,
val options: List<String> = listOf(),
override val defaultOption: String = "",
) : PromptTemplateInputEditor(
label = label, type = PromptTemplateInputEditorType.SINGLE_SELECT, defaultOption = defaultOption
)
data class PromptTemplateConfig(val inputEditors: List<PromptTemplateInputEditor> = listOf())
private val GEMINI_GRADIENT_STYLE = SpanStyle(
brush = linearGradient(
colors = listOf(Color(0xFF4285f4), Color(0xFF9b72cb), Color(0xFFd96570))
)
)
enum class PromptTemplateType(
val label: String,
val config: PromptTemplateConfig,
val genFullPrompt: (userInput: String, inputEditorValues: Map<String, Any>) -> AnnotatedString = { _, _ ->
AnnotatedString("")
},
val examplePrompts: List<String> = listOf(),
) {
FREE_FORM(
label = "Free form",
config = PromptTemplateConfig(),
genFullPrompt = { userInput, _ -> AnnotatedString(userInput) },
examplePrompts = listOf(
"Suggest 3 topics for a podcast about \"Friendships in your 20s\".",
"Outline the key sections needed in a basic logo design brief.",
"List 3 pros and 3 cons to consider before buying a smart watch.",
"Write a short, optimistic quote about the future of technology.",
"Generate 3 potential names for a mobile app that helps users identify plants.",
"Explain the difference between AI and machine learning in 2 sentences.",
"Create a simple haiku about a cat sleeping in the sun.",
"List 3 ways to make instant noodles taste better using common kitchen ingredients."
)
),
REWRITE_TONE(
label = "Rewrite tone", config = PromptTemplateConfig(
inputEditors = listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.TONE.label,
options = RewriteToneType.entries.map { it.label },
defaultOption = RewriteToneType.FORMAL.label
)
)
), genFullPrompt = { userInput, inputEditorValues ->
val tone = inputEditorValues[InputEditorLabel.TONE.label] as String
buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) {
append("Rewrite the following text using a ${tone.lowercase()} tone: ")
}
append(userInput)
}
}, examplePrompts = listOf(
"Hey team, just wanted to remind everyone about the meeting tomorrow @ 10. Be there!",
"Our new software update includes several bug fixes and performance improvements.",
"Due to the fact that the weather was bad, we decided to postpone the event.",
"Please find attached the requested documentation for your perusal.",
"Welcome to the team. Review the onboarding materials.",
)
),
SUMMARIZE_TEXT(
label = "Summarize text",
config = PromptTemplateConfig(
inputEditors = listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.STYLE.label,
options = SummarizationType.entries.map { it.label },
defaultOption = SummarizationType.KEY_BULLET_POINT.label
)
)
),
genFullPrompt = { userInput, inputEditorValues ->
val style = inputEditorValues[InputEditorLabel.STYLE.label] as String
buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) {
append("Please summarize the following in ${style.lowercase()}: ")
}
append(userInput)
}
},
examplePrompts = listOf(
"The new Pixel phone features an advanced camera system with improved low-light performance and AI-powered editing tools. The display is brighter and more energy-efficient. It runs on the latest Tensor chip, offering faster processing and enhanced security features. Battery life has also been extended, providing all-day power for most users.",
"Beginning this Friday, January 24, giant pandas Bao Li and Qing Bao are officially on view to the public at the Smithsonians National Zoo and Conservation Biology Institute (NZCBI). The 3-year-old bears arrived in Washington this past October, undergoing a quarantine period before making their debut. Under NZCBIs new agreement with the CWCA, Qing Bao and Bao Li will remain in the United States for ten years, until April 2034, in exchange for an annual fee of \$1 million. The pair are still too young to breed, as pandas only reach sexual maturity between ages 4 and 7. “Kind of picture them as like awkward teenagers right now,” Lally told WUSA9. “We still have about two years before we would probably even see signs that theyre ready to start mating.”",
),
),
CODE_SNIPPET(
label = "Code snippet",
config = PromptTemplateConfig(
inputEditors = listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.LANGUAGE.label,
options = LanguageType.entries.map { it.label },
defaultOption = LanguageType.JAVASCRIPT.label
)
)
),
genFullPrompt = { userInput, inputEditorValues ->
val language = inputEditorValues[InputEditorLabel.LANGUAGE.label] as String
buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) {
append("Write a $language code snippet to ")
}
append(userInput)
}
},
examplePrompts = listOf(
"Create an alert box that says \"Hello, World!\"",
"Declare an immutable variable named 'appName' with the value \"AI Gallery\"",
"Print the numbers from 1 to 5 using a for loop.",
"Write a function that returns the square of an integer input.",
),
),
}

View file

@ -0,0 +1,426 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.Send
import androidx.compose.material.icons.outlined.ContentCopy
import androidx.compose.material.icons.outlined.Description
import androidx.compose.material.icons.outlined.ExpandLess
import androidx.compose.material.icons.outlined.ExpandMore
import androidx.compose.material.icons.rounded.Add
import androidx.compose.material.icons.rounded.Visibility
import androidx.compose.material.icons.rounded.VisibilityOff
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.FilterChipDefaults
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.OutlinedIconButton
import androidx.compose.material3.PrimaryScrollableTabRow
import androidx.compose.material3.Tab
import androidx.compose.material3.Text
import androidx.compose.material3.TextField
import androidx.compose.material3.TextFieldDefaults
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateMapOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshots.SnapshotStateMap
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.text.TextLayoutResult
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape
import com.google.aiedge.gallery.ui.theme.customColors
private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries
private val TAB_TITLES = PromptTemplateType.entries.map { it.label }
private val ICON_BUTTON_SIZE = 42.dp
const val FULL_PROMPT_SWITCH_KEY = "full_prompt"
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun PromptTemplatesPanel(
model: Model,
viewModel: LlmSingleTurnViewModel,
onSend: (fullPrompt: String) -> Unit,
modifier: Modifier = Modifier
) {
val uiState by viewModel.uiState.collectAsState()
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val inProgress = uiState.inProgress
var selectedTabIndex by remember { mutableIntStateOf(0) }
var curTextInputContent by remember { mutableStateOf("") }
val inputEditorValues: SnapshotStateMap<String, Any> = remember {
mutableStateMapOf(FULL_PROMPT_SWITCH_KEY to false)
}
val fullPrompt by remember {
derivedStateOf {
uiState.selectedPromptTemplateType.genFullPrompt(curTextInputContent, inputEditorValues)
}
}
val clipboardManager = LocalClipboardManager.current
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
// Update input editor values when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) {
for (config in selectedPromptTemplateType.config.inputEditors) {
inputEditorValues[config.label] = config.defaultOption
}
expandedStates.clear()
}
var showExamplePromptBottomSheet by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius)
Column(modifier = modifier) {
// Scrollable tab row for all prompt templates.
PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) {
TAB_TITLES.forEachIndexed { index, title ->
Tab(selected = selectedTabIndex == index, onClick = {
// Clear input when tab changes.
curTextInputContent = ""
// Reset full prompt switch.
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false
selectedTabIndex = index
viewModel.selectPromptTemplate(
model = model,
promptTemplateType = promptTemplateTypes[index]
)
}, text = { Text(text = title) })
}
}
// Content.
Column(
modifier = Modifier
.weight(1f)
.fillMaxWidth()
) {
// Input editor row.
if (selectedPromptTemplateType.config.inputEditors.isNotEmpty()) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
modifier = Modifier
.fillMaxWidth()
.background(MaterialTheme.colorScheme.surfaceContainerLow)
.padding(horizontal = 16.dp, vertical = 10.dp)
) {
// Input editors.
for (inputEditor in selectedPromptTemplateType.config.inputEditors) {
when (inputEditor.type) {
PromptTemplateInputEditorType.SINGLE_SELECT -> SingleSelectButton(config = inputEditor as PromptTemplateSingleSelectInputEditor,
onSelected = { option ->
inputEditorValues[inputEditor.label] = option
})
}
}
}
}
// Text input box.
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(rememberScrollState())
) {
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
Text(
fullPrompt,
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
.padding(bottom = 32.dp)
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
.background(MaterialTheme.customColors.agentBubbleBgColor)
.padding(16.dp)
)
} else {
TextField(
value = curTextInputContent,
onValueChange = { curTextInputContent = it },
colors = TextFieldDefaults.colors(
unfocusedContainerColor = Color.Transparent,
focusedContainerColor = Color.Transparent,
focusedIndicatorColor = Color.Transparent,
unfocusedIndicatorColor = Color.Transparent,
disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyMedium,
placeholder = { Text("Enter content") },
modifier = Modifier.padding(bottom = 32.dp)
)
}
}
// Text action row.
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp),
modifier = Modifier
.fillMaxWidth()
.padding(vertical = 4.dp, horizontal = 16.dp)
) {
// Full prompt switch.
if (selectedPromptTemplateType != PromptTemplateType.FREE_FORM && curTextInputContent.isNotEmpty()) {
Row(verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp),
modifier = Modifier
.clip(CircleShape)
.background(if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.customColors.agentBubbleBgColor)
.clickable {
inputEditorValues[FULL_PROMPT_SWITCH_KEY] =
!(inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean)
}
.height(40.dp)
.border(
width = 1.dp, color = MaterialTheme.colorScheme.surface, shape = CircleShape
)
.padding(horizontal = 12.dp)) {
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
Icon(
imageVector = Icons.Rounded.Visibility,
contentDescription = "",
modifier = Modifier.size(FilterChipDefaults.IconSize),
)
} else {
Icon(
imageVector = Icons.Rounded.VisibilityOff,
contentDescription = "",
modifier = Modifier
.size(FilterChipDefaults.IconSize)
.alpha(0.3f),
)
}
Text("Preview prompt", style = MaterialTheme.typography.labelMedium)
}
}
Spacer(modifier = Modifier.weight(1f))
// Button to copy full prompt.
if (curTextInputContent.isNotEmpty()) {
OutlinedIconButton(
onClick = {
val clipData = fullPrompt
clipboardManager.setText(clipData)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE)
) {
Icon(
Icons.Outlined.ContentCopy, contentDescription = "", modifier = Modifier.size(20.dp)
)
}
}
// Add example prompt button.
OutlinedIconButton(
enabled = !inProgress,
onClick = { showExamplePromptBottomSheet = true },
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE)
) {
Icon(
Icons.Rounded.Add,
contentDescription = "",
modifier = Modifier.size(20.dp),
)
}
// Send button
OutlinedIconButton(
enabled = !inProgress && curTextInputContent.isNotEmpty(),
onClick = {
onSend(fullPrompt.text)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE)
) {
Icon(
Icons.AutoMirrored.Rounded.Send,
contentDescription = "",
modifier = Modifier
.size(20.dp)
.offset(x = 2.dp),
)
}
}
}
}
}
if (showExamplePromptBottomSheet) {
ModalBottomSheet(
onDismissRequest = { showExamplePromptBottomSheet = false },
sheetState = sheetState,
modifier = Modifier.wrapContentHeight(),
) {
Column(modifier = Modifier.padding(bottom = 16.dp)) {
// Title
Text(
"Select an example",
modifier = Modifier
.fillMaxWidth()
.padding(16.dp),
style = MaterialTheme.typography.titleLarge
)
// Examples
for (prompt in selectedPromptTemplateType.examplePrompts) {
var textLayoutResultState by remember { mutableStateOf<TextLayoutResult?>(null) }
val hasOverflow = remember(textLayoutResultState) {
textLayoutResultState?.hasVisualOverflow ?: false
}
val isExpanded = expandedStates[prompt] ?: false
Column(
modifier = Modifier
.fillMaxWidth()
.clickable {
curTextInputContent = prompt
showExamplePromptBottomSheet = false
}
.padding(horizontal = 16.dp, vertical = 8.dp),
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
) {
Icon(Icons.Outlined.Description, contentDescription = "")
Text(prompt,
maxLines = if (isExpanded) Int.MAX_VALUE else 3,
overflow = TextOverflow.Ellipsis,
style = MaterialTheme.typography.bodySmall,
modifier = Modifier.weight(1f),
onTextLayout = { textLayoutResultState = it }
)
}
if (hasOverflow && !isExpanded) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 2.dp),
horizontalArrangement = Arrangement.End
) {
Box(modifier = Modifier
.padding(end = 16.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
.clickable {
expandedStates[prompt] = true
}
.padding(vertical = 1.dp, horizontal = 6.dp)) {
Icon(
Icons.Outlined.ExpandMore,
contentDescription = "",
modifier = Modifier.size(12.dp)
)
}
}
} else if (isExpanded) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 2.dp),
horizontalArrangement = Arrangement.End
) {
Box(modifier = Modifier
.padding(end = 16.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
.clickable {
expandedStates[prompt] = false
}
.padding(vertical = 1.dp, horizontal = 6.dp)) {
Icon(
Icons.Outlined.ExpandLess,
contentDescription = "",
modifier = Modifier.size(12.dp)
)
}
}
}
}
}
}
}
}
}

View file

@ -0,0 +1,206 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.AutoAwesome
import androidx.compose.material.icons.outlined.ContentCopy
import androidx.compose.material.icons.outlined.Timer
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.PrimaryTabRow
import androidx.compose.material3.Tab
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
import com.google.aiedge.gallery.ui.theme.GalleryTheme
private val OPTIONS = listOf("Response", "Benchmark")
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ResponsePanel(
model: Model,
viewModel: LlmSingleTurnViewModel,
modifier: Modifier = Modifier,
) {
val uiState by viewModel.uiState.collectAsState()
val inProgress = uiState.inProgress
val initializing = uiState.initializing
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val response = uiState.responsesByModel[model.name]?.get(selectedPromptTemplateType.label) ?: ""
val benchmark = uiState.benchmarkByModel[model.name]?.get(selectedPromptTemplateType.label)
val responseScrollState = rememberScrollState()
var selectedOptionIndex by remember { mutableIntStateOf(0) }
val clipboardManager = LocalClipboardManager.current
// Scroll to bottom when response changes.
LaunchedEffect(response) {
if (inProgress) {
responseScrollState.animateScrollTo(responseScrollState.maxValue)
}
}
// Select the "response" tab when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) {
selectedOptionIndex = 0
}
if (initializing) {
Box(
contentAlignment = Alignment.TopStart,
modifier = modifier
.fillMaxSize()
.padding(horizontal = 16.dp)
) {
MessageBodyLoading()
}
} else {
// Message when response is empty.
if (response.isEmpty()) {
Row(
modifier = Modifier.fillMaxSize(),
horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically
) {
Text(
"Response will appear here",
modifier = Modifier.alpha(0.5f),
style = MaterialTheme.typography.labelMedium,
)
}
}
// Response markdown.
else {
Column(
modifier = modifier
.padding(horizontal = 16.dp)
.padding(bottom = 4.dp)
) {
// Response/benchmark switch.
Row(modifier = Modifier.fillMaxWidth()) {
PrimaryTabRow(
selectedTabIndex = selectedOptionIndex,
containerColor = Color.Transparent,
) {
OPTIONS.forEachIndexed { index, title ->
Tab(selected = selectedOptionIndex == index, onClick = {
selectedOptionIndex = index
}, text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp)
) {
Icon(
ICONS[index],
contentDescription = "",
modifier = Modifier
.size(16.dp)
.alpha(0.7f)
)
Text(text = title)
}
})
}
}
}
if (selectedOptionIndex == 0) {
Box(
contentAlignment = Alignment.BottomEnd,
modifier = Modifier.weight(1f)
) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(responseScrollState)
) {
MarkdownText(
text = response,
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp)
)
}
// Copy button.
IconButton(
onClick = {
val clipData = AnnotatedString(response)
clipboardManager.setText(clipData)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
contentColor = MaterialTheme.colorScheme.primary,
),
) {
Icon(
Icons.Outlined.ContentCopy,
contentDescription = "",
modifier = Modifier.size(20.dp),
)
}
}
} else if (selectedOptionIndex == 1) {
if (benchmark != null) {
MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth())
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ResponsePanelPreview() {
GalleryTheme {
ResponsePanel(
model = TASK_LLM_SINGLE_TURN.models[0],
viewModel = LlmSingleTurnViewModel(),
modifier = Modifier.fillMaxSize()
)
}
}

View file

@ -0,0 +1,90 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ArrowDropDown
import androidx.compose.material3.DropdownMenu
import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.unit.dp
@Composable
fun SingleSelectButton(
config: PromptTemplateSingleSelectInputEditor,
onSelected: (String) -> Unit
) {
var showMenu by remember { mutableStateOf(false) }
var selectedOption by remember { mutableStateOf(config.defaultOption) }
LaunchedEffect(config) {
selectedOption = config.defaultOption
}
Box {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier
.clip(RoundedCornerShape(8.dp))
.background(MaterialTheme.colorScheme.secondaryContainer)
.clickable {
showMenu = true
}
.padding(vertical = 4.dp, horizontal = 6.dp)
.padding(start = 8.dp)
) {
Text("${config.label}: $selectedOption", style = MaterialTheme.typography.labelLarge)
Icon(Icons.Rounded.ArrowDropDown, contentDescription = "")
}
DropdownMenu(
expanded = showMenu,
onDismissRequest = { showMenu = false }
) {
// Options
for (option in config.options) {
DropdownMenuItem(
text = { Text(option) },
onClick = {
selectedOption = option
showMenu = false
onSelected(option)
}
)
}
}
}
}

View file

@ -0,0 +1,133 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmsingleturn
import androidx.compose.foundation.background
import androidx.compose.foundation.gestures.detectDragGestures
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.layout.onGloballyPositioned
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
@Composable
fun VerticalSplitView(
topView: @Composable () -> Unit,
bottomView: @Composable () -> Unit,
modifier: Modifier = Modifier,
initialRatio: Float = 0.5f,
minTopHeight: Dp = 250.dp,
minBottomHeight: Dp = 200.dp,
handleThickness: Dp = 20.dp
) {
var splitRatio by remember { mutableFloatStateOf(initialRatio) }
var columnHeightPx by remember {
mutableFloatStateOf(0f)
}
var columnHeightDp by remember {
mutableStateOf(0.dp)
}
val localDensity = LocalDensity.current
Column(modifier = modifier
.fillMaxSize()
.onGloballyPositioned { coordinates ->
// Set column height using the LayoutCoordinates
columnHeightPx = coordinates.size.height.toFloat()
columnHeightDp = with(localDensity) { coordinates.size.height.toDp() }
}
) {
Box(
modifier = Modifier
.fillMaxWidth()
.weight(splitRatio)
) {
topView()
}
Box(
modifier = Modifier
.fillMaxWidth()
.height(handleThickness)
.background(MaterialTheme.customColors.agentBubbleBgColor)
.pointerInput(Unit) {
detectDragGestures { change, dragAmount ->
val newTopHeightPx = columnHeightPx * splitRatio + dragAmount.y
var newTopHeightDp = with(localDensity) { newTopHeightPx.toDp() }
if (newTopHeightDp < minTopHeight) {
newTopHeightDp = minTopHeight
}
if (columnHeightDp - newTopHeightDp < minBottomHeight) {
newTopHeightDp = columnHeightDp - minBottomHeight
}
splitRatio = newTopHeightDp / columnHeightDp
change.consume()
}
},
contentAlignment = Alignment.Center
) {
Box(
modifier = Modifier
.width(32.dp)
.height(4.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f))
)
}
Box(
modifier = Modifier
.fillMaxWidth()
.weight(1f - splitRatio)
) {
bottomView()
}
}
}
@Preview(showBackground = true)
@Composable
fun VerticalSplitViewPreview() {
GalleryTheme {
VerticalSplitView(topView = {
Text("top")
}, bottomView = {
Text("bottom")
})
}
}

View file

@ -40,6 +40,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType
@ -175,9 +176,13 @@ open class ModelManagerViewModel(
// Kick off downloads for these models .
withContext(Dispatchers.Main) {
val tokenStatusAndData = getTokenStatusAndData()
for (info in inProgressWorkInfos) {
val model: Model? = getModelByName(info.modelName)
if (model != null) {
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
model.accessToken = tokenStatusAndData.data.accessToken
}
Log.d(TAG, "Sending a new download request for '${model.name}'")
downloadRepository.downloadModel(
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
@ -233,11 +238,13 @@ open class ModelManagerViewModel(
// Delete model from the list if model is imported as a local model.
if (model.imported) {
val index = task.models.indexOf(model)
for (curTask in TASKS) {
val index = curTask.models.indexOf(model)
if (index >= 0) {
task.models.removeAt(index)
curTask.models.removeAt(index)
}
curTask.updateTrigger.value = System.currentTimeMillis()
}
task.updateTrigger.value = System.currentTimeMillis()
curModelDownloadStatus.remove(model.name)
// Update preference.
@ -252,7 +259,7 @@ open class ModelManagerViewModel(
_uiState.update { newUiState }
}
fun initializeModel(context: Context, model: Model, force: Boolean = false) {
fun initializeModel(context: Context, task: Task, model: Model, force: Boolean = false) {
viewModelScope.launch(Dispatchers.Default) {
// Skip if initialized already.
if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) {
@ -267,7 +274,7 @@ open class ModelManagerViewModel(
}
// Clean up.
cleanupModel(model = model)
cleanupModel(task = task, model = model)
// Start initialization.
Log.d(TAG, "Initializing model '${model.name}'...")
@ -301,7 +308,7 @@ open class ModelManagerViewModel(
)
}
}
when (model.taskType) {
when (task.type) {
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.initialize(
context = context,
model = model,
@ -320,24 +327,33 @@ open class ModelManagerViewModel(
onDone = onDone,
)
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
context = context, model = model, onDone = onDone
)
else -> {}
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
}
}
fun cleanupModel(model: Model) {
fun cleanupModel(task: Task, model: Model) {
if (model.instance != null) {
Log.d(TAG, "Cleaning up model '${model.name}'...")
when (model.taskType) {
when (task.type) {
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.cleanUp(model = model)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
else -> {}
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
model.instance = null
model.initializing = false
@ -421,10 +437,11 @@ open class ModelManagerViewModel(
return connection.responseCode
}
fun addImportedLlmModel(task: Task, info: ImportedModelInfo) {
fun addImportedLlmModel(info: ImportedModelInfo) {
Log.d(TAG, "adding imported llm model: $info")
// Remove duplicated imported model if existed.
val task = TASK_LLM_CHAT
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first")
@ -455,6 +472,9 @@ open class ModelManagerViewModel(
)
}
task.updateTrigger.value = System.currentTimeMillis()
// Also need to update single turn task.
TASK_LLM_SINGLE_TURN.updateTrigger.value = System.currentTimeMillis()
// Add to preference storage.
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
@ -630,11 +650,8 @@ open class ModelManagerViewModel(
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
val accelerators: List<Accelerator> = (convertValueToTargetType(
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!,
ValueType.STRING
) as String)
.split(",")
.mapNotNull { acceleratorLabel ->
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING
) as String).split(",").mapNotNull { acceleratorLabel ->
when (acceleratorLabel.trim()) {
Accelerator.GPU.label -> Accelerator.GPU
Accelerator.CPU.label -> Accelerator.CPU
@ -643,20 +660,16 @@ open class ModelManagerViewModel(
}
val configs: List<Config> = createLlmChatConfigs(
defaultMaxToken = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!,
ValueType.INT
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!, ValueType.INT
) as Int,
defaultTopK = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!,
ValueType.INT
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!, ValueType.INT
) as Int,
defaultTopP = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!,
ValueType.FLOAT
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!, ValueType.FLOAT
) as Float,
defaultTemperature = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!,
ValueType.FLOAT
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!, ValueType.FLOAT
) as Float,
accelerators = accelerators,
)
@ -666,9 +679,10 @@ open class ModelManagerViewModel(
configs = configs,
sizeInBytes = info.fileSize,
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
showBenchmarkButton = false,
imported = true,
)
model.preProcess(task = task)
model.preProcess()
return model
}
@ -741,7 +755,7 @@ open class ModelManagerViewModel(
val task = TASKS.find { it.type.label == hfModel.task }
val model = hfModel.toModel()
if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) {
model.preProcess(task = task)
model.preProcess()
Log.d(TAG, "AG model: $model")
task.models.add(model)

View file

@ -46,6 +46,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
@ -58,6 +59,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationDestination
@ -131,7 +134,7 @@ fun GalleryNavHost(
task = curPickedTask,
onModelClicked = { model ->
navigateToTaskScreen(
navController = navController, taskType = model.taskType!!, model = model
navController = navController, taskType = curPickedTask.type, model = model
)
},
navigateUp = { showModelManager = false })
@ -220,6 +223,24 @@ fun GalleryNavHost(
)
}
}
// LLMm single turn.
composable(
route = "${LlmSingleTurnDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_SINGLE_TURN)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
}
// Handle incoming intents for deep links
@ -231,9 +252,10 @@ fun GalleryNavHost(
if (data.toString().startsWith("com.google.aiedge.gallery://model/")) {
val modelName = data.pathSegments.last()
getModelByName(modelName)?.let { model ->
// TODO(jingjin): need to show a list of possible tasks for this model.
navigateToTaskScreen(
navController = navController,
taskType = model.taskType!!,
taskType = TaskType.LLM_CHAT,
model = model
)
}
@ -249,6 +271,7 @@ fun navigateToTaskScreen(
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_SINGLE_TURN -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}

View file

@ -0,0 +1,21 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.preview
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
class PreviewLlmSingleTurnViewModel : LlmSingleTurnViewModel(task = TASK_TEST1)

View file

@ -34,7 +34,7 @@ class PreviewModelManagerViewModel(context: Context) :
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess(task = task)
model.preProcess()
}
}

View file

@ -25,7 +25,8 @@ val primaryContainerLight = Color(0xFFD0E4FF)
val onPrimaryContainerLight = Color(0xFF144A74)
val secondaryLight = Color(0xFF526070)
val onSecondaryLight = Color(0xFFFFFFFF)
val secondaryContainerLight = Color(0xFFD6E4F7)
//val secondaryContainerLight = Color(0xFFD6E4F7)
val secondaryContainerLight = Color(0xFFC2E7FF)
val onSecondaryContainerLight = Color(0xFF3B4857)
val tertiaryLight = Color(0xFF775A0B)
val onTertiaryLight = Color(0xFFFFFFFF)

View file

@ -116,6 +116,7 @@ data class CustomColors(
val userBubbleBgColor: Color = Color.Transparent,
val agentBubbleBgColor: Color = Color.Transparent,
val linkColor: Color = Color.Transparent,
val successColor: Color = Color.Transparent,
)
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
@ -145,6 +146,7 @@ val lightCustomColors = CustomColors(
agentBubbleBgColor = Color(0xFFe9eef6),
userBubbleBgColor = Color(0xFF32628D),
linkColor = Color(0xFF32628D),
successColor = Color(0xff3d860b),
)
val darkCustomColors = CustomColors(
@ -172,6 +174,7 @@ val darkCustomColors = CustomColors(
agentBubbleBgColor = Color(0xFF1b1c1d),
userBubbleBgColor = Color(0xFF1f3760),
linkColor = Color(0xFF9DCAFC),
successColor = Color(0xFFA1CE83),
)
val MaterialTheme.customColors: CustomColors

View file

@ -92,6 +92,7 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
val connection = url.openConnection() as HttpURLConnection
if (accessToken != null) {
Log.d(TAG, "Using access token: ${accessToken.subSequence(0, 10)}...")
connection.setRequestProperty("Authorization", "Bearer $accessToken")
}
@ -176,6 +177,7 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
KEY_MODEL_DOWNLOAD_REMAINING_MS, remainingMs.toLong()
).build()
)
Log.d(TAG, "downloadedBytes: $downloadedBytes")
lastSetProgressTs = curTs
}
}