diff --git a/Android/src/app/build.gradle.kts b/Android/src/app/build.gradle.kts index d21b009..091b60c 100644 --- a/Android/src/app/build.gradle.kts +++ b/Android/src/app/build.gradle.kts @@ -27,10 +27,10 @@ android { defaultConfig { applicationId = "com.google.aiedge.gallery" - minSdk = 24 + minSdk = 26 targetSdk = 35 versionCode = 1 - versionName = "20250428" + versionName = "0.9.0" // Needed for HuggingFace auth workflows. manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth" diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml index 4d3c832..1f7f37e 100644 --- a/Android/src/app/src/main/AndroidManifest.xml +++ b/Android/src/app/src/main/AndroidManifest.xml @@ -41,7 +41,9 @@ android:name=".MainActivity" android:exported="true" android:theme="@style/Theme.Gallery.SplashScreen" - android:windowSoftInputMode="adjustResize"> + android:screenOrientation="portrait" + android:windowSoftInputMode="adjustResize" + tools:ignore="DiscouragedApi,LockedOrientationActivity"> diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApp.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApp.kt index 4cd0fa8..9fb9caa 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApp.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApp.kt @@ -68,7 +68,6 @@ fun GalleryTopAppBar( leftAction: AppBarAction? = null, rightAction: AppBarAction? = null, scrollBehavior: TopAppBarScrollBehavior? = null, - loadingHfModels: Boolean = false, subtitle: String = "", ) { CenterAlignedTopAppBar( @@ -151,28 +150,6 @@ fun GalleryTopAppBar( } } - // Click an icon to open "download manager". - AppBarActionType.DOWNLOAD_MANAGER -> { - if (loadingHfModels) { - CircularProgressIndicator( - trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, - strokeWidth = 3.dp, - modifier = Modifier - .padding(end = 12.dp) - .size(20.dp) - ) - } -// else { -// IconButton(onClick = rightAction.actionFn) { -// Icon( -// imageVector = Deployed_code, -// contentDescription = "", -// tint = MaterialTheme.colorScheme.primary -// ) -// } -// } - } - AppBarActionType.MODEL_SELECTOR -> { Text("ms") } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt index d1cc28d..3046580 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt @@ -37,7 +37,7 @@ import javax.crypto.SecretKey data class AccessTokenData( val accessToken: String, val refreshToken: String, - val expiresAtSeconds: Long + val expiresAtMs: Long ) interface DataStoreRepository { @@ -46,6 +46,7 @@ interface DataStoreRepository { fun saveThemeOverride(theme: String) fun readThemeOverride(): String fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) + fun clearAccessTokenData() fun readAccessTokenData(): AccessTokenData? fun saveImportedModels(importedModels: List) fun readImportedModels(): List @@ -135,6 +136,18 @@ class DefaultDataStoreRepository( } } + override fun clearAccessTokenData() { + return runBlocking { + dataStore.edit { preferences -> + preferences.remove(PreferencesKeys.ENCRYPTED_ACCESS_TOKEN) + preferences.remove(PreferencesKeys.ACCESS_TOKEN_IV) + preferences.remove(PreferencesKeys.ENCRYPTED_REFRESH_TOKEN) + preferences.remove(PreferencesKeys.REFRESH_TOKEN_IV) + preferences.remove(PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT) + } + } + } + override fun readAccessTokenData(): AccessTokenData? { return runBlocking { val preferences = dataStore.data.first() diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt index 5d301e9..02b97c1 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt @@ -43,7 +43,7 @@ data class AllowedModel( // Config. val isLlmModel = - taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id) + taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id) var configs: List = listOf() if (isLlmModel) { var defaultTopK: Int = DEFAULT_TOPK diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt index 6218046..64ce7a5 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt @@ -32,9 +32,9 @@ enum class TaskType(val label: String, val id: String) { TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"), IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"), IMAGE_GENERATION(label = "Image Generation", id = "image_generation"), - LLM_CHAT(label = "LLM Chat", id = "llm_chat"), - LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"), - LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"), + LLM_CHAT(label = "AI Chat", id = "llm_chat"), + LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"), + LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"), TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_2(label = "Test task 2", id = "test_task_2") @@ -93,7 +93,6 @@ val TASK_IMAGE_CLASSIFICATION = Task( val TASK_LLM_CHAT = Task( type = TaskType.LLM_CHAT, icon = Icons.Outlined.Forum, -// models = MODELS_LLM, models = mutableListOf(), description = "Chat with on-device large language models", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", @@ -101,10 +100,9 @@ val TASK_LLM_CHAT = Task( textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat ) -val TASK_LLM_USECASES = Task( - type = TaskType.LLM_USECASES, +val TASK_LLM_PROMPT_LAB = Task( + type = TaskType.LLM_PROMPT_LAB, icon = Icons.Outlined.Widgets, -// models = MODELS_LLM, models = mutableListOf(), description = "Single turn use cases with on-device large language model", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", @@ -112,10 +110,9 @@ val TASK_LLM_USECASES = Task( textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat ) -val TASK_LLM_IMAGE_TO_TEXT = Task( - type = TaskType.LLM_IMAGE_TO_TEXT, +val TASK_LLM_ASK_IMAGE = Task( + type = TaskType.LLM_ASK_IMAGE, icon = Icons.Outlined.Mms, -// models = MODELS_LLM, models = mutableListOf(), description = "Ask questions about images with on-device large language models", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", @@ -135,12 +132,9 @@ val TASK_IMAGE_GENERATION = Task( /** All tasks. */ val TASKS: List = listOf( -// TASK_TEXT_CLASSIFICATION, -// TASK_IMAGE_CLASSIFICATION, -// TASK_IMAGE_GENERATION, - TASK_LLM_USECASES, + TASK_LLM_ASK_IMAGE, + TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT, - TASK_LLM_IMAGE_TO_TEXT ) fun getModelByName(name: String): Model? { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt index bc3d4ab..cf22471 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/ViewModelProvider.kt @@ -25,7 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel -import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel +import com.google.aiedge.gallery.ui.llmchat.LlmAskImageViewModel import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel @@ -63,9 +63,9 @@ object ViewModelProvider { LlmSingleTurnViewModel() } - // Initializer for LlmImageToTextViewModel. + // Initializer for LlmAskImageViewModel. initializer { - LlmImageToTextViewModel() + LlmAskImageViewModel() } // Initializer for ImageGenerationViewModel. diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt index ba156ca..b2871db 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt @@ -26,6 +26,8 @@ import androidx.browser.customtabs.CustomTabsIntent import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.wrapContentHeight +import androidx.compose.foundation.text.BasicText +import androidx.compose.foundation.text.TextAutoSize import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.rounded.ArrowForward import androidx.compose.material.icons.rounded.Error @@ -48,6 +50,7 @@ import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel @@ -291,11 +294,32 @@ fun DownloadAndTryButton( modifier = Modifier.padding(end = 4.dp) ) + val textColor = MaterialTheme.colorScheme.onPrimary if (checkingToken) { - Text("Checking access...") + BasicText( + text = "Checking access...", + maxLines = 1, + color = { textColor }, + style = MaterialTheme.typography.bodyMedium, + autoSize = TextAutoSize.StepBased( + minFontSize = 8.sp, + maxFontSize = 14.sp, + stepSize = 1.sp + ) + ) } else { if (needToDownloadFirst) { - Text("Download & Try it", maxLines = 1) + BasicText( + text = "Download & Try", + maxLines = 1, + color = { textColor }, + style = MaterialTheme.typography.bodyMedium, + autoSize = TextAutoSize.StepBased( + minFontSize = 8.sp, + maxFontSize = 14.sp, + stepSize = 1.sp + ) + ) } else { Text("Try it", maxLines = 1) } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt index c9b17e3..8986bf1 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt @@ -63,6 +63,8 @@ fun ModelPageAppBar( modelManagerViewModel: ModelManagerViewModel, onBackClicked: () -> Unit, onModelSelected: (Model) -> Unit, + inProgress: Boolean, + modelPreparing: Boolean, modifier: Modifier = Modifier, isResettingSession: Boolean = false, onResetSessionClicked: (Model) -> Unit = {}, @@ -129,15 +131,16 @@ fun ModelPageAppBar( val isModelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING if (showConfigButton) { + val enableConfigButton = !isModelInitializing && !inProgress IconButton( onClick = { showConfigDialog = true }, - enabled = !isModelInitializing, + enabled = enableConfigButton, modifier = Modifier .scale(0.75f) .offset(x = configButtonOffset) - .alpha(if (isModelInitializing) 0.5f else 1f) + .alpha(if (!enableConfigButton) 0.5f else 1f) ) { Icon( imageVector = Icons.Rounded.Tune, @@ -154,14 +157,15 @@ fun ModelPageAppBar( modifier = Modifier.size(16.dp) ) } else { + val enableResetButton = !isModelInitializing && !modelPreparing IconButton( onClick = { onResetSessionClicked(model) }, - enabled = !isModelInitializing, + enabled = enableResetButton, modifier = Modifier .scale(0.75f) - .alpha(if (isModelInitializing) 0.5f else 1f) + .alpha(if (!enableResetButton) 0.5f else 1f) ) { Icon( imageVector = Icons.Rounded.MapsUgc, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt index b4636f5..160d454 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt @@ -83,8 +83,11 @@ fun ModelPickerChipsPager( val scope = rememberCoroutineScope() val density = LocalDensity.current val windowInfo = LocalWindowInfo.current - val screenWidthDp = - remember { with(density) { windowInfo.containerSize.width.toDp() } } + val screenWidthDp = remember { + with(density) { + windowInfo.containerSize.width.toDp() + } + } val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel), pageCount = { task.models.size }) @@ -147,7 +150,7 @@ fun ModelPickerChipsPager( } Text( model.name, - style = MaterialTheme.typography.labelSmall, + style = MaterialTheme.typography.labelMedium, modifier = Modifier .padding(start = 4.dp) .widthIn(0.dp, screenWidthDp - 250.dp), diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt index da4d98d..665a073 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt @@ -21,6 +21,7 @@ import android.content.Context import android.content.pm.PackageManager import android.net.Uri import android.os.Build +import android.util.Log import androidx.activity.compose.ManagedActivityResultLauncher import androidx.compose.material3.MaterialTheme import androidx.compose.runtime.Composable @@ -39,7 +40,11 @@ import com.google.aiedge.gallery.ui.common.chat.Histogram import com.google.aiedge.gallery.ui.common.chat.Stat import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.theme.customColors +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.Json import java.io.File +import java.net.HttpURLConnection +import java.net.URL import kotlin.math.abs import kotlin.math.ln import kotlin.math.max @@ -487,4 +492,39 @@ fun processLlmResponse(response: String): String { newContent = newContent.replace("\\n", "\n") return newContent -} \ No newline at end of file +} + +@OptIn(ExperimentalSerializationApi::class) +inline fun getJsonResponse(url: String): T? { + try { + val connection = URL(url).openConnection() as HttpURLConnection + connection.requestMethod = "GET" + connection.connect() + + val responseCode = connection.responseCode + if (responseCode == HttpURLConnection.HTTP_OK) { + val inputStream = connection.inputStream + val response = inputStream.bufferedReader().use { it.readText() } + + // Parse JSON using kotlinx.serialization + val json = Json { + // Handle potential extra fields + ignoreUnknownKeys = true + allowComments = true + allowTrailingComma = true + } + val jsonObj = json.decodeFromString(response) + return jsonObj + } else { + Log.e("AGUtils", "HTTP error: $responseCode") + } + } catch (e: Exception) { + Log.e( + "AGUtils", + "Error when getting json response: ${e.message}" + ) + e.printStackTrace() + } + + return null +} diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt index ea1c650..5ce0f4b 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt @@ -16,9 +16,14 @@ package com.google.aiedge.gallery.ui.common.chat +import androidx.compose.animation.AnimatedContent +import androidx.compose.animation.ExperimentalSharedTransitionApi +import androidx.compose.animation.SharedTransitionLayout +import androidx.compose.foundation.Image import androidx.compose.foundation.background import androidx.compose.foundation.clickable import androidx.compose.foundation.gestures.detectTapGestures +import androidx.compose.foundation.gestures.detectTransformGestures import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -28,6 +33,7 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.ime import androidx.compose.foundation.layout.imePadding +import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.width @@ -39,12 +45,15 @@ import androidx.compose.foundation.lazy.rememberLazyListState import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.outlined.Timer +import androidx.compose.material.icons.rounded.Close import androidx.compose.material.icons.rounded.ContentCopy import androidx.compose.material.icons.rounded.Refresh import androidx.compose.material3.Button import androidx.compose.material3.Card import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.IconButtonDefaults import androidx.compose.material3.MaterialTheme import androidx.compose.material3.ModalBottomSheet import androidx.compose.material3.SnackbarHost @@ -55,6 +64,7 @@ import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.MutableState import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableFloatStateOf import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember @@ -65,12 +75,16 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.draw.clip import androidx.compose.ui.geometry.Offset import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.RectangleShape import androidx.compose.ui.graphics.asImageBitmap +import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.hapticfeedback.HapticFeedbackType import androidx.compose.ui.input.nestedscroll.NestedScrollConnection import androidx.compose.ui.input.nestedscroll.NestedScrollSource import androidx.compose.ui.input.nestedscroll.nestedScroll import androidx.compose.ui.input.pointer.pointerInput +import androidx.compose.ui.layout.ContentScale +import androidx.compose.ui.layout.onSizeChanged import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalDensity @@ -80,6 +94,7 @@ import androidx.compose.ui.res.dimensionResource import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.tooling.preview.Preview +import androidx.compose.ui.unit.IntSize import androidx.compose.ui.unit.dp import androidx.compose.ui.window.Dialog import com.google.aiedge.gallery.R @@ -102,7 +117,7 @@ enum class ChatInputType { /** * Composable function for the main chat panel, displaying messages and handling user input. */ -@OptIn(ExperimentalMaterial3Api::class) +@OptIn(ExperimentalMaterial3Api::class, ExperimentalSharedTransitionApi::class) @Composable fun ChatPanel( modelManagerViewModel: ModelManagerViewModel, @@ -127,6 +142,7 @@ fun ChatPanel( val snackbarHostState = remember { SnackbarHostState() } val scope = rememberCoroutineScope() val haptic = LocalHapticFeedback.current + var selectedImageMessage by remember { mutableStateOf(null) } var curMessage by remember { mutableStateOf("") } // Correct state val focusManager = LocalFocusManager.current @@ -200,291 +216,378 @@ fun ChatPanel( } } - val modelInitializationStatus = - modelManagerUiState.modelInitializationStatus[selectedModel.name] + val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[selectedModel.name] LaunchedEffect(modelInitializationStatus) { showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR } - Column( - modifier = modifier.imePadding() - ) { - Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) { - LazyColumn( - modifier = Modifier - .fillMaxSize() - .nestedScroll(nestedScrollConnection), - state = listState, verticalArrangement = Arrangement.Top, + SharedTransitionLayout(modifier = Modifier.fillMaxSize()) { + AnimatedContent(targetState = selectedImageMessage) { targetSelectedImageMessage -> + Column( + modifier = modifier.imePadding() ) { - items(messages) { message -> - val imageHistoryCurIndex = remember { mutableIntStateOf(0) } - var hAlign: Alignment.Horizontal = Alignment.End - var backgroundColor: Color = MaterialTheme.customColors.userBubbleBgColor - var hardCornerAtLeftOrRight = false - var extraPaddingStart = 48.dp - var extraPaddingEnd = 0.dp - if (message.side == ChatSide.AGENT) { - hAlign = Alignment.Start - backgroundColor = MaterialTheme.customColors.agentBubbleBgColor - hardCornerAtLeftOrRight = true - extraPaddingStart = 0.dp - extraPaddingEnd = 48.dp - } else if (message.side == ChatSide.SYSTEM) { - extraPaddingStart = 24.dp - extraPaddingEnd = 24.dp - if (message.type == ChatMessageType.PROMPT_TEMPLATES) { - extraPaddingStart = 12.dp - extraPaddingEnd = 12.dp - } - } - if (message.type == ChatMessageType.IMAGE) { - backgroundColor = Color.Transparent - } - val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius) - - Column( + Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) { + LazyColumn( modifier = Modifier - .fillMaxWidth() - .padding( - start = 12.dp + extraPaddingStart, - end = 12.dp + extraPaddingEnd, - top = 6.dp, - bottom = 6.dp, - ), - horizontalAlignment = hAlign, + .fillMaxSize() + .nestedScroll(nestedScrollConnection), + state = listState, verticalArrangement = Arrangement.Top, ) { - // Sender row. - MessageSender( - message = message, - agentNameRes = task.agentNameRes, - imageHistoryCurIndex = imageHistoryCurIndex.intValue - ) + items(messages) { message -> + val imageHistoryCurIndex = remember { mutableIntStateOf(0) } + var hAlign: Alignment.Horizontal = Alignment.End + var backgroundColor: Color = MaterialTheme.customColors.userBubbleBgColor + var hardCornerAtLeftOrRight = false + var extraPaddingStart = 48.dp + var extraPaddingEnd = 0.dp + if (message.side == ChatSide.AGENT) { + hAlign = Alignment.Start + backgroundColor = MaterialTheme.customColors.agentBubbleBgColor + hardCornerAtLeftOrRight = true + extraPaddingStart = 0.dp + extraPaddingEnd = 48.dp + } else if (message.side == ChatSide.SYSTEM) { + extraPaddingStart = 24.dp + extraPaddingEnd = 24.dp + if (message.type == ChatMessageType.PROMPT_TEMPLATES) { + extraPaddingStart = 12.dp + extraPaddingEnd = 12.dp + } + } + if (message.type == ChatMessageType.IMAGE) { + backgroundColor = Color.Transparent + } + val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius) - // Message body. - when (message) { - // Loading. - is ChatMessageLoading -> MessageBodyLoading() + Column( + modifier = Modifier + .fillMaxWidth() + .padding( + start = 12.dp + extraPaddingStart, + end = 12.dp + extraPaddingEnd, + top = 6.dp, + bottom = 6.dp, + ), + horizontalAlignment = hAlign, + ) { + // Sender row. + MessageSender( + message = message, + agentNameRes = task.agentNameRes, + imageHistoryCurIndex = imageHistoryCurIndex.intValue + ) - // Info. - is ChatMessageInfo -> MessageBodyInfo(message = message) + // Message body. + when (message) { + // Loading. + is ChatMessageLoading -> MessageBodyLoading() - // Warning - is ChatMessageWarning -> MessageBodyWarning(message = message) + // Info. + is ChatMessageInfo -> MessageBodyInfo(message = message) - // Config values change. - is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message) + // Warning + is ChatMessageWarning -> MessageBodyWarning(message = message) - // Prompt templates. - is ChatMessagePromptTemplates -> MessageBodyPromptTemplates(message = message, - task = task, - onPromptClicked = { template -> - onSendMessage( - selectedModel, - listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER)) - ) - }) + // Config values change. + is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message) - // Non-system messages. - else -> { - // The bubble shape around the message body. - var messageBubbleModifier = Modifier - .clip( - MessageBubbleShape( - radius = bubbleBorderRadius, - hardCornerAtLeftOrRight = hardCornerAtLeftOrRight - ) - ) - .background(backgroundColor) - if (message is ChatMessageText) { - messageBubbleModifier = messageBubbleModifier - .pointerInput(Unit) { - detectTapGestures( - onLongPress = { - haptic.performHapticFeedback(HapticFeedbackType.LongPress) - longPressedMessage.value = message - showMessageLongPressedSheet = true - }, + // Prompt templates. + is ChatMessagePromptTemplates -> MessageBodyPromptTemplates(message = message, + task = task, + onPromptClicked = { template -> + onSendMessage( + selectedModel, + listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER)) ) + }) + + // Non-system messages. + else -> { + // The bubble shape around the message body. + var messageBubbleModifier = Modifier + .clip( + MessageBubbleShape( + radius = bubbleBorderRadius, + hardCornerAtLeftOrRight = hardCornerAtLeftOrRight + ) + ) + .background(backgroundColor) + if (message is ChatMessageText) { + messageBubbleModifier = messageBubbleModifier.pointerInput(Unit) { + detectTapGestures( + onLongPress = { + haptic.performHapticFeedback(HapticFeedbackType.LongPress) + longPressedMessage.value = message + showMessageLongPressedSheet = true + }, + ) + } } - } - Box( - modifier = messageBubbleModifier, - ) { - when (message) { - // Text - is ChatMessageText -> MessageBodyText(message = message) - - // Image - is ChatMessageImage -> MessageBodyImage(message = message) - - // Image with history (for image gen) - is ChatMessageImageWithHistory -> MessageBodyImageWithHistory( - message = message, imageHistoryCurIndex = imageHistoryCurIndex - ) - - // Classification result - is ChatMessageClassification -> MessageBodyClassification( - message = message, - modifier = Modifier.width(message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH) - ) - - // Benchmark result. - is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message) - - // Benchmark LLM result. - is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm( - message = message, - modifier = Modifier.wrapContentWidth() - ) - - else -> {} - } - } - if (message.side == ChatSide.AGENT) { - Row( - verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.spacedBy(8.dp), - ) { - LatencyText(message = message) - // A button to show stats for the LLM message. - if (task.type == TaskType.LLM_CHAT && message is ChatMessageText - // This means we only want to show the action button when the message is done - // generating, at which point the latency will be set. - && message.latencyMs >= 0 + Box( + modifier = messageBubbleModifier, ) { - val showingStats = - viewModel.isShowingStats(model = selectedModel, message = message) - MessageActionButton( - label = if (showingStats) "Hide stats" else "Show stats", - icon = Icons.Outlined.Timer, - onClick = { - // Toggle showing stats. - viewModel.toggleShowingStats(selectedModel, message) + when (message) { + // Text + is ChatMessageText -> MessageBodyText(message = message) - // Add the stats message after the LLM message. - if (viewModel.isShowingStats(model = selectedModel, message = message)) { - val llmBenchmarkResult = message.llmBenchmarkResult - if (llmBenchmarkResult != null) { - viewModel.insertMessageAfter( - model = selectedModel, - anchorMessage = message, - messageToAdd = llmBenchmarkResult, - ) - } - } - // Remove the stats message. - else { - val curMessageIndex = - viewModel.getMessageIndex(model = selectedModel, message = message) - viewModel.removeMessageAt( - model = selectedModel, - index = curMessageIndex + 1 + // Image + is ChatMessageImage -> { + if (targetSelectedImageMessage != message) { + MessageBodyImage( + message = message, + modifier = Modifier + .clickable { + selectedImageMessage = message + } + .sharedElement( + sharedContentState = rememberSharedContentState(key = "selected_image"), + animatedVisibilityScope = this@AnimatedContent + ), ) } - }, - enabled = !uiState.inProgress - ) - } - } - } else if (message.side == ChatSide.USER) { - Row( - verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.spacedBy(4.dp) - ) { - // Run again button. - if (selectedModel.showRunAgainButton) { - MessageActionButton( - label = stringResource(R.string.run_again), - icon = Icons.Rounded.Refresh, - onClick = { - onRunAgainClicked(selectedModel, message) - }, - enabled = !uiState.inProgress - ) - } + } - // Benchmark button - if (selectedModel.showBenchmarkButton) { - MessageActionButton( - label = stringResource(R.string.benchmark), - icon = Icons.Outlined.Timer, - onClick = { - showBenchmarkConfigsDialog = true - benchmarkMessage.value = message - }, - enabled = !uiState.inProgress - ) + // Image with history (for image gen) + is ChatMessageImageWithHistory -> MessageBodyImageWithHistory( + message = message, imageHistoryCurIndex = imageHistoryCurIndex + ) + + // Classification result + is ChatMessageClassification -> MessageBodyClassification( + message = message, modifier = Modifier.width( + message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH + ) + ) + + // Benchmark result. + is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message) + + // Benchmark LLM result. + is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm( + message = message, modifier = Modifier.wrapContentWidth() + ) + + else -> {} + } + } + if (message.side == ChatSide.AGENT) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(8.dp), + ) { + LatencyText(message = message) + // A button to show stats for the LLM message. + if ((task.type == TaskType.LLM_CHAT || task.type == TaskType.LLM_ASK_IMAGE) && message is ChatMessageText + // This means we only want to show the action button when the message is done + // generating, at which point the latency will be set. + && message.latencyMs >= 0 + ) { + val showingStats = + viewModel.isShowingStats(model = selectedModel, message = message) + MessageActionButton( + label = if (showingStats) "Hide stats" else "Show stats", + icon = Icons.Outlined.Timer, + onClick = { + // Toggle showing stats. + viewModel.toggleShowingStats(selectedModel, message) + + // Add the stats message after the LLM message. + if (viewModel.isShowingStats( + model = selectedModel, message = message + ) + ) { + val llmBenchmarkResult = message.llmBenchmarkResult + if (llmBenchmarkResult != null) { + viewModel.insertMessageAfter( + model = selectedModel, + anchorMessage = message, + messageToAdd = llmBenchmarkResult, + ) + } + } + // Remove the stats message. + else { + val curMessageIndex = + viewModel.getMessageIndex( + model = selectedModel, + message = message + ) + viewModel.removeMessageAt( + model = selectedModel, index = curMessageIndex + 1 + ) + } + }, + enabled = !uiState.inProgress + ) + } + } + } else if (message.side == ChatSide.USER) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(4.dp) + ) { + // Run again button. + if (selectedModel.showRunAgainButton) { + MessageActionButton( + label = stringResource(R.string.run_again), + icon = Icons.Rounded.Refresh, + onClick = { + onRunAgainClicked(selectedModel, message) + }, + enabled = !uiState.inProgress + ) + } + + // Benchmark button + if (selectedModel.showBenchmarkButton) { + MessageActionButton( + label = stringResource(R.string.benchmark), + icon = Icons.Outlined.Timer, + onClick = { + showBenchmarkConfigsDialog = true + benchmarkMessage.value = message + }, + enabled = !uiState.inProgress + ) + } + } } } } } } } + + SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp)) + + // Show an info message for ask image task to get users started. + if (task.type == TaskType.LLM_ASK_IMAGE && messages.isEmpty()) { + Column( + modifier = Modifier + .padding(horizontal = 16.dp) + .fillMaxSize(), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + MessageBodyInfo( + ChatMessageInfo(content = "To get started, click + below to add an image and type a prompt to ask a question about it."), + smallFontSize = false + ) + } + } + } + + // Chat input + when (chatInputType) { + ChatInputType.TEXT -> { +// val isLlmTask = task.type == TaskType.LLM_CHAT +// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates) + val hasImageMessage = messages.any { it is ChatMessageImage } + MessageInputText( + modelManagerViewModel = modelManagerViewModel, + curMessage = curMessage, + inProgress = uiState.inProgress, + isResettingSession = uiState.isResettingSession, + modelPreparing = uiState.preparing, + hasImageMessage = hasImageMessage, + modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, + textFieldPlaceHolderRes = task.textInputPlaceHolderRes, + onValueChanged = { curMessage = it }, + onSendMessage = { + onSendMessage(selectedModel, it) + curMessage = "" + }, + onOpenPromptTemplatesClicked = { + onSendMessage( + selectedModel, listOf( + ChatMessagePromptTemplates( + templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false + ) + ) + ) + }, + onStopButtonClicked = onStopButtonClicked, +// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, + showPromptTemplatesInMenu = false, + showImagePickerInMenu = selectedModel.llmSupportImage, + showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress, + ) + } + + ChatInputType.IMAGE -> MessageInputImage( + disableButtons = uiState.inProgress, + streamingMessage = streamingMessage, + onImageSelected = { bitmap -> + onSendMessage( + selectedModel, listOf( + ChatMessageImage( + bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER + ) + ) + ) + }, + onStreamImage = { bitmap -> + onStreamImageMessage( + selectedModel, ChatMessageImage( + bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER + ) + ) + }, + onStreamEnd = onStreamEnd, + ) } } - SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp)) - } - - - // Chat input - when (chatInputType) { - ChatInputType.TEXT -> { -// val isLlmTask = task.type == TaskType.LLM_CHAT -// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates) - val hasImageMessage = messages.any { it is ChatMessageImage } - MessageInputText( - modelManagerViewModel = modelManagerViewModel, - curMessage = curMessage, - inProgress = uiState.inProgress, - isResettingSession = uiState.isResettingSession, - hasImageMessage = hasImageMessage, - modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, - textFieldPlaceHolderRes = task.textInputPlaceHolderRes, - onValueChanged = { curMessage = it }, - onSendMessage = { - onSendMessage(selectedModel, it) - curMessage = "" - }, - onOpenPromptTemplatesClicked = { - onSendMessage( - selectedModel, listOf( - ChatMessagePromptTemplates( - templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false - ) - ) + // A full-screen image viewer. + if (targetSelectedImageMessage != null) { + ZoomableBox( + modifier = Modifier + .fillMaxSize() + .background(Color.Black.copy(alpha = 0.9f)) + .sharedElement( + rememberSharedContentState(key = "bounds"), + animatedVisibilityScope = this, ) - }, - onStopButtonClicked = onStopButtonClicked, -// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, - showPromptTemplatesInMenu = false, - showImagePickerInMenu = selectedModel.llmSupportImage == true, - showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress, - ) + .skipToLookaheadSize(), + ) { + // Image. + Image( + bitmap = targetSelectedImageMessage.imageBitMap, + contentDescription = "", + modifier = modifier + .fillMaxSize() + .graphicsLayer( + scaleX = scale, + scaleY = scale, + translationX = offsetX, + translationY = offsetY + ) + .sharedElement( + sharedContentState = rememberSharedContentState(key = "selected_image"), + animatedVisibilityScope = this@AnimatedContent + ), + contentScale = ContentScale.Fit, + ) + + // Close button. + IconButton( + onClick = { + selectedImageMessage = null + }, + colors = IconButtonDefaults.iconButtonColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant, + ), + modifier = Modifier.offset(x = (-8).dp, y = 8.dp) + ) { + Icon( + Icons.Rounded.Close, + contentDescription = "", + tint = MaterialTheme.colorScheme.primary + ) + } + } } - - ChatInputType.IMAGE -> MessageInputImage( - disableButtons = uiState.inProgress, - streamingMessage = streamingMessage, - onImageSelected = { bitmap -> - onSendMessage( - selectedModel, listOf( - ChatMessageImage( - bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER - ) - ) - ) - }, - onStreamImage = { bitmap -> - onStreamImageMessage( - selectedModel, ChatMessageImage( - bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER - ) - ) - }, - onStreamEnd = onStreamEnd, - ) } } @@ -498,9 +601,7 @@ fun ChatPanel( ) { Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { Column( - modifier = Modifier - .padding(20.dp), - verticalArrangement = Arrangement.spacedBy(16.dp) + modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) ) { // Title Text( @@ -568,13 +669,10 @@ fun ChatPanel( Row( verticalAlignment = Alignment.CenterVertically, horizontalArrangement = Arrangement.spacedBy(6.dp), - modifier = Modifier - .padding(vertical = 8.dp, horizontal = 16.dp) + modifier = Modifier.padding(vertical = 8.dp, horizontal = 16.dp) ) { Icon( - Icons.Rounded.ContentCopy, - contentDescription = "", - modifier = Modifier.size(18.dp) + Icons.Rounded.ContentCopy, contentDescription = "", modifier = Modifier.size(18.dp) ) Text("Copy text") } @@ -586,6 +684,51 @@ fun ChatPanel( } } +@Composable +fun ZoomableBox( + modifier: Modifier = Modifier, + minScale: Float = 1f, + maxScale: Float = 5f, + content: @Composable ZoomableBoxScope.() -> Unit +) { + var scale by remember { mutableFloatStateOf(1f) } + var offsetX by remember { mutableFloatStateOf(0f) } + var offsetY by remember { mutableFloatStateOf(0f) } + var size by remember { mutableStateOf(IntSize.Zero) } + Box( + modifier = modifier + .clip(RectangleShape) + .onSizeChanged { size = it } + .pointerInput(Unit) { + detectTransformGestures { _, pan, zoom, _ -> + scale = maxOf(minScale, minOf(scale * zoom, maxScale)) + val maxX = (size.width * (scale - 1)) / 2 + val minX = -maxX + offsetX = maxOf(minX, minOf(maxX, offsetX + pan.x)) + val maxY = (size.height * (scale - 1)) / 2 + val minY = -maxY + offsetY = maxOf(minY, minOf(maxY, offsetY + pan.y)) + } + }, + contentAlignment = Alignment.TopEnd + ) { + val scope = ZoomableBoxScopeImpl(scale, offsetX, offsetY) + scope.content() + } +} + +interface ZoomableBoxScope { + val scale: Float + val offsetX: Float + val offsetY: Float +} + +private data class ZoomableBoxScopeImpl( + override val scale: Float, + override val offsetX: Float, + override val offsetY: Float +) : ZoomableBoxScope + @Preview(showBackground = true) @Composable fun ChatPanelPreview() { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt index 91bb5f7..c580572 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt @@ -19,6 +19,7 @@ package com.google.aiedge.gallery.ui.common.chat import android.util.Log import androidx.activity.compose.BackHandler import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.fillMaxSize @@ -27,6 +28,7 @@ import androidx.compose.foundation.pager.HorizontalPager import androidx.compose.foundation.pager.rememberPagerState import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Scaffold +import androidx.compose.material3.Text import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.collectAsState @@ -36,10 +38,13 @@ import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.setValue import androidx.compose.runtime.snapshotFlow +import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.tooling.preview.Preview +import androidx.compose.ui.unit.dp import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.Task @@ -81,7 +86,7 @@ fun ChatView( chatInputType: ChatInputType = ChatInputType.TEXT, showStopButtonInInputWhenInProgress: Boolean = false, ) { - val uiStat by viewModel.uiState.collectAsState() + val uiState by viewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val selectedModel = modelManagerUiState.selectedModel @@ -158,7 +163,9 @@ fun ChatView( model = selectedModel, modelManagerViewModel = modelManagerViewModel, showResetSessionButton = true, - isResettingSession = uiStat.isResettingSession, + isResettingSession = uiState.isResettingSession, + inProgress = uiState.inProgress, + modelPreparing = uiState.preparing, onResetSessionClicked = onResetSessionClicked, onConfigChanged = { old, new -> viewModel.addConfigChangedMessage( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt index c9a2cff..b186305 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatViewModel.kt @@ -38,6 +38,11 @@ data class ChatUiState( */ val isResettingSession: Boolean = false, + /** + * Indicates whether the model is preparing (before outputting any result and after initializing). + */ + val preparing: Boolean = false, + /** * A map of model names to lists of chat messages. */ @@ -204,6 +209,10 @@ open class ChatViewModel(val task: Task) : ViewModel() { _uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) } } + fun setPreparing(preparing: Boolean) { + _uiState.update { _uiState.value.copy(preparing = preparing) } + } + fun addConfigChangedMessage( oldConfigValues: Map, newConfigValues: Map, model: Model ) { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ConfigDialog.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ConfigDialog.kt index ed592f7..9450383 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ConfigDialog.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ConfigDialog.kt @@ -18,6 +18,8 @@ package com.google.aiedge.gallery.ui.common.chat import android.util.Log import androidx.compose.foundation.border +import androidx.compose.foundation.clickable +import androidx.compose.foundation.interaction.MutableInteractionSource import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -28,9 +30,11 @@ import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.width +import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.text.BasicTextField import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll import androidx.compose.material3.Button import androidx.compose.material3.Card import androidx.compose.material3.MaterialTheme @@ -53,6 +57,9 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.focus.FocusRequester import androidx.compose.ui.focus.focusRequester import androidx.compose.ui.focus.onFocusChanged +import androidx.compose.ui.graphics.SolidColor +import androidx.compose.ui.platform.LocalFocusManager +import androidx.compose.ui.text.TextStyle import androidx.compose.ui.text.input.KeyboardType import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp @@ -89,9 +96,20 @@ fun ConfigDialog( putAll(initialValues) } } + val interactionSource = remember { MutableInteractionSource() } Dialog(onDismissRequest = onDismissed) { - Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { + val focusManager = LocalFocusManager.current + Card( + modifier = Modifier + .fillMaxWidth() + .clickable( + interactionSource = interactionSource, indication = null // Disable the ripple effect + ) { + focusManager.clearFocus() + }, + shape = RoundedCornerShape(16.dp) + ) { Column( modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) ) { @@ -114,7 +132,14 @@ fun ConfigDialog( } // List of config rows. - ConfigEditorsPanel(configs = configs, values = values) + Column( + modifier = Modifier + .verticalScroll(rememberScrollState()) + .weight(1f, fill = false), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + ConfigEditorsPanel(configs = configs, values = values) + } // Button row. Row( @@ -264,6 +289,8 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap Box( modifier = Modifier.border( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyImage.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyImage.kt index b5a380f..e69f24c 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyImage.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyImage.kt @@ -25,17 +25,18 @@ import androidx.compose.ui.layout.ContentScale import androidx.compose.ui.unit.dp @Composable -fun MessageBodyImage(message: ChatMessageImage) { +fun MessageBodyImage(message: ChatMessageImage, modifier: Modifier = Modifier) { val bitmapWidth = message.bitmap.width val bitmapHeight = message.bitmap.height val imageWidth = if (bitmapWidth >= bitmapHeight) 200 else (200f / bitmapHeight * bitmapWidth).toInt() val imageHeight = if (bitmapHeight >= bitmapWidth) 200 else (200f / bitmapWidth * bitmapHeight).toInt() + Image( bitmap = message.imageBitMap, contentDescription = "", - modifier = Modifier + modifier = modifier .height(imageHeight.dp) .width(imageWidth.dp), contentScale = ContentScale.Fit, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyInfo.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyInfo.kt index 03c8e3e..844dce6 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyInfo.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyInfo.kt @@ -38,7 +38,7 @@ import com.google.aiedge.gallery.ui.theme.customColors * Supports markdown. */ @Composable -fun MessageBodyInfo(message: ChatMessageInfo) { +fun MessageBodyInfo(message: ChatMessageInfo, smallFontSize: Boolean = true) { Row( modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center ) { @@ -47,7 +47,11 @@ fun MessageBodyInfo(message: ChatMessageInfo) { .clip(RoundedCornerShape(16.dp)) .background(MaterialTheme.customColors.agentBubbleBgColor) ) { - MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true) + MarkdownText( + text = message.content, + modifier = Modifier.padding(12.dp), + smallFontSize = smallFontSize + ) } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt index a24c34f..7359457 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt @@ -103,6 +103,7 @@ fun MessageInputText( @StringRes textFieldPlaceHolderRes: Int, onValueChanged: (String) -> Unit, onSendMessage: (List) -> Unit, + modelPreparing: Boolean = false, onOpenPromptTemplatesClicked: () -> Unit = {}, onStopButtonClicked: () -> Unit = {}, showPromptTemplatesInMenu: Boolean = false, @@ -173,7 +174,7 @@ fun MessageInputText( .height(80.dp) .shadow(2.dp, shape = RoundedCornerShape(8.dp)) .clip(RoundedCornerShape(8.dp)) - .border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)), + .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)), ) Box(modifier = Modifier .offset(x = 10.dp, y = (-10).dp) @@ -219,7 +220,7 @@ fun MessageInputText( expanded = showAddContentMenu, onDismissRequest = { showAddContentMenu = false }) { if (showImagePickerInMenu) { - // Take a photo. + // Take a picture. DropdownMenuItem( text = { Row( @@ -227,7 +228,7 @@ fun MessageInputText( horizontalArrangement = Arrangement.spacedBy(6.dp) ) { Icon(Icons.Rounded.PhotoCamera, contentDescription = "") - Text("Take a photo") + Text("Take a picture") } }, enabled = pickedImages.isEmpty() && !hasImageMessage, @@ -321,7 +322,7 @@ fun MessageInputText( Spacer(modifier = Modifier.width(8.dp)) if (inProgress && showStopButtonWhenInProgress) { - if (!modelInitializing) { + if (!modelInitializing && !modelPreparing) { IconButton( onClick = onStopButtonClicked, colors = IconButtonDefaults.iconButtonColors( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt index c7eb4b3..3738be2 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt @@ -20,8 +20,9 @@ import android.content.Intent import android.net.Uri import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.result.contract.ActivityResultContracts -import androidx.compose.animation.core.animateFloatAsState -import androidx.compose.animation.core.tween +import androidx.compose.animation.AnimatedContent +import androidx.compose.animation.ExperimentalSharedTransitionApi +import androidx.compose.animation.SharedTransitionLayout import androidx.compose.foundation.background import androidx.compose.foundation.clickable import androidx.compose.foundation.interaction.MutableInteractionSource @@ -30,17 +31,13 @@ import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.fillMaxWidth -import androidx.compose.foundation.layout.heightIn -import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.padding import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.rounded.ChevronRight -import androidx.compose.material.icons.rounded.Settings import androidx.compose.material.icons.rounded.UnfoldLess import androidx.compose.material.icons.rounded.UnfoldMore import androidx.compose.material3.Icon -import androidx.compose.material3.IconButton import androidx.compose.material3.OutlinedButton import androidx.compose.material3.Text import androidx.compose.material3.ripple @@ -48,16 +45,12 @@ import androidx.compose.runtime.Composable import androidx.compose.runtime.collectAsState import androidx.compose.runtime.derivedStateOf import androidx.compose.runtime.getValue -import androidx.compose.runtime.movableContentOf -import androidx.compose.runtime.movableContentWithReceiverOf import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier -import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.clip -import androidx.compose.ui.layout.LookaheadScope import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.Dp @@ -91,6 +84,7 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp * model description and buttons for learning more (opening a URL) and downloading/trying * the model. */ +@OptIn(ExperimentalSharedTransitionApi::class) @Composable fun ModelItem( model: Model, @@ -117,181 +111,170 @@ fun ModelItem( var isExpanded by remember { mutableStateOf(false) } - // Animate alpha for model description and button rows when switching between layouts. - val alphaAnimation by animateFloatAsState( - targetValue = if (isExpanded) 1f else 0f, - animationSpec = tween(durationMillis = LAYOUT_ANIMATION_DURATION - 50) - ) - - LookaheadScope { - // Task icon. - val taskIcon = remember { - movableContentOf { - TaskIcon( - task = task, modifier = Modifier.animateLayout() - ) + var boxModifier = modifier + .fillMaxWidth() + .clip(RoundedCornerShape(size = 42.dp)) + .background( + getTaskBgColor(task) + ) + boxModifier = if (canExpand) { + boxModifier.clickable(onClick = { + if (!model.imported) { + isExpanded = !isExpanded + } else { + onModelClicked(model) } - } + }, interactionSource = remember { MutableInteractionSource() }, indication = ripple( + bounded = true, + radius = 1000.dp, + ) + ) + } else { + boxModifier + } - // Model name and status. - val modelNameAndStatus = remember { - movableContentOf { - ModelNameAndStatus( - model = model, - task = task, - downloadStatus = downloadStatus, - isExpanded = isExpanded, - modifier = Modifier.animateLayout() - ) - } - } - - val actionButton = remember { - movableContentOf { - ModelItemActionButton( - context = context, - model = model, - task = task, - modelManagerViewModel = modelManagerViewModel, - downloadStatus = downloadStatus, - onDownloadClicked = { model -> - checkNotificationPermissionAndStartDownload( - context = context, - launcher = launcher, - modelManagerViewModel = modelManagerViewModel, - task = task, - model = model + Box( + modifier = boxModifier, + contentAlignment = Alignment.Center, + ) { + SharedTransitionLayout { + AnimatedContent( + isExpanded, label = "item_layout_transition", + ) { targetState -> + val taskIcon = @Composable { + TaskIcon( + task = task, modifier = Modifier.sharedElement( + sharedContentState = rememberSharedContentState(key = "task_icon"), + animatedVisibilityScope = this@AnimatedContent, ) - }, - showDeleteButton = showDeleteButton, - showDownloadButton = false, - ) - } - } + ) + } - // Expand/collapse icon, or the config icon. - val expandButton = remember { - movableContentOf { - if (showConfigButtonIfExisted) { - if (downloadStatus?.status === ModelDownloadStatusType.SUCCEEDED) { - if (model.configs.isNotEmpty()) { - IconButton(onClick = onConfigClicked) { - Icon( - Icons.Rounded.Settings, - contentDescription = "", - tint = getTaskIconColor(task) - ) - } - } - } - } else { + val modelNameAndStatus = @Composable { + ModelNameAndStatus( + model = model, + task = task, + downloadStatus = downloadStatus, + isExpanded = isExpanded, + animatedVisibilityScope = this@AnimatedContent, + sharedTransitionScope = this@SharedTransitionLayout + ) + } + + val actionButton = @Composable { + ModelItemActionButton( + context = context, + model = model, + task = task, + modelManagerViewModel = modelManagerViewModel, + downloadStatus = downloadStatus, + onDownloadClicked = { model -> + checkNotificationPermissionAndStartDownload( + context = context, + launcher = launcher, + modelManagerViewModel = modelManagerViewModel, + task = task, + model = model + ) + }, + showDeleteButton = showDeleteButton, + showDownloadButton = false, + modifier = Modifier.sharedElement( + sharedContentState = rememberSharedContentState(key = "action_button"), + animatedVisibilityScope = this@AnimatedContent, + ) + ) + } + + val expandButton = @Composable { Icon( // For imported model, show ">" directly indicating users can just tap the model item to // go into it without needing to expand it first. if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore, contentDescription = "", tint = getTaskIconColor(task), + modifier = Modifier.sharedElement( + sharedContentState = rememberSharedContentState(key = "expand_button"), + animatedVisibilityScope = this@AnimatedContent, + ) ) } - } - } - // Model description shown in expanded layout. - val modelDescription = remember { - movableContentOf { m: Modifier -> - if (model.info.isNotEmpty()) { - MarkdownText( - model.info, - modifier = Modifier - .heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp) - .animateLayout() - .then(m) - ) + val description = @Composable { + if (model.info.isNotEmpty()) { + MarkdownText( + model.info, modifier = Modifier + .sharedElement( + sharedContentState = rememberSharedContentState(key = "description"), + animatedVisibilityScope = this@AnimatedContent, + ) + .skipToLookaheadSize() + ) + } } - } - } - // Button rows shown in expanded layout. - val buttonRows = remember { - movableContentOf { m: Modifier -> - Row( - modifier = Modifier - .heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp) - .animateLayout() - .then(m), - horizontalArrangement = Arrangement.spacedBy(12.dp), - ) { - // The "learn more" button. Click to show related urls in a bottom sheet. - if (model.learnMoreUrl.isNotEmpty()) { - OutlinedButton( - onClick = { - if (isExpanded) { - val intent = Intent(Intent.ACTION_VIEW, Uri.parse(model.learnMoreUrl)) - context.startActivity(intent) - } - }, + val buttonsRow = @Composable { + Row( + horizontalArrangement = Arrangement.spacedBy(12.dp), modifier = Modifier + .sharedElement( + sharedContentState = rememberSharedContentState(key = "buttons_row"), + animatedVisibilityScope = this@AnimatedContent, + ) + .skipToLookaheadSize() + ) { + // The "learn more" button. Click to show related urls in a bottom sheet. + if (model.learnMoreUrl.isNotEmpty()) { + OutlinedButton( + onClick = { + if (isExpanded) { + val intent = Intent(Intent.ACTION_VIEW, Uri.parse(model.learnMoreUrl)) + context.startActivity(intent) + } + }, + ) { + Text("Learn More", maxLines = 1) + } + } + + // Button to start the download and start the chat session with the model. + val needToDownloadFirst = + downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED + DownloadAndTryButton(task = task, + model = model, + enabled = isExpanded, + needToDownloadFirst = needToDownloadFirst, + modelManagerViewModel = modelManagerViewModel, + onClicked = { onModelClicked(model) }) + } + } + + // Collapsed state. + if (!targetState) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(12.dp), + modifier = Modifier + .fillMaxWidth() + .padding(start = 18.dp, end = 18.dp) + .padding(vertical = verticalSpacing) ) { - Text("Learn More", maxLines = 1) + // Icon at the left. + taskIcon() + // Model name and status at the center. + Row(modifier = Modifier.weight(1f)) { + modelNameAndStatus() + } + // Action button and expand/collapse button at the right. + Row(verticalAlignment = Alignment.CenterVertically) { + actionButton() + expandButton() + } } } - - // Button to start the download and start the chat session with the model. - val needToDownloadFirst = - downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED - DownloadAndTryButton( - task = task, - model = model, - enabled = isExpanded, - needToDownloadFirst = needToDownloadFirst, - modelManagerViewModel = modelManagerViewModel, - onClicked = { onModelClicked(model) } - ) - } - } - } - - val container = remember { - movableContentWithReceiverOf Unit> { content -> - Box( - modifier = Modifier.animateLayout(), - contentAlignment = Alignment.TopEnd, - ) { - content() - } - } - } - - var boxModifier = modifier - .fillMaxWidth() - .clip(RoundedCornerShape(size = 42.dp)) - .background( - getTaskBgColor(task) - ) - boxModifier = if (canExpand) { - boxModifier.clickable( - onClick = { - if (!model.imported) { - isExpanded = !isExpanded - } else { - onModelClicked(model) - } - }, - interactionSource = remember { MutableInteractionSource() }, - indication = ripple( - bounded = true, - radius = 500.dp, - ) - ) - } else { - boxModifier - } - Box( - modifier = boxModifier, - contentAlignment = Alignment.Center - ) { - if (isExpanded) { - container { - // The main part (icon, model name, status, etc) + } else { Column( verticalArrangement = Arrangement.spacedBy(14.dp), horizontalAlignment = Alignment.CenterHorizontally, @@ -300,7 +283,9 @@ fun ModelItem( .padding(vertical = verticalSpacing, horizontal = 18.dp) ) { Box(contentAlignment = Alignment.Center) { + // Icon at the top-center. taskIcon() + // Action button and expand/collapse button at the right. Row( modifier = Modifier.fillMaxWidth(), verticalAlignment = Alignment.CenterVertically, @@ -310,39 +295,12 @@ fun ModelItem( expandButton() } } + // Name and status below the icon. modelNameAndStatus() - modelDescription(Modifier.alpha(alphaAnimation)) - buttonRows(Modifier.alpha(alphaAnimation)) // Apply alpha here - } - } - } else { - container { - Column(horizontalAlignment = Alignment.CenterHorizontally) { - // The main part (icon, model name, status, etc) - Row( - verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.spacedBy(12.dp), - modifier = Modifier - .fillMaxWidth() - .padding(start = 18.dp, end = 18.dp) - .padding(vertical = verticalSpacing) - ) { - taskIcon() - Row(modifier = Modifier.weight(1f)) { - modelNameAndStatus() - } - Row(verticalAlignment = Alignment.CenterVertically) { - actionButton() - expandButton() - } - } - Column( - modifier = Modifier.offset(y = 30.dp), - horizontalAlignment = Alignment.CenterHorizontally - ) { - modelDescription(Modifier.alpha(alphaAnimation)) - buttonRows(Modifier.alpha(alphaAnimation)) - } + // Description. + description() + // Buttons + buttonsRow() } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItemActionButton.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItemActionButton.kt index 0fa31bd..d852d8a 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItemActionButton.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItemActionButton.kt @@ -108,7 +108,8 @@ fun ModelItemActionButton( // Button to cancel the download when it is in progress. ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = { modelManagerViewModel.cancelDownloadModel( - model + task = task, + model = model ) }) { Icon( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt index da435f9..53115bd 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt @@ -16,6 +16,9 @@ package com.google.aiedge.gallery.ui.common.modelitem +import androidx.compose.animation.AnimatedVisibilityScope +import androidx.compose.animation.ExperimentalSharedTransitionApi +import androidx.compose.animation.SharedTransitionScope import androidx.compose.animation.core.Animatable import androidx.compose.animation.core.tween import androidx.compose.foundation.layout.Column @@ -30,6 +33,7 @@ import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.remember import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.focus.focusModifier import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp @@ -54,134 +58,164 @@ import com.google.aiedge.gallery.ui.theme.labelSmallNarrow * - "Unzipping..." status for unzipping processes. * - Model size for successful downloads. */ +@OptIn(ExperimentalSharedTransitionApi::class) @Composable fun ModelNameAndStatus( model: Model, task: Task, downloadStatus: ModelDownloadStatus?, isExpanded: Boolean, + sharedTransitionScope: SharedTransitionScope, + animatedVisibilityScope: AnimatedVisibilityScope, modifier: Modifier = Modifier ) { val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED var curDownloadProgress = 0f - Column( - horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start - ) { - // Model name. - Row( - verticalAlignment = Alignment.CenterVertically, + with(sharedTransitionScope) { + Column( + horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start ) { - Text( - model.name, - maxLines = 1, - overflow = TextOverflow.MiddleEllipsis, - style = MaterialTheme.typography.titleMedium, - modifier = modifier, - ) - } - - Row(verticalAlignment = Alignment.CenterVertically) { - // Status icon. - if (!inProgress && !isPartiallyDownloaded) { - StatusIcon( - downloadStatus = downloadStatus, - modifier = modifier.padding(end = 4.dp) + // Model name. + Row( + verticalAlignment = Alignment.CenterVertically, + ) { + Text( + model.name, + maxLines = 1, + overflow = TextOverflow.MiddleEllipsis, + style = MaterialTheme.typography.titleMedium, + modifier = Modifier.sharedElement( + rememberSharedContentState(key = "model_name"), + animatedVisibilityScope = animatedVisibilityScope + ) ) } - // Failure message. - if (downloadStatus != null && downloadStatus.status == ModelDownloadStatusType.FAILED) { - Row(verticalAlignment = Alignment.CenterVertically) { - Text( - downloadStatus.errorMessage, - color = MaterialTheme.colorScheme.error, - style = labelSmallNarrow, - overflow = TextOverflow.Ellipsis, - modifier = modifier, + Row(verticalAlignment = Alignment.CenterVertically) { + // Status icon. + if (!inProgress && !isPartiallyDownloaded) { + StatusIcon( + downloadStatus = downloadStatus, + modifier = modifier + .padding(end = 4.dp) + .sharedElement( + rememberSharedContentState(key = "download_status_icon"), + animatedVisibilityScope = animatedVisibilityScope + ) ) } - } - // Status label - else { - var sizeLabel = model.totalBytes.humanReadableSize() - var fontSize = 11.sp - - // Populate the status label. - if (downloadStatus != null) { - // For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime} - if (inProgress || isPartiallyDownloaded) { - var totalSize = downloadStatus.totalBytes - if (totalSize == 0L) { - totalSize = model.totalBytes - } - sizeLabel = - "${downloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}" - if (downloadStatus.bytesPerSecond > 0) { - sizeLabel = - "$sizeLabel · ${downloadStatus.bytesPerSecond.humanReadableSize()} / s" - if (downloadStatus.remainingMs >= 0) { - sizeLabel = - "$sizeLabel\n${downloadStatus.remainingMs.formatToHourMinSecond()} left" - } - } - if (isPartiallyDownloaded) { - sizeLabel = "$sizeLabel (resuming...)" - } - curDownloadProgress = - downloadStatus.receivedBytes.toFloat() / downloadStatus.totalBytes.toFloat() - if (curDownloadProgress.isNaN()) { - curDownloadProgress = 0f - } - fontSize = 9.sp - } - // Status for unzipping. - else if (downloadStatus.status == ModelDownloadStatusType.UNZIPPING) { - sizeLabel = "Unzipping..." - } - } - - Column( - horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start, - ) { - for ((index, line) in sizeLabel.split("\n").withIndex()) { + // Failure message. + if (downloadStatus != null && downloadStatus.status == ModelDownloadStatusType.FAILED) { + Row(verticalAlignment = Alignment.CenterVertically) { Text( - line, - color = MaterialTheme.colorScheme.secondary, - maxLines = 1, - style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp), - textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start, - overflow = TextOverflow.Visible, - modifier = modifier.offset(y = if (index == 0) 0.dp else (-1).dp) + downloadStatus.errorMessage, + color = MaterialTheme.colorScheme.error, + style = labelSmallNarrow, + overflow = TextOverflow.Ellipsis, + modifier = Modifier.sharedElement( + rememberSharedContentState(key = "failure_messsage"), + animatedVisibilityScope = animatedVisibilityScope + ) ) } } - } - } - // Download progress bar. - if (inProgress || isPartiallyDownloaded) { - val animatedProgress = remember { Animatable(0f) } - LinearProgressIndicator( - progress = { animatedProgress.value }, - color = getTaskIconColor(task = task), - trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, - modifier = modifier.padding(top = 2.dp) - ) - LaunchedEffect(curDownloadProgress) { - animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150)) + // Status label + else { + var sizeLabel = model.totalBytes.humanReadableSize() + var fontSize = 11.sp + + // Populate the status label. + if (downloadStatus != null) { + // For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime} + if (inProgress || isPartiallyDownloaded) { + var totalSize = downloadStatus.totalBytes + if (totalSize == 0L) { + totalSize = model.totalBytes + } + sizeLabel = + "${downloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}" + if (downloadStatus.bytesPerSecond > 0) { + sizeLabel = + "$sizeLabel · ${downloadStatus.bytesPerSecond.humanReadableSize()} / s" + if (downloadStatus.remainingMs >= 0) { + sizeLabel = + "$sizeLabel\n${downloadStatus.remainingMs.formatToHourMinSecond()} left" + } + } + if (isPartiallyDownloaded) { + sizeLabel = "$sizeLabel (resuming...)" + } + curDownloadProgress = + downloadStatus.receivedBytes.toFloat() / downloadStatus.totalBytes.toFloat() + if (curDownloadProgress.isNaN()) { + curDownloadProgress = 0f + } + fontSize = 9.sp + } + // Status for unzipping. + else if (downloadStatus.status == ModelDownloadStatusType.UNZIPPING) { + sizeLabel = "Unzipping..." + } + } + + Column( + horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start, + ) { + for ((index, line) in sizeLabel.split("\n").withIndex()) { + Text( + line, + color = MaterialTheme.colorScheme.secondary, + maxLines = 1, + style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp), + textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start, + overflow = TextOverflow.Visible, + modifier = Modifier + .offset(y = if (index == 0) 0.dp else (-1).dp) + .sharedElement( + rememberSharedContentState(key = "status_label_${index}"), + animatedVisibilityScope = animatedVisibilityScope + ) + ) + } + } + } + } + + // Download progress bar. + if (inProgress || isPartiallyDownloaded) { + val animatedProgress = remember { Animatable(0f) } + LinearProgressIndicator( + progress = { animatedProgress.value }, + color = getTaskIconColor(task = task), + trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, + modifier = Modifier + .padding(top = 2.dp) + .sharedElement( + rememberSharedContentState(key = "download_progress_bar"), + animatedVisibilityScope = animatedVisibilityScope + ) + ) + LaunchedEffect(curDownloadProgress) { + animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150)) + } + } + // Unzipping progress. + else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) { + LinearProgressIndicator( + color = getTaskIconColor(task = task), + trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, + modifier = Modifier + .padding(top = 2.dp) + .sharedElement( + rememberSharedContentState(key = "unzip_progress_bar"), + animatedVisibilityScope = animatedVisibilityScope + ) + ) } - } - // Unzipping progress. - else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) { - LinearProgressIndicator( - color = getTaskIconColor(task = task), - trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, - modifier = Modifier - .padding(top = 2.dp), - ) } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt index 0461118..f3c369d 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt @@ -16,8 +16,11 @@ package com.google.aiedge.gallery.ui.home +import android.app.Activity +import android.content.Context import android.content.Intent import android.net.Uri +import android.provider.OpenableColumns import android.util.Log import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.result.ActivityResultLauncher @@ -49,6 +52,7 @@ import androidx.compose.material.icons.automirrored.outlined.NoteAdd import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.rounded.Error import androidx.compose.material3.AlertDialog +import androidx.compose.material3.Button import androidx.compose.material3.Card import androidx.compose.material3.CardDefaults import androidx.compose.material3.CircularProgressIndicator @@ -80,12 +84,18 @@ import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.scale import androidx.compose.ui.graphics.Brush import androidx.compose.ui.input.nestedscroll.nestedScroll -import androidx.compose.ui.layout.layout -import androidx.compose.ui.platform.LocalConfiguration import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.platform.LocalDensity +import androidx.compose.ui.platform.LocalWindowInfo import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.LinkAnnotation +import androidx.compose.ui.text.SpanStyle +import androidx.compose.ui.text.TextLinkStyles +import androidx.compose.ui.text.buildAnnotatedString import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.text.style.TextDecoration +import androidx.compose.ui.text.withLink import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.sp @@ -93,7 +103,6 @@ import com.google.aiedge.gallery.GalleryTopAppBar import com.google.aiedge.gallery.R import com.google.aiedge.gallery.data.AppBarAction import com.google.aiedge.gallery.data.AppBarActionType -import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ImportedModelInfo import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.TaskIcon @@ -101,7 +110,6 @@ import com.google.aiedge.gallery.ui.common.getTaskBgColor import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.theme.GalleryTheme -import com.google.aiedge.gallery.ui.theme.ThemeSettings import com.google.aiedge.gallery.ui.theme.customColors import com.google.aiedge.gallery.ui.theme.titleMediumNarrow import kotlinx.coroutines.delay @@ -109,6 +117,12 @@ import kotlinx.coroutines.launch private const val TAG = "AGHomeScreen" private const val TASK_COUNT_ANIMATION_DURATION = 250 +private const val MAX_TASK_CARD_PADDING = 24 +private const val MIN_TASK_CARD_PADDING = 18 +private const val MAX_TASK_CARD_RADIUS = 43.5 +private const val MIN_TASK_CARD_RADIUS = 30 +private const val MAX_TASK_CARD_ICON_SIZE = 56 +private const val MIN_TASK_CARD_ICON_SIZE = 50 /** Navigation destination data */ object HomeScreenDestination { @@ -127,6 +141,7 @@ fun HomeScreen( val uiState by modelManagerViewModel.uiState.collectAsState() var showSettingsDialog by remember { mutableStateOf(false) } var showImportModelSheet by remember { mutableStateOf(false) } + var showUnsupportedFileTypeDialog by remember { mutableStateOf(false) } val sheetState = rememberModalBottomSheetState() var showImportDialog by remember { mutableStateOf(false) } var showImportingDialog by remember { mutableStateOf(false) } @@ -135,17 +150,21 @@ fun HomeScreen( val coroutineScope = rememberCoroutineScope() val snackbarHostState = remember { SnackbarHostState() } val scope = rememberCoroutineScope() - - val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 } - val loadingHfModels = uiState.loadingHfModels + val context = LocalContext.current val filePickerLauncher: ActivityResultLauncher = rememberLauncherForActivityResult( contract = ActivityResultContracts.StartActivityForResult() ) { result -> if (result.resultCode == android.app.Activity.RESULT_OK) { result.data?.data?.let { uri -> - selectedLocalModelFileUri.value = uri - showImportDialog = true + val fileName = getFileName(context = context, uri = uri) + Log.d(TAG, "Selected file: $fileName") + if (fileName != null && !fileName.endsWith(".task")) { + showUnsupportedFileTypeDialog = true + } else { + selectedLocalModelFileUri.value = uri + showImportDialog = true + } } ?: run { Log.d(TAG, "No file selected or URI is null.") } @@ -154,36 +173,29 @@ fun HomeScreen( } } - Scaffold( - modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), - topBar = { - GalleryTopAppBar( - title = stringResource(HomeScreenDestination.titleRes), - rightAction = AppBarAction( - actionType = AppBarActionType.APP_SETTING, actionFn = { - showSettingsDialog = true - } - ), - loadingHfModels = loadingHfModels, - scrollBehavior = scrollBehavior, - ) - }, - floatingActionButton = { - // A floating action button to show "import model" bottom sheet. - SmallFloatingActionButton( - onClick = { - showImportModelSheet = true - }, - containerColor = MaterialTheme.colorScheme.secondaryContainer, - contentColor = MaterialTheme.colorScheme.secondary, - ) { - Icon(Icons.Filled.Add, "") - } + Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = { + GalleryTopAppBar( + title = stringResource(HomeScreenDestination.titleRes), + rightAction = AppBarAction(actionType = AppBarActionType.APP_SETTING, actionFn = { + showSettingsDialog = true + }), + scrollBehavior = scrollBehavior, + ) + }, floatingActionButton = { + // A floating action button to show "import model" bottom sheet. + SmallFloatingActionButton( + onClick = { + showImportModelSheet = true + }, + containerColor = MaterialTheme.colorScheme.secondaryContainer, + contentColor = MaterialTheme.colorScheme.secondary, + ) { + Icon(Icons.Filled.Add, "") } - ) { innerPadding -> + }) { innerPadding -> Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) { TaskList( - tasks = nonEmptyTasks, + tasks = uiState.tasks, navigateToTaskScreen = navigateToTaskScreen, loadingModelAllowlist = uiState.loadingModelAllowlist, modifier = Modifier.fillMaxSize(), @@ -198,16 +210,8 @@ fun HomeScreen( if (showSettingsDialog) { SettingsDialog( curThemeOverride = modelManagerViewModel.readThemeOverride(), + modelManagerViewModel = modelManagerViewModel, onDismissed = { showSettingsDialog = false }, - onOk = { curConfigValues -> - // Update theme settings. - // This will update app's theme. - val themeOverride = curConfigValues[ConfigKey.THEME.label] as String - ThemeSettings.themeOverride.value = themeOverride - - // Save to data store. - modelManagerViewModel.saveThemeOverride(themeOverride) - }, ) } @@ -232,10 +236,6 @@ fun HomeScreen( val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply { addCategory(Intent.CATEGORY_OPENABLE) type = "*/*" - putExtra( - Intent.EXTRA_MIME_TYPES, - arrayOf("application/x-binary", "application/octet-stream") - ) // Single select. putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false) } @@ -259,13 +259,11 @@ fun HomeScreen( // Import dialog if (showImportDialog) { selectedLocalModelFileUri.value?.let { uri -> - ModelImportDialog(uri = uri, - onDismiss = { showImportDialog = false }, - onDone = { info -> - selectedImportedModelInfo.value = info - showImportDialog = false - showImportingDialog = true - }) + ModelImportDialog(uri = uri, onDismiss = { showImportDialog = false }, onDone = { info -> + selectedImportedModelInfo.value = info + showImportDialog = false + showImportingDialog = true + }) } } @@ -273,8 +271,7 @@ fun HomeScreen( if (showImportingDialog) { selectedLocalModelFileUri.value?.let { uri -> selectedImportedModelInfo.value?.let { info -> - ModelImportingDialog( - uri = uri, + ModelImportingDialog(uri = uri, info = info, onDismiss = { showImportingDialog = false }, onDone = { @@ -292,6 +289,22 @@ fun HomeScreen( } } + // Alert dialog for unsupported file type. + if (showUnsupportedFileTypeDialog) { + AlertDialog( + onDismissRequest = { showUnsupportedFileTypeDialog = false }, + title = { Text("Unsupported file type") }, + text = { + Text("Only \".task\" file type is supported.") + }, + confirmButton = { + Button(onClick = { showUnsupportedFileTypeDialog = false }) { + Text(stringResource(R.string.ok)) + } + }, + ) + } + if (uiState.loadingModelAllowlistError.isNotEmpty()) { AlertDialog( icon = { @@ -307,11 +320,9 @@ fun HomeScreen( modelManagerViewModel.loadModelAllowlist() }, confirmButton = { - TextButton( - onClick = { - modelManagerViewModel.loadModelAllowlist() - } - ) { + TextButton(onClick = { + modelManagerViewModel.loadModelAllowlist() + }) { Text("Retry") } }, @@ -327,6 +338,38 @@ private fun TaskList( modifier: Modifier = Modifier, contentPadding: PaddingValues = PaddingValues(0.dp), ) { + val density = LocalDensity.current + val windowInfo = LocalWindowInfo.current + val screenWidthDp = remember { + with(density) { + windowInfo.containerSize.width.toDp() + } + } + val screenHeightDp = remember { + with(density) { + windowInfo.containerSize.height.toDp() + } + } + val sizeFraction = remember { ((screenWidthDp - 360.dp) / (410.dp - 360.dp)).coerceIn(0f, 1f) } + val linkColor = MaterialTheme.customColors.linkColor + + val introText = buildAnnotatedString { + append("Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from ") + withLink( + link = LinkAnnotation.Url( + url = "https://huggingface.co/litert-community", // Replace with the actual URL + styles = TextLinkStyles( + style = SpanStyle( + color = linkColor, + textDecoration = TextDecoration.Underline, + ) + ) + ) + ) { + append("LiteRT community") + } + } + Box(modifier = modifier.fillMaxSize()) { LazyVerticalGrid( columns = GridCells.Fixed(count = 2), @@ -335,10 +378,15 @@ private fun TaskList( horizontalArrangement = Arrangement.spacedBy(8.dp), verticalArrangement = Arrangement.spacedBy(8.dp), ) { + // New rel + item(key = "newReleaseNotification", span = { GridItemSpan(2) }) { + NewReleaseNotification() + } + // Headline. item(key = "headline", span = { GridItemSpan(2) }) { Text( - "Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community", + introText, textAlign = TextAlign.Center, style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), modifier = Modifier.padding(bottom = 20.dp) @@ -364,14 +412,21 @@ private fun TaskList( } } } else { - // Cards. + // LLM Cards. + item(key = "llmCardsHeader", span = { GridItemSpan(2) }) { + Text( + "Example LLM Use Cases", + style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold), + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.padding(bottom = 4.dp) + ) + } + items(tasks) { task -> TaskCard( - task = task, - onClick = { + sizeFraction = sizeFraction, task = task, onClick = { navigateToTaskScreen(task) - }, - modifier = Modifier + }, modifier = Modifier .fillMaxWidth() .aspectRatio(1f) ) @@ -388,7 +443,7 @@ private fun TaskList( Box( modifier = Modifier .fillMaxWidth() - .height(LocalConfiguration.current.screenHeightDp.dp * 0.25f) + .height(screenHeightDp * 0.25f) .background( Brush.verticalGradient( colors = MaterialTheme.customColors.homeBottomGradient, @@ -400,7 +455,15 @@ private fun TaskList( } @Composable -private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) { +private fun TaskCard( + task: Task, onClick: () -> Unit, sizeFraction: Float, modifier: Modifier = Modifier +) { + val padding = + (MAX_TASK_CARD_PADDING - MIN_TASK_CARD_PADDING) * sizeFraction + MIN_TASK_CARD_PADDING + val radius = (MAX_TASK_CARD_RADIUS - MIN_TASK_CARD_RADIUS) * sizeFraction + MIN_TASK_CARD_RADIUS + val iconSize = + (MAX_TASK_CARD_ICON_SIZE - MIN_TASK_CARD_ICON_SIZE) * sizeFraction + MIN_TASK_CARD_ICON_SIZE + // Observes the model count and updates the model count label with a fade-in/fade-out animation // whenever the count changes. val modelCount by remember { @@ -445,7 +508,7 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif Card( modifier = modifier - .clip(RoundedCornerShape(43.5.dp)) + .clip(RoundedCornerShape(radius.dp)) .clickable( onClick = onClick, ), @@ -456,39 +519,24 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif Column( modifier = Modifier .fillMaxSize() - .padding(24.dp), + .padding(padding.dp), ) { // Icon. - TaskIcon(task = task) + TaskIcon(task = task, width = iconSize.dp) - Spacer(modifier = Modifier.weight(1f)) + Spacer(modifier = Modifier.weight(2f)) // Title. - val pair = task.type.label.splitByFirstSpace() Text( - pair.first, + task.type.label, color = MaterialTheme.colorScheme.primary, style = titleMediumNarrow.copy( fontSize = 20.sp, fontWeight = FontWeight.Bold, ), ) - if (pair.second.isNotEmpty()) { - Text( - pair.second, - color = MaterialTheme.colorScheme.primary, - style = titleMediumNarrow.copy( - fontSize = 18.sp, - fontWeight = FontWeight.Bold, - ), - modifier = Modifier.layout { measurable, constraints -> - val placeable = measurable.measure(constraints) - layout(placeable.width, placeable.height) { - placeable.placeRelative(0, -4.dp.roundToPx()) - } - } - ) - } + + Spacer(modifier = Modifier.weight(1f)) // Model count. Text( @@ -503,12 +551,21 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif } } -private fun String.splitByFirstSpace(): Pair { - val spaceIndex = this.indexOf(' ') - if (spaceIndex == -1) { - return Pair(this, "") +// Helper function to get the file name from a URI +fun getFileName(context: Context, uri: Uri): String? { + if (uri.scheme == "content") { + context.contentResolver.query(uri, null, null, null, null)?.use { cursor -> + if (cursor.moveToFirst()) { + val nameIndex = cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME) + if (nameIndex != -1) { + return cursor.getString(nameIndex) + } + } + } + } else if (uri.scheme == "file") { + return uri.lastPathSegment } - return Pair(this.substring(0, spaceIndex), this.substring(spaceIndex + 1)) + return null } @Preview diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt index c5e9733..7708d19 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt @@ -22,12 +22,16 @@ import android.provider.OpenableColumns import android.util.Log import androidx.compose.animation.core.Animatable import androidx.compose.animation.core.tween +import androidx.compose.foundation.clickable +import androidx.compose.foundation.interaction.MutableInteractionSource import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.verticalScroll import androidx.compose.material.icons.Icons import androidx.compose.material.icons.rounded.Error import androidx.compose.material3.Button @@ -51,6 +55,7 @@ import androidx.compose.runtime.snapshots.SnapshotStateMap import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.unit.dp import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.DialogProperties @@ -82,41 +87,34 @@ import java.nio.charset.StandardCharsets private const val TAG = "AGModelImportDialog" private val IMPORT_CONFIGS_LLM: List = listOf( - LabelConfig(key = ConfigKey.NAME), - LabelConfig(key = ConfigKey.MODEL_TYPE), - NumberSliderConfig( + LabelConfig(key = ConfigKey.NAME), LabelConfig(key = ConfigKey.MODEL_TYPE), NumberSliderConfig( key = ConfigKey.DEFAULT_MAX_TOKENS, sliderMin = 100f, sliderMax = 1024f, defaultValue = DEFAULT_MAX_TOKEN.toFloat(), valueType = ValueType.INT - ), - NumberSliderConfig( + ), NumberSliderConfig( key = ConfigKey.DEFAULT_TOPK, sliderMin = 5f, sliderMax = 40f, defaultValue = DEFAULT_TOPK.toFloat(), valueType = ValueType.INT - ), - NumberSliderConfig( + ), NumberSliderConfig( key = ConfigKey.DEFAULT_TOPP, sliderMin = 0.0f, sliderMax = 1.0f, defaultValue = DEFAULT_TOPP, valueType = ValueType.FLOAT - ), - NumberSliderConfig( + ), NumberSliderConfig( key = ConfigKey.DEFAULT_TEMPERATURE, sliderMin = 0.0f, sliderMax = 2.0f, defaultValue = DEFAULT_TEMPERATURE, valueType = ValueType.FLOAT - ), - BooleanSwitchConfig( + ), BooleanSwitchConfig( key = ConfigKey.SUPPORT_IMAGE, defaultValue = false, - ), - SegmentedButtonConfig( + ), SegmentedButtonConfig( key = ConfigKey.COMPATIBLE_ACCELERATORS, defaultValue = Accelerator.CPU.label, options = listOf(Accelerator.CPU.label, Accelerator.GPU.label), @@ -126,9 +124,7 @@ private val IMPORT_CONFIGS_LLM: List = listOf( @Composable fun ModelImportDialog( - uri: Uri, - onDismiss: () -> Unit, - onDone: (ImportedModelInfo) -> Unit + uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit ) { val context = LocalContext.current val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) } @@ -150,15 +146,23 @@ fun ModelImportDialog( putAll(initialValues) } } + val interactionSource = remember { MutableInteractionSource() } Dialog( onDismissRequest = onDismiss, ) { - Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { + val focusManager = LocalFocusManager.current + Card( + modifier = Modifier + .fillMaxWidth() + .clickable( + interactionSource = interactionSource, indication = null // Disable the ripple effect + ) { + focusManager.clearFocus() + }, shape = RoundedCornerShape(16.dp) + ) { Column( - modifier = Modifier - .padding(20.dp), - verticalArrangement = Arrangement.spacedBy(16.dp) + modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) ) { // Title. Text( @@ -167,11 +171,18 @@ fun ModelImportDialog( modifier = Modifier.padding(bottom = 8.dp) ) - // Default configs for users to set. - ConfigEditorsPanel( - configs = IMPORT_CONFIGS_LLM, - values = values, - ) + Column( + modifier = Modifier + .verticalScroll(rememberScrollState()) + .weight(1f, fill = false), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Default configs for users to set. + ConfigEditorsPanel( + configs = IMPORT_CONFIGS_LLM, + values = values, + ) + } // Button row. Row( @@ -210,10 +221,7 @@ fun ModelImportDialog( @Composable fun ModelImportingDialog( - uri: Uri, - info: ImportedModelInfo, - onDismiss: () -> Unit, - onDone: (ImportedModelInfo) -> Unit + uri: Uri, info: ImportedModelInfo, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit ) { var error by remember { mutableStateOf("") } val context = LocalContext.current @@ -222,8 +230,7 @@ fun ModelImportingDialog( LaunchedEffect(Unit) { // Import. - importModel( - context = context, + importModel(context = context, coroutineScope = coroutineScope, fileName = info.fileName, fileSize = info.fileSize, @@ -236,8 +243,7 @@ fun ModelImportingDialog( }, onError = { error = it - } - ) + }) } Dialog( @@ -246,9 +252,7 @@ fun ModelImportingDialog( ) { Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { Column( - modifier = Modifier - .padding(20.dp), - verticalArrangement = Arrangement.spacedBy(16.dp) + modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) ) { // Title. Text( @@ -280,13 +284,10 @@ fun ModelImportingDialog( // Has error. else { Row( - verticalAlignment = Alignment.Top, - horizontalArrangement = Arrangement.spacedBy(6.dp) + verticalAlignment = Alignment.Top, horizontalArrangement = Arrangement.spacedBy(6.dp) ) { Icon( - Icons.Rounded.Error, - contentDescription = "", - tint = MaterialTheme.colorScheme.error + Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error ) Text( error, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt new file mode 100644 index 0000000..231a92b --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt @@ -0,0 +1,121 @@ +package com.google.aiedge.gallery.ui.home + +import android.util.Log +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.expandVertically +import androidx.compose.animation.fadeIn +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.rounded.OpenInNew +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.unit.dp +import com.google.aiedge.gallery.BuildConfig +import com.google.aiedge.gallery.ui.common.getJsonResponse +import com.google.aiedge.gallery.ui.modelmanager.ClickableLink +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import kotlinx.serialization.Serializable +import kotlin.math.max + +private const val TAG = "AGNewReleaseNotification" +private const val REPO = "google-ai-edge/gallery" + +@Serializable +data class ReleaseInfo( + val html_url: String, + val tag_name: String, +) + +@Composable +fun NewReleaseNotification() { + var newReleaseVersion by remember { mutableStateOf("") } + var newReleaseUrl by remember { mutableStateOf("") } + + LaunchedEffect(Unit) { + withContext(Dispatchers.IO) { + Log.d("AGNewReleaseNotification", "Checking for new release...") + val info = getJsonResponse("https://api.github.com/repos/$REPO/releases/latest") + if (info != null) { + val curRelease = BuildConfig.VERSION_NAME + val newRelease = info.tag_name + val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease) + Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer") + if (isNewer) { + newReleaseVersion = newRelease + newReleaseUrl = info.html_url + } + } + } + } + + AnimatedVisibility( + visible = newReleaseVersion.isNotEmpty(), + enter = fadeIn() + expandVertically() + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween, + modifier = Modifier + .padding(horizontal = 16.dp) + .padding(bottom = 12.dp) + .clip( + CircleShape + ) + .background(MaterialTheme.colorScheme.tertiaryContainer) + .padding(4.dp) + ) { + Text( + "New release $newReleaseVersion available", + style = MaterialTheme.typography.bodyMedium, + modifier = Modifier.padding(start = 12.dp) + ) + Row( + modifier = Modifier.padding(end = 12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + ClickableLink( + url = newReleaseUrl, + linkText = "View", + icon = Icons.AutoMirrored.Rounded.OpenInNew, + ) + } + } + } +} + +fun isNewerRelease(currentRelease: String, newRelease: String): Boolean { + // Split the version strings into their individual components (e.g., "0.9.0" -> ["0", "9", "0"]) + val currentComponents = currentRelease.split('.').map { it.toIntOrNull() ?: 0 } + val newComponents = newRelease.split('.').map { it.toIntOrNull() ?: 0 } + + // Determine the maximum number of components to iterate through + val maxComponents = max(currentComponents.size, newComponents.size) + + // Iterate through the components from left to right (major, minor, patch, etc.) + for (i in 0 until maxComponents) { + val currentComponent = currentComponents.getOrElse(i) { 0 } + val newComponent = newComponents.getOrElse(i) { 0 } + + if (newComponent > currentComponent) { + return true + } else if (newComponent < currentComponent) { + return false + } + } + + return false +} diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt index f8b4607..e5becbe 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt @@ -16,45 +16,266 @@ package com.google.aiedge.gallery.ui.home +import androidx.compose.foundation.border +import androidx.compose.foundation.clickable +import androidx.compose.foundation.interaction.MutableInteractionSource +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.offset +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.wrapContentHeight +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.BasicTextField +import androidx.compose.foundation.verticalScroll +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.rounded.CheckCircle +import androidx.compose.material3.Button +import androidx.compose.material3.Card +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.MultiChoiceSegmentedButtonRow +import androidx.compose.material3.OutlinedButton +import androidx.compose.material3.SegmentedButton +import androidx.compose.material3.SegmentedButtonDefaults +import androidx.compose.material3.Text import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.focus.FocusRequester +import androidx.compose.ui.focus.focusRequester +import androidx.compose.ui.focus.onFocusChanged +import androidx.compose.ui.graphics.SolidColor +import androidx.compose.ui.platform.LocalFocusManager +import androidx.compose.ui.text.TextStyle +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.dp +import androidx.compose.ui.window.Dialog import com.google.aiedge.gallery.BuildConfig -import com.google.aiedge.gallery.data.Config -import com.google.aiedge.gallery.data.ConfigKey -import com.google.aiedge.gallery.data.SegmentedButtonConfig -import com.google.aiedge.gallery.ui.common.chat.ConfigDialog +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.theme.THEME_AUTO import com.google.aiedge.gallery.ui.theme.THEME_DARK import com.google.aiedge.gallery.ui.theme.THEME_LIGHT +import com.google.aiedge.gallery.ui.theme.ThemeSettings +import com.google.aiedge.gallery.ui.theme.labelSmallNarrow +import java.time.Instant +import java.time.ZoneId +import java.time.format.DateTimeFormatter +import java.util.Locale +import kotlin.math.min -private val CONFIGS: List = listOf( - SegmentedButtonConfig( - key = ConfigKey.THEME, - defaultValue = THEME_AUTO, - options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK), - ) -) +private val THEME_OPTIONS = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK) @Composable fun SettingsDialog( curThemeOverride: String, + modelManagerViewModel: ModelManagerViewModel, onDismissed: () -> Unit, - onOk: (Map) -> Unit, ) { - val initialValues = mapOf( - ConfigKey.THEME.label to curThemeOverride - ) - ConfigDialog( - title = "Settings", - subtitle = "App version: ${BuildConfig.VERSION_NAME}", - okBtnLabel = "OK", - configs = CONFIGS, - initialValues = initialValues, - onDismissed = onDismissed, - onOk = { curConfigValues -> - onOk(curConfigValues) + var selectedTheme by remember { mutableStateOf(curThemeOverride) } + var hfToken by remember { mutableStateOf(modelManagerViewModel.getTokenStatusAndData().data) } + val dateFormatter = remember { + DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.systemDefault()) + .withLocale(Locale.getDefault()) + } + var customHfToken by remember { mutableStateOf("") } + var isFocused by remember { mutableStateOf(false) } + val focusRequester = remember { FocusRequester() } + val interactionSource = remember { MutableInteractionSource() } - // Hide config dialog. - onDismissed() - }, - ) + Dialog(onDismissRequest = onDismissed) { + val focusManager = LocalFocusManager.current + Card( + modifier = Modifier + .fillMaxWidth() + .clickable( + interactionSource = interactionSource, indication = null // Disable the ripple effect + ) { + focusManager.clearFocus() + }, shape = RoundedCornerShape(16.dp) + ) { + Column( + modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Dialog title and subtitle. + Column { + Text( + "Settings", + style = MaterialTheme.typography.titleLarge, + modifier = Modifier.padding(bottom = 8.dp) + ) + // Subtitle. + Text( + "App version: ${BuildConfig.VERSION_NAME}", + style = labelSmallNarrow, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.offset(y = (-6).dp) + ) + } + + Column( + modifier = Modifier + .verticalScroll(rememberScrollState()) + .weight(1f, fill = false), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Theme switcher. + Column( + modifier = Modifier.fillMaxWidth() + ) { + Text( + "Theme", + style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold) + ) + MultiChoiceSegmentedButtonRow { + THEME_OPTIONS.forEachIndexed { index, label -> + SegmentedButton(shape = SegmentedButtonDefaults.itemShape( + index = index, count = THEME_OPTIONS.size + ), onCheckedChange = { + selectedTheme = label + + // Update theme settings. + // This will update app's theme. + ThemeSettings.themeOverride.value = label + + // Save to data store. + modelManagerViewModel.saveThemeOverride(label) + }, checked = label == selectedTheme, label = { Text(label) }) + } + } + } + + // HF Token management. + Column( + modifier = Modifier.fillMaxWidth(), verticalArrangement = Arrangement.spacedBy(4.dp) + ) { + Text( + "HuggingFace access token", + style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold) + ) + // Show the start of the token. + val curHfToken = hfToken + if (curHfToken != null) { + Text( + curHfToken.accessToken.substring(0, min(16, curHfToken.accessToken.length)) + "...", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + "Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } else { + Text( + "Not available", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + "The token will be automatically retrieved when a gated model is downloaded", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) { + OutlinedButton( + onClick = { + modelManagerViewModel.clearAccessToken() + hfToken = null + }, enabled = curHfToken != null + ) { + Text("Clear") + } + BasicTextField( + value = customHfToken, + singleLine = true, + modifier = Modifier + .fillMaxWidth() + .padding(top = 4.dp) + .focusRequester(focusRequester) + .onFocusChanged { + isFocused = it.isFocused + }, + onValueChange = { + customHfToken = it + }, + textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface), + cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface), + ) { innerTextField -> + Box( + modifier = Modifier + .border( + width = if (isFocused) 2.dp else 1.dp, + color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline, + shape = CircleShape, + ) + .height(40.dp), contentAlignment = Alignment.CenterStart + ) { + Row(verticalAlignment = Alignment.CenterVertically) { + Box( + modifier = Modifier + .padding(start = 16.dp) + .weight(1f) + ) { + if (customHfToken.isEmpty()) { + Text( + "Enter token manually", + color = MaterialTheme.colorScheme.onSurfaceVariant, + style = MaterialTheme.typography.bodySmall + ) + } + innerTextField() + } + if (customHfToken.isNotEmpty()) { + IconButton( + modifier = Modifier.offset(x = 1.dp), + onClick = { + modelManagerViewModel.saveAccessToken( + accessToken = customHfToken, + refreshToken = "", + expiresAt = System.currentTimeMillis() + 1000L * 60 * 60 * 24 * 365 * 10, + ) + hfToken = modelManagerViewModel.getTokenStatusAndData().data + }) { + Icon(Icons.Rounded.CheckCircle, contentDescription = "") + } + } + } + } + } + } + } + } + + + // Button row. + Row( + modifier = Modifier + .fillMaxWidth() + .padding(top = 8.dp), + horizontalArrangement = Arrangement.End, + ) { + // Close button + Button( + onClick = { + onDismissed() + }, + ) { + Text("Close") + } + } + } + } + } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt index 740e4a1..2080ad7 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt @@ -34,9 +34,9 @@ object LlmChatDestination { val route = "LlmChatRoute" } -object LlmImageToTextDestination { +object LlmAskImageDestination { @Serializable - val route = "LlmImageToTextRoute" + val route = "LlmAskImageRoute" } @Composable @@ -57,11 +57,11 @@ fun LlmChatScreen( } @Composable -fun LlmImageToTextScreen( +fun LlmAskImageScreen( modelManagerViewModel: ModelManagerViewModel, navigateUp: () -> Unit, modifier: Modifier = Modifier, - viewModel: LlmImageToTextViewModel = viewModel( + viewModel: LlmAskImageViewModel = viewModel( factory = ViewModelProvider.Factory ), ) { @@ -129,12 +129,7 @@ fun ChatViewWrapper( }) } }, - onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations -> - if (message is ChatMessageText) { - viewModel.benchmark( - model = model, message = message - ) - } + onBenchmarkClicked = { _, _, _, _ -> }, onResetSessionClicked = { model -> viewModel.resetSession(model = model) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt index 3f4a2c6..91fabb3 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -22,7 +22,7 @@ import android.util.Log import androidx.lifecycle.viewModelScope import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.TASK_LLM_CHAT -import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT +import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading @@ -49,6 +49,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) { viewModelScope.launch(Dispatchers.Default) { setInProgress(true) + setPreparing(true) // Loading. addMessage( @@ -88,6 +89,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task timeToFirstToken = (firstTokenTs - start) / 1000f prefillSpeed = prefillTokens / timeToFirstToken firstRun = false + setPreparing(false) } else { decodeTokens++ } @@ -137,10 +139,12 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task }, cleanUpListener = { setInProgress(false) + setPreparing(false) }) } catch (e: Exception) { Log.e(TAG, "Error occurred while running inference", e) setInProgress(false) + setPreparing(false) onError() } } @@ -194,98 +198,6 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task } } - fun benchmark(model: Model, message: ChatMessageText) { - viewModelScope.launch(Dispatchers.Default) { - setInProgress(true) - - // Wait for model to be initialized. - while (model.instance == null) { - delay(100) - } - val instance = model.instance as LlmModelInstance - val prefillTokens = instance.session.sizeInTokens(message.content) - - // Add the message to show benchmark results. - val benchmarkLlmResult = ChatMessageBenchmarkLlmResult( - orderedStats = STATS, - statValues = mutableMapOf(), - running = true, - latencyMs = -1f, - ) - addMessage(model = model, message = benchmarkLlmResult) - - // Run inference. - val result = StringBuilder() - var firstRun = true - var timeToFirstToken = 0f - var firstTokenTs = 0L - var decodeTokens = 0 - var prefillSpeed = 0f - var decodeSpeed: Float - val start = System.currentTimeMillis() - var lastUpdateTime = 0L - LlmChatModelHelper.runInference(model = model, - input = message.content, - resultListener = { partialResult, done -> - val curTs = System.currentTimeMillis() - - if (firstRun) { - firstTokenTs = System.currentTimeMillis() - timeToFirstToken = (firstTokenTs - start) / 1000f - prefillSpeed = prefillTokens / timeToFirstToken - firstRun = false - - // Update message to show prefill speed. - replaceLastMessage( - model = model, - message = ChatMessageBenchmarkLlmResult( - orderedStats = STATS, - statValues = mutableMapOf( - "prefill_speed" to prefillSpeed, - "time_to_first_token" to timeToFirstToken, - "latency" to (curTs - start).toFloat() / 1000f, - ), - running = false, - latencyMs = -1f, - ), - type = ChatMessageType.BENCHMARK_LLM_RESULT, - ) - } else { - decodeTokens++ - } - result.append(partialResult) - - if (curTs - lastUpdateTime > 500 || done) { - decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f) - if (decodeSpeed.isNaN()) { - decodeSpeed = 0f - } - replaceLastMessage( - model = model, message = ChatMessageBenchmarkLlmResult( - orderedStats = STATS, - statValues = mutableMapOf( - "prefill_speed" to prefillSpeed, - "decode_speed" to decodeSpeed, - "time_to_first_token" to timeToFirstToken, - "latency" to (curTs - start).toFloat() / 1000f, - ), - running = !done, - latencyMs = -1f, - ), type = ChatMessageType.BENCHMARK_LLM_RESULT - ) - lastUpdateTime = curTs - - if (done) { - setInProgress(false) - } - } - }, - cleanUpListener = { - setInProgress(false) - }) - } - } - fun handleError( context: Context, model: Model, @@ -320,15 +232,8 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task ) // Re-generate the response automatically. - generateResponse(model = model, input = triggeredMessage.content, onError = { - handleError( - context = context, - model = model, - modelManagerViewModel = modelManagerViewModel, - triggeredMessage = triggeredMessage - ) - }) + generateResponse(model = model, input = triggeredMessage.content, onError = {}) } } -class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT) \ No newline at end of file +class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt index 1b96140..0748a8f 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt @@ -73,6 +73,7 @@ fun LlmSingleTurnScreen( ) { val task = viewModel.task val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() + val uiState by viewModel.uiState.collectAsState() val selectedModel = modelManagerUiState.selectedModel val scope = rememberCoroutineScope() val context = LocalContext.current @@ -114,6 +115,8 @@ fun LlmSingleTurnScreen( task = task, model = selectedModel, modelManagerViewModel = modelManagerViewModel, + inProgress = uiState.inProgress, + modelPreparing = uiState.preparing, onConfigChanged = { _, _ -> }, onBackClicked = { handleNavigateUp() }, onModelSelected = { newSelectedModel -> diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt index 21e1b3c..9f3b677 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt @@ -20,10 +20,9 @@ import android.util.Log import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.google.aiedge.gallery.data.Model -import com.google.aiedge.gallery.data.TASK_LLM_USECASES +import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult -import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading import com.google.aiedge.gallery.ui.common.chat.Stat import com.google.aiedge.gallery.ui.common.processLlmResponse import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper @@ -44,9 +43,9 @@ data class LlmSingleTurnUiState( val inProgress: Boolean = false, /** - * Indicates whether the model is currently being initialized. + * Indicates whether the model is preparing (before outputting any result and after initializing). */ - val initializing: Boolean = false, + val preparing: Boolean = false, // model ->