Add support for image to text models

This commit is contained in:
Jing Jin 2025-05-16 00:17:24 -07:00
parent ef290cd7b0
commit bedc488a15
29 changed files with 763 additions and 231 deletions

View file

@ -19,7 +19,6 @@ package com.google.aiedge.gallery.data
import android.content.Context
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import java.io.File
data class ModelDataFile(
@ -91,6 +90,9 @@ data class Model(
/** The prompt templates for the model (only for LLM). */
val llmPromptTemplates: List<PromptTemplate> = listOf(),
/** Whether the LLM model supports image input. */
val llmSupportImage: Boolean = false,
/** Whether the model is imported or not. */
val imported: Boolean = false,
@ -204,6 +206,7 @@ enum class ConfigKey(val label: String) {
DEFAULT_TOPK("Default TopK"),
DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"),
SUPPORT_IMAGE("Support image"),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Accelerator"),
@ -250,83 +253,9 @@ const val IMAGE_CLASSIFICATION_INFO = ""
const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/litert/android"
const val LLM_CHAT_INFO =
"Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms"
const val IMAGE_GENERATION_INFO =
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
////////////////////////////////////////////////////////////////////////////////////////////////////
// Model spec.
val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
name = "Gemma 2B (GPU int4)",
downloadFileName = "gemma-2b-it-gpu-int4.bin",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
sizeInBytes = 1354301440L,
configs = createLlmChatConfigs(
accelerators = listOf(Accelerator.GPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
name = "Gemma 2 2B (GPU int8)",
downloadFileName = "gemma2-2b-it-gpu-int8.bin",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
sizeInBytes = 2627141632L,
configs = createLlmChatConfigs(
accelerators = listOf(Accelerator.GPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
name = "Gemma 3 1B (int4)",
downloadFileName = "gemma3-1b-it-int4.task",
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
sizeInBytes = 554661243L,
configs = createLlmChatConfigs(
defaultTopK = 64,
defaultTopP = 0.95f,
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
llmPromptTemplates = listOf(
PromptTemplate(
title = "Emoji Fun",
description = "Generate emojis by emotions",
prompt = "Show me emojis grouped by emotions"
),
PromptTemplate(
title = "Trip Planner",
description = "Plan a trip to a destination",
prompt = "Plan a two-day trip to San Francisco"
),
)
)
val MODEL_LLM_DEEPSEEK: Model = Model(
name = "Deepseek",
downloadFileName = "deepseek.task",
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
sizeInBytes = 1860686856L,
configs = createLlmChatConfigs(
defaultTemperature = 0.6f,
defaultTopK = 40,
defaultTopP = 0.7f,
accelerators = listOf(Accelerator.CPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
)
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
name = "MobileBert",
@ -414,12 +343,5 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
)
val MODELS_LLM: MutableList<Model> = mutableListOf(
// MODEL_LLM_GEMMA_2B_GPU_INT4,
// MODEL_LLM_GEMMA_2_2B_GPU_INT8,
MODEL_LLM_GEMMA_3_1B_INT4,
MODEL_LLM_DEEPSEEK,
)
val MODELS_IMAGE_GENERATION: MutableList<Model> =
mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION)

View file

@ -35,6 +35,7 @@ data class AllowedModel(
val defaultConfig: Map<String, ConfigValue>,
val taskTypes: List<String>,
val disabled: Boolean? = null,
val llmSupportImage: Boolean? = null,
) {
fun toModel(): Model {
// Construct HF download url.
@ -48,7 +49,7 @@ data class AllowedModel(
var defaultTopK: Int = DEFAULT_TOPK
var defaultTopP: Float = DEFAULT_TOPP
var defaultTemperature: Float = DEFAULT_TEMPERATURE
var defaultMaxToken: Int = 1024
var defaultMaxToken = 1024
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
if (defaultConfig.containsKey("topK")) {
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK)
@ -84,9 +85,10 @@ data class AllowedModel(
// Misc.
var showBenchmarkButton = true
val showRunAgainButton = true
if (taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)) {
var showRunAgainButton = true
if (isLlmModel) {
showBenchmarkButton = false
showRunAgainButton = false
}
return Model(
@ -99,7 +101,8 @@ data class AllowedModel(
downloadFileName = modelFile,
showBenchmarkButton = showBenchmarkButton,
showRunAgainButton = showRunAgainButton,
learnMoreUrl = "https://huggingface.co/${modelId}"
learnMoreUrl = "https://huggingface.co/${modelId}",
llmSupportImage = llmSupportImage == true,
)
}

View file

@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data
import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Forum
import androidx.compose.material.icons.outlined.Mms
import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.material.icons.rounded.ImageSearch
import androidx.compose.runtime.MutableState
@ -33,6 +34,7 @@ enum class TaskType(val label: String, val id: String) {
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
@ -93,7 +95,7 @@ val TASK_LLM_CHAT = Task(
icon = Icons.Outlined.Forum,
// models = MODELS_LLM,
models = mutableListOf(),
description = "Chat with a on-device large language model",
description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
@ -110,6 +112,17 @@ val TASK_LLM_USECASES = Task(
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_LLM_IMAGE_TO_TEXT = Task(
type = TaskType.LLM_IMAGE_TO_TEXT,
icon = Icons.Outlined.Mms,
// models = MODELS_LLM,
models = mutableListOf(),
description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_IMAGE_GENERATION = Task(
type = TaskType.IMAGE_GENERATION,
iconVectorResourceId = R.drawable.image_spark,
@ -127,6 +140,7 @@ val TASKS: List<Task> = listOf(
// TASK_IMAGE_GENERATION,
TASK_LLM_USECASES,
TASK_LLM_CHAT,
TASK_LLM_IMAGE_TO_TEXT
)
fun getModelByName(name: String): Model? {

View file

@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
@ -62,6 +63,12 @@ object ViewModelProvider {
LlmSingleTurnViewModel()
}
// Initializer for LlmImageToTextViewModel.
initializer {
LlmImageToTextViewModel()
}
// Initializer for ImageGenerationViewModel.
initializer {
ImageGenerationViewModel()
}

View file

@ -17,13 +17,17 @@
package com.google.aiedge.gallery.ui.common
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.size
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material.icons.rounded.MapsUgc
import androidx.compose.material.icons.rounded.Tune
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
@ -38,6 +42,7 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.scale
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource
@ -47,6 +52,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@OptIn(ExperimentalMaterial3Api::class)
@ -58,12 +64,17 @@ fun ModelPageAppBar(
onBackClicked: () -> Unit,
onModelSelected: (Model) -> Unit,
modifier: Modifier = Modifier,
isResettingSession: Boolean = false,
onResetSessionClicked: (Model) -> Unit = {},
showResetSessionButton: Boolean = false,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
) {
var showConfigDialog by remember { mutableStateOf(false) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val context = LocalContext.current
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[model.name]
CenterAlignedTopAppBar(title = {
Column(
@ -110,19 +121,58 @@ fun ModelPageAppBar(
actions = {
val showConfigButton =
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
Box(modifier = Modifier.size(42.dp), contentAlignment = Alignment.Center) {
var configButtonOffset = 0.dp
if (showConfigButton && showResetSessionButton) {
configButtonOffset = (-40).dp
}
val isModelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (showConfigButton) {
IconButton(
onClick = {
showConfigDialog = true
},
enabled = showConfigButton,
modifier = Modifier.alpha(if (showConfigButton) 1f else 0f)
enabled = !isModelInitializing,
modifier = Modifier
.scale(0.75f)
.offset(x = configButtonOffset)
.alpha(if (isModelInitializing) 0.5f else 1f)
) {
Icon(
imageVector = Icons.Rounded.Settings,
imageVector = Icons.Rounded.Tune,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
if (showResetSessionButton) {
if (isResettingSession) {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceVariant,
strokeWidth = 2.dp,
modifier = Modifier.size(16.dp)
)
} else {
IconButton(
onClick = {
onResetSessionClicked(model)
},
enabled = !isModelInitializing,
modifier = Modifier
.scale(0.75f)
.alpha(if (isModelInitializing) 0.5f else 1f)
) {
Icon(
imageVector = Icons.Rounded.MapsUgc,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
}
}
})
// Config dialog.

View file

@ -29,6 +29,7 @@ import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.widthIn
import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.foundation.shape.CircleShape
@ -54,6 +55,9 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalWindowInfo
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
@ -77,6 +81,10 @@ fun ModelPickerChipsPager(
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val scope = rememberCoroutineScope()
val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current
val screenWidthDp =
remember { with(density) { windowInfo.containerSize.width.toDp() } }
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
pageCount = { task.models.size })
@ -140,7 +148,12 @@ fun ModelPickerChipsPager(
Text(
model.name,
style = MaterialTheme.typography.labelSmall,
modifier = Modifier.padding(start = 4.dp),
modifier = Modifier
.padding(start = 4.dp)
.widthIn(0.dp, screenWidthDp - 250.dp),
maxLines = 1,
overflow = TextOverflow.MiddleEllipsis
)
Icon(
Icons.Rounded.ArrowDropDown,

View file

@ -109,7 +109,7 @@ fun ChatPanel(
task: Task,
selectedModel: Model,
viewModel: ChatViewModel,
onSendMessage: (Model, ChatMessage) -> Unit,
onSendMessage: (Model, List<ChatMessage>) -> Unit,
onRunAgainClicked: (Model, ChatMessage) -> Unit,
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
navigateUp: () -> Unit,
@ -280,7 +280,8 @@ fun ChatPanel(
task = task,
onPromptClicked = { template ->
onSendMessage(
selectedModel, ChatMessageText(content = template.prompt, side = ChatSide.USER)
selectedModel,
listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER))
)
})
@ -430,12 +431,15 @@ fun ChatPanel(
// Chat input
when (chatInputType) {
ChatInputType.TEXT -> {
val isLlmTask = task.type == TaskType.LLM_CHAT
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
// val isLlmTask = task.type == TaskType.LLM_CHAT
// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
val hasImageMessage = messages.any { it is ChatMessageImage }
MessageInputText(
modelManagerViewModel = modelManagerViewModel,
curMessage = curMessage,
inProgress = uiState.inProgress,
isResettingSession = uiState.isResettingSession,
hasImageMessage = hasImageMessage,
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
onValueChanged = { curMessage = it },
@ -445,13 +449,17 @@ fun ChatPanel(
},
onOpenPromptTemplatesClicked = {
onSendMessage(
selectedModel, ChatMessagePromptTemplates(
selectedModel, listOf(
ChatMessagePromptTemplates(
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
)
)
)
},
onStopButtonClicked = onStopButtonClicked,
showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false,
showImagePickerInMenu = selectedModel.llmSupportImage == true,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
)
}
@ -461,10 +469,12 @@ fun ChatPanel(
streamingMessage = streamingMessage,
onImageSelected = { bitmap ->
onSendMessage(
selectedModel, ChatMessageImage(
selectedModel, listOf(
ChatMessageImage(
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
)
)
)
},
onStreamImage = { bitmap ->
onStreamImageMessage(

View file

@ -70,16 +70,18 @@ fun ChatView(
task: Task,
viewModel: ChatViewModel,
modelManagerViewModel: ModelManagerViewModel,
onSendMessage: (Model, ChatMessage) -> Unit,
onSendMessage: (Model, List<ChatMessage>) -> Unit,
onRunAgainClicked: (Model, ChatMessage) -> Unit,
onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
onResetSessionClicked: (Model) -> Unit = {},
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
onStopButtonClicked: (Model) -> Unit = {},
chatInputType: ChatInputType = ChatInputType.TEXT,
showStopButtonInInputWhenInProgress: Boolean = false,
) {
val uiStat by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel
@ -155,6 +157,9 @@ fun ChatView(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
showResetSessionButton = true,
isResettingSession = uiStat.isResettingSession,
onResetSessionClicked = onResetSessionClicked,
onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage(
oldConfigValues = old,

View file

@ -33,6 +33,11 @@ data class ChatUiState(
*/
val inProgress: Boolean = false,
/**
* Indicates whether the session is being reset.
*/
val isResettingSession: Boolean = false,
/**
* A map of model names to lists of chat messages.
*/
@ -106,6 +111,12 @@ open class ChatViewModel(val task: Task) : ViewModel() {
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
}
fun clearAllMessages(model: Model) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
newMessagesByModel[model.name] = mutableListOf()
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
}
fun getLastMessage(model: Model): ChatMessage? {
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
}
@ -189,6 +200,10 @@ open class ChatViewModel(val task: Task) : ViewModel() {
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
}
fun setIsResettingSession(isResettingSession: Boolean) {
_uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) }
}
fun addConfigChangedMessage(
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
) {

View file

@ -21,14 +21,18 @@ import androidx.compose.material3.ProvideTextStyle
import androidx.compose.runtime.Composable
import androidx.compose.runtime.CompositionLocalProvider
import androidx.compose.ui.Modifier
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.TextLinkStyles
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.tooling.preview.Preview
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
import com.halilibo.richtext.commonmark.Markdown
import com.halilibo.richtext.ui.CodeBlockStyle
import com.halilibo.richtext.ui.RichTextStyle
import com.halilibo.richtext.ui.material3.RichText
import com.halilibo.richtext.ui.string.RichTextStringStyle
/**
* Composable function to display Markdown-formatted text.
@ -56,6 +60,11 @@ fun MarkdownText(
fontSize = MaterialTheme.typography.bodySmall.fontSize,
fontFamily = FontFamily.Monospace,
)
),
stringStyle = RichTextStringStyle(
linkStyle = TextLinkStyles(
style = SpanStyle(color = MaterialTheme.customColors.linkColor)
)
)
),
) {

View file

@ -46,7 +46,11 @@ fun MessageBodyWarning(message: ChatMessageWarning) {
.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer)
) {
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
MarkdownText(
text = message.content,
modifier = Modifier.padding(horizontal = 16.dp, vertical = 6.dp),
smallFontSize = true
)
}
}
}

View file

@ -16,28 +16,45 @@
package com.google.aiedge.gallery.ui.common.chat
import android.Manifest
import android.content.Context
import android.content.pm.PackageManager
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Matrix
import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.PickVisualMediaRequest
import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.StringRes
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.Send
import androidx.compose.material.icons.rounded.Add
import androidx.compose.material.icons.rounded.Close
import androidx.compose.material.icons.rounded.History
import androidx.compose.material.icons.rounded.Photo
import androidx.compose.material.icons.rounded.PhotoCamera
import androidx.compose.material.icons.rounded.PostAdd
import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material3.DropdownMenu
import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
@ -54,12 +71,17 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.shadow
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.core.content.ContextCompat
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.ui.common.createTempPictureUri
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.GalleryTheme
@ -74,24 +96,107 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
fun MessageInputText(
modelManagerViewModel: ModelManagerViewModel,
curMessage: String,
isResettingSession: Boolean,
inProgress: Boolean,
hasImageMessage: Boolean,
modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit,
onSendMessage: (ChatMessage) -> Unit,
onSendMessage: (List<ChatMessage>) -> Unit,
onOpenPromptTemplatesClicked: () -> Unit = {},
onStopButtonClicked: () -> Unit = {},
showPromptTemplatesInMenu: Boolean = true,
showPromptTemplatesInMenu: Boolean = false,
showImagePickerInMenu: Boolean = false,
showStopButtonWhenInProgress: Boolean = false,
) {
val context = LocalContext.current
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
var showAddContentMenu by remember { mutableStateOf(false) }
var showTextInputHistorySheet by remember { mutableStateOf(false) }
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
val updatePickedImages: (Bitmap) -> Unit = { bitmap ->
val newPickedImages: MutableList<Bitmap> = mutableListOf()
newPickedImages.addAll(pickedImages)
newPickedImages.add(bitmap)
pickedImages = newPickedImages.toList()
}
// launches camera
val cameraLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.TakePicture()) { isImageSaved ->
if (isImageSaved) {
handleImageSelected(
context = context,
uri = tempPhotoUri,
onImageSelected = { bitmap ->
updatePickedImages(bitmap)
},
rotateForPortrait = true,
)
}
}
// Permission request when taking picture.
val takePicturePermissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) { permissionGranted ->
if (permissionGranted) {
showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri)
}
}
// Registers a photo picker activity launcher in single-select mode.
val pickMedia =
rememberLauncherForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri ->
// Callback is invoked after the user selects a media item or closes the
// photo picker.
if (uri != null) {
handleImageSelected(context = context, uri = uri, onImageSelected = { bitmap ->
updatePickedImages(bitmap)
})
}
}
Box(contentAlignment = Alignment.CenterStart) {
// A preview panel for the selected image.
if (pickedImages.isNotEmpty()) {
Box(
contentAlignment = Alignment.TopEnd, modifier = Modifier.offset(x = 16.dp, y = (-80).dp)
) {
Image(
bitmap = pickedImages.last().asImageBitmap(),
contentDescription = "",
modifier = Modifier
.height(80.dp)
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
.clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)),
)
Box(modifier = Modifier
.offset(x = 10.dp, y = (-10).dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surface)
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
.clickable {
pickedImages = listOf()
}) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
modifier = Modifier
.padding(3.dp)
.size(16.dp)
)
}
}
}
// A plus button to show a popup menu to add stuff to the chat.
IconButton(
enabled = !inProgress,
enabled = !inProgress && !isResettingSession,
onClick = { showAddContentMenu = true },
modifier = Modifier
.offset(x = 16.dp)
@ -113,6 +218,58 @@ fun MessageInputText(
DropdownMenu(
expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false }) {
if (showImagePickerInMenu) {
// Take a photo.
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
Text("Take a photo")
}
},
enabled = pickedImages.isEmpty() && !hasImageMessage,
onClick = {
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(
context, Manifest.permission.CAMERA
) -> {
showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri)
}
// Otherwise, ask for permission
else -> {
takePicturePermissionLauncher.launch(Manifest.permission.CAMERA)
}
}
})
// Pick an image from album.
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.Photo, contentDescription = "")
Text("Pick from album")
}
},
enabled = pickedImages.isEmpty() && !hasImageMessage,
onClick = {
// Launch the photo picker and let the user choose only images.
pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly))
showAddContentMenu = false
})
}
// Prompt templates.
if (showPromptTemplatesInMenu) {
DropdownMenuItem(text = {
Row(
@ -127,6 +284,7 @@ fun MessageInputText(
showAddContentMenu = false
})
}
// Prompt history.
DropdownMenuItem(text = {
Row(
verticalAlignment = Alignment.CenterVertically,
@ -171,18 +329,19 @@ fun MessageInputText(
),
) {
Icon(
Icons.Rounded.Stop,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
Icons.Rounded.Stop, contentDescription = "", tint = MaterialTheme.colorScheme.primary
)
}
}
} // Send button. Only shown when text is not empty.
else if (curMessage.isNotEmpty()) {
IconButton(
enabled = !inProgress,
enabled = !inProgress && !isResettingSession,
onClick = {
onSendMessage(ChatMessageText(content = curMessage.trim(), side = ChatSide.USER))
onSendMessage(
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
)
pickedImages = listOf()
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
@ -200,7 +359,6 @@ fun MessageInputText(
}
}
// A bottom sheet to show the text input history to pick from.
if (showTextInputHistorySheet) {
TextInputHistorySheet(
@ -209,7 +367,8 @@ fun MessageInputText(
showTextInputHistorySheet = false
},
onHistoryItemClicked = { item ->
onSendMessage(ChatMessageText(content = item, side = ChatSide.USER))
onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
pickedImages = listOf()
modelManagerViewModel.promoteTextInputHistoryItem(item)
},
onHistoryItemDeleted = { item ->
@ -217,9 +376,55 @@ fun MessageInputText(
},
onHistoryItemsDeleteAll = {
modelManagerViewModel.clearTextInputHistory()
})
}
}
private fun handleImageSelected(
context: Context,
uri: Uri,
onImageSelected: (Bitmap) -> Unit,
// For some reason, some Android phone would store the picture taken by the camera rotated
// horizontally. Use this flag to rotate the image back to portrait if the picture's width
// is bigger than height.
rotateForPortrait: Boolean = false,
) {
val bitmap: Bitmap? = try {
val inputStream = context.contentResolver.openInputStream(uri)
val tmpBitmap = BitmapFactory.decodeStream(inputStream)
if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) {
val matrix = Matrix()
matrix.postRotate(90f)
Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true)
} else {
tmpBitmap
}
} catch (e: Exception) {
e.printStackTrace()
null
}
if (bitmap != null) {
onImageSelected(bitmap)
}
}
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> {
val messages: MutableList<ChatMessage> = mutableListOf()
if (pickedImages.isNotEmpty()) {
val lastImage = pickedImages.last()
messages.add(
ChatMessageImage(
bitmap = lastImage, imageBitMap = lastImage.asImageBitmap(), side = ChatSide.USER
)
)
}
messages.add(
ChatMessageText(
content = text, side = ChatSide.USER
)
)
return messages
}
@Preview(showBackground = true)
@ -233,6 +438,21 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = false,
isResettingSession = false,
modelInitializing = false,
hasImageMessage = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
showImagePickerInMenu = true,
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = false,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
@ -243,6 +463,8 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = true,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
@ -252,6 +474,8 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = false,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
@ -261,6 +485,8 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = true,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},

View file

@ -52,6 +52,7 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
@ -149,6 +150,8 @@ private fun SheetContent(
Text(
item,
style = MaterialTheme.typography.bodyMedium,
maxLines = 3,
overflow = TextOverflow.Ellipsis,
modifier = Modifier
.padding(vertical = 16.dp)
.padding(start = 16.dp)

View file

@ -76,7 +76,7 @@ fun ModelNameAndStatus(
Text(
model.name,
maxLines = 1,
overflow = TextOverflow.Ellipsis,
overflow = TextOverflow.MiddleEllipsis,
style = MaterialTheme.typography.titleMedium,
modifier = modifier,
)

View file

@ -136,7 +136,7 @@ fun HomeScreen(
val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope()
val tasks = uiState.tasks
val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 }
val loadingHfModels = uiState.loadingHfModels
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
@ -183,14 +183,14 @@ fun HomeScreen(
) { innerPadding ->
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
TaskList(
tasks = tasks,
tasks = nonEmptyTasks,
navigateToTaskScreen = navigateToTaskScreen,
loadingModelAllowlist = uiState.loadingModelAllowlist,
modifier = Modifier.fillMaxSize(),
contentPadding = innerPadding,
)
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 16.dp))
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 32.dp))
}
}
@ -285,7 +285,7 @@ fun HomeScreen(
// Show a snack bar for successful import.
scope.launch {
snackbarHostState.showSnackbar("Model imported successfully")
snackbarHostState.showSnackbar("Model imported successfully")
}
})
}
@ -316,7 +316,6 @@ fun HomeScreen(
}
},
)
}
}

View file

@ -55,6 +55,7 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.BooleanSwitchConfig
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.IMPORTS_DIR
@ -111,6 +112,10 @@ private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
defaultValue = DEFAULT_TEMPERATURE,
valueType = ValueType.FLOAT
),
BooleanSwitchConfig(
key = ConfigKey.SUPPORT_IMAGE,
defaultValue = false,
),
SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,

View file

@ -50,7 +50,8 @@ fun ImageClassificationScreen(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message ->
onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage(
model = model,
message = message,

View file

@ -44,7 +44,8 @@ fun ImageGenerationScreen(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message ->
onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage(
model = model,
message = message,

View file

@ -17,11 +17,14 @@
package com.google.aiedge.gallery.ui.llmchat
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import com.google.mediapipe.framework.image.BitmapImageBuilder
import com.google.mediapipe.tasks.genai.llminference.GraphOptions
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
@ -55,7 +58,9 @@ object LlmChatModelHelper {
}
val options =
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend).build()
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend)
.setMaxNumImages(if (model.llmSupportImage) 1 else 0)
.build()
// Create an instance of the LLM Inference task
try {
@ -64,7 +69,10 @@ object LlmChatModelHelper {
val session = LlmInferenceSession.createFromOptions(
llmInference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature).build()
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
).build()
)
model.instance = LlmModelInstance(engine = llmInference, session = session)
} catch (e: Exception) {
@ -75,6 +83,8 @@ object LlmChatModelHelper {
}
fun resetSession(model: Model) {
Log.d(TAG, "Resetting session for model '${model.name}'")
val instance = model.instance as LlmModelInstance? ?: return
val session = instance.session
session.close()
@ -87,9 +97,13 @@ object LlmChatModelHelper {
val newSession = LlmInferenceSession.createFromOptions(
inference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature).build()
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
).build()
)
instance.session = newSession
Log.d(TAG, "Resetting done")
}
fun cleanUp(model: Model) {
@ -118,6 +132,7 @@ object LlmChatModelHelper {
resultListener: ResultListener,
cleanUpListener: CleanUpListener,
singleTurn: Boolean = false,
image: Bitmap? = null,
) {
if (singleTurn) {
resetSession(model = model)
@ -132,6 +147,9 @@ object LlmChatModelHelper {
// Start async inference.
val session = instance.session
session.addQueryChunk(input)
if (image != null) {
session.addImage(BitmapImageBuilder(image).build())
}
session.generateResponseAsync(resultListener)
}
}

View file

@ -16,15 +16,14 @@
package com.google.aiedge.gallery.ui.llmchat
import android.util.Log
import android.graphics.Bitmap
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.chat.ChatMessageInfo
import com.google.aiedge.gallery.ui.common.chat.ChatMessageImage
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
import com.google.aiedge.gallery.ui.common.chat.ChatView
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
@ -35,6 +34,11 @@ object LlmChatDestination {
val route = "LlmChatRoute"
}
object LlmImageToTextDestination {
@Serializable
val route = "LlmImageToTextRoute"
}
@Composable
fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel,
@ -43,6 +47,38 @@ fun LlmChatScreen(
viewModel: LlmChatViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
ChatViewWrapper(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = navigateUp,
modifier = modifier,
)
}
@Composable
fun LlmImageToTextScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmImageToTextViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
ChatViewWrapper(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = navigateUp,
modifier = modifier,
)
}
@Composable
fun ChatViewWrapper(
viewModel: LlmChatViewModel,
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier
) {
val context = LocalContext.current
@ -50,21 +86,33 @@ fun LlmChatScreen(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message ->
onSendMessage = { model, messages ->
for (message in messages) {
viewModel.addMessage(
model = model,
message = message,
)
if (message is ChatMessageText) {
modelManagerViewModel.addTextInputHistory(message.content)
viewModel.generateResponse(model = model, input = message.content, onError = {
viewModel.addMessage(
model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
)
}
modelManagerViewModel.initializeModel(
context = context, task = viewModel.task, model = model, force = true
var text = ""
var image: Bitmap? = null
var chatMessageText: ChatMessageText? = null
for (message in messages) {
if (message is ChatMessageText) {
chatMessageText = message
text = message.content
} else if (message is ChatMessageImage) {
image = message.bitmap
}
}
if (text.isNotEmpty() && chatMessageText != null) {
modelManagerViewModel.addTextInputHistory(text)
viewModel.generateResponse(model = model, input = text, image = image, onError = {
viewModel.handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = chatMessageText,
)
})
}
@ -72,13 +120,11 @@ fun LlmChatScreen(
onRunAgainClicked = { model, message ->
if (message is ChatMessageText) {
viewModel.runAgain(model = model, message = message, onError = {
viewModel.addMessage(
viewModel.handleError(
context = context,
model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
)
modelManagerViewModel.initializeModel(
context = context, task = viewModel.task, model = model, force = true
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = message,
)
})
}
@ -90,6 +136,9 @@ fun LlmChatScreen(
)
}
},
onResetSessionClicked = { model ->
viewModel.resetSession(model = model)
},
showStopButtonInInputWhenInProgress = true,
onStopButtonClicked = { model ->
viewModel.stopResponse(model = model)
@ -98,4 +147,3 @@ fun LlmChatScreen(
modifier = modifier,
)
}

View file

@ -16,18 +16,23 @@
package com.google.aiedge.gallery.ui.llmchat
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
import com.google.aiedge.gallery.ui.common.chat.ChatSide
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import com.google.aiedge.gallery.ui.common.chat.Stat
import kotlinx.coroutines.CoroutineExceptionHandler
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
@ -40,8 +45,8 @@ private val STATS = listOf(
Stat(id = "latency", label = "Latency", unit = "sec")
)
class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
fun generateResponse(model: Model, input: String, onError: () -> Unit) {
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
@ -58,7 +63,10 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
// Run inference.
val instance = model.instance as LlmModelInstance
val prefillTokens = instance.session.sizeInTokens(input)
var prefillTokens = instance.session.sizeInTokens(input)
if (image != null) {
prefillTokens += 257
}
var firstRun = true
var timeToFirstToken = 0f
@ -69,9 +77,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
val start = System.currentTimeMillis()
try {
LlmChatModelHelper.runInference(
model = model,
LlmChatModelHelper.runInference(model = model,
input = input,
image = image,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
@ -92,32 +100,27 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
// Add an empty message that will receive streaming results.
addMessage(
model = model,
message = ChatMessageText(content = "", side = ChatSide.AGENT)
model = model, message = ChatMessageText(content = "", side = ChatSide.AGENT)
)
}
// Incrementally update the streamed partial results.
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
updateLastTextMessageContentIncrementally(
model = model,
partialContent = partialResult,
latencyMs = latencyMs.toFloat()
model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat()
)
if (done) {
setInProgress(false)
decodeSpeed =
decodeTokens / ((curTs - firstTokenTs) / 1000f)
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) {
decodeSpeed = 0f
}
if (lastMessage is ChatMessageText) {
updateLastTextMessageLlmBenchmarkResult(
model = model, llmBenchmarkResult =
ChatMessageBenchmarkLlmResult(
model = model, llmBenchmarkResult = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
@ -131,10 +134,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
)
}
}
}, cleanUpListener = {
},
cleanUpListener = {
setInProgress(false)
})
} catch (e: Exception) {
Log.e(TAG, "Error occurred while running inference", e)
setInProgress(false)
onError()
}
@ -143,6 +148,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
fun stopResponse(model: Model) {
Log.d(TAG, "Stopping response for model ${model.name}...")
if (getLastMessage(model = model) is ChatMessageLoading) {
removeLastMessage(model = model)
}
viewModelScope.launch(Dispatchers.Default) {
setInProgress(false)
val instance = model.instance as LlmModelInstance
@ -150,6 +158,25 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
}
}
fun resetSession(model: Model) {
viewModelScope.launch(Dispatchers.Default) {
setIsResettingSession(true)
clearAllMessages(model = model)
stopResponse(model = model)
while (true) {
try {
LlmChatModelHelper.resetSession(model = model)
break
} catch (e: Exception) {
Log.d(TAG, "Failed to reset session. Trying again")
}
delay(200)
}
setIsResettingSession(false)
}
}
fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
@ -162,9 +189,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
// Run inference.
generateResponse(
model = model,
input = message.content,
onError = onError
model = model, input = message.content, onError = onError
)
}
}
@ -199,8 +224,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
var decodeSpeed: Float
val start = System.currentTimeMillis()
var lastUpdateTime = 0L
LlmChatModelHelper.runInference(
model = model,
LlmChatModelHelper.runInference(model = model,
input = message.content,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
@ -232,14 +256,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
result.append(partialResult)
if (curTs - lastUpdateTime > 500 || done) {
decodeSpeed =
decodeTokens / ((curTs - firstTokenTs) / 1000f)
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) {
decodeSpeed = 0f
}
replaceLastMessage(
model = model,
message = ChatMessageBenchmarkLlmResult(
model = model, message = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
@ -249,8 +271,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
),
running = !done,
latencyMs = -1f,
),
type = ChatMessageType.BENCHMARK_LLM_RESULT
), type = ChatMessageType.BENCHMARK_LLM_RESULT
)
lastUpdateTime = curTs
@ -261,9 +282,53 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
},
cleanUpListener = {
setInProgress(false)
})
}
}
fun handleError(
context: Context,
model: Model,
modelManagerViewModel: ModelManagerViewModel,
triggeredMessage: ChatMessageText,
) {
// Clean up.
modelManagerViewModel.cleanupModel(task = task, model = model)
// Remove the "loading" message.
if (getLastMessage(model = model) is ChatMessageLoading) {
removeLastMessage(model = model)
}
// Remove the last Text message.
if (getLastMessage(model = model) == triggeredMessage) {
removeLastMessage(model = model)
}
// Add a warning message for re-initializing the session.
addMessage(
model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the session.")
)
}
// Add the triggered message back.
addMessage(model = model, message = triggeredMessage)
// Re-initialize the session/engine.
modelManagerViewModel.initializeModel(
context = context, task = task, model = model
)
// Re-generate the response automatically.
generateResponse(model = model, input = triggeredMessage.content, onError = {
handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = triggeredMessage
)
})
}
}
class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT)

View file

@ -155,9 +155,14 @@ fun LlmSingleTurnScreen(
PromptTemplatesPanel(
model = selectedModel,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSend = { fullPrompt ->
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
}, modifier = Modifier.fillMaxSize()
},
onStopButtonClicked = { model ->
viewModel.stopResponse(model = model)
},
modifier = Modifier.fillMaxSize()
)
},
bottomView = {

View file

@ -23,6 +23,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
import com.google.aiedge.gallery.ui.common.chat.Stat
import com.google.aiedge.gallery.ui.common.processLlmResponse
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
@ -194,6 +195,15 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
}
}
fun stopResponse(model: Model) {
Log.d(TAG, "Stopping response for model ${model.name}...")
viewModelScope.launch(Dispatchers.Default) {
setInProgress(false)
val instance = model.instance as LlmModelInstance
instance.session.cancelGenerateResponseAsync()
}
}
private fun createUiState(task: Task): LlmSingleTurnUiState {
val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf()
val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> =

View file

@ -43,11 +43,13 @@ import androidx.compose.material.icons.outlined.Description
import androidx.compose.material.icons.outlined.ExpandLess
import androidx.compose.material.icons.outlined.ExpandMore
import androidx.compose.material.icons.rounded.Add
import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material.icons.rounded.Visibility
import androidx.compose.material.icons.rounded.VisibilityOff
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.FilterChipDefaults
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
@ -86,9 +88,11 @@ import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.launch
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries
private val TAB_TITLES = PromptTemplateType.entries.map { it.label }
@ -101,11 +105,14 @@ const val FULL_PROMPT_SWITCH_KEY = "full_prompt"
fun PromptTemplatesPanel(
model: Model,
viewModel: LlmSingleTurnViewModel,
modelManagerViewModel: ModelManagerViewModel,
onSend: (fullPrompt: String) -> Unit,
onStopButtonClicked: (Model) -> Unit,
modifier: Modifier = Modifier
) {
val scope = rememberCoroutineScope()
val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val inProgress = uiState.inProgress
var selectedTabIndex by remember { mutableIntStateOf(0) }
@ -123,6 +130,8 @@ fun PromptTemplatesPanel(
val focusManager = LocalFocusManager.current
val interactionSource = remember { MutableInteractionSource() }
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[model.name]
// Update input editor values when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) {
@ -328,6 +337,25 @@ fun PromptTemplatesPanel(
)
}
val modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (inProgress && !modelInitializing) {
IconButton(
onClick = {
onStopButtonClicked(model)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
),
modifier = Modifier.size(ICON_BUTTON_SIZE)
) {
Icon(
Icons.Rounded.Stop,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
} else {
// Send button
OutlinedIconButton(
enabled = !inProgress && curTextInputContent.isNotEmpty(),
@ -356,6 +384,7 @@ fun PromptTemplatesPanel(
}
}
}
}
if (showExamplePromptBottomSheet) {
ModalBottomSheet(

View file

@ -21,6 +21,10 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Scaffold
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
@ -48,6 +52,24 @@ fun ModelManager(
if (task.models.size != 1) {
title += "s"
}
// Model count.
val modelCount by remember {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
task.models.size
} else {
-1
}
}
}
// Navigate up when there are no models left.
LaunchedEffect(modelCount) {
if (modelCount == 0) {
navigateUp()
}
}
// Handle system's edge swipe.
BackHandler {

View file

@ -38,6 +38,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
@ -58,6 +59,7 @@ import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import net.openid.appauth.AuthorizationException
import net.openid.appauth.AuthorizationRequest
@ -229,7 +231,10 @@ open class ModelManagerViewModel(
}
dataStoreRepository.saveImportedModels(importedModels = importedModels)
}
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
val newUiState = uiState.value.copy(
modelDownloadStatus = curModelDownloadStatus,
tasks = uiState.value.tasks.toList()
)
_uiState.update { newUiState }
}
@ -312,6 +317,12 @@ open class ModelManagerViewModel(
onDone = onDone,
)
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
context = context, model = model, onDone = onDone
)
@ -331,6 +342,7 @@ open class ModelManagerViewModel(
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
@ -434,14 +446,16 @@ open class ModelManagerViewModel(
// Create model.
val model = createModelFromImportedModelInfo(info = info)
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES, TASK_LLM_IMAGE_TO_TEXT)) {
// Remove duplicated imported model if existed.
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES)) {
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex)
}
if (task == TASK_LLM_IMAGE_TO_TEXT && model.llmSupportImage || task != TASK_LLM_IMAGE_TO_TEXT) {
task.models.add(model)
}
task.updateTrigger.value = System.currentTimeMillis()
}
@ -632,8 +646,7 @@ open class ModelManagerViewModel(
fun loadModelAllowlist() {
_uiState.update {
uiState.value.copy(
loadingModelAllowlist = true,
loadingModelAllowlistError = ""
loadingModelAllowlist = true, loadingModelAllowlistError = ""
)
}
@ -663,6 +676,9 @@ open class ModelManagerViewModel(
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
TASK_LLM_USECASES.models.add(model)
}
if (allowedModel.taskTypes.contains(TASK_LLM_IMAGE_TO_TEXT.type.id)) {
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
}
}
// Pre-process all tasks.
@ -717,6 +733,9 @@ open class ModelManagerViewModel(
// Add to task.
TASK_LLM_CHAT.models.add(model)
TASK_LLM_USECASES.models.add(model)
if (model.llmSupportImage) {
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
}
// Update status.
modelDownloadStatus[model.name] = ModelDownloadStatus(
@ -731,7 +750,7 @@ open class ModelManagerViewModel(
Log.d(TAG, "model download status: $modelDownloadStatus")
return ModelManagerUiState(
tasks = TASKS,
tasks = TASKS.toList(),
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances,
textInputHistory = textInputHistory,
@ -763,6 +782,9 @@ open class ModelManagerViewModel(
) as Float,
accelerators = accelerators,
)
val llmSupportImage = convertValueToTargetType(
info.defaultValues[ConfigKey.SUPPORT_IMAGE.label] ?: false, ValueType.BOOLEAN
) as Boolean
val model = Model(
name = info.fileName,
url = "",
@ -770,7 +792,9 @@ open class ModelManagerViewModel(
sizeInBytes = info.fileSize,
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
showBenchmarkButton = false,
showRunAgainButton = false,
imported = true,
llmSupportImage = llmSupportImage,
)
model.preProcess()
@ -803,6 +827,7 @@ open class ModelManagerViewModel(
)
}
@OptIn(ExperimentalSerializationApi::class)
private inline fun <reified T> getJsonResponse(url: String): T? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
@ -815,7 +840,12 @@ open class ModelManagerViewModel(
val response = inputStream.bufferedReader().use { it.readText() }
// Parse JSON using kotlinx.serialization
val json = Json { ignoreUnknownKeys = true } // Handle potential extra fields
val json = Json {
// Handle potential extra fields
ignoreUnknownKeys = true
allowComments = true
allowTrailingComma = true
}
val jsonObj = json.decodeFromString<T>(response)
return jsonObj
} else {

View file

@ -46,6 +46,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.data.Task
@ -59,6 +60,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextDestination
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextScreen
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
@ -129,8 +132,7 @@ fun GalleryNavHost(
) {
val curPickedTask = pickedTask
if (curPickedTask != null) {
ModelManager(
viewModel = modelManagerViewModel,
ModelManager(viewModel = modelManagerViewModel,
task = curPickedTask,
onModelClicked = { model ->
navigateToTaskScreen(
@ -207,7 +209,7 @@ fun GalleryNavHost(
}
}
// LLMm chat demos.
// LLM chat demos.
composable(
route = "${LlmChatDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
@ -224,7 +226,7 @@ fun GalleryNavHost(
}
}
// LLMm single turn.
// LLM single turn.
composable(
route = "${LlmSingleTurnDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
@ -241,6 +243,22 @@ fun GalleryNavHost(
}
}
// LLM image to text.
composable(
route = "${LlmImageToTextDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_IMAGE_TO_TEXT)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmImageToTextScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
}
// Handle incoming intents for deep links
@ -254,9 +272,7 @@ fun GalleryNavHost(
getModelByName(modelName)?.let { model ->
// TODO(jingjin): need to show a list of possible tasks for this model.
navigateToTaskScreen(
navController = navController,
taskType = TaskType.LLM_CHAT,
model = model
navController = navController, taskType = TaskType.LLM_CHAT, model = model
)
}
}
@ -271,6 +287,7 @@ fun navigateToTaskScreen(
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_IMAGE_TO_TEXT -> navController.navigate("${LlmImageToTextDestination.route}/${modelName}")
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {}

View file

@ -44,7 +44,8 @@ fun TextClassificationScreen(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message ->
onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage(
model = model,
message = message,

View file

@ -7,7 +7,7 @@ junitVersion = "1.2.1"
espressoCore = "3.6.1"
lifecycleRuntimeKtx = "2.8.7"
activityCompose = "1.10.1"
composeBom = "2025.03.01"
composeBom = "2025.05.00"
navigation = "2.8.9"
serializationPlugin = "2.0.21"
serializationJson = "1.7.3"