mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-06 06:30:30 -04:00
Add support for image to text models
This commit is contained in:
parent
ef290cd7b0
commit
bedc488a15
29 changed files with 763 additions and 231 deletions
|
@ -19,7 +19,6 @@ package com.google.aiedge.gallery.data
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
||||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||||
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
data class ModelDataFile(
|
data class ModelDataFile(
|
||||||
|
@ -91,6 +90,9 @@ data class Model(
|
||||||
/** The prompt templates for the model (only for LLM). */
|
/** The prompt templates for the model (only for LLM). */
|
||||||
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
||||||
|
|
||||||
|
/** Whether the LLM model supports image input. */
|
||||||
|
val llmSupportImage: Boolean = false,
|
||||||
|
|
||||||
/** Whether the model is imported or not. */
|
/** Whether the model is imported or not. */
|
||||||
val imported: Boolean = false,
|
val imported: Boolean = false,
|
||||||
|
|
||||||
|
@ -204,6 +206,7 @@ enum class ConfigKey(val label: String) {
|
||||||
DEFAULT_TOPK("Default TopK"),
|
DEFAULT_TOPK("Default TopK"),
|
||||||
DEFAULT_TOPP("Default TopP"),
|
DEFAULT_TOPP("Default TopP"),
|
||||||
DEFAULT_TEMPERATURE("Default temperature"),
|
DEFAULT_TEMPERATURE("Default temperature"),
|
||||||
|
SUPPORT_IMAGE("Support image"),
|
||||||
MAX_RESULT_COUNT("Max result count"),
|
MAX_RESULT_COUNT("Max result count"),
|
||||||
USE_GPU("Use GPU"),
|
USE_GPU("Use GPU"),
|
||||||
ACCELERATOR("Accelerator"),
|
ACCELERATOR("Accelerator"),
|
||||||
|
@ -250,83 +253,9 @@ const val IMAGE_CLASSIFICATION_INFO = ""
|
||||||
|
|
||||||
const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/litert/android"
|
const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/litert/android"
|
||||||
|
|
||||||
const val LLM_CHAT_INFO =
|
|
||||||
"Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms"
|
|
||||||
|
|
||||||
const val IMAGE_GENERATION_INFO =
|
const val IMAGE_GENERATION_INFO =
|
||||||
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
|
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Model spec.
|
|
||||||
|
|
||||||
val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
|
|
||||||
name = "Gemma 2B (GPU int4)",
|
|
||||||
downloadFileName = "gemma-2b-it-gpu-int4.bin",
|
|
||||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
|
|
||||||
sizeInBytes = 1354301440L,
|
|
||||||
configs = createLlmChatConfigs(
|
|
||||||
accelerators = listOf(Accelerator.GPU)
|
|
||||||
),
|
|
||||||
showBenchmarkButton = false,
|
|
||||||
info = LLM_CHAT_INFO,
|
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
|
||||||
)
|
|
||||||
|
|
||||||
val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
|
|
||||||
name = "Gemma 2 2B (GPU int8)",
|
|
||||||
downloadFileName = "gemma2-2b-it-gpu-int8.bin",
|
|
||||||
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
|
|
||||||
sizeInBytes = 2627141632L,
|
|
||||||
configs = createLlmChatConfigs(
|
|
||||||
accelerators = listOf(Accelerator.GPU)
|
|
||||||
),
|
|
||||||
showBenchmarkButton = false,
|
|
||||||
info = LLM_CHAT_INFO,
|
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community",
|
|
||||||
)
|
|
||||||
|
|
||||||
val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
|
||||||
name = "Gemma 3 1B (int4)",
|
|
||||||
downloadFileName = "gemma3-1b-it-int4.task",
|
|
||||||
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
|
|
||||||
sizeInBytes = 554661243L,
|
|
||||||
configs = createLlmChatConfigs(
|
|
||||||
defaultTopK = 64,
|
|
||||||
defaultTopP = 0.95f,
|
|
||||||
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
|
|
||||||
),
|
|
||||||
showBenchmarkButton = false,
|
|
||||||
info = LLM_CHAT_INFO,
|
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
|
||||||
llmPromptTemplates = listOf(
|
|
||||||
PromptTemplate(
|
|
||||||
title = "Emoji Fun",
|
|
||||||
description = "Generate emojis by emotions",
|
|
||||||
prompt = "Show me emojis grouped by emotions"
|
|
||||||
),
|
|
||||||
PromptTemplate(
|
|
||||||
title = "Trip Planner",
|
|
||||||
description = "Plan a trip to a destination",
|
|
||||||
prompt = "Plan a two-day trip to San Francisco"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
val MODEL_LLM_DEEPSEEK: Model = Model(
|
|
||||||
name = "Deepseek",
|
|
||||||
downloadFileName = "deepseek.task",
|
|
||||||
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
|
|
||||||
sizeInBytes = 1860686856L,
|
|
||||||
configs = createLlmChatConfigs(
|
|
||||||
defaultTemperature = 0.6f,
|
|
||||||
defaultTopK = 40,
|
|
||||||
defaultTopP = 0.7f,
|
|
||||||
accelerators = listOf(Accelerator.CPU)
|
|
||||||
),
|
|
||||||
showBenchmarkButton = false,
|
|
||||||
info = LLM_CHAT_INFO,
|
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
|
||||||
)
|
|
||||||
|
|
||||||
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
|
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
|
||||||
name = "MobileBert",
|
name = "MobileBert",
|
||||||
|
@ -414,12 +343,5 @@ val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
|
||||||
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
|
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
|
||||||
)
|
)
|
||||||
|
|
||||||
val MODELS_LLM: MutableList<Model> = mutableListOf(
|
|
||||||
// MODEL_LLM_GEMMA_2B_GPU_INT4,
|
|
||||||
// MODEL_LLM_GEMMA_2_2B_GPU_INT8,
|
|
||||||
MODEL_LLM_GEMMA_3_1B_INT4,
|
|
||||||
MODEL_LLM_DEEPSEEK,
|
|
||||||
)
|
|
||||||
|
|
||||||
val MODELS_IMAGE_GENERATION: MutableList<Model> =
|
val MODELS_IMAGE_GENERATION: MutableList<Model> =
|
||||||
mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION)
|
mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION)
|
||||||
|
|
|
@ -35,6 +35,7 @@ data class AllowedModel(
|
||||||
val defaultConfig: Map<String, ConfigValue>,
|
val defaultConfig: Map<String, ConfigValue>,
|
||||||
val taskTypes: List<String>,
|
val taskTypes: List<String>,
|
||||||
val disabled: Boolean? = null,
|
val disabled: Boolean? = null,
|
||||||
|
val llmSupportImage: Boolean? = null,
|
||||||
) {
|
) {
|
||||||
fun toModel(): Model {
|
fun toModel(): Model {
|
||||||
// Construct HF download url.
|
// Construct HF download url.
|
||||||
|
@ -48,7 +49,7 @@ data class AllowedModel(
|
||||||
var defaultTopK: Int = DEFAULT_TOPK
|
var defaultTopK: Int = DEFAULT_TOPK
|
||||||
var defaultTopP: Float = DEFAULT_TOPP
|
var defaultTopP: Float = DEFAULT_TOPP
|
||||||
var defaultTemperature: Float = DEFAULT_TEMPERATURE
|
var defaultTemperature: Float = DEFAULT_TEMPERATURE
|
||||||
var defaultMaxToken: Int = 1024
|
var defaultMaxToken = 1024
|
||||||
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
|
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
|
||||||
if (defaultConfig.containsKey("topK")) {
|
if (defaultConfig.containsKey("topK")) {
|
||||||
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK)
|
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK)
|
||||||
|
@ -84,9 +85,10 @@ data class AllowedModel(
|
||||||
|
|
||||||
// Misc.
|
// Misc.
|
||||||
var showBenchmarkButton = true
|
var showBenchmarkButton = true
|
||||||
val showRunAgainButton = true
|
var showRunAgainButton = true
|
||||||
if (taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)) {
|
if (isLlmModel) {
|
||||||
showBenchmarkButton = false
|
showBenchmarkButton = false
|
||||||
|
showRunAgainButton = false
|
||||||
}
|
}
|
||||||
|
|
||||||
return Model(
|
return Model(
|
||||||
|
@ -99,7 +101,8 @@ data class AllowedModel(
|
||||||
downloadFileName = modelFile,
|
downloadFileName = modelFile,
|
||||||
showBenchmarkButton = showBenchmarkButton,
|
showBenchmarkButton = showBenchmarkButton,
|
||||||
showRunAgainButton = showRunAgainButton,
|
showRunAgainButton = showRunAgainButton,
|
||||||
learnMoreUrl = "https://huggingface.co/${modelId}"
|
learnMoreUrl = "https://huggingface.co/${modelId}",
|
||||||
|
llmSupportImage = llmSupportImage == true,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data
|
||||||
import androidx.annotation.StringRes
|
import androidx.annotation.StringRes
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.outlined.Forum
|
import androidx.compose.material.icons.outlined.Forum
|
||||||
|
import androidx.compose.material.icons.outlined.Mms
|
||||||
import androidx.compose.material.icons.outlined.Widgets
|
import androidx.compose.material.icons.outlined.Widgets
|
||||||
import androidx.compose.material.icons.rounded.ImageSearch
|
import androidx.compose.material.icons.rounded.ImageSearch
|
||||||
import androidx.compose.runtime.MutableState
|
import androidx.compose.runtime.MutableState
|
||||||
|
@ -33,6 +34,7 @@ enum class TaskType(val label: String, val id: String) {
|
||||||
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
|
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
|
||||||
LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
|
LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
|
||||||
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
|
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
|
||||||
|
LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"),
|
||||||
|
|
||||||
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
||||||
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
|
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
|
||||||
|
@ -93,7 +95,7 @@ val TASK_LLM_CHAT = Task(
|
||||||
icon = Icons.Outlined.Forum,
|
icon = Icons.Outlined.Forum,
|
||||||
// models = MODELS_LLM,
|
// models = MODELS_LLM,
|
||||||
models = mutableListOf(),
|
models = mutableListOf(),
|
||||||
description = "Chat with a on-device large language model",
|
description = "Chat with on-device large language models",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||||
|
@ -110,6 +112,17 @@ val TASK_LLM_USECASES = Task(
|
||||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val TASK_LLM_IMAGE_TO_TEXT = Task(
|
||||||
|
type = TaskType.LLM_IMAGE_TO_TEXT,
|
||||||
|
icon = Icons.Outlined.Mms,
|
||||||
|
// models = MODELS_LLM,
|
||||||
|
models = mutableListOf(),
|
||||||
|
description = "Ask questions about images with on-device large language models",
|
||||||
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||||
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||||
|
)
|
||||||
|
|
||||||
val TASK_IMAGE_GENERATION = Task(
|
val TASK_IMAGE_GENERATION = Task(
|
||||||
type = TaskType.IMAGE_GENERATION,
|
type = TaskType.IMAGE_GENERATION,
|
||||||
iconVectorResourceId = R.drawable.image_spark,
|
iconVectorResourceId = R.drawable.image_spark,
|
||||||
|
@ -127,6 +140,7 @@ val TASKS: List<Task> = listOf(
|
||||||
// TASK_IMAGE_GENERATION,
|
// TASK_IMAGE_GENERATION,
|
||||||
TASK_LLM_USECASES,
|
TASK_LLM_USECASES,
|
||||||
TASK_LLM_CHAT,
|
TASK_LLM_CHAT,
|
||||||
|
TASK_LLM_IMAGE_TO_TEXT
|
||||||
)
|
)
|
||||||
|
|
||||||
fun getModelByName(name: String): Model? {
|
fun getModelByName(name: String): Model? {
|
||||||
|
|
|
@ -25,6 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
|
||||||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
|
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel
|
||||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
|
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
|
||||||
|
@ -62,6 +63,12 @@ object ViewModelProvider {
|
||||||
LlmSingleTurnViewModel()
|
LlmSingleTurnViewModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initializer for LlmImageToTextViewModel.
|
||||||
|
initializer {
|
||||||
|
LlmImageToTextViewModel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initializer for ImageGenerationViewModel.
|
||||||
initializer {
|
initializer {
|
||||||
ImageGenerationViewModel()
|
ImageGenerationViewModel()
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,13 +17,17 @@
|
||||||
package com.google.aiedge.gallery.ui.common
|
package com.google.aiedge.gallery.ui.common
|
||||||
|
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.offset
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
|
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
|
||||||
import androidx.compose.material.icons.rounded.Settings
|
import androidx.compose.material.icons.rounded.MapsUgc
|
||||||
|
import androidx.compose.material.icons.rounded.Tune
|
||||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||||
|
import androidx.compose.material3.CircularProgressIndicator
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.IconButton
|
import androidx.compose.material3.IconButton
|
||||||
|
@ -38,6 +42,7 @@ import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.alpha
|
import androidx.compose.ui.draw.alpha
|
||||||
|
import androidx.compose.ui.draw.scale
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.res.vectorResource
|
import androidx.compose.ui.res.vectorResource
|
||||||
|
@ -47,6 +52,7 @@ import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
|
@ -58,12 +64,17 @@ fun ModelPageAppBar(
|
||||||
onBackClicked: () -> Unit,
|
onBackClicked: () -> Unit,
|
||||||
onModelSelected: (Model) -> Unit,
|
onModelSelected: (Model) -> Unit,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
|
isResettingSession: Boolean = false,
|
||||||
|
onResetSessionClicked: (Model) -> Unit = {},
|
||||||
|
showResetSessionButton: Boolean = false,
|
||||||
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
|
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
|
||||||
) {
|
) {
|
||||||
var showConfigDialog by remember { mutableStateOf(false) }
|
var showConfigDialog by remember { mutableStateOf(false) }
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
||||||
|
val modelInitializationStatus =
|
||||||
|
modelManagerUiState.modelInitializationStatus[model.name]
|
||||||
|
|
||||||
CenterAlignedTopAppBar(title = {
|
CenterAlignedTopAppBar(title = {
|
||||||
Column(
|
Column(
|
||||||
|
@ -110,19 +121,58 @@ fun ModelPageAppBar(
|
||||||
actions = {
|
actions = {
|
||||||
val showConfigButton =
|
val showConfigButton =
|
||||||
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
|
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
|
||||||
IconButton(
|
Box(modifier = Modifier.size(42.dp), contentAlignment = Alignment.Center) {
|
||||||
onClick = {
|
var configButtonOffset = 0.dp
|
||||||
showConfigDialog = true
|
if (showConfigButton && showResetSessionButton) {
|
||||||
},
|
configButtonOffset = (-40).dp
|
||||||
enabled = showConfigButton,
|
}
|
||||||
modifier = Modifier.alpha(if (showConfigButton) 1f else 0f)
|
val isModelInitializing =
|
||||||
) {
|
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
||||||
Icon(
|
if (showConfigButton) {
|
||||||
imageVector = Icons.Rounded.Settings,
|
IconButton(
|
||||||
contentDescription = "",
|
onClick = {
|
||||||
tint = MaterialTheme.colorScheme.primary
|
showConfigDialog = true
|
||||||
)
|
},
|
||||||
|
enabled = !isModelInitializing,
|
||||||
|
modifier = Modifier
|
||||||
|
.scale(0.75f)
|
||||||
|
.offset(x = configButtonOffset)
|
||||||
|
.alpha(if (isModelInitializing) 0.5f else 1f)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Rounded.Tune,
|
||||||
|
contentDescription = "",
|
||||||
|
tint = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (showResetSessionButton) {
|
||||||
|
if (isResettingSession) {
|
||||||
|
CircularProgressIndicator(
|
||||||
|
trackColor = MaterialTheme.colorScheme.surfaceVariant,
|
||||||
|
strokeWidth = 2.dp,
|
||||||
|
modifier = Modifier.size(16.dp)
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
IconButton(
|
||||||
|
onClick = {
|
||||||
|
onResetSessionClicked(model)
|
||||||
|
},
|
||||||
|
enabled = !isModelInitializing,
|
||||||
|
modifier = Modifier
|
||||||
|
.scale(0.75f)
|
||||||
|
.alpha(if (isModelInitializing) 0.5f else 1f)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Rounded.MapsUgc,
|
||||||
|
contentDescription = "",
|
||||||
|
tint = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// Config dialog.
|
// Config dialog.
|
||||||
|
|
|
@ -29,6 +29,7 @@ import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
|
import androidx.compose.foundation.layout.widthIn
|
||||||
import androidx.compose.foundation.pager.HorizontalPager
|
import androidx.compose.foundation.pager.HorizontalPager
|
||||||
import androidx.compose.foundation.pager.rememberPagerState
|
import androidx.compose.foundation.pager.rememberPagerState
|
||||||
import androidx.compose.foundation.shape.CircleShape
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
|
@ -54,6 +55,9 @@ import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.alpha
|
import androidx.compose.ui.draw.alpha
|
||||||
import androidx.compose.ui.draw.clip
|
import androidx.compose.ui.draw.clip
|
||||||
import androidx.compose.ui.graphics.graphicsLayer
|
import androidx.compose.ui.graphics.graphicsLayer
|
||||||
|
import androidx.compose.ui.platform.LocalDensity
|
||||||
|
import androidx.compose.ui.platform.LocalWindowInfo
|
||||||
|
import androidx.compose.ui.text.style.TextOverflow
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
@ -77,6 +81,10 @@ fun ModelPickerChipsPager(
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
|
val density = LocalDensity.current
|
||||||
|
val windowInfo = LocalWindowInfo.current
|
||||||
|
val screenWidthDp =
|
||||||
|
remember { with(density) { windowInfo.containerSize.width.toDp() } }
|
||||||
|
|
||||||
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
|
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
|
||||||
pageCount = { task.models.size })
|
pageCount = { task.models.size })
|
||||||
|
@ -140,7 +148,12 @@ fun ModelPickerChipsPager(
|
||||||
Text(
|
Text(
|
||||||
model.name,
|
model.name,
|
||||||
style = MaterialTheme.typography.labelSmall,
|
style = MaterialTheme.typography.labelSmall,
|
||||||
modifier = Modifier.padding(start = 4.dp),
|
modifier = Modifier
|
||||||
|
.padding(start = 4.dp)
|
||||||
|
.widthIn(0.dp, screenWidthDp - 250.dp),
|
||||||
|
maxLines = 1,
|
||||||
|
overflow = TextOverflow.MiddleEllipsis
|
||||||
|
|
||||||
)
|
)
|
||||||
Icon(
|
Icon(
|
||||||
Icons.Rounded.ArrowDropDown,
|
Icons.Rounded.ArrowDropDown,
|
||||||
|
|
|
@ -109,7 +109,7 @@ fun ChatPanel(
|
||||||
task: Task,
|
task: Task,
|
||||||
selectedModel: Model,
|
selectedModel: Model,
|
||||||
viewModel: ChatViewModel,
|
viewModel: ChatViewModel,
|
||||||
onSendMessage: (Model, ChatMessage) -> Unit,
|
onSendMessage: (Model, List<ChatMessage>) -> Unit,
|
||||||
onRunAgainClicked: (Model, ChatMessage) -> Unit,
|
onRunAgainClicked: (Model, ChatMessage) -> Unit,
|
||||||
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
|
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
|
||||||
navigateUp: () -> Unit,
|
navigateUp: () -> Unit,
|
||||||
|
@ -280,7 +280,8 @@ fun ChatPanel(
|
||||||
task = task,
|
task = task,
|
||||||
onPromptClicked = { template ->
|
onPromptClicked = { template ->
|
||||||
onSendMessage(
|
onSendMessage(
|
||||||
selectedModel, ChatMessageText(content = template.prompt, side = ChatSide.USER)
|
selectedModel,
|
||||||
|
listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER))
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -430,12 +431,15 @@ fun ChatPanel(
|
||||||
// Chat input
|
// Chat input
|
||||||
when (chatInputType) {
|
when (chatInputType) {
|
||||||
ChatInputType.TEXT -> {
|
ChatInputType.TEXT -> {
|
||||||
val isLlmTask = task.type == TaskType.LLM_CHAT
|
// val isLlmTask = task.type == TaskType.LLM_CHAT
|
||||||
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
|
// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
|
||||||
|
val hasImageMessage = messages.any { it is ChatMessageImage }
|
||||||
MessageInputText(
|
MessageInputText(
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
curMessage = curMessage,
|
curMessage = curMessage,
|
||||||
inProgress = uiState.inProgress,
|
inProgress = uiState.inProgress,
|
||||||
|
isResettingSession = uiState.isResettingSession,
|
||||||
|
hasImageMessage = hasImageMessage,
|
||||||
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||||
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
||||||
onValueChanged = { curMessage = it },
|
onValueChanged = { curMessage = it },
|
||||||
|
@ -445,13 +449,17 @@ fun ChatPanel(
|
||||||
},
|
},
|
||||||
onOpenPromptTemplatesClicked = {
|
onOpenPromptTemplatesClicked = {
|
||||||
onSendMessage(
|
onSendMessage(
|
||||||
selectedModel, ChatMessagePromptTemplates(
|
selectedModel, listOf(
|
||||||
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
|
ChatMessagePromptTemplates(
|
||||||
|
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
onStopButtonClicked = onStopButtonClicked,
|
onStopButtonClicked = onStopButtonClicked,
|
||||||
showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
||||||
|
showPromptTemplatesInMenu = false,
|
||||||
|
showImagePickerInMenu = selectedModel.llmSupportImage == true,
|
||||||
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
|
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -461,8 +469,10 @@ fun ChatPanel(
|
||||||
streamingMessage = streamingMessage,
|
streamingMessage = streamingMessage,
|
||||||
onImageSelected = { bitmap ->
|
onImageSelected = { bitmap ->
|
||||||
onSendMessage(
|
onSendMessage(
|
||||||
selectedModel, ChatMessageImage(
|
selectedModel, listOf(
|
||||||
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
|
ChatMessageImage(
|
||||||
|
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
|
|
@ -70,16 +70,18 @@ fun ChatView(
|
||||||
task: Task,
|
task: Task,
|
||||||
viewModel: ChatViewModel,
|
viewModel: ChatViewModel,
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
onSendMessage: (Model, ChatMessage) -> Unit,
|
onSendMessage: (Model, List<ChatMessage>) -> Unit,
|
||||||
onRunAgainClicked: (Model, ChatMessage) -> Unit,
|
onRunAgainClicked: (Model, ChatMessage) -> Unit,
|
||||||
onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit,
|
onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit,
|
||||||
navigateUp: () -> Unit,
|
navigateUp: () -> Unit,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
|
onResetSessionClicked: (Model) -> Unit = {},
|
||||||
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
|
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
|
||||||
onStopButtonClicked: (Model) -> Unit = {},
|
onStopButtonClicked: (Model) -> Unit = {},
|
||||||
chatInputType: ChatInputType = ChatInputType.TEXT,
|
chatInputType: ChatInputType = ChatInputType.TEXT,
|
||||||
showStopButtonInInputWhenInProgress: Boolean = false,
|
showStopButtonInInputWhenInProgress: Boolean = false,
|
||||||
) {
|
) {
|
||||||
|
val uiStat by viewModel.uiState.collectAsState()
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val selectedModel = modelManagerUiState.selectedModel
|
val selectedModel = modelManagerUiState.selectedModel
|
||||||
|
|
||||||
|
@ -155,6 +157,9 @@ fun ChatView(
|
||||||
task = task,
|
task = task,
|
||||||
model = selectedModel,
|
model = selectedModel,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
showResetSessionButton = true,
|
||||||
|
isResettingSession = uiStat.isResettingSession,
|
||||||
|
onResetSessionClicked = onResetSessionClicked,
|
||||||
onConfigChanged = { old, new ->
|
onConfigChanged = { old, new ->
|
||||||
viewModel.addConfigChangedMessage(
|
viewModel.addConfigChangedMessage(
|
||||||
oldConfigValues = old,
|
oldConfigValues = old,
|
||||||
|
|
|
@ -33,6 +33,11 @@ data class ChatUiState(
|
||||||
*/
|
*/
|
||||||
val inProgress: Boolean = false,
|
val inProgress: Boolean = false,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indicates whether the session is being reset.
|
||||||
|
*/
|
||||||
|
val isResettingSession: Boolean = false,
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A map of model names to lists of chat messages.
|
* A map of model names to lists of chat messages.
|
||||||
*/
|
*/
|
||||||
|
@ -106,6 +111,12 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
||||||
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun clearAllMessages(model: Model) {
|
||||||
|
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
|
||||||
|
newMessagesByModel[model.name] = mutableListOf()
|
||||||
|
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
|
||||||
|
}
|
||||||
|
|
||||||
fun getLastMessage(model: Model): ChatMessage? {
|
fun getLastMessage(model: Model): ChatMessage? {
|
||||||
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
|
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
|
||||||
}
|
}
|
||||||
|
@ -189,6 +200,10 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
||||||
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun setIsResettingSession(isResettingSession: Boolean) {
|
||||||
|
_uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) }
|
||||||
|
}
|
||||||
|
|
||||||
fun addConfigChangedMessage(
|
fun addConfigChangedMessage(
|
||||||
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
|
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
|
||||||
) {
|
) {
|
||||||
|
|
|
@ -21,14 +21,18 @@ import androidx.compose.material3.ProvideTextStyle
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.CompositionLocalProvider
|
import androidx.compose.runtime.CompositionLocalProvider
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.text.SpanStyle
|
||||||
|
import androidx.compose.ui.text.TextLinkStyles
|
||||||
import androidx.compose.ui.text.TextStyle
|
import androidx.compose.ui.text.TextStyle
|
||||||
import androidx.compose.ui.text.font.FontFamily
|
import androidx.compose.ui.text.font.FontFamily
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
import com.halilibo.richtext.commonmark.Markdown
|
import com.halilibo.richtext.commonmark.Markdown
|
||||||
import com.halilibo.richtext.ui.CodeBlockStyle
|
import com.halilibo.richtext.ui.CodeBlockStyle
|
||||||
import com.halilibo.richtext.ui.RichTextStyle
|
import com.halilibo.richtext.ui.RichTextStyle
|
||||||
import com.halilibo.richtext.ui.material3.RichText
|
import com.halilibo.richtext.ui.material3.RichText
|
||||||
|
import com.halilibo.richtext.ui.string.RichTextStringStyle
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Composable function to display Markdown-formatted text.
|
* Composable function to display Markdown-formatted text.
|
||||||
|
@ -56,6 +60,11 @@ fun MarkdownText(
|
||||||
fontSize = MaterialTheme.typography.bodySmall.fontSize,
|
fontSize = MaterialTheme.typography.bodySmall.fontSize,
|
||||||
fontFamily = FontFamily.Monospace,
|
fontFamily = FontFamily.Monospace,
|
||||||
)
|
)
|
||||||
|
),
|
||||||
|
stringStyle = RichTextStringStyle(
|
||||||
|
linkStyle = TextLinkStyles(
|
||||||
|
style = SpanStyle(color = MaterialTheme.customColors.linkColor)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
) {
|
) {
|
||||||
|
|
|
@ -46,7 +46,11 @@ fun MessageBodyWarning(message: ChatMessageWarning) {
|
||||||
.clip(RoundedCornerShape(16.dp))
|
.clip(RoundedCornerShape(16.dp))
|
||||||
.background(MaterialTheme.colorScheme.tertiaryContainer)
|
.background(MaterialTheme.colorScheme.tertiaryContainer)
|
||||||
) {
|
) {
|
||||||
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
|
MarkdownText(
|
||||||
|
text = message.content,
|
||||||
|
modifier = Modifier.padding(horizontal = 16.dp, vertical = 6.dp),
|
||||||
|
smallFontSize = true
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,28 +16,45 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.common.chat
|
package com.google.aiedge.gallery.ui.common.chat
|
||||||
|
|
||||||
|
import android.Manifest
|
||||||
|
import android.content.Context
|
||||||
|
import android.content.pm.PackageManager
|
||||||
|
import android.graphics.Bitmap
|
||||||
|
import android.graphics.BitmapFactory
|
||||||
|
import android.graphics.Matrix
|
||||||
|
import android.net.Uri
|
||||||
|
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||||
|
import androidx.activity.result.PickVisualMediaRequest
|
||||||
|
import androidx.activity.result.contract.ActivityResultContracts
|
||||||
import androidx.annotation.StringRes
|
import androidx.annotation.StringRes
|
||||||
|
import androidx.compose.foundation.Image
|
||||||
|
import androidx.compose.foundation.background
|
||||||
import androidx.compose.foundation.border
|
import androidx.compose.foundation.border
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.Spacer
|
import androidx.compose.foundation.layout.Spacer
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.height
|
||||||
import androidx.compose.foundation.layout.offset
|
import androidx.compose.foundation.layout.offset
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.foundation.layout.width
|
import androidx.compose.foundation.layout.width
|
||||||
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.automirrored.rounded.Send
|
import androidx.compose.material.icons.automirrored.rounded.Send
|
||||||
import androidx.compose.material.icons.rounded.Add
|
import androidx.compose.material.icons.rounded.Add
|
||||||
|
import androidx.compose.material.icons.rounded.Close
|
||||||
import androidx.compose.material.icons.rounded.History
|
import androidx.compose.material.icons.rounded.History
|
||||||
|
import androidx.compose.material.icons.rounded.Photo
|
||||||
|
import androidx.compose.material.icons.rounded.PhotoCamera
|
||||||
import androidx.compose.material.icons.rounded.PostAdd
|
import androidx.compose.material.icons.rounded.PostAdd
|
||||||
import androidx.compose.material.icons.rounded.Stop
|
import androidx.compose.material.icons.rounded.Stop
|
||||||
import androidx.compose.material3.DropdownMenu
|
import androidx.compose.material3.DropdownMenu
|
||||||
import androidx.compose.material3.DropdownMenuItem
|
import androidx.compose.material3.DropdownMenuItem
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.IconButton
|
import androidx.compose.material3.IconButton
|
||||||
import androidx.compose.material3.IconButtonDefaults
|
import androidx.compose.material3.IconButtonDefaults
|
||||||
|
@ -54,12 +71,17 @@ import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.alpha
|
import androidx.compose.ui.draw.alpha
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.draw.shadow
|
||||||
import androidx.compose.ui.graphics.Color
|
import androidx.compose.ui.graphics.Color
|
||||||
|
import androidx.compose.ui.graphics.asImageBitmap
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.res.stringResource
|
import androidx.compose.ui.res.stringResource
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.core.content.ContextCompat
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
|
import com.google.aiedge.gallery.ui.common.createTempPictureUri
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
@ -74,24 +96,107 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
fun MessageInputText(
|
fun MessageInputText(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
curMessage: String,
|
curMessage: String,
|
||||||
|
isResettingSession: Boolean,
|
||||||
inProgress: Boolean,
|
inProgress: Boolean,
|
||||||
|
hasImageMessage: Boolean,
|
||||||
modelInitializing: Boolean,
|
modelInitializing: Boolean,
|
||||||
@StringRes textFieldPlaceHolderRes: Int,
|
@StringRes textFieldPlaceHolderRes: Int,
|
||||||
onValueChanged: (String) -> Unit,
|
onValueChanged: (String) -> Unit,
|
||||||
onSendMessage: (ChatMessage) -> Unit,
|
onSendMessage: (List<ChatMessage>) -> Unit,
|
||||||
onOpenPromptTemplatesClicked: () -> Unit = {},
|
onOpenPromptTemplatesClicked: () -> Unit = {},
|
||||||
onStopButtonClicked: () -> Unit = {},
|
onStopButtonClicked: () -> Unit = {},
|
||||||
showPromptTemplatesInMenu: Boolean = true,
|
showPromptTemplatesInMenu: Boolean = false,
|
||||||
|
showImagePickerInMenu: Boolean = false,
|
||||||
showStopButtonWhenInProgress: Boolean = false,
|
showStopButtonWhenInProgress: Boolean = false,
|
||||||
) {
|
) {
|
||||||
|
val context = LocalContext.current
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
var showAddContentMenu by remember { mutableStateOf(false) }
|
var showAddContentMenu by remember { mutableStateOf(false) }
|
||||||
var showTextInputHistorySheet by remember { mutableStateOf(false) }
|
var showTextInputHistorySheet by remember { mutableStateOf(false) }
|
||||||
|
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
|
||||||
|
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
|
||||||
|
val updatePickedImages: (Bitmap) -> Unit = { bitmap ->
|
||||||
|
val newPickedImages: MutableList<Bitmap> = mutableListOf()
|
||||||
|
newPickedImages.addAll(pickedImages)
|
||||||
|
newPickedImages.add(bitmap)
|
||||||
|
pickedImages = newPickedImages.toList()
|
||||||
|
}
|
||||||
|
|
||||||
|
// launches camera
|
||||||
|
val cameraLauncher =
|
||||||
|
rememberLauncherForActivityResult(ActivityResultContracts.TakePicture()) { isImageSaved ->
|
||||||
|
if (isImageSaved) {
|
||||||
|
handleImageSelected(
|
||||||
|
context = context,
|
||||||
|
uri = tempPhotoUri,
|
||||||
|
onImageSelected = { bitmap ->
|
||||||
|
updatePickedImages(bitmap)
|
||||||
|
},
|
||||||
|
rotateForPortrait = true,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Permission request when taking picture.
|
||||||
|
val takePicturePermissionLauncher = rememberLauncherForActivityResult(
|
||||||
|
ActivityResultContracts.RequestPermission()
|
||||||
|
) { permissionGranted ->
|
||||||
|
if (permissionGranted) {
|
||||||
|
showAddContentMenu = false
|
||||||
|
tempPhotoUri = context.createTempPictureUri()
|
||||||
|
cameraLauncher.launch(tempPhotoUri)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registers a photo picker activity launcher in single-select mode.
|
||||||
|
val pickMedia =
|
||||||
|
rememberLauncherForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri ->
|
||||||
|
// Callback is invoked after the user selects a media item or closes the
|
||||||
|
// photo picker.
|
||||||
|
if (uri != null) {
|
||||||
|
handleImageSelected(context = context, uri = uri, onImageSelected = { bitmap ->
|
||||||
|
updatePickedImages(bitmap)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Box(contentAlignment = Alignment.CenterStart) {
|
Box(contentAlignment = Alignment.CenterStart) {
|
||||||
|
// A preview panel for the selected image.
|
||||||
|
if (pickedImages.isNotEmpty()) {
|
||||||
|
Box(
|
||||||
|
contentAlignment = Alignment.TopEnd, modifier = Modifier.offset(x = 16.dp, y = (-80).dp)
|
||||||
|
) {
|
||||||
|
Image(
|
||||||
|
bitmap = pickedImages.last().asImageBitmap(),
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier
|
||||||
|
.height(80.dp)
|
||||||
|
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
|
||||||
|
.clip(RoundedCornerShape(8.dp))
|
||||||
|
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)),
|
||||||
|
)
|
||||||
|
Box(modifier = Modifier
|
||||||
|
.offset(x = 10.dp, y = (-10).dp)
|
||||||
|
.clip(CircleShape)
|
||||||
|
.background(MaterialTheme.colorScheme.surface)
|
||||||
|
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
|
||||||
|
.clickable {
|
||||||
|
pickedImages = listOf()
|
||||||
|
}) {
|
||||||
|
Icon(
|
||||||
|
Icons.Rounded.Close,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(3.dp)
|
||||||
|
.size(16.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// A plus button to show a popup menu to add stuff to the chat.
|
// A plus button to show a popup menu to add stuff to the chat.
|
||||||
IconButton(
|
IconButton(
|
||||||
enabled = !inProgress,
|
enabled = !inProgress && !isResettingSession,
|
||||||
onClick = { showAddContentMenu = true },
|
onClick = { showAddContentMenu = true },
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.offset(x = 16.dp)
|
.offset(x = 16.dp)
|
||||||
|
@ -113,6 +218,58 @@ fun MessageInputText(
|
||||||
DropdownMenu(
|
DropdownMenu(
|
||||||
expanded = showAddContentMenu,
|
expanded = showAddContentMenu,
|
||||||
onDismissRequest = { showAddContentMenu = false }) {
|
onDismissRequest = { showAddContentMenu = false }) {
|
||||||
|
if (showImagePickerInMenu) {
|
||||||
|
// Take a photo.
|
||||||
|
DropdownMenuItem(
|
||||||
|
text = {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||||
|
) {
|
||||||
|
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
|
||||||
|
Text("Take a photo")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
enabled = pickedImages.isEmpty() && !hasImageMessage,
|
||||||
|
onClick = {
|
||||||
|
// Check permission
|
||||||
|
when (PackageManager.PERMISSION_GRANTED) {
|
||||||
|
// Already got permission. Call the lambda.
|
||||||
|
ContextCompat.checkSelfPermission(
|
||||||
|
context, Manifest.permission.CAMERA
|
||||||
|
) -> {
|
||||||
|
showAddContentMenu = false
|
||||||
|
tempPhotoUri = context.createTempPictureUri()
|
||||||
|
cameraLauncher.launch(tempPhotoUri)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, ask for permission
|
||||||
|
else -> {
|
||||||
|
takePicturePermissionLauncher.launch(Manifest.permission.CAMERA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Pick an image from album.
|
||||||
|
DropdownMenuItem(
|
||||||
|
text = {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||||
|
) {
|
||||||
|
Icon(Icons.Rounded.Photo, contentDescription = "")
|
||||||
|
Text("Pick from album")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
enabled = pickedImages.isEmpty() && !hasImageMessage,
|
||||||
|
onClick = {
|
||||||
|
// Launch the photo picker and let the user choose only images.
|
||||||
|
pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly))
|
||||||
|
showAddContentMenu = false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prompt templates.
|
||||||
if (showPromptTemplatesInMenu) {
|
if (showPromptTemplatesInMenu) {
|
||||||
DropdownMenuItem(text = {
|
DropdownMenuItem(text = {
|
||||||
Row(
|
Row(
|
||||||
|
@ -127,6 +284,7 @@ fun MessageInputText(
|
||||||
showAddContentMenu = false
|
showAddContentMenu = false
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
// Prompt history.
|
||||||
DropdownMenuItem(text = {
|
DropdownMenuItem(text = {
|
||||||
Row(
|
Row(
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
@ -171,18 +329,19 @@ fun MessageInputText(
|
||||||
),
|
),
|
||||||
) {
|
) {
|
||||||
Icon(
|
Icon(
|
||||||
Icons.Rounded.Stop,
|
Icons.Rounded.Stop, contentDescription = "", tint = MaterialTheme.colorScheme.primary
|
||||||
contentDescription = "",
|
|
||||||
tint = MaterialTheme.colorScheme.primary
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // Send button. Only shown when text is not empty.
|
} // Send button. Only shown when text is not empty.
|
||||||
else if (curMessage.isNotEmpty()) {
|
else if (curMessage.isNotEmpty()) {
|
||||||
IconButton(
|
IconButton(
|
||||||
enabled = !inProgress,
|
enabled = !inProgress && !isResettingSession,
|
||||||
onClick = {
|
onClick = {
|
||||||
onSendMessage(ChatMessageText(content = curMessage.trim(), side = ChatSide.USER))
|
onSendMessage(
|
||||||
|
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
|
||||||
|
)
|
||||||
|
pickedImages = listOf()
|
||||||
},
|
},
|
||||||
colors = IconButtonDefaults.iconButtonColors(
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||||
|
@ -200,7 +359,6 @@ fun MessageInputText(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// A bottom sheet to show the text input history to pick from.
|
// A bottom sheet to show the text input history to pick from.
|
||||||
if (showTextInputHistorySheet) {
|
if (showTextInputHistorySheet) {
|
||||||
TextInputHistorySheet(
|
TextInputHistorySheet(
|
||||||
|
@ -209,7 +367,8 @@ fun MessageInputText(
|
||||||
showTextInputHistorySheet = false
|
showTextInputHistorySheet = false
|
||||||
},
|
},
|
||||||
onHistoryItemClicked = { item ->
|
onHistoryItemClicked = { item ->
|
||||||
onSendMessage(ChatMessageText(content = item, side = ChatSide.USER))
|
onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
|
||||||
|
pickedImages = listOf()
|
||||||
modelManagerViewModel.promoteTextInputHistoryItem(item)
|
modelManagerViewModel.promoteTextInputHistoryItem(item)
|
||||||
},
|
},
|
||||||
onHistoryItemDeleted = { item ->
|
onHistoryItemDeleted = { item ->
|
||||||
|
@ -217,9 +376,55 @@ fun MessageInputText(
|
||||||
},
|
},
|
||||||
onHistoryItemsDeleteAll = {
|
onHistoryItemsDeleteAll = {
|
||||||
modelManagerViewModel.clearTextInputHistory()
|
modelManagerViewModel.clearTextInputHistory()
|
||||||
}
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun handleImageSelected(
|
||||||
|
context: Context,
|
||||||
|
uri: Uri,
|
||||||
|
onImageSelected: (Bitmap) -> Unit,
|
||||||
|
// For some reason, some Android phone would store the picture taken by the camera rotated
|
||||||
|
// horizontally. Use this flag to rotate the image back to portrait if the picture's width
|
||||||
|
// is bigger than height.
|
||||||
|
rotateForPortrait: Boolean = false,
|
||||||
|
) {
|
||||||
|
val bitmap: Bitmap? = try {
|
||||||
|
val inputStream = context.contentResolver.openInputStream(uri)
|
||||||
|
val tmpBitmap = BitmapFactory.decodeStream(inputStream)
|
||||||
|
if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) {
|
||||||
|
val matrix = Matrix()
|
||||||
|
matrix.postRotate(90f)
|
||||||
|
Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true)
|
||||||
|
} else {
|
||||||
|
tmpBitmap
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
e.printStackTrace()
|
||||||
|
null
|
||||||
|
}
|
||||||
|
if (bitmap != null) {
|
||||||
|
onImageSelected(bitmap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> {
|
||||||
|
val messages: MutableList<ChatMessage> = mutableListOf()
|
||||||
|
if (pickedImages.isNotEmpty()) {
|
||||||
|
val lastImage = pickedImages.last()
|
||||||
|
messages.add(
|
||||||
|
ChatMessageImage(
|
||||||
|
bitmap = lastImage, imageBitMap = lastImage.asImageBitmap(), side = ChatSide.USER
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
messages.add(
|
||||||
|
ChatMessageText(
|
||||||
|
content = text, side = ChatSide.USER
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
@Preview(showBackground = true)
|
@Preview(showBackground = true)
|
||||||
|
@ -233,6 +438,21 @@ fun MessageInputTextPreview() {
|
||||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
curMessage = "hello",
|
curMessage = "hello",
|
||||||
inProgress = false,
|
inProgress = false,
|
||||||
|
isResettingSession = false,
|
||||||
|
modelInitializing = false,
|
||||||
|
hasImageMessage = false,
|
||||||
|
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||||
|
onValueChanged = {},
|
||||||
|
onSendMessage = {},
|
||||||
|
showStopButtonWhenInProgress = true,
|
||||||
|
showImagePickerInMenu = true,
|
||||||
|
)
|
||||||
|
MessageInputText(
|
||||||
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
|
curMessage = "hello",
|
||||||
|
inProgress = false,
|
||||||
|
isResettingSession = false,
|
||||||
|
hasImageMessage = false,
|
||||||
modelInitializing = false,
|
modelInitializing = false,
|
||||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||||
onValueChanged = {},
|
onValueChanged = {},
|
||||||
|
@ -243,6 +463,8 @@ fun MessageInputTextPreview() {
|
||||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
curMessage = "hello",
|
curMessage = "hello",
|
||||||
inProgress = true,
|
inProgress = true,
|
||||||
|
isResettingSession = false,
|
||||||
|
hasImageMessage = false,
|
||||||
modelInitializing = false,
|
modelInitializing = false,
|
||||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||||
onValueChanged = {},
|
onValueChanged = {},
|
||||||
|
@ -252,6 +474,8 @@ fun MessageInputTextPreview() {
|
||||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
curMessage = "",
|
curMessage = "",
|
||||||
inProgress = false,
|
inProgress = false,
|
||||||
|
isResettingSession = false,
|
||||||
|
hasImageMessage = false,
|
||||||
modelInitializing = false,
|
modelInitializing = false,
|
||||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||||
onValueChanged = {},
|
onValueChanged = {},
|
||||||
|
@ -261,6 +485,8 @@ fun MessageInputTextPreview() {
|
||||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
curMessage = "",
|
curMessage = "",
|
||||||
inProgress = true,
|
inProgress = true,
|
||||||
|
isResettingSession = false,
|
||||||
|
hasImageMessage = false,
|
||||||
modelInitializing = false,
|
modelInitializing = false,
|
||||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||||
onValueChanged = {},
|
onValueChanged = {},
|
||||||
|
|
|
@ -52,6 +52,7 @@ import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.clip
|
import androidx.compose.ui.draw.clip
|
||||||
import androidx.compose.ui.res.stringResource
|
import androidx.compose.ui.res.stringResource
|
||||||
import androidx.compose.ui.text.style.TextAlign
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
|
import androidx.compose.ui.text.style.TextOverflow
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
|
@ -149,6 +150,8 @@ private fun SheetContent(
|
||||||
Text(
|
Text(
|
||||||
item,
|
item,
|
||||||
style = MaterialTheme.typography.bodyMedium,
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
maxLines = 3,
|
||||||
|
overflow = TextOverflow.Ellipsis,
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.padding(vertical = 16.dp)
|
.padding(vertical = 16.dp)
|
||||||
.padding(start = 16.dp)
|
.padding(start = 16.dp)
|
||||||
|
|
|
@ -76,7 +76,7 @@ fun ModelNameAndStatus(
|
||||||
Text(
|
Text(
|
||||||
model.name,
|
model.name,
|
||||||
maxLines = 1,
|
maxLines = 1,
|
||||||
overflow = TextOverflow.Ellipsis,
|
overflow = TextOverflow.MiddleEllipsis,
|
||||||
style = MaterialTheme.typography.titleMedium,
|
style = MaterialTheme.typography.titleMedium,
|
||||||
modifier = modifier,
|
modifier = modifier,
|
||||||
)
|
)
|
||||||
|
|
|
@ -136,7 +136,7 @@ fun HomeScreen(
|
||||||
val snackbarHostState = remember { SnackbarHostState() }
|
val snackbarHostState = remember { SnackbarHostState() }
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
|
|
||||||
val tasks = uiState.tasks
|
val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 }
|
||||||
val loadingHfModels = uiState.loadingHfModels
|
val loadingHfModels = uiState.loadingHfModels
|
||||||
|
|
||||||
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
||||||
|
@ -183,14 +183,14 @@ fun HomeScreen(
|
||||||
) { innerPadding ->
|
) { innerPadding ->
|
||||||
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
|
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
|
||||||
TaskList(
|
TaskList(
|
||||||
tasks = tasks,
|
tasks = nonEmptyTasks,
|
||||||
navigateToTaskScreen = navigateToTaskScreen,
|
navigateToTaskScreen = navigateToTaskScreen,
|
||||||
loadingModelAllowlist = uiState.loadingModelAllowlist,
|
loadingModelAllowlist = uiState.loadingModelAllowlist,
|
||||||
modifier = Modifier.fillMaxSize(),
|
modifier = Modifier.fillMaxSize(),
|
||||||
contentPadding = innerPadding,
|
contentPadding = innerPadding,
|
||||||
)
|
)
|
||||||
|
|
||||||
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 16.dp))
|
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(bottom = 32.dp))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,7 +285,7 @@ fun HomeScreen(
|
||||||
|
|
||||||
// Show a snack bar for successful import.
|
// Show a snack bar for successful import.
|
||||||
scope.launch {
|
scope.launch {
|
||||||
snackbarHostState.showSnackbar("✅ Model imported successfully")
|
snackbarHostState.showSnackbar("Model imported successfully")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -316,7 +316,6 @@ fun HomeScreen(
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,7 @@ import androidx.compose.ui.unit.dp
|
||||||
import androidx.compose.ui.window.Dialog
|
import androidx.compose.ui.window.Dialog
|
||||||
import androidx.compose.ui.window.DialogProperties
|
import androidx.compose.ui.window.DialogProperties
|
||||||
import com.google.aiedge.gallery.data.Accelerator
|
import com.google.aiedge.gallery.data.Accelerator
|
||||||
|
import com.google.aiedge.gallery.data.BooleanSwitchConfig
|
||||||
import com.google.aiedge.gallery.data.Config
|
import com.google.aiedge.gallery.data.Config
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||||
|
@ -111,6 +112,10 @@ private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
|
||||||
defaultValue = DEFAULT_TEMPERATURE,
|
defaultValue = DEFAULT_TEMPERATURE,
|
||||||
valueType = ValueType.FLOAT
|
valueType = ValueType.FLOAT
|
||||||
),
|
),
|
||||||
|
BooleanSwitchConfig(
|
||||||
|
key = ConfigKey.SUPPORT_IMAGE,
|
||||||
|
defaultValue = false,
|
||||||
|
),
|
||||||
SegmentedButtonConfig(
|
SegmentedButtonConfig(
|
||||||
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
||||||
defaultValue = Accelerator.CPU.label,
|
defaultValue = Accelerator.CPU.label,
|
||||||
|
|
|
@ -50,7 +50,8 @@ fun ImageClassificationScreen(
|
||||||
task = viewModel.task,
|
task = viewModel.task,
|
||||||
viewModel = viewModel,
|
viewModel = viewModel,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
onSendMessage = { model, message ->
|
onSendMessage = { model, messages ->
|
||||||
|
val message = messages[0]
|
||||||
viewModel.addMessage(
|
viewModel.addMessage(
|
||||||
model = model,
|
model = model,
|
||||||
message = message,
|
message = message,
|
||||||
|
|
|
@ -44,7 +44,8 @@ fun ImageGenerationScreen(
|
||||||
task = viewModel.task,
|
task = viewModel.task,
|
||||||
viewModel = viewModel,
|
viewModel = viewModel,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
onSendMessage = { model, message ->
|
onSendMessage = { model, messages ->
|
||||||
|
val message = messages[0]
|
||||||
viewModel.addMessage(
|
viewModel.addMessage(
|
||||||
model = model,
|
model = model,
|
||||||
message = message,
|
message = message,
|
||||||
|
|
|
@ -17,11 +17,14 @@
|
||||||
package com.google.aiedge.gallery.ui.llmchat
|
package com.google.aiedge.gallery.ui.llmchat
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import android.graphics.Bitmap
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import com.google.aiedge.gallery.data.Accelerator
|
import com.google.aiedge.gallery.data.Accelerator
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
||||||
|
import com.google.mediapipe.framework.image.BitmapImageBuilder
|
||||||
|
import com.google.mediapipe.tasks.genai.llminference.GraphOptions
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||||
|
|
||||||
|
@ -55,7 +58,9 @@ object LlmChatModelHelper {
|
||||||
}
|
}
|
||||||
val options =
|
val options =
|
||||||
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
|
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
|
||||||
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend).build()
|
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend)
|
||||||
|
.setMaxNumImages(if (model.llmSupportImage) 1 else 0)
|
||||||
|
.build()
|
||||||
|
|
||||||
// Create an instance of the LLM Inference task
|
// Create an instance of the LLM Inference task
|
||||||
try {
|
try {
|
||||||
|
@ -64,7 +69,10 @@ object LlmChatModelHelper {
|
||||||
val session = LlmInferenceSession.createFromOptions(
|
val session = LlmInferenceSession.createFromOptions(
|
||||||
llmInference,
|
llmInference,
|
||||||
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||||
.setTemperature(temperature).build()
|
.setTemperature(temperature)
|
||||||
|
.setGraphOptions(
|
||||||
|
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
||||||
|
).build()
|
||||||
)
|
)
|
||||||
model.instance = LlmModelInstance(engine = llmInference, session = session)
|
model.instance = LlmModelInstance(engine = llmInference, session = session)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
|
@ -75,6 +83,8 @@ object LlmChatModelHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
fun resetSession(model: Model) {
|
fun resetSession(model: Model) {
|
||||||
|
Log.d(TAG, "Resetting session for model '${model.name}'")
|
||||||
|
|
||||||
val instance = model.instance as LlmModelInstance? ?: return
|
val instance = model.instance as LlmModelInstance? ?: return
|
||||||
val session = instance.session
|
val session = instance.session
|
||||||
session.close()
|
session.close()
|
||||||
|
@ -87,9 +97,13 @@ object LlmChatModelHelper {
|
||||||
val newSession = LlmInferenceSession.createFromOptions(
|
val newSession = LlmInferenceSession.createFromOptions(
|
||||||
inference,
|
inference,
|
||||||
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||||
.setTemperature(temperature).build()
|
.setTemperature(temperature)
|
||||||
|
.setGraphOptions(
|
||||||
|
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
||||||
|
).build()
|
||||||
)
|
)
|
||||||
instance.session = newSession
|
instance.session = newSession
|
||||||
|
Log.d(TAG, "Resetting done")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun cleanUp(model: Model) {
|
fun cleanUp(model: Model) {
|
||||||
|
@ -118,6 +132,7 @@ object LlmChatModelHelper {
|
||||||
resultListener: ResultListener,
|
resultListener: ResultListener,
|
||||||
cleanUpListener: CleanUpListener,
|
cleanUpListener: CleanUpListener,
|
||||||
singleTurn: Boolean = false,
|
singleTurn: Boolean = false,
|
||||||
|
image: Bitmap? = null,
|
||||||
) {
|
) {
|
||||||
if (singleTurn) {
|
if (singleTurn) {
|
||||||
resetSession(model = model)
|
resetSession(model = model)
|
||||||
|
@ -132,6 +147,9 @@ object LlmChatModelHelper {
|
||||||
// Start async inference.
|
// Start async inference.
|
||||||
val session = instance.session
|
val session = instance.session
|
||||||
session.addQueryChunk(input)
|
session.addQueryChunk(input)
|
||||||
|
if (image != null) {
|
||||||
|
session.addImage(BitmapImageBuilder(image).build())
|
||||||
|
}
|
||||||
session.generateResponseAsync(resultListener)
|
session.generateResponseAsync(resultListener)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,15 +16,14 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.llmchat
|
package com.google.aiedge.gallery.ui.llmchat
|
||||||
|
|
||||||
import android.util.Log
|
import android.graphics.Bitmap
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||||
import com.google.aiedge.gallery.ui.ViewModelProvider
|
import com.google.aiedge.gallery.ui.ViewModelProvider
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageInfo
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageImage
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
|
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatView
|
import com.google.aiedge.gallery.ui.common.chat.ChatView
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import kotlinx.serialization.Serializable
|
import kotlinx.serialization.Serializable
|
||||||
|
@ -35,6 +34,11 @@ object LlmChatDestination {
|
||||||
val route = "LlmChatRoute"
|
val route = "LlmChatRoute"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object LlmImageToTextDestination {
|
||||||
|
@Serializable
|
||||||
|
val route = "LlmImageToTextRoute"
|
||||||
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun LlmChatScreen(
|
fun LlmChatScreen(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
@ -43,6 +47,38 @@ fun LlmChatScreen(
|
||||||
viewModel: LlmChatViewModel = viewModel(
|
viewModel: LlmChatViewModel = viewModel(
|
||||||
factory = ViewModelProvider.Factory
|
factory = ViewModelProvider.Factory
|
||||||
),
|
),
|
||||||
|
) {
|
||||||
|
ChatViewWrapper(
|
||||||
|
viewModel = viewModel,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
navigateUp = navigateUp,
|
||||||
|
modifier = modifier,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun LlmImageToTextScreen(
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
navigateUp: () -> Unit,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
viewModel: LlmImageToTextViewModel = viewModel(
|
||||||
|
factory = ViewModelProvider.Factory
|
||||||
|
),
|
||||||
|
) {
|
||||||
|
ChatViewWrapper(
|
||||||
|
viewModel = viewModel,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
navigateUp = navigateUp,
|
||||||
|
modifier = modifier,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ChatViewWrapper(
|
||||||
|
viewModel: LlmChatViewModel,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
navigateUp: () -> Unit,
|
||||||
|
modifier: Modifier = Modifier
|
||||||
) {
|
) {
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
|
|
||||||
|
@ -50,21 +86,33 @@ fun LlmChatScreen(
|
||||||
task = viewModel.task,
|
task = viewModel.task,
|
||||||
viewModel = viewModel,
|
viewModel = viewModel,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
onSendMessage = { model, message ->
|
onSendMessage = { model, messages ->
|
||||||
viewModel.addMessage(
|
for (message in messages) {
|
||||||
model = model,
|
viewModel.addMessage(
|
||||||
message = message,
|
model = model,
|
||||||
)
|
message = message,
|
||||||
if (message is ChatMessageText) {
|
)
|
||||||
modelManagerViewModel.addTextInputHistory(message.content)
|
}
|
||||||
viewModel.generateResponse(model = model, input = message.content, onError = {
|
|
||||||
viewModel.addMessage(
|
|
||||||
model = model,
|
|
||||||
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
|
|
||||||
)
|
|
||||||
|
|
||||||
modelManagerViewModel.initializeModel(
|
var text = ""
|
||||||
context = context, task = viewModel.task, model = model, force = true
|
var image: Bitmap? = null
|
||||||
|
var chatMessageText: ChatMessageText? = null
|
||||||
|
for (message in messages) {
|
||||||
|
if (message is ChatMessageText) {
|
||||||
|
chatMessageText = message
|
||||||
|
text = message.content
|
||||||
|
} else if (message is ChatMessageImage) {
|
||||||
|
image = message.bitmap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (text.isNotEmpty() && chatMessageText != null) {
|
||||||
|
modelManagerViewModel.addTextInputHistory(text)
|
||||||
|
viewModel.generateResponse(model = model, input = text, image = image, onError = {
|
||||||
|
viewModel.handleError(
|
||||||
|
context = context,
|
||||||
|
model = model,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
triggeredMessage = chatMessageText,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -72,13 +120,11 @@ fun LlmChatScreen(
|
||||||
onRunAgainClicked = { model, message ->
|
onRunAgainClicked = { model, message ->
|
||||||
if (message is ChatMessageText) {
|
if (message is ChatMessageText) {
|
||||||
viewModel.runAgain(model = model, message = message, onError = {
|
viewModel.runAgain(model = model, message = message, onError = {
|
||||||
viewModel.addMessage(
|
viewModel.handleError(
|
||||||
|
context = context,
|
||||||
model = model,
|
model = model,
|
||||||
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
)
|
triggeredMessage = message,
|
||||||
|
|
||||||
modelManagerViewModel.initializeModel(
|
|
||||||
context = context, task = viewModel.task, model = model, force = true
|
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -90,6 +136,9 @@ fun LlmChatScreen(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
onResetSessionClicked = { model ->
|
||||||
|
viewModel.resetSession(model = model)
|
||||||
|
},
|
||||||
showStopButtonInInputWhenInProgress = true,
|
showStopButtonInInputWhenInProgress = true,
|
||||||
onStopButtonClicked = { model ->
|
onStopButtonClicked = { model ->
|
||||||
viewModel.stopResponse(model = model)
|
viewModel.stopResponse(model = model)
|
||||||
|
@ -97,5 +146,4 @@ fun LlmChatScreen(
|
||||||
navigateUp = navigateUp,
|
navigateUp = navigateUp,
|
||||||
modifier = modifier,
|
modifier = modifier,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,18 +16,23 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.llmchat
|
package com.google.aiedge.gallery.ui.llmchat
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.graphics.Bitmap
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
||||||
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatSide
|
import com.google.aiedge.gallery.ui.common.chat.ChatSide
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
|
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
|
||||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||||
import kotlinx.coroutines.CoroutineExceptionHandler
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
@ -40,8 +45,8 @@ private val STATS = listOf(
|
||||||
Stat(id = "latency", label = "Latency", unit = "sec")
|
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||||
)
|
)
|
||||||
|
|
||||||
class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
|
||||||
fun generateResponse(model: Model, input: String, onError: () -> Unit) {
|
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
setInProgress(true)
|
setInProgress(true)
|
||||||
|
|
||||||
|
@ -58,7 +63,10 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
|
|
||||||
// Run inference.
|
// Run inference.
|
||||||
val instance = model.instance as LlmModelInstance
|
val instance = model.instance as LlmModelInstance
|
||||||
val prefillTokens = instance.session.sizeInTokens(input)
|
var prefillTokens = instance.session.sizeInTokens(input)
|
||||||
|
if (image != null) {
|
||||||
|
prefillTokens += 257
|
||||||
|
}
|
||||||
|
|
||||||
var firstRun = true
|
var firstRun = true
|
||||||
var timeToFirstToken = 0f
|
var timeToFirstToken = 0f
|
||||||
|
@ -69,9 +77,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
val start = System.currentTimeMillis()
|
val start = System.currentTimeMillis()
|
||||||
|
|
||||||
try {
|
try {
|
||||||
LlmChatModelHelper.runInference(
|
LlmChatModelHelper.runInference(model = model,
|
||||||
model = model,
|
|
||||||
input = input,
|
input = input,
|
||||||
|
image = image,
|
||||||
resultListener = { partialResult, done ->
|
resultListener = { partialResult, done ->
|
||||||
val curTs = System.currentTimeMillis()
|
val curTs = System.currentTimeMillis()
|
||||||
|
|
||||||
|
@ -92,32 +100,27 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
|
|
||||||
// Add an empty message that will receive streaming results.
|
// Add an empty message that will receive streaming results.
|
||||||
addMessage(
|
addMessage(
|
||||||
model = model,
|
model = model, message = ChatMessageText(content = "", side = ChatSide.AGENT)
|
||||||
message = ChatMessageText(content = "", side = ChatSide.AGENT)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Incrementally update the streamed partial results.
|
// Incrementally update the streamed partial results.
|
||||||
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
|
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
|
||||||
updateLastTextMessageContentIncrementally(
|
updateLastTextMessageContentIncrementally(
|
||||||
model = model,
|
model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat()
|
||||||
partialContent = partialResult,
|
|
||||||
latencyMs = latencyMs.toFloat()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (done) {
|
if (done) {
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
|
|
||||||
decodeSpeed =
|
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||||
decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
|
||||||
if (decodeSpeed.isNaN()) {
|
if (decodeSpeed.isNaN()) {
|
||||||
decodeSpeed = 0f
|
decodeSpeed = 0f
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lastMessage is ChatMessageText) {
|
if (lastMessage is ChatMessageText) {
|
||||||
updateLastTextMessageLlmBenchmarkResult(
|
updateLastTextMessageLlmBenchmarkResult(
|
||||||
model = model, llmBenchmarkResult =
|
model = model, llmBenchmarkResult = ChatMessageBenchmarkLlmResult(
|
||||||
ChatMessageBenchmarkLlmResult(
|
|
||||||
orderedStats = STATS,
|
orderedStats = STATS,
|
||||||
statValues = mutableMapOf(
|
statValues = mutableMapOf(
|
||||||
"prefill_speed" to prefillSpeed,
|
"prefill_speed" to prefillSpeed,
|
||||||
|
@ -131,10 +134,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, cleanUpListener = {
|
},
|
||||||
|
cleanUpListener = {
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
})
|
})
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Error occurred while running inference", e)
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
onError()
|
onError()
|
||||||
}
|
}
|
||||||
|
@ -143,6 +148,9 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
|
|
||||||
fun stopResponse(model: Model) {
|
fun stopResponse(model: Model) {
|
||||||
Log.d(TAG, "Stopping response for model ${model.name}...")
|
Log.d(TAG, "Stopping response for model ${model.name}...")
|
||||||
|
if (getLastMessage(model = model) is ChatMessageLoading) {
|
||||||
|
removeLastMessage(model = model)
|
||||||
|
}
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
val instance = model.instance as LlmModelInstance
|
val instance = model.instance as LlmModelInstance
|
||||||
|
@ -150,6 +158,25 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun resetSession(model: Model) {
|
||||||
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
|
setIsResettingSession(true)
|
||||||
|
clearAllMessages(model = model)
|
||||||
|
stopResponse(model = model)
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
LlmChatModelHelper.resetSession(model = model)
|
||||||
|
break
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.d(TAG, "Failed to reset session. Trying again")
|
||||||
|
}
|
||||||
|
delay(200)
|
||||||
|
}
|
||||||
|
setIsResettingSession(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) {
|
fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) {
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
// Wait for model to be initialized.
|
// Wait for model to be initialized.
|
||||||
|
@ -162,9 +189,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
|
|
||||||
// Run inference.
|
// Run inference.
|
||||||
generateResponse(
|
generateResponse(
|
||||||
model = model,
|
model = model, input = message.content, onError = onError
|
||||||
input = message.content,
|
|
||||||
onError = onError
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -199,8 +224,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
var decodeSpeed: Float
|
var decodeSpeed: Float
|
||||||
val start = System.currentTimeMillis()
|
val start = System.currentTimeMillis()
|
||||||
var lastUpdateTime = 0L
|
var lastUpdateTime = 0L
|
||||||
LlmChatModelHelper.runInference(
|
LlmChatModelHelper.runInference(model = model,
|
||||||
model = model,
|
|
||||||
input = message.content,
|
input = message.content,
|
||||||
resultListener = { partialResult, done ->
|
resultListener = { partialResult, done ->
|
||||||
val curTs = System.currentTimeMillis()
|
val curTs = System.currentTimeMillis()
|
||||||
|
@ -232,14 +256,12 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
result.append(partialResult)
|
result.append(partialResult)
|
||||||
|
|
||||||
if (curTs - lastUpdateTime > 500 || done) {
|
if (curTs - lastUpdateTime > 500 || done) {
|
||||||
decodeSpeed =
|
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||||
decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
|
||||||
if (decodeSpeed.isNaN()) {
|
if (decodeSpeed.isNaN()) {
|
||||||
decodeSpeed = 0f
|
decodeSpeed = 0f
|
||||||
}
|
}
|
||||||
replaceLastMessage(
|
replaceLastMessage(
|
||||||
model = model,
|
model = model, message = ChatMessageBenchmarkLlmResult(
|
||||||
message = ChatMessageBenchmarkLlmResult(
|
|
||||||
orderedStats = STATS,
|
orderedStats = STATS,
|
||||||
statValues = mutableMapOf(
|
statValues = mutableMapOf(
|
||||||
"prefill_speed" to prefillSpeed,
|
"prefill_speed" to prefillSpeed,
|
||||||
|
@ -249,8 +271,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
),
|
),
|
||||||
running = !done,
|
running = !done,
|
||||||
latencyMs = -1f,
|
latencyMs = -1f,
|
||||||
),
|
), type = ChatMessageType.BENCHMARK_LLM_RESULT
|
||||||
type = ChatMessageType.BENCHMARK_LLM_RESULT
|
|
||||||
)
|
)
|
||||||
lastUpdateTime = curTs
|
lastUpdateTime = curTs
|
||||||
|
|
||||||
|
@ -261,9 +282,53 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
},
|
},
|
||||||
cleanUpListener = {
|
cleanUpListener = {
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
}
|
})
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun handleError(
|
||||||
|
context: Context,
|
||||||
|
model: Model,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
triggeredMessage: ChatMessageText,
|
||||||
|
) {
|
||||||
|
// Clean up.
|
||||||
|
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||||
|
|
||||||
|
// Remove the "loading" message.
|
||||||
|
if (getLastMessage(model = model) is ChatMessageLoading) {
|
||||||
|
removeLastMessage(model = model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the last Text message.
|
||||||
|
if (getLastMessage(model = model) == triggeredMessage) {
|
||||||
|
removeLastMessage(model = model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a warning message for re-initializing the session.
|
||||||
|
addMessage(
|
||||||
|
model = model,
|
||||||
|
message = ChatMessageWarning(content = "Error occurred. Re-initializing the session.")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Add the triggered message back.
|
||||||
|
addMessage(model = model, message = triggeredMessage)
|
||||||
|
|
||||||
|
// Re-initialize the session/engine.
|
||||||
|
modelManagerViewModel.initializeModel(
|
||||||
|
context = context, task = task, model = model
|
||||||
|
)
|
||||||
|
|
||||||
|
// Re-generate the response automatically.
|
||||||
|
generateResponse(model = model, input = triggeredMessage.content, onError = {
|
||||||
|
handleError(
|
||||||
|
context = context,
|
||||||
|
model = model,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
triggeredMessage = triggeredMessage
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT)
|
|
@ -155,9 +155,14 @@ fun LlmSingleTurnScreen(
|
||||||
PromptTemplatesPanel(
|
PromptTemplatesPanel(
|
||||||
model = selectedModel,
|
model = selectedModel,
|
||||||
viewModel = viewModel,
|
viewModel = viewModel,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
onSend = { fullPrompt ->
|
onSend = { fullPrompt ->
|
||||||
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
|
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
|
||||||
}, modifier = Modifier.fillMaxSize()
|
},
|
||||||
|
onStopButtonClicked = { model ->
|
||||||
|
viewModel.stopResponse(model = model)
|
||||||
|
},
|
||||||
|
modifier = Modifier.fillMaxSize()
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
bottomView = {
|
bottomView = {
|
||||||
|
|
|
@ -23,6 +23,7 @@ import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
||||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||||
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||||
|
@ -194,6 +195,15 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun stopResponse(model: Model) {
|
||||||
|
Log.d(TAG, "Stopping response for model ${model.name}...")
|
||||||
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
|
setInProgress(false)
|
||||||
|
val instance = model.instance as LlmModelInstance
|
||||||
|
instance.session.cancelGenerateResponseAsync()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private fun createUiState(task: Task): LlmSingleTurnUiState {
|
private fun createUiState(task: Task): LlmSingleTurnUiState {
|
||||||
val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf()
|
val responsesByModel: MutableMap<String, Map<String, String>> = mutableMapOf()
|
||||||
val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> =
|
val benchmarkByModel: MutableMap<String, Map<String, ChatMessageBenchmarkLlmResult>> =
|
||||||
|
|
|
@ -43,11 +43,13 @@ import androidx.compose.material.icons.outlined.Description
|
||||||
import androidx.compose.material.icons.outlined.ExpandLess
|
import androidx.compose.material.icons.outlined.ExpandLess
|
||||||
import androidx.compose.material.icons.outlined.ExpandMore
|
import androidx.compose.material.icons.outlined.ExpandMore
|
||||||
import androidx.compose.material.icons.rounded.Add
|
import androidx.compose.material.icons.rounded.Add
|
||||||
|
import androidx.compose.material.icons.rounded.Stop
|
||||||
import androidx.compose.material.icons.rounded.Visibility
|
import androidx.compose.material.icons.rounded.Visibility
|
||||||
import androidx.compose.material.icons.rounded.VisibilityOff
|
import androidx.compose.material.icons.rounded.VisibilityOff
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.FilterChipDefaults
|
import androidx.compose.material3.FilterChipDefaults
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.IconButton
|
||||||
import androidx.compose.material3.IconButtonDefaults
|
import androidx.compose.material3.IconButtonDefaults
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.ModalBottomSheet
|
import androidx.compose.material3.ModalBottomSheet
|
||||||
|
@ -86,9 +88,11 @@ import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape
|
import com.google.aiedge.gallery.ui.common.chat.MessageBubbleShape
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.theme.customColors
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries
|
private val promptTemplateTypes: List<PromptTemplateType> = PromptTemplateType.entries
|
||||||
private val TAB_TITLES = PromptTemplateType.entries.map { it.label }
|
private val TAB_TITLES = PromptTemplateType.entries.map { it.label }
|
||||||
|
@ -101,11 +105,14 @@ const val FULL_PROMPT_SWITCH_KEY = "full_prompt"
|
||||||
fun PromptTemplatesPanel(
|
fun PromptTemplatesPanel(
|
||||||
model: Model,
|
model: Model,
|
||||||
viewModel: LlmSingleTurnViewModel,
|
viewModel: LlmSingleTurnViewModel,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
onSend: (fullPrompt: String) -> Unit,
|
onSend: (fullPrompt: String) -> Unit,
|
||||||
|
onStopButtonClicked: (Model) -> Unit,
|
||||||
modifier: Modifier = Modifier
|
modifier: Modifier = Modifier
|
||||||
) {
|
) {
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
val uiState by viewModel.uiState.collectAsState()
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||||
val inProgress = uiState.inProgress
|
val inProgress = uiState.inProgress
|
||||||
var selectedTabIndex by remember { mutableIntStateOf(0) }
|
var selectedTabIndex by remember { mutableIntStateOf(0) }
|
||||||
|
@ -123,6 +130,8 @@ fun PromptTemplatesPanel(
|
||||||
val focusManager = LocalFocusManager.current
|
val focusManager = LocalFocusManager.current
|
||||||
val interactionSource = remember { MutableInteractionSource() }
|
val interactionSource = remember { MutableInteractionSource() }
|
||||||
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
|
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
|
||||||
|
val modelInitializationStatus =
|
||||||
|
modelManagerUiState.modelInitializationStatus[model.name]
|
||||||
|
|
||||||
// Update input editor values when prompt template changes.
|
// Update input editor values when prompt template changes.
|
||||||
LaunchedEffect(selectedPromptTemplateType) {
|
LaunchedEffect(selectedPromptTemplateType) {
|
||||||
|
@ -328,29 +337,49 @@ fun PromptTemplatesPanel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send button
|
val modelInitializing =
|
||||||
OutlinedIconButton(
|
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
||||||
enabled = !inProgress && curTextInputContent.isNotEmpty(),
|
if (inProgress && !modelInitializing) {
|
||||||
onClick = {
|
IconButton(
|
||||||
focusManager.clearFocus()
|
onClick = {
|
||||||
onSend(fullPrompt.text)
|
onStopButtonClicked(model)
|
||||||
},
|
},
|
||||||
colors = IconButtonDefaults.iconButtonColors(
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||||
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
|
),
|
||||||
contentColor = MaterialTheme.colorScheme.primary,
|
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||||
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
|
) {
|
||||||
),
|
Icon(
|
||||||
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
Icons.Rounded.Stop,
|
||||||
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
contentDescription = "",
|
||||||
) {
|
tint = MaterialTheme.colorScheme.primary
|
||||||
Icon(
|
)
|
||||||
Icons.AutoMirrored.Rounded.Send,
|
}
|
||||||
contentDescription = "",
|
} else {
|
||||||
modifier = Modifier
|
// Send button
|
||||||
.size(20.dp)
|
OutlinedIconButton(
|
||||||
.offset(x = 2.dp),
|
enabled = !inProgress && curTextInputContent.isNotEmpty(),
|
||||||
)
|
onClick = {
|
||||||
|
focusManager.clearFocus()
|
||||||
|
onSend(fullPrompt.text)
|
||||||
|
},
|
||||||
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||||
|
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
|
||||||
|
contentColor = MaterialTheme.colorScheme.primary,
|
||||||
|
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
|
||||||
|
),
|
||||||
|
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
|
||||||
|
modifier = Modifier.size(ICON_BUTTON_SIZE)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.AutoMirrored.Rounded.Send,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier
|
||||||
|
.size(20.dp)
|
||||||
|
.offset(x = 2.dp),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,10 @@ import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Scaffold
|
import androidx.compose.material3.Scaffold
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.derivedStateOf
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
@ -48,6 +52,24 @@ fun ModelManager(
|
||||||
if (task.models.size != 1) {
|
if (task.models.size != 1) {
|
||||||
title += "s"
|
title += "s"
|
||||||
}
|
}
|
||||||
|
// Model count.
|
||||||
|
val modelCount by remember {
|
||||||
|
derivedStateOf {
|
||||||
|
val trigger = task.updateTrigger.value
|
||||||
|
if (trigger >= 0) {
|
||||||
|
task.models.size
|
||||||
|
} else {
|
||||||
|
-1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Navigate up when there are no models left.
|
||||||
|
LaunchedEffect(modelCount) {
|
||||||
|
if (modelCount == 0) {
|
||||||
|
navigateUp()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle system's edge swipe.
|
// Handle system's edge swipe.
|
||||||
BackHandler {
|
BackHandler {
|
||||||
|
|
|
@ -38,6 +38,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.TASKS
|
import com.google.aiedge.gallery.data.TASKS
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
|
@ -58,6 +59,7 @@ import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.flow.update
|
import kotlinx.coroutines.flow.update
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
|
import kotlinx.serialization.ExperimentalSerializationApi
|
||||||
import kotlinx.serialization.json.Json
|
import kotlinx.serialization.json.Json
|
||||||
import net.openid.appauth.AuthorizationException
|
import net.openid.appauth.AuthorizationException
|
||||||
import net.openid.appauth.AuthorizationRequest
|
import net.openid.appauth.AuthorizationRequest
|
||||||
|
@ -229,7 +231,10 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
dataStoreRepository.saveImportedModels(importedModels = importedModels)
|
dataStoreRepository.saveImportedModels(importedModels = importedModels)
|
||||||
}
|
}
|
||||||
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
|
val newUiState = uiState.value.copy(
|
||||||
|
modelDownloadStatus = curModelDownloadStatus,
|
||||||
|
tasks = uiState.value.tasks.toList()
|
||||||
|
)
|
||||||
_uiState.update { newUiState }
|
_uiState.update { newUiState }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -312,6 +317,12 @@ open class ModelManagerViewModel(
|
||||||
onDone = onDone,
|
onDone = onDone,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.initialize(
|
||||||
|
context = context,
|
||||||
|
model = model,
|
||||||
|
onDone = onDone,
|
||||||
|
)
|
||||||
|
|
||||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
|
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
|
||||||
context = context, model = model, onDone = onDone
|
context = context, model = model, onDone = onDone
|
||||||
)
|
)
|
||||||
|
@ -331,6 +342,7 @@ open class ModelManagerViewModel(
|
||||||
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
||||||
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
|
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
|
@ -434,14 +446,16 @@ open class ModelManagerViewModel(
|
||||||
// Create model.
|
// Create model.
|
||||||
val model = createModelFromImportedModelInfo(info = info)
|
val model = createModelFromImportedModelInfo(info = info)
|
||||||
|
|
||||||
// Remove duplicated imported model if existed.
|
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES, TASK_LLM_IMAGE_TO_TEXT)) {
|
||||||
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES)) {
|
// Remove duplicated imported model if existed.
|
||||||
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
||||||
if (modelIndex >= 0) {
|
if (modelIndex >= 0) {
|
||||||
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
||||||
task.models.removeAt(modelIndex)
|
task.models.removeAt(modelIndex)
|
||||||
}
|
}
|
||||||
task.models.add(model)
|
if (task == TASK_LLM_IMAGE_TO_TEXT && model.llmSupportImage || task != TASK_LLM_IMAGE_TO_TEXT) {
|
||||||
|
task.models.add(model)
|
||||||
|
}
|
||||||
task.updateTrigger.value = System.currentTimeMillis()
|
task.updateTrigger.value = System.currentTimeMillis()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -632,8 +646,7 @@ open class ModelManagerViewModel(
|
||||||
fun loadModelAllowlist() {
|
fun loadModelAllowlist() {
|
||||||
_uiState.update {
|
_uiState.update {
|
||||||
uiState.value.copy(
|
uiState.value.copy(
|
||||||
loadingModelAllowlist = true,
|
loadingModelAllowlist = true, loadingModelAllowlistError = ""
|
||||||
loadingModelAllowlistError = ""
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -663,6 +676,9 @@ open class ModelManagerViewModel(
|
||||||
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
|
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
|
||||||
TASK_LLM_USECASES.models.add(model)
|
TASK_LLM_USECASES.models.add(model)
|
||||||
}
|
}
|
||||||
|
if (allowedModel.taskTypes.contains(TASK_LLM_IMAGE_TO_TEXT.type.id)) {
|
||||||
|
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-process all tasks.
|
// Pre-process all tasks.
|
||||||
|
@ -717,6 +733,9 @@ open class ModelManagerViewModel(
|
||||||
// Add to task.
|
// Add to task.
|
||||||
TASK_LLM_CHAT.models.add(model)
|
TASK_LLM_CHAT.models.add(model)
|
||||||
TASK_LLM_USECASES.models.add(model)
|
TASK_LLM_USECASES.models.add(model)
|
||||||
|
if (model.llmSupportImage) {
|
||||||
|
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
|
||||||
|
}
|
||||||
|
|
||||||
// Update status.
|
// Update status.
|
||||||
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||||
|
@ -731,7 +750,7 @@ open class ModelManagerViewModel(
|
||||||
|
|
||||||
Log.d(TAG, "model download status: $modelDownloadStatus")
|
Log.d(TAG, "model download status: $modelDownloadStatus")
|
||||||
return ModelManagerUiState(
|
return ModelManagerUiState(
|
||||||
tasks = TASKS,
|
tasks = TASKS.toList(),
|
||||||
modelDownloadStatus = modelDownloadStatus,
|
modelDownloadStatus = modelDownloadStatus,
|
||||||
modelInitializationStatus = modelInstances,
|
modelInitializationStatus = modelInstances,
|
||||||
textInputHistory = textInputHistory,
|
textInputHistory = textInputHistory,
|
||||||
|
@ -763,6 +782,9 @@ open class ModelManagerViewModel(
|
||||||
) as Float,
|
) as Float,
|
||||||
accelerators = accelerators,
|
accelerators = accelerators,
|
||||||
)
|
)
|
||||||
|
val llmSupportImage = convertValueToTargetType(
|
||||||
|
info.defaultValues[ConfigKey.SUPPORT_IMAGE.label] ?: false, ValueType.BOOLEAN
|
||||||
|
) as Boolean
|
||||||
val model = Model(
|
val model = Model(
|
||||||
name = info.fileName,
|
name = info.fileName,
|
||||||
url = "",
|
url = "",
|
||||||
|
@ -770,7 +792,9 @@ open class ModelManagerViewModel(
|
||||||
sizeInBytes = info.fileSize,
|
sizeInBytes = info.fileSize,
|
||||||
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
|
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
|
||||||
showBenchmarkButton = false,
|
showBenchmarkButton = false,
|
||||||
|
showRunAgainButton = false,
|
||||||
imported = true,
|
imported = true,
|
||||||
|
llmSupportImage = llmSupportImage,
|
||||||
)
|
)
|
||||||
model.preProcess()
|
model.preProcess()
|
||||||
|
|
||||||
|
@ -803,6 +827,7 @@ open class ModelManagerViewModel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalSerializationApi::class)
|
||||||
private inline fun <reified T> getJsonResponse(url: String): T? {
|
private inline fun <reified T> getJsonResponse(url: String): T? {
|
||||||
try {
|
try {
|
||||||
val connection = URL(url).openConnection() as HttpURLConnection
|
val connection = URL(url).openConnection() as HttpURLConnection
|
||||||
|
@ -815,7 +840,12 @@ open class ModelManagerViewModel(
|
||||||
val response = inputStream.bufferedReader().use { it.readText() }
|
val response = inputStream.bufferedReader().use { it.readText() }
|
||||||
|
|
||||||
// Parse JSON using kotlinx.serialization
|
// Parse JSON using kotlinx.serialization
|
||||||
val json = Json { ignoreUnknownKeys = true } // Handle potential extra fields
|
val json = Json {
|
||||||
|
// Handle potential extra fields
|
||||||
|
ignoreUnknownKeys = true
|
||||||
|
allowComments = true
|
||||||
|
allowTrailingComma = true
|
||||||
|
}
|
||||||
val jsonObj = json.decodeFromString<T>(response)
|
val jsonObj = json.decodeFromString<T>(response)
|
||||||
return jsonObj
|
return jsonObj
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -46,6 +46,7 @@ import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
||||||
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||||
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
@ -59,6 +60,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextDestination
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextScreen
|
||||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
|
||||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
|
||||||
|
@ -129,8 +132,7 @@ fun GalleryNavHost(
|
||||||
) {
|
) {
|
||||||
val curPickedTask = pickedTask
|
val curPickedTask = pickedTask
|
||||||
if (curPickedTask != null) {
|
if (curPickedTask != null) {
|
||||||
ModelManager(
|
ModelManager(viewModel = modelManagerViewModel,
|
||||||
viewModel = modelManagerViewModel,
|
|
||||||
task = curPickedTask,
|
task = curPickedTask,
|
||||||
onModelClicked = { model ->
|
onModelClicked = { model ->
|
||||||
navigateToTaskScreen(
|
navigateToTaskScreen(
|
||||||
|
@ -207,7 +209,7 @@ fun GalleryNavHost(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLMm chat demos.
|
// LLM chat demos.
|
||||||
composable(
|
composable(
|
||||||
route = "${LlmChatDestination.route}/{modelName}",
|
route = "${LlmChatDestination.route}/{modelName}",
|
||||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||||
|
@ -224,7 +226,7 @@ fun GalleryNavHost(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLMm single turn.
|
// LLM single turn.
|
||||||
composable(
|
composable(
|
||||||
route = "${LlmSingleTurnDestination.route}/{modelName}",
|
route = "${LlmSingleTurnDestination.route}/{modelName}",
|
||||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||||
|
@ -241,6 +243,22 @@ fun GalleryNavHost(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LLM image to text.
|
||||||
|
composable(
|
||||||
|
route = "${LlmImageToTextDestination.route}/{modelName}",
|
||||||
|
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||||
|
enterTransition = { slideEnter() },
|
||||||
|
exitTransition = { slideExit() },
|
||||||
|
) {
|
||||||
|
getModelFromNavigationParam(it, TASK_LLM_IMAGE_TO_TEXT)?.let { defaultModel ->
|
||||||
|
modelManagerViewModel.selectModel(defaultModel)
|
||||||
|
|
||||||
|
LlmImageToTextScreen(
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
navigateUp = { navController.navigateUp() },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle incoming intents for deep links
|
// Handle incoming intents for deep links
|
||||||
|
@ -254,9 +272,7 @@ fun GalleryNavHost(
|
||||||
getModelByName(modelName)?.let { model ->
|
getModelByName(modelName)?.let { model ->
|
||||||
// TODO(jingjin): need to show a list of possible tasks for this model.
|
// TODO(jingjin): need to show a list of possible tasks for this model.
|
||||||
navigateToTaskScreen(
|
navigateToTaskScreen(
|
||||||
navController = navController,
|
navController = navController, taskType = TaskType.LLM_CHAT, model = model
|
||||||
taskType = TaskType.LLM_CHAT,
|
|
||||||
model = model
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -271,6 +287,7 @@ fun navigateToTaskScreen(
|
||||||
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
||||||
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
||||||
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
||||||
|
TaskType.LLM_IMAGE_TO_TEXT -> navController.navigate("${LlmImageToTextDestination.route}/${modelName}")
|
||||||
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||||
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
|
|
|
@ -44,7 +44,8 @@ fun TextClassificationScreen(
|
||||||
task = viewModel.task,
|
task = viewModel.task,
|
||||||
viewModel = viewModel,
|
viewModel = viewModel,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
onSendMessage = { model, message ->
|
onSendMessage = { model, messages ->
|
||||||
|
val message = messages[0]
|
||||||
viewModel.addMessage(
|
viewModel.addMessage(
|
||||||
model = model,
|
model = model,
|
||||||
message = message,
|
message = message,
|
||||||
|
|
|
@ -7,7 +7,7 @@ junitVersion = "1.2.1"
|
||||||
espressoCore = "3.6.1"
|
espressoCore = "3.6.1"
|
||||||
lifecycleRuntimeKtx = "2.8.7"
|
lifecycleRuntimeKtx = "2.8.7"
|
||||||
activityCompose = "1.10.1"
|
activityCompose = "1.10.1"
|
||||||
composeBom = "2025.03.01"
|
composeBom = "2025.05.00"
|
||||||
navigation = "2.8.9"
|
navigation = "2.8.9"
|
||||||
serializationPlugin = "2.0.21"
|
serializationPlugin = "2.0.21"
|
||||||
serializationJson = "1.7.3"
|
serializationJson = "1.7.3"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue