mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-06 06:30:30 -04:00
Add support for LLM single-turn experience
This commit is contained in:
parent
46aaee2654
commit
d94fec0674
39 changed files with 2376 additions and 295 deletions
|
@ -30,7 +30,7 @@ android {
|
|||
minSdk = 24
|
||||
targetSdk = 35
|
||||
versionCode = 1
|
||||
versionName = "20250421"
|
||||
versionName = "20250428"
|
||||
|
||||
// Needed for HuggingFace auth workflows.
|
||||
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
|
||||
|
|
|
@ -187,6 +187,5 @@ fun GalleryTopAppBar(
|
|||
else -> {}
|
||||
}
|
||||
}
|
||||
|
||||
)
|
||||
}
|
|
@ -23,7 +23,7 @@ import androidx.datastore.preferences.core.Preferences
|
|||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import com.google.aiedge.gallery.data.AppContainer
|
||||
import com.google.aiedge.gallery.data.DefaultAppContainer
|
||||
import com.google.aiedge.gallery.data.TASKS
|
||||
import com.google.aiedge.gallery.ui.common.processTasks
|
||||
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
||||
|
||||
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
|
||||
|
@ -36,12 +36,7 @@ class GalleryApplication : Application() {
|
|||
super.onCreate()
|
||||
|
||||
// Process tasks.
|
||||
for ((index, task) in TASKS.withIndex()) {
|
||||
task.index = index
|
||||
for (model in task.models) {
|
||||
model.preProcess(task = task)
|
||||
}
|
||||
}
|
||||
processTasks()
|
||||
|
||||
container = DefaultAppContainer(this, dataStore)
|
||||
|
||||
|
|
|
@ -92,15 +92,13 @@ data class Model(
|
|||
val imported: Boolean = false,
|
||||
|
||||
// The following fields are managed by the app. Don't need to set manually.
|
||||
var taskType: TaskType? = null,
|
||||
var instance: Any? = null,
|
||||
var initializing: Boolean = false,
|
||||
var configValues: Map<String, Any> = mapOf(),
|
||||
var totalBytes: Long = 0L,
|
||||
var accessToken: String? = null,
|
||||
) {
|
||||
fun preProcess(task: Task) {
|
||||
this.taskType = task.type
|
||||
fun preProcess() {
|
||||
val configValues: MutableMap<String, Any> = mutableMapOf()
|
||||
for (config in this.configs) {
|
||||
configValues[config.key.label] = config.defaultValue
|
||||
|
@ -246,6 +244,7 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
|
|||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
|
||||
sizeInBytes = 1354301440L,
|
||||
configs = createLlmChatConfigs(),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||
)
|
||||
|
@ -256,6 +255,7 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
|||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
|
||||
sizeInBytes = 2627141632L,
|
||||
configs = createLlmChatConfigs(),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||
)
|
||||
|
@ -271,6 +271,7 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
|||
defaultTopP = 0.95f,
|
||||
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
|
||||
),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
||||
llmPromptTemplates = listOf(
|
||||
|
@ -299,6 +300,7 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
|
|||
defaultTopP = 0.7f,
|
||||
accelerators = listOf(Accelerator.CPU)
|
||||
),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
)
|
||||
|
@ -389,7 +391,7 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
|
|||
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
|
||||
)
|
||||
|
||||
val MODELS_LLM_CHAT: MutableList<Model> = mutableListOf(
|
||||
val MODELS_LLM: MutableList<Model> = mutableListOf(
|
||||
MODEL_LLM_GEMMA_2B_GPU_INT4,
|
||||
MODEL_LLM_GEMMA_2_2B_GPU_INT8,
|
||||
MODEL_LLM_GEMMA_3_1B_INT4,
|
||||
|
|
|
@ -18,9 +18,11 @@ package com.google.aiedge.gallery.data
|
|||
|
||||
import androidx.annotation.StringRes
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.outlined.Forum
|
||||
import androidx.compose.material.icons.outlined.Widgets
|
||||
import androidx.compose.material.icons.rounded.ImageSearch
|
||||
import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.mutableLongStateOf
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import com.google.aiedge.gallery.R
|
||||
|
||||
|
@ -30,6 +32,7 @@ enum class TaskType(val label: String) {
|
|||
IMAGE_CLASSIFICATION("Image Classification"),
|
||||
IMAGE_GENERATION("Image Generation"),
|
||||
LLM_CHAT("LLM Chat"),
|
||||
LLM_SINGLE_TURN("LLM Use Cases"),
|
||||
|
||||
TEST_TASK_1("Test task 1"),
|
||||
TEST_TASK_2("Test task 2")
|
||||
|
@ -67,7 +70,7 @@ data class Task(
|
|||
// The following fields are managed by the app. Don't need to set manually.
|
||||
var index: Int = -1,
|
||||
|
||||
val updateTrigger: MutableState<Long> = mutableStateOf(0)
|
||||
val updateTrigger: MutableState<Long> = mutableLongStateOf(0)
|
||||
)
|
||||
|
||||
val TASK_TEXT_CLASSIFICATION = Task(
|
||||
|
@ -87,9 +90,19 @@ val TASK_IMAGE_CLASSIFICATION = Task(
|
|||
|
||||
val TASK_LLM_CHAT = Task(
|
||||
type = TaskType.LLM_CHAT,
|
||||
iconVectorResourceId = R.drawable.chat_spark,
|
||||
models = MODELS_LLM_CHAT,
|
||||
description = "Chat? with a on-device large language model",
|
||||
icon = Icons.Outlined.Forum,
|
||||
models = MODELS_LLM,
|
||||
description = "Chat with a on-device large language model",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||
)
|
||||
|
||||
val TASK_LLM_SINGLE_TURN = Task(
|
||||
type = TaskType.LLM_SINGLE_TURN,
|
||||
icon = Icons.Outlined.Widgets,
|
||||
models = MODELS_LLM,
|
||||
description = "Single turn use cases with on-device large language model",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||
|
@ -108,9 +121,10 @@ val TASK_IMAGE_GENERATION = Task(
|
|||
/** All tasks. */
|
||||
val TASKS: List<Task> = listOf(
|
||||
// TASK_TEXT_CLASSIFICATION,
|
||||
// TASK_IMAGE_CLASSIFICATION,
|
||||
TASK_IMAGE_CLASSIFICATION,
|
||||
TASK_IMAGE_GENERATION,
|
||||
TASK_LLM_CHAT,
|
||||
TASK_LLM_SINGLE_TURN,
|
||||
)
|
||||
|
||||
fun getModelByName(name: String): Model? {
|
||||
|
|
|
@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
|
|||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
|
||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
|
||||
|
||||
|
@ -56,6 +57,11 @@ object ViewModelProvider {
|
|||
LlmChatViewModel()
|
||||
}
|
||||
|
||||
// Initializer for LlmSingleTurnViewModel..
|
||||
initializer {
|
||||
LlmSingleTurnViewModel()
|
||||
}
|
||||
|
||||
initializer {
|
||||
ImageGenerationViewModel()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,213 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.common
|
||||
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
|
||||
import androidx.compose.material.icons.rounded.ArrowDropDown
|
||||
import androidx.compose.material.icons.rounded.Settings
|
||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.rememberModalBottomSheetState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.vectorResource
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
||||
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ModelPageAppBar(
|
||||
task: Task,
|
||||
model: Model,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
onBackClicked: () -> Unit,
|
||||
onModelSelected: (Model) -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
|
||||
) {
|
||||
var showConfigDialog by remember { mutableStateOf(false) }
|
||||
var showModelPicker by remember { mutableStateOf(false) }
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||
val context = LocalContext.current
|
||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
||||
|
||||
CenterAlignedTopAppBar(title = {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
// Task type.
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||
) {
|
||||
Icon(
|
||||
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
|
||||
tint = getTaskIconColor(task = task),
|
||||
modifier = Modifier.size(16.dp),
|
||||
contentDescription = "",
|
||||
)
|
||||
Text(
|
||||
task.type.label,
|
||||
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
|
||||
color = getTaskIconColor(task = task)
|
||||
)
|
||||
}
|
||||
|
||||
// Model name.
|
||||
Row(verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(2.dp),
|
||||
modifier = Modifier
|
||||
.clip(CircleShape)
|
||||
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
|
||||
.clickable {
|
||||
showModelPicker = true
|
||||
}
|
||||
.padding(start = 8.dp, end = 2.dp)) {
|
||||
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
||||
Text(
|
||||
model.name,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
modifier = Modifier.padding(start = 4.dp),
|
||||
)
|
||||
Icon(
|
||||
Icons.Rounded.ArrowDropDown,
|
||||
modifier = Modifier.size(20.dp),
|
||||
contentDescription = "",
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
}, modifier = modifier,
|
||||
// The back button.
|
||||
navigationIcon = {
|
||||
IconButton(onClick = onBackClicked) {
|
||||
Icon(
|
||||
imageVector = Icons.AutoMirrored.Rounded.ArrowBack,
|
||||
contentDescription = "",
|
||||
)
|
||||
}
|
||||
},
|
||||
// The config button for the model (if existed).
|
||||
actions = {
|
||||
if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
IconButton(onClick = { showConfigDialog = true }) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.Settings,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Config dialog.
|
||||
if (showConfigDialog) {
|
||||
ConfigDialog(
|
||||
title = "Model configs",
|
||||
configs = model.configs,
|
||||
initialValues = model.configValues,
|
||||
onDismissed = { showConfigDialog = false },
|
||||
onOk = { curConfigValues ->
|
||||
// Hide config dialog.
|
||||
showConfigDialog = false
|
||||
|
||||
// Check if the configs are changed or not. Also check if the model needs to be
|
||||
// re-initialized.
|
||||
var same = true
|
||||
var needReinitialization = false
|
||||
for (config in model.configs) {
|
||||
val key = config.key.label
|
||||
val oldValue = convertValueToTargetType(
|
||||
value = model.configValues.getValue(key), valueType = config.valueType
|
||||
)
|
||||
val newValue = convertValueToTargetType(
|
||||
value = curConfigValues.getValue(key), valueType = config.valueType
|
||||
)
|
||||
if (oldValue != newValue) {
|
||||
same = false
|
||||
if (config.needReinitialization) {
|
||||
needReinitialization = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if (same) {
|
||||
return@ConfigDialog
|
||||
}
|
||||
|
||||
// Save the config values to Model.
|
||||
val oldConfigValues = model.configValues
|
||||
model.configValues = curConfigValues
|
||||
|
||||
// Force to re-initialize the model with the new configs.
|
||||
if (needReinitialization) {
|
||||
modelManagerViewModel.initializeModel(
|
||||
context = context, task = task, model = model, force = true
|
||||
)
|
||||
}
|
||||
|
||||
// Notify.
|
||||
onConfigChanged(oldConfigValues, model.configValues)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Model picker.
|
||||
if (showModelPicker) {
|
||||
ModalBottomSheet(
|
||||
onDismissRequest = { showModelPicker = false },
|
||||
sheetState = sheetState,
|
||||
) {
|
||||
ModelPicker(
|
||||
task = task,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onModelSelected = { model ->
|
||||
showModelPicker = false
|
||||
onModelSelected(model)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.common
|
||||
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.outlined.CheckCircle
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.vectorResource
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
|
||||
|
||||
@Composable
|
||||
fun ModelPicker(
|
||||
task: Task,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
onModelSelected: (Model) -> Unit
|
||||
) {
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
|
||||
Column(modifier = Modifier.padding(bottom = 8.dp)) {
|
||||
// Title
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.padding(horizontal = 16.dp)
|
||||
.padding(top = 4.dp, bottom = 4.dp),
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||
) {
|
||||
Icon(
|
||||
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
|
||||
tint = getTaskIconColor(task = task),
|
||||
modifier = Modifier.size(16.dp),
|
||||
contentDescription = "",
|
||||
)
|
||||
Text(
|
||||
"${task.type.label} models",
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
color = getTaskIconColor(task = task),
|
||||
)
|
||||
}
|
||||
|
||||
// Model list.
|
||||
for (model in task.models) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.SpaceBetween,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.clickable {
|
||||
onModelSelected(model)
|
||||
}
|
||||
.padding(horizontal = 16.dp, vertical = 8.dp),
|
||||
) {
|
||||
Spacer(modifier = Modifier.width(24.dp))
|
||||
Column(modifier = Modifier.weight(1f)) {
|
||||
Text(model.name, style = MaterialTheme.typography.bodyMedium)
|
||||
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
|
||||
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
||||
Text(
|
||||
model.sizeInBytes.humanReadableSize(),
|
||||
color = MaterialTheme.colorScheme.secondary,
|
||||
style = labelSmallNarrow.copy(lineHeight = 10.sp)
|
||||
)
|
||||
}
|
||||
}
|
||||
if (model.name == modelManagerUiState.selectedModel.name) {
|
||||
Icon(
|
||||
Icons.Outlined.CheckCircle,
|
||||
modifier = Modifier.size(16.dp),
|
||||
contentDescription = ""
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ModelPickerPreview() {
|
||||
val context = LocalContext.current
|
||||
|
||||
GalleryTheme {
|
||||
ModelPicker(
|
||||
task = TASK_TEST1,
|
||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||
onModelSelected = {},
|
||||
)
|
||||
}
|
||||
}
|
|
@ -29,6 +29,7 @@ import androidx.core.content.ContextCompat
|
|||
import androidx.core.content.FileProvider
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.TASKS
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.ValueType
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkResult
|
||||
|
@ -57,6 +58,9 @@ interface LatencyProvider {
|
|||
val latencyMs: Float
|
||||
}
|
||||
|
||||
private const val START_THINKING = "***Thinking...***"
|
||||
private const val DONE_THINKING = "***Done thinking***"
|
||||
|
||||
/** Format the bytes into a human-readable format. */
|
||||
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
|
||||
val bytes = this
|
||||
|
@ -452,3 +456,33 @@ fun cleanUpMediapipeTaskErrorMessage(message: String): String {
|
|||
}
|
||||
return message
|
||||
}
|
||||
|
||||
fun processTasks() {
|
||||
for ((index, task) in TASKS.withIndex()) {
|
||||
task.index = index
|
||||
for (model in task.models) {
|
||||
model.preProcess()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun processLlmResponse(response: String): String {
|
||||
// Add "thinking" and "done thinking" around the thinking content.
|
||||
var newContent = response
|
||||
.replace("<think>", "$START_THINKING\n")
|
||||
.replace("</think>", "\n$DONE_THINKING")
|
||||
|
||||
// Remove empty thinking content.
|
||||
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
|
||||
if (endThinkingIndex >= 0) {
|
||||
val thinkingContent =
|
||||
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
|
||||
.replace(START_THINKING, "")
|
||||
.replace(DONE_THINKING, "")
|
||||
if (thinkingContent.isBlank()) {
|
||||
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
|
||||
}
|
||||
}
|
||||
|
||||
return newContent
|
||||
}
|
|
@ -38,6 +38,7 @@ import androidx.compose.foundation.layout.padding
|
|||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.foundation.layout.wrapContentHeight
|
||||
import androidx.compose.foundation.layout.wrapContentWidth
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.foundation.lazy.rememberLazyListState
|
||||
|
@ -334,7 +335,10 @@ fun ChatPanel(
|
|||
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
|
||||
|
||||
// Benchmark LLM result.
|
||||
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(message = message)
|
||||
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
|
||||
message = message,
|
||||
modifier = Modifier.wrapContentWidth()
|
||||
)
|
||||
|
||||
else -> {}
|
||||
}
|
||||
|
@ -346,7 +350,7 @@ fun ChatPanel(
|
|||
) {
|
||||
LatencyText(message = message)
|
||||
// A button to show stats for the LLM message.
|
||||
if (selectedModel.taskType == TaskType.LLM_CHAT && message is ChatMessageText
|
||||
if (task.type == TaskType.LLM_CHAT && message is ChatMessageText
|
||||
// This means we only want to show the action button when the message is done
|
||||
// generating, at which point the latency will be set.
|
||||
&& message.latencyMs >= 0
|
||||
|
@ -403,21 +407,17 @@ fun ChatPanel(
|
|||
}
|
||||
|
||||
// Benchmark button
|
||||
// if (selectedModel.showBenchmarkButton) {
|
||||
// MessageActionButton(
|
||||
// label = stringResource(R.string.benchmark),
|
||||
// icon = Icons.Outlined.Timer,
|
||||
// onClick = {
|
||||
// if (selectedModel.taskType == TaskType.LLM_CHAT) {
|
||||
// onBenchmarkClicked(selectedModel, message, 0, 0)
|
||||
// } else {
|
||||
// showBenchmarkConfigsDialog = true
|
||||
// benchmarkMessage.value = message
|
||||
// }
|
||||
// },
|
||||
// enabled = !uiState.inProgress
|
||||
// )
|
||||
// }
|
||||
if (selectedModel.showBenchmarkButton) {
|
||||
MessageActionButton(
|
||||
label = stringResource(R.string.benchmark),
|
||||
icon = Icons.Outlined.Timer,
|
||||
onClick = {
|
||||
showBenchmarkConfigsDialog = true
|
||||
benchmarkMessage.value = message
|
||||
},
|
||||
enabled = !uiState.inProgress
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -443,7 +443,7 @@ fun ChatPanel(
|
|||
// Chat input
|
||||
when (chatInputType) {
|
||||
ChatInputType.TEXT -> {
|
||||
val isLlmTask = selectedModel.taskType == TaskType.LLM_CHAT
|
||||
val isLlmTask = task.type == TaskType.LLM_CHAT
|
||||
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
|
||||
MessageInputText(
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
|
|
|
@ -18,18 +18,10 @@ package com.google.aiedge.gallery.ui.common.chat
|
|||
|
||||
import android.util.Log
|
||||
import androidx.activity.compose.BackHandler
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import androidx.compose.animation.fadeIn
|
||||
import androidx.compose.animation.fadeOut
|
||||
import androidx.compose.animation.scaleIn
|
||||
import androidx.compose.animation.scaleOut
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.pager.HorizontalPager
|
||||
import androidx.compose.foundation.pager.rememberPagerState
|
||||
|
@ -40,29 +32,21 @@ import androidx.compose.runtime.Composable
|
|||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.graphicsLayer
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import com.google.aiedge.gallery.GalleryTopAppBar
|
||||
import com.google.aiedge.gallery.data.AppBarAction
|
||||
import com.google.aiedge.gallery.data.AppBarActionType
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
|
||||
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlin.math.absoluteValue
|
||||
|
||||
|
@ -77,7 +61,6 @@ private const val TAG = "AGChatView"
|
|||
* manages model initialization, cleanup, and download status, and handles navigation and system
|
||||
* back gestures.
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ChatView(
|
||||
task: Task,
|
||||
|
@ -96,34 +79,29 @@ fun ChatView(
|
|||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val selectedModel = modelManagerUiState.selectedModel
|
||||
|
||||
val pagerState = rememberPagerState(initialPage = task.models.indexOf(selectedModel),
|
||||
val pagerState = rememberPagerState(
|
||||
initialPage = task.models.indexOf(selectedModel),
|
||||
pageCount = { task.models.size })
|
||||
val context = LocalContext.current
|
||||
val scope = rememberCoroutineScope()
|
||||
|
||||
val launcher = rememberLauncherForActivityResult(
|
||||
ActivityResultContracts.RequestPermission()
|
||||
) {
|
||||
modelManagerViewModel.downloadModel(task = task, model = selectedModel)
|
||||
}
|
||||
|
||||
val handleNavigateUp = {
|
||||
navigateUp()
|
||||
|
||||
// clean up all models.
|
||||
scope.launch(Dispatchers.Default) {
|
||||
for (model in task.models) {
|
||||
modelManagerViewModel.cleanupModel(model = model)
|
||||
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize model when model/download state changes.
|
||||
val status = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
||||
LaunchedEffect(status, selectedModel.name) {
|
||||
if (status?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
||||
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
|
||||
modelManagerViewModel.initializeModel(context, model = selectedModel)
|
||||
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -135,7 +113,7 @@ fun ChatView(
|
|||
"Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model."
|
||||
)
|
||||
if (curSelectedModel.name != selectedModel.name) {
|
||||
modelManagerViewModel.cleanupModel(model = selectedModel)
|
||||
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
|
||||
}
|
||||
modelManagerViewModel.selectModel(curSelectedModel)
|
||||
}
|
||||
|
@ -146,24 +124,36 @@ fun ChatView(
|
|||
}
|
||||
|
||||
Scaffold(modifier = modifier, topBar = {
|
||||
GalleryTopAppBar(
|
||||
title = task.type.label,
|
||||
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = {
|
||||
ModelPageAppBar(
|
||||
task = task,
|
||||
model = selectedModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onConfigChanged = { old, new ->
|
||||
viewModel.addConfigChangedMessage(
|
||||
oldConfigValues = old,
|
||||
newConfigValues = new,
|
||||
model = selectedModel
|
||||
)
|
||||
},
|
||||
onBackClicked = {
|
||||
handleNavigateUp()
|
||||
}),
|
||||
rightAction = AppBarAction(actionType = AppBarActionType.NO_ACTION, actionFn = {}),
|
||||
},
|
||||
onModelSelected = { model ->
|
||||
scope.launch {
|
||||
pagerState.animateScrollToPage(task.models.indexOf(model))
|
||||
}
|
||||
},
|
||||
)
|
||||
}) { innerPadding ->
|
||||
Box {
|
||||
// A horizontal scrollable pager to switch between models.
|
||||
HorizontalPager(state = pagerState) { pageIndex ->
|
||||
val curSelectedModel = task.models[pageIndex]
|
||||
val curModelDownloadStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name]
|
||||
|
||||
// Calculate the alpha of the current page based on how far they are from the center.
|
||||
val pageOffset = (
|
||||
(pagerState.currentPage - pageIndex) + pagerState
|
||||
.currentPageOffsetFraction
|
||||
).absoluteValue
|
||||
val pageOffset =
|
||||
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
|
||||
val curAlpha = 1f - pageOffset.coerceIn(0f, 1f)
|
||||
|
||||
Column(
|
||||
|
@ -172,91 +162,14 @@ fun ChatView(
|
|||
.fillMaxSize()
|
||||
.background(MaterialTheme.colorScheme.surface)
|
||||
) {
|
||||
// Model selector at the top.
|
||||
ModelSelector(
|
||||
ModelDownloadStatusInfoPanel(
|
||||
model = curSelectedModel,
|
||||
task = task,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onConfigChanged = { old, new ->
|
||||
viewModel.addConfigChangedMessage(
|
||||
oldConfigValues = old,
|
||||
newConfigValues = new,
|
||||
model = curSelectedModel
|
||||
modelManagerViewModel = modelManagerViewModel
|
||||
)
|
||||
},
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
contentAlpha = curAlpha,
|
||||
)
|
||||
|
||||
// Manages the conditional display of UI elements (download model button and downloading
|
||||
// animation) based on the corresponding download status.
|
||||
//
|
||||
// It uses delayed visibility ensuring they are shown only after a short delay if their
|
||||
// respective conditions remain true. This prevents UI flickering and provides a smoother
|
||||
// user experience.
|
||||
val curStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name]
|
||||
var shouldShowDownloadingAnimation by remember { mutableStateOf(false) }
|
||||
var downloadingAnimationConditionMet by remember { mutableStateOf(false) }
|
||||
var shouldShowDownloadModelButton by remember { mutableStateOf(false) }
|
||||
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
|
||||
|
||||
downloadingAnimationConditionMet =
|
||||
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS ||
|
||||
curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED ||
|
||||
curStatus?.status == ModelDownloadStatusType.UNZIPPING
|
||||
downloadModelButtonConditionMet =
|
||||
curStatus?.status == ModelDownloadStatusType.FAILED ||
|
||||
curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
|
||||
|
||||
LaunchedEffect(downloadingAnimationConditionMet) {
|
||||
if (downloadingAnimationConditionMet) {
|
||||
delay(100)
|
||||
shouldShowDownloadingAnimation = true
|
||||
} else {
|
||||
shouldShowDownloadingAnimation = false
|
||||
}
|
||||
}
|
||||
|
||||
LaunchedEffect(downloadModelButtonConditionMet) {
|
||||
if (downloadModelButtonConditionMet) {
|
||||
delay(700)
|
||||
shouldShowDownloadModelButton = true
|
||||
} else {
|
||||
shouldShowDownloadModelButton = false
|
||||
}
|
||||
}
|
||||
|
||||
AnimatedVisibility(
|
||||
visible = shouldShowDownloadingAnimation,
|
||||
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
|
||||
exit = scaleOut(targetScale = 0.9f) + fadeOut()
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
ModelDownloadingAnimation()
|
||||
}
|
||||
}
|
||||
|
||||
AnimatedVisibility(
|
||||
visible = shouldShowDownloadModelButton,
|
||||
enter = fadeIn(),
|
||||
exit = fadeOut()
|
||||
) {
|
||||
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
|
||||
checkNotificationPermissionAndStartDownload(
|
||||
context = context,
|
||||
launcher = launcher,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
task = task,
|
||||
model = curSelectedModel
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// The main messages panel.
|
||||
if (curStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
if (curModelDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
ChatPanel(
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
task = task,
|
||||
|
|
|
@ -20,13 +20,12 @@ import android.util.Log
|
|||
import androidx.lifecycle.ViewModel
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.update
|
||||
|
||||
private const val TAG = "AGChatViewModel"
|
||||
private const val START_THINKING = "***Thinking...***"
|
||||
private const val DONE_THINKING = "***Done thinking***"
|
||||
|
||||
data class ChatUiState(
|
||||
/**
|
||||
|
@ -121,26 +120,7 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
|||
if (newMessages.size > 0) {
|
||||
val lastMessage = newMessages.last()
|
||||
if (lastMessage is ChatMessageText) {
|
||||
var newContent = "${lastMessage.content}${partialContent}"
|
||||
// TODO: special handling for deepseek to remove the <think> tag.
|
||||
|
||||
// Add "thinking" and "done thinking" around the thinking content.
|
||||
newContent = newContent
|
||||
.replace("<think>", "$START_THINKING\n")
|
||||
.replace("</think>", "\n$DONE_THINKING")
|
||||
|
||||
// Remove empty thinking content.
|
||||
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
|
||||
if (endThinkingIndex >= 0) {
|
||||
val thinkingContent =
|
||||
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
|
||||
.replace(START_THINKING, "")
|
||||
.replace(DONE_THINKING, "")
|
||||
if (thinkingContent.isBlank()) {
|
||||
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
|
||||
}
|
||||
}
|
||||
|
||||
val newContent = processLlmResponse(response = "${lastMessage.content}${partialContent}")
|
||||
val newLastMessage = ChatMessageText(
|
||||
content = newContent,
|
||||
side = lastMessage.side,
|
||||
|
|
|
@ -45,7 +45,7 @@ fun MarkdownText(
|
|||
ProvideTextStyle(
|
||||
value = TextStyle(
|
||||
fontSize = fontSize,
|
||||
lineHeight = fontSize * 1.2,
|
||||
lineHeight = fontSize * 1.4,
|
||||
)
|
||||
) {
|
||||
RichText(
|
||||
|
|
|
@ -48,15 +48,16 @@ fun MessageActionButton(
|
|||
label: String,
|
||||
icon: ImageVector,
|
||||
onClick: () -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
enabled: Boolean = true
|
||||
) {
|
||||
val modifier = Modifier
|
||||
val curModifier = modifier
|
||||
.padding(top = 4.dp)
|
||||
.clip(CircleShape)
|
||||
.background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh)
|
||||
val alpha: Float = if (enabled) 1.0f else 0.3f
|
||||
Row(
|
||||
modifier = if (enabled) modifier.clickable { onClick() } else modifier,
|
||||
modifier = if (enabled) curModifier.clickable { onClick() } else modifier,
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
) {
|
||||
Icon(
|
||||
|
|
|
@ -19,8 +19,8 @@ package com.google.aiedge.gallery.ui.common.chat
|
|||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.wrapContentWidth
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
|
@ -33,16 +33,14 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
|||
* This function renders benchmark statistics (e.g., various token speed) in data cards
|
||||
*/
|
||||
@Composable
|
||||
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult) {
|
||||
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult, modifier: Modifier = Modifier) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(12.dp)
|
||||
.wrapContentWidth(),
|
||||
modifier = modifier.padding(12.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(8.dp)
|
||||
) {
|
||||
// Data cards.
|
||||
Row(
|
||||
modifier = Modifier.wrapContentWidth(), horizontalArrangement = Arrangement.spacedBy(16.dp)
|
||||
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
|
||||
) {
|
||||
for (stat in message.orderedStats) {
|
||||
DataCard(
|
||||
|
|
|
@ -82,7 +82,7 @@ fun MessageBodyPromptTemplates(
|
|||
style = MaterialTheme.typography.titleSmall,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.offset(y = -4.dp),
|
||||
.offset(y = (-4).dp),
|
||||
textAlign = TextAlign.Center,
|
||||
)
|
||||
}
|
||||
|
@ -140,7 +140,7 @@ fun MessageBodyPromptTemplatesPreview() {
|
|||
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
|
||||
task.index = index
|
||||
for (model in task.models) {
|
||||
model.preProcess(task = task)
|
||||
model.preProcess()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -70,7 +70,6 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
|||
* This function renders a row containing a text field for message input and a send button.
|
||||
* It handles message composition, input validation, and sending messages.
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun MessageInputText(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
|
@ -190,7 +189,7 @@ fun MessageInputText(
|
|||
Icons.AutoMirrored.Rounded.Send,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.offset(x = 2.dp),
|
||||
tint = if (inProgress) MaterialTheme.colorScheme.surfaceContainerHigh else MaterialTheme.colorScheme.primary
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.common.chat
|
||||
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import androidx.compose.animation.fadeIn
|
||||
import androidx.compose.animation.fadeOut
|
||||
import androidx.compose.animation.scaleIn
|
||||
import androidx.compose.animation.scaleOut
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import kotlinx.coroutines.delay
|
||||
|
||||
@Composable
|
||||
fun ModelDownloadStatusInfoPanel(
|
||||
model: Model,
|
||||
task: Task,
|
||||
modelManagerViewModel: ModelManagerViewModel
|
||||
) {
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
|
||||
// Manages the conditional display of UI elements (download model button and downloading
|
||||
// animation) based on the corresponding download status.
|
||||
//
|
||||
// It uses delayed visibility ensuring they are shown only after a short delay if their
|
||||
// respective conditions remain true. This prevents UI flickering and provides a smoother
|
||||
// user experience.
|
||||
val curStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
||||
var shouldShowDownloadingAnimation by remember { mutableStateOf(false) }
|
||||
var downloadingAnimationConditionMet by remember { mutableStateOf(false) }
|
||||
var shouldShowDownloadModelButton by remember { mutableStateOf(false) }
|
||||
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
|
||||
|
||||
downloadingAnimationConditionMet =
|
||||
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS || curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED || curStatus?.status == ModelDownloadStatusType.UNZIPPING
|
||||
downloadModelButtonConditionMet =
|
||||
curStatus?.status == ModelDownloadStatusType.FAILED || curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
|
||||
|
||||
LaunchedEffect(downloadingAnimationConditionMet) {
|
||||
if (downloadingAnimationConditionMet) {
|
||||
delay(100)
|
||||
shouldShowDownloadingAnimation = true
|
||||
} else {
|
||||
shouldShowDownloadingAnimation = false
|
||||
}
|
||||
}
|
||||
|
||||
LaunchedEffect(downloadModelButtonConditionMet) {
|
||||
if (downloadModelButtonConditionMet) {
|
||||
delay(700)
|
||||
shouldShowDownloadModelButton = true
|
||||
} else {
|
||||
shouldShowDownloadModelButton = false
|
||||
}
|
||||
}
|
||||
|
||||
AnimatedVisibility(
|
||||
visible = shouldShowDownloadingAnimation,
|
||||
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
|
||||
exit = scaleOut(targetScale = 0.9f) + fadeOut()
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center
|
||||
) {
|
||||
ModelDownloadingAnimation(
|
||||
model = model, task = task, modelManagerViewModel = modelManagerViewModel
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
AnimatedVisibility(
|
||||
visible = shouldShowDownloadModelButton, enter = fadeIn(), exit = fadeOut()
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
verticalArrangement = Arrangement.Center,
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
DownloadAndTryButton(
|
||||
task = task,
|
||||
model = model,
|
||||
enabled = true,
|
||||
needToDownloadFirst = true,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onClicked = {}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -24,6 +24,7 @@ import androidx.compose.foundation.layout.Arrangement
|
|||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
|
@ -32,23 +33,40 @@ import androidx.compose.foundation.layout.width
|
|||
import androidx.compose.foundation.lazy.grid.GridCells
|
||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||
import androidx.compose.foundation.lazy.grid.itemsIndexed
|
||||
import androidx.compose.material3.LinearProgressIndicator
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.ColorFilter
|
||||
import androidx.compose.ui.graphics.graphicsLayer
|
||||
import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import com.google.aiedge.gallery.R
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.formatToHourMinSecond
|
||||
import com.google.aiedge.gallery.ui.common.getTaskIconColor
|
||||
import com.google.aiedge.gallery.ui.common.humanReadableSize
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.MODEL_TEST1
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlin.math.cos
|
||||
import kotlin.math.pow
|
||||
|
@ -66,8 +84,19 @@ private const val END_SCALE = 0.6f
|
|||
* scaling and rotation effect.
|
||||
*/
|
||||
@Composable
|
||||
fun ModelDownloadingAnimation() {
|
||||
fun ModelDownloadingAnimation(
|
||||
model: Model,
|
||||
task: Task,
|
||||
modelManagerViewModel: ModelManagerViewModel
|
||||
) {
|
||||
val scale = remember { Animatable(END_SCALE) }
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val downloadStatus by remember {
|
||||
derivedStateOf { modelManagerUiState.modelDownloadStatus[model.name] }
|
||||
}
|
||||
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
|
||||
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
|
||||
var curDownloadProgress = 0f
|
||||
|
||||
LaunchedEffect(Unit) { // Run this once
|
||||
while (true) {
|
||||
|
@ -93,6 +122,21 @@ fun ModelDownloadingAnimation() {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Failure message.
|
||||
val curDownloadStatus = downloadStatus
|
||||
if (curDownloadStatus != null && curDownloadStatus.status == ModelDownloadStatusType.FAILED) {
|
||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||
Text(
|
||||
curDownloadStatus.errorMessage,
|
||||
color = MaterialTheme.colorScheme.error,
|
||||
style = labelSmallNarrow,
|
||||
overflow = TextOverflow.Ellipsis,
|
||||
)
|
||||
}
|
||||
}
|
||||
// No failure
|
||||
else {
|
||||
Column(
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
modifier = Modifier.offset(y = -GRID_SIZE / 8)
|
||||
|
@ -146,6 +190,78 @@ fun ModelDownloadingAnimation() {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Download stats
|
||||
var sizeLabel = model.totalBytes.humanReadableSize()
|
||||
if (curDownloadStatus != null) {
|
||||
// For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime}
|
||||
if (inProgress || isPartiallyDownloaded) {
|
||||
var totalSize = curDownloadStatus.totalBytes
|
||||
if (totalSize == 0L) {
|
||||
totalSize = model.totalBytes
|
||||
}
|
||||
sizeLabel =
|
||||
"${curDownloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
|
||||
if (curDownloadStatus.bytesPerSecond > 0) {
|
||||
sizeLabel =
|
||||
"$sizeLabel · ${curDownloadStatus.bytesPerSecond.humanReadableSize()} / s"
|
||||
if (curDownloadStatus.remainingMs >= 0) {
|
||||
sizeLabel =
|
||||
"$sizeLabel · ${curDownloadStatus.remainingMs.formatToHourMinSecond()} left"
|
||||
}
|
||||
}
|
||||
if (isPartiallyDownloaded) {
|
||||
sizeLabel = "$sizeLabel (resuming...)"
|
||||
}
|
||||
curDownloadProgress =
|
||||
curDownloadStatus.receivedBytes.toFloat() / curDownloadStatus.totalBytes.toFloat()
|
||||
if (curDownloadProgress.isNaN()) {
|
||||
curDownloadProgress = 0f
|
||||
}
|
||||
}
|
||||
// Status for unzipping.
|
||||
else if (curDownloadStatus.status == ModelDownloadStatusType.UNZIPPING) {
|
||||
sizeLabel = "Unzipping..."
|
||||
}
|
||||
Text(
|
||||
sizeLabel,
|
||||
color = MaterialTheme.colorScheme.secondary,
|
||||
style = labelSmallNarrow.copy(fontSize = 9.sp, lineHeight = 10.sp),
|
||||
textAlign = TextAlign.Center,
|
||||
overflow = TextOverflow.Visible,
|
||||
modifier = Modifier
|
||||
.padding(bottom = 4.dp)
|
||||
)
|
||||
}
|
||||
|
||||
// Download progress.
|
||||
if (inProgress || isPartiallyDownloaded) {
|
||||
val animatedProgress = remember { Animatable(0f) }
|
||||
LinearProgressIndicator(
|
||||
progress = { animatedProgress.value },
|
||||
color = getTaskIconColor(task = task),
|
||||
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(bottom = 36.dp)
|
||||
.padding(horizontal = 36.dp)
|
||||
)
|
||||
LaunchedEffect(curDownloadProgress) {
|
||||
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
|
||||
}
|
||||
}
|
||||
// Unzipping progress.
|
||||
else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) {
|
||||
LinearProgressIndicator(
|
||||
color = getTaskIconColor(task = task),
|
||||
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(bottom = 36.dp)
|
||||
.padding(horizontal = 36.dp)
|
||||
)
|
||||
}
|
||||
|
||||
Text(
|
||||
"Feel free to switch apps or lock your device.\n"
|
||||
+ "The download will continue in the background.\n"
|
||||
|
@ -154,6 +270,8 @@ fun ModelDownloadingAnimation() {
|
|||
textAlign = TextAlign.Center
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Custom Easing function for a multi-bounce effect
|
||||
|
@ -168,9 +286,15 @@ fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
|
|||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ModelDownloadingAnimationPreview() {
|
||||
val context = LocalContext.current
|
||||
|
||||
GalleryTheme {
|
||||
Row(modifier = Modifier.padding(16.dp)) {
|
||||
ModelDownloadingAnimation()
|
||||
ModelDownloadingAnimation(
|
||||
model = MODEL_TEST1,
|
||||
task = TASK_TEST1,
|
||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -63,7 +63,8 @@ fun ModelSelector(
|
|||
) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth().padding(bottom = 8.dp),
|
||||
.fillMaxWidth()
|
||||
.padding(bottom = 8.dp),
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
// Model row.
|
||||
|
@ -134,7 +135,12 @@ fun ModelSelector(
|
|||
|
||||
// Force to re-initialize the model with the new configs.
|
||||
if (needReinitialization) {
|
||||
modelManagerViewModel.initializeModel(context = context, model = model, force = true)
|
||||
modelManagerViewModel.initializeModel(
|
||||
context = context,
|
||||
task = task,
|
||||
model = model,
|
||||
force = true
|
||||
)
|
||||
}
|
||||
|
||||
// Notify.
|
||||
|
|
|
@ -181,7 +181,5 @@ fun ModelNameAndStatus(
|
|||
.padding(top = 2.dp),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,8 +23,10 @@ import androidx.compose.foundation.layout.size
|
|||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.outlined.HelpOutline
|
||||
import androidx.compose.material.icons.filled.DownloadForOffline
|
||||
import androidx.compose.material.icons.rounded.Downloading
|
||||
import androidx.compose.material.icons.rounded.Error
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
|
@ -34,6 +36,7 @@ import androidx.compose.ui.unit.dp
|
|||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
|
||||
/**
|
||||
* Composable function to display an icon representing the download status of a model.
|
||||
|
@ -56,7 +59,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
|
|||
ModelDownloadStatusType.SUCCEEDED -> {
|
||||
Icon(
|
||||
Icons.Filled.DownloadForOffline,
|
||||
tint = Color(0xff3d860b),
|
||||
tint = MaterialTheme.customColors.successColor,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(14.dp)
|
||||
)
|
||||
|
@ -69,6 +72,12 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
|
|||
modifier = Modifier.size(14.dp)
|
||||
)
|
||||
|
||||
ModelDownloadStatusType.IN_PROGRESS -> Icon(
|
||||
Icons.Rounded.Downloading,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(14.dp)
|
||||
)
|
||||
|
||||
else -> {}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,7 +73,6 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.draw.scale
|
||||
import androidx.compose.ui.focus.focusModifier
|
||||
import androidx.compose.ui.graphics.Brush
|
||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||
import androidx.compose.ui.layout.layout
|
||||
|
@ -91,7 +90,6 @@ import com.google.aiedge.gallery.data.AppBarAction
|
|||
import com.google.aiedge.gallery.data.AppBarActionType
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.TaskIcon
|
||||
import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
||||
|
@ -275,7 +273,6 @@ fun HomeScreen(
|
|||
onDismiss = { showImportingDialog = false },
|
||||
onDone = {
|
||||
modelManagerViewModel.addImportedLlmModel(
|
||||
task = TASK_LLM_CHAT,
|
||||
info = it,
|
||||
)
|
||||
showImportingDialog = false
|
||||
|
|
|
@ -30,7 +30,7 @@ private const val TAG = "AGLlmChatModelHelper"
|
|||
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
|
||||
typealias CleanUpListener = () -> Unit
|
||||
|
||||
data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceSession)
|
||||
data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession)
|
||||
|
||||
object LlmChatModelHelper {
|
||||
// Indexed by model name.
|
||||
|
@ -74,6 +74,24 @@ object LlmChatModelHelper {
|
|||
onDone("")
|
||||
}
|
||||
|
||||
fun resetSession(model: Model) {
|
||||
val instance = model.instance as LlmModelInstance? ?: return
|
||||
val session = instance.session
|
||||
session.close()
|
||||
|
||||
val inference = instance.engine
|
||||
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
|
||||
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
|
||||
val temperature =
|
||||
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
|
||||
val newSession = LlmInferenceSession.createFromOptions(
|
||||
inference,
|
||||
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||
.setTemperature(temperature).build()
|
||||
)
|
||||
instance.session = newSession
|
||||
}
|
||||
|
||||
fun cleanUp(model: Model) {
|
||||
if (model.instance == null) {
|
||||
return
|
||||
|
@ -99,7 +117,11 @@ object LlmChatModelHelper {
|
|||
input: String,
|
||||
resultListener: ResultListener,
|
||||
cleanUpListener: CleanUpListener,
|
||||
singleTurn: Boolean = false,
|
||||
) {
|
||||
if (singleTurn) {
|
||||
resetSession(model = model)
|
||||
}
|
||||
val instance = model.instance as LlmModelInstance
|
||||
|
||||
// Set listener.
|
||||
|
|
|
@ -32,7 +32,7 @@ import kotlinx.coroutines.launch
|
|||
|
||||
private const val TAG = "AGLlmChatViewModel"
|
||||
private val STATS = listOf(
|
||||
Stat(id = "time_to_first_token", label = "Time to 1st token", unit = "sec"),
|
||||
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
|
||||
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
|
||||
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
|
||||
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||
|
||||
import android.util.Log
|
||||
import androidx.activity.compose.BackHandler
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import androidx.compose.animation.fadeIn
|
||||
import androidx.compose.animation.fadeOut
|
||||
import androidx.compose.animation.scaleIn
|
||||
import androidx.compose.animation.scaleOut
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.calculateStartPadding
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalLayoutDirection
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.ui.ViewModelProvider
|
||||
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
||||
import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
|
||||
import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.serialization.Serializable
|
||||
|
||||
/** Navigation destination data */
|
||||
object LlmSingleTurnDestination {
|
||||
@Serializable
|
||||
val route = "LlmSingleTurnRoute"
|
||||
}
|
||||
|
||||
private const val TAG = "AGLlmSingleTurnScreen"
|
||||
|
||||
@Composable
|
||||
fun LlmSingleTurnScreen(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
navigateUp: () -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
viewModel: LlmSingleTurnViewModel = viewModel(
|
||||
factory = ViewModelProvider.Factory
|
||||
),
|
||||
) {
|
||||
val task = viewModel.task
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val selectedModel = modelManagerUiState.selectedModel
|
||||
val scope = rememberCoroutineScope()
|
||||
val context = LocalContext.current
|
||||
|
||||
val handleNavigateUp = {
|
||||
navigateUp()
|
||||
|
||||
// clean up all models.
|
||||
scope.launch(Dispatchers.Default) {
|
||||
for (model in task.models) {
|
||||
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle system's edge swipe.
|
||||
BackHandler {
|
||||
handleNavigateUp()
|
||||
}
|
||||
|
||||
// Initialize model when model/download state changes.
|
||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
||||
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
Log.d(
|
||||
TAG,
|
||||
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
|
||||
)
|
||||
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||
}
|
||||
}
|
||||
|
||||
val modelInitializationStatus =
|
||||
modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
||||
|
||||
Scaffold(modifier = modifier, topBar = {
|
||||
ModelPageAppBar(
|
||||
task = task,
|
||||
model = selectedModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onConfigChanged = { _, _ -> },
|
||||
onBackClicked = { handleNavigateUp() },
|
||||
onModelSelected = { newSelectedModel ->
|
||||
scope.launch(Dispatchers.Default) {
|
||||
// Clean up current model.
|
||||
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
|
||||
|
||||
// Update selected model.
|
||||
modelManagerViewModel.selectModel(model = newSelectedModel)
|
||||
}
|
||||
}
|
||||
)
|
||||
}) { innerPadding ->
|
||||
Column(
|
||||
modifier = Modifier.padding(
|
||||
top = innerPadding.calculateTopPadding(),
|
||||
start = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
|
||||
end = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
|
||||
)
|
||||
) {
|
||||
ModelDownloadStatusInfoPanel(
|
||||
model = selectedModel,
|
||||
task = task,
|
||||
modelManagerViewModel = modelManagerViewModel
|
||||
)
|
||||
|
||||
// Main UI after model is downloaded.
|
||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
Box(
|
||||
contentAlignment = Alignment.BottomCenter,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
VerticalSplitView(modifier = Modifier.fillMaxSize(),
|
||||
topView = {
|
||||
PromptTemplatesPanel(
|
||||
model = selectedModel,
|
||||
viewModel = viewModel,
|
||||
onSend = { fullPrompt ->
|
||||
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
|
||||
}, modifier = Modifier.fillMaxSize()
|
||||
)
|
||||
},
|
||||
bottomView = {
|
||||
Box(
|
||||
contentAlignment = Alignment.BottomCenter,
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||
) {
|
||||
ResponsePanel(
|
||||
model = selectedModel,
|
||||
viewModel = viewModel,
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(bottom = innerPadding.calculateBottomPadding())
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
// Model initialization in-progress message.
|
||||
this@Column.AnimatedVisibility(
|
||||
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||
enter = scaleIn() + fadeIn(),
|
||||
exit = scaleOut() + fadeOut(),
|
||||
modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding())
|
||||
) {
|
||||
ModelInitializationStatusChip()
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun LlmSingleTurnScreenPreview() {
|
||||
val context = LocalContext.current
|
||||
GalleryTheme {
|
||||
LlmSingleTurnScreen(
|
||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||
viewModel = PreviewLlmSingleTurnViewModel(),
|
||||
navigateUp = {},
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||
|
||||
import android.util.Log
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmModelInstance
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
private const val TAG = "AGLlmSingleTurnViewModel"
|
||||
|
||||
data class LlmSingleTurnUiState(
|
||||
/**
|
||||
* Indicates whether the runtime is currently processing a message.
|
||||
*/
|
||||
val inProgress: Boolean = false,
|
||||
|
||||
/**
|
||||
* Indicates whether the model is currently being initialized.
|
||||
*/
|
||||
val initializing: Boolean = false,
|
||||
|
||||
// model -> <template label -> response>
|
||||
val responsesByModel: Map<String, Map<String, String>>,
|
||||
|
||||
// model -> <template label -> benchmark result>
|
||||
val benchmarkByModel: Map<String, Map<String, ChatMessageBenchmarkLlmResult>>,
|
||||
|
||||
/** Selected prompt template type. */
|
||||
val selectedPromptTemplateType: PromptTemplateType = PromptTemplateType.entries[0],
|
||||
)
|
||||
|
||||
private val STATS = listOf(
|
||||
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
|
||||
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
|
||||
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
|
||||
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||
)
|
||||
|
||||
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_SINGLE_TURN) : ViewModel() {
|
||||
private val _uiState = MutableStateFlow(createUiState(task = task))
|
||||
val uiState = _uiState.asStateFlow()
|
||||
|
||||
fun generateResponse(model: Model, input: String) {
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
setInProgress(true)
|
||||
setInitializing(true)
|
||||
|
||||
// Wait for instance to be initialized.
|
||||
while (model.instance == null) {
|
||||
delay(100)
|
||||
}
|
||||
|
||||
// Run inference.
|
||||
val instance = model.instance as LlmModelInstance
|
||||
val prefillTokens = instance.session.sizeInTokens(input)
|
||||
|
||||
var firstRun = true
|
||||
var timeToFirstToken = 0f
|
||||
var firstTokenTs = 0L
|
||||
var decodeTokens = 0
|
||||
var prefillSpeed = 0f
|
||||
var decodeSpeed: Float
|
||||
val start = System.currentTimeMillis()
|
||||
var response = ""
|
||||
var lastBenchmarkUpdateTs = 0L
|
||||
LlmChatModelHelper.runInference(model = model,
|
||||
input = input,
|
||||
resultListener = { partialResult, done ->
|
||||
val curTs = System.currentTimeMillis()
|
||||
|
||||
if (firstRun) {
|
||||
setInitializing(false)
|
||||
firstTokenTs = System.currentTimeMillis()
|
||||
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||
prefillSpeed = prefillTokens / timeToFirstToken
|
||||
firstRun = false
|
||||
} else {
|
||||
decodeTokens++
|
||||
}
|
||||
|
||||
// Incrementally update the streamed partial results.
|
||||
response = processLlmResponse(response = "$response$partialResult")
|
||||
|
||||
// Update response.
|
||||
updateResponse(
|
||||
model = model,
|
||||
promptTemplateType = uiState.value.selectedPromptTemplateType,
|
||||
response = response
|
||||
)
|
||||
|
||||
// Update benchmark (with throttling).
|
||||
if (curTs - lastBenchmarkUpdateTs > 200) {
|
||||
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||
if (decodeSpeed.isNaN()) {
|
||||
decodeSpeed = 0f
|
||||
}
|
||||
val benchmark = ChatMessageBenchmarkLlmResult(
|
||||
orderedStats = STATS,
|
||||
statValues = mutableMapOf(
|
||||
"prefill_speed" to prefillSpeed,
|
||||
"decode_speed" to decodeSpeed,
|
||||
"time_to_first_token" to timeToFirstToken,
|
||||
"latency" to (curTs - start).toFloat() / 1000f,
|
||||
),
|
||||
running = !done,
|
||||
latencyMs = -1f,
|
||||
)
|
||||
updateBenchmark(
|
||||
model = model,
|
||||
promptTemplateType = uiState.value.selectedPromptTemplateType,
|
||||
benchmark = benchmark
|
||||
)
|
||||
lastBenchmarkUpdateTs = curTs
|
||||
}
|
||||
|
||||
if (done) {
|
||||
setInProgress(false)
|
||||
}
|
||||
},
|
||||
singleTurn = true,
|
||||
cleanUpListener = {
|
||||
setInitializing(false)
|
||||
setInProgress(false)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fun selectPromptTemplate(model: Model, promptTemplateType: PromptTemplateType) {
|
||||
Log.d(TAG, "selecting prompt template: ${promptTemplateType.label}")
|
||||
|
||||
// Clear response.
|
||||
updateResponse(model = model, promptTemplateType = promptTemplateType, response = "")
|
||||
|
||||
this._uiState.update { this.uiState.value.copy(selectedPromptTemplateType = promptTemplateType) }
|
||||
}
|
||||
|
||||
fun setInProgress(inProgress: Boolean) {
|
||||
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
||||
}
|
||||
|
||||
fun setInitializing(initializing: Boolean) {
|
||||
_uiState.update { _uiState.value.copy(initializing = initializing) }
|
||||
}
|
||||
|
||||
fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) {
|
||||
_uiState.update { currentState ->
|
||||
val currentResponses = currentState.responsesByModel
|
||||
val modelResponses = currentResponses[model.name]?.toMutableMap() ?: mutableMapOf()
|
||||
modelResponses[promptTemplateType.label] = response
|
||||
val newResponses = currentResponses.toMutableMap()
|
||||
newResponses[model.name] = modelResponses
|
||||
currentState.copy(responsesByModel = newResponses)
|
||||
}
|
||||
}
|
||||
|
||||
fun updateBenchmark(
|
||||
model: Model, promptTemplateType: PromptTemplateType, benchmark: ChatMessageBenchmarkLlmResult
|
||||
) {
|
||||
_uiState.update { currentState ->
|
||||
val currentBenchmark = currentState.benchmarkByModel
|
||||
val modelBenchmarks = currentBenchmark[model.name]?.toMutableMap() ?: mutableMapOf()
|
||||
modelBenchmarks[promptTemplateType.label] = benchmark
|
||||
val newBenchmarks = currentBenchmark.toMutableMap()
|
||||
newBenchmarks[model.name] = modelBenchmarks
|
||||
currentState.copy(benchmarkByModel = newBenchmarks)
|
||||
}
|
||||
}
|
||||
|
||||
private fun createUiState(task: Task): LlmSingleTurnUiState {
|
||||
val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf()
|
||||
val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> =
|
||||
mutableMapOf()
|
||||
for (model in task.models) {
|
||||
responsesByModel[model.name] = mutableMapOf()
|
||||
benchmarkByModel[model.name] = mutableMapOf()
|
||||
}
|
||||
return LlmSingleTurnUiState(
|
||||
responsesByModel = responsesByModel,
|
||||
benchmarkByModel = benchmarkByModel,
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.text.AnnotatedString
|
||||
import androidx.compose.ui.text.buildAnnotatedString
|
||||
import androidx.compose.ui.text.withStyle
|
||||
import androidx.compose.ui.text.SpanStyle
|
||||
import androidx.compose.ui.graphics.Brush.Companion.linearGradient
|
||||
|
||||
enum class PromptTemplateInputEditorType {
|
||||
SINGLE_SELECT
|
||||
}
|
||||
|
||||
enum class RewriteToneType(val label: String) {
|
||||
FORMAL(label = "Formal"), CASUAL(label = "Casual"), FRIENDLY(label = "Friendly"), POLITE(label = "Polite"), ENTHUSIASTIC(
|
||||
label = "Enthusiastic"
|
||||
),
|
||||
CONCISE(label = "Concise"),
|
||||
}
|
||||
|
||||
enum class SummarizationType(val label: String) {
|
||||
KEY_BULLET_POINT(label = "Key bullet points (3-5)"),
|
||||
SHORT_PARAGRAPH(label = "Short paragraph (1-2 sentences)"),
|
||||
CONCISE_SUMMARY(label = "Concise summary (~50 words)"),
|
||||
HEADLINE_TITLE(label = "Headline / title"),
|
||||
ONE_SENTENCE_SUMMARY(label = "One-sentence summary"),
|
||||
}
|
||||
|
||||
enum class LanguageType(val label: String) {
|
||||
CPP(label = "C++"),
|
||||
JAVA(label = "Java"),
|
||||
JAVASCRIPT(label = "JavaScript"),
|
||||
KOTLIN(label = "Kotlin"),
|
||||
PYTHON(label = "Python"),
|
||||
SWIFT(label = "Swift"),
|
||||
TYPESCRIPT(label = "TypeScript"),
|
||||
}
|
||||
|
||||
enum class InputEditorLabel(val label: String) {
|
||||
TONE(label = "Tone"),
|
||||
STYLE(label = "Style"),
|
||||
LANGUAGE(label = "Language"),
|
||||
}
|
||||
|
||||
open class PromptTemplateInputEditor(
|
||||
open val label: String,
|
||||
open val type: PromptTemplateInputEditorType,
|
||||
open val defaultOption: String = "",
|
||||
)
|
||||
|
||||
/** Single select that shows options in bottom sheet. */
|
||||
class PromptTemplateSingleSelectInputEditor(
|
||||
override val label: String,
|
||||
val options: List<String> = listOf(),
|
||||
override val defaultOption: String = "",
|
||||
) : PromptTemplateInputEditor(
|
||||
label = label, type = PromptTemplateInputEditorType.SINGLE_SELECT, defaultOption = defaultOption
|
||||
)
|
||||
|
||||
data class PromptTemplateConfig(val inputEditors: List<PromptTemplateInputEditor> = listOf())
|
||||
|
||||
private val GEMINI_GRADIENT_STYLE = SpanStyle(
|
||||
brush = linearGradient(
|
||||
colors = listOf(Color(0xFF4285f4), Color(0xFF9b72cb), Color(0xFFd96570))
|
||||
)
|
||||
)
|
||||
|
||||
enum class PromptTemplateType(
|
||||
val label: String,
|
||||
val config: PromptTemplateConfig,
|
||||
val genFullPrompt: (userInput: String, inputEditorValues: Map<String, Any>) -> AnnotatedString = { _, _ ->
|
||||
AnnotatedString("")
|
||||
},
|
||||
val examplePrompts: List<String> = listOf(),
|
||||
) {
|
||||
FREE_FORM(
|
||||
label = "Free form",
|
||||
config = PromptTemplateConfig(),
|
||||
genFullPrompt = { userInput, _ -> AnnotatedString(userInput) },
|
||||
examplePrompts = listOf(
|
||||
"Suggest 3 topics for a podcast about \"Friendships in your 20s\".",
|
||||
"Outline the key sections needed in a basic logo design brief.",
|
||||
"List 3 pros and 3 cons to consider before buying a smart watch.",
|
||||
"Write a short, optimistic quote about the future of technology.",
|
||||
"Generate 3 potential names for a mobile app that helps users identify plants.",
|
||||
"Explain the difference between AI and machine learning in 2 sentences.",
|
||||
"Create a simple haiku about a cat sleeping in the sun.",
|
||||
"List 3 ways to make instant noodles taste better using common kitchen ingredients."
|
||||
)
|
||||
),
|
||||
REWRITE_TONE(
|
||||
label = "Rewrite tone", config = PromptTemplateConfig(
|
||||
inputEditors = listOf(
|
||||
PromptTemplateSingleSelectInputEditor(
|
||||
label = InputEditorLabel.TONE.label,
|
||||
options = RewriteToneType.entries.map { it.label },
|
||||
defaultOption = RewriteToneType.FORMAL.label
|
||||
)
|
||||
)
|
||||
), genFullPrompt = { userInput, inputEditorValues ->
|
||||
val tone = inputEditorValues[InputEditorLabel.TONE.label] as String
|
||||
buildAnnotatedString {
|
||||
withStyle(GEMINI_GRADIENT_STYLE) {
|
||||
append("Rewrite the following text using a ${tone.lowercase()} tone: ")
|
||||
}
|
||||
append(userInput)
|
||||
}
|
||||
}, examplePrompts = listOf(
|
||||
"Hey team, just wanted to remind everyone about the meeting tomorrow @ 10. Be there!",
|
||||
"Our new software update includes several bug fixes and performance improvements.",
|
||||
"Due to the fact that the weather was bad, we decided to postpone the event.",
|
||||
"Please find attached the requested documentation for your perusal.",
|
||||
"Welcome to the team. Review the onboarding materials.",
|
||||
)
|
||||
),
|
||||
SUMMARIZE_TEXT(
|
||||
label = "Summarize text",
|
||||
config = PromptTemplateConfig(
|
||||
inputEditors = listOf(
|
||||
PromptTemplateSingleSelectInputEditor(
|
||||
label = InputEditorLabel.STYLE.label,
|
||||
options = SummarizationType.entries.map { it.label },
|
||||
defaultOption = SummarizationType.KEY_BULLET_POINT.label
|
||||
)
|
||||
)
|
||||
),
|
||||
genFullPrompt = { userInput, inputEditorValues ->
|
||||
val style = inputEditorValues[InputEditorLabel.STYLE.label] as String
|
||||
buildAnnotatedString {
|
||||
withStyle(GEMINI_GRADIENT_STYLE) {
|
||||
append("Please summarize the following in ${style.lowercase()}: ")
|
||||
}
|
||||
append(userInput)
|
||||
}
|
||||
},
|
||||
examplePrompts = listOf(
|
||||
"The new Pixel phone features an advanced camera system with improved low-light performance and AI-powered editing tools. The display is brighter and more energy-efficient. It runs on the latest Tensor chip, offering faster processing and enhanced security features. Battery life has also been extended, providing all-day power for most users.",
|
||||
"Beginning this Friday, January 24, giant pandas Bao Li and Qing Bao are officially on view to the public at the Smithsonian’s National Zoo and Conservation Biology Institute (NZCBI). The 3-year-old bears arrived in Washington this past October, undergoing a quarantine period before making their debut. Under NZCBI’s new agreement with the CWCA, Qing Bao and Bao Li will remain in the United States for ten years, until April 2034, in exchange for an annual fee of \$1 million. The pair are still too young to breed, as pandas only reach sexual maturity between ages 4 and 7. “Kind of picture them as like awkward teenagers right now,” Lally told WUSA9. “We still have about two years before we would probably even see signs that they’re ready to start mating.”",
|
||||
),
|
||||
),
|
||||
CODE_SNIPPET(
|
||||
label = "Code snippet",
|
||||
config = PromptTemplateConfig(
|
||||
inputEditors = listOf(
|
||||
PromptTemplateSingleSelectInputEditor(
|
||||
label = InputEditorLabel.LANGUAGE.label,
|
||||
options = LanguageType.entries.map { it.label },
|
||||
defaultOption = LanguageType.JAVASCRIPT.label
|
||||
)
|
||||
)
|
||||
),
|
||||
genFullPrompt = { userInput, inputEditorValues ->
|
||||
val language = inputEditorValues[InputEditorLabel.LANGUAGE.label] as String
|
||||
buildAnnotatedString {
|
||||
withStyle(GEMINI_GRADIENT_STYLE) {
|
||||
append("Write a $language code snippet to ")
|
||||
}
|
||||
append(userInput)
|
||||
}
|
||||
},
|
||||
examplePrompts = listOf(
|
||||
"Create an alert box that says \"Hello, World!\"",
|
||||
"Declare an immutable variable named 'appName' with the value \"AI Gallery\"",
|
||||
"Print the numbers from 1 to 5 using a for loop.",
|
||||
"Write a function that returns the square of an integer input.",
|
||||
),
|
||||
),
|
||||
}
|
||||
|
|
@ -0,0 +1,426 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||
|
||||
import androidx.compose.foundation.BorderStroke
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.layout.wrapContentHeight
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.rounded.Send
|
||||
import androidx.compose.material.icons.outlined.ContentCopy
|
||||
import androidx.compose.material.icons.outlined.Description
|
||||
import androidx.compose.material.icons.outlined.ExpandLess
|
||||
import androidx.compose.material.icons.outlined.ExpandMore
|
||||
import androidx.compose.material.icons.rounded.Add
|
||||
import androidx.compose.material.icons.rounded.Visibility
|
||||
import androidx.compose.material.icons.rounded.VisibilityOff
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.FilterChipDefaults
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButtonDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
import androidx.compose.material3.OutlinedIconButton
|
||||
import androidx.compose.material3.PrimaryScrollableTabRow
|
||||
import androidx.compose.material3.Tab
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextField
|
||||
import androidx.compose.material3.TextFieldDefaults
|
||||
import androidx.compose.material3.rememberModalBottomSheetState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableIntStateOf
|
||||
import androidx.compose.runtime.mutableStateMapOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.snapshots.SnapshotStateMap
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalClipboardManager
|
||||
import androidx.compose.ui.res.dimensionResource
|
||||
import androidx.compose.ui.text.TextLayoutResult
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.R
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
|
||||
private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries
|
||||
private val TAB_TITLES = PromptTemplateType.entries.map { it.label }
|
||||
private val ICON_BUTTON_SIZE = 42.dp
|
||||
|
||||
const val FULL_PROMPT_SWITCH_KEY = "full_prompt"
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun PromptTemplatesPanel(
|
||||
model: Model,
|
||||
viewModel: LlmSingleTurnViewModel,
|
||||
onSend: (fullPrompt: String) -> Unit,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||
val inProgress = uiState.inProgress
|
||||
var selectedTabIndex by remember { mutableIntStateOf(0) }
|
||||
var curTextInputContent by remember { mutableStateOf("") }
|
||||
val inputEditorValues: SnapshotStateMap<String, Any> = remember {
|
||||
mutableStateMapOf(FULL_PROMPT_SWITCH_KEY to false)
|
||||
}
|
||||
val fullPrompt by remember {
|
||||
derivedStateOf {
|
||||
uiState.selectedPromptTemplateType.genFullPrompt(curTextInputContent, inputEditorValues)
|
||||
}
|
||||
}
|
||||
val clipboardManager = LocalClipboardManager.current
|
||||
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
|
||||
|
||||
// Update input editor values when prompt template changes.
|
||||
LaunchedEffect(selectedPromptTemplateType) {
|
||||
for (config in selectedPromptTemplateType.config.inputEditors) {
|
||||
inputEditorValues[config.label] = config.defaultOption
|
||||
}
|
||||
expandedStates.clear()
|
||||
}
|
||||
|
||||
var showExamplePromptBottomSheet by remember { mutableStateOf(false) }
|
||||
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||
val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius)
|
||||
|
||||
Column(modifier = modifier) {
|
||||
// Scrollable tab row for all prompt templates.
|
||||
PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) {
|
||||
TAB_TITLES.forEachIndexed { index, title ->
|
||||
Tab(selected = selectedTabIndex == index, onClick = {
|
||||
// Clear input when tab changes.
|
||||
curTextInputContent = ""
|
||||
// Reset full prompt switch.
|
||||
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false
|
||||
|
||||
selectedTabIndex = index
|
||||
viewModel.selectPromptTemplate(
|
||||
model = model,
|
||||
promptTemplateType = promptTemplateTypes[index]
|
||||
)
|
||||
}, text = { Text(text = title) })
|
||||
}
|
||||
}
|
||||
|
||||
// Content.
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.weight(1f)
|
||||
.fillMaxWidth()
|
||||
) {
|
||||
// Input editor row.
|
||||
if (selectedPromptTemplateType.config.inputEditors.isNotEmpty()) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.background(MaterialTheme.colorScheme.surfaceContainerLow)
|
||||
.padding(horizontal = 16.dp, vertical = 10.dp)
|
||||
) {
|
||||
// Input editors.
|
||||
for (inputEditor in selectedPromptTemplateType.config.inputEditors) {
|
||||
when (inputEditor.type) {
|
||||
PromptTemplateInputEditorType.SINGLE_SELECT -> SingleSelectButton(config = inputEditor as PromptTemplateSingleSelectInputEditor,
|
||||
onSelected = { option ->
|
||||
inputEditorValues[inputEditor.label] = option
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Text input box.
|
||||
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.verticalScroll(rememberScrollState())
|
||||
) {
|
||||
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
|
||||
Text(
|
||||
fullPrompt,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp)
|
||||
.padding(bottom = 32.dp)
|
||||
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
|
||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||
.padding(16.dp)
|
||||
)
|
||||
} else {
|
||||
TextField(
|
||||
value = curTextInputContent,
|
||||
onValueChange = { curTextInputContent = it },
|
||||
colors = TextFieldDefaults.colors(
|
||||
unfocusedContainerColor = Color.Transparent,
|
||||
focusedContainerColor = Color.Transparent,
|
||||
focusedIndicatorColor = Color.Transparent,
|
||||
unfocusedIndicatorColor = Color.Transparent,
|
||||
disabledIndicatorColor = Color.Transparent,
|
||||
disabledContainerColor = Color.Transparent,
|
||||
),
|
||||
textStyle = MaterialTheme.typography.bodyMedium,
|
||||
placeholder = { Text("Enter content") },
|
||||
modifier = Modifier.padding(bottom = 32.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Text action row.
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(vertical = 4.dp, horizontal = 16.dp)
|
||||
) {
|
||||
// Full prompt switch.
|
||||
if (selectedPromptTemplateType != PromptTemplateType.FREE_FORM && curTextInputContent.isNotEmpty()) {
|
||||
Row(verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||
modifier = Modifier
|
||||
.clip(CircleShape)
|
||||
.background(if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.customColors.agentBubbleBgColor)
|
||||
.clickable {
|
||||
inputEditorValues[FULL_PROMPT_SWITCH_KEY] =
|
||||
!(inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean)
|
||||
}
|
||||
.height(40.dp)
|
||||
.border(
|
||||
width = 1.dp, color = MaterialTheme.colorScheme.surface, shape = CircleShape
|
||||
)
|
||||
.padding(horizontal = 12.dp)) {
|
||||
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.Visibility,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(FilterChipDefaults.IconSize),
|
||||
)
|
||||
} else {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.VisibilityOff,
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
.size(FilterChipDefaults.IconSize)
|
||||
.alpha(0.3f),
|
||||
)
|
||||
}
|
||||
Text("Preview prompt", style = MaterialTheme.typography.labelMedium)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.weight(1f))
|
||||
|
||||
// Button to copy full prompt.
|
||||
if (curTextInputContent.isNotEmpty()) {
|
||||
OutlinedIconButton(
|
||||
onClick = {
|
||||
val clipData = fullPrompt
|
||||
clipboardManager.setText(clipData)
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
|
||||
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
|
||||
contentColor = MaterialTheme.colorScheme.primary,
|
||||
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
|
||||
),
|
||||
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
||||
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Outlined.ContentCopy, contentDescription = "", modifier = Modifier.size(20.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Add example prompt button.
|
||||
OutlinedIconButton(
|
||||
enabled = !inProgress,
|
||||
onClick = { showExamplePromptBottomSheet = true },
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
|
||||
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
|
||||
contentColor = MaterialTheme.colorScheme.primary,
|
||||
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
|
||||
),
|
||||
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
||||
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.Add,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(20.dp),
|
||||
)
|
||||
}
|
||||
|
||||
// Send button
|
||||
OutlinedIconButton(
|
||||
enabled = !inProgress && curTextInputContent.isNotEmpty(),
|
||||
onClick = {
|
||||
onSend(fullPrompt.text)
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
|
||||
contentColor = MaterialTheme.colorScheme.primary,
|
||||
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
|
||||
),
|
||||
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
||||
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||
) {
|
||||
Icon(
|
||||
Icons.AutoMirrored.Rounded.Send,
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
.size(20.dp)
|
||||
.offset(x = 2.dp),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (showExamplePromptBottomSheet) {
|
||||
ModalBottomSheet(
|
||||
onDismissRequest = { showExamplePromptBottomSheet = false },
|
||||
sheetState = sheetState,
|
||||
modifier = Modifier.wrapContentHeight(),
|
||||
) {
|
||||
Column(modifier = Modifier.padding(bottom = 16.dp)) {
|
||||
// Title
|
||||
Text(
|
||||
"Select an example",
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp),
|
||||
style = MaterialTheme.typography.titleLarge
|
||||
)
|
||||
|
||||
// Examples
|
||||
for (prompt in selectedPromptTemplateType.examplePrompts) {
|
||||
var textLayoutResultState by remember { mutableStateOf<TextLayoutResult?>(null) }
|
||||
val hasOverflow = remember(textLayoutResultState) {
|
||||
textLayoutResultState?.hasVisualOverflow ?: false
|
||||
}
|
||||
val isExpanded = expandedStates[prompt] ?: false
|
||||
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.clickable {
|
||||
curTextInputContent = prompt
|
||||
showExamplePromptBottomSheet = false
|
||||
}
|
||||
.padding(horizontal = 16.dp, vertical = 8.dp),
|
||||
) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||
) {
|
||||
Icon(Icons.Outlined.Description, contentDescription = "")
|
||||
Text(prompt,
|
||||
maxLines = if (isExpanded) Int.MAX_VALUE else 3,
|
||||
overflow = TextOverflow.Ellipsis,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
modifier = Modifier.weight(1f),
|
||||
onTextLayout = { textLayoutResultState = it }
|
||||
|
||||
)
|
||||
}
|
||||
|
||||
if (hasOverflow && !isExpanded) {
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(top = 2.dp),
|
||||
horizontalArrangement = Arrangement.End
|
||||
) {
|
||||
Box(modifier = Modifier
|
||||
.padding(end = 16.dp)
|
||||
.clip(CircleShape)
|
||||
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
|
||||
.clickable {
|
||||
expandedStates[prompt] = true
|
||||
}
|
||||
.padding(vertical = 1.dp, horizontal = 6.dp)) {
|
||||
Icon(
|
||||
Icons.Outlined.ExpandMore,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(12.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if (isExpanded) {
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(top = 2.dp),
|
||||
horizontalArrangement = Arrangement.End
|
||||
) {
|
||||
Box(modifier = Modifier
|
||||
.padding(end = 16.dp)
|
||||
.clip(CircleShape)
|
||||
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
|
||||
.clickable {
|
||||
expandedStates[prompt] = false
|
||||
}
|
||||
.padding(vertical = 1.dp, horizontal = 6.dp)) {
|
||||
Icon(
|
||||
Icons.Outlined.ExpandLess,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(12.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,206 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.outlined.AutoAwesome
|
||||
import androidx.compose.material.icons.outlined.ContentCopy
|
||||
import androidx.compose.material.icons.outlined.Timer
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.IconButtonDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.PrimaryTabRow
|
||||
import androidx.compose.material3.Tab
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableIntStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalClipboardManager
|
||||
import androidx.compose.ui.text.AnnotatedString
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
|
||||
private val OPTIONS = listOf("Response", "Benchmark")
|
||||
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ResponsePanel(
|
||||
model: Model,
|
||||
viewModel: LlmSingleTurnViewModel,
|
||||
modifier: Modifier = Modifier,
|
||||
) {
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val inProgress = uiState.inProgress
|
||||
val initializing = uiState.initializing
|
||||
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||
val response = uiState.responsesByModel[model.name]?.get(selectedPromptTemplateType.label) ?: ""
|
||||
val benchmark = uiState.benchmarkByModel[model.name]?.get(selectedPromptTemplateType.label)
|
||||
val responseScrollState = rememberScrollState()
|
||||
var selectedOptionIndex by remember { mutableIntStateOf(0) }
|
||||
val clipboardManager = LocalClipboardManager.current
|
||||
|
||||
// Scroll to bottom when response changes.
|
||||
LaunchedEffect(response) {
|
||||
if (inProgress) {
|
||||
responseScrollState.animateScrollTo(responseScrollState.maxValue)
|
||||
}
|
||||
}
|
||||
|
||||
// Select the "response" tab when prompt template changes.
|
||||
LaunchedEffect(selectedPromptTemplateType) {
|
||||
selectedOptionIndex = 0
|
||||
}
|
||||
|
||||
if (initializing) {
|
||||
Box(
|
||||
contentAlignment = Alignment.TopStart,
|
||||
modifier = modifier
|
||||
.fillMaxSize()
|
||||
.padding(horizontal = 16.dp)
|
||||
) {
|
||||
MessageBodyLoading()
|
||||
}
|
||||
} else {
|
||||
// Message when response is empty.
|
||||
if (response.isEmpty()) {
|
||||
Row(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
horizontalArrangement = Arrangement.Center,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Text(
|
||||
"Response will appear here",
|
||||
modifier = Modifier.alpha(0.5f),
|
||||
style = MaterialTheme.typography.labelMedium,
|
||||
)
|
||||
}
|
||||
}
|
||||
// Response markdown.
|
||||
else {
|
||||
Column(
|
||||
modifier = modifier
|
||||
.padding(horizontal = 16.dp)
|
||||
.padding(bottom = 4.dp)
|
||||
) {
|
||||
// Response/benchmark switch.
|
||||
Row(modifier = Modifier.fillMaxWidth()) {
|
||||
PrimaryTabRow(
|
||||
selectedTabIndex = selectedOptionIndex,
|
||||
containerColor = Color.Transparent,
|
||||
) {
|
||||
OPTIONS.forEachIndexed { index, title ->
|
||||
Tab(selected = selectedOptionIndex == index, onClick = {
|
||||
selectedOptionIndex = index
|
||||
}, text = {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
Icon(
|
||||
ICONS[index],
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
.size(16.dp)
|
||||
.alpha(0.7f)
|
||||
)
|
||||
Text(text = title)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if (selectedOptionIndex == 0) {
|
||||
Box(
|
||||
contentAlignment = Alignment.BottomEnd,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.verticalScroll(responseScrollState)
|
||||
) {
|
||||
MarkdownText(
|
||||
text = response,
|
||||
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp)
|
||||
)
|
||||
}
|
||||
// Copy button.
|
||||
IconButton(
|
||||
onClick = {
|
||||
val clipData = AnnotatedString(response)
|
||||
clipboardManager.setText(clipData)
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||
contentColor = MaterialTheme.colorScheme.primary,
|
||||
),
|
||||
) {
|
||||
Icon(
|
||||
Icons.Outlined.ContentCopy,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(20.dp),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if (selectedOptionIndex == 1) {
|
||||
if (benchmark != null) {
|
||||
MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ResponsePanelPreview() {
|
||||
GalleryTheme {
|
||||
ResponsePanel(
|
||||
model = TASK_LLM_SINGLE_TURN.models[0],
|
||||
viewModel = LlmSingleTurnViewModel(),
|
||||
modifier = Modifier.fillMaxSize()
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.ArrowDropDown
|
||||
import androidx.compose.material3.DropdownMenu
|
||||
import androidx.compose.material3.DropdownMenuItem
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.unit.dp
|
||||
|
||||
@Composable
|
||||
fun SingleSelectButton(
|
||||
config: PromptTemplateSingleSelectInputEditor,
|
||||
onSelected: (String) -> Unit
|
||||
) {
|
||||
var showMenu by remember { mutableStateOf(false) }
|
||||
var selectedOption by remember { mutableStateOf(config.defaultOption) }
|
||||
|
||||
LaunchedEffect(config) {
|
||||
selectedOption = config.defaultOption
|
||||
}
|
||||
|
||||
Box {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(2.dp),
|
||||
modifier = Modifier
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.background(MaterialTheme.colorScheme.secondaryContainer)
|
||||
.clickable {
|
||||
showMenu = true
|
||||
}
|
||||
.padding(vertical = 4.dp, horizontal = 6.dp)
|
||||
.padding(start = 8.dp)
|
||||
) {
|
||||
Text("${config.label}: $selectedOption", style = MaterialTheme.typography.labelLarge)
|
||||
Icon(Icons.Rounded.ArrowDropDown, contentDescription = "")
|
||||
}
|
||||
|
||||
DropdownMenu(
|
||||
expanded = showMenu,
|
||||
onDismissRequest = { showMenu = false }
|
||||
) {
|
||||
// Options
|
||||
for (option in config.options) {
|
||||
DropdownMenuItem(
|
||||
text = { Text(option) },
|
||||
onClick = {
|
||||
selectedOption = option
|
||||
showMenu = false
|
||||
onSelected(option)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.gestures.detectDragGestures
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableFloatStateOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.input.pointer.pointerInput
|
||||
import androidx.compose.ui.layout.onGloballyPositioned
|
||||
import androidx.compose.ui.platform.LocalDensity
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.Dp
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
|
||||
@Composable
|
||||
fun VerticalSplitView(
|
||||
topView: @Composable () -> Unit,
|
||||
bottomView: @Composable () -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
initialRatio: Float = 0.5f,
|
||||
minTopHeight: Dp = 250.dp,
|
||||
minBottomHeight: Dp = 200.dp,
|
||||
handleThickness: Dp = 20.dp
|
||||
) {
|
||||
var splitRatio by remember { mutableFloatStateOf(initialRatio) }
|
||||
var columnHeightPx by remember {
|
||||
mutableFloatStateOf(0f)
|
||||
}
|
||||
var columnHeightDp by remember {
|
||||
mutableStateOf(0.dp)
|
||||
}
|
||||
val localDensity = LocalDensity.current
|
||||
|
||||
Column(modifier = modifier
|
||||
.fillMaxSize()
|
||||
.onGloballyPositioned { coordinates ->
|
||||
// Set column height using the LayoutCoordinates
|
||||
columnHeightPx = coordinates.size.height.toFloat()
|
||||
columnHeightDp = with(localDensity) { coordinates.size.height.toDp() }
|
||||
}
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.weight(splitRatio)
|
||||
) {
|
||||
topView()
|
||||
}
|
||||
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(handleThickness)
|
||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||
.pointerInput(Unit) {
|
||||
detectDragGestures { change, dragAmount ->
|
||||
val newTopHeightPx = columnHeightPx * splitRatio + dragAmount.y
|
||||
var newTopHeightDp = with(localDensity) { newTopHeightPx.toDp() }
|
||||
if (newTopHeightDp < minTopHeight) {
|
||||
newTopHeightDp = minTopHeight
|
||||
}
|
||||
if (columnHeightDp - newTopHeightDp < minBottomHeight) {
|
||||
newTopHeightDp = columnHeightDp - minBottomHeight
|
||||
}
|
||||
splitRatio = newTopHeightDp / columnHeightDp
|
||||
change.consume()
|
||||
}
|
||||
},
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.width(32.dp)
|
||||
.height(4.dp)
|
||||
.clip(CircleShape)
|
||||
.background(MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f))
|
||||
)
|
||||
}
|
||||
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.weight(1f - splitRatio)
|
||||
) {
|
||||
bottomView()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun VerticalSplitViewPreview() {
|
||||
GalleryTheme {
|
||||
VerticalSplitView(topView = {
|
||||
Text("top")
|
||||
}, bottomView = {
|
||||
Text("bottom")
|
||||
})
|
||||
}
|
||||
}
|
|
@ -40,6 +40,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatus
|
|||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.TASKS
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.TaskType
|
||||
import com.google.aiedge.gallery.data.ValueType
|
||||
|
@ -175,9 +176,13 @@ open class ModelManagerViewModel(
|
|||
|
||||
// Kick off downloads for these models .
|
||||
withContext(Dispatchers.Main) {
|
||||
val tokenStatusAndData = getTokenStatusAndData()
|
||||
for (info in inProgressWorkInfos) {
|
||||
val model: Model? = getModelByName(info.modelName)
|
||||
if (model != null) {
|
||||
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
|
||||
model.accessToken = tokenStatusAndData.data.accessToken
|
||||
}
|
||||
Log.d(TAG, "Sending a new download request for '${model.name}'")
|
||||
downloadRepository.downloadModel(
|
||||
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
|
||||
|
@ -233,11 +238,13 @@ open class ModelManagerViewModel(
|
|||
|
||||
// Delete model from the list if model is imported as a local model.
|
||||
if (model.imported) {
|
||||
val index = task.models.indexOf(model)
|
||||
for (curTask in TASKS) {
|
||||
val index = curTask.models.indexOf(model)
|
||||
if (index >= 0) {
|
||||
task.models.removeAt(index)
|
||||
curTask.models.removeAt(index)
|
||||
}
|
||||
curTask.updateTrigger.value = System.currentTimeMillis()
|
||||
}
|
||||
task.updateTrigger.value = System.currentTimeMillis()
|
||||
curModelDownloadStatus.remove(model.name)
|
||||
|
||||
// Update preference.
|
||||
|
@ -252,7 +259,7 @@ open class ModelManagerViewModel(
|
|||
_uiState.update { newUiState }
|
||||
}
|
||||
|
||||
fun initializeModel(context: Context, model: Model, force: Boolean = false) {
|
||||
fun initializeModel(context: Context, task: Task, model: Model, force: Boolean = false) {
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
// Skip if initialized already.
|
||||
if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) {
|
||||
|
@ -267,7 +274,7 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
|
||||
// Clean up.
|
||||
cleanupModel(model = model)
|
||||
cleanupModel(task = task, model = model)
|
||||
|
||||
// Start initialization.
|
||||
Log.d(TAG, "Initializing model '${model.name}'...")
|
||||
|
@ -301,7 +308,7 @@ open class ModelManagerViewModel(
|
|||
)
|
||||
}
|
||||
}
|
||||
when (model.taskType) {
|
||||
when (task.type) {
|
||||
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.initialize(
|
||||
context = context,
|
||||
model = model,
|
||||
|
@ -320,24 +327,33 @@ open class ModelManagerViewModel(
|
|||
onDone = onDone,
|
||||
)
|
||||
|
||||
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.initialize(
|
||||
context = context,
|
||||
model = model,
|
||||
onDone = onDone,
|
||||
)
|
||||
|
||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
|
||||
context = context, model = model, onDone = onDone
|
||||
)
|
||||
|
||||
else -> {}
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
TaskType.TEST_TASK_2 -> {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun cleanupModel(model: Model) {
|
||||
fun cleanupModel(task: Task, model: Model) {
|
||||
if (model.instance != null) {
|
||||
Log.d(TAG, "Cleaning up model '${model.name}'...")
|
||||
when (model.taskType) {
|
||||
when (task.type) {
|
||||
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
||||
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
||||
else -> {}
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
TaskType.TEST_TASK_2 -> {}
|
||||
}
|
||||
model.instance = null
|
||||
model.initializing = false
|
||||
|
@ -421,10 +437,11 @@ open class ModelManagerViewModel(
|
|||
return connection.responseCode
|
||||
}
|
||||
|
||||
fun addImportedLlmModel(task: Task, info: ImportedModelInfo) {
|
||||
fun addImportedLlmModel(info: ImportedModelInfo) {
|
||||
Log.d(TAG, "adding imported llm model: $info")
|
||||
|
||||
// Remove duplicated imported model if existed.
|
||||
val task = TASK_LLM_CHAT
|
||||
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
||||
if (modelIndex >= 0) {
|
||||
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
||||
|
@ -455,6 +472,9 @@ open class ModelManagerViewModel(
|
|||
)
|
||||
}
|
||||
task.updateTrigger.value = System.currentTimeMillis()
|
||||
// Also need to update single turn task.
|
||||
TASK_LLM_SINGLE_TURN.updateTrigger.value = System.currentTimeMillis()
|
||||
|
||||
|
||||
// Add to preference storage.
|
||||
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
|
||||
|
@ -630,11 +650,8 @@ open class ModelManagerViewModel(
|
|||
|
||||
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
|
||||
val accelerators: List<Accelerator> = (convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!,
|
||||
ValueType.STRING
|
||||
) as String)
|
||||
.split(",")
|
||||
.mapNotNull { acceleratorLabel ->
|
||||
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING
|
||||
) as String).split(",").mapNotNull { acceleratorLabel ->
|
||||
when (acceleratorLabel.trim()) {
|
||||
Accelerator.GPU.label -> Accelerator.GPU
|
||||
Accelerator.CPU.label -> Accelerator.CPU
|
||||
|
@ -643,20 +660,16 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
val configs: List<Config> = createLlmChatConfigs(
|
||||
defaultMaxToken = convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!,
|
||||
ValueType.INT
|
||||
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!, ValueType.INT
|
||||
) as Int,
|
||||
defaultTopK = convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!,
|
||||
ValueType.INT
|
||||
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!, ValueType.INT
|
||||
) as Int,
|
||||
defaultTopP = convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!,
|
||||
ValueType.FLOAT
|
||||
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!, ValueType.FLOAT
|
||||
) as Float,
|
||||
defaultTemperature = convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!,
|
||||
ValueType.FLOAT
|
||||
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!, ValueType.FLOAT
|
||||
) as Float,
|
||||
accelerators = accelerators,
|
||||
)
|
||||
|
@ -666,9 +679,10 @@ open class ModelManagerViewModel(
|
|||
configs = configs,
|
||||
sizeInBytes = info.fileSize,
|
||||
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
|
||||
showBenchmarkButton = false,
|
||||
imported = true,
|
||||
)
|
||||
model.preProcess(task = task)
|
||||
model.preProcess()
|
||||
|
||||
return model
|
||||
}
|
||||
|
@ -741,7 +755,7 @@ open class ModelManagerViewModel(
|
|||
val task = TASKS.find { it.type.label == hfModel.task }
|
||||
val model = hfModel.toModel()
|
||||
if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) {
|
||||
model.preProcess(task = task)
|
||||
model.preProcess()
|
||||
Log.d(TAG, "AG model: $model")
|
||||
task.models.add(model)
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ import com.google.aiedge.gallery.data.Model
|
|||
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
||||
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.TaskType
|
||||
|
@ -58,6 +59,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
|
|||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationDestination
|
||||
|
@ -131,7 +134,7 @@ fun GalleryNavHost(
|
|||
task = curPickedTask,
|
||||
onModelClicked = { model ->
|
||||
navigateToTaskScreen(
|
||||
navController = navController, taskType = model.taskType!!, model = model
|
||||
navController = navController, taskType = curPickedTask.type, model = model
|
||||
)
|
||||
},
|
||||
navigateUp = { showModelManager = false })
|
||||
|
@ -220,6 +223,24 @@ fun GalleryNavHost(
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
// LLMm single turn.
|
||||
composable(
|
||||
route = "${LlmSingleTurnDestination.route}/{modelName}",
|
||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||
enterTransition = { slideEnter() },
|
||||
exitTransition = { slideExit() },
|
||||
) {
|
||||
getModelFromNavigationParam(it, TASK_LLM_SINGLE_TURN)?.let { defaultModel ->
|
||||
modelManagerViewModel.selectModel(defaultModel)
|
||||
|
||||
LlmSingleTurnScreen(
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
navigateUp = { navController.navigateUp() },
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Handle incoming intents for deep links
|
||||
|
@ -231,9 +252,10 @@ fun GalleryNavHost(
|
|||
if (data.toString().startsWith("com.google.aiedge.gallery://model/")) {
|
||||
val modelName = data.pathSegments.last()
|
||||
getModelByName(modelName)?.let { model ->
|
||||
// TODO(jingjin): need to show a list of possible tasks for this model.
|
||||
navigateToTaskScreen(
|
||||
navController = navController,
|
||||
taskType = model.taskType!!,
|
||||
taskType = TaskType.LLM_CHAT,
|
||||
model = model
|
||||
)
|
||||
}
|
||||
|
@ -249,6 +271,7 @@ fun navigateToTaskScreen(
|
|||
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
||||
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
||||
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
||||
TaskType.LLM_SINGLE_TURN -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
TaskType.TEST_TASK_2 -> {}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.preview
|
||||
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||
|
||||
class PreviewLlmSingleTurnViewModel : LlmSingleTurnViewModel(task = TASK_TEST1)
|
|
@ -34,7 +34,7 @@ class PreviewModelManagerViewModel(context: Context) :
|
|||
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
|
||||
task.index = index
|
||||
for (model in task.models) {
|
||||
model.preProcess(task = task)
|
||||
model.preProcess()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,8 @@ val primaryContainerLight = Color(0xFFD0E4FF)
|
|||
val onPrimaryContainerLight = Color(0xFF144A74)
|
||||
val secondaryLight = Color(0xFF526070)
|
||||
val onSecondaryLight = Color(0xFFFFFFFF)
|
||||
val secondaryContainerLight = Color(0xFFD6E4F7)
|
||||
//val secondaryContainerLight = Color(0xFFD6E4F7)
|
||||
val secondaryContainerLight = Color(0xFFC2E7FF)
|
||||
val onSecondaryContainerLight = Color(0xFF3B4857)
|
||||
val tertiaryLight = Color(0xFF775A0B)
|
||||
val onTertiaryLight = Color(0xFFFFFFFF)
|
||||
|
|
|
@ -116,6 +116,7 @@ data class CustomColors(
|
|||
val userBubbleBgColor: Color = Color.Transparent,
|
||||
val agentBubbleBgColor: Color = Color.Transparent,
|
||||
val linkColor: Color = Color.Transparent,
|
||||
val successColor: Color = Color.Transparent,
|
||||
)
|
||||
|
||||
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
|
||||
|
@ -145,6 +146,7 @@ val lightCustomColors = CustomColors(
|
|||
agentBubbleBgColor = Color(0xFFe9eef6),
|
||||
userBubbleBgColor = Color(0xFF32628D),
|
||||
linkColor = Color(0xFF32628D),
|
||||
successColor = Color(0xff3d860b),
|
||||
)
|
||||
|
||||
val darkCustomColors = CustomColors(
|
||||
|
@ -172,6 +174,7 @@ val darkCustomColors = CustomColors(
|
|||
agentBubbleBgColor = Color(0xFF1b1c1d),
|
||||
userBubbleBgColor = Color(0xFF1f3760),
|
||||
linkColor = Color(0xFF9DCAFC),
|
||||
successColor = Color(0xFFA1CE83),
|
||||
)
|
||||
|
||||
val MaterialTheme.customColors: CustomColors
|
||||
|
|
|
@ -92,6 +92,7 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
|
|||
|
||||
val connection = url.openConnection() as HttpURLConnection
|
||||
if (accessToken != null) {
|
||||
Log.d(TAG, "Using access token: ${accessToken.subSequence(0, 10)}...")
|
||||
connection.setRequestProperty("Authorization", "Bearer $accessToken")
|
||||
}
|
||||
|
||||
|
@ -176,6 +177,7 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
|
|||
KEY_MODEL_DOWNLOAD_REMAINING_MS, remainingMs.toLong()
|
||||
).build()
|
||||
)
|
||||
Log.d(TAG, "downloadedBytes: $downloadedBytes")
|
||||
lastSetProgressTs = curTs
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue