diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt index 2fe0b03..508e9f3 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt @@ -19,7 +19,6 @@ package com.google.aiedge.gallery.data import android.content.Context import com.google.aiedge.gallery.ui.common.chat.PromptTemplate import com.google.aiedge.gallery.ui.common.convertValueToTargetType -import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs import java.io.File data class ModelDataFile( @@ -91,6 +90,9 @@ data class Model( /** The prompt templates for the model (only for LLM). */ val llmPromptTemplates: List = listOf(), + /** Whether the LLM model supports image input. */ + val llmSupportImage: Boolean = false, + /** Whether the model is imported or not. */ val imported: Boolean = false, @@ -204,6 +206,7 @@ enum class ConfigKey(val label: String) { DEFAULT_TOPK("Default TopK"), DEFAULT_TOPP("Default TopP"), DEFAULT_TEMPERATURE("Default temperature"), + SUPPORT_IMAGE("Support image"), MAX_RESULT_COUNT("Max result count"), USE_GPU("Use GPU"), ACCELERATOR("Accelerator"), @@ -250,83 +253,9 @@ const val IMAGE_CLASSIFICATION_INFO = "" const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/litert/android" -const val LLM_CHAT_INFO = - "Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms" - const val IMAGE_GENERATION_INFO = "Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)" -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Model spec. - -val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model( - name = "Gemma 2B (GPU int4)", - downloadFileName = "gemma-2b-it-gpu-int4.bin", - url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin", - sizeInBytes = 1354301440L, - configs = createLlmChatConfigs( - accelerators = listOf(Accelerator.GPU) - ), - showBenchmarkButton = false, - info = LLM_CHAT_INFO, - learnMoreUrl = "https://huggingface.co/litert-community", -) - -val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model( - name = "Gemma 2 2B (GPU int8)", - downloadFileName = "gemma2-2b-it-gpu-int8.bin", - url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin", - sizeInBytes = 2627141632L, - configs = createLlmChatConfigs( - accelerators = listOf(Accelerator.GPU) - ), - showBenchmarkButton = false, - info = LLM_CHAT_INFO, - learnMoreUrl = "https://huggingface.co/litert-community", -) - -val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model( - name = "Gemma 3 1B (int4)", - downloadFileName = "gemma3-1b-it-int4.task", - url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true", - sizeInBytes = 554661243L, - configs = createLlmChatConfigs( - defaultTopK = 64, - defaultTopP = 0.95f, - accelerators = listOf(Accelerator.CPU, Accelerator.GPU) - ), - showBenchmarkButton = false, - info = LLM_CHAT_INFO, - learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT", - llmPromptTemplates = listOf( - PromptTemplate( - title = "Emoji Fun", - description = "Generate emojis by emotions", - prompt = "Show me emojis grouped by emotions" - ), - PromptTemplate( - title = "Trip Planner", - description = "Plan a trip to a destination", - prompt = "Plan a two-day trip to San Francisco" - ), - ) -) - -val MODEL_LLM_DEEPSEEK: Model = Model( - name = "Deepseek", - downloadFileName = "deepseek.task", - url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true", - sizeInBytes = 1860686856L, - configs = createLlmChatConfigs( - defaultTemperature = 0.6f, - defaultTopK = 40, - defaultTopP = 0.7f, - accelerators = listOf(Accelerator.CPU) - ), - showBenchmarkButton = false, - info = LLM_CHAT_INFO, - learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B", -) val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model( name = "MobileBert", @@ -414,12 +343,5 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList = mutableListOf( MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2, ) -val MODELS_LLM: MutableList = mutableListOf( -// MODEL_LLM_GEMMA_2B_GPU_INT4, -// MODEL_LLM_GEMMA_2_2B_GPU_INT8, - MODEL_LLM_GEMMA_3_1B_INT4, - MODEL_LLM_DEEPSEEK, -) - val MODELS_IMAGE_GENERATION: MutableList = mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt index 1ddfd0f..5d301e9 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt @@ -35,6 +35,7 @@ data class AllowedModel( val defaultConfig: Map, val taskTypes: List, val disabled: Boolean? = null, + val llmSupportImage: Boolean? = null, ) { fun toModel(): Model { // Construct HF download url. @@ -48,7 +49,7 @@ data class AllowedModel( var defaultTopK: Int = DEFAULT_TOPK var defaultTopP: Float = DEFAULT_TOPP var defaultTemperature: Float = DEFAULT_TEMPERATURE - var defaultMaxToken: Int = 1024 + var defaultMaxToken = 1024 var accelerators: List = DEFAULT_ACCELERATORS if (defaultConfig.containsKey("topK")) { defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK) @@ -84,9 +85,10 @@ data class AllowedModel( // Misc. var showBenchmarkButton = true - val showRunAgainButton = true - if (taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)) { + var showRunAgainButton = true + if (isLlmModel) { showBenchmarkButton = false + showRunAgainButton = false } return Model( @@ -99,7 +101,8 @@ data class AllowedModel( downloadFileName = modelFile, showBenchmarkButton = showBenchmarkButton, showRunAgainButton = showRunAgainButton, - learnMoreUrl = "https://huggingface.co/${modelId}" + learnMoreUrl = "https://huggingface.co/${modelId}", + llmSupportImage = llmSupportImage == true, ) } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt index e50fb47..6218046 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt @@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data import androidx.annotation.StringRes import androidx.compose.material.icons.Icons import androidx.compose.material.icons.outlined.Forum +import androidx.compose.material.icons.outlined.Mms import androidx.compose.material.icons.outlined.Widgets import androidx.compose.material.icons.rounded.ImageSearch import androidx.compose.runtime.MutableState @@ -33,6 +34,7 @@ enum class TaskType(val label: String, val id: String) { IMAGE_GENERATION(label = "Image Generation", id = "image_generation"), LLM_CHAT(label = "LLM Chat", id = "llm_chat"), LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"), + LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"), TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_2(label = "Test task 2", id = "test_task_2") @@ -93,7 +95,7 @@ val TASK_LLM_CHAT = Task( icon = Icons.Outlined.Forum, // models = MODELS_LLM, models = mutableListOf(), - description = "Chat with a on-device large language model", + description = "Chat with on-device large language models", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat @@ -110,6 +112,17 @@ val TASK_LLM_USECASES = Task( textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat ) +val TASK_LLM_IMAGE_TO_TEXT = Task( + type = TaskType.LLM_IMAGE_TO_TEXT, + icon = Icons.Outlined.Mms, +// models = MODELS_LLM, + models = mutableListOf(), + description = "Ask questions about images with on-device large language models", + docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", + sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", + textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat +) + val TASK_IMAGE_GENERATION = Task( type = TaskType.IMAGE_GENERATION, iconVectorResourceId = R.drawable.image_spark, @@ -127,6 +140,7 @@ val TASKS: List = listOf( // TASK_IMAGE_GENERATION, TASK_LLM_USECASES, TASK_LLM_CHAT, + TASK_LLM_IMAGE_TO_TEXT ) fun getModelByName(name: String): Model? { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt index a1a4739..bc3d4ab 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt @@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel +import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel @@ -62,6 +63,12 @@ object ViewModelProvider { LlmSingleTurnViewModel() } + // Initializer for LlmImageToTextViewModel. + initializer { + LlmImageToTextViewModel() + } + + // Initializer for ImageGenerationViewModel. initializer { ImageGenerationViewModel() } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt index 1d940db..c9b17e3 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt @@ -17,13 +17,17 @@ package com.google.aiedge.gallery.ui.common import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.size import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.rounded.ArrowBack -import androidx.compose.material.icons.rounded.Settings +import androidx.compose.material.icons.rounded.MapsUgc +import androidx.compose.material.icons.rounded.Tune import androidx.compose.material3.CenterAlignedTopAppBar +import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.IconButton @@ -38,6 +42,7 @@ import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.draw.alpha +import androidx.compose.ui.draw.scale import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.res.vectorResource @@ -47,6 +52,7 @@ import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.chat.ConfigDialog +import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel @OptIn(ExperimentalMaterial3Api::class) @@ -58,12 +64,17 @@ fun ModelPageAppBar( onBackClicked: () -> Unit, onModelSelected: (Model) -> Unit, modifier: Modifier = Modifier, + isResettingSession: Boolean = false, + onResetSessionClicked: (Model) -> Unit = {}, + showResetSessionButton: Boolean = false, onConfigChanged: (oldConfigValues: Map, newConfigValues: Map) -> Unit = { _, _ -> }, ) { var showConfigDialog by remember { mutableStateOf(false) } val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val context = LocalContext.current val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name] + val modelInitializationStatus = + modelManagerUiState.modelInitializationStatus[model.name] CenterAlignedTopAppBar(title = { Column( @@ -110,19 +121,58 @@ fun ModelPageAppBar( actions = { val showConfigButton = model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED - IconButton( - onClick = { - showConfigDialog = true - }, - enabled = showConfigButton, - modifier = Modifier.alpha(if (showConfigButton) 1f else 0f) - ) { - Icon( - imageVector = Icons.Rounded.Settings, - contentDescription = "", - tint = MaterialTheme.colorScheme.primary - ) + Box(modifier = Modifier.size(42.dp), contentAlignment = Alignment.Center) { + var configButtonOffset = 0.dp + if (showConfigButton && showResetSessionButton) { + configButtonOffset = (-40).dp + } + val isModelInitializing = + modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING + if (showConfigButton) { + IconButton( + onClick = { + showConfigDialog = true + }, + enabled = !isModelInitializing, + modifier = Modifier + .scale(0.75f) + .offset(x = configButtonOffset) + .alpha(if (isModelInitializing) 0.5f else 1f) + ) { + Icon( + imageVector = Icons.Rounded.Tune, + contentDescription = "", + tint = MaterialTheme.colorScheme.primary + ) + } + } + if (showResetSessionButton) { + if (isResettingSession) { + CircularProgressIndicator( + trackColor = MaterialTheme.colorScheme.surfaceVariant, + strokeWidth = 2.dp, + modifier = Modifier.size(16.dp) + ) + } else { + IconButton( + onClick = { + onResetSessionClicked(model) + }, + enabled = !isModelInitializing, + modifier = Modifier + .scale(0.75f) + .alpha(if (isModelInitializing) 0.5f else 1f) + ) { + Icon( + imageVector = Icons.Rounded.MapsUgc, + contentDescription = "", + tint = MaterialTheme.colorScheme.primary + ) + } + } + } } + }) // Config dialog. diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt index 8c82b58..b4636f5 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt @@ -29,6 +29,7 @@ import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size +import androidx.compose.foundation.layout.widthIn import androidx.compose.foundation.pager.HorizontalPager import androidx.compose.foundation.pager.rememberPagerState import androidx.compose.foundation.shape.CircleShape @@ -54,6 +55,9 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.clip import androidx.compose.ui.graphics.graphicsLayer +import androidx.compose.ui.platform.LocalDensity +import androidx.compose.ui.platform.LocalWindowInfo +import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Task @@ -77,6 +81,10 @@ fun ModelPickerChipsPager( val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) val scope = rememberCoroutineScope() + val density = LocalDensity.current + val windowInfo = LocalWindowInfo.current + val screenWidthDp = + remember { with(density) { windowInfo.containerSize.width.toDp() } } val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel), pageCount = { task.models.size }) @@ -140,7 +148,12 @@ fun ModelPickerChipsPager( Text( model.name, style = MaterialTheme.typography.labelSmall, - modifier = Modifier.padding(start = 4.dp), + modifier = Modifier + .padding(start = 4.dp) + .widthIn(0.dp, screenWidthDp - 250.dp), + maxLines = 1, + overflow = TextOverflow.MiddleEllipsis + ) Icon( Icons.Rounded.ArrowDropDown, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt index 20c967d..ea1c650 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt @@ -109,7 +109,7 @@ fun ChatPanel( task: Task, selectedModel: Model, viewModel: ChatViewModel, - onSendMessage: (Model, ChatMessage) -> Unit, + onSendMessage: (Model, List) -> Unit, onRunAgainClicked: (Model, ChatMessage) -> Unit, onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit, navigateUp: () -> Unit, @@ -280,7 +280,8 @@ fun ChatPanel( task = task, onPromptClicked = { template -> onSendMessage( - selectedModel, ChatMessageText(content = template.prompt, side = ChatSide.USER) + selectedModel, + listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER)) ) }) @@ -430,12 +431,15 @@ fun ChatPanel( // Chat input when (chatInputType) { ChatInputType.TEXT -> { - val isLlmTask = task.type == TaskType.LLM_CHAT - val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates) +// val isLlmTask = task.type == TaskType.LLM_CHAT +// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates) + val hasImageMessage = messages.any { it is ChatMessageImage } MessageInputText( modelManagerViewModel = modelManagerViewModel, curMessage = curMessage, inProgress = uiState.inProgress, + isResettingSession = uiState.isResettingSession, + hasImageMessage = hasImageMessage, modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, textFieldPlaceHolderRes = task.textInputPlaceHolderRes, onValueChanged = { curMessage = it }, @@ -445,13 +449,17 @@ fun ChatPanel( }, onOpenPromptTemplatesClicked = { onSendMessage( - selectedModel, ChatMessagePromptTemplates( - templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false + selectedModel, listOf( + ChatMessagePromptTemplates( + templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false + ) ) ) }, onStopButtonClicked = onStopButtonClicked, - showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, +// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, + showPromptTemplatesInMenu = false, + showImagePickerInMenu = selectedModel.llmSupportImage == true, showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress, ) } @@ -461,8 +469,10 @@ fun ChatPanel( streamingMessage = streamingMessage, onImageSelected = { bitmap -> onSendMessage( - selectedModel, ChatMessageImage( - bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER + selectedModel, listOf( + ChatMessageImage( + bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER + ) ) ) }, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt index 60992de..91bb5f7 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt @@ -70,16 +70,18 @@ fun ChatView( task: Task, viewModel: ChatViewModel, modelManagerViewModel: ModelManagerViewModel, - onSendMessage: (Model, ChatMessage) -> Unit, + onSendMessage: (Model, List) -> Unit, onRunAgainClicked: (Model, ChatMessage) -> Unit, onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit, navigateUp: () -> Unit, modifier: Modifier = Modifier, + onResetSessionClicked: (Model) -> Unit = {}, onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> }, onStopButtonClicked: (Model) -> Unit = {}, chatInputType: ChatInputType = ChatInputType.TEXT, showStopButtonInInputWhenInProgress: Boolean = false, ) { + val uiStat by viewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val selectedModel = modelManagerUiState.selectedModel @@ -155,6 +157,9 @@ fun ChatView( task = task, model = selectedModel, modelManagerViewModel = modelManagerViewModel, + showResetSessionButton = true, + isResettingSession = uiStat.isResettingSession, + onResetSessionClicked = onResetSessionClicked, onConfigChanged = { old, new -> viewModel.addConfigChangedMessage( oldConfigValues = old, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt index 34f475e..c9a2cff 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt @@ -33,6 +33,11 @@ data class ChatUiState( */ val inProgress: Boolean = false, + /** + * Indicates whether the session is being reset. + */ + val isResettingSession: Boolean = false, + /** * A map of model names to lists of chat messages. */ @@ -106,6 +111,12 @@ open class ChatViewModel(val task: Task) : ViewModel() { _uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) } } + fun clearAllMessages(model: Model) { + val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap() + newMessagesByModel[model.name] = mutableListOf() + _uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) } + } + fun getLastMessage(model: Model): ChatMessage? { return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull() } @@ -189,6 +200,10 @@ open class ChatViewModel(val task: Task) : ViewModel() { _uiState.update { _uiState.value.copy(inProgress = inProgress) } } + fun setIsResettingSession(isResettingSession: Boolean) { + _uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) } + } + fun addConfigChangedMessage( oldConfigValues: Map, newConfigValues: Map, model: Model ) { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MarkdownText.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MarkdownText.kt index fc0d315..265d93a 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MarkdownText.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MarkdownText.kt @@ -21,14 +21,18 @@ import androidx.compose.material3.ProvideTextStyle import androidx.compose.runtime.Composable import androidx.compose.runtime.CompositionLocalProvider import androidx.compose.ui.Modifier +import androidx.compose.ui.text.SpanStyle +import androidx.compose.ui.text.TextLinkStyles import androidx.compose.ui.text.TextStyle import androidx.compose.ui.text.font.FontFamily import androidx.compose.ui.tooling.preview.Preview import com.google.aiedge.gallery.ui.theme.GalleryTheme +import com.google.aiedge.gallery.ui.theme.customColors import com.halilibo.richtext.commonmark.Markdown import com.halilibo.richtext.ui.CodeBlockStyle import com.halilibo.richtext.ui.RichTextStyle import com.halilibo.richtext.ui.material3.RichText +import com.halilibo.richtext.ui.string.RichTextStringStyle /** * Composable function to display Markdown-formatted text. @@ -56,6 +60,11 @@ fun MarkdownText( fontSize = MaterialTheme.typography.bodySmall.fontSize, fontFamily = FontFamily.Monospace, ) + ), + stringStyle = RichTextStringStyle( + linkStyle = TextLinkStyles( + style = SpanStyle(color = MaterialTheme.customColors.linkColor) + ) ) ), ) { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyWarning.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyWarning.kt index 1c79ad2..b2039c1 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyWarning.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyWarning.kt @@ -46,7 +46,11 @@ fun MessageBodyWarning(message: ChatMessageWarning) { .clip(RoundedCornerShape(16.dp)) .background(MaterialTheme.colorScheme.tertiaryContainer) ) { - MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true) + MarkdownText( + text = message.content, + modifier = Modifier.padding(horizontal = 16.dp, vertical = 6.dp), + smallFontSize = true + ) } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt index 2029882..a24c34f 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt @@ -16,28 +16,45 @@ package com.google.aiedge.gallery.ui.common.chat +import android.Manifest +import android.content.Context +import android.content.pm.PackageManager +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.graphics.Matrix +import android.net.Uri +import androidx.activity.compose.rememberLauncherForActivityResult +import androidx.activity.result.PickVisualMediaRequest +import androidx.activity.result.contract.ActivityResultContracts import androidx.annotation.StringRes +import androidx.compose.foundation.Image +import androidx.compose.foundation.background import androidx.compose.foundation.border +import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.width +import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.rounded.Send import androidx.compose.material.icons.rounded.Add +import androidx.compose.material.icons.rounded.Close import androidx.compose.material.icons.rounded.History +import androidx.compose.material.icons.rounded.Photo +import androidx.compose.material.icons.rounded.PhotoCamera import androidx.compose.material.icons.rounded.PostAdd import androidx.compose.material.icons.rounded.Stop import androidx.compose.material3.DropdownMenu import androidx.compose.material3.DropdownMenuItem -import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.IconButton import androidx.compose.material3.IconButtonDefaults @@ -54,12 +71,17 @@ import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.draw.alpha +import androidx.compose.ui.draw.clip +import androidx.compose.ui.draw.shadow import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.asImageBitmap import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.res.stringResource import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp +import androidx.core.content.ContextCompat import com.google.aiedge.gallery.R +import com.google.aiedge.gallery.ui.common.createTempPictureUri import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.theme.GalleryTheme @@ -74,24 +96,107 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme fun MessageInputText( modelManagerViewModel: ModelManagerViewModel, curMessage: String, + isResettingSession: Boolean, inProgress: Boolean, + hasImageMessage: Boolean, modelInitializing: Boolean, @StringRes textFieldPlaceHolderRes: Int, onValueChanged: (String) -> Unit, - onSendMessage: (ChatMessage) -> Unit, + onSendMessage: (List) -> Unit, onOpenPromptTemplatesClicked: () -> Unit = {}, onStopButtonClicked: () -> Unit = {}, - showPromptTemplatesInMenu: Boolean = true, + showPromptTemplatesInMenu: Boolean = false, + showImagePickerInMenu: Boolean = false, showStopButtonWhenInProgress: Boolean = false, ) { + val context = LocalContext.current val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() var showAddContentMenu by remember { mutableStateOf(false) } var showTextInputHistorySheet by remember { mutableStateOf(false) } + var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) } + var pickedImages by remember { mutableStateOf>(listOf()) } + val updatePickedImages: (Bitmap) -> Unit = { bitmap -> + val newPickedImages: MutableList = mutableListOf() + newPickedImages.addAll(pickedImages) + newPickedImages.add(bitmap) + pickedImages = newPickedImages.toList() + } + + // launches camera + val cameraLauncher = + rememberLauncherForActivityResult(ActivityResultContracts.TakePicture()) { isImageSaved -> + if (isImageSaved) { + handleImageSelected( + context = context, + uri = tempPhotoUri, + onImageSelected = { bitmap -> + updatePickedImages(bitmap) + }, + rotateForPortrait = true, + ) + } + } + + // Permission request when taking picture. + val takePicturePermissionLauncher = rememberLauncherForActivityResult( + ActivityResultContracts.RequestPermission() + ) { permissionGranted -> + if (permissionGranted) { + showAddContentMenu = false + tempPhotoUri = context.createTempPictureUri() + cameraLauncher.launch(tempPhotoUri) + } + } + + // Registers a photo picker activity launcher in single-select mode. + val pickMedia = + rememberLauncherForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri -> + // Callback is invoked after the user selects a media item or closes the + // photo picker. + if (uri != null) { + handleImageSelected(context = context, uri = uri, onImageSelected = { bitmap -> + updatePickedImages(bitmap) + }) + } + } Box(contentAlignment = Alignment.CenterStart) { + // A preview panel for the selected image. + if (pickedImages.isNotEmpty()) { + Box( + contentAlignment = Alignment.TopEnd, modifier = Modifier.offset(x = 16.dp, y = (-80).dp) + ) { + Image( + bitmap = pickedImages.last().asImageBitmap(), + contentDescription = "", + modifier = Modifier + .height(80.dp) + .shadow(2.dp, shape = RoundedCornerShape(8.dp)) + .clip(RoundedCornerShape(8.dp)) + .border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)), + ) + Box(modifier = Modifier + .offset(x = 10.dp, y = (-10).dp) + .clip(CircleShape) + .background(MaterialTheme.colorScheme.surface) + .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape) + .clickable { + pickedImages = listOf() + }) { + Icon( + Icons.Rounded.Close, + contentDescription = "", + modifier = Modifier + .padding(3.dp) + .size(16.dp) + ) + } + } + } + // A plus button to show a popup menu to add stuff to the chat. IconButton( - enabled = !inProgress, + enabled = !inProgress && !isResettingSession, onClick = { showAddContentMenu = true }, modifier = Modifier .offset(x = 16.dp) @@ -113,6 +218,58 @@ fun MessageInputText( DropdownMenu( expanded = showAddContentMenu, onDismissRequest = { showAddContentMenu = false }) { + if (showImagePickerInMenu) { + // Take a photo. + DropdownMenuItem( + text = { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp) + ) { + Icon(Icons.Rounded.PhotoCamera, contentDescription = "") + Text("Take a photo") + } + }, + enabled = pickedImages.isEmpty() && !hasImageMessage, + onClick = { + // Check permission + when (PackageManager.PERMISSION_GRANTED) { + // Already got permission. Call the lambda. + ContextCompat.checkSelfPermission( + context, Manifest.permission.CAMERA + ) -> { + showAddContentMenu = false + tempPhotoUri = context.createTempPictureUri() + cameraLauncher.launch(tempPhotoUri) + } + + // Otherwise, ask for permission + else -> { + takePicturePermissionLauncher.launch(Manifest.permission.CAMERA) + } + } + }) + + // Pick an image from album. + DropdownMenuItem( + text = { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp) + ) { + Icon(Icons.Rounded.Photo, contentDescription = "") + Text("Pick from album") + } + }, + enabled = pickedImages.isEmpty() && !hasImageMessage, + onClick = { + // Launch the photo picker and let the user choose only images. + pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly)) + showAddContentMenu = false + }) + } + + // Prompt templates. if (showPromptTemplatesInMenu) { DropdownMenuItem(text = { Row( @@ -127,6 +284,7 @@ fun MessageInputText( showAddContentMenu = false }) } + // Prompt history. DropdownMenuItem(text = { Row( verticalAlignment = Alignment.CenterVertically, @@ -171,18 +329,19 @@ fun MessageInputText( ), ) { Icon( - Icons.Rounded.Stop, - contentDescription = "", - tint = MaterialTheme.colorScheme.primary + Icons.Rounded.Stop, contentDescription = "", tint = MaterialTheme.colorScheme.primary ) } } } // Send button. Only shown when text is not empty. else if (curMessage.isNotEmpty()) { IconButton( - enabled = !inProgress, + enabled = !inProgress && !isResettingSession, onClick = { - onSendMessage(ChatMessageText(content = curMessage.trim(), side = ChatSide.USER)) + onSendMessage( + createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim()) + ) + pickedImages = listOf() }, colors = IconButtonDefaults.iconButtonColors( containerColor = MaterialTheme.colorScheme.secondaryContainer, @@ -200,7 +359,6 @@ fun MessageInputText( } } - // A bottom sheet to show the text input history to pick from. if (showTextInputHistorySheet) { TextInputHistorySheet( @@ -209,7 +367,8 @@ fun MessageInputText( showTextInputHistorySheet = false }, onHistoryItemClicked = { item -> - onSendMessage(ChatMessageText(content = item, side = ChatSide.USER)) + onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item)) + pickedImages = listOf() modelManagerViewModel.promoteTextInputHistoryItem(item) }, onHistoryItemDeleted = { item -> @@ -217,9 +376,55 @@ fun MessageInputText( }, onHistoryItemsDeleteAll = { modelManagerViewModel.clearTextInputHistory() - } + }) + } +} + +private fun handleImageSelected( + context: Context, + uri: Uri, + onImageSelected: (Bitmap) -> Unit, + // For some reason, some Android phone would store the picture taken by the camera rotated + // horizontally. Use this flag to rotate the image back to portrait if the picture's width + // is bigger than height. + rotateForPortrait: Boolean = false, +) { + val bitmap: Bitmap? = try { + val inputStream = context.contentResolver.openInputStream(uri) + val tmpBitmap = BitmapFactory.decodeStream(inputStream) + if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) { + val matrix = Matrix() + matrix.postRotate(90f) + Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true) + } else { + tmpBitmap + } + } catch (e: Exception) { + e.printStackTrace() + null + } + if (bitmap != null) { + onImageSelected(bitmap) + } +} + +private fun createMessagesToSend(pickedImages: List, text: String): List { + val messages: MutableList = mutableListOf() + if (pickedImages.isNotEmpty()) { + val lastImage = pickedImages.last() + messages.add( + ChatMessageImage( + bitmap = lastImage, imageBitMap = lastImage.asImageBitmap(), side = ChatSide.USER + ) ) } + messages.add( + ChatMessageText( + content = text, side = ChatSide.USER + ) + ) + + return messages } @Preview(showBackground = true) @@ -233,6 +438,21 @@ fun MessageInputTextPreview() { modelManagerViewModel = PreviewModelManagerViewModel(context = context), curMessage = "hello", inProgress = false, + isResettingSession = false, + modelInitializing = false, + hasImageMessage = false, + textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, + onValueChanged = {}, + onSendMessage = {}, + showStopButtonWhenInProgress = true, + showImagePickerInMenu = true, + ) + MessageInputText( + modelManagerViewModel = PreviewModelManagerViewModel(context = context), + curMessage = "hello", + inProgress = false, + isResettingSession = false, + hasImageMessage = false, modelInitializing = false, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, onValueChanged = {}, @@ -243,6 +463,8 @@ fun MessageInputTextPreview() { modelManagerViewModel = PreviewModelManagerViewModel(context = context), curMessage = "hello", inProgress = true, + isResettingSession = false, + hasImageMessage = false, modelInitializing = false, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, onValueChanged = {}, @@ -252,6 +474,8 @@ fun MessageInputTextPreview() { modelManagerViewModel = PreviewModelManagerViewModel(context = context), curMessage = "", inProgress = false, + isResettingSession = false, + hasImageMessage = false, modelInitializing = false, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, onValueChanged = {}, @@ -261,6 +485,8 @@ fun MessageInputTextPreview() { modelManagerViewModel = PreviewModelManagerViewModel(context = context), curMessage = "", inProgress = true, + isResettingSession = false, + hasImageMessage = false, modelInitializing = false, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, onValueChanged = {}, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/TextInputHistorySheet.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/TextInputHistorySheet.kt index 67363a5..50629f2 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/TextInputHistorySheet.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/TextInputHistorySheet.kt @@ -52,6 +52,7 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.draw.clip import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import com.google.aiedge.gallery.R @@ -149,6 +150,8 @@ private fun SheetContent( Text( item, style = MaterialTheme.typography.bodyMedium, + maxLines = 3, + overflow = TextOverflow.Ellipsis, modifier = Modifier .padding(vertical = 16.dp) .padding(start = 16.dp) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt index d60afff..da435f9 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt @@ -76,7 +76,7 @@ fun ModelNameAndStatus( Text( model.name, maxLines = 1, - overflow = TextOverflow.Ellipsis, + overflow = TextOverflow.MiddleEllipsis, style = MaterialTheme.typography.titleMedium, modifier = modifier, ) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt index 4438e1f..0461118 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt @@ -136,7 +136,7 @@ fun HomeScreen( val snackbarHostState = remember { SnackbarHostState() } val scope = rememberCoroutineScope() - val tasks = uiState.tasks + val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 } val loadingHfModels = uiState.loadingHfModels val filePickerLauncher: ActivityResultLauncher = rememberLauncherForActivityResult( @@ -183,14 +183,14 @@ fun HomeScreen( ) { innerPadding -> Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) { TaskList( - tasks = tasks, + tasks = nonEmptyTasks, navigateToTaskScreen = navigateToTaskScreen, loadingModelAllowlist = uiState.loadingModelAllowlist, modifier = Modifier.fillMaxSize(), contentPadding = innerPadding, ) - SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 16.dp)) + SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 32.dp)) } } @@ -285,7 +285,7 @@ fun HomeScreen( // Show a snack bar for successful import. scope.launch { - snackbarHostState.showSnackbar("✅ Model imported successfully") + snackbarHostState.showSnackbar("Model imported successfully") } }) } @@ -316,7 +316,6 @@ fun HomeScreen( } }, ) - } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt index aefbbe9..c5e9733 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt @@ -55,6 +55,7 @@ import androidx.compose.ui.unit.dp import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.DialogProperties import com.google.aiedge.gallery.data.Accelerator +import com.google.aiedge.gallery.data.BooleanSwitchConfig import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.IMPORTS_DIR @@ -111,6 +112,10 @@ private val IMPORT_CONFIGS_LLM: List = listOf( defaultValue = DEFAULT_TEMPERATURE, valueType = ValueType.FLOAT ), + BooleanSwitchConfig( + key = ConfigKey.SUPPORT_IMAGE, + defaultValue = false, + ), SegmentedButtonConfig( key = ConfigKey.COMPATIBLE_ACCELERATORS, defaultValue = Accelerator.CPU.label, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imageclassification/ImageClassificationScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imageclassification/ImageClassificationScreen.kt index 9dc2bea..98290bb 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imageclassification/ImageClassificationScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imageclassification/ImageClassificationScreen.kt @@ -50,7 +50,8 @@ fun ImageClassificationScreen( task = viewModel.task, viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, - onSendMessage = { model, message -> + onSendMessage = { model, messages -> + val message = messages[0] viewModel.addMessage( model = model, message = message, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationScreen.kt index 4c20b63..ffa0227 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationScreen.kt @@ -44,7 +44,8 @@ fun ImageGenerationScreen( task = viewModel.task, viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, - onSendMessage = { model, message -> + onSendMessage = { model, messages -> + val message = messages[0] viewModel.addMessage( model = model, message = message, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt index d33bdad..f0b8dec 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -17,11 +17,14 @@ package com.google.aiedge.gallery.ui.llmchat import android.content.Context +import android.graphics.Bitmap import android.util.Log import com.google.aiedge.gallery.data.Accelerator import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage +import com.google.mediapipe.framework.image.BitmapImageBuilder +import com.google.mediapipe.tasks.genai.llminference.GraphOptions import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession @@ -55,7 +58,9 @@ object LlmChatModelHelper { } val options = LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context)) - .setMaxTokens(maxTokens).setPreferredBackend(preferredBackend).build() + .setMaxTokens(maxTokens).setPreferredBackend(preferredBackend) + .setMaxNumImages(if (model.llmSupportImage) 1 else 0) + .build() // Create an instance of the LLM Inference task try { @@ -64,7 +69,10 @@ object LlmChatModelHelper { val session = LlmInferenceSession.createFromOptions( llmInference, LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) - .setTemperature(temperature).build() + .setTemperature(temperature) + .setGraphOptions( + GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() + ).build() ) model.instance = LlmModelInstance(engine = llmInference, session = session) } catch (e: Exception) { @@ -75,6 +83,8 @@ object LlmChatModelHelper { } fun resetSession(model: Model) { + Log.d(TAG, "Resetting session for model '${model.name}'") + val instance = model.instance as LlmModelInstance? ?: return val session = instance.session session.close() @@ -87,9 +97,13 @@ object LlmChatModelHelper { val newSession = LlmInferenceSession.createFromOptions( inference, LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) - .setTemperature(temperature).build() + .setTemperature(temperature) + .setGraphOptions( + GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() + ).build() ) instance.session = newSession + Log.d(TAG, "Resetting done") } fun cleanUp(model: Model) { @@ -118,6 +132,7 @@ object LlmChatModelHelper { resultListener: ResultListener, cleanUpListener: CleanUpListener, singleTurn: Boolean = false, + image: Bitmap? = null, ) { if (singleTurn) { resetSession(model = model) @@ -132,6 +147,9 @@ object LlmChatModelHelper { // Start async inference. val session = instance.session session.addQueryChunk(input) + if (image != null) { + session.addImage(BitmapImageBuilder(image).build()) + } session.generateResponseAsync(resultListener) } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt index a3db07d..740e4a1 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt @@ -16,15 +16,14 @@ package com.google.aiedge.gallery.ui.llmchat -import android.util.Log +import android.graphics.Bitmap import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.lifecycle.viewmodel.compose.viewModel import com.google.aiedge.gallery.ui.ViewModelProvider -import com.google.aiedge.gallery.ui.common.chat.ChatMessageInfo +import com.google.aiedge.gallery.ui.common.chat.ChatMessageImage import com.google.aiedge.gallery.ui.common.chat.ChatMessageText -import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning import com.google.aiedge.gallery.ui.common.chat.ChatView import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import kotlinx.serialization.Serializable @@ -35,6 +34,11 @@ object LlmChatDestination { val route = "LlmChatRoute" } +object LlmImageToTextDestination { + @Serializable + val route = "LlmImageToTextRoute" +} + @Composable fun LlmChatScreen( modelManagerViewModel: ModelManagerViewModel, @@ -43,6 +47,38 @@ fun LlmChatScreen( viewModel: LlmChatViewModel = viewModel( factory = ViewModelProvider.Factory ), +) { + ChatViewWrapper( + viewModel = viewModel, + modelManagerViewModel = modelManagerViewModel, + navigateUp = navigateUp, + modifier = modifier, + ) +} + +@Composable +fun LlmImageToTextScreen( + modelManagerViewModel: ModelManagerViewModel, + navigateUp: () -> Unit, + modifier: Modifier = Modifier, + viewModel: LlmImageToTextViewModel = viewModel( + factory = ViewModelProvider.Factory + ), +) { + ChatViewWrapper( + viewModel = viewModel, + modelManagerViewModel = modelManagerViewModel, + navigateUp = navigateUp, + modifier = modifier, + ) +} + +@Composable +fun ChatViewWrapper( + viewModel: LlmChatViewModel, + modelManagerViewModel: ModelManagerViewModel, + navigateUp: () -> Unit, + modifier: Modifier = Modifier ) { val context = LocalContext.current @@ -50,21 +86,33 @@ fun LlmChatScreen( task = viewModel.task, viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, - onSendMessage = { model, message -> - viewModel.addMessage( - model = model, - message = message, - ) - if (message is ChatMessageText) { - modelManagerViewModel.addTextInputHistory(message.content) - viewModel.generateResponse(model = model, input = message.content, onError = { - viewModel.addMessage( - model = model, - message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.") - ) + onSendMessage = { model, messages -> + for (message in messages) { + viewModel.addMessage( + model = model, + message = message, + ) + } - modelManagerViewModel.initializeModel( - context = context, task = viewModel.task, model = model, force = true + var text = "" + var image: Bitmap? = null + var chatMessageText: ChatMessageText? = null + for (message in messages) { + if (message is ChatMessageText) { + chatMessageText = message + text = message.content + } else if (message is ChatMessageImage) { + image = message.bitmap + } + } + if (text.isNotEmpty() && chatMessageText != null) { + modelManagerViewModel.addTextInputHistory(text) + viewModel.generateResponse(model = model, input = text, image = image, onError = { + viewModel.handleError( + context = context, + model = model, + modelManagerViewModel = modelManagerViewModel, + triggeredMessage = chatMessageText, ) }) } @@ -72,13 +120,11 @@ fun LlmChatScreen( onRunAgainClicked = { model, message -> if (message is ChatMessageText) { viewModel.runAgain(model = model, message = message, onError = { - viewModel.addMessage( + viewModel.handleError( + context = context, model = model, - message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.") - ) - - modelManagerViewModel.initializeModel( - context = context, task = viewModel.task, model = model, force = true + modelManagerViewModel = modelManagerViewModel, + triggeredMessage = message, ) }) } @@ -90,6 +136,9 @@ fun LlmChatScreen( ) } }, + onResetSessionClicked = { model -> + viewModel.resetSession(model = model) + }, showStopButtonInInputWhenInProgress = true, onStopButtonClicked = { model -> viewModel.stopResponse(model = model) @@ -97,5 +146,4 @@ fun LlmChatScreen( navigateUp = navigateUp, modifier = modifier, ) -} - +} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt index f994978..3f4a2c6 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -16,18 +16,23 @@ package com.google.aiedge.gallery.ui.llmchat +import android.content.Context +import android.graphics.Bitmap import android.util.Log import androidx.lifecycle.viewModelScope import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.TASK_LLM_CHAT +import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT +import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading import com.google.aiedge.gallery.ui.common.chat.ChatMessageText import com.google.aiedge.gallery.ui.common.chat.ChatMessageType +import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning import com.google.aiedge.gallery.ui.common.chat.ChatSide import com.google.aiedge.gallery.ui.common.chat.ChatViewModel import com.google.aiedge.gallery.ui.common.chat.Stat -import kotlinx.coroutines.CoroutineExceptionHandler +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.launch @@ -40,8 +45,8 @@ private val STATS = listOf( Stat(id = "latency", label = "Latency", unit = "sec") ) -class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { - fun generateResponse(model: Model, input: String, onError: () -> Unit) { +open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) { + fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) { viewModelScope.launch(Dispatchers.Default) { setInProgress(true) @@ -58,7 +63,10 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { // Run inference. val instance = model.instance as LlmModelInstance - val prefillTokens = instance.session.sizeInTokens(input) + var prefillTokens = instance.session.sizeInTokens(input) + if (image != null) { + prefillTokens += 257 + } var firstRun = true var timeToFirstToken = 0f @@ -69,9 +77,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { val start = System.currentTimeMillis() try { - LlmChatModelHelper.runInference( - model = model, + LlmChatModelHelper.runInference(model = model, input = input, + image = image, resultListener = { partialResult, done -> val curTs = System.currentTimeMillis() @@ -92,32 +100,27 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { // Add an empty message that will receive streaming results. addMessage( - model = model, - message = ChatMessageText(content = "", side = ChatSide.AGENT) + model = model, message = ChatMessageText(content = "", side = ChatSide.AGENT) ) } // Incrementally update the streamed partial results. val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1 updateLastTextMessageContentIncrementally( - model = model, - partialContent = partialResult, - latencyMs = latencyMs.toFloat() + model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat() ) if (done) { setInProgress(false) - decodeSpeed = - decodeTokens / ((curTs - firstTokenTs) / 1000f) + decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f) if (decodeSpeed.isNaN()) { decodeSpeed = 0f } if (lastMessage is ChatMessageText) { updateLastTextMessageLlmBenchmarkResult( - model = model, llmBenchmarkResult = - ChatMessageBenchmarkLlmResult( + model = model, llmBenchmarkResult = ChatMessageBenchmarkLlmResult( orderedStats = STATS, statValues = mutableMapOf( "prefill_speed" to prefillSpeed, @@ -131,10 +134,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { ) } } - }, cleanUpListener = { + }, + cleanUpListener = { setInProgress(false) }) } catch (e: Exception) { + Log.e(TAG, "Error occurred while running inference", e) setInProgress(false) onError() } @@ -143,6 +148,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { fun stopResponse(model: Model) { Log.d(TAG, "Stopping response for model ${model.name}...") + if (getLastMessage(model = model) is ChatMessageLoading) { + removeLastMessage(model = model) + } viewModelScope.launch(Dispatchers.Default) { setInProgress(false) val instance = model.instance as LlmModelInstance @@ -150,6 +158,25 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { } } + fun resetSession(model: Model) { + viewModelScope.launch(Dispatchers.Default) { + setIsResettingSession(true) + clearAllMessages(model = model) + stopResponse(model = model) + + while (true) { + try { + LlmChatModelHelper.resetSession(model = model) + break + } catch (e: Exception) { + Log.d(TAG, "Failed to reset session. Trying again") + } + delay(200) + } + setIsResettingSession(false) + } + } + fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) { viewModelScope.launch(Dispatchers.Default) { // Wait for model to be initialized. @@ -162,9 +189,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { // Run inference. generateResponse( - model = model, - input = message.content, - onError = onError + model = model, input = message.content, onError = onError ) } } @@ -199,8 +224,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { var decodeSpeed: Float val start = System.currentTimeMillis() var lastUpdateTime = 0L - LlmChatModelHelper.runInference( - model = model, + LlmChatModelHelper.runInference(model = model, input = message.content, resultListener = { partialResult, done -> val curTs = System.currentTimeMillis() @@ -232,14 +256,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { result.append(partialResult) if (curTs - lastUpdateTime > 500 || done) { - decodeSpeed = - decodeTokens / ((curTs - firstTokenTs) / 1000f) + decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f) if (decodeSpeed.isNaN()) { decodeSpeed = 0f } replaceLastMessage( - model = model, - message = ChatMessageBenchmarkLlmResult( + model = model, message = ChatMessageBenchmarkLlmResult( orderedStats = STATS, statValues = mutableMapOf( "prefill_speed" to prefillSpeed, @@ -249,8 +271,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { ), running = !done, latencyMs = -1f, - ), - type = ChatMessageType.BENCHMARK_LLM_RESULT + ), type = ChatMessageType.BENCHMARK_LLM_RESULT ) lastUpdateTime = curTs @@ -261,9 +282,53 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { }, cleanUpListener = { setInProgress(false) - } - ) + }) } } + + fun handleError( + context: Context, + model: Model, + modelManagerViewModel: ModelManagerViewModel, + triggeredMessage: ChatMessageText, + ) { + // Clean up. + modelManagerViewModel.cleanupModel(task = task, model = model) + + // Remove the "loading" message. + if (getLastMessage(model = model) is ChatMessageLoading) { + removeLastMessage(model = model) + } + + // Remove the last Text message. + if (getLastMessage(model = model) == triggeredMessage) { + removeLastMessage(model = model) + } + + // Add a warning message for re-initializing the session. + addMessage( + model = model, + message = ChatMessageWarning(content = "Error occurred. Re-initializing the session.") + ) + + // Add the triggered message back. + addMessage(model = model, message = triggeredMessage) + + // Re-initialize the session/engine. + modelManagerViewModel.initializeModel( + context = context, task = task, model = model + ) + + // Re-generate the response automatically. + generateResponse(model = model, input = triggeredMessage.content, onError = { + handleError( + context = context, + model = model, + modelManagerViewModel = modelManagerViewModel, + triggeredMessage = triggeredMessage + ) + }) + } } +class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT) \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt index f34b1f4..1b96140 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt @@ -155,9 +155,14 @@ fun LlmSingleTurnScreen( PromptTemplatesPanel( model = selectedModel, viewModel = viewModel, + modelManagerViewModel = modelManagerViewModel, onSend = { fullPrompt -> viewModel.generateResponse(model = selectedModel, input = fullPrompt) - }, modifier = Modifier.fillMaxSize() + }, + onStopButtonClicked = { model -> + viewModel.stopResponse(model = model) + }, + modifier = Modifier.fillMaxSize() ) }, bottomView = { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt index cc045cf..21e1b3c 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt @@ -23,6 +23,7 @@ import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult +import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading import com.google.aiedge.gallery.ui.common.chat.Stat import com.google.aiedge.gallery.ui.common.processLlmResponse import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper @@ -194,6 +195,15 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode } } + fun stopResponse(model: Model) { + Log.d(TAG, "Stopping response for model ${model.name}...") + viewModelScope.launch(Dispatchers.Default) { + setInProgress(false) + val instance = model.instance as LlmModelInstance + instance.session.cancelGenerateResponseAsync() + } + } + private fun createUiState(task: Task): LlmSingleTurnUiState { val responsesByModel: MutableMap> = mutableMapOf() val benchmarkByModel: MutableMap> = diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt index 661daaa..2ca98d7 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt @@ -43,11 +43,13 @@ import androidx.compose.material.icons.outlined.Description import androidx.compose.material.icons.outlined.ExpandLess import androidx.compose.material.icons.outlined.ExpandMore import androidx.compose.material.icons.rounded.Add +import androidx.compose.material.icons.rounded.Stop import androidx.compose.material.icons.rounded.Visibility import androidx.compose.material.icons.rounded.VisibilityOff import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.FilterChipDefaults import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton import androidx.compose.material3.IconButtonDefaults import androidx.compose.material3.MaterialTheme import androidx.compose.material3.ModalBottomSheet @@ -86,9 +88,11 @@ import androidx.compose.ui.unit.dp import com.google.aiedge.gallery.R import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape +import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.theme.customColors -import kotlinx.coroutines.launch import kotlinx.coroutines.delay +import kotlinx.coroutines.launch private val promptTemplateTypes: List = PromptTemplateType.entries private val TAB_TITLES = PromptTemplateType.entries.map { it.label } @@ -101,11 +105,14 @@ const val FULL_PROMPT_SWITCH_KEY = "full_prompt" fun PromptTemplatesPanel( model: Model, viewModel: LlmSingleTurnViewModel, + modelManagerViewModel: ModelManagerViewModel, onSend: (fullPrompt: String) -> Unit, + onStopButtonClicked: (Model) -> Unit, modifier: Modifier = Modifier ) { val scope = rememberCoroutineScope() val uiState by viewModel.uiState.collectAsState() + val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val selectedPromptTemplateType = uiState.selectedPromptTemplateType val inProgress = uiState.inProgress var selectedTabIndex by remember { mutableIntStateOf(0) } @@ -123,6 +130,8 @@ fun PromptTemplatesPanel( val focusManager = LocalFocusManager.current val interactionSource = remember { MutableInteractionSource() } val expandedStates = remember { mutableStateMapOf() } + val modelInitializationStatus = + modelManagerUiState.modelInitializationStatus[model.name] // Update input editor values when prompt template changes. LaunchedEffect(selectedPromptTemplateType) { @@ -328,29 +337,49 @@ fun PromptTemplatesPanel( ) } - // Send button - OutlinedIconButton( - enabled = !inProgress && curTextInputContent.isNotEmpty(), - onClick = { - focusManager.clearFocus() - onSend(fullPrompt.text) - }, - colors = IconButtonDefaults.iconButtonColors( - containerColor = MaterialTheme.colorScheme.secondaryContainer, - disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f), - contentColor = MaterialTheme.colorScheme.primary, - disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f), - ), - border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface), - modifier = Modifier.size(ICON_BUTTON_SIZE) - ) { - Icon( - Icons.AutoMirrored.Rounded.Send, - contentDescription = "", - modifier = Modifier - .size(20.dp) - .offset(x = 2.dp), - ) + val modelInitializing = + modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING + if (inProgress && !modelInitializing) { + IconButton( + onClick = { + onStopButtonClicked(model) + }, + colors = IconButtonDefaults.iconButtonColors( + containerColor = MaterialTheme.colorScheme.secondaryContainer, + ), + modifier = Modifier.size(ICON_BUTTON_SIZE) + ) { + Icon( + Icons.Rounded.Stop, + contentDescription = "", + tint = MaterialTheme.colorScheme.primary + ) + } + } else { + // Send button + OutlinedIconButton( + enabled = !inProgress && curTextInputContent.isNotEmpty(), + onClick = { + focusManager.clearFocus() + onSend(fullPrompt.text) + }, + colors = IconButtonDefaults.iconButtonColors( + containerColor = MaterialTheme.colorScheme.secondaryContainer, + disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f), + contentColor = MaterialTheme.colorScheme.primary, + disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f), + ), + border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface), + modifier = Modifier.size(ICON_BUTTON_SIZE) + ) { + Icon( + Icons.AutoMirrored.Rounded.Send, + contentDescription = "", + modifier = Modifier + .size(20.dp) + .offset(x = 2.dp), + ) + } } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManager.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManager.kt index 3a563b3..49994e8 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManager.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManager.kt @@ -21,6 +21,10 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Scaffold import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.derivedStateOf +import androidx.compose.runtime.getValue +import androidx.compose.runtime.remember import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.tooling.preview.Preview @@ -48,6 +52,24 @@ fun ModelManager( if (task.models.size != 1) { title += "s" } + // Model count. + val modelCount by remember { + derivedStateOf { + val trigger = task.updateTrigger.value + if (trigger >= 0) { + task.models.size + } else { + -1 + } + } + } + + // Navigate up when there are no models left. + LaunchedEffect(modelCount) { + if (modelCount == 0) { + navigateUp() + } + } // Handle system's edge swipe. BackHandler { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt index d9598cf..d977e09 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -38,6 +38,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.data.TASK_LLM_CHAT +import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.TaskType @@ -58,6 +59,7 @@ import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch import kotlinx.coroutines.withContext +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.json.Json import net.openid.appauth.AuthorizationException import net.openid.appauth.AuthorizationRequest @@ -229,7 +231,10 @@ open class ModelManagerViewModel( } dataStoreRepository.saveImportedModels(importedModels = importedModels) } - val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus) + val newUiState = uiState.value.copy( + modelDownloadStatus = curModelDownloadStatus, + tasks = uiState.value.tasks.toList() + ) _uiState.update { newUiState } } @@ -312,6 +317,12 @@ open class ModelManagerViewModel( onDone = onDone, ) + TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.initialize( + context = context, + model = model, + onDone = onDone, + ) + TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize( context = context, model = model, onDone = onDone ) @@ -331,6 +342,7 @@ open class ModelManagerViewModel( TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model) TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model) + TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.cleanUp(model = model) TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model) TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_2 -> {} @@ -434,14 +446,16 @@ open class ModelManagerViewModel( // Create model. val model = createModelFromImportedModelInfo(info = info) - // Remove duplicated imported model if existed. - for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES)) { + for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES, TASK_LLM_IMAGE_TO_TEXT)) { + // Remove duplicated imported model if existed. val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } if (modelIndex >= 0) { Log.d(TAG, "duplicated imported model found in task. Removing it first") task.models.removeAt(modelIndex) } - task.models.add(model) + if (task == TASK_LLM_IMAGE_TO_TEXT && model.llmSupportImage || task != TASK_LLM_IMAGE_TO_TEXT) { + task.models.add(model) + } task.updateTrigger.value = System.currentTimeMillis() } @@ -632,8 +646,7 @@ open class ModelManagerViewModel( fun loadModelAllowlist() { _uiState.update { uiState.value.copy( - loadingModelAllowlist = true, - loadingModelAllowlistError = "" + loadingModelAllowlist = true, loadingModelAllowlistError = "" ) } @@ -663,6 +676,9 @@ open class ModelManagerViewModel( if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) { TASK_LLM_USECASES.models.add(model) } + if (allowedModel.taskTypes.contains(TASK_LLM_IMAGE_TO_TEXT.type.id)) { + TASK_LLM_IMAGE_TO_TEXT.models.add(model) + } } // Pre-process all tasks. @@ -717,6 +733,9 @@ open class ModelManagerViewModel( // Add to task. TASK_LLM_CHAT.models.add(model) TASK_LLM_USECASES.models.add(model) + if (model.llmSupportImage) { + TASK_LLM_IMAGE_TO_TEXT.models.add(model) + } // Update status. modelDownloadStatus[model.name] = ModelDownloadStatus( @@ -731,7 +750,7 @@ open class ModelManagerViewModel( Log.d(TAG, "model download status: $modelDownloadStatus") return ModelManagerUiState( - tasks = TASKS, + tasks = TASKS.toList(), modelDownloadStatus = modelDownloadStatus, modelInitializationStatus = modelInstances, textInputHistory = textInputHistory, @@ -763,6 +782,9 @@ open class ModelManagerViewModel( ) as Float, accelerators = accelerators, ) + val llmSupportImage = convertValueToTargetType( + info.defaultValues[ConfigKey.SUPPORT_IMAGE.label] ?: false, ValueType.BOOLEAN + ) as Boolean val model = Model( name = info.fileName, url = "", @@ -770,7 +792,9 @@ open class ModelManagerViewModel( sizeInBytes = info.fileSize, downloadFileName = "$IMPORTS_DIR/${info.fileName}", showBenchmarkButton = false, + showRunAgainButton = false, imported = true, + llmSupportImage = llmSupportImage, ) model.preProcess() @@ -803,6 +827,7 @@ open class ModelManagerViewModel( ) } + @OptIn(ExperimentalSerializationApi::class) private inline fun getJsonResponse(url: String): T? { try { val connection = URL(url).openConnection() as HttpURLConnection @@ -815,7 +840,12 @@ open class ModelManagerViewModel( val response = inputStream.bufferedReader().use { it.readText() } // Parse JSON using kotlinx.serialization - val json = Json { ignoreUnknownKeys = true } // Handle potential extra fields + val json = Json { + // Handle potential extra fields + ignoreUnknownKeys = true + allowComments = true + allowTrailingComma = true + } val jsonObj = json.decodeFromString(response) return jsonObj } else { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/navigation/GalleryNavGraph.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/navigation/GalleryNavGraph.kt index 844500e..61efd1a 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/navigation/GalleryNavGraph.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/navigation/GalleryNavGraph.kt @@ -46,6 +46,7 @@ import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION import com.google.aiedge.gallery.data.TASK_LLM_CHAT +import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION import com.google.aiedge.gallery.data.Task @@ -59,6 +60,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen +import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextDestination +import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextScreen import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen import com.google.aiedge.gallery.ui.modelmanager.ModelManager @@ -129,8 +132,7 @@ fun GalleryNavHost( ) { val curPickedTask = pickedTask if (curPickedTask != null) { - ModelManager( - viewModel = modelManagerViewModel, + ModelManager(viewModel = modelManagerViewModel, task = curPickedTask, onModelClicked = { model -> navigateToTaskScreen( @@ -207,7 +209,7 @@ fun GalleryNavHost( } } - // LLMm chat demos. + // LLM chat demos. composable( route = "${LlmChatDestination.route}/{modelName}", arguments = listOf(navArgument("modelName") { type = NavType.StringType }), @@ -224,7 +226,7 @@ fun GalleryNavHost( } } - // LLMm single turn. + // LLM single turn. composable( route = "${LlmSingleTurnDestination.route}/{modelName}", arguments = listOf(navArgument("modelName") { type = NavType.StringType }), @@ -241,6 +243,22 @@ fun GalleryNavHost( } } + // LLM image to text. + composable( + route = "${LlmImageToTextDestination.route}/{modelName}", + arguments = listOf(navArgument("modelName") { type = NavType.StringType }), + enterTransition = { slideEnter() }, + exitTransition = { slideExit() }, + ) { + getModelFromNavigationParam(it, TASK_LLM_IMAGE_TO_TEXT)?.let { defaultModel -> + modelManagerViewModel.selectModel(defaultModel) + + LlmImageToTextScreen( + modelManagerViewModel = modelManagerViewModel, + navigateUp = { navController.navigateUp() }, + ) + } + } } // Handle incoming intents for deep links @@ -254,9 +272,7 @@ fun GalleryNavHost( getModelByName(modelName)?.let { model -> // TODO(jingjin): need to show a list of possible tasks for this model. navigateToTaskScreen( - navController = navController, - taskType = TaskType.LLM_CHAT, - model = model + navController = navController, taskType = TaskType.LLM_CHAT, model = model ) } } @@ -271,6 +287,7 @@ fun navigateToTaskScreen( TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}") TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}") TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}") + TaskType.LLM_IMAGE_TO_TEXT -> navController.navigate("${LlmImageToTextDestination.route}/${modelName}") TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}") TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}") TaskType.TEST_TASK_1 -> {} diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/textclassification/TextClassificationScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/textclassification/TextClassificationScreen.kt index 7dab53c..f31394c 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/textclassification/TextClassificationScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/textclassification/TextClassificationScreen.kt @@ -44,7 +44,8 @@ fun TextClassificationScreen( task = viewModel.task, viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, - onSendMessage = { model, message -> + onSendMessage = { model, messages -> + val message = messages[0] viewModel.addMessage( model = model, message = message, diff --git a/Android/src/gradle/libs.versions.toml b/Android/src/gradle/libs.versions.toml index 66c5fc7..7eca9d2 100644 --- a/Android/src/gradle/libs.versions.toml +++ b/Android/src/gradle/libs.versions.toml @@ -7,7 +7,7 @@ junitVersion = "1.2.1" espressoCore = "3.6.1" lifecycleRuntimeKtx = "2.8.7" activityCompose = "1.10.1" -composeBom = "2025.03.01" +composeBom = "2025.05.00" navigation = "2.8.9" serializationPlugin = "2.0.21" serializationJson = "1.7.3"