Add support for image to text models

This commit is contained in:
Jing Jin 2025-05-16 00:17:24 -07:00
parent ef290cd7b0
commit bedc488a15
29 changed files with 763 additions and 231 deletions

View file

@ -19,7 +19,6 @@ package com.google.aiedge.gallery.data
import android.content.Context import android.content.Context
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
import com.google.aiedge.gallery.ui.common.convertValueToTargetType import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import java.io.File import java.io.File
data class ModelDataFile( data class ModelDataFile(
@ -91,6 +90,9 @@ data class Model(
/** The prompt templates for the model (only for LLM). */ /** The prompt templates for the model (only for LLM). */
val llmPromptTemplates: List<PromptTemplate> = listOf(), val llmPromptTemplates: List<PromptTemplate> = listOf(),
/** Whether the LLM model supports image input. */
val llmSupportImage: Boolean = false,
/** Whether the model is imported or not. */ /** Whether the model is imported or not. */
val imported: Boolean = false, val imported: Boolean = false,
@ -204,6 +206,7 @@ enum class ConfigKey(val label: String) {
DEFAULT_TOPK("Default TopK"), DEFAULT_TOPK("Default TopK"),
DEFAULT_TOPP("Default TopP"), DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"), DEFAULT_TEMPERATURE("Default temperature"),
SUPPORT_IMAGE("Support image"),
MAX_RESULT_COUNT("Max result count"), MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"), USE_GPU("Use GPU"),
ACCELERATOR("Accelerator"), ACCELERATOR("Accelerator"),
@ -250,83 +253,9 @@ const val IMAGE_CLASSIFICATION_INFO = ""
const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/litert/android" const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/litert/android"
const val LLM_CHAT_INFO =
"Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms"
const val IMAGE_GENERATION_INFO = const val IMAGE_GENERATION_INFO =
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)" "Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
////////////////////////////////////////////////////////////////////////////////////////////////////
// Model spec.
val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
name = "Gemma 2B (GPU int4)",
downloadFileName = "gemma-2b-it-gpu-int4.bin",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
sizeInBytes = 1354301440L,
configs = createLlmChatConfigs(
accelerators = listOf(Accelerator.GPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
name = "Gemma 2 2B (GPU int8)",
downloadFileName = "gemma2-2b-it-gpu-int8.bin",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
sizeInBytes = 2627141632L,
configs = createLlmChatConfigs(
accelerators = listOf(Accelerator.GPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
name = "Gemma 3 1B (int4)",
downloadFileName = "gemma3-1b-it-int4.task",
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
sizeInBytes = 554661243L,
configs = createLlmChatConfigs(
defaultTopK = 64,
defaultTopP = 0.95f,
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
llmPromptTemplates = listOf(
PromptTemplate(
title = "Emoji Fun",
description = "Generate emojis by emotions",
prompt = "Show me emojis grouped by emotions"
),
PromptTemplate(
title = "Trip Planner",
description = "Plan a trip to a destination",
prompt = "Plan a two-day trip to San Francisco"
),
)
)
val MODEL_LLM_DEEPSEEK: Model = Model(
name = "Deepseek",
downloadFileName = "deepseek.task",
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
sizeInBytes = 1860686856L,
configs = createLlmChatConfigs(
defaultTemperature = 0.6f,
defaultTopK = 40,
defaultTopP = 0.7f,
accelerators = listOf(Accelerator.CPU)
),
showBenchmarkButton = false,
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
)
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model( val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
name = "MobileBert", name = "MobileBert",
@ -414,12 +343,5 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2, MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
) )
val MODELS_LLM: MutableList<Model> = mutableListOf(
// MODEL_LLM_GEMMA_2B_GPU_INT4,
// MODEL_LLM_GEMMA_2_2B_GPU_INT8,
MODEL_LLM_GEMMA_3_1B_INT4,
MODEL_LLM_DEEPSEEK,
)
val MODELS_IMAGE_GENERATION: MutableList<Model> = val MODELS_IMAGE_GENERATION: MutableList<Model> =
mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION) mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION)

View file

@ -35,6 +35,7 @@ data class AllowedModel(
val defaultConfig: Map<String, ConfigValue>, val defaultConfig: Map<String, ConfigValue>,
val taskTypes: List<String>, val taskTypes: List<String>,
val disabled: Boolean? = null, val disabled: Boolean? = null,
val llmSupportImage: Boolean? = null,
) { ) {
fun toModel(): Model { fun toModel(): Model {
// Construct HF download url. // Construct HF download url.
@ -48,7 +49,7 @@ data class AllowedModel(
var defaultTopK: Int = DEFAULT_TOPK var defaultTopK: Int = DEFAULT_TOPK
var defaultTopP: Float = DEFAULT_TOPP var defaultTopP: Float = DEFAULT_TOPP
var defaultTemperature: Float = DEFAULT_TEMPERATURE var defaultTemperature: Float = DEFAULT_TEMPERATURE
var defaultMaxToken: Int = 1024 var defaultMaxToken = 1024
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
if (defaultConfig.containsKey("topK")) { if (defaultConfig.containsKey("topK")) {
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK) defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK)
@ -84,9 +85,10 @@ data class AllowedModel(
// Misc. // Misc.
var showBenchmarkButton = true var showBenchmarkButton = true
val showRunAgainButton = true var showRunAgainButton = true
if (taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)) { if (isLlmModel) {
showBenchmarkButton = false showBenchmarkButton = false
showRunAgainButton = false
} }
return Model( return Model(
@ -99,7 +101,8 @@ data class AllowedModel(
downloadFileName = modelFile, downloadFileName = modelFile,
showBenchmarkButton = showBenchmarkButton, showBenchmarkButton = showBenchmarkButton,
showRunAgainButton = showRunAgainButton, showRunAgainButton = showRunAgainButton,
learnMoreUrl = "https://huggingface.co/${modelId}" learnMoreUrl = "https://huggingface.co/${modelId}",
llmSupportImage = llmSupportImage == true,
) )
} }

View file

@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data
import androidx.annotation.StringRes import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Forum import androidx.compose.material.icons.outlined.Forum
import androidx.compose.material.icons.outlined.Mms
import androidx.compose.material.icons.outlined.Widgets import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.material.icons.rounded.ImageSearch import androidx.compose.material.icons.rounded.ImageSearch
import androidx.compose.runtime.MutableState import androidx.compose.runtime.MutableState
@ -33,6 +34,7 @@ enum class TaskType(val label: String, val id: String) {
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"), IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
LLM_CHAT(label = "LLM Chat", id = "llm_chat"), LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"), LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2") TEST_TASK_2(label = "Test task 2", id = "test_task_2")
@ -93,7 +95,7 @@ val TASK_LLM_CHAT = Task(
icon = Icons.Outlined.Forum, icon = Icons.Outlined.Forum,
// models = MODELS_LLM, // models = MODELS_LLM,
models = mutableListOf(), models = mutableListOf(),
description = "Chat with a on-device large language model", description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
@ -110,6 +112,17 @@ val TASK_LLM_USECASES = Task(
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
) )
val TASK_LLM_IMAGE_TO_TEXT = Task(
type = TaskType.LLM_IMAGE_TO_TEXT,
icon = Icons.Outlined.Mms,
// models = MODELS_LLM,
models = mutableListOf(),
description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_IMAGE_GENERATION = Task( val TASK_IMAGE_GENERATION = Task(
type = TaskType.IMAGE_GENERATION, type = TaskType.IMAGE_GENERATION,
iconVectorResourceId = R.drawable.image_spark, iconVectorResourceId = R.drawable.image_spark,
@ -127,6 +140,7 @@ val TASKS: List<Task> = listOf(
// TASK_IMAGE_GENERATION, // TASK_IMAGE_GENERATION,
TASK_LLM_USECASES, TASK_LLM_USECASES,
TASK_LLM_CHAT, TASK_LLM_CHAT,
TASK_LLM_IMAGE_TO_TEXT
) )
fun getModelByName(name: String): Model? { fun getModelByName(name: String): Model? {

View file

@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
@ -62,6 +63,12 @@ object ViewModelProvider {
LlmSingleTurnViewModel() LlmSingleTurnViewModel()
} }
// Initializer for LlmImageToTextViewModel.
initializer {
LlmImageToTextViewModel()
}
// Initializer for ImageGenerationViewModel.
initializer { initializer {
ImageGenerationViewModel() ImageGenerationViewModel()
} }

View file

@ -17,13 +17,17 @@
package com.google.aiedge.gallery.ui.common package com.google.aiedge.gallery.ui.common
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.Settings import androidx.compose.material.icons.rounded.MapsUgc
import androidx.compose.material.icons.rounded.Tune
import androidx.compose.material3.CenterAlignedTopAppBar import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
@ -38,6 +42,7 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.scale
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource import androidx.compose.ui.res.vectorResource
@ -47,6 +52,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@ -58,12 +64,17 @@ fun ModelPageAppBar(
onBackClicked: () -> Unit, onBackClicked: () -> Unit,
onModelSelected: (Model) -> Unit, onModelSelected: (Model) -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
isResettingSession: Boolean = false,
onResetSessionClicked: (Model) -> Unit = {},
showResetSessionButton: Boolean = false,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> }, onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
) { ) {
var showConfigDialog by remember { mutableStateOf(false) } var showConfigDialog by remember { mutableStateOf(false) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val context = LocalContext.current val context = LocalContext.current
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name] val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[model.name]
CenterAlignedTopAppBar(title = { CenterAlignedTopAppBar(title = {
Column( Column(
@ -110,19 +121,58 @@ fun ModelPageAppBar(
actions = { actions = {
val showConfigButton = val showConfigButton =
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
Box(modifier = Modifier.size(42.dp), contentAlignment = Alignment.Center) {
var configButtonOffset = 0.dp
if (showConfigButton && showResetSessionButton) {
configButtonOffset = (-40).dp
}
val isModelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (showConfigButton) {
IconButton( IconButton(
onClick = { onClick = {
showConfigDialog = true showConfigDialog = true
}, },
enabled = showConfigButton, enabled = !isModelInitializing,
modifier = Modifier.alpha(if (showConfigButton) 1f else 0f) modifier = Modifier
.scale(0.75f)
.offset(x = configButtonOffset)
.alpha(if (isModelInitializing) 0.5f else 1f)
) { ) {
Icon( Icon(
imageVector = Icons.Rounded.Settings, imageVector = Icons.Rounded.Tune,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.primary tint = MaterialTheme.colorScheme.primary
) )
} }
}
if (showResetSessionButton) {
if (isResettingSession) {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceVariant,
strokeWidth = 2.dp,
modifier = Modifier.size(16.dp)
)
} else {
IconButton(
onClick = {
onResetSessionClicked(model)
},
enabled = !isModelInitializing,
modifier = Modifier
.scale(0.75f)
.alpha(if (isModelInitializing) 0.5f else 1f)
) {
Icon(
imageVector = Icons.Rounded.MapsUgc,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
}
}
}) })
// Config dialog. // Config dialog.

View file

@ -29,6 +29,7 @@ import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.widthIn
import androidx.compose.foundation.pager.HorizontalPager import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.CircleShape
@ -54,6 +55,9 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalWindowInfo
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
@ -77,6 +81,10 @@ fun ModelPickerChipsPager(
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current
val screenWidthDp =
remember { with(density) { windowInfo.containerSize.width.toDp() } }
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel), val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
pageCount = { task.models.size }) pageCount = { task.models.size })
@ -140,7 +148,12 @@ fun ModelPickerChipsPager(
Text( Text(
model.name, model.name,
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelSmall,
modifier = Modifier.padding(start = 4.dp), modifier = Modifier
.padding(start = 4.dp)
.widthIn(0.dp, screenWidthDp - 250.dp),
maxLines = 1,
overflow = TextOverflow.MiddleEllipsis
) )
Icon( Icon(
Icons.Rounded.ArrowDropDown, Icons.Rounded.ArrowDropDown,

View file

@ -109,7 +109,7 @@ fun ChatPanel(
task: Task, task: Task,
selectedModel: Model, selectedModel: Model,
viewModel: ChatViewModel, viewModel: ChatViewModel,
onSendMessage: (Model, ChatMessage) -> Unit, onSendMessage: (Model, List<ChatMessage>) -> Unit,
onRunAgainClicked: (Model, ChatMessage) -> Unit, onRunAgainClicked: (Model, ChatMessage) -> Unit,
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit, onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
navigateUp: () -> Unit, navigateUp: () -> Unit,
@ -280,7 +280,8 @@ fun ChatPanel(
task = task, task = task,
onPromptClicked = { template -> onPromptClicked = { template ->
onSendMessage( onSendMessage(
selectedModel, ChatMessageText(content = template.prompt, side = ChatSide.USER) selectedModel,
listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER))
) )
}) })
@ -430,12 +431,15 @@ fun ChatPanel(
// Chat input // Chat input
when (chatInputType) { when (chatInputType) {
ChatInputType.TEXT -> { ChatInputType.TEXT -> {
val isLlmTask = task.type == TaskType.LLM_CHAT // val isLlmTask = task.type == TaskType.LLM_CHAT
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates) // val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
val hasImageMessage = messages.any { it is ChatMessageImage }
MessageInputText( MessageInputText(
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
curMessage = curMessage, curMessage = curMessage,
inProgress = uiState.inProgress, inProgress = uiState.inProgress,
isResettingSession = uiState.isResettingSession,
hasImageMessage = hasImageMessage,
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes, textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
onValueChanged = { curMessage = it }, onValueChanged = { curMessage = it },
@ -445,13 +449,17 @@ fun ChatPanel(
}, },
onOpenPromptTemplatesClicked = { onOpenPromptTemplatesClicked = {
onSendMessage( onSendMessage(
selectedModel, ChatMessagePromptTemplates( selectedModel, listOf(
ChatMessagePromptTemplates(
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
) )
) )
)
}, },
onStopButtonClicked = onStopButtonClicked, onStopButtonClicked = onStopButtonClicked,
showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, // showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false,
showImagePickerInMenu = selectedModel.llmSupportImage == true,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress, showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
) )
} }
@ -461,10 +469,12 @@ fun ChatPanel(
streamingMessage = streamingMessage, streamingMessage = streamingMessage,
onImageSelected = { bitmap -> onImageSelected = { bitmap ->
onSendMessage( onSendMessage(
selectedModel, ChatMessageImage( selectedModel, listOf(
ChatMessageImage(
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
) )
) )
)
}, },
onStreamImage = { bitmap -> onStreamImage = { bitmap ->
onStreamImageMessage( onStreamImageMessage(

View file

@ -70,16 +70,18 @@ fun ChatView(
task: Task, task: Task,
viewModel: ChatViewModel, viewModel: ChatViewModel,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
onSendMessage: (Model, ChatMessage) -> Unit, onSendMessage: (Model, List<ChatMessage>) -> Unit,
onRunAgainClicked: (Model, ChatMessage) -> Unit, onRunAgainClicked: (Model, ChatMessage) -> Unit,
onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit, onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
onResetSessionClicked: (Model) -> Unit = {},
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> }, onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
onStopButtonClicked: (Model) -> Unit = {}, onStopButtonClicked: (Model) -> Unit = {},
chatInputType: ChatInputType = ChatInputType.TEXT, chatInputType: ChatInputType = ChatInputType.TEXT,
showStopButtonInInputWhenInProgress: Boolean = false, showStopButtonInInputWhenInProgress: Boolean = false,
) { ) {
val uiStat by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel val selectedModel = modelManagerUiState.selectedModel
@ -155,6 +157,9 @@ fun ChatView(
task = task, task = task,
model = selectedModel, model = selectedModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
showResetSessionButton = true,
isResettingSession = uiStat.isResettingSession,
onResetSessionClicked = onResetSessionClicked,
onConfigChanged = { old, new -> onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage( viewModel.addConfigChangedMessage(
oldConfigValues = old, oldConfigValues = old,

View file

@ -33,6 +33,11 @@ data class ChatUiState(
*/ */
val inProgress: Boolean = false, val inProgress: Boolean = false,
/**
* Indicates whether the session is being reset.
*/
val isResettingSession: Boolean = false,
/** /**
* A map of model names to lists of chat messages. * A map of model names to lists of chat messages.
*/ */
@ -106,6 +111,12 @@ open class ChatViewModel(val task: Task) : ViewModel() {
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) } _uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
} }
fun clearAllMessages(model: Model) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
newMessagesByModel[model.name] = mutableListOf()
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
}
fun getLastMessage(model: Model): ChatMessage? { fun getLastMessage(model: Model): ChatMessage? {
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull() return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
} }
@ -189,6 +200,10 @@ open class ChatViewModel(val task: Task) : ViewModel() {
_uiState.update { _uiState.value.copy(inProgress = inProgress) } _uiState.update { _uiState.value.copy(inProgress = inProgress) }
} }
fun setIsResettingSession(isResettingSession: Boolean) {
_uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) }
}
fun addConfigChangedMessage( fun addConfigChangedMessage(
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
) { ) {

View file

@ -21,14 +21,18 @@ import androidx.compose.material3.ProvideTextStyle
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.CompositionLocalProvider import androidx.compose.runtime.CompositionLocalProvider
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.TextLinkStyles
import androidx.compose.ui.text.TextStyle import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontFamily import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
import com.halilibo.richtext.commonmark.Markdown import com.halilibo.richtext.commonmark.Markdown
import com.halilibo.richtext.ui.CodeBlockStyle import com.halilibo.richtext.ui.CodeBlockStyle
import com.halilibo.richtext.ui.RichTextStyle import com.halilibo.richtext.ui.RichTextStyle
import com.halilibo.richtext.ui.material3.RichText import com.halilibo.richtext.ui.material3.RichText
import com.halilibo.richtext.ui.string.RichTextStringStyle
/** /**
* Composable function to display Markdown-formatted text. * Composable function to display Markdown-formatted text.
@ -56,6 +60,11 @@ fun MarkdownText(
fontSize = MaterialTheme.typography.bodySmall.fontSize, fontSize = MaterialTheme.typography.bodySmall.fontSize,
fontFamily = FontFamily.Monospace, fontFamily = FontFamily.Monospace,
) )
),
stringStyle = RichTextStringStyle(
linkStyle = TextLinkStyles(
style = SpanStyle(color = MaterialTheme.customColors.linkColor)
)
) )
), ),
) { ) {

View file

@ -46,7 +46,11 @@ fun MessageBodyWarning(message: ChatMessageWarning) {
.clip(RoundedCornerShape(16.dp)) .clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer) .background(MaterialTheme.colorScheme.tertiaryContainer)
) { ) {
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true) MarkdownText(
text = message.content,
modifier = Modifier.padding(horizontal = 16.dp, vertical = 6.dp),
smallFontSize = true
)
} }
} }
} }

View file

@ -16,28 +16,45 @@
package com.google.aiedge.gallery.ui.common.chat package com.google.aiedge.gallery.ui.common.chat
import android.Manifest
import android.content.Context
import android.content.pm.PackageManager
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Matrix
import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.PickVisualMediaRequest
import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.StringRes import androidx.annotation.StringRes
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.Send import androidx.compose.material.icons.automirrored.rounded.Send
import androidx.compose.material.icons.rounded.Add import androidx.compose.material.icons.rounded.Add
import androidx.compose.material.icons.rounded.Close
import androidx.compose.material.icons.rounded.History import androidx.compose.material.icons.rounded.History
import androidx.compose.material.icons.rounded.Photo
import androidx.compose.material.icons.rounded.PhotoCamera
import androidx.compose.material.icons.rounded.PostAdd import androidx.compose.material.icons.rounded.PostAdd
import androidx.compose.material.icons.rounded.Stop import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material3.DropdownMenu import androidx.compose.material3.DropdownMenu
import androidx.compose.material3.DropdownMenuItem import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults import androidx.compose.material3.IconButtonDefaults
@ -54,12 +71,17 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.shadow
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.core.content.ContextCompat
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.ui.common.createTempPictureUri
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.GalleryTheme
@ -74,24 +96,107 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
fun MessageInputText( fun MessageInputText(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
curMessage: String, curMessage: String,
isResettingSession: Boolean,
inProgress: Boolean, inProgress: Boolean,
hasImageMessage: Boolean,
modelInitializing: Boolean, modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int, @StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit, onValueChanged: (String) -> Unit,
onSendMessage: (ChatMessage) -> Unit, onSendMessage: (List<ChatMessage>) -> Unit,
onOpenPromptTemplatesClicked: () -> Unit = {}, onOpenPromptTemplatesClicked: () -> Unit = {},
onStopButtonClicked: () -> Unit = {}, onStopButtonClicked: () -> Unit = {},
showPromptTemplatesInMenu: Boolean = true, showPromptTemplatesInMenu: Boolean = false,
showImagePickerInMenu: Boolean = false,
showStopButtonWhenInProgress: Boolean = false, showStopButtonWhenInProgress: Boolean = false,
) { ) {
val context = LocalContext.current
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
var showAddContentMenu by remember { mutableStateOf(false) } var showAddContentMenu by remember { mutableStateOf(false) }
var showTextInputHistorySheet by remember { mutableStateOf(false) } var showTextInputHistorySheet by remember { mutableStateOf(false) }
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
val updatePickedImages: (Bitmap) -> Unit = { bitmap ->
val newPickedImages: MutableList<Bitmap> = mutableListOf()
newPickedImages.addAll(pickedImages)
newPickedImages.add(bitmap)
pickedImages = newPickedImages.toList()
}
// launches camera
val cameraLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.TakePicture()) { isImageSaved ->
if (isImageSaved) {
handleImageSelected(
context = context,
uri = tempPhotoUri,
onImageSelected = { bitmap ->
updatePickedImages(bitmap)
},
rotateForPortrait = true,
)
}
}
// Permission request when taking picture.
val takePicturePermissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) { permissionGranted ->
if (permissionGranted) {
showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri)
}
}
// Registers a photo picker activity launcher in single-select mode.
val pickMedia =
rememberLauncherForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri ->
// Callback is invoked after the user selects a media item or closes the
// photo picker.
if (uri != null) {
handleImageSelected(context = context, uri = uri, onImageSelected = { bitmap ->
updatePickedImages(bitmap)
})
}
}
Box(contentAlignment = Alignment.CenterStart) { Box(contentAlignment = Alignment.CenterStart) {
// A preview panel for the selected image.
if (pickedImages.isNotEmpty()) {
Box(
contentAlignment = Alignment.TopEnd, modifier = Modifier.offset(x = 16.dp, y = (-80).dp)
) {
Image(
bitmap = pickedImages.last().asImageBitmap(),
contentDescription = "",
modifier = Modifier
.height(80.dp)
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
.clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)),
)
Box(modifier = Modifier
.offset(x = 10.dp, y = (-10).dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surface)
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
.clickable {
pickedImages = listOf()
}) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
modifier = Modifier
.padding(3.dp)
.size(16.dp)
)
}
}
}
// A plus button to show a popup menu to add stuff to the chat. // A plus button to show a popup menu to add stuff to the chat.
IconButton( IconButton(
enabled = !inProgress, enabled = !inProgress && !isResettingSession,
onClick = { showAddContentMenu = true }, onClick = { showAddContentMenu = true },
modifier = Modifier modifier = Modifier
.offset(x = 16.dp) .offset(x = 16.dp)
@ -113,6 +218,58 @@ fun MessageInputText(
DropdownMenu( DropdownMenu(
expanded = showAddContentMenu, expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false }) { onDismissRequest = { showAddContentMenu = false }) {
if (showImagePickerInMenu) {
// Take a photo.
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
Text("Take a photo")
}
},
enabled = pickedImages.isEmpty() && !hasImageMessage,
onClick = {
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(
context, Manifest.permission.CAMERA
) -> {
showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri)
}
// Otherwise, ask for permission
else -> {
takePicturePermissionLauncher.launch(Manifest.permission.CAMERA)
}
}
})
// Pick an image from album.
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.Photo, contentDescription = "")
Text("Pick from album")
}
},
enabled = pickedImages.isEmpty() && !hasImageMessage,
onClick = {
// Launch the photo picker and let the user choose only images.
pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly))
showAddContentMenu = false
})
}
// Prompt templates.
if (showPromptTemplatesInMenu) { if (showPromptTemplatesInMenu) {
DropdownMenuItem(text = { DropdownMenuItem(text = {
Row( Row(
@ -127,6 +284,7 @@ fun MessageInputText(
showAddContentMenu = false showAddContentMenu = false
}) })
} }
// Prompt history.
DropdownMenuItem(text = { DropdownMenuItem(text = {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
@ -171,18 +329,19 @@ fun MessageInputText(
), ),
) { ) {
Icon( Icon(
Icons.Rounded.Stop, Icons.Rounded.Stop, contentDescription = "", tint = MaterialTheme.colorScheme.primary
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
) )
} }
} }
} // Send button. Only shown when text is not empty. } // Send button. Only shown when text is not empty.
else if (curMessage.isNotEmpty()) { else if (curMessage.isNotEmpty()) {
IconButton( IconButton(
enabled = !inProgress, enabled = !inProgress && !isResettingSession,
onClick = { onClick = {
onSendMessage(ChatMessageText(content = curMessage.trim(), side = ChatSide.USER)) onSendMessage(
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
)
pickedImages = listOf()
}, },
colors = IconButtonDefaults.iconButtonColors( colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer, containerColor = MaterialTheme.colorScheme.secondaryContainer,
@ -200,7 +359,6 @@ fun MessageInputText(
} }
} }
// A bottom sheet to show the text input history to pick from. // A bottom sheet to show the text input history to pick from.
if (showTextInputHistorySheet) { if (showTextInputHistorySheet) {
TextInputHistorySheet( TextInputHistorySheet(
@ -209,7 +367,8 @@ fun MessageInputText(
showTextInputHistorySheet = false showTextInputHistorySheet = false
}, },
onHistoryItemClicked = { item -> onHistoryItemClicked = { item ->
onSendMessage(ChatMessageText(content = item, side = ChatSide.USER)) onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
pickedImages = listOf()
modelManagerViewModel.promoteTextInputHistoryItem(item) modelManagerViewModel.promoteTextInputHistoryItem(item)
}, },
onHistoryItemDeleted = { item -> onHistoryItemDeleted = { item ->
@ -217,9 +376,55 @@ fun MessageInputText(
}, },
onHistoryItemsDeleteAll = { onHistoryItemsDeleteAll = {
modelManagerViewModel.clearTextInputHistory() modelManagerViewModel.clearTextInputHistory()
})
} }
}
private fun handleImageSelected(
context: Context,
uri: Uri,
onImageSelected: (Bitmap) -> Unit,
// For some reason, some Android phone would store the picture taken by the camera rotated
// horizontally. Use this flag to rotate the image back to portrait if the picture's width
// is bigger than height.
rotateForPortrait: Boolean = false,
) {
val bitmap: Bitmap? = try {
val inputStream = context.contentResolver.openInputStream(uri)
val tmpBitmap = BitmapFactory.decodeStream(inputStream)
if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) {
val matrix = Matrix()
matrix.postRotate(90f)
Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true)
} else {
tmpBitmap
}
} catch (e: Exception) {
e.printStackTrace()
null
}
if (bitmap != null) {
onImageSelected(bitmap)
}
}
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> {
val messages: MutableList<ChatMessage> = mutableListOf()
if (pickedImages.isNotEmpty()) {
val lastImage = pickedImages.last()
messages.add(
ChatMessageImage(
bitmap = lastImage, imageBitMap = lastImage.asImageBitmap(), side = ChatSide.USER
)
) )
} }
messages.add(
ChatMessageText(
content = text, side = ChatSide.USER
)
)
return messages
} }
@Preview(showBackground = true) @Preview(showBackground = true)
@ -233,6 +438,21 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context), modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello", curMessage = "hello",
inProgress = false, inProgress = false,
isResettingSession = false,
modelInitializing = false,
hasImageMessage = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
showImagePickerInMenu = true,
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = false,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false, modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {}, onValueChanged = {},
@ -243,6 +463,8 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context), modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello", curMessage = "hello",
inProgress = true, inProgress = true,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false, modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {}, onValueChanged = {},
@ -252,6 +474,8 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context), modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "", curMessage = "",
inProgress = false, inProgress = false,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false, modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {}, onValueChanged = {},
@ -261,6 +485,8 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context), modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "", curMessage = "",
inProgress = true, inProgress = true,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false, modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {}, onValueChanged = {},

View file

@ -52,6 +52,7 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
@ -149,6 +150,8 @@ private fun SheetContent(
Text( Text(
item, item,
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
maxLines = 3,
overflow = TextOverflow.Ellipsis,
modifier = Modifier modifier = Modifier
.padding(vertical = 16.dp) .padding(vertical = 16.dp)
.padding(start = 16.dp) .padding(start = 16.dp)

View file

@ -76,7 +76,7 @@ fun ModelNameAndStatus(
Text( Text(
model.name, model.name,
maxLines = 1, maxLines = 1,
overflow = TextOverflow.Ellipsis, overflow = TextOverflow.MiddleEllipsis,
style = MaterialTheme.typography.titleMedium, style = MaterialTheme.typography.titleMedium,
modifier = modifier, modifier = modifier,
) )

View file

@ -136,7 +136,7 @@ fun HomeScreen(
val snackbarHostState = remember { SnackbarHostState() } val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val tasks = uiState.tasks val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 }
val loadingHfModels = uiState.loadingHfModels val loadingHfModels = uiState.loadingHfModels
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult( val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
@ -183,14 +183,14 @@ fun HomeScreen(
) { innerPadding -> ) { innerPadding ->
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) { Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
TaskList( TaskList(
tasks = tasks, tasks = nonEmptyTasks,
navigateToTaskScreen = navigateToTaskScreen, navigateToTaskScreen = navigateToTaskScreen,
loadingModelAllowlist = uiState.loadingModelAllowlist, loadingModelAllowlist = uiState.loadingModelAllowlist,
modifier = Modifier.fillMaxSize(), modifier = Modifier.fillMaxSize(),
contentPadding = innerPadding, contentPadding = innerPadding,
) )
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 16.dp)) SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 32.dp))
} }
} }
@ -285,7 +285,7 @@ fun HomeScreen(
// Show a snack bar for successful import. // Show a snack bar for successful import.
scope.launch { scope.launch {
snackbarHostState.showSnackbar("Model imported successfully") snackbarHostState.showSnackbar("Model imported successfully")
} }
}) })
} }
@ -316,7 +316,6 @@ fun HomeScreen(
} }
}, },
) )
} }
} }

View file

@ -55,6 +55,7 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties import androidx.compose.ui.window.DialogProperties
import com.google.aiedge.gallery.data.Accelerator import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.BooleanSwitchConfig
import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.IMPORTS_DIR import com.google.aiedge.gallery.data.IMPORTS_DIR
@ -111,6 +112,10 @@ private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
defaultValue = DEFAULT_TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE,
valueType = ValueType.FLOAT valueType = ValueType.FLOAT
), ),
BooleanSwitchConfig(
key = ConfigKey.SUPPORT_IMAGE,
defaultValue = false,
),
SegmentedButtonConfig( SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS, key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label, defaultValue = Accelerator.CPU.label,

View file

@ -50,7 +50,8 @@ fun ImageClassificationScreen(
task = viewModel.task, task = viewModel.task,
viewModel = viewModel, viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message -> onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage( viewModel.addMessage(
model = model, model = model,
message = message, message = message,

View file

@ -44,7 +44,8 @@ fun ImageGenerationScreen(
task = viewModel.task, task = viewModel.task,
viewModel = viewModel, viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message -> onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage( viewModel.addMessage(
model = model, model = model,
message = message, message = message,

View file

@ -17,11 +17,14 @@
package com.google.aiedge.gallery.ui.llmchat package com.google.aiedge.gallery.ui.llmchat
import android.content.Context import android.content.Context
import android.graphics.Bitmap
import android.util.Log import android.util.Log
import com.google.aiedge.gallery.data.Accelerator import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import com.google.mediapipe.framework.image.BitmapImageBuilder
import com.google.mediapipe.tasks.genai.llminference.GraphOptions
import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
@ -55,7 +58,9 @@ object LlmChatModelHelper {
} }
val options = val options =
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context)) LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend).build() .setMaxTokens(maxTokens).setPreferredBackend(preferredBackend)
.setMaxNumImages(if (model.llmSupportImage) 1 else 0)
.build()
// Create an instance of the LLM Inference task // Create an instance of the LLM Inference task
try { try {
@ -64,7 +69,10 @@ object LlmChatModelHelper {
val session = LlmInferenceSession.createFromOptions( val session = LlmInferenceSession.createFromOptions(
llmInference, llmInference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature).build() .setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
).build()
) )
model.instance = LlmModelInstance(engine = llmInference, session = session) model.instance = LlmModelInstance(engine = llmInference, session = session)
} catch (e: Exception) { } catch (e: Exception) {
@ -75,6 +83,8 @@ object LlmChatModelHelper {
} }
fun resetSession(model: Model) { fun resetSession(model: Model) {
Log.d(TAG, "Resetting session for model '${model.name}'")
val instance = model.instance as LlmModelInstance? ?: return val instance = model.instance as LlmModelInstance? ?: return
val session = instance.session val session = instance.session
session.close() session.close()
@ -87,9 +97,13 @@ object LlmChatModelHelper {
val newSession = LlmInferenceSession.createFromOptions( val newSession = LlmInferenceSession.createFromOptions(
inference, inference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature).build() .setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
).build()
) )
instance.session = newSession instance.session = newSession
Log.d(TAG, "Resetting done")
} }
fun cleanUp(model: Model) { fun cleanUp(model: Model) {
@ -118,6 +132,7 @@ object LlmChatModelHelper {
resultListener: ResultListener, resultListener: ResultListener,
cleanUpListener: CleanUpListener, cleanUpListener: CleanUpListener,
singleTurn: Boolean = false, singleTurn: Boolean = false,
image: Bitmap? = null,
) { ) {
if (singleTurn) { if (singleTurn) {
resetSession(model = model) resetSession(model = model)
@ -132,6 +147,9 @@ object LlmChatModelHelper {
// Start async inference. // Start async inference.
val session = instance.session val session = instance.session
session.addQueryChunk(input) session.addQueryChunk(input)
if (image != null) {
session.addImage(BitmapImageBuilder(image).build())
}
session.generateResponseAsync(resultListener) session.generateResponseAsync(resultListener)
} }
} }

View file

@ -16,15 +16,14 @@
package com.google.aiedge.gallery.ui.llmchat package com.google.aiedge.gallery.ui.llmchat
import android.util.Log import android.graphics.Bitmap
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.ui.ViewModelProvider import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.chat.ChatMessageInfo import com.google.aiedge.gallery.ui.common.chat.ChatMessageImage
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
import com.google.aiedge.gallery.ui.common.chat.ChatView import com.google.aiedge.gallery.ui.common.chat.ChatView
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
@ -35,6 +34,11 @@ object LlmChatDestination {
val route = "LlmChatRoute" val route = "LlmChatRoute"
} }
object LlmImageToTextDestination {
@Serializable
val route = "LlmImageToTextRoute"
}
@Composable @Composable
fun LlmChatScreen( fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
@ -43,6 +47,38 @@ fun LlmChatScreen(
viewModel: LlmChatViewModel = viewModel( viewModel: LlmChatViewModel = viewModel(
factory = ViewModelProvider.Factory factory = ViewModelProvider.Factory
), ),
) {
ChatViewWrapper(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = navigateUp,
modifier = modifier,
)
}
@Composable
fun LlmImageToTextScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmImageToTextViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
ChatViewWrapper(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = navigateUp,
modifier = modifier,
)
}
@Composable
fun ChatViewWrapper(
viewModel: LlmChatViewModel,
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier
) { ) {
val context = LocalContext.current val context = LocalContext.current
@ -50,21 +86,33 @@ fun LlmChatScreen(
task = viewModel.task, task = viewModel.task,
viewModel = viewModel, viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message -> onSendMessage = { model, messages ->
for (message in messages) {
viewModel.addMessage( viewModel.addMessage(
model = model, model = model,
message = message, message = message,
) )
if (message is ChatMessageText) { }
modelManagerViewModel.addTextInputHistory(message.content)
viewModel.generateResponse(model = model, input = message.content, onError = {
viewModel.addMessage(
model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
)
modelManagerViewModel.initializeModel( var text = ""
context = context, task = viewModel.task, model = model, force = true var image: Bitmap? = null
var chatMessageText: ChatMessageText? = null
for (message in messages) {
if (message is ChatMessageText) {
chatMessageText = message
text = message.content
} else if (message is ChatMessageImage) {
image = message.bitmap
}
}
if (text.isNotEmpty() && chatMessageText != null) {
modelManagerViewModel.addTextInputHistory(text)
viewModel.generateResponse(model = model, input = text, image = image, onError = {
viewModel.handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = chatMessageText,
) )
}) })
} }
@ -72,13 +120,11 @@ fun LlmChatScreen(
onRunAgainClicked = { model, message -> onRunAgainClicked = { model, message ->
if (message is ChatMessageText) { if (message is ChatMessageText) {
viewModel.runAgain(model = model, message = message, onError = { viewModel.runAgain(model = model, message = message, onError = {
viewModel.addMessage( viewModel.handleError(
context = context,
model = model, model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.") modelManagerViewModel = modelManagerViewModel,
) triggeredMessage = message,
modelManagerViewModel.initializeModel(
context = context, task = viewModel.task, model = model, force = true
) )
}) })
} }
@ -90,6 +136,9 @@ fun LlmChatScreen(
) )
} }
}, },
onResetSessionClicked = { model ->
viewModel.resetSession(model = model)
},
showStopButtonInInputWhenInProgress = true, showStopButtonInInputWhenInProgress = true,
onStopButtonClicked = { model -> onStopButtonClicked = { model ->
viewModel.stopResponse(model = model) viewModel.stopResponse(model = model)
@ -98,4 +147,3 @@ fun LlmChatScreen(
modifier = modifier, modifier = modifier,
) )
} }

View file

@ -16,18 +16,23 @@
package com.google.aiedge.gallery.ui.llmchat package com.google.aiedge.gallery.ui.llmchat
import android.content.Context
import android.graphics.Bitmap
import android.util.Log import android.util.Log
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
import com.google.aiedge.gallery.ui.common.chat.ChatSide import com.google.aiedge.gallery.ui.common.chat.ChatSide
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import com.google.aiedge.gallery.ui.common.chat.Stat import com.google.aiedge.gallery.ui.common.chat.Stat
import kotlinx.coroutines.CoroutineExceptionHandler import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -40,8 +45,8 @@ private val STATS = listOf(
Stat(id = "latency", label = "Latency", unit = "sec") Stat(id = "latency", label = "Latency", unit = "sec")
) )
class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
fun generateResponse(model: Model, input: String, onError: () -> Unit) { fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
setInProgress(true) setInProgress(true)
@ -58,7 +63,10 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
// Run inference. // Run inference.
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
val prefillTokens = instance.session.sizeInTokens(input) var prefillTokens = instance.session.sizeInTokens(input)
if (image != null) {
prefillTokens += 257
}
var firstRun = true var firstRun = true
var timeToFirstToken = 0f var timeToFirstToken = 0f
@ -69,9 +77,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
try { try {
LlmChatModelHelper.runInference( LlmChatModelHelper.runInference(model = model,
model = model,
input = input, input = input,
image = image,
resultListener = { partialResult, done -> resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis() val curTs = System.currentTimeMillis()
@ -92,32 +100,27 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
// Add an empty message that will receive streaming results. // Add an empty message that will receive streaming results.
addMessage( addMessage(
model = model, model = model, message = ChatMessageText(content = "", side = ChatSide.AGENT)
message = ChatMessageText(content = "", side = ChatSide.AGENT)
) )
} }
// Incrementally update the streamed partial results. // Incrementally update the streamed partial results.
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1 val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
updateLastTextMessageContentIncrementally( updateLastTextMessageContentIncrementally(
model = model, model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat()
partialContent = partialResult,
latencyMs = latencyMs.toFloat()
) )
if (done) { if (done) {
setInProgress(false) setInProgress(false)
decodeSpeed = decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) { if (decodeSpeed.isNaN()) {
decodeSpeed = 0f decodeSpeed = 0f
} }
if (lastMessage is ChatMessageText) { if (lastMessage is ChatMessageText) {
updateLastTextMessageLlmBenchmarkResult( updateLastTextMessageLlmBenchmarkResult(
model = model, llmBenchmarkResult = model = model, llmBenchmarkResult = ChatMessageBenchmarkLlmResult(
ChatMessageBenchmarkLlmResult(
orderedStats = STATS, orderedStats = STATS,
statValues = mutableMapOf( statValues = mutableMapOf(
"prefill_speed" to prefillSpeed, "prefill_speed" to prefillSpeed,
@ -131,10 +134,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
) )
} }
} }
}, cleanUpListener = { },
cleanUpListener = {
setInProgress(false) setInProgress(false)
}) })
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error occurred while running inference", e)
setInProgress(false) setInProgress(false)
onError() onError()
} }
@ -143,6 +148,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
fun stopResponse(model: Model) { fun stopResponse(model: Model) {
Log.d(TAG, "Stopping response for model ${model.name}...") Log.d(TAG, "Stopping response for model ${model.name}...")
if (getLastMessage(model = model) is ChatMessageLoading) {
removeLastMessage(model = model)
}
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
setInProgress(false) setInProgress(false)
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
@ -150,6 +158,25 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
} }
} }
fun resetSession(model: Model) {
viewModelScope.launch(Dispatchers.Default) {
setIsResettingSession(true)
clearAllMessages(model = model)
stopResponse(model = model)
while (true) {
try {
LlmChatModelHelper.resetSession(model = model)
break
} catch (e: Exception) {
Log.d(TAG, "Failed to reset session. Trying again")
}
delay(200)
}
setIsResettingSession(false)
}
}
fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) { fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized. // Wait for model to be initialized.
@ -162,9 +189,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
// Run inference. // Run inference.
generateResponse( generateResponse(
model = model, model = model, input = message.content, onError = onError
input = message.content,
onError = onError
) )
} }
} }
@ -199,8 +224,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
var decodeSpeed: Float var decodeSpeed: Float
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
var lastUpdateTime = 0L var lastUpdateTime = 0L
LlmChatModelHelper.runInference( LlmChatModelHelper.runInference(model = model,
model = model,
input = message.content, input = message.content,
resultListener = { partialResult, done -> resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis() val curTs = System.currentTimeMillis()
@ -232,14 +256,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
result.append(partialResult) result.append(partialResult)
if (curTs - lastUpdateTime > 500 || done) { if (curTs - lastUpdateTime > 500 || done) {
decodeSpeed = decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) { if (decodeSpeed.isNaN()) {
decodeSpeed = 0f decodeSpeed = 0f
} }
replaceLastMessage( replaceLastMessage(
model = model, model = model, message = ChatMessageBenchmarkLlmResult(
message = ChatMessageBenchmarkLlmResult(
orderedStats = STATS, orderedStats = STATS,
statValues = mutableMapOf( statValues = mutableMapOf(
"prefill_speed" to prefillSpeed, "prefill_speed" to prefillSpeed,
@ -249,8 +271,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
), ),
running = !done, running = !done,
latencyMs = -1f, latencyMs = -1f,
), ), type = ChatMessageType.BENCHMARK_LLM_RESULT
type = ChatMessageType.BENCHMARK_LLM_RESULT
) )
lastUpdateTime = curTs lastUpdateTime = curTs
@ -261,9 +282,53 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
}, },
cleanUpListener = { cleanUpListener = {
setInProgress(false) setInProgress(false)
} })
)
}
} }
} }
fun handleError(
context: Context,
model: Model,
modelManagerViewModel: ModelManagerViewModel,
triggeredMessage: ChatMessageText,
) {
// Clean up.
modelManagerViewModel.cleanupModel(task = task, model = model)
// Remove the "loading" message.
if (getLastMessage(model = model) is ChatMessageLoading) {
removeLastMessage(model = model)
}
// Remove the last Text message.
if (getLastMessage(model = model) == triggeredMessage) {
removeLastMessage(model = model)
}
// Add a warning message for re-initializing the session.
addMessage(
model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the session.")
)
// Add the triggered message back.
addMessage(model = model, message = triggeredMessage)
// Re-initialize the session/engine.
modelManagerViewModel.initializeModel(
context = context, task = task, model = model
)
// Re-generate the response automatically.
generateResponse(model = model, input = triggeredMessage.content, onError = {
handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = triggeredMessage
)
})
}
}
class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT)

View file

@ -155,9 +155,14 @@ fun LlmSingleTurnScreen(
PromptTemplatesPanel( PromptTemplatesPanel(
model = selectedModel, model = selectedModel,
viewModel = viewModel, viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSend = { fullPrompt -> onSend = { fullPrompt ->
viewModel.generateResponse(model = selectedModel, input = fullPrompt) viewModel.generateResponse(model = selectedModel, input = fullPrompt)
}, modifier = Modifier.fillMaxSize() },
onStopButtonClicked = { model ->
viewModel.stopResponse(model = model)
},
modifier = Modifier.fillMaxSize()
) )
}, },
bottomView = { bottomView = {

View file

@ -23,6 +23,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
import com.google.aiedge.gallery.ui.common.chat.Stat import com.google.aiedge.gallery.ui.common.chat.Stat
import com.google.aiedge.gallery.ui.common.processLlmResponse import com.google.aiedge.gallery.ui.common.processLlmResponse
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
@ -194,6 +195,15 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
} }
} }
fun stopResponse(model: Model) {
Log.d(TAG, "Stopping response for model ${model.name}...")
viewModelScope.launch(Dispatchers.Default) {
setInProgress(false)
val instance = model.instance as LlmModelInstance
instance.session.cancelGenerateResponseAsync()
}
}
private fun createUiState(task: Task): LlmSingleTurnUiState { private fun createUiState(task: Task): LlmSingleTurnUiState {
val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf() val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf()
val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> = val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> =

View file

@ -43,11 +43,13 @@ import androidx.compose.material.icons.outlined.Description
import androidx.compose.material.icons.outlined.ExpandLess import androidx.compose.material.icons.outlined.ExpandLess
import androidx.compose.material.icons.outlined.ExpandMore import androidx.compose.material.icons.outlined.ExpandMore
import androidx.compose.material.icons.rounded.Add import androidx.compose.material.icons.rounded.Add
import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material.icons.rounded.Visibility import androidx.compose.material.icons.rounded.Visibility
import androidx.compose.material.icons.rounded.VisibilityOff import androidx.compose.material.icons.rounded.VisibilityOff
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.FilterChipDefaults import androidx.compose.material3.FilterChipDefaults
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet import androidx.compose.material3.ModalBottomSheet
@ -86,9 +88,11 @@ import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.customColors import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.launch
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries
private val TAB_TITLES = PromptTemplateType.entries.map { it.label } private val TAB_TITLES = PromptTemplateType.entries.map { it.label }
@ -101,11 +105,14 @@ const val FULL_PROMPT_SWITCH_KEY = "full_prompt"
fun PromptTemplatesPanel( fun PromptTemplatesPanel(
model: Model, model: Model,
viewModel: LlmSingleTurnViewModel, viewModel: LlmSingleTurnViewModel,
modelManagerViewModel: ModelManagerViewModel,
onSend: (fullPrompt: String) -> Unit, onSend: (fullPrompt: String) -> Unit,
onStopButtonClicked: (Model) -> Unit,
modifier: Modifier = Modifier modifier: Modifier = Modifier
) { ) {
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedPromptTemplateType = uiState.selectedPromptTemplateType val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val inProgress = uiState.inProgress val inProgress = uiState.inProgress
var selectedTabIndex by remember { mutableIntStateOf(0) } var selectedTabIndex by remember { mutableIntStateOf(0) }
@ -123,6 +130,8 @@ fun PromptTemplatesPanel(
val focusManager = LocalFocusManager.current val focusManager = LocalFocusManager.current
val interactionSource = remember { MutableInteractionSource() } val interactionSource = remember { MutableInteractionSource() }
val expandedStates = remember { mutableStateMapOf<String, Boolean>() } val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[model.name]
// Update input editor values when prompt template changes. // Update input editor values when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) { LaunchedEffect(selectedPromptTemplateType) {
@ -328,6 +337,25 @@ fun PromptTemplatesPanel(
) )
} }
val modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (inProgress && !modelInitializing) {
IconButton(
onClick = {
onStopButtonClicked(model)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
),
modifier = Modifier.size(ICON_BUTTON_SIZE)
) {
Icon(
Icons.Rounded.Stop,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
} else {
// Send button // Send button
OutlinedIconButton( OutlinedIconButton(
enabled = !inProgress && curTextInputContent.isNotEmpty(), enabled = !inProgress && curTextInputContent.isNotEmpty(),
@ -356,6 +384,7 @@ fun PromptTemplatesPanel(
} }
} }
} }
}
if (showExamplePromptBottomSheet) { if (showExamplePromptBottomSheet) {
ModalBottomSheet( ModalBottomSheet(

View file

@ -21,6 +21,10 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
@ -48,6 +52,24 @@ fun ModelManager(
if (task.models.size != 1) { if (task.models.size != 1) {
title += "s" title += "s"
} }
// Model count.
val modelCount by remember {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
task.models.size
} else {
-1
}
}
}
// Navigate up when there are no models left.
LaunchedEffect(modelCount) {
if (modelCount == 0) {
navigateUp()
}
}
// Handle system's edge swipe. // Handle system's edge swipe.
BackHandler { BackHandler {

View file

@ -38,6 +38,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
@ -58,6 +59,7 @@ import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import net.openid.appauth.AuthorizationException import net.openid.appauth.AuthorizationException
import net.openid.appauth.AuthorizationRequest import net.openid.appauth.AuthorizationRequest
@ -229,7 +231,10 @@ open class ModelManagerViewModel(
} }
dataStoreRepository.saveImportedModels(importedModels = importedModels) dataStoreRepository.saveImportedModels(importedModels = importedModels)
} }
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus) val newUiState = uiState.value.copy(
modelDownloadStatus = curModelDownloadStatus,
tasks = uiState.value.tasks.toList()
)
_uiState.update { newUiState } _uiState.update { newUiState }
} }
@ -312,6 +317,12 @@ open class ModelManagerViewModel(
onDone = onDone, onDone = onDone,
) )
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize( TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
context = context, model = model, onDone = onDone context = context, model = model, onDone = onDone
) )
@ -331,6 +342,7 @@ open class ModelManagerViewModel(
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model) TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model) TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {} TaskType.TEST_TASK_2 -> {}
@ -434,14 +446,16 @@ open class ModelManagerViewModel(
// Create model. // Create model.
val model = createModelFromImportedModelInfo(info = info) val model = createModelFromImportedModelInfo(info = info)
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES, TASK_LLM_IMAGE_TO_TEXT)) {
// Remove duplicated imported model if existed. // Remove duplicated imported model if existed.
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES)) {
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) { if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first") Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex) task.models.removeAt(modelIndex)
} }
if (task == TASK_LLM_IMAGE_TO_TEXT && model.llmSupportImage || task != TASK_LLM_IMAGE_TO_TEXT) {
task.models.add(model) task.models.add(model)
}
task.updateTrigger.value = System.currentTimeMillis() task.updateTrigger.value = System.currentTimeMillis()
} }
@ -632,8 +646,7 @@ open class ModelManagerViewModel(
fun loadModelAllowlist() { fun loadModelAllowlist() {
_uiState.update { _uiState.update {
uiState.value.copy( uiState.value.copy(
loadingModelAllowlist = true, loadingModelAllowlist = true, loadingModelAllowlistError = ""
loadingModelAllowlistError = ""
) )
} }
@ -663,6 +676,9 @@ open class ModelManagerViewModel(
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) { if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
TASK_LLM_USECASES.models.add(model) TASK_LLM_USECASES.models.add(model)
} }
if (allowedModel.taskTypes.contains(TASK_LLM_IMAGE_TO_TEXT.type.id)) {
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
}
} }
// Pre-process all tasks. // Pre-process all tasks.
@ -717,6 +733,9 @@ open class ModelManagerViewModel(
// Add to task. // Add to task.
TASK_LLM_CHAT.models.add(model) TASK_LLM_CHAT.models.add(model)
TASK_LLM_USECASES.models.add(model) TASK_LLM_USECASES.models.add(model)
if (model.llmSupportImage) {
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
}
// Update status. // Update status.
modelDownloadStatus[model.name] = ModelDownloadStatus( modelDownloadStatus[model.name] = ModelDownloadStatus(
@ -731,7 +750,7 @@ open class ModelManagerViewModel(
Log.d(TAG, "model download status: $modelDownloadStatus") Log.d(TAG, "model download status: $modelDownloadStatus")
return ModelManagerUiState( return ModelManagerUiState(
tasks = TASKS, tasks = TASKS.toList(),
modelDownloadStatus = modelDownloadStatus, modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances, modelInitializationStatus = modelInstances,
textInputHistory = textInputHistory, textInputHistory = textInputHistory,
@ -763,6 +782,9 @@ open class ModelManagerViewModel(
) as Float, ) as Float,
accelerators = accelerators, accelerators = accelerators,
) )
val llmSupportImage = convertValueToTargetType(
info.defaultValues[ConfigKey.SUPPORT_IMAGE.label] ?: false, ValueType.BOOLEAN
) as Boolean
val model = Model( val model = Model(
name = info.fileName, name = info.fileName,
url = "", url = "",
@ -770,7 +792,9 @@ open class ModelManagerViewModel(
sizeInBytes = info.fileSize, sizeInBytes = info.fileSize,
downloadFileName = "$IMPORTS_DIR/${info.fileName}", downloadFileName = "$IMPORTS_DIR/${info.fileName}",
showBenchmarkButton = false, showBenchmarkButton = false,
showRunAgainButton = false,
imported = true, imported = true,
llmSupportImage = llmSupportImage,
) )
model.preProcess() model.preProcess()
@ -803,6 +827,7 @@ open class ModelManagerViewModel(
) )
} }
@OptIn(ExperimentalSerializationApi::class)
private inline fun <reified T> getJsonResponse(url: String): T? { private inline fun <reified T> getJsonResponse(url: String): T? {
try { try {
val connection = URL(url).openConnection() as HttpURLConnection val connection = URL(url).openConnection() as HttpURLConnection
@ -815,7 +840,12 @@ open class ModelManagerViewModel(
val response = inputStream.bufferedReader().use { it.readText() } val response = inputStream.bufferedReader().use { it.readText() }
// Parse JSON using kotlinx.serialization // Parse JSON using kotlinx.serialization
val json = Json { ignoreUnknownKeys = true } // Handle potential extra fields val json = Json {
// Handle potential extra fields
ignoreUnknownKeys = true
allowComments = true
allowTrailingComma = true
}
val jsonObj = json.decodeFromString<T>(response) val jsonObj = json.decodeFromString<T>(response)
return jsonObj return jsonObj
} else { } else {

View file

@ -46,6 +46,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
@ -59,6 +60,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextDestination
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextScreen
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.aiedge.gallery.ui.modelmanager.ModelManager import com.google.aiedge.gallery.ui.modelmanager.ModelManager
@ -129,8 +132,7 @@ fun GalleryNavHost(
) { ) {
val curPickedTask = pickedTask val curPickedTask = pickedTask
if (curPickedTask != null) { if (curPickedTask != null) {
ModelManager( ModelManager(viewModel = modelManagerViewModel,
viewModel = modelManagerViewModel,
task = curPickedTask, task = curPickedTask,
onModelClicked = { model -> onModelClicked = { model ->
navigateToTaskScreen( navigateToTaskScreen(
@ -207,7 +209,7 @@ fun GalleryNavHost(
} }
} }
// LLMm chat demos. // LLM chat demos.
composable( composable(
route = "${LlmChatDestination.route}/{modelName}", route = "${LlmChatDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
@ -224,7 +226,7 @@ fun GalleryNavHost(
} }
} }
// LLMm single turn. // LLM single turn.
composable( composable(
route = "${LlmSingleTurnDestination.route}/{modelName}", route = "${LlmSingleTurnDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
@ -241,6 +243,22 @@ fun GalleryNavHost(
} }
} }
// LLM image to text.
composable(
route = "${LlmImageToTextDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_IMAGE_TO_TEXT)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmImageToTextScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
} }
// Handle incoming intents for deep links // Handle incoming intents for deep links
@ -254,9 +272,7 @@ fun GalleryNavHost(
getModelByName(modelName)?.let { model -> getModelByName(modelName)?.let { model ->
// TODO(jingjin): need to show a list of possible tasks for this model. // TODO(jingjin): need to show a list of possible tasks for this model.
navigateToTaskScreen( navigateToTaskScreen(
navController = navController, navController = navController, taskType = TaskType.LLM_CHAT, model = model
taskType = TaskType.LLM_CHAT,
model = model
) )
} }
} }
@ -271,6 +287,7 @@ fun navigateToTaskScreen(
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}") TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}") TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}") TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_IMAGE_TO_TEXT -> navController.navigate("${LlmImageToTextDestination.route}/${modelName}")
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}") TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}") TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}

View file

@ -44,7 +44,8 @@ fun TextClassificationScreen(
task = viewModel.task, task = viewModel.task,
viewModel = viewModel, viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message -> onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage( viewModel.addMessage(
model = model, model = model,
message = message, message = message,

View file

@ -7,7 +7,7 @@ junitVersion = "1.2.1"
espressoCore = "3.6.1" espressoCore = "3.6.1"
lifecycleRuntimeKtx = "2.8.7" lifecycleRuntimeKtx = "2.8.7"
activityCompose = "1.10.1" activityCompose = "1.10.1"
composeBom = "2025.03.01" composeBom = "2025.05.00"
navigation = "2.8.9" navigation = "2.8.9"
serializationPlugin = "2.0.21" serializationPlugin = "2.0.21"
serializationJson = "1.7.3" serializationJson = "1.7.3"