mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-06 06:30:30 -04:00
Add support for LLM single-turn experience
This commit is contained in:
parent
46aaee2654
commit
d94fec0674
39 changed files with 2376 additions and 295 deletions
|
@ -30,7 +30,7 @@ android {
|
||||||
minSdk = 24
|
minSdk = 24
|
||||||
targetSdk = 35
|
targetSdk = 35
|
||||||
versionCode = 1
|
versionCode = 1
|
||||||
versionName = "20250421"
|
versionName = "20250428"
|
||||||
|
|
||||||
// Needed for HuggingFace auth workflows.
|
// Needed for HuggingFace auth workflows.
|
||||||
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
|
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
|
||||||
|
|
|
@ -187,6 +187,5 @@ fun GalleryTopAppBar(
|
||||||
else -> {}
|
else -> {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
)
|
)
|
||||||
}
|
}
|
|
@ -23,7 +23,7 @@ import androidx.datastore.preferences.core.Preferences
|
||||||
import androidx.datastore.preferences.preferencesDataStore
|
import androidx.datastore.preferences.preferencesDataStore
|
||||||
import com.google.aiedge.gallery.data.AppContainer
|
import com.google.aiedge.gallery.data.AppContainer
|
||||||
import com.google.aiedge.gallery.data.DefaultAppContainer
|
import com.google.aiedge.gallery.data.DefaultAppContainer
|
||||||
import com.google.aiedge.gallery.data.TASKS
|
import com.google.aiedge.gallery.ui.common.processTasks
|
||||||
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
||||||
|
|
||||||
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
|
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
|
||||||
|
@ -36,12 +36,7 @@ class GalleryApplication : Application() {
|
||||||
super.onCreate()
|
super.onCreate()
|
||||||
|
|
||||||
// Process tasks.
|
// Process tasks.
|
||||||
for ((index, task) in TASKS.withIndex()) {
|
processTasks()
|
||||||
task.index = index
|
|
||||||
for (model in task.models) {
|
|
||||||
model.preProcess(task = task)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
container = DefaultAppContainer(this, dataStore)
|
container = DefaultAppContainer(this, dataStore)
|
||||||
|
|
||||||
|
|
|
@ -92,15 +92,13 @@ data class Model(
|
||||||
val imported: Boolean = false,
|
val imported: Boolean = false,
|
||||||
|
|
||||||
// The following fields are managed by the app. Don't need to set manually.
|
// The following fields are managed by the app. Don't need to set manually.
|
||||||
var taskType: TaskType? = null,
|
|
||||||
var instance: Any? = null,
|
var instance: Any? = null,
|
||||||
var initializing: Boolean = false,
|
var initializing: Boolean = false,
|
||||||
var configValues: Map<String, Any> = mapOf(),
|
var configValues: Map<String, Any> = mapOf(),
|
||||||
var totalBytes: Long = 0L,
|
var totalBytes: Long = 0L,
|
||||||
var accessToken: String? = null,
|
var accessToken: String? = null,
|
||||||
) {
|
) {
|
||||||
fun preProcess(task: Task) {
|
fun preProcess() {
|
||||||
this.taskType = task.type
|
|
||||||
val configValues: MutableMap<String, Any> = mutableMapOf()
|
val configValues: MutableMap<String, Any> = mutableMapOf()
|
||||||
for (config in this.configs) {
|
for (config in this.configs) {
|
||||||
configValues[config.key.label] = config.defaultValue
|
configValues[config.key.label] = config.defaultValue
|
||||||
|
@ -246,6 +244,7 @@ val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
|
||||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
|
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
|
||||||
sizeInBytes = 1354301440L,
|
sizeInBytes = 1354301440L,
|
||||||
configs = createLlmChatConfigs(),
|
configs = createLlmChatConfigs(),
|
||||||
|
showBenchmarkButton = false,
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||||
)
|
)
|
||||||
|
@ -256,6 +255,7 @@ val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
||||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
|
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
|
||||||
sizeInBytes = 2627141632L,
|
sizeInBytes = 2627141632L,
|
||||||
configs = createLlmChatConfigs(),
|
configs = createLlmChatConfigs(),
|
||||||
|
showBenchmarkButton = false,
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||||
)
|
)
|
||||||
|
@ -271,6 +271,7 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
||||||
defaultTopP = 0.95f,
|
defaultTopP = 0.95f,
|
||||||
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
|
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
|
||||||
),
|
),
|
||||||
|
showBenchmarkButton = false,
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
||||||
llmPromptTemplates = listOf(
|
llmPromptTemplates = listOf(
|
||||||
|
@ -299,6 +300,7 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
|
||||||
defaultTopP = 0.7f,
|
defaultTopP = 0.7f,
|
||||||
accelerators = listOf(Accelerator.CPU)
|
accelerators = listOf(Accelerator.CPU)
|
||||||
),
|
),
|
||||||
|
showBenchmarkButton = false,
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||||
)
|
)
|
||||||
|
@ -389,7 +391,7 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
|
||||||
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
|
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
|
||||||
)
|
)
|
||||||
|
|
||||||
val MODELS_LLM_CHAT: MutableList<Model> = mutableListOf(
|
val MODELS_LLM: MutableList<Model> = mutableListOf(
|
||||||
MODEL_LLM_GEMMA_2B_GPU_INT4,
|
MODEL_LLM_GEMMA_2B_GPU_INT4,
|
||||||
MODEL_LLM_GEMMA_2_2B_GPU_INT8,
|
MODEL_LLM_GEMMA_2_2B_GPU_INT8,
|
||||||
MODEL_LLM_GEMMA_3_1B_INT4,
|
MODEL_LLM_GEMMA_3_1B_INT4,
|
||||||
|
|
|
@ -18,9 +18,11 @@ package com.google.aiedge.gallery.data
|
||||||
|
|
||||||
import androidx.annotation.StringRes
|
import androidx.annotation.StringRes
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.outlined.Forum
|
||||||
|
import androidx.compose.material.icons.outlined.Widgets
|
||||||
import androidx.compose.material.icons.rounded.ImageSearch
|
import androidx.compose.material.icons.rounded.ImageSearch
|
||||||
import androidx.compose.runtime.MutableState
|
import androidx.compose.runtime.MutableState
|
||||||
import androidx.compose.runtime.mutableStateOf
|
import androidx.compose.runtime.mutableLongStateOf
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
|
|
||||||
|
@ -30,6 +32,7 @@ enum class TaskType(val label: String) {
|
||||||
IMAGE_CLASSIFICATION("Image Classification"),
|
IMAGE_CLASSIFICATION("Image Classification"),
|
||||||
IMAGE_GENERATION("Image Generation"),
|
IMAGE_GENERATION("Image Generation"),
|
||||||
LLM_CHAT("LLM Chat"),
|
LLM_CHAT("LLM Chat"),
|
||||||
|
LLM_SINGLE_TURN("LLM Use Cases"),
|
||||||
|
|
||||||
TEST_TASK_1("Test task 1"),
|
TEST_TASK_1("Test task 1"),
|
||||||
TEST_TASK_2("Test task 2")
|
TEST_TASK_2("Test task 2")
|
||||||
|
@ -67,7 +70,7 @@ data class Task(
|
||||||
// The following fields are managed by the app. Don't need to set manually.
|
// The following fields are managed by the app. Don't need to set manually.
|
||||||
var index: Int = -1,
|
var index: Int = -1,
|
||||||
|
|
||||||
val updateTrigger: MutableState<Long> = mutableStateOf(0)
|
val updateTrigger: MutableState<Long> = mutableLongStateOf(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
val TASK_TEXT_CLASSIFICATION = Task(
|
val TASK_TEXT_CLASSIFICATION = Task(
|
||||||
|
@ -87,9 +90,19 @@ val TASK_IMAGE_CLASSIFICATION = Task(
|
||||||
|
|
||||||
val TASK_LLM_CHAT = Task(
|
val TASK_LLM_CHAT = Task(
|
||||||
type = TaskType.LLM_CHAT,
|
type = TaskType.LLM_CHAT,
|
||||||
iconVectorResourceId = R.drawable.chat_spark,
|
icon = Icons.Outlined.Forum,
|
||||||
models = MODELS_LLM_CHAT,
|
models = MODELS_LLM,
|
||||||
description = "Chat? with a on-device large language model",
|
description = "Chat with a on-device large language model",
|
||||||
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||||
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||||
|
)
|
||||||
|
|
||||||
|
val TASK_LLM_SINGLE_TURN = Task(
|
||||||
|
type = TaskType.LLM_SINGLE_TURN,
|
||||||
|
icon = Icons.Outlined.Widgets,
|
||||||
|
models = MODELS_LLM,
|
||||||
|
description = "Single turn use cases with on-device large language model",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||||
|
@ -108,9 +121,10 @@ val TASK_IMAGE_GENERATION = Task(
|
||||||
/** All tasks. */
|
/** All tasks. */
|
||||||
val TASKS: List<Task> = listOf(
|
val TASKS: List<Task> = listOf(
|
||||||
// TASK_TEXT_CLASSIFICATION,
|
// TASK_TEXT_CLASSIFICATION,
|
||||||
// TASK_IMAGE_CLASSIFICATION,
|
TASK_IMAGE_CLASSIFICATION,
|
||||||
TASK_IMAGE_GENERATION,
|
TASK_IMAGE_GENERATION,
|
||||||
TASK_LLM_CHAT,
|
TASK_LLM_CHAT,
|
||||||
|
TASK_LLM_SINGLE_TURN,
|
||||||
)
|
)
|
||||||
|
|
||||||
fun getModelByName(name: String): Model? {
|
fun getModelByName(name: String): Model? {
|
||||||
|
|
|
@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
|
||||||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
|
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
|
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
|
||||||
|
|
||||||
|
@ -56,6 +57,11 @@ object ViewModelProvider {
|
||||||
LlmChatViewModel()
|
LlmChatViewModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initializer for LlmSingleTurnViewModel..
|
||||||
|
initializer {
|
||||||
|
LlmSingleTurnViewModel()
|
||||||
|
}
|
||||||
|
|
||||||
initializer {
|
initializer {
|
||||||
ImageGenerationViewModel()
|
ImageGenerationViewModel()
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,213 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.common
|
||||||
|
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.size
|
||||||
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
|
||||||
|
import androidx.compose.material.icons.rounded.ArrowDropDown
|
||||||
|
import androidx.compose.material.icons.rounded.Settings
|
||||||
|
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||||
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.IconButton
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.ModalBottomSheet
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.material3.rememberModalBottomSheetState
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
import androidx.compose.ui.res.vectorResource
|
||||||
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
||||||
|
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
|
||||||
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
|
@Composable
|
||||||
|
fun ModelPageAppBar(
|
||||||
|
task: Task,
|
||||||
|
model: Model,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
onBackClicked: () -> Unit,
|
||||||
|
onModelSelected: (Model) -> Unit,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
|
||||||
|
) {
|
||||||
|
var showConfigDialog by remember { mutableStateOf(false) }
|
||||||
|
var showModelPicker by remember { mutableStateOf(false) }
|
||||||
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
|
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||||
|
val context = LocalContext.current
|
||||||
|
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
||||||
|
|
||||||
|
CenterAlignedTopAppBar(title = {
|
||||||
|
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||||
|
// Task type.
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
|
||||||
|
tint = getTaskIconColor(task = task),
|
||||||
|
modifier = Modifier.size(16.dp),
|
||||||
|
contentDescription = "",
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
task.type.label,
|
||||||
|
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
|
||||||
|
color = getTaskIconColor(task = task)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model name.
|
||||||
|
Row(verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(2.dp),
|
||||||
|
modifier = Modifier
|
||||||
|
.clip(CircleShape)
|
||||||
|
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
|
||||||
|
.clickable {
|
||||||
|
showModelPicker = true
|
||||||
|
}
|
||||||
|
.padding(start = 8.dp, end = 2.dp)) {
|
||||||
|
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
||||||
|
Text(
|
||||||
|
model.name,
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
modifier = Modifier.padding(start = 4.dp),
|
||||||
|
)
|
||||||
|
Icon(
|
||||||
|
Icons.Rounded.ArrowDropDown,
|
||||||
|
modifier = Modifier.size(20.dp),
|
||||||
|
contentDescription = "",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}, modifier = modifier,
|
||||||
|
// The back button.
|
||||||
|
navigationIcon = {
|
||||||
|
IconButton(onClick = onBackClicked) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.AutoMirrored.Rounded.ArrowBack,
|
||||||
|
contentDescription = "",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// The config button for the model (if existed).
|
||||||
|
actions = {
|
||||||
|
if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||||
|
IconButton(onClick = { showConfigDialog = true }) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Rounded.Settings,
|
||||||
|
contentDescription = "",
|
||||||
|
tint = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Config dialog.
|
||||||
|
if (showConfigDialog) {
|
||||||
|
ConfigDialog(
|
||||||
|
title = "Model configs",
|
||||||
|
configs = model.configs,
|
||||||
|
initialValues = model.configValues,
|
||||||
|
onDismissed = { showConfigDialog = false },
|
||||||
|
onOk = { curConfigValues ->
|
||||||
|
// Hide config dialog.
|
||||||
|
showConfigDialog = false
|
||||||
|
|
||||||
|
// Check if the configs are changed or not. Also check if the model needs to be
|
||||||
|
// re-initialized.
|
||||||
|
var same = true
|
||||||
|
var needReinitialization = false
|
||||||
|
for (config in model.configs) {
|
||||||
|
val key = config.key.label
|
||||||
|
val oldValue = convertValueToTargetType(
|
||||||
|
value = model.configValues.getValue(key), valueType = config.valueType
|
||||||
|
)
|
||||||
|
val newValue = convertValueToTargetType(
|
||||||
|
value = curConfigValues.getValue(key), valueType = config.valueType
|
||||||
|
)
|
||||||
|
if (oldValue != newValue) {
|
||||||
|
same = false
|
||||||
|
if (config.needReinitialization) {
|
||||||
|
needReinitialization = true
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (same) {
|
||||||
|
return@ConfigDialog
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the config values to Model.
|
||||||
|
val oldConfigValues = model.configValues
|
||||||
|
model.configValues = curConfigValues
|
||||||
|
|
||||||
|
// Force to re-initialize the model with the new configs.
|
||||||
|
if (needReinitialization) {
|
||||||
|
modelManagerViewModel.initializeModel(
|
||||||
|
context = context, task = task, model = model, force = true
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify.
|
||||||
|
onConfigChanged(oldConfigValues, model.configValues)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model picker.
|
||||||
|
if (showModelPicker) {
|
||||||
|
ModalBottomSheet(
|
||||||
|
onDismissRequest = { showModelPicker = false },
|
||||||
|
sheetState = sheetState,
|
||||||
|
) {
|
||||||
|
ModelPicker(
|
||||||
|
task = task,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
onModelSelected = { model ->
|
||||||
|
showModelPicker = false
|
||||||
|
onModelSelected(model)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,132 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.common
|
||||||
|
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.Spacer
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.size
|
||||||
|
import androidx.compose.foundation.layout.width
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.outlined.CheckCircle
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
import androidx.compose.ui.res.vectorResource
|
||||||
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.compose.ui.unit.sp
|
||||||
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||||
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ModelPicker(
|
||||||
|
task: Task,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
onModelSelected: (Model) -> Unit
|
||||||
|
) {
|
||||||
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
|
|
||||||
|
Column(modifier = Modifier.padding(bottom = 8.dp)) {
|
||||||
|
// Title
|
||||||
|
Row(
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(horizontal = 16.dp)
|
||||||
|
.padding(top = 4.dp, bottom = 4.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
|
||||||
|
tint = getTaskIconColor(task = task),
|
||||||
|
modifier = Modifier.size(16.dp),
|
||||||
|
contentDescription = "",
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
"${task.type.label} models",
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
style = MaterialTheme.typography.titleMedium,
|
||||||
|
color = getTaskIconColor(task = task),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model list.
|
||||||
|
for (model in task.models) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.SpaceBetween,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.clickable {
|
||||||
|
onModelSelected(model)
|
||||||
|
}
|
||||||
|
.padding(horizontal = 16.dp, vertical = 8.dp),
|
||||||
|
) {
|
||||||
|
Spacer(modifier = Modifier.width(24.dp))
|
||||||
|
Column(modifier = Modifier.weight(1f)) {
|
||||||
|
Text(model.name, style = MaterialTheme.typography.bodyMedium)
|
||||||
|
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
|
||||||
|
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
||||||
|
Text(
|
||||||
|
model.sizeInBytes.humanReadableSize(),
|
||||||
|
color = MaterialTheme.colorScheme.secondary,
|
||||||
|
style = labelSmallNarrow.copy(lineHeight = 10.sp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (model.name == modelManagerUiState.selectedModel.name) {
|
||||||
|
Icon(
|
||||||
|
Icons.Outlined.CheckCircle,
|
||||||
|
modifier = Modifier.size(16.dp),
|
||||||
|
contentDescription = ""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Preview(showBackground = true)
|
||||||
|
@Composable
|
||||||
|
fun ModelPickerPreview() {
|
||||||
|
val context = LocalContext.current
|
||||||
|
|
||||||
|
GalleryTheme {
|
||||||
|
ModelPicker(
|
||||||
|
task = TASK_TEST1,
|
||||||
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
|
onModelSelected = {},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
|
@ -29,6 +29,7 @@ import androidx.core.content.ContextCompat
|
||||||
import androidx.core.content.FileProvider
|
import androidx.core.content.FileProvider
|
||||||
import com.google.aiedge.gallery.data.Config
|
import com.google.aiedge.gallery.data.Config
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.TASKS
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.ValueType
|
import com.google.aiedge.gallery.data.ValueType
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkResult
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkResult
|
||||||
|
@ -57,6 +58,9 @@ interface LatencyProvider {
|
||||||
val latencyMs: Float
|
val latencyMs: Float
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private const val START_THINKING = "***Thinking...***"
|
||||||
|
private const val DONE_THINKING = "***Done thinking***"
|
||||||
|
|
||||||
/** Format the bytes into a human-readable format. */
|
/** Format the bytes into a human-readable format. */
|
||||||
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
|
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
|
||||||
val bytes = this
|
val bytes = this
|
||||||
|
@ -452,3 +456,33 @@ fun cleanUpMediapipeTaskErrorMessage(message: String): String {
|
||||||
}
|
}
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun processTasks() {
|
||||||
|
for ((index, task) in TASKS.withIndex()) {
|
||||||
|
task.index = index
|
||||||
|
for (model in task.models) {
|
||||||
|
model.preProcess()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun processLlmResponse(response: String): String {
|
||||||
|
// Add "thinking" and "done thinking" around the thinking content.
|
||||||
|
var newContent = response
|
||||||
|
.replace("<think>", "$START_THINKING\n")
|
||||||
|
.replace("</think>", "\n$DONE_THINKING")
|
||||||
|
|
||||||
|
// Remove empty thinking content.
|
||||||
|
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
|
||||||
|
if (endThinkingIndex >= 0) {
|
||||||
|
val thinkingContent =
|
||||||
|
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
|
||||||
|
.replace(START_THINKING, "")
|
||||||
|
.replace(DONE_THINKING, "")
|
||||||
|
if (thinkingContent.isBlank()) {
|
||||||
|
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newContent
|
||||||
|
}
|
|
@ -38,6 +38,7 @@ import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.foundation.layout.width
|
import androidx.compose.foundation.layout.width
|
||||||
import androidx.compose.foundation.layout.wrapContentHeight
|
import androidx.compose.foundation.layout.wrapContentHeight
|
||||||
|
import androidx.compose.foundation.layout.wrapContentWidth
|
||||||
import androidx.compose.foundation.lazy.LazyColumn
|
import androidx.compose.foundation.lazy.LazyColumn
|
||||||
import androidx.compose.foundation.lazy.items
|
import androidx.compose.foundation.lazy.items
|
||||||
import androidx.compose.foundation.lazy.rememberLazyListState
|
import androidx.compose.foundation.lazy.rememberLazyListState
|
||||||
|
@ -334,7 +335,10 @@ fun ChatPanel(
|
||||||
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
|
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
|
||||||
|
|
||||||
// Benchmark LLM result.
|
// Benchmark LLM result.
|
||||||
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(message = message)
|
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
|
||||||
|
message = message,
|
||||||
|
modifier = Modifier.wrapContentWidth()
|
||||||
|
)
|
||||||
|
|
||||||
else -> {}
|
else -> {}
|
||||||
}
|
}
|
||||||
|
@ -346,7 +350,7 @@ fun ChatPanel(
|
||||||
) {
|
) {
|
||||||
LatencyText(message = message)
|
LatencyText(message = message)
|
||||||
// A button to show stats for the LLM message.
|
// A button to show stats for the LLM message.
|
||||||
if (selectedModel.taskType == TaskType.LLM_CHAT && message is ChatMessageText
|
if (task.type == TaskType.LLM_CHAT && message is ChatMessageText
|
||||||
// This means we only want to show the action button when the message is done
|
// This means we only want to show the action button when the message is done
|
||||||
// generating, at which point the latency will be set.
|
// generating, at which point the latency will be set.
|
||||||
&& message.latencyMs >= 0
|
&& message.latencyMs >= 0
|
||||||
|
@ -403,21 +407,17 @@ fun ChatPanel(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark button
|
// Benchmark button
|
||||||
// if (selectedModel.showBenchmarkButton) {
|
if (selectedModel.showBenchmarkButton) {
|
||||||
// MessageActionButton(
|
MessageActionButton(
|
||||||
// label = stringResource(R.string.benchmark),
|
label = stringResource(R.string.benchmark),
|
||||||
// icon = Icons.Outlined.Timer,
|
icon = Icons.Outlined.Timer,
|
||||||
// onClick = {
|
onClick = {
|
||||||
// if (selectedModel.taskType == TaskType.LLM_CHAT) {
|
showBenchmarkConfigsDialog = true
|
||||||
// onBenchmarkClicked(selectedModel, message, 0, 0)
|
benchmarkMessage.value = message
|
||||||
// } else {
|
},
|
||||||
// showBenchmarkConfigsDialog = true
|
enabled = !uiState.inProgress
|
||||||
// benchmarkMessage.value = message
|
)
|
||||||
// }
|
}
|
||||||
// },
|
|
||||||
// enabled = !uiState.inProgress
|
|
||||||
// )
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -443,7 +443,7 @@ fun ChatPanel(
|
||||||
// Chat input
|
// Chat input
|
||||||
when (chatInputType) {
|
when (chatInputType) {
|
||||||
ChatInputType.TEXT -> {
|
ChatInputType.TEXT -> {
|
||||||
val isLlmTask = selectedModel.taskType == TaskType.LLM_CHAT
|
val isLlmTask = task.type == TaskType.LLM_CHAT
|
||||||
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
|
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
|
||||||
MessageInputText(
|
MessageInputText(
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
|
|
@ -18,18 +18,10 @@ package com.google.aiedge.gallery.ui.common.chat
|
||||||
|
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.activity.compose.BackHandler
|
import androidx.activity.compose.BackHandler
|
||||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
|
||||||
import androidx.activity.result.contract.ActivityResultContracts
|
|
||||||
import androidx.compose.animation.AnimatedVisibility
|
|
||||||
import androidx.compose.animation.fadeIn
|
|
||||||
import androidx.compose.animation.fadeOut
|
|
||||||
import androidx.compose.animation.scaleIn
|
|
||||||
import androidx.compose.animation.scaleOut
|
|
||||||
import androidx.compose.foundation.background
|
import androidx.compose.foundation.background
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.fillMaxSize
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.pager.HorizontalPager
|
import androidx.compose.foundation.pager.HorizontalPager
|
||||||
import androidx.compose.foundation.pager.rememberPagerState
|
import androidx.compose.foundation.pager.rememberPagerState
|
||||||
|
@ -40,29 +32,21 @@ import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.LaunchedEffect
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
import androidx.compose.runtime.collectAsState
|
import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableStateOf
|
|
||||||
import androidx.compose.runtime.remember
|
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
import androidx.compose.runtime.setValue
|
|
||||||
import androidx.compose.ui.Alignment
|
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.graphics.graphicsLayer
|
import androidx.compose.ui.graphics.graphicsLayer
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import com.google.aiedge.gallery.GalleryTopAppBar
|
|
||||||
import com.google.aiedge.gallery.data.AppBarAction
|
|
||||||
import com.google.aiedge.gallery.data.AppBarActionType
|
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
|
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.delay
|
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
import kotlin.math.absoluteValue
|
import kotlin.math.absoluteValue
|
||||||
|
|
||||||
|
@ -77,7 +61,6 @@ private const val TAG = "AGChatView"
|
||||||
* manages model initialization, cleanup, and download status, and handles navigation and system
|
* manages model initialization, cleanup, and download status, and handles navigation and system
|
||||||
* back gestures.
|
* back gestures.
|
||||||
*/
|
*/
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ChatView(
|
fun ChatView(
|
||||||
task: Task,
|
task: Task,
|
||||||
|
@ -96,34 +79,29 @@ fun ChatView(
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val selectedModel = modelManagerUiState.selectedModel
|
val selectedModel = modelManagerUiState.selectedModel
|
||||||
|
|
||||||
val pagerState = rememberPagerState(initialPage = task.models.indexOf(selectedModel),
|
val pagerState = rememberPagerState(
|
||||||
|
initialPage = task.models.indexOf(selectedModel),
|
||||||
pageCount = { task.models.size })
|
pageCount = { task.models.size })
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
|
|
||||||
val launcher = rememberLauncherForActivityResult(
|
|
||||||
ActivityResultContracts.RequestPermission()
|
|
||||||
) {
|
|
||||||
modelManagerViewModel.downloadModel(task = task, model = selectedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
val handleNavigateUp = {
|
val handleNavigateUp = {
|
||||||
navigateUp()
|
navigateUp()
|
||||||
|
|
||||||
// clean up all models.
|
// clean up all models.
|
||||||
scope.launch(Dispatchers.Default) {
|
scope.launch(Dispatchers.Default) {
|
||||||
for (model in task.models) {
|
for (model in task.models) {
|
||||||
modelManagerViewModel.cleanupModel(model = model)
|
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize model when model/download state changes.
|
// Initialize model when model/download state changes.
|
||||||
val status = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
||||||
LaunchedEffect(status, selectedModel.name) {
|
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
||||||
if (status?.status == ModelDownloadStatusType.SUCCEEDED) {
|
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||||
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
|
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
|
||||||
modelManagerViewModel.initializeModel(context, model = selectedModel)
|
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,7 +113,7 @@ fun ChatView(
|
||||||
"Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model."
|
"Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model."
|
||||||
)
|
)
|
||||||
if (curSelectedModel.name != selectedModel.name) {
|
if (curSelectedModel.name != selectedModel.name) {
|
||||||
modelManagerViewModel.cleanupModel(model = selectedModel)
|
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
|
||||||
}
|
}
|
||||||
modelManagerViewModel.selectModel(curSelectedModel)
|
modelManagerViewModel.selectModel(curSelectedModel)
|
||||||
}
|
}
|
||||||
|
@ -146,24 +124,36 @@ fun ChatView(
|
||||||
}
|
}
|
||||||
|
|
||||||
Scaffold(modifier = modifier, topBar = {
|
Scaffold(modifier = modifier, topBar = {
|
||||||
GalleryTopAppBar(
|
ModelPageAppBar(
|
||||||
title = task.type.label,
|
task = task,
|
||||||
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = {
|
model = selectedModel,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
onConfigChanged = { old, new ->
|
||||||
|
viewModel.addConfigChangedMessage(
|
||||||
|
oldConfigValues = old,
|
||||||
|
newConfigValues = new,
|
||||||
|
model = selectedModel
|
||||||
|
)
|
||||||
|
},
|
||||||
|
onBackClicked = {
|
||||||
handleNavigateUp()
|
handleNavigateUp()
|
||||||
}),
|
},
|
||||||
rightAction = AppBarAction(actionType = AppBarActionType.NO_ACTION, actionFn = {}),
|
onModelSelected = { model ->
|
||||||
|
scope.launch {
|
||||||
|
pagerState.animateScrollToPage(task.models.indexOf(model))
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
}) { innerPadding ->
|
}) { innerPadding ->
|
||||||
Box {
|
Box {
|
||||||
// A horizontal scrollable pager to switch between models.
|
// A horizontal scrollable pager to switch between models.
|
||||||
HorizontalPager(state = pagerState) { pageIndex ->
|
HorizontalPager(state = pagerState) { pageIndex ->
|
||||||
val curSelectedModel = task.models[pageIndex]
|
val curSelectedModel = task.models[pageIndex]
|
||||||
|
val curModelDownloadStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name]
|
||||||
|
|
||||||
// Calculate the alpha of the current page based on how far they are from the center.
|
// Calculate the alpha of the current page based on how far they are from the center.
|
||||||
val pageOffset = (
|
val pageOffset =
|
||||||
(pagerState.currentPage - pageIndex) + pagerState
|
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
|
||||||
.currentPageOffsetFraction
|
|
||||||
).absoluteValue
|
|
||||||
val curAlpha = 1f - pageOffset.coerceIn(0f, 1f)
|
val curAlpha = 1f - pageOffset.coerceIn(0f, 1f)
|
||||||
|
|
||||||
Column(
|
Column(
|
||||||
|
@ -172,91 +162,14 @@ fun ChatView(
|
||||||
.fillMaxSize()
|
.fillMaxSize()
|
||||||
.background(MaterialTheme.colorScheme.surface)
|
.background(MaterialTheme.colorScheme.surface)
|
||||||
) {
|
) {
|
||||||
// Model selector at the top.
|
ModelDownloadStatusInfoPanel(
|
||||||
ModelSelector(
|
|
||||||
model = curSelectedModel,
|
model = curSelectedModel,
|
||||||
task = task,
|
task = task,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel
|
||||||
onConfigChanged = { old, new ->
|
|
||||||
viewModel.addConfigChangedMessage(
|
|
||||||
oldConfigValues = old,
|
|
||||||
newConfigValues = new,
|
|
||||||
model = curSelectedModel
|
|
||||||
)
|
)
|
||||||
},
|
|
||||||
modifier = Modifier.fillMaxWidth(),
|
|
||||||
contentAlpha = curAlpha,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Manages the conditional display of UI elements (download model button and downloading
|
|
||||||
// animation) based on the corresponding download status.
|
|
||||||
//
|
|
||||||
// It uses delayed visibility ensuring they are shown only after a short delay if their
|
|
||||||
// respective conditions remain true. This prevents UI flickering and provides a smoother
|
|
||||||
// user experience.
|
|
||||||
val curStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name]
|
|
||||||
var shouldShowDownloadingAnimation by remember { mutableStateOf(false) }
|
|
||||||
var downloadingAnimationConditionMet by remember { mutableStateOf(false) }
|
|
||||||
var shouldShowDownloadModelButton by remember { mutableStateOf(false) }
|
|
||||||
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
|
|
||||||
|
|
||||||
downloadingAnimationConditionMet =
|
|
||||||
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS ||
|
|
||||||
curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED ||
|
|
||||||
curStatus?.status == ModelDownloadStatusType.UNZIPPING
|
|
||||||
downloadModelButtonConditionMet =
|
|
||||||
curStatus?.status == ModelDownloadStatusType.FAILED ||
|
|
||||||
curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
|
|
||||||
|
|
||||||
LaunchedEffect(downloadingAnimationConditionMet) {
|
|
||||||
if (downloadingAnimationConditionMet) {
|
|
||||||
delay(100)
|
|
||||||
shouldShowDownloadingAnimation = true
|
|
||||||
} else {
|
|
||||||
shouldShowDownloadingAnimation = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LaunchedEffect(downloadModelButtonConditionMet) {
|
|
||||||
if (downloadModelButtonConditionMet) {
|
|
||||||
delay(700)
|
|
||||||
shouldShowDownloadModelButton = true
|
|
||||||
} else {
|
|
||||||
shouldShowDownloadModelButton = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
AnimatedVisibility(
|
|
||||||
visible = shouldShowDownloadingAnimation,
|
|
||||||
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
|
|
||||||
exit = scaleOut(targetScale = 0.9f) + fadeOut()
|
|
||||||
) {
|
|
||||||
Box(
|
|
||||||
modifier = Modifier.fillMaxSize(),
|
|
||||||
contentAlignment = Alignment.Center
|
|
||||||
) {
|
|
||||||
ModelDownloadingAnimation()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
AnimatedVisibility(
|
|
||||||
visible = shouldShowDownloadModelButton,
|
|
||||||
enter = fadeIn(),
|
|
||||||
exit = fadeOut()
|
|
||||||
) {
|
|
||||||
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
|
|
||||||
checkNotificationPermissionAndStartDownload(
|
|
||||||
context = context,
|
|
||||||
launcher = launcher,
|
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
|
||||||
task = task,
|
|
||||||
model = curSelectedModel
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// The main messages panel.
|
// The main messages panel.
|
||||||
if (curStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
if (curModelDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||||
ChatPanel(
|
ChatPanel(
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
task = task,
|
task = task,
|
||||||
|
|
|
@ -20,13 +20,12 @@ import android.util.Log
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.flow.update
|
import kotlinx.coroutines.flow.update
|
||||||
|
|
||||||
private const val TAG = "AGChatViewModel"
|
private const val TAG = "AGChatViewModel"
|
||||||
private const val START_THINKING = "***Thinking...***"
|
|
||||||
private const val DONE_THINKING = "***Done thinking***"
|
|
||||||
|
|
||||||
data class ChatUiState(
|
data class ChatUiState(
|
||||||
/**
|
/**
|
||||||
|
@ -121,26 +120,7 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
||||||
if (newMessages.size > 0) {
|
if (newMessages.size > 0) {
|
||||||
val lastMessage = newMessages.last()
|
val lastMessage = newMessages.last()
|
||||||
if (lastMessage is ChatMessageText) {
|
if (lastMessage is ChatMessageText) {
|
||||||
var newContent = "${lastMessage.content}${partialContent}"
|
val newContent = processLlmResponse(response = "${lastMessage.content}${partialContent}")
|
||||||
// TODO: special handling for deepseek to remove the <think> tag.
|
|
||||||
|
|
||||||
// Add "thinking" and "done thinking" around the thinking content.
|
|
||||||
newContent = newContent
|
|
||||||
.replace("<think>", "$START_THINKING\n")
|
|
||||||
.replace("</think>", "\n$DONE_THINKING")
|
|
||||||
|
|
||||||
// Remove empty thinking content.
|
|
||||||
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
|
|
||||||
if (endThinkingIndex >= 0) {
|
|
||||||
val thinkingContent =
|
|
||||||
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
|
|
||||||
.replace(START_THINKING, "")
|
|
||||||
.replace(DONE_THINKING, "")
|
|
||||||
if (thinkingContent.isBlank()) {
|
|
||||||
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val newLastMessage = ChatMessageText(
|
val newLastMessage = ChatMessageText(
|
||||||
content = newContent,
|
content = newContent,
|
||||||
side = lastMessage.side,
|
side = lastMessage.side,
|
||||||
|
|
|
@ -45,7 +45,7 @@ fun MarkdownText(
|
||||||
ProvideTextStyle(
|
ProvideTextStyle(
|
||||||
value = TextStyle(
|
value = TextStyle(
|
||||||
fontSize = fontSize,
|
fontSize = fontSize,
|
||||||
lineHeight = fontSize * 1.2,
|
lineHeight = fontSize * 1.4,
|
||||||
)
|
)
|
||||||
) {
|
) {
|
||||||
RichText(
|
RichText(
|
||||||
|
|
|
@ -48,15 +48,16 @@ fun MessageActionButton(
|
||||||
label: String,
|
label: String,
|
||||||
icon: ImageVector,
|
icon: ImageVector,
|
||||||
onClick: () -> Unit,
|
onClick: () -> Unit,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
enabled: Boolean = true
|
enabled: Boolean = true
|
||||||
) {
|
) {
|
||||||
val modifier = Modifier
|
val curModifier = modifier
|
||||||
.padding(top = 4.dp)
|
.padding(top = 4.dp)
|
||||||
.clip(CircleShape)
|
.clip(CircleShape)
|
||||||
.background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh)
|
.background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh)
|
||||||
val alpha: Float = if (enabled) 1.0f else 0.3f
|
val alpha: Float = if (enabled) 1.0f else 0.3f
|
||||||
Row(
|
Row(
|
||||||
modifier = if (enabled) modifier.clickable { onClick() } else modifier,
|
modifier = if (enabled) curModifier.clickable { onClick() } else modifier,
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
) {
|
) {
|
||||||
Icon(
|
Icon(
|
||||||
|
|
|
@ -19,8 +19,8 @@ package com.google.aiedge.gallery.ui.common.chat
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.wrapContentWidth
|
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
@ -33,16 +33,14 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
* This function renders benchmark statistics (e.g., various token speed) in data cards
|
* This function renders benchmark statistics (e.g., various token speed) in data cards
|
||||||
*/
|
*/
|
||||||
@Composable
|
@Composable
|
||||||
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult) {
|
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult, modifier: Modifier = Modifier) {
|
||||||
Column(
|
Column(
|
||||||
modifier = Modifier
|
modifier = modifier.padding(12.dp),
|
||||||
.padding(12.dp)
|
|
||||||
.wrapContentWidth(),
|
|
||||||
verticalArrangement = Arrangement.spacedBy(8.dp)
|
verticalArrangement = Arrangement.spacedBy(8.dp)
|
||||||
) {
|
) {
|
||||||
// Data cards.
|
// Data cards.
|
||||||
Row(
|
Row(
|
||||||
modifier = Modifier.wrapContentWidth(), horizontalArrangement = Arrangement.spacedBy(16.dp)
|
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
|
||||||
) {
|
) {
|
||||||
for (stat in message.orderedStats) {
|
for (stat in message.orderedStats) {
|
||||||
DataCard(
|
DataCard(
|
||||||
|
|
|
@ -82,7 +82,7 @@ fun MessageBodyPromptTemplates(
|
||||||
style = MaterialTheme.typography.titleSmall,
|
style = MaterialTheme.typography.titleSmall,
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
.offset(y = -4.dp),
|
.offset(y = (-4).dp),
|
||||||
textAlign = TextAlign.Center,
|
textAlign = TextAlign.Center,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -140,7 +140,7 @@ fun MessageBodyPromptTemplatesPreview() {
|
||||||
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
|
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
|
||||||
task.index = index
|
task.index = index
|
||||||
for (model in task.models) {
|
for (model in task.models) {
|
||||||
model.preProcess(task = task)
|
model.preProcess()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,6 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
* This function renders a row containing a text field for message input and a send button.
|
* This function renders a row containing a text field for message input and a send button.
|
||||||
* It handles message composition, input validation, and sending messages.
|
* It handles message composition, input validation, and sending messages.
|
||||||
*/
|
*/
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
|
||||||
@Composable
|
@Composable
|
||||||
fun MessageInputText(
|
fun MessageInputText(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
@ -190,7 +189,7 @@ fun MessageInputText(
|
||||||
Icons.AutoMirrored.Rounded.Send,
|
Icons.AutoMirrored.Rounded.Send,
|
||||||
contentDescription = "",
|
contentDescription = "",
|
||||||
modifier = Modifier.offset(x = 2.dp),
|
modifier = Modifier.offset(x = 2.dp),
|
||||||
tint = if (inProgress) MaterialTheme.colorScheme.surfaceContainerHigh else MaterialTheme.colorScheme.primary
|
tint = MaterialTheme.colorScheme.primary
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,119 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.common.chat
|
||||||
|
|
||||||
|
import androidx.compose.animation.AnimatedVisibility
|
||||||
|
import androidx.compose.animation.fadeIn
|
||||||
|
import androidx.compose.animation.fadeOut
|
||||||
|
import androidx.compose.animation.scaleIn
|
||||||
|
import androidx.compose.animation.scaleOut
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ModelDownloadStatusInfoPanel(
|
||||||
|
model: Model,
|
||||||
|
task: Task,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel
|
||||||
|
) {
|
||||||
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
|
|
||||||
|
// Manages the conditional display of UI elements (download model button and downloading
|
||||||
|
// animation) based on the corresponding download status.
|
||||||
|
//
|
||||||
|
// It uses delayed visibility ensuring they are shown only after a short delay if their
|
||||||
|
// respective conditions remain true. This prevents UI flickering and provides a smoother
|
||||||
|
// user experience.
|
||||||
|
val curStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
||||||
|
var shouldShowDownloadingAnimation by remember { mutableStateOf(false) }
|
||||||
|
var downloadingAnimationConditionMet by remember { mutableStateOf(false) }
|
||||||
|
var shouldShowDownloadModelButton by remember { mutableStateOf(false) }
|
||||||
|
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
|
downloadingAnimationConditionMet =
|
||||||
|
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS || curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED || curStatus?.status == ModelDownloadStatusType.UNZIPPING
|
||||||
|
downloadModelButtonConditionMet =
|
||||||
|
curStatus?.status == ModelDownloadStatusType.FAILED || curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
|
||||||
|
|
||||||
|
LaunchedEffect(downloadingAnimationConditionMet) {
|
||||||
|
if (downloadingAnimationConditionMet) {
|
||||||
|
delay(100)
|
||||||
|
shouldShowDownloadingAnimation = true
|
||||||
|
} else {
|
||||||
|
shouldShowDownloadingAnimation = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LaunchedEffect(downloadModelButtonConditionMet) {
|
||||||
|
if (downloadModelButtonConditionMet) {
|
||||||
|
delay(700)
|
||||||
|
shouldShowDownloadModelButton = true
|
||||||
|
} else {
|
||||||
|
shouldShowDownloadModelButton = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AnimatedVisibility(
|
||||||
|
visible = shouldShowDownloadingAnimation,
|
||||||
|
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
|
||||||
|
exit = scaleOut(targetScale = 0.9f) + fadeOut()
|
||||||
|
) {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center
|
||||||
|
) {
|
||||||
|
ModelDownloadingAnimation(
|
||||||
|
model = model, task = task, modelManagerViewModel = modelManagerViewModel
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AnimatedVisibility(
|
||||||
|
visible = shouldShowDownloadModelButton, enter = fadeIn(), exit = fadeOut()
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
verticalArrangement = Arrangement.Center,
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally
|
||||||
|
) {
|
||||||
|
DownloadAndTryButton(
|
||||||
|
task = task,
|
||||||
|
model = model,
|
||||||
|
enabled = true,
|
||||||
|
needToDownloadFirst = true,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
onClicked = {}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,6 +24,7 @@ import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.height
|
import androidx.compose.foundation.layout.height
|
||||||
import androidx.compose.foundation.layout.offset
|
import androidx.compose.foundation.layout.offset
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
|
@ -32,23 +33,40 @@ import androidx.compose.foundation.layout.width
|
||||||
import androidx.compose.foundation.lazy.grid.GridCells
|
import androidx.compose.foundation.lazy.grid.GridCells
|
||||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||||
import androidx.compose.foundation.lazy.grid.itemsIndexed
|
import androidx.compose.foundation.lazy.grid.itemsIndexed
|
||||||
|
import androidx.compose.material3.LinearProgressIndicator
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.LaunchedEffect
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.derivedStateOf
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.graphics.ColorFilter
|
import androidx.compose.ui.graphics.ColorFilter
|
||||||
import androidx.compose.ui.graphics.graphicsLayer
|
import androidx.compose.ui.graphics.graphicsLayer
|
||||||
import androidx.compose.ui.layout.ContentScale
|
import androidx.compose.ui.layout.ContentScale
|
||||||
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.res.painterResource
|
import androidx.compose.ui.res.painterResource
|
||||||
import androidx.compose.ui.text.style.TextAlign
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
|
import androidx.compose.ui.text.style.TextOverflow
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.compose.ui.unit.sp
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
import com.google.aiedge.gallery.ui.common.formatToHourMinSecond
|
||||||
import com.google.aiedge.gallery.ui.common.getTaskIconColor
|
import com.google.aiedge.gallery.ui.common.getTaskIconColor
|
||||||
|
import com.google.aiedge.gallery.ui.common.humanReadableSize
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.preview.MODEL_TEST1
|
||||||
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlin.math.cos
|
import kotlin.math.cos
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
@ -66,8 +84,19 @@ private const val END_SCALE = 0.6f
|
||||||
* scaling and rotation effect.
|
* scaling and rotation effect.
|
||||||
*/
|
*/
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelDownloadingAnimation() {
|
fun ModelDownloadingAnimation(
|
||||||
|
model: Model,
|
||||||
|
task: Task,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel
|
||||||
|
) {
|
||||||
val scale = remember { Animatable(END_SCALE) }
|
val scale = remember { Animatable(END_SCALE) }
|
||||||
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
|
val downloadStatus by remember {
|
||||||
|
derivedStateOf { modelManagerUiState.modelDownloadStatus[model.name] }
|
||||||
|
}
|
||||||
|
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
|
||||||
|
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
|
||||||
|
var curDownloadProgress = 0f
|
||||||
|
|
||||||
LaunchedEffect(Unit) { // Run this once
|
LaunchedEffect(Unit) { // Run this once
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -93,6 +122,21 @@ fun ModelDownloadingAnimation() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Failure message.
|
||||||
|
val curDownloadStatus = downloadStatus
|
||||||
|
if (curDownloadStatus != null && curDownloadStatus.status == ModelDownloadStatusType.FAILED) {
|
||||||
|
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||||
|
Text(
|
||||||
|
curDownloadStatus.errorMessage,
|
||||||
|
color = MaterialTheme.colorScheme.error,
|
||||||
|
style = labelSmallNarrow,
|
||||||
|
overflow = TextOverflow.Ellipsis,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// No failure
|
||||||
|
else {
|
||||||
Column(
|
Column(
|
||||||
horizontalAlignment = Alignment.CenterHorizontally,
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
modifier = Modifier.offset(y = -GRID_SIZE / 8)
|
modifier = Modifier.offset(y = -GRID_SIZE / 8)
|
||||||
|
@ -146,6 +190,78 @@ fun ModelDownloadingAnimation() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Download stats
|
||||||
|
var sizeLabel = model.totalBytes.humanReadableSize()
|
||||||
|
if (curDownloadStatus != null) {
|
||||||
|
// For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime}
|
||||||
|
if (inProgress || isPartiallyDownloaded) {
|
||||||
|
var totalSize = curDownloadStatus.totalBytes
|
||||||
|
if (totalSize == 0L) {
|
||||||
|
totalSize = model.totalBytes
|
||||||
|
}
|
||||||
|
sizeLabel =
|
||||||
|
"${curDownloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
|
||||||
|
if (curDownloadStatus.bytesPerSecond > 0) {
|
||||||
|
sizeLabel =
|
||||||
|
"$sizeLabel · ${curDownloadStatus.bytesPerSecond.humanReadableSize()} / s"
|
||||||
|
if (curDownloadStatus.remainingMs >= 0) {
|
||||||
|
sizeLabel =
|
||||||
|
"$sizeLabel · ${curDownloadStatus.remainingMs.formatToHourMinSecond()} left"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (isPartiallyDownloaded) {
|
||||||
|
sizeLabel = "$sizeLabel (resuming...)"
|
||||||
|
}
|
||||||
|
curDownloadProgress =
|
||||||
|
curDownloadStatus.receivedBytes.toFloat() / curDownloadStatus.totalBytes.toFloat()
|
||||||
|
if (curDownloadProgress.isNaN()) {
|
||||||
|
curDownloadProgress = 0f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Status for unzipping.
|
||||||
|
else if (curDownloadStatus.status == ModelDownloadStatusType.UNZIPPING) {
|
||||||
|
sizeLabel = "Unzipping..."
|
||||||
|
}
|
||||||
|
Text(
|
||||||
|
sizeLabel,
|
||||||
|
color = MaterialTheme.colorScheme.secondary,
|
||||||
|
style = labelSmallNarrow.copy(fontSize = 9.sp, lineHeight = 10.sp),
|
||||||
|
textAlign = TextAlign.Center,
|
||||||
|
overflow = TextOverflow.Visible,
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(bottom = 4.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download progress.
|
||||||
|
if (inProgress || isPartiallyDownloaded) {
|
||||||
|
val animatedProgress = remember { Animatable(0f) }
|
||||||
|
LinearProgressIndicator(
|
||||||
|
progress = { animatedProgress.value },
|
||||||
|
color = getTaskIconColor(task = task),
|
||||||
|
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(bottom = 36.dp)
|
||||||
|
.padding(horizontal = 36.dp)
|
||||||
|
)
|
||||||
|
LaunchedEffect(curDownloadProgress) {
|
||||||
|
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Unzipping progress.
|
||||||
|
else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) {
|
||||||
|
LinearProgressIndicator(
|
||||||
|
color = getTaskIconColor(task = task),
|
||||||
|
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(bottom = 36.dp)
|
||||||
|
.padding(horizontal = 36.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
Text(
|
Text(
|
||||||
"Feel free to switch apps or lock your device.\n"
|
"Feel free to switch apps or lock your device.\n"
|
||||||
+ "The download will continue in the background.\n"
|
+ "The download will continue in the background.\n"
|
||||||
|
@ -156,6 +272,8 @@ fun ModelDownloadingAnimation() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// Custom Easing function for a multi-bounce effect
|
// Custom Easing function for a multi-bounce effect
|
||||||
fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
|
fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
|
||||||
if (x == 1f) {
|
if (x == 1f) {
|
||||||
|
@ -168,9 +286,15 @@ fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
|
||||||
@Preview(showBackground = true)
|
@Preview(showBackground = true)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelDownloadingAnimationPreview() {
|
fun ModelDownloadingAnimationPreview() {
|
||||||
|
val context = LocalContext.current
|
||||||
|
|
||||||
GalleryTheme {
|
GalleryTheme {
|
||||||
Row(modifier = Modifier.padding(16.dp)) {
|
Row(modifier = Modifier.padding(16.dp)) {
|
||||||
ModelDownloadingAnimation()
|
ModelDownloadingAnimation(
|
||||||
|
model = MODEL_TEST1,
|
||||||
|
task = TASK_TEST1,
|
||||||
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -63,7 +63,8 @@ fun ModelSelector(
|
||||||
) {
|
) {
|
||||||
Box(
|
Box(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth().padding(bottom = 8.dp),
|
.fillMaxWidth()
|
||||||
|
.padding(bottom = 8.dp),
|
||||||
contentAlignment = Alignment.Center
|
contentAlignment = Alignment.Center
|
||||||
) {
|
) {
|
||||||
// Model row.
|
// Model row.
|
||||||
|
@ -134,7 +135,12 @@ fun ModelSelector(
|
||||||
|
|
||||||
// Force to re-initialize the model with the new configs.
|
// Force to re-initialize the model with the new configs.
|
||||||
if (needReinitialization) {
|
if (needReinitialization) {
|
||||||
modelManagerViewModel.initializeModel(context = context, model = model, force = true)
|
modelManagerViewModel.initializeModel(
|
||||||
|
context = context,
|
||||||
|
task = task,
|
||||||
|
model = model,
|
||||||
|
force = true
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify.
|
// Notify.
|
||||||
|
|
|
@ -181,7 +181,5 @@ fun ModelNameAndStatus(
|
||||||
.padding(top = 2.dp),
|
.padding(top = 2.dp),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,8 +23,10 @@ import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.automirrored.outlined.HelpOutline
|
import androidx.compose.material.icons.automirrored.outlined.HelpOutline
|
||||||
import androidx.compose.material.icons.filled.DownloadForOffline
|
import androidx.compose.material.icons.filled.DownloadForOffline
|
||||||
|
import androidx.compose.material.icons.rounded.Downloading
|
||||||
import androidx.compose.material.icons.rounded.Error
|
import androidx.compose.material.icons.rounded.Error
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
@ -34,6 +36,7 @@ import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Composable function to display an icon representing the download status of a model.
|
* Composable function to display an icon representing the download status of a model.
|
||||||
|
@ -56,7 +59,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
|
||||||
ModelDownloadStatusType.SUCCEEDED -> {
|
ModelDownloadStatusType.SUCCEEDED -> {
|
||||||
Icon(
|
Icon(
|
||||||
Icons.Filled.DownloadForOffline,
|
Icons.Filled.DownloadForOffline,
|
||||||
tint = Color(0xff3d860b),
|
tint = MaterialTheme.customColors.successColor,
|
||||||
contentDescription = "",
|
contentDescription = "",
|
||||||
modifier = Modifier.size(14.dp)
|
modifier = Modifier.size(14.dp)
|
||||||
)
|
)
|
||||||
|
@ -69,6 +72,12 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
|
||||||
modifier = Modifier.size(14.dp)
|
modifier = Modifier.size(14.dp)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ModelDownloadStatusType.IN_PROGRESS -> Icon(
|
||||||
|
Icons.Rounded.Downloading,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier.size(14.dp)
|
||||||
|
)
|
||||||
|
|
||||||
else -> {}
|
else -> {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,7 +73,6 @@ import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.alpha
|
import androidx.compose.ui.draw.alpha
|
||||||
import androidx.compose.ui.draw.clip
|
import androidx.compose.ui.draw.clip
|
||||||
import androidx.compose.ui.draw.scale
|
import androidx.compose.ui.draw.scale
|
||||||
import androidx.compose.ui.focus.focusModifier
|
|
||||||
import androidx.compose.ui.graphics.Brush
|
import androidx.compose.ui.graphics.Brush
|
||||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||||
import androidx.compose.ui.layout.layout
|
import androidx.compose.ui.layout.layout
|
||||||
|
@ -91,7 +90,6 @@ import com.google.aiedge.gallery.data.AppBarAction
|
||||||
import com.google.aiedge.gallery.data.AppBarActionType
|
import com.google.aiedge.gallery.data.AppBarActionType
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.ImportedModelInfo
|
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.TaskIcon
|
import com.google.aiedge.gallery.ui.common.TaskIcon
|
||||||
import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
||||||
|
@ -275,7 +273,6 @@ fun HomeScreen(
|
||||||
onDismiss = { showImportingDialog = false },
|
onDismiss = { showImportingDialog = false },
|
||||||
onDone = {
|
onDone = {
|
||||||
modelManagerViewModel.addImportedLlmModel(
|
modelManagerViewModel.addImportedLlmModel(
|
||||||
task = TASK_LLM_CHAT,
|
|
||||||
info = it,
|
info = it,
|
||||||
)
|
)
|
||||||
showImportingDialog = false
|
showImportingDialog = false
|
||||||
|
|
|
@ -30,7 +30,7 @@ private const val TAG = "AGLlmChatModelHelper"
|
||||||
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
|
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
|
||||||
typealias CleanUpListener = () -> Unit
|
typealias CleanUpListener = () -> Unit
|
||||||
|
|
||||||
data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceSession)
|
data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession)
|
||||||
|
|
||||||
object LlmChatModelHelper {
|
object LlmChatModelHelper {
|
||||||
// Indexed by model name.
|
// Indexed by model name.
|
||||||
|
@ -74,6 +74,24 @@ object LlmChatModelHelper {
|
||||||
onDone("")
|
onDone("")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun resetSession(model: Model) {
|
||||||
|
val instance = model.instance as LlmModelInstance? ?: return
|
||||||
|
val session = instance.session
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
val inference = instance.engine
|
||||||
|
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
|
||||||
|
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
|
||||||
|
val temperature =
|
||||||
|
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
|
||||||
|
val newSession = LlmInferenceSession.createFromOptions(
|
||||||
|
inference,
|
||||||
|
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||||
|
.setTemperature(temperature).build()
|
||||||
|
)
|
||||||
|
instance.session = newSession
|
||||||
|
}
|
||||||
|
|
||||||
fun cleanUp(model: Model) {
|
fun cleanUp(model: Model) {
|
||||||
if (model.instance == null) {
|
if (model.instance == null) {
|
||||||
return
|
return
|
||||||
|
@ -99,7 +117,11 @@ object LlmChatModelHelper {
|
||||||
input: String,
|
input: String,
|
||||||
resultListener: ResultListener,
|
resultListener: ResultListener,
|
||||||
cleanUpListener: CleanUpListener,
|
cleanUpListener: CleanUpListener,
|
||||||
|
singleTurn: Boolean = false,
|
||||||
) {
|
) {
|
||||||
|
if (singleTurn) {
|
||||||
|
resetSession(model = model)
|
||||||
|
}
|
||||||
val instance = model.instance as LlmModelInstance
|
val instance = model.instance as LlmModelInstance
|
||||||
|
|
||||||
// Set listener.
|
// Set listener.
|
||||||
|
|
|
@ -32,7 +32,7 @@ import kotlinx.coroutines.launch
|
||||||
|
|
||||||
private const val TAG = "AGLlmChatViewModel"
|
private const val TAG = "AGLlmChatViewModel"
|
||||||
private val STATS = listOf(
|
private val STATS = listOf(
|
||||||
Stat(id = "time_to_first_token", label = "Time to 1st token", unit = "sec"),
|
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
|
||||||
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
|
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
|
||||||
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
|
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
|
||||||
Stat(id = "latency", label = "Latency", unit = "sec")
|
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||||
|
|
|
@ -0,0 +1,206 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.activity.compose.BackHandler
|
||||||
|
import androidx.compose.animation.AnimatedVisibility
|
||||||
|
import androidx.compose.animation.fadeIn
|
||||||
|
import androidx.compose.animation.fadeOut
|
||||||
|
import androidx.compose.animation.scaleIn
|
||||||
|
import androidx.compose.animation.scaleOut
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.calculateStartPadding
|
||||||
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
|
import androidx.compose.foundation.layout.offset
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Scaffold
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
import androidx.compose.ui.platform.LocalLayoutDirection
|
||||||
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||||
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
|
import com.google.aiedge.gallery.ui.ViewModelProvider
|
||||||
|
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import kotlinx.serialization.Serializable
|
||||||
|
|
||||||
|
/** Navigation destination data */
|
||||||
|
object LlmSingleTurnDestination {
|
||||||
|
@Serializable
|
||||||
|
val route = "LlmSingleTurnRoute"
|
||||||
|
}
|
||||||
|
|
||||||
|
private const val TAG = "AGLlmSingleTurnScreen"
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun LlmSingleTurnScreen(
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
navigateUp: () -> Unit,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
viewModel: LlmSingleTurnViewModel = viewModel(
|
||||||
|
factory = ViewModelProvider.Factory
|
||||||
|
),
|
||||||
|
) {
|
||||||
|
val task = viewModel.task
|
||||||
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
|
val selectedModel = modelManagerUiState.selectedModel
|
||||||
|
val scope = rememberCoroutineScope()
|
||||||
|
val context = LocalContext.current
|
||||||
|
|
||||||
|
val handleNavigateUp = {
|
||||||
|
navigateUp()
|
||||||
|
|
||||||
|
// clean up all models.
|
||||||
|
scope.launch(Dispatchers.Default) {
|
||||||
|
for (model in task.models) {
|
||||||
|
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle system's edge swipe.
|
||||||
|
BackHandler {
|
||||||
|
handleNavigateUp()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize model when model/download state changes.
|
||||||
|
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
||||||
|
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
||||||
|
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||||
|
Log.d(
|
||||||
|
TAG,
|
||||||
|
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
|
||||||
|
)
|
||||||
|
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val modelInitializationStatus =
|
||||||
|
modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
||||||
|
|
||||||
|
Scaffold(modifier = modifier, topBar = {
|
||||||
|
ModelPageAppBar(
|
||||||
|
task = task,
|
||||||
|
model = selectedModel,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
onConfigChanged = { _, _ -> },
|
||||||
|
onBackClicked = { handleNavigateUp() },
|
||||||
|
onModelSelected = { newSelectedModel ->
|
||||||
|
scope.launch(Dispatchers.Default) {
|
||||||
|
// Clean up current model.
|
||||||
|
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
|
||||||
|
|
||||||
|
// Update selected model.
|
||||||
|
modelManagerViewModel.selectModel(model = newSelectedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}) { innerPadding ->
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.padding(
|
||||||
|
top = innerPadding.calculateTopPadding(),
|
||||||
|
start = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
|
||||||
|
end = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
ModelDownloadStatusInfoPanel(
|
||||||
|
model = selectedModel,
|
||||||
|
task = task,
|
||||||
|
modelManagerViewModel = modelManagerViewModel
|
||||||
|
)
|
||||||
|
|
||||||
|
// Main UI after model is downloaded.
|
||||||
|
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||||
|
Box(
|
||||||
|
contentAlignment = Alignment.BottomCenter,
|
||||||
|
modifier = Modifier.weight(1f)
|
||||||
|
) {
|
||||||
|
VerticalSplitView(modifier = Modifier.fillMaxSize(),
|
||||||
|
topView = {
|
||||||
|
PromptTemplatesPanel(
|
||||||
|
model = selectedModel,
|
||||||
|
viewModel = viewModel,
|
||||||
|
onSend = { fullPrompt ->
|
||||||
|
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
|
||||||
|
}, modifier = Modifier.fillMaxSize()
|
||||||
|
)
|
||||||
|
},
|
||||||
|
bottomView = {
|
||||||
|
Box(
|
||||||
|
contentAlignment = Alignment.BottomCenter,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||||
|
) {
|
||||||
|
ResponsePanel(
|
||||||
|
model = selectedModel,
|
||||||
|
viewModel = viewModel,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.padding(bottom = innerPadding.calculateBottomPadding())
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Model initialization in-progress message.
|
||||||
|
this@Column.AnimatedVisibility(
|
||||||
|
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||||
|
enter = scaleIn() + fadeIn(),
|
||||||
|
exit = scaleOut() + fadeOut(),
|
||||||
|
modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding())
|
||||||
|
) {
|
||||||
|
ModelInitializationStatusChip()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Preview(showBackground = true)
|
||||||
|
@Composable
|
||||||
|
fun LlmSingleTurnScreenPreview() {
|
||||||
|
val context = LocalContext.current
|
||||||
|
GalleryTheme {
|
||||||
|
LlmSingleTurnScreen(
|
||||||
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
|
viewModel = PreviewLlmSingleTurnViewModel(),
|
||||||
|
navigateUp = {},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,210 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.lifecycle.ViewModel
|
||||||
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||||
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||||
|
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.LlmModelInstance
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
|
import kotlinx.coroutines.flow.update
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
|
private const val TAG = "AGLlmSingleTurnViewModel"
|
||||||
|
|
||||||
|
data class LlmSingleTurnUiState(
|
||||||
|
/**
|
||||||
|
* Indicates whether the runtime is currently processing a message.
|
||||||
|
*/
|
||||||
|
val inProgress: Boolean = false,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indicates whether the model is currently being initialized.
|
||||||
|
*/
|
||||||
|
val initializing: Boolean = false,
|
||||||
|
|
||||||
|
// model -> <template label -> response>
|
||||||
|
val responsesByModel: Map<String, Map<String, String>>,
|
||||||
|
|
||||||
|
// model -> <template label -> benchmark result>
|
||||||
|
val benchmarkByModel: Map<String, Map<String, ChatMessageBenchmarkLlmResult>>,
|
||||||
|
|
||||||
|
/** Selected prompt template type. */
|
||||||
|
val selectedPromptTemplateType: PromptTemplateType = PromptTemplateType.entries[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
private val STATS = listOf(
|
||||||
|
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
|
||||||
|
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
|
||||||
|
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
|
||||||
|
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||||
|
)
|
||||||
|
|
||||||
|
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_SINGLE_TURN) : ViewModel() {
|
||||||
|
private val _uiState = MutableStateFlow(createUiState(task = task))
|
||||||
|
val uiState = _uiState.asStateFlow()
|
||||||
|
|
||||||
|
fun generateResponse(model: Model, input: String) {
|
||||||
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
|
setInProgress(true)
|
||||||
|
setInitializing(true)
|
||||||
|
|
||||||
|
// Wait for instance to be initialized.
|
||||||
|
while (model.instance == null) {
|
||||||
|
delay(100)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run inference.
|
||||||
|
val instance = model.instance as LlmModelInstance
|
||||||
|
val prefillTokens = instance.session.sizeInTokens(input)
|
||||||
|
|
||||||
|
var firstRun = true
|
||||||
|
var timeToFirstToken = 0f
|
||||||
|
var firstTokenTs = 0L
|
||||||
|
var decodeTokens = 0
|
||||||
|
var prefillSpeed = 0f
|
||||||
|
var decodeSpeed: Float
|
||||||
|
val start = System.currentTimeMillis()
|
||||||
|
var response = ""
|
||||||
|
var lastBenchmarkUpdateTs = 0L
|
||||||
|
LlmChatModelHelper.runInference(model = model,
|
||||||
|
input = input,
|
||||||
|
resultListener = { partialResult, done ->
|
||||||
|
val curTs = System.currentTimeMillis()
|
||||||
|
|
||||||
|
if (firstRun) {
|
||||||
|
setInitializing(false)
|
||||||
|
firstTokenTs = System.currentTimeMillis()
|
||||||
|
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||||
|
prefillSpeed = prefillTokens / timeToFirstToken
|
||||||
|
firstRun = false
|
||||||
|
} else {
|
||||||
|
decodeTokens++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Incrementally update the streamed partial results.
|
||||||
|
response = processLlmResponse(response = "$response$partialResult")
|
||||||
|
|
||||||
|
// Update response.
|
||||||
|
updateResponse(
|
||||||
|
model = model,
|
||||||
|
promptTemplateType = uiState.value.selectedPromptTemplateType,
|
||||||
|
response = response
|
||||||
|
)
|
||||||
|
|
||||||
|
// Update benchmark (with throttling).
|
||||||
|
if (curTs - lastBenchmarkUpdateTs > 200) {
|
||||||
|
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||||
|
if (decodeSpeed.isNaN()) {
|
||||||
|
decodeSpeed = 0f
|
||||||
|
}
|
||||||
|
val benchmark = ChatMessageBenchmarkLlmResult(
|
||||||
|
orderedStats = STATS,
|
||||||
|
statValues = mutableMapOf(
|
||||||
|
"prefill_speed" to prefillSpeed,
|
||||||
|
"decode_speed" to decodeSpeed,
|
||||||
|
"time_to_first_token" to timeToFirstToken,
|
||||||
|
"latency" to (curTs - start).toFloat() / 1000f,
|
||||||
|
),
|
||||||
|
running = !done,
|
||||||
|
latencyMs = -1f,
|
||||||
|
)
|
||||||
|
updateBenchmark(
|
||||||
|
model = model,
|
||||||
|
promptTemplateType = uiState.value.selectedPromptTemplateType,
|
||||||
|
benchmark = benchmark
|
||||||
|
)
|
||||||
|
lastBenchmarkUpdateTs = curTs
|
||||||
|
}
|
||||||
|
|
||||||
|
if (done) {
|
||||||
|
setInProgress(false)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
singleTurn = true,
|
||||||
|
cleanUpListener = {
|
||||||
|
setInitializing(false)
|
||||||
|
setInProgress(false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun selectPromptTemplate(model: Model, promptTemplateType: PromptTemplateType) {
|
||||||
|
Log.d(TAG, "selecting prompt template: ${promptTemplateType.label}")
|
||||||
|
|
||||||
|
// Clear response.
|
||||||
|
updateResponse(model = model, promptTemplateType = promptTemplateType, response = "")
|
||||||
|
|
||||||
|
this._uiState.update { this.uiState.value.copy(selectedPromptTemplateType = promptTemplateType) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fun setInProgress(inProgress: Boolean) {
|
||||||
|
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fun setInitializing(initializing: Boolean) {
|
||||||
|
_uiState.update { _uiState.value.copy(initializing = initializing) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) {
|
||||||
|
_uiState.update { currentState ->
|
||||||
|
val currentResponses = currentState.responsesByModel
|
||||||
|
val modelResponses = currentResponses[model.name]?.toMutableMap() ?: mutableMapOf()
|
||||||
|
modelResponses[promptTemplateType.label] = response
|
||||||
|
val newResponses = currentResponses.toMutableMap()
|
||||||
|
newResponses[model.name] = modelResponses
|
||||||
|
currentState.copy(responsesByModel = newResponses)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun updateBenchmark(
|
||||||
|
model: Model, promptTemplateType: PromptTemplateType, benchmark: ChatMessageBenchmarkLlmResult
|
||||||
|
) {
|
||||||
|
_uiState.update { currentState ->
|
||||||
|
val currentBenchmark = currentState.benchmarkByModel
|
||||||
|
val modelBenchmarks = currentBenchmark[model.name]?.toMutableMap() ?: mutableMapOf()
|
||||||
|
modelBenchmarks[promptTemplateType.label] = benchmark
|
||||||
|
val newBenchmarks = currentBenchmark.toMutableMap()
|
||||||
|
newBenchmarks[model.name] = modelBenchmarks
|
||||||
|
currentState.copy(benchmarkByModel = newBenchmarks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun createUiState(task: Task): LlmSingleTurnUiState {
|
||||||
|
val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf()
|
||||||
|
val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> =
|
||||||
|
mutableMapOf()
|
||||||
|
for (model in task.models) {
|
||||||
|
responsesByModel[model.name] = mutableMapOf()
|
||||||
|
benchmarkByModel[model.name] = mutableMapOf()
|
||||||
|
}
|
||||||
|
return LlmSingleTurnUiState(
|
||||||
|
responsesByModel = responsesByModel,
|
||||||
|
benchmarkByModel = benchmarkByModel,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,185 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||||
|
|
||||||
|
import androidx.compose.ui.graphics.Color
|
||||||
|
import androidx.compose.ui.text.AnnotatedString
|
||||||
|
import androidx.compose.ui.text.buildAnnotatedString
|
||||||
|
import androidx.compose.ui.text.withStyle
|
||||||
|
import androidx.compose.ui.text.SpanStyle
|
||||||
|
import androidx.compose.ui.graphics.Brush.Companion.linearGradient
|
||||||
|
|
||||||
|
enum class PromptTemplateInputEditorType {
|
||||||
|
SINGLE_SELECT
|
||||||
|
}
|
||||||
|
|
||||||
|
enum class RewriteToneType(val label: String) {
|
||||||
|
FORMAL(label = "Formal"), CASUAL(label = "Casual"), FRIENDLY(label = "Friendly"), POLITE(label = "Polite"), ENTHUSIASTIC(
|
||||||
|
label = "Enthusiastic"
|
||||||
|
),
|
||||||
|
CONCISE(label = "Concise"),
|
||||||
|
}
|
||||||
|
|
||||||
|
enum class SummarizationType(val label: String) {
|
||||||
|
KEY_BULLET_POINT(label = "Key bullet points (3-5)"),
|
||||||
|
SHORT_PARAGRAPH(label = "Short paragraph (1-2 sentences)"),
|
||||||
|
CONCISE_SUMMARY(label = "Concise summary (~50 words)"),
|
||||||
|
HEADLINE_TITLE(label = "Headline / title"),
|
||||||
|
ONE_SENTENCE_SUMMARY(label = "One-sentence summary"),
|
||||||
|
}
|
||||||
|
|
||||||
|
enum class LanguageType(val label: String) {
|
||||||
|
CPP(label = "C++"),
|
||||||
|
JAVA(label = "Java"),
|
||||||
|
JAVASCRIPT(label = "JavaScript"),
|
||||||
|
KOTLIN(label = "Kotlin"),
|
||||||
|
PYTHON(label = "Python"),
|
||||||
|
SWIFT(label = "Swift"),
|
||||||
|
TYPESCRIPT(label = "TypeScript"),
|
||||||
|
}
|
||||||
|
|
||||||
|
enum class InputEditorLabel(val label: String) {
|
||||||
|
TONE(label = "Tone"),
|
||||||
|
STYLE(label = "Style"),
|
||||||
|
LANGUAGE(label = "Language"),
|
||||||
|
}
|
||||||
|
|
||||||
|
open class PromptTemplateInputEditor(
|
||||||
|
open val label: String,
|
||||||
|
open val type: PromptTemplateInputEditorType,
|
||||||
|
open val defaultOption: String = "",
|
||||||
|
)
|
||||||
|
|
||||||
|
/** Single select that shows options in bottom sheet. */
|
||||||
|
class PromptTemplateSingleSelectInputEditor(
|
||||||
|
override val label: String,
|
||||||
|
val options: List<String> = listOf(),
|
||||||
|
override val defaultOption: String = "",
|
||||||
|
) : PromptTemplateInputEditor(
|
||||||
|
label = label, type = PromptTemplateInputEditorType.SINGLE_SELECT, defaultOption = defaultOption
|
||||||
|
)
|
||||||
|
|
||||||
|
data class PromptTemplateConfig(val inputEditors: List<PromptTemplateInputEditor> = listOf())
|
||||||
|
|
||||||
|
private val GEMINI_GRADIENT_STYLE = SpanStyle(
|
||||||
|
brush = linearGradient(
|
||||||
|
colors = listOf(Color(0xFF4285f4), Color(0xFF9b72cb), Color(0xFFd96570))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
enum class PromptTemplateType(
|
||||||
|
val label: String,
|
||||||
|
val config: PromptTemplateConfig,
|
||||||
|
val genFullPrompt: (userInput: String, inputEditorValues: Map<String, Any>) -> AnnotatedString = { _, _ ->
|
||||||
|
AnnotatedString("")
|
||||||
|
},
|
||||||
|
val examplePrompts: List<String> = listOf(),
|
||||||
|
) {
|
||||||
|
FREE_FORM(
|
||||||
|
label = "Free form",
|
||||||
|
config = PromptTemplateConfig(),
|
||||||
|
genFullPrompt = { userInput, _ -> AnnotatedString(userInput) },
|
||||||
|
examplePrompts = listOf(
|
||||||
|
"Suggest 3 topics for a podcast about \"Friendships in your 20s\".",
|
||||||
|
"Outline the key sections needed in a basic logo design brief.",
|
||||||
|
"List 3 pros and 3 cons to consider before buying a smart watch.",
|
||||||
|
"Write a short, optimistic quote about the future of technology.",
|
||||||
|
"Generate 3 potential names for a mobile app that helps users identify plants.",
|
||||||
|
"Explain the difference between AI and machine learning in 2 sentences.",
|
||||||
|
"Create a simple haiku about a cat sleeping in the sun.",
|
||||||
|
"List 3 ways to make instant noodles taste better using common kitchen ingredients."
|
||||||
|
)
|
||||||
|
),
|
||||||
|
REWRITE_TONE(
|
||||||
|
label = "Rewrite tone", config = PromptTemplateConfig(
|
||||||
|
inputEditors = listOf(
|
||||||
|
PromptTemplateSingleSelectInputEditor(
|
||||||
|
label = InputEditorLabel.TONE.label,
|
||||||
|
options = RewriteToneType.entries.map { it.label },
|
||||||
|
defaultOption = RewriteToneType.FORMAL.label
|
||||||
|
)
|
||||||
|
)
|
||||||
|
), genFullPrompt = { userInput, inputEditorValues ->
|
||||||
|
val tone = inputEditorValues[InputEditorLabel.TONE.label] as String
|
||||||
|
buildAnnotatedString {
|
||||||
|
withStyle(GEMINI_GRADIENT_STYLE) {
|
||||||
|
append("Rewrite the following text using a ${tone.lowercase()} tone: ")
|
||||||
|
}
|
||||||
|
append(userInput)
|
||||||
|
}
|
||||||
|
}, examplePrompts = listOf(
|
||||||
|
"Hey team, just wanted to remind everyone about the meeting tomorrow @ 10. Be there!",
|
||||||
|
"Our new software update includes several bug fixes and performance improvements.",
|
||||||
|
"Due to the fact that the weather was bad, we decided to postpone the event.",
|
||||||
|
"Please find attached the requested documentation for your perusal.",
|
||||||
|
"Welcome to the team. Review the onboarding materials.",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
SUMMARIZE_TEXT(
|
||||||
|
label = "Summarize text",
|
||||||
|
config = PromptTemplateConfig(
|
||||||
|
inputEditors = listOf(
|
||||||
|
PromptTemplateSingleSelectInputEditor(
|
||||||
|
label = InputEditorLabel.STYLE.label,
|
||||||
|
options = SummarizationType.entries.map { it.label },
|
||||||
|
defaultOption = SummarizationType.KEY_BULLET_POINT.label
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
genFullPrompt = { userInput, inputEditorValues ->
|
||||||
|
val style = inputEditorValues[InputEditorLabel.STYLE.label] as String
|
||||||
|
buildAnnotatedString {
|
||||||
|
withStyle(GEMINI_GRADIENT_STYLE) {
|
||||||
|
append("Please summarize the following in ${style.lowercase()}: ")
|
||||||
|
}
|
||||||
|
append(userInput)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
examplePrompts = listOf(
|
||||||
|
"The new Pixel phone features an advanced camera system with improved low-light performance and AI-powered editing tools. The display is brighter and more energy-efficient. It runs on the latest Tensor chip, offering faster processing and enhanced security features. Battery life has also been extended, providing all-day power for most users.",
|
||||||
|
"Beginning this Friday, January 24, giant pandas Bao Li and Qing Bao are officially on view to the public at the Smithsonian’s National Zoo and Conservation Biology Institute (NZCBI). The 3-year-old bears arrived in Washington this past October, undergoing a quarantine period before making their debut. Under NZCBI’s new agreement with the CWCA, Qing Bao and Bao Li will remain in the United States for ten years, until April 2034, in exchange for an annual fee of \$1 million. The pair are still too young to breed, as pandas only reach sexual maturity between ages 4 and 7. “Kind of picture them as like awkward teenagers right now,” Lally told WUSA9. “We still have about two years before we would probably even see signs that they’re ready to start mating.”",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
CODE_SNIPPET(
|
||||||
|
label = "Code snippet",
|
||||||
|
config = PromptTemplateConfig(
|
||||||
|
inputEditors = listOf(
|
||||||
|
PromptTemplateSingleSelectInputEditor(
|
||||||
|
label = InputEditorLabel.LANGUAGE.label,
|
||||||
|
options = LanguageType.entries.map { it.label },
|
||||||
|
defaultOption = LanguageType.JAVASCRIPT.label
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
genFullPrompt = { userInput, inputEditorValues ->
|
||||||
|
val language = inputEditorValues[InputEditorLabel.LANGUAGE.label] as String
|
||||||
|
buildAnnotatedString {
|
||||||
|
withStyle(GEMINI_GRADIENT_STYLE) {
|
||||||
|
append("Write a $language code snippet to ")
|
||||||
|
}
|
||||||
|
append(userInput)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
examplePrompts = listOf(
|
||||||
|
"Create an alert box that says \"Hello, World!\"",
|
||||||
|
"Declare an immutable variable named 'appName' with the value \"AI Gallery\"",
|
||||||
|
"Print the numbers from 1 to 5 using a for loop.",
|
||||||
|
"Write a function that returns the square of an integer input.",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,426 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||||
|
|
||||||
|
import androidx.compose.foundation.BorderStroke
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.border
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.Spacer
|
||||||
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.height
|
||||||
|
import androidx.compose.foundation.layout.offset
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.size
|
||||||
|
import androidx.compose.foundation.layout.wrapContentHeight
|
||||||
|
import androidx.compose.foundation.rememberScrollState
|
||||||
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
|
import androidx.compose.foundation.verticalScroll
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.automirrored.rounded.Send
|
||||||
|
import androidx.compose.material.icons.outlined.ContentCopy
|
||||||
|
import androidx.compose.material.icons.outlined.Description
|
||||||
|
import androidx.compose.material.icons.outlined.ExpandLess
|
||||||
|
import androidx.compose.material.icons.outlined.ExpandMore
|
||||||
|
import androidx.compose.material.icons.rounded.Add
|
||||||
|
import androidx.compose.material.icons.rounded.Visibility
|
||||||
|
import androidx.compose.material.icons.rounded.VisibilityOff
|
||||||
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
|
import androidx.compose.material3.FilterChipDefaults
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.IconButtonDefaults
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.ModalBottomSheet
|
||||||
|
import androidx.compose.material3.OutlinedIconButton
|
||||||
|
import androidx.compose.material3.PrimaryScrollableTabRow
|
||||||
|
import androidx.compose.material3.Tab
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.material3.TextField
|
||||||
|
import androidx.compose.material3.TextFieldDefaults
|
||||||
|
import androidx.compose.material3.rememberModalBottomSheetState
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.derivedStateOf
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableIntStateOf
|
||||||
|
import androidx.compose.runtime.mutableStateMapOf
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.runtime.snapshots.SnapshotStateMap
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.alpha
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.graphics.Color
|
||||||
|
import androidx.compose.ui.platform.LocalClipboardManager
|
||||||
|
import androidx.compose.ui.res.dimensionResource
|
||||||
|
import androidx.compose.ui.text.TextLayoutResult
|
||||||
|
import androidx.compose.ui.text.style.TextOverflow
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.google.aiedge.gallery.R
|
||||||
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape
|
||||||
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
|
|
||||||
|
private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries
|
||||||
|
private val TAB_TITLES = PromptTemplateType.entries.map { it.label }
|
||||||
|
private val ICON_BUTTON_SIZE = 42.dp
|
||||||
|
|
||||||
|
const val FULL_PROMPT_SWITCH_KEY = "full_prompt"
|
||||||
|
|
||||||
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
|
@Composable
|
||||||
|
fun PromptTemplatesPanel(
|
||||||
|
model: Model,
|
||||||
|
viewModel: LlmSingleTurnViewModel,
|
||||||
|
onSend: (fullPrompt: String) -> Unit,
|
||||||
|
modifier: Modifier = Modifier
|
||||||
|
) {
|
||||||
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
|
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||||
|
val inProgress = uiState.inProgress
|
||||||
|
var selectedTabIndex by remember { mutableIntStateOf(0) }
|
||||||
|
var curTextInputContent by remember { mutableStateOf("") }
|
||||||
|
val inputEditorValues: SnapshotStateMap<String, Any> = remember {
|
||||||
|
mutableStateMapOf(FULL_PROMPT_SWITCH_KEY to false)
|
||||||
|
}
|
||||||
|
val fullPrompt by remember {
|
||||||
|
derivedStateOf {
|
||||||
|
uiState.selectedPromptTemplateType.genFullPrompt(curTextInputContent, inputEditorValues)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val clipboardManager = LocalClipboardManager.current
|
||||||
|
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
|
||||||
|
|
||||||
|
// Update input editor values when prompt template changes.
|
||||||
|
LaunchedEffect(selectedPromptTemplateType) {
|
||||||
|
for (config in selectedPromptTemplateType.config.inputEditors) {
|
||||||
|
inputEditorValues[config.label] = config.defaultOption
|
||||||
|
}
|
||||||
|
expandedStates.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
var showExamplePromptBottomSheet by remember { mutableStateOf(false) }
|
||||||
|
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||||
|
val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius)
|
||||||
|
|
||||||
|
Column(modifier = modifier) {
|
||||||
|
// Scrollable tab row for all prompt templates.
|
||||||
|
PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) {
|
||||||
|
TAB_TITLES.forEachIndexed { index, title ->
|
||||||
|
Tab(selected = selectedTabIndex == index, onClick = {
|
||||||
|
// Clear input when tab changes.
|
||||||
|
curTextInputContent = ""
|
||||||
|
// Reset full prompt switch.
|
||||||
|
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = false
|
||||||
|
|
||||||
|
selectedTabIndex = index
|
||||||
|
viewModel.selectPromptTemplate(
|
||||||
|
model = model,
|
||||||
|
promptTemplateType = promptTemplateTypes[index]
|
||||||
|
)
|
||||||
|
}, text = { Text(text = title) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Content.
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.weight(1f)
|
||||||
|
.fillMaxWidth()
|
||||||
|
) {
|
||||||
|
// Input editor row.
|
||||||
|
if (selectedPromptTemplateType.config.inputEditors.isNotEmpty()) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.background(MaterialTheme.colorScheme.surfaceContainerLow)
|
||||||
|
.padding(horizontal = 16.dp, vertical = 10.dp)
|
||||||
|
) {
|
||||||
|
// Input editors.
|
||||||
|
for (inputEditor in selectedPromptTemplateType.config.inputEditors) {
|
||||||
|
when (inputEditor.type) {
|
||||||
|
PromptTemplateInputEditorType.SINGLE_SELECT -> SingleSelectButton(config = inputEditor as PromptTemplateSingleSelectInputEditor,
|
||||||
|
onSelected = { option ->
|
||||||
|
inputEditorValues[inputEditor.label] = option
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text input box.
|
||||||
|
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.verticalScroll(rememberScrollState())
|
||||||
|
) {
|
||||||
|
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
|
||||||
|
Text(
|
||||||
|
fullPrompt,
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(16.dp)
|
||||||
|
.padding(bottom = 32.dp)
|
||||||
|
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
|
||||||
|
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||||
|
.padding(16.dp)
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
TextField(
|
||||||
|
value = curTextInputContent,
|
||||||
|
onValueChange = { curTextInputContent = it },
|
||||||
|
colors = TextFieldDefaults.colors(
|
||||||
|
unfocusedContainerColor = Color.Transparent,
|
||||||
|
focusedContainerColor = Color.Transparent,
|
||||||
|
focusedIndicatorColor = Color.Transparent,
|
||||||
|
unfocusedIndicatorColor = Color.Transparent,
|
||||||
|
disabledIndicatorColor = Color.Transparent,
|
||||||
|
disabledContainerColor = Color.Transparent,
|
||||||
|
),
|
||||||
|
textStyle = MaterialTheme.typography.bodyMedium,
|
||||||
|
placeholder = { Text("Enter content") },
|
||||||
|
modifier = Modifier.padding(bottom = 32.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text action row.
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(vertical = 4.dp, horizontal = 16.dp)
|
||||||
|
) {
|
||||||
|
// Full prompt switch.
|
||||||
|
if (selectedPromptTemplateType != PromptTemplateType.FREE_FORM && curTextInputContent.isNotEmpty()) {
|
||||||
|
Row(verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||||
|
modifier = Modifier
|
||||||
|
.clip(CircleShape)
|
||||||
|
.background(if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.customColors.agentBubbleBgColor)
|
||||||
|
.clickable {
|
||||||
|
inputEditorValues[FULL_PROMPT_SWITCH_KEY] =
|
||||||
|
!(inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean)
|
||||||
|
}
|
||||||
|
.height(40.dp)
|
||||||
|
.border(
|
||||||
|
width = 1.dp, color = MaterialTheme.colorScheme.surface, shape = CircleShape
|
||||||
|
)
|
||||||
|
.padding(horizontal = 12.dp)) {
|
||||||
|
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Rounded.Visibility,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier.size(FilterChipDefaults.IconSize),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Rounded.VisibilityOff,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier
|
||||||
|
.size(FilterChipDefaults.IconSize)
|
||||||
|
.alpha(0.3f),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Text("Preview prompt", style = MaterialTheme.typography.labelMedium)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.weight(1f))
|
||||||
|
|
||||||
|
// Button to copy full prompt.
|
||||||
|
if (curTextInputContent.isNotEmpty()) {
|
||||||
|
OutlinedIconButton(
|
||||||
|
onClick = {
|
||||||
|
val clipData = fullPrompt
|
||||||
|
clipboardManager.setText(clipData)
|
||||||
|
},
|
||||||
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
|
||||||
|
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
|
||||||
|
contentColor = MaterialTheme.colorScheme.primary,
|
||||||
|
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
|
||||||
|
),
|
||||||
|
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
||||||
|
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Outlined.ContentCopy, contentDescription = "", modifier = Modifier.size(20.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add example prompt button.
|
||||||
|
OutlinedIconButton(
|
||||||
|
enabled = !inProgress,
|
||||||
|
onClick = { showExamplePromptBottomSheet = true },
|
||||||
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
|
||||||
|
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
|
||||||
|
contentColor = MaterialTheme.colorScheme.primary,
|
||||||
|
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
|
||||||
|
),
|
||||||
|
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
||||||
|
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Rounded.Add,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier.size(20.dp),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send button
|
||||||
|
OutlinedIconButton(
|
||||||
|
enabled = !inProgress && curTextInputContent.isNotEmpty(),
|
||||||
|
onClick = {
|
||||||
|
onSend(fullPrompt.text)
|
||||||
|
},
|
||||||
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||||
|
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
|
||||||
|
contentColor = MaterialTheme.colorScheme.primary,
|
||||||
|
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
|
||||||
|
),
|
||||||
|
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
||||||
|
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.AutoMirrored.Rounded.Send,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier
|
||||||
|
.size(20.dp)
|
||||||
|
.offset(x = 2.dp),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (showExamplePromptBottomSheet) {
|
||||||
|
ModalBottomSheet(
|
||||||
|
onDismissRequest = { showExamplePromptBottomSheet = false },
|
||||||
|
sheetState = sheetState,
|
||||||
|
modifier = Modifier.wrapContentHeight(),
|
||||||
|
) {
|
||||||
|
Column(modifier = Modifier.padding(bottom = 16.dp)) {
|
||||||
|
// Title
|
||||||
|
Text(
|
||||||
|
"Select an example",
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(16.dp),
|
||||||
|
style = MaterialTheme.typography.titleLarge
|
||||||
|
)
|
||||||
|
|
||||||
|
// Examples
|
||||||
|
for (prompt in selectedPromptTemplateType.examplePrompts) {
|
||||||
|
var textLayoutResultState by remember { mutableStateOf<TextLayoutResult?>(null) }
|
||||||
|
val hasOverflow = remember(textLayoutResultState) {
|
||||||
|
textLayoutResultState?.hasVisualOverflow ?: false
|
||||||
|
}
|
||||||
|
val isExpanded = expandedStates[prompt] ?: false
|
||||||
|
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.clickable {
|
||||||
|
curTextInputContent = prompt
|
||||||
|
showExamplePromptBottomSheet = false
|
||||||
|
}
|
||||||
|
.padding(horizontal = 16.dp, vertical = 8.dp),
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
) {
|
||||||
|
Icon(Icons.Outlined.Description, contentDescription = "")
|
||||||
|
Text(prompt,
|
||||||
|
maxLines = if (isExpanded) Int.MAX_VALUE else 3,
|
||||||
|
overflow = TextOverflow.Ellipsis,
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
modifier = Modifier.weight(1f),
|
||||||
|
onTextLayout = { textLayoutResultState = it }
|
||||||
|
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasOverflow && !isExpanded) {
|
||||||
|
Row(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(top = 2.dp),
|
||||||
|
horizontalArrangement = Arrangement.End
|
||||||
|
) {
|
||||||
|
Box(modifier = Modifier
|
||||||
|
.padding(end = 16.dp)
|
||||||
|
.clip(CircleShape)
|
||||||
|
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
|
||||||
|
.clickable {
|
||||||
|
expandedStates[prompt] = true
|
||||||
|
}
|
||||||
|
.padding(vertical = 1.dp, horizontal = 6.dp)) {
|
||||||
|
Icon(
|
||||||
|
Icons.Outlined.ExpandMore,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier.size(12.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (isExpanded) {
|
||||||
|
Row(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(top = 2.dp),
|
||||||
|
horizontalArrangement = Arrangement.End
|
||||||
|
) {
|
||||||
|
Box(modifier = Modifier
|
||||||
|
.padding(end = 16.dp)
|
||||||
|
.clip(CircleShape)
|
||||||
|
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
|
||||||
|
.clickable {
|
||||||
|
expandedStates[prompt] = false
|
||||||
|
}
|
||||||
|
.padding(vertical = 1.dp, horizontal = 6.dp)) {
|
||||||
|
Icon(
|
||||||
|
Icons.Outlined.ExpandLess,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier.size(12.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,206 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||||
|
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.size
|
||||||
|
import androidx.compose.foundation.rememberScrollState
|
||||||
|
import androidx.compose.foundation.verticalScroll
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.outlined.AutoAwesome
|
||||||
|
import androidx.compose.material.icons.outlined.ContentCopy
|
||||||
|
import androidx.compose.material.icons.outlined.Timer
|
||||||
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.IconButton
|
||||||
|
import androidx.compose.material3.IconButtonDefaults
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.PrimaryTabRow
|
||||||
|
import androidx.compose.material3.Tab
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableIntStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.alpha
|
||||||
|
import androidx.compose.ui.graphics.Color
|
||||||
|
import androidx.compose.ui.platform.LocalClipboardManager
|
||||||
|
import androidx.compose.ui.text.AnnotatedString
|
||||||
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
||||||
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
|
||||||
|
private val OPTIONS = listOf("Response", "Benchmark")
|
||||||
|
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
|
||||||
|
|
||||||
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
|
@Composable
|
||||||
|
fun ResponsePanel(
|
||||||
|
model: Model,
|
||||||
|
viewModel: LlmSingleTurnViewModel,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
) {
|
||||||
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
|
val inProgress = uiState.inProgress
|
||||||
|
val initializing = uiState.initializing
|
||||||
|
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||||
|
val response = uiState.responsesByModel[model.name]?.get(selectedPromptTemplateType.label) ?: ""
|
||||||
|
val benchmark = uiState.benchmarkByModel[model.name]?.get(selectedPromptTemplateType.label)
|
||||||
|
val responseScrollState = rememberScrollState()
|
||||||
|
var selectedOptionIndex by remember { mutableIntStateOf(0) }
|
||||||
|
val clipboardManager = LocalClipboardManager.current
|
||||||
|
|
||||||
|
// Scroll to bottom when response changes.
|
||||||
|
LaunchedEffect(response) {
|
||||||
|
if (inProgress) {
|
||||||
|
responseScrollState.animateScrollTo(responseScrollState.maxValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select the "response" tab when prompt template changes.
|
||||||
|
LaunchedEffect(selectedPromptTemplateType) {
|
||||||
|
selectedOptionIndex = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if (initializing) {
|
||||||
|
Box(
|
||||||
|
contentAlignment = Alignment.TopStart,
|
||||||
|
modifier = modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.padding(horizontal = 16.dp)
|
||||||
|
) {
|
||||||
|
MessageBodyLoading()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Message when response is empty.
|
||||||
|
if (response.isEmpty()) {
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
horizontalArrangement = Arrangement.Center,
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"Response will appear here",
|
||||||
|
modifier = Modifier.alpha(0.5f),
|
||||||
|
style = MaterialTheme.typography.labelMedium,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Response markdown.
|
||||||
|
else {
|
||||||
|
Column(
|
||||||
|
modifier = modifier
|
||||||
|
.padding(horizontal = 16.dp)
|
||||||
|
.padding(bottom = 4.dp)
|
||||||
|
) {
|
||||||
|
// Response/benchmark switch.
|
||||||
|
Row(modifier = Modifier.fillMaxWidth()) {
|
||||||
|
PrimaryTabRow(
|
||||||
|
selectedTabIndex = selectedOptionIndex,
|
||||||
|
containerColor = Color.Transparent,
|
||||||
|
) {
|
||||||
|
OPTIONS.forEachIndexed { index, title ->
|
||||||
|
Tab(selected = selectedOptionIndex == index, onClick = {
|
||||||
|
selectedOptionIndex = index
|
||||||
|
}, text = {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(4.dp)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
ICONS[index],
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier
|
||||||
|
.size(16.dp)
|
||||||
|
.alpha(0.7f)
|
||||||
|
)
|
||||||
|
Text(text = title)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (selectedOptionIndex == 0) {
|
||||||
|
Box(
|
||||||
|
contentAlignment = Alignment.BottomEnd,
|
||||||
|
modifier = Modifier.weight(1f)
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.verticalScroll(responseScrollState)
|
||||||
|
) {
|
||||||
|
MarkdownText(
|
||||||
|
text = response,
|
||||||
|
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
// Copy button.
|
||||||
|
IconButton(
|
||||||
|
onClick = {
|
||||||
|
val clipData = AnnotatedString(response)
|
||||||
|
clipboardManager.setText(clipData)
|
||||||
|
},
|
||||||
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||||
|
contentColor = MaterialTheme.colorScheme.primary,
|
||||||
|
),
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Outlined.ContentCopy,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier.size(20.dp),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (selectedOptionIndex == 1) {
|
||||||
|
if (benchmark != null) {
|
||||||
|
MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Preview(showBackground = true)
|
||||||
|
@Composable
|
||||||
|
fun ResponsePanelPreview() {
|
||||||
|
GalleryTheme {
|
||||||
|
ResponsePanel(
|
||||||
|
model = TASK_LLM_SINGLE_TURN.models[0],
|
||||||
|
viewModel = LlmSingleTurnViewModel(),
|
||||||
|
modifier = Modifier.fillMaxSize()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,90 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||||
|
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.rounded.ArrowDropDown
|
||||||
|
import androidx.compose.material3.DropdownMenu
|
||||||
|
import androidx.compose.material3.DropdownMenuItem
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun SingleSelectButton(
|
||||||
|
config: PromptTemplateSingleSelectInputEditor,
|
||||||
|
onSelected: (String) -> Unit
|
||||||
|
) {
|
||||||
|
var showMenu by remember { mutableStateOf(false) }
|
||||||
|
var selectedOption by remember { mutableStateOf(config.defaultOption) }
|
||||||
|
|
||||||
|
LaunchedEffect(config) {
|
||||||
|
selectedOption = config.defaultOption
|
||||||
|
}
|
||||||
|
|
||||||
|
Box {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(2.dp),
|
||||||
|
modifier = Modifier
|
||||||
|
.clip(RoundedCornerShape(8.dp))
|
||||||
|
.background(MaterialTheme.colorScheme.secondaryContainer)
|
||||||
|
.clickable {
|
||||||
|
showMenu = true
|
||||||
|
}
|
||||||
|
.padding(vertical = 4.dp, horizontal = 6.dp)
|
||||||
|
.padding(start = 8.dp)
|
||||||
|
) {
|
||||||
|
Text("${config.label}: $selectedOption", style = MaterialTheme.typography.labelLarge)
|
||||||
|
Icon(Icons.Rounded.ArrowDropDown, contentDescription = "")
|
||||||
|
}
|
||||||
|
|
||||||
|
DropdownMenu(
|
||||||
|
expanded = showMenu,
|
||||||
|
onDismissRequest = { showMenu = false }
|
||||||
|
) {
|
||||||
|
// Options
|
||||||
|
for (option in config.options) {
|
||||||
|
DropdownMenuItem(
|
||||||
|
text = { Text(option) },
|
||||||
|
onClick = {
|
||||||
|
selectedOption = option
|
||||||
|
showMenu = false
|
||||||
|
onSelected(option)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,133 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||||
|
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.gestures.detectDragGestures
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.height
|
||||||
|
import androidx.compose.foundation.layout.width
|
||||||
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableFloatStateOf
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.input.pointer.pointerInput
|
||||||
|
import androidx.compose.ui.layout.onGloballyPositioned
|
||||||
|
import androidx.compose.ui.platform.LocalDensity
|
||||||
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
import androidx.compose.ui.unit.Dp
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun VerticalSplitView(
|
||||||
|
topView: @Composable () -> Unit,
|
||||||
|
bottomView: @Composable () -> Unit,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
initialRatio: Float = 0.5f,
|
||||||
|
minTopHeight: Dp = 250.dp,
|
||||||
|
minBottomHeight: Dp = 200.dp,
|
||||||
|
handleThickness: Dp = 20.dp
|
||||||
|
) {
|
||||||
|
var splitRatio by remember { mutableFloatStateOf(initialRatio) }
|
||||||
|
var columnHeightPx by remember {
|
||||||
|
mutableFloatStateOf(0f)
|
||||||
|
}
|
||||||
|
var columnHeightDp by remember {
|
||||||
|
mutableStateOf(0.dp)
|
||||||
|
}
|
||||||
|
val localDensity = LocalDensity.current
|
||||||
|
|
||||||
|
Column(modifier = modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.onGloballyPositioned { coordinates ->
|
||||||
|
// Set column height using the LayoutCoordinates
|
||||||
|
columnHeightPx = coordinates.size.height.toFloat()
|
||||||
|
columnHeightDp = with(localDensity) { coordinates.size.height.toDp() }
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.weight(splitRatio)
|
||||||
|
) {
|
||||||
|
topView()
|
||||||
|
}
|
||||||
|
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.height(handleThickness)
|
||||||
|
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||||
|
.pointerInput(Unit) {
|
||||||
|
detectDragGestures { change, dragAmount ->
|
||||||
|
val newTopHeightPx = columnHeightPx * splitRatio + dragAmount.y
|
||||||
|
var newTopHeightDp = with(localDensity) { newTopHeightPx.toDp() }
|
||||||
|
if (newTopHeightDp < minTopHeight) {
|
||||||
|
newTopHeightDp = minTopHeight
|
||||||
|
}
|
||||||
|
if (columnHeightDp - newTopHeightDp < minBottomHeight) {
|
||||||
|
newTopHeightDp = columnHeightDp - minBottomHeight
|
||||||
|
}
|
||||||
|
splitRatio = newTopHeightDp / columnHeightDp
|
||||||
|
change.consume()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
contentAlignment = Alignment.Center
|
||||||
|
) {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.width(32.dp)
|
||||||
|
.height(4.dp)
|
||||||
|
.clip(CircleShape)
|
||||||
|
.background(MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.weight(1f - splitRatio)
|
||||||
|
) {
|
||||||
|
bottomView()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Preview(showBackground = true)
|
||||||
|
@Composable
|
||||||
|
fun VerticalSplitViewPreview() {
|
||||||
|
GalleryTheme {
|
||||||
|
VerticalSplitView(topView = {
|
||||||
|
Text("top")
|
||||||
|
}, bottomView = {
|
||||||
|
Text("bottom")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,6 +40,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.TASKS
|
import com.google.aiedge.gallery.data.TASKS
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
import com.google.aiedge.gallery.data.ValueType
|
import com.google.aiedge.gallery.data.ValueType
|
||||||
|
@ -175,9 +176,13 @@ open class ModelManagerViewModel(
|
||||||
|
|
||||||
// Kick off downloads for these models .
|
// Kick off downloads for these models .
|
||||||
withContext(Dispatchers.Main) {
|
withContext(Dispatchers.Main) {
|
||||||
|
val tokenStatusAndData = getTokenStatusAndData()
|
||||||
for (info in inProgressWorkInfos) {
|
for (info in inProgressWorkInfos) {
|
||||||
val model: Model? = getModelByName(info.modelName)
|
val model: Model? = getModelByName(info.modelName)
|
||||||
if (model != null) {
|
if (model != null) {
|
||||||
|
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
|
||||||
|
model.accessToken = tokenStatusAndData.data.accessToken
|
||||||
|
}
|
||||||
Log.d(TAG, "Sending a new download request for '${model.name}'")
|
Log.d(TAG, "Sending a new download request for '${model.name}'")
|
||||||
downloadRepository.downloadModel(
|
downloadRepository.downloadModel(
|
||||||
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
|
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
|
||||||
|
@ -233,11 +238,13 @@ open class ModelManagerViewModel(
|
||||||
|
|
||||||
// Delete model from the list if model is imported as a local model.
|
// Delete model from the list if model is imported as a local model.
|
||||||
if (model.imported) {
|
if (model.imported) {
|
||||||
val index = task.models.indexOf(model)
|
for (curTask in TASKS) {
|
||||||
|
val index = curTask.models.indexOf(model)
|
||||||
if (index >= 0) {
|
if (index >= 0) {
|
||||||
task.models.removeAt(index)
|
curTask.models.removeAt(index)
|
||||||
|
}
|
||||||
|
curTask.updateTrigger.value = System.currentTimeMillis()
|
||||||
}
|
}
|
||||||
task.updateTrigger.value = System.currentTimeMillis()
|
|
||||||
curModelDownloadStatus.remove(model.name)
|
curModelDownloadStatus.remove(model.name)
|
||||||
|
|
||||||
// Update preference.
|
// Update preference.
|
||||||
|
@ -252,7 +259,7 @@ open class ModelManagerViewModel(
|
||||||
_uiState.update { newUiState }
|
_uiState.update { newUiState }
|
||||||
}
|
}
|
||||||
|
|
||||||
fun initializeModel(context: Context, model: Model, force: Boolean = false) {
|
fun initializeModel(context: Context, task: Task, model: Model, force: Boolean = false) {
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
// Skip if initialized already.
|
// Skip if initialized already.
|
||||||
if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) {
|
if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) {
|
||||||
|
@ -267,7 +274,7 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up.
|
// Clean up.
|
||||||
cleanupModel(model = model)
|
cleanupModel(task = task, model = model)
|
||||||
|
|
||||||
// Start initialization.
|
// Start initialization.
|
||||||
Log.d(TAG, "Initializing model '${model.name}'...")
|
Log.d(TAG, "Initializing model '${model.name}'...")
|
||||||
|
@ -301,7 +308,7 @@ open class ModelManagerViewModel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
when (model.taskType) {
|
when (task.type) {
|
||||||
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.initialize(
|
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.initialize(
|
||||||
context = context,
|
context = context,
|
||||||
model = model,
|
model = model,
|
||||||
|
@ -320,24 +327,33 @@ open class ModelManagerViewModel(
|
||||||
onDone = onDone,
|
onDone = onDone,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.initialize(
|
||||||
|
context = context,
|
||||||
|
model = model,
|
||||||
|
onDone = onDone,
|
||||||
|
)
|
||||||
|
|
||||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
|
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
|
||||||
context = context, model = model, onDone = onDone
|
context = context, model = model, onDone = onDone
|
||||||
)
|
)
|
||||||
|
|
||||||
else -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
|
TaskType.TEST_TASK_2 -> {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun cleanupModel(model: Model) {
|
fun cleanupModel(task: Task, model: Model) {
|
||||||
if (model.instance != null) {
|
if (model.instance != null) {
|
||||||
Log.d(TAG, "Cleaning up model '${model.name}'...")
|
Log.d(TAG, "Cleaning up model '${model.name}'...")
|
||||||
when (model.taskType) {
|
when (task.type) {
|
||||||
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
||||||
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
||||||
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
|
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
||||||
else -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
|
TaskType.TEST_TASK_2 -> {}
|
||||||
}
|
}
|
||||||
model.instance = null
|
model.instance = null
|
||||||
model.initializing = false
|
model.initializing = false
|
||||||
|
@ -421,10 +437,11 @@ open class ModelManagerViewModel(
|
||||||
return connection.responseCode
|
return connection.responseCode
|
||||||
}
|
}
|
||||||
|
|
||||||
fun addImportedLlmModel(task: Task, info: ImportedModelInfo) {
|
fun addImportedLlmModel(info: ImportedModelInfo) {
|
||||||
Log.d(TAG, "adding imported llm model: $info")
|
Log.d(TAG, "adding imported llm model: $info")
|
||||||
|
|
||||||
// Remove duplicated imported model if existed.
|
// Remove duplicated imported model if existed.
|
||||||
|
val task = TASK_LLM_CHAT
|
||||||
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
||||||
if (modelIndex >= 0) {
|
if (modelIndex >= 0) {
|
||||||
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
||||||
|
@ -455,6 +472,9 @@ open class ModelManagerViewModel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
task.updateTrigger.value = System.currentTimeMillis()
|
task.updateTrigger.value = System.currentTimeMillis()
|
||||||
|
// Also need to update single turn task.
|
||||||
|
TASK_LLM_SINGLE_TURN.updateTrigger.value = System.currentTimeMillis()
|
||||||
|
|
||||||
|
|
||||||
// Add to preference storage.
|
// Add to preference storage.
|
||||||
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
|
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
|
||||||
|
@ -630,11 +650,8 @@ open class ModelManagerViewModel(
|
||||||
|
|
||||||
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
|
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
|
||||||
val accelerators: List<Accelerator> = (convertValueToTargetType(
|
val accelerators: List<Accelerator> = (convertValueToTargetType(
|
||||||
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!,
|
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING
|
||||||
ValueType.STRING
|
) as String).split(",").mapNotNull { acceleratorLabel ->
|
||||||
) as String)
|
|
||||||
.split(",")
|
|
||||||
.mapNotNull { acceleratorLabel ->
|
|
||||||
when (acceleratorLabel.trim()) {
|
when (acceleratorLabel.trim()) {
|
||||||
Accelerator.GPU.label -> Accelerator.GPU
|
Accelerator.GPU.label -> Accelerator.GPU
|
||||||
Accelerator.CPU.label -> Accelerator.CPU
|
Accelerator.CPU.label -> Accelerator.CPU
|
||||||
|
@ -643,20 +660,16 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
val configs: List<Config> = createLlmChatConfigs(
|
val configs: List<Config> = createLlmChatConfigs(
|
||||||
defaultMaxToken = convertValueToTargetType(
|
defaultMaxToken = convertValueToTargetType(
|
||||||
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!,
|
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!, ValueType.INT
|
||||||
ValueType.INT
|
|
||||||
) as Int,
|
) as Int,
|
||||||
defaultTopK = convertValueToTargetType(
|
defaultTopK = convertValueToTargetType(
|
||||||
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!,
|
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!, ValueType.INT
|
||||||
ValueType.INT
|
|
||||||
) as Int,
|
) as Int,
|
||||||
defaultTopP = convertValueToTargetType(
|
defaultTopP = convertValueToTargetType(
|
||||||
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!,
|
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!, ValueType.FLOAT
|
||||||
ValueType.FLOAT
|
|
||||||
) as Float,
|
) as Float,
|
||||||
defaultTemperature = convertValueToTargetType(
|
defaultTemperature = convertValueToTargetType(
|
||||||
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!,
|
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!, ValueType.FLOAT
|
||||||
ValueType.FLOAT
|
|
||||||
) as Float,
|
) as Float,
|
||||||
accelerators = accelerators,
|
accelerators = accelerators,
|
||||||
)
|
)
|
||||||
|
@ -666,9 +679,10 @@ open class ModelManagerViewModel(
|
||||||
configs = configs,
|
configs = configs,
|
||||||
sizeInBytes = info.fileSize,
|
sizeInBytes = info.fileSize,
|
||||||
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
|
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
|
||||||
|
showBenchmarkButton = false,
|
||||||
imported = true,
|
imported = true,
|
||||||
)
|
)
|
||||||
model.preProcess(task = task)
|
model.preProcess()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
}
|
}
|
||||||
|
@ -741,7 +755,7 @@ open class ModelManagerViewModel(
|
||||||
val task = TASKS.find { it.type.label == hfModel.task }
|
val task = TASKS.find { it.type.label == hfModel.task }
|
||||||
val model = hfModel.toModel()
|
val model = hfModel.toModel()
|
||||||
if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) {
|
if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) {
|
||||||
model.preProcess(task = task)
|
model.preProcess()
|
||||||
Log.d(TAG, "AG model: $model")
|
Log.d(TAG, "AG model: $model")
|
||||||
task.models.add(model)
|
task.models.add(model)
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
||||||
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||||
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
|
@ -58,6 +59,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
|
||||||
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
|
||||||
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationDestination
|
import com.google.aiedge.gallery.ui.textclassification.TextClassificationDestination
|
||||||
|
@ -131,7 +134,7 @@ fun GalleryNavHost(
|
||||||
task = curPickedTask,
|
task = curPickedTask,
|
||||||
onModelClicked = { model ->
|
onModelClicked = { model ->
|
||||||
navigateToTaskScreen(
|
navigateToTaskScreen(
|
||||||
navController = navController, taskType = model.taskType!!, model = model
|
navController = navController, taskType = curPickedTask.type, model = model
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
navigateUp = { showModelManager = false })
|
navigateUp = { showModelManager = false })
|
||||||
|
@ -220,6 +223,24 @@ fun GalleryNavHost(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LLMm single turn.
|
||||||
|
composable(
|
||||||
|
route = "${LlmSingleTurnDestination.route}/{modelName}",
|
||||||
|
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||||
|
enterTransition = { slideEnter() },
|
||||||
|
exitTransition = { slideExit() },
|
||||||
|
) {
|
||||||
|
getModelFromNavigationParam(it, TASK_LLM_SINGLE_TURN)?.let { defaultModel ->
|
||||||
|
modelManagerViewModel.selectModel(defaultModel)
|
||||||
|
|
||||||
|
LlmSingleTurnScreen(
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
navigateUp = { navController.navigateUp() },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle incoming intents for deep links
|
// Handle incoming intents for deep links
|
||||||
|
@ -231,9 +252,10 @@ fun GalleryNavHost(
|
||||||
if (data.toString().startsWith("com.google.aiedge.gallery://model/")) {
|
if (data.toString().startsWith("com.google.aiedge.gallery://model/")) {
|
||||||
val modelName = data.pathSegments.last()
|
val modelName = data.pathSegments.last()
|
||||||
getModelByName(modelName)?.let { model ->
|
getModelByName(modelName)?.let { model ->
|
||||||
|
// TODO(jingjin): need to show a list of possible tasks for this model.
|
||||||
navigateToTaskScreen(
|
navigateToTaskScreen(
|
||||||
navController = navController,
|
navController = navController,
|
||||||
taskType = model.taskType!!,
|
taskType = TaskType.LLM_CHAT,
|
||||||
model = model
|
model = model
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -249,6 +271,7 @@ fun navigateToTaskScreen(
|
||||||
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
||||||
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
||||||
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
||||||
|
TaskType.LLM_SINGLE_TURN -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||||
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.preview
|
||||||
|
|
||||||
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||||
|
|
||||||
|
class PreviewLlmSingleTurnViewModel : LlmSingleTurnViewModel(task = TASK_TEST1)
|
|
@ -34,7 +34,7 @@ class PreviewModelManagerViewModel(context: Context) :
|
||||||
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
|
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
|
||||||
task.index = index
|
task.index = index
|
||||||
for (model in task.models) {
|
for (model in task.models) {
|
||||||
model.preProcess(task = task)
|
model.preProcess()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,8 @@ val primaryContainerLight = Color(0xFFD0E4FF)
|
||||||
val onPrimaryContainerLight = Color(0xFF144A74)
|
val onPrimaryContainerLight = Color(0xFF144A74)
|
||||||
val secondaryLight = Color(0xFF526070)
|
val secondaryLight = Color(0xFF526070)
|
||||||
val onSecondaryLight = Color(0xFFFFFFFF)
|
val onSecondaryLight = Color(0xFFFFFFFF)
|
||||||
val secondaryContainerLight = Color(0xFFD6E4F7)
|
//val secondaryContainerLight = Color(0xFFD6E4F7)
|
||||||
|
val secondaryContainerLight = Color(0xFFC2E7FF)
|
||||||
val onSecondaryContainerLight = Color(0xFF3B4857)
|
val onSecondaryContainerLight = Color(0xFF3B4857)
|
||||||
val tertiaryLight = Color(0xFF775A0B)
|
val tertiaryLight = Color(0xFF775A0B)
|
||||||
val onTertiaryLight = Color(0xFFFFFFFF)
|
val onTertiaryLight = Color(0xFFFFFFFF)
|
||||||
|
|
|
@ -116,6 +116,7 @@ data class CustomColors(
|
||||||
val userBubbleBgColor: Color = Color.Transparent,
|
val userBubbleBgColor: Color = Color.Transparent,
|
||||||
val agentBubbleBgColor: Color = Color.Transparent,
|
val agentBubbleBgColor: Color = Color.Transparent,
|
||||||
val linkColor: Color = Color.Transparent,
|
val linkColor: Color = Color.Transparent,
|
||||||
|
val successColor: Color = Color.Transparent,
|
||||||
)
|
)
|
||||||
|
|
||||||
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
|
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
|
||||||
|
@ -145,6 +146,7 @@ val lightCustomColors = CustomColors(
|
||||||
agentBubbleBgColor = Color(0xFFe9eef6),
|
agentBubbleBgColor = Color(0xFFe9eef6),
|
||||||
userBubbleBgColor = Color(0xFF32628D),
|
userBubbleBgColor = Color(0xFF32628D),
|
||||||
linkColor = Color(0xFF32628D),
|
linkColor = Color(0xFF32628D),
|
||||||
|
successColor = Color(0xff3d860b),
|
||||||
)
|
)
|
||||||
|
|
||||||
val darkCustomColors = CustomColors(
|
val darkCustomColors = CustomColors(
|
||||||
|
@ -172,6 +174,7 @@ val darkCustomColors = CustomColors(
|
||||||
agentBubbleBgColor = Color(0xFF1b1c1d),
|
agentBubbleBgColor = Color(0xFF1b1c1d),
|
||||||
userBubbleBgColor = Color(0xFF1f3760),
|
userBubbleBgColor = Color(0xFF1f3760),
|
||||||
linkColor = Color(0xFF9DCAFC),
|
linkColor = Color(0xFF9DCAFC),
|
||||||
|
successColor = Color(0xFFA1CE83),
|
||||||
)
|
)
|
||||||
|
|
||||||
val MaterialTheme.customColors: CustomColors
|
val MaterialTheme.customColors: CustomColors
|
||||||
|
|
|
@ -92,6 +92,7 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
|
||||||
|
|
||||||
val connection = url.openConnection() as HttpURLConnection
|
val connection = url.openConnection() as HttpURLConnection
|
||||||
if (accessToken != null) {
|
if (accessToken != null) {
|
||||||
|
Log.d(TAG, "Using access token: ${accessToken.subSequence(0, 10)}...")
|
||||||
connection.setRequestProperty("Authorization", "Bearer $accessToken")
|
connection.setRequestProperty("Authorization", "Bearer $accessToken")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,6 +177,7 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
|
||||||
KEY_MODEL_DOWNLOAD_REMAINING_MS, remainingMs.toLong()
|
KEY_MODEL_DOWNLOAD_REMAINING_MS, remainingMs.toLong()
|
||||||
).build()
|
).build()
|
||||||
)
|
)
|
||||||
|
Log.d(TAG, "downloadedBytes: $downloadedBytes")
|
||||||
lastSetProgressTs = curTs
|
lastSetProgressTs = curTs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue