mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-06 06:30:30 -04:00
Add support for image to text models
This commit is contained in:
parent
ef290cd7b0
commit
bedc488a15
29 changed files with 763 additions and 231 deletions
|
@ -19,7 +19,6 @@ package com.google.aiedge.gallery.data
|
|||
import android.content.Context
|
||||
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||
import java.io.File
|
||||
|
||||
data class ModelDataFile(
|
||||
|
@ -91,6 +90,9 @@ data class Model(
|
|||
/** The prompt templates for the model (only for LLM). */
|
||||
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
||||
|
||||
/** Whether the LLM model supports image input. */
|
||||
val llmSupportImage: Boolean = false,
|
||||
|
||||
/** Whether the model is imported or not. */
|
||||
val imported: Boolean = false,
|
||||
|
||||
|
@ -204,6 +206,7 @@ enum class ConfigKey(val label: String) {
|
|||
DEFAULT_TOPK("Default TopK"),
|
||||
DEFAULT_TOPP("Default TopP"),
|
||||
DEFAULT_TEMPERATURE("Default temperature"),
|
||||
SUPPORT_IMAGE("Support image"),
|
||||
MAX_RESULT_COUNT("Max result count"),
|
||||
USE_GPU("Use GPU"),
|
||||
ACCELERATOR("Accelerator"),
|
||||
|
@ -250,83 +253,9 @@ const val IMAGE_CLASSIFICATION_INFO = ""
|
|||
|
||||
const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/litert/android"
|
||||
|
||||
const val LLM_CHAT_INFO =
|
||||
"Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms"
|
||||
|
||||
const val IMAGE_GENERATION_INFO =
|
||||
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Model spec.
|
||||
|
||||
val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
|
||||
name = "Gemma 2B (GPU int4)",
|
||||
downloadFileName = "gemma-2b-it-gpu-int4.bin",
|
||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
|
||||
sizeInBytes = 1354301440L,
|
||||
configs = createLlmChatConfigs(
|
||||
accelerators = listOf(Accelerator.GPU)
|
||||
),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||
)
|
||||
|
||||
val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
||||
name = "Gemma 2 2B (GPU int8)",
|
||||
downloadFileName = "gemma2-2b-it-gpu-int8.bin",
|
||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
|
||||
sizeInBytes = 2627141632L,
|
||||
configs = createLlmChatConfigs(
|
||||
accelerators = listOf(Accelerator.GPU)
|
||||
),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
||||
)
|
||||
|
||||
val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
||||
name = "Gemma 3 1B (int4)",
|
||||
downloadFileName = "gemma3-1b-it-int4.task",
|
||||
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
|
||||
sizeInBytes = 554661243L,
|
||||
configs = createLlmChatConfigs(
|
||||
defaultTopK = 64,
|
||||
defaultTopP = 0.95f,
|
||||
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
|
||||
),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
||||
llmPromptTemplates = listOf(
|
||||
PromptTemplate(
|
||||
title = "Emoji Fun",
|
||||
description = "Generate emojis by emotions",
|
||||
prompt = "Show me emojis grouped by emotions"
|
||||
),
|
||||
PromptTemplate(
|
||||
title = "Trip Planner",
|
||||
description = "Plan a trip to a destination",
|
||||
prompt = "Plan a two-day trip to San Francisco"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
val MODEL_LLM_DEEPSEEK: Model = Model(
|
||||
name = "Deepseek",
|
||||
downloadFileName = "deepseek.task",
|
||||
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
|
||||
sizeInBytes = 1860686856L,
|
||||
configs = createLlmChatConfigs(
|
||||
defaultTemperature = 0.6f,
|
||||
defaultTopK = 40,
|
||||
defaultTopP = 0.7f,
|
||||
accelerators = listOf(Accelerator.CPU)
|
||||
),
|
||||
showBenchmarkButton = false,
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
)
|
||||
|
||||
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
|
||||
name = "MobileBert",
|
||||
|
@ -414,12 +343,5 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
|
|||
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
|
||||
)
|
||||
|
||||
val MODELS_LLM: MutableList<Model> = mutableListOf(
|
||||
// MODEL_LLM_GEMMA_2B_GPU_INT4,
|
||||
// MODEL_LLM_GEMMA_2_2B_GPU_INT8,
|
||||
MODEL_LLM_GEMMA_3_1B_INT4,
|
||||
MODEL_LLM_DEEPSEEK,
|
||||
)
|
||||
|
||||
val MODELS_IMAGE_GENERATION: MutableList<Model> =
|
||||
mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION)
|
||||
|
|
|
@ -35,6 +35,7 @@ data class AllowedModel(
|
|||
val defaultConfig: Map<String, ConfigValue>,
|
||||
val taskTypes: List<String>,
|
||||
val disabled: Boolean? = null,
|
||||
val llmSupportImage: Boolean? = null,
|
||||
) {
|
||||
fun toModel(): Model {
|
||||
// Construct HF download url.
|
||||
|
@ -48,7 +49,7 @@ data class AllowedModel(
|
|||
var defaultTopK: Int = DEFAULT_TOPK
|
||||
var defaultTopP: Float = DEFAULT_TOPP
|
||||
var defaultTemperature: Float = DEFAULT_TEMPERATURE
|
||||
var defaultMaxToken: Int = 1024
|
||||
var defaultMaxToken = 1024
|
||||
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
|
||||
if (defaultConfig.containsKey("topK")) {
|
||||
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK)
|
||||
|
@ -84,9 +85,10 @@ data class AllowedModel(
|
|||
|
||||
// Misc.
|
||||
var showBenchmarkButton = true
|
||||
val showRunAgainButton = true
|
||||
if (taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)) {
|
||||
var showRunAgainButton = true
|
||||
if (isLlmModel) {
|
||||
showBenchmarkButton = false
|
||||
showRunAgainButton = false
|
||||
}
|
||||
|
||||
return Model(
|
||||
|
@ -99,7 +101,8 @@ data class AllowedModel(
|
|||
downloadFileName = modelFile,
|
||||
showBenchmarkButton = showBenchmarkButton,
|
||||
showRunAgainButton = showRunAgainButton,
|
||||
learnMoreUrl = "https://huggingface.co/${modelId}"
|
||||
learnMoreUrl = "https://huggingface.co/${modelId}",
|
||||
llmSupportImage = llmSupportImage == true,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data
|
|||
import androidx.annotation.StringRes
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.outlined.Forum
|
||||
import androidx.compose.material.icons.outlined.Mms
|
||||
import androidx.compose.material.icons.outlined.Widgets
|
||||
import androidx.compose.material.icons.rounded.ImageSearch
|
||||
import androidx.compose.runtime.MutableState
|
||||
|
@ -33,6 +34,7 @@ enum class TaskType(val label: String, val id: String) {
|
|||
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
|
||||
LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
|
||||
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
|
||||
LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"),
|
||||
|
||||
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
||||
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
|
||||
|
@ -93,7 +95,7 @@ val TASK_LLM_CHAT = Task(
|
|||
icon = Icons.Outlined.Forum,
|
||||
// models = MODELS_LLM,
|
||||
models = mutableListOf(),
|
||||
description = "Chat with a on-device large language model",
|
||||
description = "Chat with on-device large language models",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||
|
@ -110,6 +112,17 @@ val TASK_LLM_USECASES = Task(
|
|||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||
)
|
||||
|
||||
val TASK_LLM_IMAGE_TO_TEXT = Task(
|
||||
type = TaskType.LLM_IMAGE_TO_TEXT,
|
||||
icon = Icons.Outlined.Mms,
|
||||
// models = MODELS_LLM,
|
||||
models = mutableListOf(),
|
||||
description = "Ask questions about images with on-device large language models",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||
)
|
||||
|
||||
val TASK_IMAGE_GENERATION = Task(
|
||||
type = TaskType.IMAGE_GENERATION,
|
||||
iconVectorResourceId = R.drawable.image_spark,
|
||||
|
@ -127,6 +140,7 @@ val TASKS: List<Task> = listOf(
|
|||
// TASK_IMAGE_GENERATION,
|
||||
TASK_LLM_USECASES,
|
||||
TASK_LLM_CHAT,
|
||||
TASK_LLM_IMAGE_TO_TEXT
|
||||
)
|
||||
|
||||
fun getModelByName(name: String): Model? {
|
||||
|
|
|
@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
|
|||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
|
||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
|
||||
|
@ -62,6 +63,12 @@ object ViewModelProvider {
|
|||
LlmSingleTurnViewModel()
|
||||
}
|
||||
|
||||
// Initializer for LlmImageToTextViewModel.
|
||||
initializer {
|
||||
LlmImageToTextViewModel()
|
||||
}
|
||||
|
||||
// Initializer for ImageGenerationViewModel.
|
||||
initializer {
|
||||
ImageGenerationViewModel()
|
||||
}
|
||||
|
|
|
@ -17,13 +17,17 @@
|
|||
package com.google.aiedge.gallery.ui.common
|
||||
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
|
||||
import androidx.compose.material.icons.rounded.Settings
|
||||
import androidx.compose.material.icons.rounded.MapsUgc
|
||||
import androidx.compose.material.icons.rounded.Tune
|
||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
|
@ -38,6 +42,7 @@ import androidx.compose.runtime.setValue
|
|||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.scale
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.vectorResource
|
||||
|
@ -47,6 +52,7 @@ import com.google.aiedge.gallery.data.Model
|
|||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
|
@ -58,12 +64,17 @@ fun ModelPageAppBar(
|
|||
onBackClicked: () -> Unit,
|
||||
onModelSelected: (Model) -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
isResettingSession: Boolean = false,
|
||||
onResetSessionClicked: (Model) -> Unit = {},
|
||||
showResetSessionButton: Boolean = false,
|
||||
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
|
||||
) {
|
||||
var showConfigDialog by remember { mutableStateOf(false) }
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val context = LocalContext.current
|
||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
||||
val modelInitializationStatus =
|
||||
modelManagerUiState.modelInitializationStatus[model.name]
|
||||
|
||||
CenterAlignedTopAppBar(title = {
|
||||
Column(
|
||||
|
@ -110,19 +121,58 @@ fun ModelPageAppBar(
|
|||
actions = {
|
||||
val showConfigButton =
|
||||
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
|
||||
IconButton(
|
||||
onClick = {
|
||||
showConfigDialog = true
|
||||
},
|
||||
enabled = showConfigButton,
|
||||
modifier = Modifier.alpha(if (showConfigButton) 1f else 0f)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.Settings,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
Box(modifier = Modifier.size(42.dp), contentAlignment = Alignment.Center) {
|
||||
var configButtonOffset = 0.dp
|
||||
if (showConfigButton && showResetSessionButton) {
|
||||
configButtonOffset = (-40).dp
|
||||
}
|
||||
val isModelInitializing =
|
||||
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
||||
if (showConfigButton) {
|
||||
IconButton(
|
||||
onClick = {
|
||||
showConfigDialog = true
|
||||
},
|
||||
enabled = !isModelInitializing,
|
||||
modifier = Modifier
|
||||
.scale(0.75f)
|
||||
.offset(x = configButtonOffset)
|
||||
.alpha(if (isModelInitializing) 0.5f else 1f)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.Tune,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
if (showResetSessionButton) {
|
||||
if (isResettingSession) {
|
||||
CircularProgressIndicator(
|
||||
trackColor = MaterialTheme.colorScheme.surfaceVariant,
|
||||
strokeWidth = 2.dp,
|
||||
modifier = Modifier.size(16.dp)
|
||||
)
|
||||
} else {
|
||||
IconButton(
|
||||
onClick = {
|
||||
onResetSessionClicked(model)
|
||||
},
|
||||
enabled = !isModelInitializing,
|
||||
modifier = Modifier
|
||||
.scale(0.75f)
|
||||
.alpha(if (isModelInitializing) 0.5f else 1f)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.MapsUgc,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
// Config dialog.
|
||||
|
|
|
@ -29,6 +29,7 @@ import androidx.compose.foundation.layout.Row
|
|||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.layout.widthIn
|
||||
import androidx.compose.foundation.pager.HorizontalPager
|
||||
import androidx.compose.foundation.pager.rememberPagerState
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
|
@ -54,6 +55,9 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.graphicsLayer
|
||||
import androidx.compose.ui.platform.LocalDensity
|
||||
import androidx.compose.ui.platform.LocalWindowInfo
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
|
@ -77,6 +81,10 @@ fun ModelPickerChipsPager(
|
|||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||
val scope = rememberCoroutineScope()
|
||||
val density = LocalDensity.current
|
||||
val windowInfo = LocalWindowInfo.current
|
||||
val screenWidthDp =
|
||||
remember { with(density) { windowInfo.containerSize.width.toDp() } }
|
||||
|
||||
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
|
||||
pageCount = { task.models.size })
|
||||
|
@ -140,7 +148,12 @@ fun ModelPickerChipsPager(
|
|||
Text(
|
||||
model.name,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
modifier = Modifier.padding(start = 4.dp),
|
||||
modifier = Modifier
|
||||
.padding(start = 4.dp)
|
||||
.widthIn(0.dp, screenWidthDp - 250.dp),
|
||||
maxLines = 1,
|
||||
overflow = TextOverflow.MiddleEllipsis
|
||||
|
||||
)
|
||||
Icon(
|
||||
Icons.Rounded.ArrowDropDown,
|
||||
|
|
|
@ -109,7 +109,7 @@ fun ChatPanel(
|
|||
task: Task,
|
||||
selectedModel: Model,
|
||||
viewModel: ChatViewModel,
|
||||
onSendMessage: (Model, ChatMessage) -> Unit,
|
||||
onSendMessage: (Model, List<ChatMessage>) -> Unit,
|
||||
onRunAgainClicked: (Model, ChatMessage) -> Unit,
|
||||
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
|
||||
navigateUp: () -> Unit,
|
||||
|
@ -280,7 +280,8 @@ fun ChatPanel(
|
|||
task = task,
|
||||
onPromptClicked = { template ->
|
||||
onSendMessage(
|
||||
selectedModel, ChatMessageText(content = template.prompt, side = ChatSide.USER)
|
||||
selectedModel,
|
||||
listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER))
|
||||
)
|
||||
})
|
||||
|
||||
|
@ -430,12 +431,15 @@ fun ChatPanel(
|
|||
// Chat input
|
||||
when (chatInputType) {
|
||||
ChatInputType.TEXT -> {
|
||||
val isLlmTask = task.type == TaskType.LLM_CHAT
|
||||
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
|
||||
// val isLlmTask = task.type == TaskType.LLM_CHAT
|
||||
// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
|
||||
val hasImageMessage = messages.any { it is ChatMessageImage }
|
||||
MessageInputText(
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
curMessage = curMessage,
|
||||
inProgress = uiState.inProgress,
|
||||
isResettingSession = uiState.isResettingSession,
|
||||
hasImageMessage = hasImageMessage,
|
||||
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
||||
onValueChanged = { curMessage = it },
|
||||
|
@ -445,13 +449,17 @@ fun ChatPanel(
|
|||
},
|
||||
onOpenPromptTemplatesClicked = {
|
||||
onSendMessage(
|
||||
selectedModel, ChatMessagePromptTemplates(
|
||||
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
|
||||
selectedModel, listOf(
|
||||
ChatMessagePromptTemplates(
|
||||
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
|
||||
)
|
||||
)
|
||||
)
|
||||
},
|
||||
onStopButtonClicked = onStopButtonClicked,
|
||||
showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
||||
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
||||
showPromptTemplatesInMenu = false,
|
||||
showImagePickerInMenu = selectedModel.llmSupportImage == true,
|
||||
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
|
||||
)
|
||||
}
|
||||
|
@ -461,8 +469,10 @@ fun ChatPanel(
|
|||
streamingMessage = streamingMessage,
|
||||
onImageSelected = { bitmap ->
|
||||
onSendMessage(
|
||||
selectedModel, ChatMessageImage(
|
||||
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
|
||||
selectedModel, listOf(
|
||||
ChatMessageImage(
|
||||
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
|
||||
)
|
||||
)
|
||||
)
|
||||
},
|
||||
|
|
|
@ -70,16 +70,18 @@ fun ChatView(
|
|||
task: Task,
|
||||
viewModel: ChatViewModel,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
onSendMessage: (Model, ChatMessage) -> Unit,
|
||||
onSendMessage: (Model, List<ChatMessage>) -> Unit,
|
||||
onRunAgainClicked: (Model, ChatMessage) -> Unit,
|
||||
onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit,
|
||||
navigateUp: () -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
onResetSessionClicked: (Model) -> Unit = {},
|
||||
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
|
||||
onStopButtonClicked: (Model) -> Unit = {},
|
||||
chatInputType: ChatInputType = ChatInputType.TEXT,
|
||||
showStopButtonInInputWhenInProgress: Boolean = false,
|
||||
) {
|
||||
val uiStat by viewModel.uiState.collectAsState()
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val selectedModel = modelManagerUiState.selectedModel
|
||||
|
||||
|
@ -155,6 +157,9 @@ fun ChatView(
|
|||
task = task,
|
||||
model = selectedModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
showResetSessionButton = true,
|
||||
isResettingSession = uiStat.isResettingSession,
|
||||
onResetSessionClicked = onResetSessionClicked,
|
||||
onConfigChanged = { old, new ->
|
||||
viewModel.addConfigChangedMessage(
|
||||
oldConfigValues = old,
|
||||
|
|
|
@ -33,6 +33,11 @@ data class ChatUiState(
|
|||
*/
|
||||
val inProgress: Boolean = false,
|
||||
|
||||
/**
|
||||
* Indicates whether the session is being reset.
|
||||
*/
|
||||
val isResettingSession: Boolean = false,
|
||||
|
||||
/**
|
||||
* A map of model names to lists of chat messages.
|
||||
*/
|
||||
|
@ -106,6 +111,12 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
|||
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||
}
|
||||
|
||||
fun clearAllMessages(model: Model) {
|
||||
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||
newMessagesByModel[model.name] = mutableListOf()
|
||||
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||
}
|
||||
|
||||
fun getLastMessage(model: Model): ChatMessage? {
|
||||
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
|
||||
}
|
||||
|
@ -189,6 +200,10 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
|||
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
||||
}
|
||||
|
||||
fun setIsResettingSession(isResettingSession: Boolean) {
|
||||
_uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) }
|
||||
}
|
||||
|
||||
fun addConfigChangedMessage(
|
||||
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
|
||||
) {
|
||||
|
|
|
@ -21,14 +21,18 @@ import androidx.compose.material3.ProvideTextStyle
|
|||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.CompositionLocalProvider
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.text.SpanStyle
|
||||
import androidx.compose.ui.text.TextLinkStyles
|
||||
import androidx.compose.ui.text.TextStyle
|
||||
import androidx.compose.ui.text.font.FontFamily
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
import com.halilibo.richtext.commonmark.Markdown
|
||||
import com.halilibo.richtext.ui.CodeBlockStyle
|
||||
import com.halilibo.richtext.ui.RichTextStyle
|
||||
import com.halilibo.richtext.ui.material3.RichText
|
||||
import com.halilibo.richtext.ui.string.RichTextStringStyle
|
||||
|
||||
/**
|
||||
* Composable function to display Markdown-formatted text.
|
||||
|
@ -56,6 +60,11 @@ fun MarkdownText(
|
|||
fontSize = MaterialTheme.typography.bodySmall.fontSize,
|
||||
fontFamily = FontFamily.Monospace,
|
||||
)
|
||||
),
|
||||
stringStyle = RichTextStringStyle(
|
||||
linkStyle = TextLinkStyles(
|
||||
style = SpanStyle(color = MaterialTheme.customColors.linkColor)
|
||||
)
|
||||
)
|
||||
),
|
||||
) {
|
||||
|
|
|
@ -46,7 +46,11 @@ fun MessageBodyWarning(message: ChatMessageWarning) {
|
|||
.clip(RoundedCornerShape(16.dp))
|
||||
.background(MaterialTheme.colorScheme.tertiaryContainer)
|
||||
) {
|
||||
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
|
||||
MarkdownText(
|
||||
text = message.content,
|
||||
modifier = Modifier.padding(horizontal = 16.dp, vertical = 6.dp),
|
||||
smallFontSize = true
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,28 +16,45 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.common.chat
|
||||
|
||||
import android.Manifest
|
||||
import android.content.Context
|
||||
import android.content.pm.PackageManager
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.BitmapFactory
|
||||
import android.graphics.Matrix
|
||||
import android.net.Uri
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.PickVisualMediaRequest
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.annotation.StringRes
|
||||
import androidx.compose.foundation.Image
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.rounded.Send
|
||||
import androidx.compose.material.icons.rounded.Add
|
||||
import androidx.compose.material.icons.rounded.Close
|
||||
import androidx.compose.material.icons.rounded.History
|
||||
import androidx.compose.material.icons.rounded.Photo
|
||||
import androidx.compose.material.icons.rounded.PhotoCamera
|
||||
import androidx.compose.material.icons.rounded.PostAdd
|
||||
import androidx.compose.material.icons.rounded.Stop
|
||||
import androidx.compose.material3.DropdownMenu
|
||||
import androidx.compose.material3.DropdownMenuItem
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.IconButtonDefaults
|
||||
|
@ -54,12 +71,17 @@ import androidx.compose.runtime.setValue
|
|||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.draw.shadow
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.asImageBitmap
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.core.content.ContextCompat
|
||||
import com.google.aiedge.gallery.R
|
||||
import com.google.aiedge.gallery.ui.common.createTempPictureUri
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
|
@ -74,24 +96,107 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
|||
fun MessageInputText(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
curMessage: String,
|
||||
isResettingSession: Boolean,
|
||||
inProgress: Boolean,
|
||||
hasImageMessage: Boolean,
|
||||
modelInitializing: Boolean,
|
||||
@StringRes textFieldPlaceHolderRes: Int,
|
||||
onValueChanged: (String) -> Unit,
|
||||
onSendMessage: (ChatMessage) -> Unit,
|
||||
onSendMessage: (List<ChatMessage>) -> Unit,
|
||||
onOpenPromptTemplatesClicked: () -> Unit = {},
|
||||
onStopButtonClicked: () -> Unit = {},
|
||||
showPromptTemplatesInMenu: Boolean = true,
|
||||
showPromptTemplatesInMenu: Boolean = false,
|
||||
showImagePickerInMenu: Boolean = false,
|
||||
showStopButtonWhenInProgress: Boolean = false,
|
||||
) {
|
||||
val context = LocalContext.current
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
var showAddContentMenu by remember { mutableStateOf(false) }
|
||||
var showTextInputHistorySheet by remember { mutableStateOf(false) }
|
||||
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
|
||||
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
|
||||
val updatePickedImages: (Bitmap) -> Unit = { bitmap ->
|
||||
val newPickedImages: MutableList<Bitmap> = mutableListOf()
|
||||
newPickedImages.addAll(pickedImages)
|
||||
newPickedImages.add(bitmap)
|
||||
pickedImages = newPickedImages.toList()
|
||||
}
|
||||
|
||||
// launches camera
|
||||
val cameraLauncher =
|
||||
rememberLauncherForActivityResult(ActivityResultContracts.TakePicture()) { isImageSaved ->
|
||||
if (isImageSaved) {
|
||||
handleImageSelected(
|
||||
context = context,
|
||||
uri = tempPhotoUri,
|
||||
onImageSelected = { bitmap ->
|
||||
updatePickedImages(bitmap)
|
||||
},
|
||||
rotateForPortrait = true,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Permission request when taking picture.
|
||||
val takePicturePermissionLauncher = rememberLauncherForActivityResult(
|
||||
ActivityResultContracts.RequestPermission()
|
||||
) { permissionGranted ->
|
||||
if (permissionGranted) {
|
||||
showAddContentMenu = false
|
||||
tempPhotoUri = context.createTempPictureUri()
|
||||
cameraLauncher.launch(tempPhotoUri)
|
||||
}
|
||||
}
|
||||
|
||||
// Registers a photo picker activity launcher in single-select mode.
|
||||
val pickMedia =
|
||||
rememberLauncherForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri ->
|
||||
// Callback is invoked after the user selects a media item or closes the
|
||||
// photo picker.
|
||||
if (uri != null) {
|
||||
handleImageSelected(context = context, uri = uri, onImageSelected = { bitmap ->
|
||||
updatePickedImages(bitmap)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Box(contentAlignment = Alignment.CenterStart) {
|
||||
// A preview panel for the selected image.
|
||||
if (pickedImages.isNotEmpty()) {
|
||||
Box(
|
||||
contentAlignment = Alignment.TopEnd, modifier = Modifier.offset(x = 16.dp, y = (-80).dp)
|
||||
) {
|
||||
Image(
|
||||
bitmap = pickedImages.last().asImageBitmap(),
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
.height(80.dp)
|
||||
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)),
|
||||
)
|
||||
Box(modifier = Modifier
|
||||
.offset(x = 10.dp, y = (-10).dp)
|
||||
.clip(CircleShape)
|
||||
.background(MaterialTheme.colorScheme.surface)
|
||||
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
|
||||
.clickable {
|
||||
pickedImages = listOf()
|
||||
}) {
|
||||
Icon(
|
||||
Icons.Rounded.Close,
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
.padding(3.dp)
|
||||
.size(16.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A plus button to show a popup menu to add stuff to the chat.
|
||||
IconButton(
|
||||
enabled = !inProgress,
|
||||
enabled = !inProgress && !isResettingSession,
|
||||
onClick = { showAddContentMenu = true },
|
||||
modifier = Modifier
|
||||
.offset(x = 16.dp)
|
||||
|
@ -113,6 +218,58 @@ fun MessageInputText(
|
|||
DropdownMenu(
|
||||
expanded = showAddContentMenu,
|
||||
onDismissRequest = { showAddContentMenu = false }) {
|
||||
if (showImagePickerInMenu) {
|
||||
// Take a photo.
|
||||
DropdownMenuItem(
|
||||
text = {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||
) {
|
||||
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
|
||||
Text("Take a photo")
|
||||
}
|
||||
},
|
||||
enabled = pickedImages.isEmpty() && !hasImageMessage,
|
||||
onClick = {
|
||||
// Check permission
|
||||
when (PackageManager.PERMISSION_GRANTED) {
|
||||
// Already got permission. Call the lambda.
|
||||
ContextCompat.checkSelfPermission(
|
||||
context, Manifest.permission.CAMERA
|
||||
) -> {
|
||||
showAddContentMenu = false
|
||||
tempPhotoUri = context.createTempPictureUri()
|
||||
cameraLauncher.launch(tempPhotoUri)
|
||||
}
|
||||
|
||||
// Otherwise, ask for permission
|
||||
else -> {
|
||||
takePicturePermissionLauncher.launch(Manifest.permission.CAMERA)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Pick an image from album.
|
||||
DropdownMenuItem(
|
||||
text = {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||
) {
|
||||
Icon(Icons.Rounded.Photo, contentDescription = "")
|
||||
Text("Pick from album")
|
||||
}
|
||||
},
|
||||
enabled = pickedImages.isEmpty() && !hasImageMessage,
|
||||
onClick = {
|
||||
// Launch the photo picker and let the user choose only images.
|
||||
pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly))
|
||||
showAddContentMenu = false
|
||||
})
|
||||
}
|
||||
|
||||
// Prompt templates.
|
||||
if (showPromptTemplatesInMenu) {
|
||||
DropdownMenuItem(text = {
|
||||
Row(
|
||||
|
@ -127,6 +284,7 @@ fun MessageInputText(
|
|||
showAddContentMenu = false
|
||||
})
|
||||
}
|
||||
// Prompt history.
|
||||
DropdownMenuItem(text = {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
|
@ -171,18 +329,19 @@ fun MessageInputText(
|
|||
),
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.Stop,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
Icons.Rounded.Stop, contentDescription = "", tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
} // Send button. Only shown when text is not empty.
|
||||
else if (curMessage.isNotEmpty()) {
|
||||
IconButton(
|
||||
enabled = !inProgress,
|
||||
enabled = !inProgress && !isResettingSession,
|
||||
onClick = {
|
||||
onSendMessage(ChatMessageText(content = curMessage.trim(), side = ChatSide.USER))
|
||||
onSendMessage(
|
||||
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
|
||||
)
|
||||
pickedImages = listOf()
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||
|
@ -200,7 +359,6 @@ fun MessageInputText(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// A bottom sheet to show the text input history to pick from.
|
||||
if (showTextInputHistorySheet) {
|
||||
TextInputHistorySheet(
|
||||
|
@ -209,7 +367,8 @@ fun MessageInputText(
|
|||
showTextInputHistorySheet = false
|
||||
},
|
||||
onHistoryItemClicked = { item ->
|
||||
onSendMessage(ChatMessageText(content = item, side = ChatSide.USER))
|
||||
onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
|
||||
pickedImages = listOf()
|
||||
modelManagerViewModel.promoteTextInputHistoryItem(item)
|
||||
},
|
||||
onHistoryItemDeleted = { item ->
|
||||
|
@ -217,9 +376,55 @@ fun MessageInputText(
|
|||
},
|
||||
onHistoryItemsDeleteAll = {
|
||||
modelManagerViewModel.clearTextInputHistory()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
private fun handleImageSelected(
|
||||
context: Context,
|
||||
uri: Uri,
|
||||
onImageSelected: (Bitmap) -> Unit,
|
||||
// For some reason, some Android phone would store the picture taken by the camera rotated
|
||||
// horizontally. Use this flag to rotate the image back to portrait if the picture's width
|
||||
// is bigger than height.
|
||||
rotateForPortrait: Boolean = false,
|
||||
) {
|
||||
val bitmap: Bitmap? = try {
|
||||
val inputStream = context.contentResolver.openInputStream(uri)
|
||||
val tmpBitmap = BitmapFactory.decodeStream(inputStream)
|
||||
if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) {
|
||||
val matrix = Matrix()
|
||||
matrix.postRotate(90f)
|
||||
Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true)
|
||||
} else {
|
||||
tmpBitmap
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
e.printStackTrace()
|
||||
null
|
||||
}
|
||||
if (bitmap != null) {
|
||||
onImageSelected(bitmap)
|
||||
}
|
||||
}
|
||||
|
||||
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> {
|
||||
val messages: MutableList<ChatMessage> = mutableListOf()
|
||||
if (pickedImages.isNotEmpty()) {
|
||||
val lastImage = pickedImages.last()
|
||||
messages.add(
|
||||
ChatMessageImage(
|
||||
bitmap = lastImage, imageBitMap = lastImage.asImageBitmap(), side = ChatSide.USER
|
||||
)
|
||||
)
|
||||
}
|
||||
messages.add(
|
||||
ChatMessageText(
|
||||
content = text, side = ChatSide.USER
|
||||
)
|
||||
)
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
|
@ -233,6 +438,21 @@ fun MessageInputTextPreview() {
|
|||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||
curMessage = "hello",
|
||||
inProgress = false,
|
||||
isResettingSession = false,
|
||||
modelInitializing = false,
|
||||
hasImageMessage = false,
|
||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||
onValueChanged = {},
|
||||
onSendMessage = {},
|
||||
showStopButtonWhenInProgress = true,
|
||||
showImagePickerInMenu = true,
|
||||
)
|
||||
MessageInputText(
|
||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||
curMessage = "hello",
|
||||
inProgress = false,
|
||||
isResettingSession = false,
|
||||
hasImageMessage = false,
|
||||
modelInitializing = false,
|
||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||
onValueChanged = {},
|
||||
|
@ -243,6 +463,8 @@ fun MessageInputTextPreview() {
|
|||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||
curMessage = "hello",
|
||||
inProgress = true,
|
||||
isResettingSession = false,
|
||||
hasImageMessage = false,
|
||||
modelInitializing = false,
|
||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||
onValueChanged = {},
|
||||
|
@ -252,6 +474,8 @@ fun MessageInputTextPreview() {
|
|||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||
curMessage = "",
|
||||
inProgress = false,
|
||||
isResettingSession = false,
|
||||
hasImageMessage = false,
|
||||
modelInitializing = false,
|
||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||
onValueChanged = {},
|
||||
|
@ -261,6 +485,8 @@ fun MessageInputTextPreview() {
|
|||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||
curMessage = "",
|
||||
inProgress = true,
|
||||
isResettingSession = false,
|
||||
hasImageMessage = false,
|
||||
modelInitializing = false,
|
||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||
onValueChanged = {},
|
||||
|
|
|
@ -52,6 +52,7 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.R
|
||||
|
@ -149,6 +150,8 @@ private fun SheetContent(
|
|||
Text(
|
||||
item,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
maxLines = 3,
|
||||
overflow = TextOverflow.Ellipsis,
|
||||
modifier = Modifier
|
||||
.padding(vertical = 16.dp)
|
||||
.padding(start = 16.dp)
|
||||
|
|
|
@ -76,7 +76,7 @@ fun ModelNameAndStatus(
|
|||
Text(
|
||||
model.name,
|
||||
maxLines = 1,
|
||||
overflow = TextOverflow.Ellipsis,
|
||||
overflow = TextOverflow.MiddleEllipsis,
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
modifier = modifier,
|
||||
)
|
||||
|
|
|
@ -136,7 +136,7 @@ fun HomeScreen(
|
|||
val snackbarHostState = remember { SnackbarHostState() }
|
||||
val scope = rememberCoroutineScope()
|
||||
|
||||
val tasks = uiState.tasks
|
||||
val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 }
|
||||
val loadingHfModels = uiState.loadingHfModels
|
||||
|
||||
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
||||
|
@ -183,14 +183,14 @@ fun HomeScreen(
|
|||
) { innerPadding ->
|
||||
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
|
||||
TaskList(
|
||||
tasks = tasks,
|
||||
tasks = nonEmptyTasks,
|
||||
navigateToTaskScreen = navigateToTaskScreen,
|
||||
loadingModelAllowlist = uiState.loadingModelAllowlist,
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
contentPadding = innerPadding,
|
||||
)
|
||||
|
||||
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 16.dp))
|
||||
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 32.dp))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -285,7 +285,7 @@ fun HomeScreen(
|
|||
|
||||
// Show a snack bar for successful import.
|
||||
scope.launch {
|
||||
snackbarHostState.showSnackbar("✅ Model imported successfully")
|
||||
snackbarHostState.showSnackbar("Model imported successfully")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -316,7 +316,6 @@ fun HomeScreen(
|
|||
}
|
||||
},
|
||||
)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ import androidx.compose.ui.unit.dp
|
|||
import androidx.compose.ui.window.Dialog
|
||||
import androidx.compose.ui.window.DialogProperties
|
||||
import com.google.aiedge.gallery.data.Accelerator
|
||||
import com.google.aiedge.gallery.data.BooleanSwitchConfig
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||
|
@ -111,6 +112,10 @@ private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
|
|||
defaultValue = DEFAULT_TEMPERATURE,
|
||||
valueType = ValueType.FLOAT
|
||||
),
|
||||
BooleanSwitchConfig(
|
||||
key = ConfigKey.SUPPORT_IMAGE,
|
||||
defaultValue = false,
|
||||
),
|
||||
SegmentedButtonConfig(
|
||||
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
||||
defaultValue = Accelerator.CPU.label,
|
||||
|
|
|
@ -50,7 +50,8 @@ fun ImageClassificationScreen(
|
|||
task = viewModel.task,
|
||||
viewModel = viewModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onSendMessage = { model, message ->
|
||||
onSendMessage = { model, messages ->
|
||||
val message = messages[0]
|
||||
viewModel.addMessage(
|
||||
model = model,
|
||||
message = message,
|
||||
|
|
|
@ -44,7 +44,8 @@ fun ImageGenerationScreen(
|
|||
task = viewModel.task,
|
||||
viewModel = viewModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onSendMessage = { model, message ->
|
||||
onSendMessage = { model, messages ->
|
||||
val message = messages[0]
|
||||
viewModel.addMessage(
|
||||
model = model,
|
||||
message = message,
|
||||
|
|
|
@ -17,11 +17,14 @@
|
|||
package com.google.aiedge.gallery.ui.llmchat
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.util.Log
|
||||
import com.google.aiedge.gallery.data.Accelerator
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder
|
||||
import com.google.mediapipe.tasks.genai.llminference.GraphOptions
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||
|
||||
|
@ -55,7 +58,9 @@ object LlmChatModelHelper {
|
|||
}
|
||||
val options =
|
||||
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
|
||||
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend).build()
|
||||
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend)
|
||||
.setMaxNumImages(if (model.llmSupportImage) 1 else 0)
|
||||
.build()
|
||||
|
||||
// Create an instance of the LLM Inference task
|
||||
try {
|
||||
|
@ -64,7 +69,10 @@ object LlmChatModelHelper {
|
|||
val session = LlmInferenceSession.createFromOptions(
|
||||
llmInference,
|
||||
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||
.setTemperature(temperature).build()
|
||||
.setTemperature(temperature)
|
||||
.setGraphOptions(
|
||||
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
||||
).build()
|
||||
)
|
||||
model.instance = LlmModelInstance(engine = llmInference, session = session)
|
||||
} catch (e: Exception) {
|
||||
|
@ -75,6 +83,8 @@ object LlmChatModelHelper {
|
|||
}
|
||||
|
||||
fun resetSession(model: Model) {
|
||||
Log.d(TAG, "Resetting session for model '${model.name}'")
|
||||
|
||||
val instance = model.instance as LlmModelInstance? ?: return
|
||||
val session = instance.session
|
||||
session.close()
|
||||
|
@ -87,9 +97,13 @@ object LlmChatModelHelper {
|
|||
val newSession = LlmInferenceSession.createFromOptions(
|
||||
inference,
|
||||
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||
.setTemperature(temperature).build()
|
||||
.setTemperature(temperature)
|
||||
.setGraphOptions(
|
||||
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
||||
).build()
|
||||
)
|
||||
instance.session = newSession
|
||||
Log.d(TAG, "Resetting done")
|
||||
}
|
||||
|
||||
fun cleanUp(model: Model) {
|
||||
|
@ -118,6 +132,7 @@ object LlmChatModelHelper {
|
|||
resultListener: ResultListener,
|
||||
cleanUpListener: CleanUpListener,
|
||||
singleTurn: Boolean = false,
|
||||
image: Bitmap? = null,
|
||||
) {
|
||||
if (singleTurn) {
|
||||
resetSession(model = model)
|
||||
|
@ -132,6 +147,9 @@ object LlmChatModelHelper {
|
|||
// Start async inference.
|
||||
val session = instance.session
|
||||
session.addQueryChunk(input)
|
||||
if (image != null) {
|
||||
session.addImage(BitmapImageBuilder(image).build())
|
||||
}
|
||||
session.generateResponseAsync(resultListener)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,15 +16,14 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.llmchat
|
||||
|
||||
import android.util.Log
|
||||
import android.graphics.Bitmap
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||
import com.google.aiedge.gallery.ui.ViewModelProvider
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageInfo
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageImage
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatView
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import kotlinx.serialization.Serializable
|
||||
|
@ -35,6 +34,11 @@ object LlmChatDestination {
|
|||
val route = "LlmChatRoute"
|
||||
}
|
||||
|
||||
object LlmImageToTextDestination {
|
||||
@Serializable
|
||||
val route = "LlmImageToTextRoute"
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun LlmChatScreen(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
|
@ -43,6 +47,38 @@ fun LlmChatScreen(
|
|||
viewModel: LlmChatViewModel = viewModel(
|
||||
factory = ViewModelProvider.Factory
|
||||
),
|
||||
) {
|
||||
ChatViewWrapper(
|
||||
viewModel = viewModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
navigateUp = navigateUp,
|
||||
modifier = modifier,
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun LlmImageToTextScreen(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
navigateUp: () -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
viewModel: LlmImageToTextViewModel = viewModel(
|
||||
factory = ViewModelProvider.Factory
|
||||
),
|
||||
) {
|
||||
ChatViewWrapper(
|
||||
viewModel = viewModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
navigateUp = navigateUp,
|
||||
modifier = modifier,
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ChatViewWrapper(
|
||||
viewModel: LlmChatViewModel,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
navigateUp: () -> Unit,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
val context = LocalContext.current
|
||||
|
||||
|
@ -50,21 +86,33 @@ fun LlmChatScreen(
|
|||
task = viewModel.task,
|
||||
viewModel = viewModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onSendMessage = { model, message ->
|
||||
viewModel.addMessage(
|
||||
model = model,
|
||||
message = message,
|
||||
)
|
||||
if (message is ChatMessageText) {
|
||||
modelManagerViewModel.addTextInputHistory(message.content)
|
||||
viewModel.generateResponse(model = model, input = message.content, onError = {
|
||||
viewModel.addMessage(
|
||||
model = model,
|
||||
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
|
||||
)
|
||||
onSendMessage = { model, messages ->
|
||||
for (message in messages) {
|
||||
viewModel.addMessage(
|
||||
model = model,
|
||||
message = message,
|
||||
)
|
||||
}
|
||||
|
||||
modelManagerViewModel.initializeModel(
|
||||
context = context, task = viewModel.task, model = model, force = true
|
||||
var text = ""
|
||||
var image: Bitmap? = null
|
||||
var chatMessageText: ChatMessageText? = null
|
||||
for (message in messages) {
|
||||
if (message is ChatMessageText) {
|
||||
chatMessageText = message
|
||||
text = message.content
|
||||
} else if (message is ChatMessageImage) {
|
||||
image = message.bitmap
|
||||
}
|
||||
}
|
||||
if (text.isNotEmpty() && chatMessageText != null) {
|
||||
modelManagerViewModel.addTextInputHistory(text)
|
||||
viewModel.generateResponse(model = model, input = text, image = image, onError = {
|
||||
viewModel.handleError(
|
||||
context = context,
|
||||
model = model,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
triggeredMessage = chatMessageText,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
@ -72,13 +120,11 @@ fun LlmChatScreen(
|
|||
onRunAgainClicked = { model, message ->
|
||||
if (message is ChatMessageText) {
|
||||
viewModel.runAgain(model = model, message = message, onError = {
|
||||
viewModel.addMessage(
|
||||
viewModel.handleError(
|
||||
context = context,
|
||||
model = model,
|
||||
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
|
||||
)
|
||||
|
||||
modelManagerViewModel.initializeModel(
|
||||
context = context, task = viewModel.task, model = model, force = true
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
triggeredMessage = message,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
@ -90,6 +136,9 @@ fun LlmChatScreen(
|
|||
)
|
||||
}
|
||||
},
|
||||
onResetSessionClicked = { model ->
|
||||
viewModel.resetSession(model = model)
|
||||
},
|
||||
showStopButtonInInputWhenInProgress = true,
|
||||
onStopButtonClicked = { model ->
|
||||
viewModel.stopResponse(model = model)
|
||||
|
@ -98,4 +147,3 @@ fun LlmChatScreen(
|
|||
modifier = modifier,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -16,18 +16,23 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.llmchat
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.util.Log
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatSide
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
|
||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||
import kotlinx.coroutines.CoroutineExceptionHandler
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
|
@ -40,8 +45,8 @@ private val STATS = listOf(
|
|||
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||
)
|
||||
|
||||
class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||
fun generateResponse(model: Model, input: String, onError: () -> Unit) {
|
||||
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
|
||||
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
setInProgress(true)
|
||||
|
||||
|
@ -58,7 +63,10 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
|
||||
// Run inference.
|
||||
val instance = model.instance as LlmModelInstance
|
||||
val prefillTokens = instance.session.sizeInTokens(input)
|
||||
var prefillTokens = instance.session.sizeInTokens(input)
|
||||
if (image != null) {
|
||||
prefillTokens += 257
|
||||
}
|
||||
|
||||
var firstRun = true
|
||||
var timeToFirstToken = 0f
|
||||
|
@ -69,9 +77,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
val start = System.currentTimeMillis()
|
||||
|
||||
try {
|
||||
LlmChatModelHelper.runInference(
|
||||
model = model,
|
||||
LlmChatModelHelper.runInference(model = model,
|
||||
input = input,
|
||||
image = image,
|
||||
resultListener = { partialResult, done ->
|
||||
val curTs = System.currentTimeMillis()
|
||||
|
||||
|
@ -92,32 +100,27 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
|
||||
// Add an empty message that will receive streaming results.
|
||||
addMessage(
|
||||
model = model,
|
||||
message = ChatMessageText(content = "", side = ChatSide.AGENT)
|
||||
model = model, message = ChatMessageText(content = "", side = ChatSide.AGENT)
|
||||
)
|
||||
}
|
||||
|
||||
// Incrementally update the streamed partial results.
|
||||
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
|
||||
updateLastTextMessageContentIncrementally(
|
||||
model = model,
|
||||
partialContent = partialResult,
|
||||
latencyMs = latencyMs.toFloat()
|
||||
model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat()
|
||||
)
|
||||
|
||||
if (done) {
|
||||
setInProgress(false)
|
||||
|
||||
decodeSpeed =
|
||||
decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||
if (decodeSpeed.isNaN()) {
|
||||
decodeSpeed = 0f
|
||||
}
|
||||
|
||||
if (lastMessage is ChatMessageText) {
|
||||
updateLastTextMessageLlmBenchmarkResult(
|
||||
model = model, llmBenchmarkResult =
|
||||
ChatMessageBenchmarkLlmResult(
|
||||
model = model, llmBenchmarkResult = ChatMessageBenchmarkLlmResult(
|
||||
orderedStats = STATS,
|
||||
statValues = mutableMapOf(
|
||||
"prefill_speed" to prefillSpeed,
|
||||
|
@ -131,10 +134,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
)
|
||||
}
|
||||
}
|
||||
}, cleanUpListener = {
|
||||
},
|
||||
cleanUpListener = {
|
||||
setInProgress(false)
|
||||
})
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error occurred while running inference", e)
|
||||
setInProgress(false)
|
||||
onError()
|
||||
}
|
||||
|
@ -143,6 +148,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
|
||||
fun stopResponse(model: Model) {
|
||||
Log.d(TAG, "Stopping response for model ${model.name}...")
|
||||
if (getLastMessage(model = model) is ChatMessageLoading) {
|
||||
removeLastMessage(model = model)
|
||||
}
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
setInProgress(false)
|
||||
val instance = model.instance as LlmModelInstance
|
||||
|
@ -150,6 +158,25 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
}
|
||||
}
|
||||
|
||||
fun resetSession(model: Model) {
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
setIsResettingSession(true)
|
||||
clearAllMessages(model = model)
|
||||
stopResponse(model = model)
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
LlmChatModelHelper.resetSession(model = model)
|
||||
break
|
||||
} catch (e: Exception) {
|
||||
Log.d(TAG, "Failed to reset session. Trying again")
|
||||
}
|
||||
delay(200)
|
||||
}
|
||||
setIsResettingSession(false)
|
||||
}
|
||||
}
|
||||
|
||||
fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) {
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
// Wait for model to be initialized.
|
||||
|
@ -162,9 +189,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
|
||||
// Run inference.
|
||||
generateResponse(
|
||||
model = model,
|
||||
input = message.content,
|
||||
onError = onError
|
||||
model = model, input = message.content, onError = onError
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -199,8 +224,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
var decodeSpeed: Float
|
||||
val start = System.currentTimeMillis()
|
||||
var lastUpdateTime = 0L
|
||||
LlmChatModelHelper.runInference(
|
||||
model = model,
|
||||
LlmChatModelHelper.runInference(model = model,
|
||||
input = message.content,
|
||||
resultListener = { partialResult, done ->
|
||||
val curTs = System.currentTimeMillis()
|
||||
|
@ -232,14 +256,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
result.append(partialResult)
|
||||
|
||||
if (curTs - lastUpdateTime > 500 || done) {
|
||||
decodeSpeed =
|
||||
decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||
if (decodeSpeed.isNaN()) {
|
||||
decodeSpeed = 0f
|
||||
}
|
||||
replaceLastMessage(
|
||||
model = model,
|
||||
message = ChatMessageBenchmarkLlmResult(
|
||||
model = model, message = ChatMessageBenchmarkLlmResult(
|
||||
orderedStats = STATS,
|
||||
statValues = mutableMapOf(
|
||||
"prefill_speed" to prefillSpeed,
|
||||
|
@ -249,8 +271,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
),
|
||||
running = !done,
|
||||
latencyMs = -1f,
|
||||
),
|
||||
type = ChatMessageType.BENCHMARK_LLM_RESULT
|
||||
), type = ChatMessageType.BENCHMARK_LLM_RESULT
|
||||
)
|
||||
lastUpdateTime = curTs
|
||||
|
||||
|
@ -261,9 +282,53 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
|||
},
|
||||
cleanUpListener = {
|
||||
setInProgress(false)
|
||||
}
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fun handleError(
|
||||
context: Context,
|
||||
model: Model,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
triggeredMessage: ChatMessageText,
|
||||
) {
|
||||
// Clean up.
|
||||
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||
|
||||
// Remove the "loading" message.
|
||||
if (getLastMessage(model = model) is ChatMessageLoading) {
|
||||
removeLastMessage(model = model)
|
||||
}
|
||||
|
||||
// Remove the last Text message.
|
||||
if (getLastMessage(model = model) == triggeredMessage) {
|
||||
removeLastMessage(model = model)
|
||||
}
|
||||
|
||||
// Add a warning message for re-initializing the session.
|
||||
addMessage(
|
||||
model = model,
|
||||
message = ChatMessageWarning(content = "Error occurred. Re-initializing the session.")
|
||||
)
|
||||
|
||||
// Add the triggered message back.
|
||||
addMessage(model = model, message = triggeredMessage)
|
||||
|
||||
// Re-initialize the session/engine.
|
||||
modelManagerViewModel.initializeModel(
|
||||
context = context, task = task, model = model
|
||||
)
|
||||
|
||||
// Re-generate the response automatically.
|
||||
generateResponse(model = model, input = triggeredMessage.content, onError = {
|
||||
handleError(
|
||||
context = context,
|
||||
model = model,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
triggeredMessage = triggeredMessage
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT)
|
|
@ -155,9 +155,14 @@ fun LlmSingleTurnScreen(
|
|||
PromptTemplatesPanel(
|
||||
model = selectedModel,
|
||||
viewModel = viewModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onSend = { fullPrompt ->
|
||||
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
|
||||
}, modifier = Modifier.fillMaxSize()
|
||||
},
|
||||
onStopButtonClicked = { model ->
|
||||
viewModel.stopResponse(model = model)
|
||||
},
|
||||
modifier = Modifier.fillMaxSize()
|
||||
)
|
||||
},
|
||||
bottomView = {
|
||||
|
|
|
@ -23,6 +23,7 @@ import com.google.aiedge.gallery.data.Model
|
|||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||
|
@ -194,6 +195,15 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
|
|||
}
|
||||
}
|
||||
|
||||
fun stopResponse(model: Model) {
|
||||
Log.d(TAG, "Stopping response for model ${model.name}...")
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
setInProgress(false)
|
||||
val instance = model.instance as LlmModelInstance
|
||||
instance.session.cancelGenerateResponseAsync()
|
||||
}
|
||||
}
|
||||
|
||||
private fun createUiState(task: Task): LlmSingleTurnUiState {
|
||||
val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf()
|
||||
val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> =
|
||||
|
|
|
@ -43,11 +43,13 @@ import androidx.compose.material.icons.outlined.Description
|
|||
import androidx.compose.material.icons.outlined.ExpandLess
|
||||
import androidx.compose.material.icons.outlined.ExpandMore
|
||||
import androidx.compose.material.icons.rounded.Add
|
||||
import androidx.compose.material.icons.rounded.Stop
|
||||
import androidx.compose.material.icons.rounded.Visibility
|
||||
import androidx.compose.material.icons.rounded.VisibilityOff
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.FilterChipDefaults
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.IconButtonDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
|
@ -86,9 +88,11 @@ import androidx.compose.ui.unit.dp
|
|||
import com.google.aiedge.gallery.R
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries
|
||||
private val TAB_TITLES = PromptTemplateType.entries.map { it.label }
|
||||
|
@ -101,11 +105,14 @@ const val FULL_PROMPT_SWITCH_KEY = "full_prompt"
|
|||
fun PromptTemplatesPanel(
|
||||
model: Model,
|
||||
viewModel: LlmSingleTurnViewModel,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
onSend: (fullPrompt: String) -> Unit,
|
||||
onStopButtonClicked: (Model) -> Unit,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
val scope = rememberCoroutineScope()
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||
val inProgress = uiState.inProgress
|
||||
var selectedTabIndex by remember { mutableIntStateOf(0) }
|
||||
|
@ -123,6 +130,8 @@ fun PromptTemplatesPanel(
|
|||
val focusManager = LocalFocusManager.current
|
||||
val interactionSource = remember { MutableInteractionSource() }
|
||||
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
|
||||
val modelInitializationStatus =
|
||||
modelManagerUiState.modelInitializationStatus[model.name]
|
||||
|
||||
// Update input editor values when prompt template changes.
|
||||
LaunchedEffect(selectedPromptTemplateType) {
|
||||
|
@ -328,29 +337,49 @@ fun PromptTemplatesPanel(
|
|||
)
|
||||
}
|
||||
|
||||
// Send button
|
||||
OutlinedIconButton(
|
||||
enabled = !inProgress && curTextInputContent.isNotEmpty(),
|
||||
onClick = {
|
||||
focusManager.clearFocus()
|
||||
onSend(fullPrompt.text)
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
|
||||
contentColor = MaterialTheme.colorScheme.primary,
|
||||
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
|
||||
),
|
||||
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
||||
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||
) {
|
||||
Icon(
|
||||
Icons.AutoMirrored.Rounded.Send,
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
.size(20.dp)
|
||||
.offset(x = 2.dp),
|
||||
)
|
||||
val modelInitializing =
|
||||
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
||||
if (inProgress && !modelInitializing) {
|
||||
IconButton(
|
||||
onClick = {
|
||||
onStopButtonClicked(model)
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||
),
|
||||
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.Stop,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Send button
|
||||
OutlinedIconButton(
|
||||
enabled = !inProgress && curTextInputContent.isNotEmpty(),
|
||||
onClick = {
|
||||
focusManager.clearFocus()
|
||||
onSend(fullPrompt.text)
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
|
||||
contentColor = MaterialTheme.colorScheme.primary,
|
||||
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
|
||||
),
|
||||
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
||||
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||
) {
|
||||
Icon(
|
||||
Icons.AutoMirrored.Rounded.Send,
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
.size(20.dp)
|
||||
.offset(x = 2.dp),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,10 @@ import androidx.compose.foundation.layout.fillMaxSize
|
|||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
|
@ -48,6 +52,24 @@ fun ModelManager(
|
|||
if (task.models.size != 1) {
|
||||
title += "s"
|
||||
}
|
||||
// Model count.
|
||||
val modelCount by remember {
|
||||
derivedStateOf {
|
||||
val trigger = task.updateTrigger.value
|
||||
if (trigger >= 0) {
|
||||
task.models.size
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Navigate up when there are no models left.
|
||||
LaunchedEffect(modelCount) {
|
||||
if (modelCount == 0) {
|
||||
navigateUp()
|
||||
}
|
||||
}
|
||||
|
||||
// Handle system's edge swipe.
|
||||
BackHandler {
|
||||
|
|
|
@ -38,6 +38,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatus
|
|||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.TASKS
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.TaskType
|
||||
|
@ -58,6 +59,7 @@ import kotlinx.coroutines.flow.asStateFlow
|
|||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.ExperimentalSerializationApi
|
||||
import kotlinx.serialization.json.Json
|
||||
import net.openid.appauth.AuthorizationException
|
||||
import net.openid.appauth.AuthorizationRequest
|
||||
|
@ -229,7 +231,10 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
dataStoreRepository.saveImportedModels(importedModels = importedModels)
|
||||
}
|
||||
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
|
||||
val newUiState = uiState.value.copy(
|
||||
modelDownloadStatus = curModelDownloadStatus,
|
||||
tasks = uiState.value.tasks.toList()
|
||||
)
|
||||
_uiState.update { newUiState }
|
||||
}
|
||||
|
||||
|
@ -312,6 +317,12 @@ open class ModelManagerViewModel(
|
|||
onDone = onDone,
|
||||
)
|
||||
|
||||
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.initialize(
|
||||
context = context,
|
||||
model = model,
|
||||
onDone = onDone,
|
||||
)
|
||||
|
||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
|
||||
context = context, model = model, onDone = onDone
|
||||
)
|
||||
|
@ -331,6 +342,7 @@ open class ModelManagerViewModel(
|
|||
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
TaskType.TEST_TASK_2 -> {}
|
||||
|
@ -434,14 +446,16 @@ open class ModelManagerViewModel(
|
|||
// Create model.
|
||||
val model = createModelFromImportedModelInfo(info = info)
|
||||
|
||||
// Remove duplicated imported model if existed.
|
||||
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES)) {
|
||||
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES, TASK_LLM_IMAGE_TO_TEXT)) {
|
||||
// Remove duplicated imported model if existed.
|
||||
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
||||
if (modelIndex >= 0) {
|
||||
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
||||
task.models.removeAt(modelIndex)
|
||||
}
|
||||
task.models.add(model)
|
||||
if (task == TASK_LLM_IMAGE_TO_TEXT && model.llmSupportImage || task != TASK_LLM_IMAGE_TO_TEXT) {
|
||||
task.models.add(model)
|
||||
}
|
||||
task.updateTrigger.value = System.currentTimeMillis()
|
||||
}
|
||||
|
||||
|
@ -632,8 +646,7 @@ open class ModelManagerViewModel(
|
|||
fun loadModelAllowlist() {
|
||||
_uiState.update {
|
||||
uiState.value.copy(
|
||||
loadingModelAllowlist = true,
|
||||
loadingModelAllowlistError = ""
|
||||
loadingModelAllowlist = true, loadingModelAllowlistError = ""
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -663,6 +676,9 @@ open class ModelManagerViewModel(
|
|||
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
|
||||
TASK_LLM_USECASES.models.add(model)
|
||||
}
|
||||
if (allowedModel.taskTypes.contains(TASK_LLM_IMAGE_TO_TEXT.type.id)) {
|
||||
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-process all tasks.
|
||||
|
@ -717,6 +733,9 @@ open class ModelManagerViewModel(
|
|||
// Add to task.
|
||||
TASK_LLM_CHAT.models.add(model)
|
||||
TASK_LLM_USECASES.models.add(model)
|
||||
if (model.llmSupportImage) {
|
||||
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
|
||||
}
|
||||
|
||||
// Update status.
|
||||
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||
|
@ -731,7 +750,7 @@ open class ModelManagerViewModel(
|
|||
|
||||
Log.d(TAG, "model download status: $modelDownloadStatus")
|
||||
return ModelManagerUiState(
|
||||
tasks = TASKS,
|
||||
tasks = TASKS.toList(),
|
||||
modelDownloadStatus = modelDownloadStatus,
|
||||
modelInitializationStatus = modelInstances,
|
||||
textInputHistory = textInputHistory,
|
||||
|
@ -763,6 +782,9 @@ open class ModelManagerViewModel(
|
|||
) as Float,
|
||||
accelerators = accelerators,
|
||||
)
|
||||
val llmSupportImage = convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.SUPPORT_IMAGE.label] ?: false, ValueType.BOOLEAN
|
||||
) as Boolean
|
||||
val model = Model(
|
||||
name = info.fileName,
|
||||
url = "",
|
||||
|
@ -770,7 +792,9 @@ open class ModelManagerViewModel(
|
|||
sizeInBytes = info.fileSize,
|
||||
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
|
||||
showBenchmarkButton = false,
|
||||
showRunAgainButton = false,
|
||||
imported = true,
|
||||
llmSupportImage = llmSupportImage,
|
||||
)
|
||||
model.preProcess()
|
||||
|
||||
|
@ -803,6 +827,7 @@ open class ModelManagerViewModel(
|
|||
)
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalSerializationApi::class)
|
||||
private inline fun <reified T> getJsonResponse(url: String): T? {
|
||||
try {
|
||||
val connection = URL(url).openConnection() as HttpURLConnection
|
||||
|
@ -815,7 +840,12 @@ open class ModelManagerViewModel(
|
|||
val response = inputStream.bufferedReader().use { it.readText() }
|
||||
|
||||
// Parse JSON using kotlinx.serialization
|
||||
val json = Json { ignoreUnknownKeys = true } // Handle potential extra fields
|
||||
val json = Json {
|
||||
// Handle potential extra fields
|
||||
ignoreUnknownKeys = true
|
||||
allowComments = true
|
||||
allowTrailingComma = true
|
||||
}
|
||||
val jsonObj = json.decodeFromString<T>(response)
|
||||
return jsonObj
|
||||
} else {
|
||||
|
|
|
@ -46,6 +46,7 @@ import com.google.aiedge.gallery.data.Model
|
|||
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
||||
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
|
@ -59,6 +60,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
|
|||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextDestination
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextScreen
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
|
||||
|
@ -129,8 +132,7 @@ fun GalleryNavHost(
|
|||
) {
|
||||
val curPickedTask = pickedTask
|
||||
if (curPickedTask != null) {
|
||||
ModelManager(
|
||||
viewModel = modelManagerViewModel,
|
||||
ModelManager(viewModel = modelManagerViewModel,
|
||||
task = curPickedTask,
|
||||
onModelClicked = { model ->
|
||||
navigateToTaskScreen(
|
||||
|
@ -207,7 +209,7 @@ fun GalleryNavHost(
|
|||
}
|
||||
}
|
||||
|
||||
// LLMm chat demos.
|
||||
// LLM chat demos.
|
||||
composable(
|
||||
route = "${LlmChatDestination.route}/{modelName}",
|
||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||
|
@ -224,7 +226,7 @@ fun GalleryNavHost(
|
|||
}
|
||||
}
|
||||
|
||||
// LLMm single turn.
|
||||
// LLM single turn.
|
||||
composable(
|
||||
route = "${LlmSingleTurnDestination.route}/{modelName}",
|
||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||
|
@ -241,6 +243,22 @@ fun GalleryNavHost(
|
|||
}
|
||||
}
|
||||
|
||||
// LLM image to text.
|
||||
composable(
|
||||
route = "${LlmImageToTextDestination.route}/{modelName}",
|
||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||
enterTransition = { slideEnter() },
|
||||
exitTransition = { slideExit() },
|
||||
) {
|
||||
getModelFromNavigationParam(it, TASK_LLM_IMAGE_TO_TEXT)?.let { defaultModel ->
|
||||
modelManagerViewModel.selectModel(defaultModel)
|
||||
|
||||
LlmImageToTextScreen(
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
navigateUp = { navController.navigateUp() },
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle incoming intents for deep links
|
||||
|
@ -254,9 +272,7 @@ fun GalleryNavHost(
|
|||
getModelByName(modelName)?.let { model ->
|
||||
// TODO(jingjin): need to show a list of possible tasks for this model.
|
||||
navigateToTaskScreen(
|
||||
navController = navController,
|
||||
taskType = TaskType.LLM_CHAT,
|
||||
model = model
|
||||
navController = navController, taskType = TaskType.LLM_CHAT, model = model
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -271,6 +287,7 @@ fun navigateToTaskScreen(
|
|||
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
||||
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
||||
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
||||
TaskType.LLM_IMAGE_TO_TEXT -> navController.navigate("${LlmImageToTextDestination.route}/${modelName}")
|
||||
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
|
|
|
@ -44,7 +44,8 @@ fun TextClassificationScreen(
|
|||
task = viewModel.task,
|
||||
viewModel = viewModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onSendMessage = { model, message ->
|
||||
onSendMessage = { model, messages ->
|
||||
val message = messages[0]
|
||||
viewModel.addMessage(
|
||||
model = model,
|
||||
message = message,
|
||||
|
|
|
@ -7,7 +7,7 @@ junitVersion = "1.2.1"
|
|||
espressoCore = "3.6.1"
|
||||
lifecycleRuntimeKtx = "2.8.7"
|
||||
activityCompose = "1.10.1"
|
||||
composeBom = "2025.03.01"
|
||||
composeBom = "2025.05.00"
|
||||
navigation = "2.8.9"
|
||||
serializationPlugin = "2.0.21"
|
||||
serializationJson = "1.7.3"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue