Various bug fixes.

This commit is contained in:
Jing Jin 2025-05-17 15:11:52 -07:00
parent 0c49efc054
commit 37a58d1a41
35 changed files with 1517 additions and 995 deletions

View file

@ -27,10 +27,10 @@ android {
defaultConfig { defaultConfig {
applicationId = "com.google.aiedge.gallery" applicationId = "com.google.aiedge.gallery"
minSdk = 24 minSdk = 26
targetSdk = 35 targetSdk = 35
versionCode = 1 versionCode = 1
versionName = "20250428" versionName = "0.9.0"
// Needed for HuggingFace auth workflows. // Needed for HuggingFace auth workflows.
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth" manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"

View file

@ -41,7 +41,9 @@
android:name=".MainActivity" android:name=".MainActivity"
android:exported="true" android:exported="true"
android:theme="@style/Theme.Gallery.SplashScreen" android:theme="@style/Theme.Gallery.SplashScreen"
android:windowSoftInputMode="adjustResize"> android:screenOrientation="portrait"
android:windowSoftInputMode="adjustResize"
tools:ignore="DiscouragedApi,LockedOrientationActivity">
<!-- This is for putting the app into launcher --> <!-- This is for putting the app into launcher -->
<intent-filter> <intent-filter>
<action android:name="android.intent.action.MAIN" /> <action android:name="android.intent.action.MAIN" />

View file

@ -68,7 +68,6 @@ fun GalleryTopAppBar(
leftAction: AppBarAction? = null, leftAction: AppBarAction? = null,
rightAction: AppBarAction? = null, rightAction: AppBarAction? = null,
scrollBehavior: TopAppBarScrollBehavior? = null, scrollBehavior: TopAppBarScrollBehavior? = null,
loadingHfModels: Boolean = false,
subtitle: String = "", subtitle: String = "",
) { ) {
CenterAlignedTopAppBar( CenterAlignedTopAppBar(
@ -151,28 +150,6 @@ fun GalleryTopAppBar(
} }
} }
// Click an icon to open "download manager".
AppBarActionType.DOWNLOAD_MANAGER -> {
if (loadingHfModels) {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
strokeWidth = 3.dp,
modifier = Modifier
.padding(end = 12.dp)
.size(20.dp)
)
}
// else {
// IconButton(onClick = rightAction.actionFn) {
// Icon(
// imageVector = Deployed_code,
// contentDescription = "",
// tint = MaterialTheme.colorScheme.primary
// )
// }
// }
}
AppBarActionType.MODEL_SELECTOR -> { AppBarActionType.MODEL_SELECTOR -> {
Text("ms") Text("ms")
} }

View file

@ -37,7 +37,7 @@ import javax.crypto.SecretKey
data class AccessTokenData( data class AccessTokenData(
val accessToken: String, val accessToken: String,
val refreshToken: String, val refreshToken: String,
val expiresAtSeconds: Long val expiresAtMs: Long
) )
interface DataStoreRepository { interface DataStoreRepository {
@ -46,6 +46,7 @@ interface DataStoreRepository {
fun saveThemeOverride(theme: String) fun saveThemeOverride(theme: String)
fun readThemeOverride(): String fun readThemeOverride(): String
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
fun clearAccessTokenData()
fun readAccessTokenData(): AccessTokenData? fun readAccessTokenData(): AccessTokenData?
fun saveImportedModels(importedModels: List<ImportedModelInfo>) fun saveImportedModels(importedModels: List<ImportedModelInfo>)
fun readImportedModels(): List<ImportedModelInfo> fun readImportedModels(): List<ImportedModelInfo>
@ -135,6 +136,18 @@ class DefaultDataStoreRepository(
} }
} }
override fun clearAccessTokenData() {
return runBlocking {
dataStore.edit { preferences ->
preferences.remove(PreferencesKeys.ENCRYPTED_ACCESS_TOKEN)
preferences.remove(PreferencesKeys.ACCESS_TOKEN_IV)
preferences.remove(PreferencesKeys.ENCRYPTED_REFRESH_TOKEN)
preferences.remove(PreferencesKeys.REFRESH_TOKEN_IV)
preferences.remove(PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT)
}
}
}
override fun readAccessTokenData(): AccessTokenData? { override fun readAccessTokenData(): AccessTokenData? {
return runBlocking { return runBlocking {
val preferences = dataStore.data.first() val preferences = dataStore.data.first()

View file

@ -43,7 +43,7 @@ data class AllowedModel(
// Config. // Config.
val isLlmModel = val isLlmModel =
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id) taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)
var configs: List<Config> = listOf() var configs: List<Config> = listOf()
if (isLlmModel) { if (isLlmModel) {
var defaultTopK: Int = DEFAULT_TOPK var defaultTopK: Int = DEFAULT_TOPK

View file

@ -32,9 +32,9 @@ enum class TaskType(val label: String, val id: String) {
TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"), TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"), IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"), IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
LLM_CHAT(label = "LLM Chat", id = "llm_chat"), LLM_CHAT(label = "AI Chat", id = "llm_chat"),
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"), LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"), LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2") TEST_TASK_2(label = "Test task 2", id = "test_task_2")
@ -93,7 +93,6 @@ val TASK_IMAGE_CLASSIFICATION = Task(
val TASK_LLM_CHAT = Task( val TASK_LLM_CHAT = Task(
type = TaskType.LLM_CHAT, type = TaskType.LLM_CHAT,
icon = Icons.Outlined.Forum, icon = Icons.Outlined.Forum,
// models = MODELS_LLM,
models = mutableListOf(), models = mutableListOf(),
description = "Chat with on-device large language models", description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -101,10 +100,9 @@ val TASK_LLM_CHAT = Task(
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
) )
val TASK_LLM_USECASES = Task( val TASK_LLM_PROMPT_LAB = Task(
type = TaskType.LLM_USECASES, type = TaskType.LLM_PROMPT_LAB,
icon = Icons.Outlined.Widgets, icon = Icons.Outlined.Widgets,
// models = MODELS_LLM,
models = mutableListOf(), models = mutableListOf(),
description = "Single turn use cases with on-device large language model", description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -112,10 +110,9 @@ val TASK_LLM_USECASES = Task(
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
) )
val TASK_LLM_IMAGE_TO_TEXT = Task( val TASK_LLM_ASK_IMAGE = Task(
type = TaskType.LLM_IMAGE_TO_TEXT, type = TaskType.LLM_ASK_IMAGE,
icon = Icons.Outlined.Mms, icon = Icons.Outlined.Mms,
// models = MODELS_LLM,
models = mutableListOf(), models = mutableListOf(),
description = "Ask questions about images with on-device large language models", description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -135,12 +132,9 @@ val TASK_IMAGE_GENERATION = Task(
/** All tasks. */ /** All tasks. */
val TASKS: List<Task> = listOf( val TASKS: List<Task> = listOf(
// TASK_TEXT_CLASSIFICATION, TASK_LLM_ASK_IMAGE,
// TASK_IMAGE_CLASSIFICATION, TASK_LLM_PROMPT_LAB,
// TASK_IMAGE_GENERATION,
TASK_LLM_USECASES,
TASK_LLM_CHAT, TASK_LLM_CHAT,
TASK_LLM_IMAGE_TO_TEXT
) )
fun getModelByName(name: String): Model? { fun getModelByName(name: String): Model? {

View file

@ -25,7 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel import com.google.aiedge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
@ -63,9 +63,9 @@ object ViewModelProvider {
LlmSingleTurnViewModel() LlmSingleTurnViewModel()
} }
// Initializer for LlmImageToTextViewModel. // Initializer for LlmAskImageViewModel.
initializer { initializer {
LlmImageToTextViewModel() LlmAskImageViewModel()
} }
// Initializer for ImageGenerationViewModel. // Initializer for ImageGenerationViewModel.

View file

@ -26,6 +26,8 @@ import androidx.browser.customtabs.CustomTabsIntent
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentHeight import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.text.BasicText
import androidx.compose.foundation.text.TextAutoSize
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowForward import androidx.compose.material.icons.automirrored.rounded.ArrowForward
import androidx.compose.material.icons.rounded.Error import androidx.compose.material.icons.rounded.Error
@ -48,6 +50,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@ -291,11 +294,32 @@ fun DownloadAndTryButton(
modifier = Modifier.padding(end = 4.dp) modifier = Modifier.padding(end = 4.dp)
) )
val textColor = MaterialTheme.colorScheme.onPrimary
if (checkingToken) { if (checkingToken) {
Text("Checking access...") BasicText(
text = "Checking access...",
maxLines = 1,
color = { textColor },
style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased(
minFontSize = 8.sp,
maxFontSize = 14.sp,
stepSize = 1.sp
)
)
} else { } else {
if (needToDownloadFirst) { if (needToDownloadFirst) {
Text("Download & Try it", maxLines = 1) BasicText(
text = "Download & Try",
maxLines = 1,
color = { textColor },
style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased(
minFontSize = 8.sp,
maxFontSize = 14.sp,
stepSize = 1.sp
)
)
} else { } else {
Text("Try it", maxLines = 1) Text("Try it", maxLines = 1)
} }

View file

@ -63,6 +63,8 @@ fun ModelPageAppBar(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
onBackClicked: () -> Unit, onBackClicked: () -> Unit,
onModelSelected: (Model) -> Unit, onModelSelected: (Model) -> Unit,
inProgress: Boolean,
modelPreparing: Boolean,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
isResettingSession: Boolean = false, isResettingSession: Boolean = false,
onResetSessionClicked: (Model) -> Unit = {}, onResetSessionClicked: (Model) -> Unit = {},
@ -129,15 +131,16 @@ fun ModelPageAppBar(
val isModelInitializing = val isModelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (showConfigButton) { if (showConfigButton) {
val enableConfigButton = !isModelInitializing && !inProgress
IconButton( IconButton(
onClick = { onClick = {
showConfigDialog = true showConfigDialog = true
}, },
enabled = !isModelInitializing, enabled = enableConfigButton,
modifier = Modifier modifier = Modifier
.scale(0.75f) .scale(0.75f)
.offset(x = configButtonOffset) .offset(x = configButtonOffset)
.alpha(if (isModelInitializing) 0.5f else 1f) .alpha(if (!enableConfigButton) 0.5f else 1f)
) { ) {
Icon( Icon(
imageVector = Icons.Rounded.Tune, imageVector = Icons.Rounded.Tune,
@ -154,14 +157,15 @@ fun ModelPageAppBar(
modifier = Modifier.size(16.dp) modifier = Modifier.size(16.dp)
) )
} else { } else {
val enableResetButton = !isModelInitializing && !modelPreparing
IconButton( IconButton(
onClick = { onClick = {
onResetSessionClicked(model) onResetSessionClicked(model)
}, },
enabled = !isModelInitializing, enabled = enableResetButton,
modifier = Modifier modifier = Modifier
.scale(0.75f) .scale(0.75f)
.alpha(if (isModelInitializing) 0.5f else 1f) .alpha(if (!enableResetButton) 0.5f else 1f)
) { ) {
Icon( Icon(
imageVector = Icons.Rounded.MapsUgc, imageVector = Icons.Rounded.MapsUgc,

View file

@ -83,8 +83,11 @@ fun ModelPickerChipsPager(
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val density = LocalDensity.current val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current val windowInfo = LocalWindowInfo.current
val screenWidthDp = val screenWidthDp = remember {
remember { with(density) { windowInfo.containerSize.width.toDp() } } with(density) {
windowInfo.containerSize.width.toDp()
}
}
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel), val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
pageCount = { task.models.size }) pageCount = { task.models.size })
@ -147,7 +150,7 @@ fun ModelPickerChipsPager(
} }
Text( Text(
model.name, model.name,
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelMedium,
modifier = Modifier modifier = Modifier
.padding(start = 4.dp) .padding(start = 4.dp)
.widthIn(0.dp, screenWidthDp - 250.dp), .widthIn(0.dp, screenWidthDp - 250.dp),

View file

@ -21,6 +21,7 @@ import android.content.Context
import android.content.pm.PackageManager import android.content.pm.PackageManager
import android.net.Uri import android.net.Uri
import android.os.Build import android.os.Build
import android.util.Log
import androidx.activity.compose.ManagedActivityResultLauncher import androidx.activity.compose.ManagedActivityResultLauncher
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
@ -39,7 +40,11 @@ import com.google.aiedge.gallery.ui.common.chat.Histogram
import com.google.aiedge.gallery.ui.common.chat.Stat import com.google.aiedge.gallery.ui.common.chat.Stat
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.customColors import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import java.io.File import java.io.File
import java.net.HttpURLConnection
import java.net.URL
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.ln import kotlin.math.ln
import kotlin.math.max import kotlin.math.max
@ -488,3 +493,38 @@ fun processLlmResponse(response: String): String {
return newContent return newContent
} }
@OptIn(ExperimentalSerializationApi::class)
inline fun <reified T> getJsonResponse(url: String): T? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET"
connection.connect()
val responseCode = connection.responseCode
if (responseCode == HttpURLConnection.HTTP_OK) {
val inputStream = connection.inputStream
val response = inputStream.bufferedReader().use { it.readText() }
// Parse JSON using kotlinx.serialization
val json = Json {
// Handle potential extra fields
ignoreUnknownKeys = true
allowComments = true
allowTrailingComma = true
}
val jsonObj = json.decodeFromString<T>(response)
return jsonObj
} else {
Log.e("AGUtils", "HTTP error: $responseCode")
}
} catch (e: Exception) {
Log.e(
"AGUtils",
"Error when getting json response: ${e.message}"
)
e.printStackTrace()
}
return null
}

View file

@ -16,9 +16,14 @@
package com.google.aiedge.gallery.ui.common.chat package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.animation.AnimatedContent
import androidx.compose.animation.ExperimentalSharedTransitionApi
import androidx.compose.animation.SharedTransitionLayout
import androidx.compose.foundation.Image
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.gestures.detectTapGestures import androidx.compose.foundation.gestures.detectTapGestures
import androidx.compose.foundation.gestures.detectTransformGestures
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
@ -28,6 +33,7 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.ime import androidx.compose.foundation.layout.ime
import androidx.compose.foundation.layout.imePadding import androidx.compose.foundation.layout.imePadding
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.width
@ -39,12 +45,15 @@ import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Timer import androidx.compose.material.icons.outlined.Timer
import androidx.compose.material.icons.rounded.Close
import androidx.compose.material.icons.rounded.ContentCopy import androidx.compose.material.icons.rounded.ContentCopy
import androidx.compose.material.icons.rounded.Refresh import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.SnackbarHost import androidx.compose.material3.SnackbarHost
@ -55,6 +64,7 @@ import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.MutableState import androidx.compose.runtime.MutableState
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
@ -65,12 +75,16 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.geometry.Offset import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.RectangleShape
import androidx.compose.ui.graphics.asImageBitmap import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.hapticfeedback.HapticFeedbackType import androidx.compose.ui.hapticfeedback.HapticFeedbackType
import androidx.compose.ui.input.nestedscroll.NestedScrollConnection import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
import androidx.compose.ui.input.nestedscroll.NestedScrollSource import androidx.compose.ui.input.nestedscroll.NestedScrollSource
import androidx.compose.ui.input.nestedscroll.nestedScroll import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.input.pointer.pointerInput import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.layout.onSizeChanged
import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalDensity import androidx.compose.ui.platform.LocalDensity
@ -80,6 +94,7 @@ import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.IntSize
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.Dialog
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
@ -102,7 +117,7 @@ enum class ChatInputType {
/** /**
* Composable function for the main chat panel, displaying messages and handling user input. * Composable function for the main chat panel, displaying messages and handling user input.
*/ */
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class, ExperimentalSharedTransitionApi::class)
@Composable @Composable
fun ChatPanel( fun ChatPanel(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
@ -127,6 +142,7 @@ fun ChatPanel(
val snackbarHostState = remember { SnackbarHostState() } val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val haptic = LocalHapticFeedback.current val haptic = LocalHapticFeedback.current
var selectedImageMessage by remember { mutableStateOf<ChatMessageImage?>(null) }
var curMessage by remember { mutableStateOf("") } // Correct state var curMessage by remember { mutableStateOf("") } // Correct state
val focusManager = LocalFocusManager.current val focusManager = LocalFocusManager.current
@ -200,13 +216,14 @@ fun ChatPanel(
} }
} }
val modelInitializationStatus = val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[selectedModel.name]
modelManagerUiState.modelInitializationStatus[selectedModel.name]
LaunchedEffect(modelInitializationStatus) { LaunchedEffect(modelInitializationStatus) {
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
} }
SharedTransitionLayout(modifier = Modifier.fillMaxSize()) {
AnimatedContent(targetState = selectedImageMessage) { targetSelectedImageMessage ->
Column( Column(
modifier = modifier.imePadding() modifier = modifier.imePadding()
) { ) {
@ -297,8 +314,7 @@ fun ChatPanel(
) )
.background(backgroundColor) .background(backgroundColor)
if (message is ChatMessageText) { if (message is ChatMessageText) {
messageBubbleModifier = messageBubbleModifier messageBubbleModifier = messageBubbleModifier.pointerInput(Unit) {
.pointerInput(Unit) {
detectTapGestures( detectTapGestures(
onLongPress = { onLongPress = {
haptic.performHapticFeedback(HapticFeedbackType.LongPress) haptic.performHapticFeedback(HapticFeedbackType.LongPress)
@ -316,7 +332,21 @@ fun ChatPanel(
is ChatMessageText -> MessageBodyText(message = message) is ChatMessageText -> MessageBodyText(message = message)
// Image // Image
is ChatMessageImage -> MessageBodyImage(message = message) is ChatMessageImage -> {
if (targetSelectedImageMessage != message) {
MessageBodyImage(
message = message,
modifier = Modifier
.clickable {
selectedImageMessage = message
}
.sharedElement(
sharedContentState = rememberSharedContentState(key = "selected_image"),
animatedVisibilityScope = this@AnimatedContent
),
)
}
}
// Image with history (for image gen) // Image with history (for image gen)
is ChatMessageImageWithHistory -> MessageBodyImageWithHistory( is ChatMessageImageWithHistory -> MessageBodyImageWithHistory(
@ -325,8 +355,9 @@ fun ChatPanel(
// Classification result // Classification result
is ChatMessageClassification -> MessageBodyClassification( is ChatMessageClassification -> MessageBodyClassification(
message = message, message = message, modifier = Modifier.width(
modifier = Modifier.width(message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH) message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH
)
) )
// Benchmark result. // Benchmark result.
@ -334,8 +365,7 @@ fun ChatPanel(
// Benchmark LLM result. // Benchmark LLM result.
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm( is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
message = message, message = message, modifier = Modifier.wrapContentWidth()
modifier = Modifier.wrapContentWidth()
) )
else -> {} else -> {}
@ -348,7 +378,7 @@ fun ChatPanel(
) { ) {
LatencyText(message = message) LatencyText(message = message)
// A button to show stats for the LLM message. // A button to show stats for the LLM message.
if (task.type == TaskType.LLM_CHAT && message is ChatMessageText if ((task.type == TaskType.LLM_CHAT || task.type == TaskType.LLM_ASK_IMAGE) && message is ChatMessageText
// This means we only want to show the action button when the message is done // This means we only want to show the action button when the message is done
// generating, at which point the latency will be set. // generating, at which point the latency will be set.
&& message.latencyMs >= 0 && message.latencyMs >= 0
@ -363,7 +393,10 @@ fun ChatPanel(
viewModel.toggleShowingStats(selectedModel, message) viewModel.toggleShowingStats(selectedModel, message)
// Add the stats message after the LLM message. // Add the stats message after the LLM message.
if (viewModel.isShowingStats(model = selectedModel, message = message)) { if (viewModel.isShowingStats(
model = selectedModel, message = message
)
) {
val llmBenchmarkResult = message.llmBenchmarkResult val llmBenchmarkResult = message.llmBenchmarkResult
if (llmBenchmarkResult != null) { if (llmBenchmarkResult != null) {
viewModel.insertMessageAfter( viewModel.insertMessageAfter(
@ -376,10 +409,12 @@ fun ChatPanel(
// Remove the stats message. // Remove the stats message.
else { else {
val curMessageIndex = val curMessageIndex =
viewModel.getMessageIndex(model = selectedModel, message = message) viewModel.getMessageIndex(
viewModel.removeMessageAt(
model = selectedModel, model = selectedModel,
index = curMessageIndex + 1 message = message
)
viewModel.removeMessageAt(
model = selectedModel, index = curMessageIndex + 1
) )
} }
}, },
@ -425,8 +460,23 @@ fun ChatPanel(
} }
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp)) SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
}
// Show an info message for ask image task to get users started.
if (task.type == TaskType.LLM_ASK_IMAGE && messages.isEmpty()) {
Column(
modifier = Modifier
.padding(horizontal = 16.dp)
.fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
MessageBodyInfo(
ChatMessageInfo(content = "To get started, click + below to add an image and type a prompt to ask a question about it."),
smallFontSize = false
)
}
}
}
// Chat input // Chat input
when (chatInputType) { when (chatInputType) {
@ -439,6 +489,7 @@ fun ChatPanel(
curMessage = curMessage, curMessage = curMessage,
inProgress = uiState.inProgress, inProgress = uiState.inProgress,
isResettingSession = uiState.isResettingSession, isResettingSession = uiState.isResettingSession,
modelPreparing = uiState.preparing,
hasImageMessage = hasImageMessage, hasImageMessage = hasImageMessage,
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes, textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
@ -459,7 +510,7 @@ fun ChatPanel(
onStopButtonClicked = onStopButtonClicked, onStopButtonClicked = onStopButtonClicked,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, // showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false, showPromptTemplatesInMenu = false,
showImagePickerInMenu = selectedModel.llmSupportImage == true, showImagePickerInMenu = selectedModel.llmSupportImage,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress, showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
) )
} }
@ -488,6 +539,58 @@ fun ChatPanel(
} }
} }
// A full-screen image viewer.
if (targetSelectedImageMessage != null) {
ZoomableBox(
modifier = Modifier
.fillMaxSize()
.background(Color.Black.copy(alpha = 0.9f))
.sharedElement(
rememberSharedContentState(key = "bounds"),
animatedVisibilityScope = this,
)
.skipToLookaheadSize(),
) {
// Image.
Image(
bitmap = targetSelectedImageMessage.imageBitMap,
contentDescription = "",
modifier = modifier
.fillMaxSize()
.graphicsLayer(
scaleX = scale,
scaleY = scale,
translationX = offsetX,
translationY = offsetY
)
.sharedElement(
sharedContentState = rememberSharedContentState(key = "selected_image"),
animatedVisibilityScope = this@AnimatedContent
),
contentScale = ContentScale.Fit,
)
// Close button.
IconButton(
onClick = {
selectedImageMessage = null
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceVariant,
),
modifier = Modifier.offset(x = (-8).dp, y = 8.dp)
) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
}
}
}
// Error dialog. // Error dialog.
if (showErrorDialog) { if (showErrorDialog) {
Dialog( Dialog(
@ -498,9 +601,7 @@ fun ChatPanel(
) { ) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column( Column(
modifier = Modifier modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) { ) {
// Title // Title
Text( Text(
@ -568,13 +669,10 @@ fun ChatPanel(
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp), horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier modifier = Modifier.padding(vertical = 8.dp, horizontal = 16.dp)
.padding(vertical = 8.dp, horizontal = 16.dp)
) { ) {
Icon( Icon(
Icons.Rounded.ContentCopy, Icons.Rounded.ContentCopy, contentDescription = "", modifier = Modifier.size(18.dp)
contentDescription = "",
modifier = Modifier.size(18.dp)
) )
Text("Copy text") Text("Copy text")
} }
@ -586,6 +684,51 @@ fun ChatPanel(
} }
} }
@Composable
fun ZoomableBox(
modifier: Modifier = Modifier,
minScale: Float = 1f,
maxScale: Float = 5f,
content: @Composable ZoomableBoxScope.() -> Unit
) {
var scale by remember { mutableFloatStateOf(1f) }
var offsetX by remember { mutableFloatStateOf(0f) }
var offsetY by remember { mutableFloatStateOf(0f) }
var size by remember { mutableStateOf(IntSize.Zero) }
Box(
modifier = modifier
.clip(RectangleShape)
.onSizeChanged { size = it }
.pointerInput(Unit) {
detectTransformGestures { _, pan, zoom, _ ->
scale = maxOf(minScale, minOf(scale * zoom, maxScale))
val maxX = (size.width * (scale - 1)) / 2
val minX = -maxX
offsetX = maxOf(minX, minOf(maxX, offsetX + pan.x))
val maxY = (size.height * (scale - 1)) / 2
val minY = -maxY
offsetY = maxOf(minY, minOf(maxY, offsetY + pan.y))
}
},
contentAlignment = Alignment.TopEnd
) {
val scope = ZoomableBoxScopeImpl(scale, offsetX, offsetY)
scope.content()
}
}
interface ZoomableBoxScope {
val scale: Float
val offsetX: Float
val offsetY: Float
}
private data class ZoomableBoxScopeImpl(
override val scale: Float,
override val offsetX: Float,
override val offsetY: Float
) : ZoomableBoxScope
@Preview(showBackground = true) @Preview(showBackground = true)
@Composable @Composable
fun ChatPanelPreview() { fun ChatPanelPreview() {

View file

@ -19,6 +19,7 @@ package com.google.aiedge.gallery.ui.common.chat
import android.util.Log import android.util.Log
import androidx.activity.compose.BackHandler import androidx.activity.compose.BackHandler
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
@ -27,6 +28,7 @@ import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
@ -36,10 +38,13 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshotFlow import androidx.compose.runtime.snapshotFlow
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
@ -81,7 +86,7 @@ fun ChatView(
chatInputType: ChatInputType = ChatInputType.TEXT, chatInputType: ChatInputType = ChatInputType.TEXT,
showStopButtonInInputWhenInProgress: Boolean = false, showStopButtonInInputWhenInProgress: Boolean = false,
) { ) {
val uiStat by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel val selectedModel = modelManagerUiState.selectedModel
@ -158,7 +163,9 @@ fun ChatView(
model = selectedModel, model = selectedModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
showResetSessionButton = true, showResetSessionButton = true,
isResettingSession = uiStat.isResettingSession, isResettingSession = uiState.isResettingSession,
inProgress = uiState.inProgress,
modelPreparing = uiState.preparing,
onResetSessionClicked = onResetSessionClicked, onResetSessionClicked = onResetSessionClicked,
onConfigChanged = { old, new -> onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage( viewModel.addConfigChangedMessage(

View file

@ -38,6 +38,11 @@ data class ChatUiState(
*/ */
val isResettingSession: Boolean = false, val isResettingSession: Boolean = false,
/**
* Indicates whether the model is preparing (before outputting any result and after initializing).
*/
val preparing: Boolean = false,
/** /**
* A map of model names to lists of chat messages. * A map of model names to lists of chat messages.
*/ */
@ -204,6 +209,10 @@ open class ChatViewModel(val task: Task) : ViewModel() {
_uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) } _uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) }
} }
fun setPreparing(preparing: Boolean) {
_uiState.update { _uiState.value.copy(preparing = preparing) }
}
fun addConfigChangedMessage( fun addConfigChangedMessage(
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
) { ) {

View file

@ -18,6 +18,8 @@ package com.google.aiedge.gallery.ui.common.chat
import android.util.Log import android.util.Log
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
@ -28,9 +30,11 @@ import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.width
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.BasicTextField import androidx.compose.foundation.text.BasicTextField
import androidx.compose.foundation.text.KeyboardOptions import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
@ -53,6 +57,9 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.FocusRequester import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusRequester import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.focus.onFocusChanged import androidx.compose.ui.focus.onFocusChanged
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.input.KeyboardType import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
@ -89,9 +96,20 @@ fun ConfigDialog(
putAll(initialValues) putAll(initialValues)
} }
} }
val interactionSource = remember { MutableInteractionSource() }
Dialog(onDismissRequest = onDismissed) { Dialog(onDismissRequest = onDismissed) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { val focusManager = LocalFocusManager.current
Card(
modifier = Modifier
.fillMaxWidth()
.clickable(
interactionSource = interactionSource, indication = null // Disable the ripple effect
) {
focusManager.clearFocus()
},
shape = RoundedCornerShape(16.dp)
) {
Column( Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) { ) {
@ -114,7 +132,14 @@ fun ConfigDialog(
} }
// List of config rows. // List of config rows.
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
ConfigEditorsPanel(configs = configs, values = values) ConfigEditorsPanel(configs = configs, values = values)
}
// Button row. // Button row.
Row( Row(
@ -264,6 +289,8 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
values[config.key.label] = NaN values[config.key.label] = NaN
} }
}, },
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
) { innerTextField -> ) { innerTextField ->
Box( Box(
modifier = Modifier.border( modifier = Modifier.border(

View file

@ -25,17 +25,18 @@ import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
@Composable @Composable
fun MessageBodyImage(message: ChatMessageImage) { fun MessageBodyImage(message: ChatMessageImage, modifier: Modifier = Modifier) {
val bitmapWidth = message.bitmap.width val bitmapWidth = message.bitmap.width
val bitmapHeight = message.bitmap.height val bitmapHeight = message.bitmap.height
val imageWidth = val imageWidth =
if (bitmapWidth >= bitmapHeight) 200 else (200f / bitmapHeight * bitmapWidth).toInt() if (bitmapWidth >= bitmapHeight) 200 else (200f / bitmapHeight * bitmapWidth).toInt()
val imageHeight = val imageHeight =
if (bitmapHeight >= bitmapWidth) 200 else (200f / bitmapWidth * bitmapHeight).toInt() if (bitmapHeight >= bitmapWidth) 200 else (200f / bitmapWidth * bitmapHeight).toInt()
Image( Image(
bitmap = message.imageBitMap, bitmap = message.imageBitMap,
contentDescription = "", contentDescription = "",
modifier = Modifier modifier = modifier
.height(imageHeight.dp) .height(imageHeight.dp)
.width(imageWidth.dp), .width(imageWidth.dp),
contentScale = ContentScale.Fit, contentScale = ContentScale.Fit,

View file

@ -38,7 +38,7 @@ import com.google.aiedge.gallery.ui.theme.customColors
* Supports markdown. * Supports markdown.
*/ */
@Composable @Composable
fun MessageBodyInfo(message: ChatMessageInfo) { fun MessageBodyInfo(message: ChatMessageInfo, smallFontSize: Boolean = true) {
Row( Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
) { ) {
@ -47,7 +47,11 @@ fun MessageBodyInfo(message: ChatMessageInfo) {
.clip(RoundedCornerShape(16.dp)) .clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor) .background(MaterialTheme.customColors.agentBubbleBgColor)
) { ) {
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true) MarkdownText(
text = message.content,
modifier = Modifier.padding(12.dp),
smallFontSize = smallFontSize
)
} }
} }
} }

View file

@ -103,6 +103,7 @@ fun MessageInputText(
@StringRes textFieldPlaceHolderRes: Int, @StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit, onValueChanged: (String) -> Unit,
onSendMessage: (List<ChatMessage>) -> Unit, onSendMessage: (List<ChatMessage>) -> Unit,
modelPreparing: Boolean = false,
onOpenPromptTemplatesClicked: () -> Unit = {}, onOpenPromptTemplatesClicked: () -> Unit = {},
onStopButtonClicked: () -> Unit = {}, onStopButtonClicked: () -> Unit = {},
showPromptTemplatesInMenu: Boolean = false, showPromptTemplatesInMenu: Boolean = false,
@ -173,7 +174,7 @@ fun MessageInputText(
.height(80.dp) .height(80.dp)
.shadow(2.dp, shape = RoundedCornerShape(8.dp)) .shadow(2.dp, shape = RoundedCornerShape(8.dp))
.clip(RoundedCornerShape(8.dp)) .clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)), .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
) )
Box(modifier = Modifier Box(modifier = Modifier
.offset(x = 10.dp, y = (-10).dp) .offset(x = 10.dp, y = (-10).dp)
@ -219,7 +220,7 @@ fun MessageInputText(
expanded = showAddContentMenu, expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false }) { onDismissRequest = { showAddContentMenu = false }) {
if (showImagePickerInMenu) { if (showImagePickerInMenu) {
// Take a photo. // Take a picture.
DropdownMenuItem( DropdownMenuItem(
text = { text = {
Row( Row(
@ -227,7 +228,7 @@ fun MessageInputText(
horizontalArrangement = Arrangement.spacedBy(6.dp) horizontalArrangement = Arrangement.spacedBy(6.dp)
) { ) {
Icon(Icons.Rounded.PhotoCamera, contentDescription = "") Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
Text("Take a photo") Text("Take a picture")
} }
}, },
enabled = pickedImages.isEmpty() && !hasImageMessage, enabled = pickedImages.isEmpty() && !hasImageMessage,
@ -321,7 +322,7 @@ fun MessageInputText(
Spacer(modifier = Modifier.width(8.dp)) Spacer(modifier = Modifier.width(8.dp))
if (inProgress && showStopButtonWhenInProgress) { if (inProgress && showStopButtonWhenInProgress) {
if (!modelInitializing) { if (!modelInitializing && !modelPreparing) {
IconButton( IconButton(
onClick = onStopButtonClicked, onClick = onStopButtonClicked,
colors = IconButtonDefaults.iconButtonColors( colors = IconButtonDefaults.iconButtonColors(

View file

@ -20,8 +20,9 @@ import android.content.Intent
import android.net.Uri import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.animation.core.animateFloatAsState import androidx.compose.animation.AnimatedContent
import androidx.compose.animation.core.tween import androidx.compose.animation.ExperimentalSharedTransitionApi
import androidx.compose.animation.SharedTransitionLayout
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource import androidx.compose.foundation.interaction.MutableInteractionSource
@ -30,17 +31,13 @@ import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.heightIn
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ChevronRight import androidx.compose.material.icons.rounded.ChevronRight
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material.icons.rounded.UnfoldLess import androidx.compose.material.icons.rounded.UnfoldLess
import androidx.compose.material.icons.rounded.UnfoldMore import androidx.compose.material.icons.rounded.UnfoldMore
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.OutlinedButton import androidx.compose.material3.OutlinedButton
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.ripple import androidx.compose.material3.ripple
@ -48,16 +45,12 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.movableContentOf
import androidx.compose.runtime.movableContentWithReceiverOf
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.layout.LookaheadScope
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp import androidx.compose.ui.unit.Dp
@ -91,6 +84,7 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
* model description and buttons for learning more (opening a URL) and downloading/trying * model description and buttons for learning more (opening a URL) and downloading/trying
* the model. * the model.
*/ */
@OptIn(ExperimentalSharedTransitionApi::class)
@Composable @Composable
fun ModelItem( fun ModelItem(
model: Model, model: Model,
@ -117,37 +111,57 @@ fun ModelItem(
var isExpanded by remember { mutableStateOf(false) } var isExpanded by remember { mutableStateOf(false) }
// Animate alpha for model description and button rows when switching between layouts. var boxModifier = modifier
val alphaAnimation by animateFloatAsState( .fillMaxWidth()
targetValue = if (isExpanded) 1f else 0f, .clip(RoundedCornerShape(size = 42.dp))
animationSpec = tween(durationMillis = LAYOUT_ANIMATION_DURATION - 50) .background(
getTaskBgColor(task)
) )
boxModifier = if (canExpand) {
boxModifier.clickable(onClick = {
if (!model.imported) {
isExpanded = !isExpanded
} else {
onModelClicked(model)
}
}, interactionSource = remember { MutableInteractionSource() }, indication = ripple(
bounded = true,
radius = 1000.dp,
)
)
} else {
boxModifier
}
LookaheadScope { Box(
// Task icon. modifier = boxModifier,
val taskIcon = remember { contentAlignment = Alignment.Center,
movableContentOf { ) {
SharedTransitionLayout {
AnimatedContent(
isExpanded, label = "item_layout_transition",
) { targetState ->
val taskIcon = @Composable {
TaskIcon( TaskIcon(
task = task, modifier = Modifier.animateLayout() task = task, modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "task_icon"),
animatedVisibilityScope = this@AnimatedContent,
)
) )
} }
}
// Model name and status. val modelNameAndStatus = @Composable {
val modelNameAndStatus = remember {
movableContentOf {
ModelNameAndStatus( ModelNameAndStatus(
model = model, model = model,
task = task, task = task,
downloadStatus = downloadStatus, downloadStatus = downloadStatus,
isExpanded = isExpanded, isExpanded = isExpanded,
modifier = Modifier.animateLayout() animatedVisibilityScope = this@AnimatedContent,
sharedTransitionScope = this@SharedTransitionLayout
) )
} }
}
val actionButton = remember { val actionButton = @Composable {
movableContentOf {
ModelItemActionButton( ModelItemActionButton(
context = context, context = context,
model = model, model = model,
@ -165,61 +179,48 @@ fun ModelItem(
}, },
showDeleteButton = showDeleteButton, showDeleteButton = showDeleteButton,
showDownloadButton = false, showDownloadButton = false,
modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "action_button"),
animatedVisibilityScope = this@AnimatedContent,
)
) )
}
} }
// Expand/collapse icon, or the config icon. val expandButton = @Composable {
val expandButton = remember {
movableContentOf {
if (showConfigButtonIfExisted) {
if (downloadStatus?.status === ModelDownloadStatusType.SUCCEEDED) {
if (model.configs.isNotEmpty()) {
IconButton(onClick = onConfigClicked) {
Icon(
Icons.Rounded.Settings,
contentDescription = "",
tint = getTaskIconColor(task)
)
}
}
}
} else {
Icon( Icon(
// For imported model, show ">" directly indicating users can just tap the model item to // For imported model, show ">" directly indicating users can just tap the model item to
// go into it without needing to expand it first. // go into it without needing to expand it first.
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore, if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "", contentDescription = "",
tint = getTaskIconColor(task), tint = getTaskIconColor(task),
modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "expand_button"),
animatedVisibilityScope = this@AnimatedContent,
)
) )
} }
}
}
// Model description shown in expanded layout. val description = @Composable {
val modelDescription = remember {
movableContentOf { m: Modifier ->
if (model.info.isNotEmpty()) { if (model.info.isNotEmpty()) {
MarkdownText( MarkdownText(
model.info, model.info, modifier = Modifier
modifier = Modifier .sharedElement(
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp) sharedContentState = rememberSharedContentState(key = "description"),
.animateLayout() animatedVisibilityScope = this@AnimatedContent,
.then(m) )
.skipToLookaheadSize()
) )
} }
} }
}
// Button rows shown in expanded layout. val buttonsRow = @Composable {
val buttonRows = remember {
movableContentOf { m: Modifier ->
Row( Row(
modifier = Modifier horizontalArrangement = Arrangement.spacedBy(12.dp), modifier = Modifier
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp) .sharedElement(
.animateLayout() sharedContentState = rememberSharedContentState(key = "buttons_row"),
.then(m), animatedVisibilityScope = this@AnimatedContent,
horizontalArrangement = Arrangement.spacedBy(12.dp), )
.skipToLookaheadSize()
) { ) {
// The "learn more" button. Click to show related urls in a bottom sheet. // The "learn more" button. Click to show related urls in a bottom sheet.
if (model.learnMoreUrl.isNotEmpty()) { if (model.learnMoreUrl.isNotEmpty()) {
@ -238,60 +239,42 @@ fun ModelItem(
// Button to start the download and start the chat session with the model. // Button to start the download and start the chat session with the model.
val needToDownloadFirst = val needToDownloadFirst =
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
DownloadAndTryButton( DownloadAndTryButton(task = task,
task = task,
model = model, model = model,
enabled = isExpanded, enabled = isExpanded,
needToDownloadFirst = needToDownloadFirst, needToDownloadFirst = needToDownloadFirst,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
onClicked = { onModelClicked(model) } onClicked = { onModelClicked(model) })
)
}
} }
} }
val container = remember { // Collapsed state.
movableContentWithReceiverOf<LookaheadScope, @Composable () -> Unit> { content -> if (!targetState) {
Box( Column(
modifier = Modifier.animateLayout(), horizontalAlignment = Alignment.CenterHorizontally,
contentAlignment = Alignment.TopEnd,
) { ) {
content() Row(
} verticalAlignment = Alignment.CenterVertically,
} horizontalArrangement = Arrangement.spacedBy(12.dp),
} modifier = Modifier
var boxModifier = modifier
.fillMaxWidth() .fillMaxWidth()
.clip(RoundedCornerShape(size = 42.dp)) .padding(start = 18.dp, end = 18.dp)
.background( .padding(vertical = verticalSpacing)
getTaskBgColor(task)
)
boxModifier = if (canExpand) {
boxModifier.clickable(
onClick = {
if (!model.imported) {
isExpanded = !isExpanded
} else {
onModelClicked(model)
}
},
interactionSource = remember { MutableInteractionSource() },
indication = ripple(
bounded = true,
radius = 500.dp,
)
)
} else {
boxModifier
}
Box(
modifier = boxModifier,
contentAlignment = Alignment.Center
) { ) {
if (isExpanded) { // Icon at the left.
container { taskIcon()
// The main part (icon, model name, status, etc) // Model name and status at the center.
Row(modifier = Modifier.weight(1f)) {
modelNameAndStatus()
}
// Action button and expand/collapse button at the right.
Row(verticalAlignment = Alignment.CenterVertically) {
actionButton()
expandButton()
}
}
}
} else {
Column( Column(
verticalArrangement = Arrangement.spacedBy(14.dp), verticalArrangement = Arrangement.spacedBy(14.dp),
horizontalAlignment = Alignment.CenterHorizontally, horizontalAlignment = Alignment.CenterHorizontally,
@ -300,7 +283,9 @@ fun ModelItem(
.padding(vertical = verticalSpacing, horizontal = 18.dp) .padding(vertical = verticalSpacing, horizontal = 18.dp)
) { ) {
Box(contentAlignment = Alignment.Center) { Box(contentAlignment = Alignment.Center) {
// Icon at the top-center.
taskIcon() taskIcon()
// Action button and expand/collapse button at the right.
Row( Row(
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth(),
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
@ -310,39 +295,12 @@ fun ModelItem(
expandButton() expandButton()
} }
} }
// Name and status below the icon.
modelNameAndStatus() modelNameAndStatus()
modelDescription(Modifier.alpha(alphaAnimation)) // Description.
buttonRows(Modifier.alpha(alphaAnimation)) // Apply alpha here description()
} // Buttons
} buttonsRow()
} else {
container {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
// The main part (icon, model name, status, etc)
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp),
modifier = Modifier
.fillMaxWidth()
.padding(start = 18.dp, end = 18.dp)
.padding(vertical = verticalSpacing)
) {
taskIcon()
Row(modifier = Modifier.weight(1f)) {
modelNameAndStatus()
}
Row(verticalAlignment = Alignment.CenterVertically) {
actionButton()
expandButton()
}
}
Column(
modifier = Modifier.offset(y = 30.dp),
horizontalAlignment = Alignment.CenterHorizontally
) {
modelDescription(Modifier.alpha(alphaAnimation))
buttonRows(Modifier.alpha(alphaAnimation))
}
} }
} }
} }

View file

@ -108,7 +108,8 @@ fun ModelItemActionButton(
// Button to cancel the download when it is in progress. // Button to cancel the download when it is in progress.
ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = { ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = {
modelManagerViewModel.cancelDownloadModel( modelManagerViewModel.cancelDownloadModel(
model task = task,
model = model
) )
}) { }) {
Icon( Icon(

View file

@ -16,6 +16,9 @@
package com.google.aiedge.gallery.ui.common.modelitem package com.google.aiedge.gallery.ui.common.modelitem
import androidx.compose.animation.AnimatedVisibilityScope
import androidx.compose.animation.ExperimentalSharedTransitionApi
import androidx.compose.animation.SharedTransitionScope
import androidx.compose.animation.core.Animatable import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween import androidx.compose.animation.core.tween
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
@ -30,6 +33,7 @@ import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.focusModifier
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
@ -54,18 +58,22 @@ import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
* - "Unzipping..." status for unzipping processes. * - "Unzipping..." status for unzipping processes.
* - Model size for successful downloads. * - Model size for successful downloads.
*/ */
@OptIn(ExperimentalSharedTransitionApi::class)
@Composable @Composable
fun ModelNameAndStatus( fun ModelNameAndStatus(
model: Model, model: Model,
task: Task, task: Task,
downloadStatus: ModelDownloadStatus?, downloadStatus: ModelDownloadStatus?,
isExpanded: Boolean, isExpanded: Boolean,
sharedTransitionScope: SharedTransitionScope,
animatedVisibilityScope: AnimatedVisibilityScope,
modifier: Modifier = Modifier modifier: Modifier = Modifier
) { ) {
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
var curDownloadProgress = 0f var curDownloadProgress = 0f
with(sharedTransitionScope) {
Column( Column(
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
) { ) {
@ -78,7 +86,10 @@ fun ModelNameAndStatus(
maxLines = 1, maxLines = 1,
overflow = TextOverflow.MiddleEllipsis, overflow = TextOverflow.MiddleEllipsis,
style = MaterialTheme.typography.titleMedium, style = MaterialTheme.typography.titleMedium,
modifier = modifier, modifier = Modifier.sharedElement(
rememberSharedContentState(key = "model_name"),
animatedVisibilityScope = animatedVisibilityScope
)
) )
} }
@ -87,7 +98,12 @@ fun ModelNameAndStatus(
if (!inProgress && !isPartiallyDownloaded) { if (!inProgress && !isPartiallyDownloaded) {
StatusIcon( StatusIcon(
downloadStatus = downloadStatus, downloadStatus = downloadStatus,
modifier = modifier.padding(end = 4.dp) modifier = modifier
.padding(end = 4.dp)
.sharedElement(
rememberSharedContentState(key = "download_status_icon"),
animatedVisibilityScope = animatedVisibilityScope
)
) )
} }
@ -99,7 +115,10 @@ fun ModelNameAndStatus(
color = MaterialTheme.colorScheme.error, color = MaterialTheme.colorScheme.error,
style = labelSmallNarrow, style = labelSmallNarrow,
overflow = TextOverflow.Ellipsis, overflow = TextOverflow.Ellipsis,
modifier = modifier, modifier = Modifier.sharedElement(
rememberSharedContentState(key = "failure_messsage"),
animatedVisibilityScope = animatedVisibilityScope
)
) )
} }
} }
@ -154,7 +173,12 @@ fun ModelNameAndStatus(
style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp), style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp),
textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start, textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start,
overflow = TextOverflow.Visible, overflow = TextOverflow.Visible,
modifier = modifier.offset(y = if (index == 0) 0.dp else (-1).dp) modifier = Modifier
.offset(y = if (index == 0) 0.dp else (-1).dp)
.sharedElement(
rememberSharedContentState(key = "status_label_${index}"),
animatedVisibilityScope = animatedVisibilityScope
)
) )
} }
} }
@ -168,7 +192,12 @@ fun ModelNameAndStatus(
progress = { animatedProgress.value }, progress = { animatedProgress.value },
color = getTaskIconColor(task = task), color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = modifier.padding(top = 2.dp) modifier = Modifier
.padding(top = 2.dp)
.sharedElement(
rememberSharedContentState(key = "download_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope
)
) )
LaunchedEffect(curDownloadProgress) { LaunchedEffect(curDownloadProgress) {
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150)) animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
@ -180,8 +209,13 @@ fun ModelNameAndStatus(
color = getTaskIconColor(task = task), color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier modifier = Modifier
.padding(top = 2.dp), .padding(top = 2.dp)
.sharedElement(
rememberSharedContentState(key = "unzip_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope
)
) )
} }
} }
} }
}

View file

@ -16,8 +16,11 @@
package com.google.aiedge.gallery.ui.home package com.google.aiedge.gallery.ui.home
import android.app.Activity
import android.content.Context
import android.content.Intent import android.content.Intent
import android.net.Uri import android.net.Uri
import android.provider.OpenableColumns
import android.util.Log import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher import androidx.activity.result.ActivityResultLauncher
@ -49,6 +52,7 @@ import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.rounded.Error import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.AlertDialog import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Button
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults import androidx.compose.material3.CardDefaults
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
@ -80,12 +84,18 @@ import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.scale import androidx.compose.ui.draw.scale
import androidx.compose.ui.graphics.Brush import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.input.nestedscroll.nestedScroll import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.layout.layout
import androidx.compose.ui.platform.LocalConfiguration
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalWindowInfo
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.LinkAnnotation
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.TextLinkStyles
import androidx.compose.ui.text.buildAnnotatedString
import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextDecoration
import androidx.compose.ui.text.withLink
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
@ -93,7 +103,6 @@ import com.google.aiedge.gallery.GalleryTopAppBar
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.AppBarAction import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.ImportedModelInfo import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.TaskIcon import com.google.aiedge.gallery.ui.common.TaskIcon
@ -101,7 +110,6 @@ import com.google.aiedge.gallery.ui.common.getTaskBgColor
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.ThemeSettings
import com.google.aiedge.gallery.ui.theme.customColors import com.google.aiedge.gallery.ui.theme.customColors
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
@ -109,6 +117,12 @@ import kotlinx.coroutines.launch
private const val TAG = "AGHomeScreen" private const val TAG = "AGHomeScreen"
private const val TASK_COUNT_ANIMATION_DURATION = 250 private const val TASK_COUNT_ANIMATION_DURATION = 250
private const val MAX_TASK_CARD_PADDING = 24
private const val MIN_TASK_CARD_PADDING = 18
private const val MAX_TASK_CARD_RADIUS = 43.5
private const val MIN_TASK_CARD_RADIUS = 30
private const val MAX_TASK_CARD_ICON_SIZE = 56
private const val MIN_TASK_CARD_ICON_SIZE = 50
/** Navigation destination data */ /** Navigation destination data */
object HomeScreenDestination { object HomeScreenDestination {
@ -127,6 +141,7 @@ fun HomeScreen(
val uiState by modelManagerViewModel.uiState.collectAsState() val uiState by modelManagerViewModel.uiState.collectAsState()
var showSettingsDialog by remember { mutableStateOf(false) } var showSettingsDialog by remember { mutableStateOf(false) }
var showImportModelSheet by remember { mutableStateOf(false) } var showImportModelSheet by remember { mutableStateOf(false) }
var showUnsupportedFileTypeDialog by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState() val sheetState = rememberModalBottomSheetState()
var showImportDialog by remember { mutableStateOf(false) } var showImportDialog by remember { mutableStateOf(false) }
var showImportingDialog by remember { mutableStateOf(false) } var showImportingDialog by remember { mutableStateOf(false) }
@ -135,17 +150,21 @@ fun HomeScreen(
val coroutineScope = rememberCoroutineScope() val coroutineScope = rememberCoroutineScope()
val snackbarHostState = remember { SnackbarHostState() } val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val context = LocalContext.current
val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 }
val loadingHfModels = uiState.loadingHfModels
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult( val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult() contract = ActivityResultContracts.StartActivityForResult()
) { result -> ) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) { if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri -> result.data?.data?.let { uri ->
val fileName = getFileName(context = context, uri = uri)
Log.d(TAG, "Selected file: $fileName")
if (fileName != null && !fileName.endsWith(".task")) {
showUnsupportedFileTypeDialog = true
} else {
selectedLocalModelFileUri.value = uri selectedLocalModelFileUri.value = uri
showImportDialog = true showImportDialog = true
}
} ?: run { } ?: run {
Log.d(TAG, "No file selected or URI is null.") Log.d(TAG, "No file selected or URI is null.")
} }
@ -154,21 +173,15 @@ fun HomeScreen(
} }
} }
Scaffold( Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = {
modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
topBar = {
GalleryTopAppBar( GalleryTopAppBar(
title = stringResource(HomeScreenDestination.titleRes), title = stringResource(HomeScreenDestination.titleRes),
rightAction = AppBarAction( rightAction = AppBarAction(actionType = AppBarActionType.APP_SETTING, actionFn = {
actionType = AppBarActionType.APP_SETTING, actionFn = {
showSettingsDialog = true showSettingsDialog = true
} }),
),
loadingHfModels = loadingHfModels,
scrollBehavior = scrollBehavior, scrollBehavior = scrollBehavior,
) )
}, }, floatingActionButton = {
floatingActionButton = {
// A floating action button to show "import model" bottom sheet. // A floating action button to show "import model" bottom sheet.
SmallFloatingActionButton( SmallFloatingActionButton(
onClick = { onClick = {
@ -179,11 +192,10 @@ fun HomeScreen(
) { ) {
Icon(Icons.Filled.Add, "") Icon(Icons.Filled.Add, "")
} }
} }) { innerPadding ->
) { innerPadding ->
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) { Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
TaskList( TaskList(
tasks = nonEmptyTasks, tasks = uiState.tasks,
navigateToTaskScreen = navigateToTaskScreen, navigateToTaskScreen = navigateToTaskScreen,
loadingModelAllowlist = uiState.loadingModelAllowlist, loadingModelAllowlist = uiState.loadingModelAllowlist,
modifier = Modifier.fillMaxSize(), modifier = Modifier.fillMaxSize(),
@ -198,16 +210,8 @@ fun HomeScreen(
if (showSettingsDialog) { if (showSettingsDialog) {
SettingsDialog( SettingsDialog(
curThemeOverride = modelManagerViewModel.readThemeOverride(), curThemeOverride = modelManagerViewModel.readThemeOverride(),
modelManagerViewModel = modelManagerViewModel,
onDismissed = { showSettingsDialog = false }, onDismissed = { showSettingsDialog = false },
onOk = { curConfigValues ->
// Update theme settings.
// This will update app's theme.
val themeOverride = curConfigValues[ConfigKey.THEME.label] as String
ThemeSettings.themeOverride.value = themeOverride
// Save to data store.
modelManagerViewModel.saveThemeOverride(themeOverride)
},
) )
} }
@ -232,10 +236,6 @@ fun HomeScreen(
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply { val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE) addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*" type = "*/*"
putExtra(
Intent.EXTRA_MIME_TYPES,
arrayOf("application/x-binary", "application/octet-stream")
)
// Single select. // Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false) putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
} }
@ -259,9 +259,7 @@ fun HomeScreen(
// Import dialog // Import dialog
if (showImportDialog) { if (showImportDialog) {
selectedLocalModelFileUri.value?.let { uri -> selectedLocalModelFileUri.value?.let { uri ->
ModelImportDialog(uri = uri, ModelImportDialog(uri = uri, onDismiss = { showImportDialog = false }, onDone = { info ->
onDismiss = { showImportDialog = false },
onDone = { info ->
selectedImportedModelInfo.value = info selectedImportedModelInfo.value = info
showImportDialog = false showImportDialog = false
showImportingDialog = true showImportingDialog = true
@ -273,8 +271,7 @@ fun HomeScreen(
if (showImportingDialog) { if (showImportingDialog) {
selectedLocalModelFileUri.value?.let { uri -> selectedLocalModelFileUri.value?.let { uri ->
selectedImportedModelInfo.value?.let { info -> selectedImportedModelInfo.value?.let { info ->
ModelImportingDialog( ModelImportingDialog(uri = uri,
uri = uri,
info = info, info = info,
onDismiss = { showImportingDialog = false }, onDismiss = { showImportingDialog = false },
onDone = { onDone = {
@ -292,6 +289,22 @@ fun HomeScreen(
} }
} }
// Alert dialog for unsupported file type.
if (showUnsupportedFileTypeDialog) {
AlertDialog(
onDismissRequest = { showUnsupportedFileTypeDialog = false },
title = { Text("Unsupported file type") },
text = {
Text("Only \".task\" file type is supported.")
},
confirmButton = {
Button(onClick = { showUnsupportedFileTypeDialog = false }) {
Text(stringResource(R.string.ok))
}
},
)
}
if (uiState.loadingModelAllowlistError.isNotEmpty()) { if (uiState.loadingModelAllowlistError.isNotEmpty()) {
AlertDialog( AlertDialog(
icon = { icon = {
@ -307,11 +320,9 @@ fun HomeScreen(
modelManagerViewModel.loadModelAllowlist() modelManagerViewModel.loadModelAllowlist()
}, },
confirmButton = { confirmButton = {
TextButton( TextButton(onClick = {
onClick = {
modelManagerViewModel.loadModelAllowlist() modelManagerViewModel.loadModelAllowlist()
} }) {
) {
Text("Retry") Text("Retry")
} }
}, },
@ -327,6 +338,38 @@ private fun TaskList(
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
contentPadding: PaddingValues = PaddingValues(0.dp), contentPadding: PaddingValues = PaddingValues(0.dp),
) { ) {
val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current
val screenWidthDp = remember {
with(density) {
windowInfo.containerSize.width.toDp()
}
}
val screenHeightDp = remember {
with(density) {
windowInfo.containerSize.height.toDp()
}
}
val sizeFraction = remember { ((screenWidthDp - 360.dp) / (410.dp - 360.dp)).coerceIn(0f, 1f) }
val linkColor = MaterialTheme.customColors.linkColor
val introText = buildAnnotatedString {
append("Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from ")
withLink(
link = LinkAnnotation.Url(
url = "https://huggingface.co/litert-community", // Replace with the actual URL
styles = TextLinkStyles(
style = SpanStyle(
color = linkColor,
textDecoration = TextDecoration.Underline,
)
)
)
) {
append("LiteRT community")
}
}
Box(modifier = modifier.fillMaxSize()) { Box(modifier = modifier.fillMaxSize()) {
LazyVerticalGrid( LazyVerticalGrid(
columns = GridCells.Fixed(count = 2), columns = GridCells.Fixed(count = 2),
@ -335,10 +378,15 @@ private fun TaskList(
horizontalArrangement = Arrangement.spacedBy(8.dp), horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalArrangement = Arrangement.spacedBy(8.dp), verticalArrangement = Arrangement.spacedBy(8.dp),
) { ) {
// New rel
item(key = "newReleaseNotification", span = { GridItemSpan(2) }) {
NewReleaseNotification()
}
// Headline. // Headline.
item(key = "headline", span = { GridItemSpan(2) }) { item(key = "headline", span = { GridItemSpan(2) }) {
Text( Text(
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community", introText,
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.padding(bottom = 20.dp) modifier = Modifier.padding(bottom = 20.dp)
@ -364,14 +412,21 @@ private fun TaskList(
} }
} }
} else { } else {
// Cards. // LLM Cards.
item(key = "llmCardsHeader", span = { GridItemSpan(2) }) {
Text(
"Example LLM Use Cases",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.padding(bottom = 4.dp)
)
}
items(tasks) { task -> items(tasks) { task ->
TaskCard( TaskCard(
task = task, sizeFraction = sizeFraction, task = task, onClick = {
onClick = {
navigateToTaskScreen(task) navigateToTaskScreen(task)
}, }, modifier = Modifier
modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
.aspectRatio(1f) .aspectRatio(1f)
) )
@ -388,7 +443,7 @@ private fun TaskList(
Box( Box(
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
.height(LocalConfiguration.current.screenHeightDp.dp * 0.25f) .height(screenHeightDp * 0.25f)
.background( .background(
Brush.verticalGradient( Brush.verticalGradient(
colors = MaterialTheme.customColors.homeBottomGradient, colors = MaterialTheme.customColors.homeBottomGradient,
@ -400,7 +455,15 @@ private fun TaskList(
} }
@Composable @Composable
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) { private fun TaskCard(
task: Task, onClick: () -> Unit, sizeFraction: Float, modifier: Modifier = Modifier
) {
val padding =
(MAX_TASK_CARD_PADDING - MIN_TASK_CARD_PADDING) * sizeFraction + MIN_TASK_CARD_PADDING
val radius = (MAX_TASK_CARD_RADIUS - MIN_TASK_CARD_RADIUS) * sizeFraction + MIN_TASK_CARD_RADIUS
val iconSize =
(MAX_TASK_CARD_ICON_SIZE - MIN_TASK_CARD_ICON_SIZE) * sizeFraction + MIN_TASK_CARD_ICON_SIZE
// Observes the model count and updates the model count label with a fade-in/fade-out animation // Observes the model count and updates the model count label with a fade-in/fade-out animation
// whenever the count changes. // whenever the count changes.
val modelCount by remember { val modelCount by remember {
@ -445,7 +508,7 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
Card( Card(
modifier = modifier modifier = modifier
.clip(RoundedCornerShape(43.5.dp)) .clip(RoundedCornerShape(radius.dp))
.clickable( .clickable(
onClick = onClick, onClick = onClick,
), ),
@ -456,39 +519,24 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
Column( Column(
modifier = Modifier modifier = Modifier
.fillMaxSize() .fillMaxSize()
.padding(24.dp), .padding(padding.dp),
) { ) {
// Icon. // Icon.
TaskIcon(task = task) TaskIcon(task = task, width = iconSize.dp)
Spacer(modifier = Modifier.weight(1f)) Spacer(modifier = Modifier.weight(2f))
// Title. // Title.
val pair = task.type.label.splitByFirstSpace()
Text( Text(
pair.first, task.type.label,
color = MaterialTheme.colorScheme.primary, color = MaterialTheme.colorScheme.primary,
style = titleMediumNarrow.copy( style = titleMediumNarrow.copy(
fontSize = 20.sp, fontSize = 20.sp,
fontWeight = FontWeight.Bold, fontWeight = FontWeight.Bold,
), ),
) )
if (pair.second.isNotEmpty()) {
Text( Spacer(modifier = Modifier.weight(1f))
pair.second,
color = MaterialTheme.colorScheme.primary,
style = titleMediumNarrow.copy(
fontSize = 18.sp,
fontWeight = FontWeight.Bold,
),
modifier = Modifier.layout { measurable, constraints ->
val placeable = measurable.measure(constraints)
layout(placeable.width, placeable.height) {
placeable.placeRelative(0, -4.dp.roundToPx())
}
}
)
}
// Model count. // Model count.
Text( Text(
@ -503,12 +551,21 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
} }
} }
private fun String.splitByFirstSpace(): Pair<String, String> { // Helper function to get the file name from a URI
val spaceIndex = this.indexOf(' ') fun getFileName(context: Context, uri: Uri): String? {
if (spaceIndex == -1) { if (uri.scheme == "content") {
return Pair(this, "") context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
val nameIndex = cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME)
if (nameIndex != -1) {
return cursor.getString(nameIndex)
} }
return Pair(this.substring(0, spaceIndex), this.substring(spaceIndex + 1)) }
}
} else if (uri.scheme == "file") {
return uri.lastPathSegment
}
return null
} }
@Preview @Preview

View file

@ -22,12 +22,16 @@ import android.provider.OpenableColumns
import android.util.Log import android.util.Log
import androidx.compose.animation.core.Animatable import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween import androidx.compose.animation.core.tween
import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Error import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.Button import androidx.compose.material3.Button
@ -51,6 +55,7 @@ import androidx.compose.runtime.snapshots.SnapshotStateMap
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties import androidx.compose.ui.window.DialogProperties
@ -82,41 +87,34 @@ import java.nio.charset.StandardCharsets
private const val TAG = "AGModelImportDialog" private const val TAG = "AGModelImportDialog"
private val IMPORT_CONFIGS_LLM: List<Config> = listOf( private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
LabelConfig(key = ConfigKey.NAME), LabelConfig(key = ConfigKey.NAME), LabelConfig(key = ConfigKey.MODEL_TYPE), NumberSliderConfig(
LabelConfig(key = ConfigKey.MODEL_TYPE),
NumberSliderConfig(
key = ConfigKey.DEFAULT_MAX_TOKENS, key = ConfigKey.DEFAULT_MAX_TOKENS,
sliderMin = 100f, sliderMin = 100f,
sliderMax = 1024f, sliderMax = 1024f,
defaultValue = DEFAULT_MAX_TOKEN.toFloat(), defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
valueType = ValueType.INT valueType = ValueType.INT
), ), NumberSliderConfig(
NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPK, key = ConfigKey.DEFAULT_TOPK,
sliderMin = 5f, sliderMin = 5f,
sliderMax = 40f, sliderMax = 40f,
defaultValue = DEFAULT_TOPK.toFloat(), defaultValue = DEFAULT_TOPK.toFloat(),
valueType = ValueType.INT valueType = ValueType.INT
), ), NumberSliderConfig(
NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPP, key = ConfigKey.DEFAULT_TOPP,
sliderMin = 0.0f, sliderMin = 0.0f,
sliderMax = 1.0f, sliderMax = 1.0f,
defaultValue = DEFAULT_TOPP, defaultValue = DEFAULT_TOPP,
valueType = ValueType.FLOAT valueType = ValueType.FLOAT
), ), NumberSliderConfig(
NumberSliderConfig(
key = ConfigKey.DEFAULT_TEMPERATURE, key = ConfigKey.DEFAULT_TEMPERATURE,
sliderMin = 0.0f, sliderMin = 0.0f,
sliderMax = 2.0f, sliderMax = 2.0f,
defaultValue = DEFAULT_TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE,
valueType = ValueType.FLOAT valueType = ValueType.FLOAT
), ), BooleanSwitchConfig(
BooleanSwitchConfig(
key = ConfigKey.SUPPORT_IMAGE, key = ConfigKey.SUPPORT_IMAGE,
defaultValue = false, defaultValue = false,
), ), SegmentedButtonConfig(
SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS, key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label, defaultValue = Accelerator.CPU.label,
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label), options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
@ -126,9 +124,7 @@ private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
@Composable @Composable
fun ModelImportDialog( fun ModelImportDialog(
uri: Uri, uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
onDismiss: () -> Unit,
onDone: (ImportedModelInfo) -> Unit
) { ) {
val context = LocalContext.current val context = LocalContext.current
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) } val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
@ -150,15 +146,23 @@ fun ModelImportDialog(
putAll(initialValues) putAll(initialValues)
} }
} }
val interactionSource = remember { MutableInteractionSource() }
Dialog( Dialog(
onDismissRequest = onDismiss, onDismissRequest = onDismiss,
) { ) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { val focusManager = LocalFocusManager.current
Column( Card(
modifier = Modifier modifier = Modifier
.padding(20.dp), .fillMaxWidth()
verticalArrangement = Arrangement.spacedBy(16.dp) .clickable(
interactionSource = interactionSource, indication = null // Disable the ripple effect
) {
focusManager.clearFocus()
}, shape = RoundedCornerShape(16.dp)
) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) { ) {
// Title. // Title.
Text( Text(
@ -167,11 +171,18 @@ fun ModelImportDialog(
modifier = Modifier.padding(bottom = 8.dp) modifier = Modifier.padding(bottom = 8.dp)
) )
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Default configs for users to set. // Default configs for users to set.
ConfigEditorsPanel( ConfigEditorsPanel(
configs = IMPORT_CONFIGS_LLM, configs = IMPORT_CONFIGS_LLM,
values = values, values = values,
) )
}
// Button row. // Button row.
Row( Row(
@ -210,10 +221,7 @@ fun ModelImportDialog(
@Composable @Composable
fun ModelImportingDialog( fun ModelImportingDialog(
uri: Uri, uri: Uri, info: ImportedModelInfo, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
info: ImportedModelInfo,
onDismiss: () -> Unit,
onDone: (ImportedModelInfo) -> Unit
) { ) {
var error by remember { mutableStateOf("") } var error by remember { mutableStateOf("") }
val context = LocalContext.current val context = LocalContext.current
@ -222,8 +230,7 @@ fun ModelImportingDialog(
LaunchedEffect(Unit) { LaunchedEffect(Unit) {
// Import. // Import.
importModel( importModel(context = context,
context = context,
coroutineScope = coroutineScope, coroutineScope = coroutineScope,
fileName = info.fileName, fileName = info.fileName,
fileSize = info.fileSize, fileSize = info.fileSize,
@ -236,8 +243,7 @@ fun ModelImportingDialog(
}, },
onError = { onError = {
error = it error = it
} })
)
} }
Dialog( Dialog(
@ -246,9 +252,7 @@ fun ModelImportingDialog(
) { ) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column( Column(
modifier = Modifier modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) { ) {
// Title. // Title.
Text( Text(
@ -280,13 +284,10 @@ fun ModelImportingDialog(
// Has error. // Has error.
else { else {
Row( Row(
verticalAlignment = Alignment.Top, verticalAlignment = Alignment.Top, horizontalArrangement = Arrangement.spacedBy(6.dp)
horizontalArrangement = Arrangement.spacedBy(6.dp)
) { ) {
Icon( Icon(
Icons.Rounded.Error, Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error
contentDescription = "",
tint = MaterialTheme.colorScheme.error
) )
Text( Text(
error, error,

View file

@ -0,0 +1,121 @@
package com.google.aiedge.gallery.ui.home
import android.util.Log
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.expandVertically
import androidx.compose.animation.fadeIn
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.OpenInNew
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.BuildConfig
import com.google.aiedge.gallery.ui.common.getJsonResponse
import com.google.aiedge.gallery.ui.modelmanager.ClickableLink
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import kotlin.math.max
private const val TAG = "AGNewReleaseNotification"
private const val REPO = "google-ai-edge/gallery"
@Serializable
data class ReleaseInfo(
val html_url: String,
val tag_name: String,
)
@Composable
fun NewReleaseNotification() {
var newReleaseVersion by remember { mutableStateOf("") }
var newReleaseUrl by remember { mutableStateOf("") }
LaunchedEffect(Unit) {
withContext(Dispatchers.IO) {
Log.d("AGNewReleaseNotification", "Checking for new release...")
val info = getJsonResponse<ReleaseInfo>("https://api.github.com/repos/$REPO/releases/latest")
if (info != null) {
val curRelease = BuildConfig.VERSION_NAME
val newRelease = info.tag_name
val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease)
Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer")
if (isNewer) {
newReleaseVersion = newRelease
newReleaseUrl = info.html_url
}
}
}
}
AnimatedVisibility(
visible = newReleaseVersion.isNotEmpty(),
enter = fadeIn() + expandVertically()
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier
.padding(horizontal = 16.dp)
.padding(bottom = 12.dp)
.clip(
CircleShape
)
.background(MaterialTheme.colorScheme.tertiaryContainer)
.padding(4.dp)
) {
Text(
"New release $newReleaseVersion available",
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier.padding(start = 12.dp)
)
Row(
modifier = Modifier.padding(end = 12.dp),
verticalAlignment = Alignment.CenterVertically
) {
ClickableLink(
url = newReleaseUrl,
linkText = "View",
icon = Icons.AutoMirrored.Rounded.OpenInNew,
)
}
}
}
}
fun isNewerRelease(currentRelease: String, newRelease: String): Boolean {
// Split the version strings into their individual components (e.g., "0.9.0" -> ["0", "9", "0"])
val currentComponents = currentRelease.split('.').map { it.toIntOrNull() ?: 0 }
val newComponents = newRelease.split('.').map { it.toIntOrNull() ?: 0 }
// Determine the maximum number of components to iterate through
val maxComponents = max(currentComponents.size, newComponents.size)
// Iterate through the components from left to right (major, minor, patch, etc.)
for (i in 0 until maxComponents) {
val currentComponent = currentComponents.getOrElse(i) { 0 }
val newComponent = newComponents.getOrElse(i) { 0 }
if (newComponent > currentComponent) {
return true
} else if (newComponent < currentComponent) {
return false
}
}
return false
}

View file

@ -16,45 +16,266 @@
package com.google.aiedge.gallery.ui.home package com.google.aiedge.gallery.ui.home
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.BasicTextField
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.CheckCircle
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.MultiChoiceSegmentedButtonRow
import androidx.compose.material3.OutlinedButton
import androidx.compose.material3.SegmentedButton
import androidx.compose.material3.SegmentedButtonDefaults
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.focus.onFocusChanged
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.google.aiedge.gallery.BuildConfig import com.google.aiedge.gallery.BuildConfig
import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
import com.google.aiedge.gallery.ui.theme.THEME_AUTO import com.google.aiedge.gallery.ui.theme.THEME_AUTO
import com.google.aiedge.gallery.ui.theme.THEME_DARK import com.google.aiedge.gallery.ui.theme.THEME_DARK
import com.google.aiedge.gallery.ui.theme.THEME_LIGHT import com.google.aiedge.gallery.ui.theme.THEME_LIGHT
import com.google.aiedge.gallery.ui.theme.ThemeSettings
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
import java.time.Instant
import java.time.ZoneId
import java.time.format.DateTimeFormatter
import java.util.Locale
import kotlin.math.min
private val CONFIGS: List<Config> = listOf( private val THEME_OPTIONS = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK)
SegmentedButtonConfig(
key = ConfigKey.THEME,
defaultValue = THEME_AUTO,
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK),
)
)
@Composable @Composable
fun SettingsDialog( fun SettingsDialog(
curThemeOverride: String, curThemeOverride: String,
modelManagerViewModel: ModelManagerViewModel,
onDismissed: () -> Unit, onDismissed: () -> Unit,
onOk: (Map<String, Any>) -> Unit,
) { ) {
val initialValues = mapOf( var selectedTheme by remember { mutableStateOf(curThemeOverride) }
ConfigKey.THEME.label to curThemeOverride var hfToken by remember { mutableStateOf(modelManagerViewModel.getTokenStatusAndData().data) }
) val dateFormatter = remember {
ConfigDialog( DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.systemDefault())
title = "Settings", .withLocale(Locale.getDefault())
subtitle = "App version: ${BuildConfig.VERSION_NAME}", }
okBtnLabel = "OK", var customHfToken by remember { mutableStateOf("") }
configs = CONFIGS, var isFocused by remember { mutableStateOf(false) }
initialValues = initialValues, val focusRequester = remember { FocusRequester() }
onDismissed = onDismissed, val interactionSource = remember { MutableInteractionSource() }
onOk = { curConfigValues ->
onOk(curConfigValues)
// Hide config dialog. Dialog(onDismissRequest = onDismissed) {
onDismissed() val focusManager = LocalFocusManager.current
}, Card(
modifier = Modifier
.fillMaxWidth()
.clickable(
interactionSource = interactionSource, indication = null // Disable the ripple effect
) {
focusManager.clearFocus()
}, shape = RoundedCornerShape(16.dp)
) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Dialog title and subtitle.
Column {
Text(
"Settings",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
// Subtitle.
Text(
"App version: ${BuildConfig.VERSION_NAME}",
style = labelSmallNarrow,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.offset(y = (-6).dp)
) )
} }
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Theme switcher.
Column(
modifier = Modifier.fillMaxWidth()
) {
Text(
"Theme",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
)
MultiChoiceSegmentedButtonRow {
THEME_OPTIONS.forEachIndexed { index, label ->
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
index = index, count = THEME_OPTIONS.size
), onCheckedChange = {
selectedTheme = label
// Update theme settings.
// This will update app's theme.
ThemeSettings.themeOverride.value = label
// Save to data store.
modelManagerViewModel.saveThemeOverride(label)
}, checked = label == selectedTheme, label = { Text(label) })
}
}
}
// HF Token management.
Column(
modifier = Modifier.fillMaxWidth(), verticalArrangement = Arrangement.spacedBy(4.dp)
) {
Text(
"HuggingFace access token",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
)
// Show the start of the token.
val curHfToken = hfToken
if (curHfToken != null) {
Text(
curHfToken.accessToken.substring(0, min(16, curHfToken.accessToken.length)) + "...",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
} else {
Text(
"Not available",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
"The token will be automatically retrieved when a gated model is downloaded",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
OutlinedButton(
onClick = {
modelManagerViewModel.clearAccessToken()
hfToken = null
}, enabled = curHfToken != null
) {
Text("Clear")
}
BasicTextField(
value = customHfToken,
singleLine = true,
modifier = Modifier
.fillMaxWidth()
.padding(top = 4.dp)
.focusRequester(focusRequester)
.onFocusChanged {
isFocused = it.isFocused
},
onValueChange = {
customHfToken = it
},
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
) { innerTextField ->
Box(
modifier = Modifier
.border(
width = if (isFocused) 2.dp else 1.dp,
color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline,
shape = CircleShape,
)
.height(40.dp), contentAlignment = Alignment.CenterStart
) {
Row(verticalAlignment = Alignment.CenterVertically) {
Box(
modifier = Modifier
.padding(start = 16.dp)
.weight(1f)
) {
if (customHfToken.isEmpty()) {
Text(
"Enter token manually",
color = MaterialTheme.colorScheme.onSurfaceVariant,
style = MaterialTheme.typography.bodySmall
)
}
innerTextField()
}
if (customHfToken.isNotEmpty()) {
IconButton(
modifier = Modifier.offset(x = 1.dp),
onClick = {
modelManagerViewModel.saveAccessToken(
accessToken = customHfToken,
refreshToken = "",
expiresAt = System.currentTimeMillis() + 1000L * 60 * 60 * 24 * 365 * 10,
)
hfToken = modelManagerViewModel.getTokenStatusAndData().data
}) {
Icon(Icons.Rounded.CheckCircle, contentDescription = "")
}
}
}
}
}
}
}
}
// Button row.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
// Close button
Button(
onClick = {
onDismissed()
},
) {
Text("Close")
}
}
}
}
}
}

View file

@ -34,9 +34,9 @@ object LlmChatDestination {
val route = "LlmChatRoute" val route = "LlmChatRoute"
} }
object LlmImageToTextDestination { object LlmAskImageDestination {
@Serializable @Serializable
val route = "LlmImageToTextRoute" val route = "LlmAskImageRoute"
} }
@Composable @Composable
@ -57,11 +57,11 @@ fun LlmChatScreen(
} }
@Composable @Composable
fun LlmImageToTextScreen( fun LlmAskImageScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmImageToTextViewModel = viewModel( viewModel: LlmAskImageViewModel = viewModel(
factory = ViewModelProvider.Factory factory = ViewModelProvider.Factory
), ),
) { ) {
@ -129,12 +129,7 @@ fun ChatViewWrapper(
}) })
} }
}, },
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations -> onBenchmarkClicked = { _, _, _, _ ->
if (message is ChatMessageText) {
viewModel.benchmark(
model = model, message = message
)
}
}, },
onResetSessionClicked = { model -> onResetSessionClicked = { model ->
viewModel.resetSession(model = model) viewModel.resetSession(model = model)

View file

@ -22,7 +22,7 @@ import android.util.Log
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
@ -49,6 +49,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) { fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
setInProgress(true) setInProgress(true)
setPreparing(true)
// Loading. // Loading.
addMessage( addMessage(
@ -88,6 +89,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
timeToFirstToken = (firstTokenTs - start) / 1000f timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken prefillSpeed = prefillTokens / timeToFirstToken
firstRun = false firstRun = false
setPreparing(false)
} else { } else {
decodeTokens++ decodeTokens++
} }
@ -137,10 +139,12 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
}, },
cleanUpListener = { cleanUpListener = {
setInProgress(false) setInProgress(false)
setPreparing(false)
}) })
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error occurred while running inference", e) Log.e(TAG, "Error occurred while running inference", e)
setInProgress(false) setInProgress(false)
setPreparing(false)
onError() onError()
} }
} }
@ -194,98 +198,6 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
} }
} }
fun benchmark(model: Model, message: ChatMessageText) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
val instance = model.instance as LlmModelInstance
val prefillTokens = instance.session.sizeInTokens(message.content)
// Add the message to show benchmark results.
val benchmarkLlmResult = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(),
running = true,
latencyMs = -1f,
)
addMessage(model = model, message = benchmarkLlmResult)
// Run inference.
val result = StringBuilder()
var firstRun = true
var timeToFirstToken = 0f
var firstTokenTs = 0L
var decodeTokens = 0
var prefillSpeed = 0f
var decodeSpeed: Float
val start = System.currentTimeMillis()
var lastUpdateTime = 0L
LlmChatModelHelper.runInference(model = model,
input = message.content,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
if (firstRun) {
firstTokenTs = System.currentTimeMillis()
timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken
firstRun = false
// Update message to show prefill speed.
replaceLastMessage(
model = model,
message = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = false,
latencyMs = -1f,
),
type = ChatMessageType.BENCHMARK_LLM_RESULT,
)
} else {
decodeTokens++
}
result.append(partialResult)
if (curTs - lastUpdateTime > 500 || done) {
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) {
decodeSpeed = 0f
}
replaceLastMessage(
model = model, message = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = !done,
latencyMs = -1f,
), type = ChatMessageType.BENCHMARK_LLM_RESULT
)
lastUpdateTime = curTs
if (done) {
setInProgress(false)
}
}
},
cleanUpListener = {
setInProgress(false)
})
}
}
fun handleError( fun handleError(
context: Context, context: Context,
model: Model, model: Model,
@ -320,15 +232,8 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
) )
// Re-generate the response automatically. // Re-generate the response automatically.
generateResponse(model = model, input = triggeredMessage.content, onError = { generateResponse(model = model, input = triggeredMessage.content, onError = {})
handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = triggeredMessage
)
})
} }
} }
class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT) class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)

View file

@ -73,6 +73,7 @@ fun LlmSingleTurnScreen(
) { ) {
val task = viewModel.task val task = viewModel.task
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val uiState by viewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel val selectedModel = modelManagerUiState.selectedModel
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val context = LocalContext.current val context = LocalContext.current
@ -114,6 +115,8 @@ fun LlmSingleTurnScreen(
task = task, task = task,
model = selectedModel, model = selectedModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
inProgress = uiState.inProgress,
modelPreparing = uiState.preparing,
onConfigChanged = { _, _ -> }, onConfigChanged = { _, _ -> },
onBackClicked = { handleNavigateUp() }, onBackClicked = { handleNavigateUp() },
onModelSelected = { newSelectedModel -> onModelSelected = { newSelectedModel ->

View file

@ -20,10 +20,9 @@ import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
import com.google.aiedge.gallery.ui.common.chat.Stat import com.google.aiedge.gallery.ui.common.chat.Stat
import com.google.aiedge.gallery.ui.common.processLlmResponse import com.google.aiedge.gallery.ui.common.processLlmResponse
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
@ -44,9 +43,9 @@ data class LlmSingleTurnUiState(
val inProgress: Boolean = false, val inProgress: Boolean = false,
/** /**
* Indicates whether the model is currently being initialized. * Indicates whether the model is preparing (before outputting any result and after initializing).
*/ */
val initializing: Boolean = false, val preparing: Boolean = false,
// model -> <template label -> response> // model -> <template label -> response>
val responsesByModel: Map<String, Map<String, String>>, val responsesByModel: Map<String, Map<String, String>>,
@ -65,14 +64,14 @@ private val STATS = listOf(
Stat(id = "latency", label = "Latency", unit = "sec") Stat(id = "latency", label = "Latency", unit = "sec")
) )
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewModel() { open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task)) private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow() val uiState = _uiState.asStateFlow()
fun generateResponse(model: Model, input: String) { fun generateResponse(model: Model, input: String) {
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
setInProgress(true) setInProgress(true)
setInitializing(true) setPreparing(true)
// Wait for instance to be initialized. // Wait for instance to be initialized.
while (model.instance == null) { while (model.instance == null) {
@ -98,7 +97,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
val curTs = System.currentTimeMillis() val curTs = System.currentTimeMillis()
if (firstRun) { if (firstRun) {
setInitializing(false) setPreparing(false)
firstTokenTs = System.currentTimeMillis() firstTokenTs = System.currentTimeMillis()
timeToFirstToken = (firstTokenTs - start) / 1000f timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken prefillSpeed = prefillTokens / timeToFirstToken
@ -148,7 +147,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
}, },
singleTurn = true, singleTurn = true,
cleanUpListener = { cleanUpListener = {
setInitializing(false) setPreparing(false)
setInProgress(false) setInProgress(false)
}) })
} }
@ -167,8 +166,8 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
_uiState.update { _uiState.value.copy(inProgress = inProgress) } _uiState.update { _uiState.value.copy(inProgress = inProgress) }
} }
fun setInitializing(initializing: Boolean) { fun setPreparing(preparing: Boolean) {
_uiState.update { _uiState.value.copy(initializing = initializing) } _uiState.update { _uiState.value.copy(preparing = preparing) }
} }
fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) { fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) {

View file

@ -339,7 +339,7 @@ fun PromptTemplatesPanel(
val modelInitializing = val modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (inProgress && !modelInitializing) { if (inProgress && !modelInitializing && !uiState.preparing) {
IconButton( IconButton(
onClick = { onClick = {
onStopButtonClicked(model) onStopButtonClicked(model)

View file

@ -57,7 +57,7 @@ import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.aiedge.gallery.ui.common.chat.MarkdownText import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
@ -76,11 +76,11 @@ fun ResponsePanel(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
) { ) {
val task = TASK_LLM_USECASES val task = TASK_LLM_PROMPT_LAB
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val inProgress = uiState.inProgress val inProgress = uiState.inProgress
val initializing = uiState.initializing val initializing = uiState.preparing
val selectedPromptTemplateType = uiState.selectedPromptTemplateType val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val responseScrollState = rememberScrollState() val responseScrollState = rememberScrollState()
var selectedOptionIndex by remember { mutableIntStateOf(0) } var selectedOptionIndex by remember { mutableIntStateOf(0) }

View file

@ -70,7 +70,7 @@ fun ModelList(
) { ) {
// This is just to update "models" list when task.updateTrigger is updated so that the UI can // This is just to update "models" list when task.updateTrigger is updated so that the UI can
// be properly updated. // be properly updated.
val models by remember { val models by remember(task) {
derivedStateOf { derivedStateOf {
val trigger = task.updateTrigger.value val trigger = task.updateTrigger.value
if (trigger >= 0) { if (trigger >= 0) {
@ -80,7 +80,7 @@ fun ModelList(
} }
} }
} }
val importedModels by remember { val importedModels by remember(task) {
derivedStateOf { derivedStateOf {
val trigger = task.updateTrigger.value val trigger = task.updateTrigger.value
if (trigger >= 0) { if (trigger >= 0) {

View file

@ -37,15 +37,16 @@ import com.google.aiedge.gallery.data.ModelAllowlist
import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.data.getModelByName import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.common.AuthConfig import com.google.aiedge.gallery.ui.common.AuthConfig
import com.google.aiedge.gallery.ui.common.convertValueToTargetType import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.common.getJsonResponse
import com.google.aiedge.gallery.ui.common.processTasks import com.google.aiedge.gallery.ui.common.processTasks
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
@ -59,8 +60,6 @@ import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import net.openid.appauth.AuthorizationException import net.openid.appauth.AuthorizationException
import net.openid.appauth.AuthorizationRequest import net.openid.appauth.AuthorizationRequest
import net.openid.appauth.AuthorizationResponse import net.openid.appauth.AuthorizationResponse
@ -73,7 +72,7 @@ import java.net.URL
private const val TAG = "AGModelManagerViewModel" private const val TAG = "AGModelManagerViewModel"
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50 private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
private const val MODEL_ALLOWLIST_URL = private const val MODEL_ALLOWLIST_URL =
"https://raw.githubusercontent.com/jinjingforever/kokoro-codelab-jingjin/refs/heads/main/model_allowlist.json" "https://raw.githubusercontent.com/google-ai-edge/gallery/refs/heads/main/model_allowlist.json"
data class ModelInitializationStatus( data class ModelInitializationStatus(
val status: ModelInitializationStatusType, var error: String = "" val status: ModelInitializationStatusType, var error: String = ""
@ -116,11 +115,6 @@ data class ModelManagerUiState(
*/ */
val modelInitializationStatus: Map<String, ModelInitializationStatus>, val modelInitializationStatus: Map<String, ModelInitializationStatus>,
/**
* Whether Hugging Face models from the given community are currently being loaded.
*/
val loadingHfModels: Boolean = false,
/** /**
* Whether the app is loading and processing the model allowlist. * Whether the app is loading and processing the model allowlist.
*/ */
@ -196,8 +190,9 @@ open class ModelManagerViewModel(
) )
} }
fun cancelDownloadModel(model: Model) { fun cancelDownloadModel(task: Task, model: Model) {
downloadRepository.cancelDownloadModel(model) downloadRepository.cancelDownloadModel(model)
deleteModel(task = task, model = model)
} }
fun deleteModel(task: Task, model: Model) { fun deleteModel(task: Task, model: Model) {
@ -311,13 +306,13 @@ open class ModelManagerViewModel(
onDone = onDone, onDone = onDone,
) )
TaskType.LLM_USECASES -> LlmChatModelHelper.initialize( TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.initialize(
context = context, context = context,
model = model, model = model,
onDone = onDone, onDone = onDone,
) )
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.initialize( TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.initialize(
context = context, context = context,
model = model, model = model,
onDone = onDone, onDone = onDone,
@ -341,8 +336,8 @@ open class ModelManagerViewModel(
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model) TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model) TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model) TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {} TaskType.TEST_TASK_2 -> {}
@ -446,14 +441,14 @@ open class ModelManagerViewModel(
// Create model. // Create model.
val model = createModelFromImportedModelInfo(info = info) val model = createModelFromImportedModelInfo(info = info)
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES, TASK_LLM_IMAGE_TO_TEXT)) { for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
// Remove duplicated imported model if existed. // Remove duplicated imported model if existed.
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) { if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first") Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex) task.models.removeAt(modelIndex)
} }
if (task == TASK_LLM_IMAGE_TO_TEXT && model.llmSupportImage || task != TASK_LLM_IMAGE_TO_TEXT) { if (task == TASK_LLM_ASK_IMAGE && model.llmSupportImage || task != TASK_LLM_ASK_IMAGE) {
task.models.add(model) task.models.add(model)
} }
task.updateTrigger.value = System.currentTimeMillis() task.updateTrigger.value = System.currentTimeMillis()
@ -502,7 +497,7 @@ open class ModelManagerViewModel(
// Check expiration (with 5-minute buffer). // Check expiration (with 5-minute buffer).
val curTs = System.currentTimeMillis() val curTs = System.currentTimeMillis()
val expirationTs = tokenData.expiresAtSeconds - 5 * 60 val expirationTs = tokenData.expiresAtMs - 5 * 60
Log.d( Log.d(
TAG, TAG,
"Checking whether token has expired or not. Current ts: $curTs, expires at: $expirationTs" "Checking whether token has expired or not. Current ts: $curTs, expires at: $expirationTs"
@ -562,7 +557,7 @@ open class ModelManagerViewModel(
} else { } else {
// Token exchange successful. Store the tokens securely // Token exchange successful. Store the tokens securely
Log.d(TAG, "Token exchange successful. Storing tokens...") Log.d(TAG, "Token exchange successful. Storing tokens...")
dataStoreRepository.saveAccessTokenData( saveAccessToken(
accessToken = tokenResponse.accessToken!!, accessToken = tokenResponse.accessToken!!,
refreshToken = tokenResponse.refreshToken!!, refreshToken = tokenResponse.refreshToken!!,
expiresAt = tokenResponse.accessTokenExpirationTime!! expiresAt = tokenResponse.accessTokenExpirationTime!!
@ -606,6 +601,18 @@ open class ModelManagerViewModel(
} }
} }
fun saveAccessToken(accessToken: String, refreshToken: String, expiresAt: Long) {
dataStoreRepository.saveAccessTokenData(
accessToken = accessToken,
refreshToken = refreshToken,
expiresAt = expiresAt,
)
}
fun clearAccessToken() {
dataStoreRepository.clearAccessTokenData()
}
private fun processPendingDownloads() { private fun processPendingDownloads() {
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos") Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
@ -673,11 +680,11 @@ open class ModelManagerViewModel(
if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) { if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) {
TASK_LLM_CHAT.models.add(model) TASK_LLM_CHAT.models.add(model)
} }
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) { if (allowedModel.taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)) {
TASK_LLM_USECASES.models.add(model) TASK_LLM_PROMPT_LAB.models.add(model)
} }
if (allowedModel.taskTypes.contains(TASK_LLM_IMAGE_TO_TEXT.type.id)) { if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
TASK_LLM_IMAGE_TO_TEXT.models.add(model) TASK_LLM_ASK_IMAGE.models.add(model)
} }
} }
@ -732,9 +739,9 @@ open class ModelManagerViewModel(
// Add to task. // Add to task.
TASK_LLM_CHAT.models.add(model) TASK_LLM_CHAT.models.add(model)
TASK_LLM_USECASES.models.add(model) TASK_LLM_PROMPT_LAB.models.add(model)
if (model.llmSupportImage) { if (model.llmSupportImage) {
TASK_LLM_IMAGE_TO_TEXT.models.add(model) TASK_LLM_ASK_IMAGE.models.add(model)
} }
// Update status. // Update status.
@ -827,38 +834,6 @@ open class ModelManagerViewModel(
) )
} }
@OptIn(ExperimentalSerializationApi::class)
private inline fun <reified T> getJsonResponse(url: String): T? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET"
connection.connect()
val responseCode = connection.responseCode
if (responseCode == HttpURLConnection.HTTP_OK) {
val inputStream = connection.inputStream
val response = inputStream.bufferedReader().use { it.readText() }
// Parse JSON using kotlinx.serialization
val json = Json {
// Handle potential extra fields
ignoreUnknownKeys = true
allowComments = true
allowTrailingComma = true
}
val jsonObj = json.decodeFromString<T>(response)
return jsonObj
} else {
Log.e(TAG, "HTTP error: $responseCode")
}
} catch (e: Exception) {
Log.e(TAG, "Error when getting json response: ${e.message}")
e.printStackTrace()
}
return null
}
private fun isFileInExternalFilesDir(fileName: String): Boolean { private fun isFileInExternalFilesDir(fileName: String): Boolean {
if (externalFilesDir != null) { if (externalFilesDir != null) {
val file = File(externalFilesDir, fileName) val file = File(externalFilesDir, fileName)

View file

@ -46,8 +46,8 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
@ -60,8 +60,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextDestination import com.google.aiedge.gallery.ui.llmchat.LlmAskImageDestination
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextScreen import com.google.aiedge.gallery.ui.llmchat.LlmAskImageScreen
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.aiedge.gallery.ui.modelmanager.ModelManager import com.google.aiedge.gallery.ui.modelmanager.ModelManager
@ -233,7 +233,7 @@ fun GalleryNavHost(
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) {
getModelFromNavigationParam(it, TASK_LLM_USECASES)?.let { defaultModel -> getModelFromNavigationParam(it, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen( LlmSingleTurnScreen(
@ -245,15 +245,15 @@ fun GalleryNavHost(
// LLM image to text. // LLM image to text.
composable( composable(
route = "${LlmImageToTextDestination.route}/{modelName}", route = "${LlmAskImageDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) {
getModelFromNavigationParam(it, TASK_LLM_IMAGE_TO_TEXT)?.let { defaultModel -> getModelFromNavigationParam(it, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmImageToTextScreen( LlmAskImageScreen(
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() }, navigateUp = { navController.navigateUp() },
) )
@ -287,8 +287,8 @@ fun navigateToTaskScreen(
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}") TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}") TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}") TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_IMAGE_TO_TEXT -> navController.navigate("${LlmImageToTextDestination.route}/${modelName}") TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}") TaskType.LLM_PROMPT_LAB -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}") TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {} TaskType.TEST_TASK_2 -> {}

View file

@ -42,6 +42,9 @@ class PreviewDataStoreRepository : DataStoreRepository {
return null return null
} }
override fun clearAccessTokenData() {
}
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) { override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
} }