Various bug fixes.

This commit is contained in:
Jing Jin 2025-05-17 15:11:52 -07:00
parent 0c49efc054
commit 37a58d1a41
35 changed files with 1517 additions and 995 deletions

View file

@ -27,10 +27,10 @@ android {
defaultConfig {
applicationId = "com.google.aiedge.gallery"
minSdk = 24
minSdk = 26
targetSdk = 35
versionCode = 1
versionName = "20250428"
versionName = "0.9.0"
// Needed for HuggingFace auth workflows.
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"

View file

@ -41,7 +41,9 @@
android:name=".MainActivity"
android:exported="true"
android:theme="@style/Theme.Gallery.SplashScreen"
android:windowSoftInputMode="adjustResize">
android:screenOrientation="portrait"
android:windowSoftInputMode="adjustResize"
tools:ignore="DiscouragedApi,LockedOrientationActivity">
<!-- This is for putting the app into launcher -->
<intent-filter>
<action android:name="android.intent.action.MAIN" />

View file

@ -68,7 +68,6 @@ fun GalleryTopAppBar(
leftAction: AppBarAction? = null,
rightAction: AppBarAction? = null,
scrollBehavior: TopAppBarScrollBehavior? = null,
loadingHfModels: Boolean = false,
subtitle: String = "",
) {
CenterAlignedTopAppBar(
@ -151,28 +150,6 @@ fun GalleryTopAppBar(
}
}
// Click an icon to open "download manager".
AppBarActionType.DOWNLOAD_MANAGER -> {
if (loadingHfModels) {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
strokeWidth = 3.dp,
modifier = Modifier
.padding(end = 12.dp)
.size(20.dp)
)
}
// else {
// IconButton(onClick = rightAction.actionFn) {
// Icon(
// imageVector = Deployed_code,
// contentDescription = "",
// tint = MaterialTheme.colorScheme.primary
// )
// }
// }
}
AppBarActionType.MODEL_SELECTOR -> {
Text("ms")
}

View file

@ -37,7 +37,7 @@ import javax.crypto.SecretKey
data class AccessTokenData(
val accessToken: String,
val refreshToken: String,
val expiresAtSeconds: Long
val expiresAtMs: Long
)
interface DataStoreRepository {
@ -46,6 +46,7 @@ interface DataStoreRepository {
fun saveThemeOverride(theme: String)
fun readThemeOverride(): String
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
fun clearAccessTokenData()
fun readAccessTokenData(): AccessTokenData?
fun saveImportedModels(importedModels: List<ImportedModelInfo>)
fun readImportedModels(): List<ImportedModelInfo>
@ -135,6 +136,18 @@ class DefaultDataStoreRepository(
}
}
override fun clearAccessTokenData() {
return runBlocking {
dataStore.edit { preferences ->
preferences.remove(PreferencesKeys.ENCRYPTED_ACCESS_TOKEN)
preferences.remove(PreferencesKeys.ACCESS_TOKEN_IV)
preferences.remove(PreferencesKeys.ENCRYPTED_REFRESH_TOKEN)
preferences.remove(PreferencesKeys.REFRESH_TOKEN_IV)
preferences.remove(PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT)
}
}
}
override fun readAccessTokenData(): AccessTokenData? {
return runBlocking {
val preferences = dataStore.data.first()

View file

@ -43,7 +43,7 @@ data class AllowedModel(
// Config.
val isLlmModel =
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)
var configs: List<Config> = listOf()
if (isLlmModel) {
var defaultTopK: Int = DEFAULT_TOPK

View file

@ -32,9 +32,9 @@ enum class TaskType(val label: String, val id: String) {
TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"),
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
@ -93,7 +93,6 @@ val TASK_IMAGE_CLASSIFICATION = Task(
val TASK_LLM_CHAT = Task(
type = TaskType.LLM_CHAT,
icon = Icons.Outlined.Forum,
// models = MODELS_LLM,
models = mutableListOf(),
description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -101,10 +100,9 @@ val TASK_LLM_CHAT = Task(
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_LLM_USECASES = Task(
type = TaskType.LLM_USECASES,
val TASK_LLM_PROMPT_LAB = Task(
type = TaskType.LLM_PROMPT_LAB,
icon = Icons.Outlined.Widgets,
// models = MODELS_LLM,
models = mutableListOf(),
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -112,10 +110,9 @@ val TASK_LLM_USECASES = Task(
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_LLM_IMAGE_TO_TEXT = Task(
type = TaskType.LLM_IMAGE_TO_TEXT,
val TASK_LLM_ASK_IMAGE = Task(
type = TaskType.LLM_ASK_IMAGE,
icon = Icons.Outlined.Mms,
// models = MODELS_LLM,
models = mutableListOf(),
description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -135,12 +132,9 @@ val TASK_IMAGE_GENERATION = Task(
/** All tasks. */
val TASKS: List<Task> = listOf(
// TASK_TEXT_CLASSIFICATION,
// TASK_IMAGE_CLASSIFICATION,
// TASK_IMAGE_GENERATION,
TASK_LLM_USECASES,
TASK_LLM_ASK_IMAGE,
TASK_LLM_PROMPT_LAB,
TASK_LLM_CHAT,
TASK_LLM_IMAGE_TO_TEXT
)
fun getModelByName(name: String): Model? {

View file

@ -25,7 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
@ -63,9 +63,9 @@ object ViewModelProvider {
LlmSingleTurnViewModel()
}
// Initializer for LlmImageToTextViewModel.
// Initializer for LlmAskImageViewModel.
initializer {
LlmImageToTextViewModel()
LlmAskImageViewModel()
}
// Initializer for ImageGenerationViewModel.

View file

@ -26,6 +26,8 @@ import androidx.browser.customtabs.CustomTabsIntent
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.text.BasicText
import androidx.compose.foundation.text.TextAutoSize
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowForward
import androidx.compose.material.icons.rounded.Error
@ -48,6 +50,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@ -291,11 +294,32 @@ fun DownloadAndTryButton(
modifier = Modifier.padding(end = 4.dp)
)
val textColor = MaterialTheme.colorScheme.onPrimary
if (checkingToken) {
Text("Checking access...")
BasicText(
text = "Checking access...",
maxLines = 1,
color = { textColor },
style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased(
minFontSize = 8.sp,
maxFontSize = 14.sp,
stepSize = 1.sp
)
)
} else {
if (needToDownloadFirst) {
Text("Download & Try it", maxLines = 1)
BasicText(
text = "Download & Try",
maxLines = 1,
color = { textColor },
style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased(
minFontSize = 8.sp,
maxFontSize = 14.sp,
stepSize = 1.sp
)
)
} else {
Text("Try it", maxLines = 1)
}

View file

@ -63,6 +63,8 @@ fun ModelPageAppBar(
modelManagerViewModel: ModelManagerViewModel,
onBackClicked: () -> Unit,
onModelSelected: (Model) -> Unit,
inProgress: Boolean,
modelPreparing: Boolean,
modifier: Modifier = Modifier,
isResettingSession: Boolean = false,
onResetSessionClicked: (Model) -> Unit = {},
@ -129,15 +131,16 @@ fun ModelPageAppBar(
val isModelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (showConfigButton) {
val enableConfigButton = !isModelInitializing && !inProgress
IconButton(
onClick = {
showConfigDialog = true
},
enabled = !isModelInitializing,
enabled = enableConfigButton,
modifier = Modifier
.scale(0.75f)
.offset(x = configButtonOffset)
.alpha(if (isModelInitializing) 0.5f else 1f)
.alpha(if (!enableConfigButton) 0.5f else 1f)
) {
Icon(
imageVector = Icons.Rounded.Tune,
@ -154,14 +157,15 @@ fun ModelPageAppBar(
modifier = Modifier.size(16.dp)
)
} else {
val enableResetButton = !isModelInitializing && !modelPreparing
IconButton(
onClick = {
onResetSessionClicked(model)
},
enabled = !isModelInitializing,
enabled = enableResetButton,
modifier = Modifier
.scale(0.75f)
.alpha(if (isModelInitializing) 0.5f else 1f)
.alpha(if (!enableResetButton) 0.5f else 1f)
) {
Icon(
imageVector = Icons.Rounded.MapsUgc,

View file

@ -83,8 +83,11 @@ fun ModelPickerChipsPager(
val scope = rememberCoroutineScope()
val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current
val screenWidthDp =
remember { with(density) { windowInfo.containerSize.width.toDp() } }
val screenWidthDp = remember {
with(density) {
windowInfo.containerSize.width.toDp()
}
}
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
pageCount = { task.models.size })
@ -147,7 +150,7 @@ fun ModelPickerChipsPager(
}
Text(
model.name,
style = MaterialTheme.typography.labelSmall,
style = MaterialTheme.typography.labelMedium,
modifier = Modifier
.padding(start = 4.dp)
.widthIn(0.dp, screenWidthDp - 250.dp),

View file

@ -21,6 +21,7 @@ import android.content.Context
import android.content.pm.PackageManager
import android.net.Uri
import android.os.Build
import android.util.Log
import androidx.activity.compose.ManagedActivityResultLauncher
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
@ -39,7 +40,11 @@ import com.google.aiedge.gallery.ui.common.chat.Histogram
import com.google.aiedge.gallery.ui.common.chat.Stat
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
import kotlin.math.abs
import kotlin.math.ln
import kotlin.math.max
@ -488,3 +493,38 @@ fun processLlmResponse(response: String): String {
return newContent
}
@OptIn(ExperimentalSerializationApi::class)
inline fun <reified T> getJsonResponse(url: String): T? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET"
connection.connect()
val responseCode = connection.responseCode
if (responseCode == HttpURLConnection.HTTP_OK) {
val inputStream = connection.inputStream
val response = inputStream.bufferedReader().use { it.readText() }
// Parse JSON using kotlinx.serialization
val json = Json {
// Handle potential extra fields
ignoreUnknownKeys = true
allowComments = true
allowTrailingComma = true
}
val jsonObj = json.decodeFromString<T>(response)
return jsonObj
} else {
Log.e("AGUtils", "HTTP error: $responseCode")
}
} catch (e: Exception) {
Log.e(
"AGUtils",
"Error when getting json response: ${e.message}"
)
e.printStackTrace()
}
return null
}

View file

@ -16,9 +16,14 @@
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.animation.AnimatedContent
import androidx.compose.animation.ExperimentalSharedTransitionApi
import androidx.compose.animation.SharedTransitionLayout
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.gestures.detectTapGestures
import androidx.compose.foundation.gestures.detectTransformGestures
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
@ -28,6 +33,7 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.ime
import androidx.compose.foundation.layout.imePadding
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
@ -39,12 +45,15 @@ import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Timer
import androidx.compose.material.icons.rounded.Close
import androidx.compose.material.icons.rounded.ContentCopy
import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.SnackbarHost
@ -55,6 +64,7 @@ import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
@ -65,12 +75,16 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.RectangleShape
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.hapticfeedback.HapticFeedbackType
import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
import androidx.compose.ui.input.nestedscroll.NestedScrollSource
import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.layout.onSizeChanged
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalDensity
@ -80,6 +94,7 @@ import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.IntSize
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.google.aiedge.gallery.R
@ -102,7 +117,7 @@ enum class ChatInputType {
/**
* Composable function for the main chat panel, displaying messages and handling user input.
*/
@OptIn(ExperimentalMaterial3Api::class)
@OptIn(ExperimentalMaterial3Api::class, ExperimentalSharedTransitionApi::class)
@Composable
fun ChatPanel(
modelManagerViewModel: ModelManagerViewModel,
@ -127,6 +142,7 @@ fun ChatPanel(
val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope()
val haptic = LocalHapticFeedback.current
var selectedImageMessage by remember { mutableStateOf<ChatMessageImage?>(null) }
var curMessage by remember { mutableStateOf("") } // Correct state
val focusManager = LocalFocusManager.current
@ -200,13 +216,14 @@ fun ChatPanel(
}
}
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[selectedModel.name]
val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[selectedModel.name]
LaunchedEffect(modelInitializationStatus) {
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
}
SharedTransitionLayout(modifier = Modifier.fillMaxSize()) {
AnimatedContent(targetState = selectedImageMessage) { targetSelectedImageMessage ->
Column(
modifier = modifier.imePadding()
) {
@ -297,8 +314,7 @@ fun ChatPanel(
)
.background(backgroundColor)
if (message is ChatMessageText) {
messageBubbleModifier = messageBubbleModifier
.pointerInput(Unit) {
messageBubbleModifier = messageBubbleModifier.pointerInput(Unit) {
detectTapGestures(
onLongPress = {
haptic.performHapticFeedback(HapticFeedbackType.LongPress)
@ -316,7 +332,21 @@ fun ChatPanel(
is ChatMessageText -> MessageBodyText(message = message)
// Image
is ChatMessageImage -> MessageBodyImage(message = message)
is ChatMessageImage -> {
if (targetSelectedImageMessage != message) {
MessageBodyImage(
message = message,
modifier = Modifier
.clickable {
selectedImageMessage = message
}
.sharedElement(
sharedContentState = rememberSharedContentState(key = "selected_image"),
animatedVisibilityScope = this@AnimatedContent
),
)
}
}
// Image with history (for image gen)
is ChatMessageImageWithHistory -> MessageBodyImageWithHistory(
@ -325,8 +355,9 @@ fun ChatPanel(
// Classification result
is ChatMessageClassification -> MessageBodyClassification(
message = message,
modifier = Modifier.width(message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH)
message = message, modifier = Modifier.width(
message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH
)
)
// Benchmark result.
@ -334,8 +365,7 @@ fun ChatPanel(
// Benchmark LLM result.
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
message = message,
modifier = Modifier.wrapContentWidth()
message = message, modifier = Modifier.wrapContentWidth()
)
else -> {}
@ -348,7 +378,7 @@ fun ChatPanel(
) {
LatencyText(message = message)
// A button to show stats for the LLM message.
if (task.type == TaskType.LLM_CHAT && message is ChatMessageText
if ((task.type == TaskType.LLM_CHAT || task.type == TaskType.LLM_ASK_IMAGE) && message is ChatMessageText
// This means we only want to show the action button when the message is done
// generating, at which point the latency will be set.
&& message.latencyMs >= 0
@ -363,7 +393,10 @@ fun ChatPanel(
viewModel.toggleShowingStats(selectedModel, message)
// Add the stats message after the LLM message.
if (viewModel.isShowingStats(model = selectedModel, message = message)) {
if (viewModel.isShowingStats(
model = selectedModel, message = message
)
) {
val llmBenchmarkResult = message.llmBenchmarkResult
if (llmBenchmarkResult != null) {
viewModel.insertMessageAfter(
@ -376,10 +409,12 @@ fun ChatPanel(
// Remove the stats message.
else {
val curMessageIndex =
viewModel.getMessageIndex(model = selectedModel, message = message)
viewModel.removeMessageAt(
viewModel.getMessageIndex(
model = selectedModel,
index = curMessageIndex + 1
message = message
)
viewModel.removeMessageAt(
model = selectedModel, index = curMessageIndex + 1
)
}
},
@ -425,8 +460,23 @@ fun ChatPanel(
}
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
}
// Show an info message for ask image task to get users started.
if (task.type == TaskType.LLM_ASK_IMAGE && messages.isEmpty()) {
Column(
modifier = Modifier
.padding(horizontal = 16.dp)
.fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
MessageBodyInfo(
ChatMessageInfo(content = "To get started, click + below to add an image and type a prompt to ask a question about it."),
smallFontSize = false
)
}
}
}
// Chat input
when (chatInputType) {
@ -439,6 +489,7 @@ fun ChatPanel(
curMessage = curMessage,
inProgress = uiState.inProgress,
isResettingSession = uiState.isResettingSession,
modelPreparing = uiState.preparing,
hasImageMessage = hasImageMessage,
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
@ -459,7 +510,7 @@ fun ChatPanel(
onStopButtonClicked = onStopButtonClicked,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false,
showImagePickerInMenu = selectedModel.llmSupportImage == true,
showImagePickerInMenu = selectedModel.llmSupportImage,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
)
}
@ -488,6 +539,58 @@ fun ChatPanel(
}
}
// A full-screen image viewer.
if (targetSelectedImageMessage != null) {
ZoomableBox(
modifier = Modifier
.fillMaxSize()
.background(Color.Black.copy(alpha = 0.9f))
.sharedElement(
rememberSharedContentState(key = "bounds"),
animatedVisibilityScope = this,
)
.skipToLookaheadSize(),
) {
// Image.
Image(
bitmap = targetSelectedImageMessage.imageBitMap,
contentDescription = "",
modifier = modifier
.fillMaxSize()
.graphicsLayer(
scaleX = scale,
scaleY = scale,
translationX = offsetX,
translationY = offsetY
)
.sharedElement(
sharedContentState = rememberSharedContentState(key = "selected_image"),
animatedVisibilityScope = this@AnimatedContent
),
contentScale = ContentScale.Fit,
)
// Close button.
IconButton(
onClick = {
selectedImageMessage = null
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceVariant,
),
modifier = Modifier.offset(x = (-8).dp, y = 8.dp)
) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
}
}
}
// Error dialog.
if (showErrorDialog) {
Dialog(
@ -498,9 +601,7 @@ fun ChatPanel(
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title
Text(
@ -568,13 +669,10 @@ fun ChatPanel(
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier
.padding(vertical = 8.dp, horizontal = 16.dp)
modifier = Modifier.padding(vertical = 8.dp, horizontal = 16.dp)
) {
Icon(
Icons.Rounded.ContentCopy,
contentDescription = "",
modifier = Modifier.size(18.dp)
Icons.Rounded.ContentCopy, contentDescription = "", modifier = Modifier.size(18.dp)
)
Text("Copy text")
}
@ -586,6 +684,51 @@ fun ChatPanel(
}
}
@Composable
fun ZoomableBox(
modifier: Modifier = Modifier,
minScale: Float = 1f,
maxScale: Float = 5f,
content: @Composable ZoomableBoxScope.() -> Unit
) {
var scale by remember { mutableFloatStateOf(1f) }
var offsetX by remember { mutableFloatStateOf(0f) }
var offsetY by remember { mutableFloatStateOf(0f) }
var size by remember { mutableStateOf(IntSize.Zero) }
Box(
modifier = modifier
.clip(RectangleShape)
.onSizeChanged { size = it }
.pointerInput(Unit) {
detectTransformGestures { _, pan, zoom, _ ->
scale = maxOf(minScale, minOf(scale * zoom, maxScale))
val maxX = (size.width * (scale - 1)) / 2
val minX = -maxX
offsetX = maxOf(minX, minOf(maxX, offsetX + pan.x))
val maxY = (size.height * (scale - 1)) / 2
val minY = -maxY
offsetY = maxOf(minY, minOf(maxY, offsetY + pan.y))
}
},
contentAlignment = Alignment.TopEnd
) {
val scope = ZoomableBoxScopeImpl(scale, offsetX, offsetY)
scope.content()
}
}
interface ZoomableBoxScope {
val scale: Float
val offsetX: Float
val offsetY: Float
}
private data class ZoomableBoxScopeImpl(
override val scale: Float,
override val offsetX: Float,
override val offsetY: Float
) : ZoomableBoxScope
@Preview(showBackground = true)
@Composable
fun ChatPanelPreview() {

View file

@ -19,6 +19,7 @@ package com.google.aiedge.gallery.ui.common.chat
import android.util.Log
import androidx.activity.compose.BackHandler
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
@ -27,6 +28,7 @@ import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
@ -36,10 +38,13 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshotFlow
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
@ -81,7 +86,7 @@ fun ChatView(
chatInputType: ChatInputType = ChatInputType.TEXT,
showStopButtonInInputWhenInProgress: Boolean = false,
) {
val uiStat by viewModel.uiState.collectAsState()
val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel
@ -158,7 +163,9 @@ fun ChatView(
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
showResetSessionButton = true,
isResettingSession = uiStat.isResettingSession,
isResettingSession = uiState.isResettingSession,
inProgress = uiState.inProgress,
modelPreparing = uiState.preparing,
onResetSessionClicked = onResetSessionClicked,
onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage(

View file

@ -38,6 +38,11 @@ data class ChatUiState(
*/
val isResettingSession: Boolean = false,
/**
* Indicates whether the model is preparing (before outputting any result and after initializing).
*/
val preparing: Boolean = false,
/**
* A map of model names to lists of chat messages.
*/
@ -204,6 +209,10 @@ open class ChatViewModel(val task: Task) : ViewModel() {
_uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) }
}
fun setPreparing(preparing: Boolean) {
_uiState.update { _uiState.value.copy(preparing = preparing) }
}
fun addConfigChangedMessage(
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
) {

View file

@ -18,6 +18,8 @@ package com.google.aiedge.gallery.ui.common.chat
import android.util.Log
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
@ -28,9 +30,11 @@ import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.BasicTextField
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.MaterialTheme
@ -53,6 +57,9 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.focus.onFocusChanged
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
@ -89,9 +96,20 @@ fun ConfigDialog(
putAll(initialValues)
}
}
val interactionSource = remember { MutableInteractionSource() }
Dialog(onDismissRequest = onDismissed) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
val focusManager = LocalFocusManager.current
Card(
modifier = Modifier
.fillMaxWidth()
.clickable(
interactionSource = interactionSource, indication = null // Disable the ripple effect
) {
focusManager.clearFocus()
},
shape = RoundedCornerShape(16.dp)
) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
@ -114,7 +132,14 @@ fun ConfigDialog(
}
// List of config rows.
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
ConfigEditorsPanel(configs = configs, values = values)
}
// Button row.
Row(
@ -264,6 +289,8 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
values[config.key.label] = NaN
}
},
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
) { innerTextField ->
Box(
modifier = Modifier.border(

View file

@ -25,17 +25,18 @@ import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.unit.dp
@Composable
fun MessageBodyImage(message: ChatMessageImage) {
fun MessageBodyImage(message: ChatMessageImage, modifier: Modifier = Modifier) {
val bitmapWidth = message.bitmap.width
val bitmapHeight = message.bitmap.height
val imageWidth =
if (bitmapWidth >= bitmapHeight) 200 else (200f / bitmapHeight * bitmapWidth).toInt()
val imageHeight =
if (bitmapHeight >= bitmapWidth) 200 else (200f / bitmapWidth * bitmapHeight).toInt()
Image(
bitmap = message.imageBitMap,
contentDescription = "",
modifier = Modifier
modifier = modifier
.height(imageHeight.dp)
.width(imageWidth.dp),
contentScale = ContentScale.Fit,

View file

@ -38,7 +38,7 @@ import com.google.aiedge.gallery.ui.theme.customColors
* Supports markdown.
*/
@Composable
fun MessageBodyInfo(message: ChatMessageInfo) {
fun MessageBodyInfo(message: ChatMessageInfo, smallFontSize: Boolean = true) {
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
) {
@ -47,7 +47,11 @@ fun MessageBodyInfo(message: ChatMessageInfo) {
.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor)
) {
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
MarkdownText(
text = message.content,
modifier = Modifier.padding(12.dp),
smallFontSize = smallFontSize
)
}
}
}

View file

@ -103,6 +103,7 @@ fun MessageInputText(
@StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit,
onSendMessage: (List<ChatMessage>) -> Unit,
modelPreparing: Boolean = false,
onOpenPromptTemplatesClicked: () -> Unit = {},
onStopButtonClicked: () -> Unit = {},
showPromptTemplatesInMenu: Boolean = false,
@ -173,7 +174,7 @@ fun MessageInputText(
.height(80.dp)
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
.clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)),
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
)
Box(modifier = Modifier
.offset(x = 10.dp, y = (-10).dp)
@ -219,7 +220,7 @@ fun MessageInputText(
expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false }) {
if (showImagePickerInMenu) {
// Take a photo.
// Take a picture.
DropdownMenuItem(
text = {
Row(
@ -227,7 +228,7 @@ fun MessageInputText(
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
Text("Take a photo")
Text("Take a picture")
}
},
enabled = pickedImages.isEmpty() && !hasImageMessage,
@ -321,7 +322,7 @@ fun MessageInputText(
Spacer(modifier = Modifier.width(8.dp))
if (inProgress && showStopButtonWhenInProgress) {
if (!modelInitializing) {
if (!modelInitializing && !modelPreparing) {
IconButton(
onClick = onStopButtonClicked,
colors = IconButtonDefaults.iconButtonColors(

View file

@ -20,8 +20,9 @@ import android.content.Intent
import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.animation.core.animateFloatAsState
import androidx.compose.animation.core.tween
import androidx.compose.animation.AnimatedContent
import androidx.compose.animation.ExperimentalSharedTransitionApi
import androidx.compose.animation.SharedTransitionLayout
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
@ -30,17 +31,13 @@ import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.heightIn
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ChevronRight
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material.icons.rounded.UnfoldLess
import androidx.compose.material.icons.rounded.UnfoldMore
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.OutlinedButton
import androidx.compose.material3.Text
import androidx.compose.material3.ripple
@ -48,16 +45,12 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.movableContentOf
import androidx.compose.runtime.movableContentWithReceiverOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.layout.LookaheadScope
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp
@ -91,6 +84,7 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
* model description and buttons for learning more (opening a URL) and downloading/trying
* the model.
*/
@OptIn(ExperimentalSharedTransitionApi::class)
@Composable
fun ModelItem(
model: Model,
@ -117,37 +111,57 @@ fun ModelItem(
var isExpanded by remember { mutableStateOf(false) }
// Animate alpha for model description and button rows when switching between layouts.
val alphaAnimation by animateFloatAsState(
targetValue = if (isExpanded) 1f else 0f,
animationSpec = tween(durationMillis = LAYOUT_ANIMATION_DURATION - 50)
var boxModifier = modifier
.fillMaxWidth()
.clip(RoundedCornerShape(size = 42.dp))
.background(
getTaskBgColor(task)
)
boxModifier = if (canExpand) {
boxModifier.clickable(onClick = {
if (!model.imported) {
isExpanded = !isExpanded
} else {
onModelClicked(model)
}
}, interactionSource = remember { MutableInteractionSource() }, indication = ripple(
bounded = true,
radius = 1000.dp,
)
)
} else {
boxModifier
}
LookaheadScope {
// Task icon.
val taskIcon = remember {
movableContentOf {
Box(
modifier = boxModifier,
contentAlignment = Alignment.Center,
) {
SharedTransitionLayout {
AnimatedContent(
isExpanded, label = "item_layout_transition",
) { targetState ->
val taskIcon = @Composable {
TaskIcon(
task = task, modifier = Modifier.animateLayout()
task = task, modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "task_icon"),
animatedVisibilityScope = this@AnimatedContent,
)
)
}
}
// Model name and status.
val modelNameAndStatus = remember {
movableContentOf {
val modelNameAndStatus = @Composable {
ModelNameAndStatus(
model = model,
task = task,
downloadStatus = downloadStatus,
isExpanded = isExpanded,
modifier = Modifier.animateLayout()
animatedVisibilityScope = this@AnimatedContent,
sharedTransitionScope = this@SharedTransitionLayout
)
}
}
val actionButton = remember {
movableContentOf {
val actionButton = @Composable {
ModelItemActionButton(
context = context,
model = model,
@ -165,61 +179,48 @@ fun ModelItem(
},
showDeleteButton = showDeleteButton,
showDownloadButton = false,
modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "action_button"),
animatedVisibilityScope = this@AnimatedContent,
)
)
}
}
// Expand/collapse icon, or the config icon.
val expandButton = remember {
movableContentOf {
if (showConfigButtonIfExisted) {
if (downloadStatus?.status === ModelDownloadStatusType.SUCCEEDED) {
if (model.configs.isNotEmpty()) {
IconButton(onClick = onConfigClicked) {
Icon(
Icons.Rounded.Settings,
contentDescription = "",
tint = getTaskIconColor(task)
)
}
}
}
} else {
val expandButton = @Composable {
Icon(
// For imported model, show ">" directly indicating users can just tap the model item to
// go into it without needing to expand it first.
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "",
tint = getTaskIconColor(task),
modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "expand_button"),
animatedVisibilityScope = this@AnimatedContent,
)
)
}
}
}
// Model description shown in expanded layout.
val modelDescription = remember {
movableContentOf { m: Modifier ->
val description = @Composable {
if (model.info.isNotEmpty()) {
MarkdownText(
model.info,
modifier = Modifier
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp)
.animateLayout()
.then(m)
model.info, modifier = Modifier
.sharedElement(
sharedContentState = rememberSharedContentState(key = "description"),
animatedVisibilityScope = this@AnimatedContent,
)
.skipToLookaheadSize()
)
}
}
}
// Button rows shown in expanded layout.
val buttonRows = remember {
movableContentOf { m: Modifier ->
val buttonsRow = @Composable {
Row(
modifier = Modifier
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp)
.animateLayout()
.then(m),
horizontalArrangement = Arrangement.spacedBy(12.dp),
horizontalArrangement = Arrangement.spacedBy(12.dp), modifier = Modifier
.sharedElement(
sharedContentState = rememberSharedContentState(key = "buttons_row"),
animatedVisibilityScope = this@AnimatedContent,
)
.skipToLookaheadSize()
) {
// The "learn more" button. Click to show related urls in a bottom sheet.
if (model.learnMoreUrl.isNotEmpty()) {
@ -238,60 +239,42 @@ fun ModelItem(
// Button to start the download and start the chat session with the model.
val needToDownloadFirst =
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
DownloadAndTryButton(
task = task,
DownloadAndTryButton(task = task,
model = model,
enabled = isExpanded,
needToDownloadFirst = needToDownloadFirst,
modelManagerViewModel = modelManagerViewModel,
onClicked = { onModelClicked(model) }
)
}
onClicked = { onModelClicked(model) })
}
}
val container = remember {
movableContentWithReceiverOf<LookaheadScope, @Composable () -> Unit> { content ->
Box(
modifier = Modifier.animateLayout(),
contentAlignment = Alignment.TopEnd,
// Collapsed state.
if (!targetState) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
) {
content()
}
}
}
var boxModifier = modifier
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp),
modifier = Modifier
.fillMaxWidth()
.clip(RoundedCornerShape(size = 42.dp))
.background(
getTaskBgColor(task)
)
boxModifier = if (canExpand) {
boxModifier.clickable(
onClick = {
if (!model.imported) {
isExpanded = !isExpanded
} else {
onModelClicked(model)
}
},
interactionSource = remember { MutableInteractionSource() },
indication = ripple(
bounded = true,
radius = 500.dp,
)
)
} else {
boxModifier
}
Box(
modifier = boxModifier,
contentAlignment = Alignment.Center
.padding(start = 18.dp, end = 18.dp)
.padding(vertical = verticalSpacing)
) {
if (isExpanded) {
container {
// The main part (icon, model name, status, etc)
// Icon at the left.
taskIcon()
// Model name and status at the center.
Row(modifier = Modifier.weight(1f)) {
modelNameAndStatus()
}
// Action button and expand/collapse button at the right.
Row(verticalAlignment = Alignment.CenterVertically) {
actionButton()
expandButton()
}
}
}
} else {
Column(
verticalArrangement = Arrangement.spacedBy(14.dp),
horizontalAlignment = Alignment.CenterHorizontally,
@ -300,7 +283,9 @@ fun ModelItem(
.padding(vertical = verticalSpacing, horizontal = 18.dp)
) {
Box(contentAlignment = Alignment.Center) {
// Icon at the top-center.
taskIcon()
// Action button and expand/collapse button at the right.
Row(
modifier = Modifier.fillMaxWidth(),
verticalAlignment = Alignment.CenterVertically,
@ -310,39 +295,12 @@ fun ModelItem(
expandButton()
}
}
// Name and status below the icon.
modelNameAndStatus()
modelDescription(Modifier.alpha(alphaAnimation))
buttonRows(Modifier.alpha(alphaAnimation)) // Apply alpha here
}
}
} else {
container {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
// The main part (icon, model name, status, etc)
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp),
modifier = Modifier
.fillMaxWidth()
.padding(start = 18.dp, end = 18.dp)
.padding(vertical = verticalSpacing)
) {
taskIcon()
Row(modifier = Modifier.weight(1f)) {
modelNameAndStatus()
}
Row(verticalAlignment = Alignment.CenterVertically) {
actionButton()
expandButton()
}
}
Column(
modifier = Modifier.offset(y = 30.dp),
horizontalAlignment = Alignment.CenterHorizontally
) {
modelDescription(Modifier.alpha(alphaAnimation))
buttonRows(Modifier.alpha(alphaAnimation))
}
// Description.
description()
// Buttons
buttonsRow()
}
}
}

View file

@ -108,7 +108,8 @@ fun ModelItemActionButton(
// Button to cancel the download when it is in progress.
ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = {
modelManagerViewModel.cancelDownloadModel(
model
task = task,
model = model
)
}) {
Icon(

View file

@ -16,6 +16,9 @@
package com.google.aiedge.gallery.ui.common.modelitem
import androidx.compose.animation.AnimatedVisibilityScope
import androidx.compose.animation.ExperimentalSharedTransitionApi
import androidx.compose.animation.SharedTransitionScope
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween
import androidx.compose.foundation.layout.Column
@ -30,6 +33,7 @@ import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.focusModifier
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
@ -54,18 +58,22 @@ import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
* - "Unzipping..." status for unzipping processes.
* - Model size for successful downloads.
*/
@OptIn(ExperimentalSharedTransitionApi::class)
@Composable
fun ModelNameAndStatus(
model: Model,
task: Task,
downloadStatus: ModelDownloadStatus?,
isExpanded: Boolean,
sharedTransitionScope: SharedTransitionScope,
animatedVisibilityScope: AnimatedVisibilityScope,
modifier: Modifier = Modifier
) {
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
var curDownloadProgress = 0f
with(sharedTransitionScope) {
Column(
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
) {
@ -78,7 +86,10 @@ fun ModelNameAndStatus(
maxLines = 1,
overflow = TextOverflow.MiddleEllipsis,
style = MaterialTheme.typography.titleMedium,
modifier = modifier,
modifier = Modifier.sharedElement(
rememberSharedContentState(key = "model_name"),
animatedVisibilityScope = animatedVisibilityScope
)
)
}
@ -87,7 +98,12 @@ fun ModelNameAndStatus(
if (!inProgress && !isPartiallyDownloaded) {
StatusIcon(
downloadStatus = downloadStatus,
modifier = modifier.padding(end = 4.dp)
modifier = modifier
.padding(end = 4.dp)
.sharedElement(
rememberSharedContentState(key = "download_status_icon"),
animatedVisibilityScope = animatedVisibilityScope
)
)
}
@ -99,7 +115,10 @@ fun ModelNameAndStatus(
color = MaterialTheme.colorScheme.error,
style = labelSmallNarrow,
overflow = TextOverflow.Ellipsis,
modifier = modifier,
modifier = Modifier.sharedElement(
rememberSharedContentState(key = "failure_messsage"),
animatedVisibilityScope = animatedVisibilityScope
)
)
}
}
@ -154,7 +173,12 @@ fun ModelNameAndStatus(
style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp),
textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start,
overflow = TextOverflow.Visible,
modifier = modifier.offset(y = if (index == 0) 0.dp else (-1).dp)
modifier = Modifier
.offset(y = if (index == 0) 0.dp else (-1).dp)
.sharedElement(
rememberSharedContentState(key = "status_label_${index}"),
animatedVisibilityScope = animatedVisibilityScope
)
)
}
}
@ -168,7 +192,12 @@ fun ModelNameAndStatus(
progress = { animatedProgress.value },
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = modifier.padding(top = 2.dp)
modifier = Modifier
.padding(top = 2.dp)
.sharedElement(
rememberSharedContentState(key = "download_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope
)
)
LaunchedEffect(curDownloadProgress) {
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
@ -180,8 +209,13 @@ fun ModelNameAndStatus(
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.padding(top = 2.dp),
.padding(top = 2.dp)
.sharedElement(
rememberSharedContentState(key = "unzip_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope
)
)
}
}
}
}

View file

@ -16,8 +16,11 @@
package com.google.aiedge.gallery.ui.home
import android.app.Activity
import android.content.Context
import android.content.Intent
import android.net.Uri
import android.provider.OpenableColumns
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher
@ -49,6 +52,7 @@ import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults
import androidx.compose.material3.CircularProgressIndicator
@ -80,12 +84,18 @@ import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.scale
import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.layout.layout
import androidx.compose.ui.platform.LocalConfiguration
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalWindowInfo
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.LinkAnnotation
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.TextLinkStyles
import androidx.compose.ui.text.buildAnnotatedString
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextDecoration
import androidx.compose.ui.text.withLink
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
@ -93,7 +103,6 @@ import com.google.aiedge.gallery.GalleryTopAppBar
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.TaskIcon
@ -101,7 +110,6 @@ import com.google.aiedge.gallery.ui.common.getTaskBgColor
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.ThemeSettings
import com.google.aiedge.gallery.ui.theme.customColors
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
import kotlinx.coroutines.delay
@ -109,6 +117,12 @@ import kotlinx.coroutines.launch
private const val TAG = "AGHomeScreen"
private const val TASK_COUNT_ANIMATION_DURATION = 250
private const val MAX_TASK_CARD_PADDING = 24
private const val MIN_TASK_CARD_PADDING = 18
private const val MAX_TASK_CARD_RADIUS = 43.5
private const val MIN_TASK_CARD_RADIUS = 30
private const val MAX_TASK_CARD_ICON_SIZE = 56
private const val MIN_TASK_CARD_ICON_SIZE = 50
/** Navigation destination data */
object HomeScreenDestination {
@ -127,6 +141,7 @@ fun HomeScreen(
val uiState by modelManagerViewModel.uiState.collectAsState()
var showSettingsDialog by remember { mutableStateOf(false) }
var showImportModelSheet by remember { mutableStateOf(false) }
var showUnsupportedFileTypeDialog by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState()
var showImportDialog by remember { mutableStateOf(false) }
var showImportingDialog by remember { mutableStateOf(false) }
@ -135,17 +150,21 @@ fun HomeScreen(
val coroutineScope = rememberCoroutineScope()
val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope()
val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 }
val loadingHfModels = uiState.loadingHfModels
val context = LocalContext.current
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
val fileName = getFileName(context = context, uri = uri)
Log.d(TAG, "Selected file: $fileName")
if (fileName != null && !fileName.endsWith(".task")) {
showUnsupportedFileTypeDialog = true
} else {
selectedLocalModelFileUri.value = uri
showImportDialog = true
}
} ?: run {
Log.d(TAG, "No file selected or URI is null.")
}
@ -154,21 +173,15 @@ fun HomeScreen(
}
}
Scaffold(
modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
topBar = {
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = {
GalleryTopAppBar(
title = stringResource(HomeScreenDestination.titleRes),
rightAction = AppBarAction(
actionType = AppBarActionType.APP_SETTING, actionFn = {
rightAction = AppBarAction(actionType = AppBarActionType.APP_SETTING, actionFn = {
showSettingsDialog = true
}
),
loadingHfModels = loadingHfModels,
}),
scrollBehavior = scrollBehavior,
)
},
floatingActionButton = {
}, floatingActionButton = {
// A floating action button to show "import model" bottom sheet.
SmallFloatingActionButton(
onClick = {
@ -179,11 +192,10 @@ fun HomeScreen(
) {
Icon(Icons.Filled.Add, "")
}
}
) { innerPadding ->
}) { innerPadding ->
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
TaskList(
tasks = nonEmptyTasks,
tasks = uiState.tasks,
navigateToTaskScreen = navigateToTaskScreen,
loadingModelAllowlist = uiState.loadingModelAllowlist,
modifier = Modifier.fillMaxSize(),
@ -198,16 +210,8 @@ fun HomeScreen(
if (showSettingsDialog) {
SettingsDialog(
curThemeOverride = modelManagerViewModel.readThemeOverride(),
modelManagerViewModel = modelManagerViewModel,
onDismissed = { showSettingsDialog = false },
onOk = { curConfigValues ->
// Update theme settings.
// This will update app's theme.
val themeOverride = curConfigValues[ConfigKey.THEME.label] as String
ThemeSettings.themeOverride.value = themeOverride
// Save to data store.
modelManagerViewModel.saveThemeOverride(themeOverride)
},
)
}
@ -232,10 +236,6 @@ fun HomeScreen(
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*"
putExtra(
Intent.EXTRA_MIME_TYPES,
arrayOf("application/x-binary", "application/octet-stream")
)
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
}
@ -259,9 +259,7 @@ fun HomeScreen(
// Import dialog
if (showImportDialog) {
selectedLocalModelFileUri.value?.let { uri ->
ModelImportDialog(uri = uri,
onDismiss = { showImportDialog = false },
onDone = { info ->
ModelImportDialog(uri = uri, onDismiss = { showImportDialog = false }, onDone = { info ->
selectedImportedModelInfo.value = info
showImportDialog = false
showImportingDialog = true
@ -273,8 +271,7 @@ fun HomeScreen(
if (showImportingDialog) {
selectedLocalModelFileUri.value?.let { uri ->
selectedImportedModelInfo.value?.let { info ->
ModelImportingDialog(
uri = uri,
ModelImportingDialog(uri = uri,
info = info,
onDismiss = { showImportingDialog = false },
onDone = {
@ -292,6 +289,22 @@ fun HomeScreen(
}
}
// Alert dialog for unsupported file type.
if (showUnsupportedFileTypeDialog) {
AlertDialog(
onDismissRequest = { showUnsupportedFileTypeDialog = false },
title = { Text("Unsupported file type") },
text = {
Text("Only \".task\" file type is supported.")
},
confirmButton = {
Button(onClick = { showUnsupportedFileTypeDialog = false }) {
Text(stringResource(R.string.ok))
}
},
)
}
if (uiState.loadingModelAllowlistError.isNotEmpty()) {
AlertDialog(
icon = {
@ -307,11 +320,9 @@ fun HomeScreen(
modelManagerViewModel.loadModelAllowlist()
},
confirmButton = {
TextButton(
onClick = {
TextButton(onClick = {
modelManagerViewModel.loadModelAllowlist()
}
) {
}) {
Text("Retry")
}
},
@ -327,6 +338,38 @@ private fun TaskList(
modifier: Modifier = Modifier,
contentPadding: PaddingValues = PaddingValues(0.dp),
) {
val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current
val screenWidthDp = remember {
with(density) {
windowInfo.containerSize.width.toDp()
}
}
val screenHeightDp = remember {
with(density) {
windowInfo.containerSize.height.toDp()
}
}
val sizeFraction = remember { ((screenWidthDp - 360.dp) / (410.dp - 360.dp)).coerceIn(0f, 1f) }
val linkColor = MaterialTheme.customColors.linkColor
val introText = buildAnnotatedString {
append("Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from ")
withLink(
link = LinkAnnotation.Url(
url = "https://huggingface.co/litert-community", // Replace with the actual URL
styles = TextLinkStyles(
style = SpanStyle(
color = linkColor,
textDecoration = TextDecoration.Underline,
)
)
)
) {
append("LiteRT community")
}
}
Box(modifier = modifier.fillMaxSize()) {
LazyVerticalGrid(
columns = GridCells.Fixed(count = 2),
@ -335,10 +378,15 @@ private fun TaskList(
horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// New rel
item(key = "newReleaseNotification", span = { GridItemSpan(2) }) {
NewReleaseNotification()
}
// Headline.
item(key = "headline", span = { GridItemSpan(2) }) {
Text(
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community",
introText,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.padding(bottom = 20.dp)
@ -364,14 +412,21 @@ private fun TaskList(
}
}
} else {
// Cards.
// LLM Cards.
item(key = "llmCardsHeader", span = { GridItemSpan(2) }) {
Text(
"Example LLM Use Cases",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.padding(bottom = 4.dp)
)
}
items(tasks) { task ->
TaskCard(
task = task,
onClick = {
sizeFraction = sizeFraction, task = task, onClick = {
navigateToTaskScreen(task)
},
modifier = Modifier
}, modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f)
)
@ -388,7 +443,7 @@ private fun TaskList(
Box(
modifier = Modifier
.fillMaxWidth()
.height(LocalConfiguration.current.screenHeightDp.dp * 0.25f)
.height(screenHeightDp * 0.25f)
.background(
Brush.verticalGradient(
colors = MaterialTheme.customColors.homeBottomGradient,
@ -400,7 +455,15 @@ private fun TaskList(
}
@Composable
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) {
private fun TaskCard(
task: Task, onClick: () -> Unit, sizeFraction: Float, modifier: Modifier = Modifier
) {
val padding =
(MAX_TASK_CARD_PADDING - MIN_TASK_CARD_PADDING) * sizeFraction + MIN_TASK_CARD_PADDING
val radius = (MAX_TASK_CARD_RADIUS - MIN_TASK_CARD_RADIUS) * sizeFraction + MIN_TASK_CARD_RADIUS
val iconSize =
(MAX_TASK_CARD_ICON_SIZE - MIN_TASK_CARD_ICON_SIZE) * sizeFraction + MIN_TASK_CARD_ICON_SIZE
// Observes the model count and updates the model count label with a fade-in/fade-out animation
// whenever the count changes.
val modelCount by remember {
@ -445,7 +508,7 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
Card(
modifier = modifier
.clip(RoundedCornerShape(43.5.dp))
.clip(RoundedCornerShape(radius.dp))
.clickable(
onClick = onClick,
),
@ -456,39 +519,24 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
Column(
modifier = Modifier
.fillMaxSize()
.padding(24.dp),
.padding(padding.dp),
) {
// Icon.
TaskIcon(task = task)
TaskIcon(task = task, width = iconSize.dp)
Spacer(modifier = Modifier.weight(1f))
Spacer(modifier = Modifier.weight(2f))
// Title.
val pair = task.type.label.splitByFirstSpace()
Text(
pair.first,
task.type.label,
color = MaterialTheme.colorScheme.primary,
style = titleMediumNarrow.copy(
fontSize = 20.sp,
fontWeight = FontWeight.Bold,
),
)
if (pair.second.isNotEmpty()) {
Text(
pair.second,
color = MaterialTheme.colorScheme.primary,
style = titleMediumNarrow.copy(
fontSize = 18.sp,
fontWeight = FontWeight.Bold,
),
modifier = Modifier.layout { measurable, constraints ->
val placeable = measurable.measure(constraints)
layout(placeable.width, placeable.height) {
placeable.placeRelative(0, -4.dp.roundToPx())
}
}
)
}
Spacer(modifier = Modifier.weight(1f))
// Model count.
Text(
@ -503,12 +551,21 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
}
}
private fun String.splitByFirstSpace(): Pair<String, String> {
val spaceIndex = this.indexOf(' ')
if (spaceIndex == -1) {
return Pair(this, "")
// Helper function to get the file name from a URI
fun getFileName(context: Context, uri: Uri): String? {
if (uri.scheme == "content") {
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
val nameIndex = cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME)
if (nameIndex != -1) {
return cursor.getString(nameIndex)
}
return Pair(this.substring(0, spaceIndex), this.substring(spaceIndex + 1))
}
}
} else if (uri.scheme == "file") {
return uri.lastPathSegment
}
return null
}
@Preview

View file

@ -22,12 +22,16 @@ import android.provider.OpenableColumns
import android.util.Log
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween
import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.Button
@ -51,6 +55,7 @@ import androidx.compose.runtime.snapshots.SnapshotStateMap
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties
@ -82,41 +87,34 @@ import java.nio.charset.StandardCharsets
private const val TAG = "AGModelImportDialog"
private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
LabelConfig(key = ConfigKey.NAME),
LabelConfig(key = ConfigKey.MODEL_TYPE),
NumberSliderConfig(
LabelConfig(key = ConfigKey.NAME), LabelConfig(key = ConfigKey.MODEL_TYPE), NumberSliderConfig(
key = ConfigKey.DEFAULT_MAX_TOKENS,
sliderMin = 100f,
sliderMax = 1024f,
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
), NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = DEFAULT_TOPK.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
), NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = DEFAULT_TOPP,
valueType = ValueType.FLOAT
),
NumberSliderConfig(
), NumberSliderConfig(
key = ConfigKey.DEFAULT_TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = DEFAULT_TEMPERATURE,
valueType = ValueType.FLOAT
),
BooleanSwitchConfig(
), BooleanSwitchConfig(
key = ConfigKey.SUPPORT_IMAGE,
defaultValue = false,
),
SegmentedButtonConfig(
), SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
@ -126,9 +124,7 @@ private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
@Composable
fun ModelImportDialog(
uri: Uri,
onDismiss: () -> Unit,
onDone: (ImportedModelInfo) -> Unit
uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
) {
val context = LocalContext.current
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
@ -150,15 +146,23 @@ fun ModelImportDialog(
putAll(initialValues)
}
}
val interactionSource = remember { MutableInteractionSource() }
Dialog(
onDismissRequest = onDismiss,
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
val focusManager = LocalFocusManager.current
Card(
modifier = Modifier
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
.fillMaxWidth()
.clickable(
interactionSource = interactionSource, indication = null // Disable the ripple effect
) {
focusManager.clearFocus()
}, shape = RoundedCornerShape(16.dp)
) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title.
Text(
@ -167,11 +171,18 @@ fun ModelImportDialog(
modifier = Modifier.padding(bottom = 8.dp)
)
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Default configs for users to set.
ConfigEditorsPanel(
configs = IMPORT_CONFIGS_LLM,
values = values,
)
}
// Button row.
Row(
@ -210,10 +221,7 @@ fun ModelImportDialog(
@Composable
fun ModelImportingDialog(
uri: Uri,
info: ImportedModelInfo,
onDismiss: () -> Unit,
onDone: (ImportedModelInfo) -> Unit
uri: Uri, info: ImportedModelInfo, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
) {
var error by remember { mutableStateOf("") }
val context = LocalContext.current
@ -222,8 +230,7 @@ fun ModelImportingDialog(
LaunchedEffect(Unit) {
// Import.
importModel(
context = context,
importModel(context = context,
coroutineScope = coroutineScope,
fileName = info.fileName,
fileSize = info.fileSize,
@ -236,8 +243,7 @@ fun ModelImportingDialog(
},
onError = {
error = it
}
)
})
}
Dialog(
@ -246,9 +252,7 @@ fun ModelImportingDialog(
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title.
Text(
@ -280,13 +284,10 @@ fun ModelImportingDialog(
// Has error.
else {
Row(
verticalAlignment = Alignment.Top,
horizontalArrangement = Arrangement.spacedBy(6.dp)
verticalAlignment = Alignment.Top, horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(
Icons.Rounded.Error,
contentDescription = "",
tint = MaterialTheme.colorScheme.error
Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error
)
Text(
error,

View file

@ -0,0 +1,121 @@
package com.google.aiedge.gallery.ui.home
import android.util.Log
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.expandVertically
import androidx.compose.animation.fadeIn
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.OpenInNew
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.BuildConfig
import com.google.aiedge.gallery.ui.common.getJsonResponse
import com.google.aiedge.gallery.ui.modelmanager.ClickableLink
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import kotlin.math.max
private const val TAG = "AGNewReleaseNotification"
private const val REPO = "google-ai-edge/gallery"
@Serializable
data class ReleaseInfo(
val html_url: String,
val tag_name: String,
)
@Composable
fun NewReleaseNotification() {
var newReleaseVersion by remember { mutableStateOf("") }
var newReleaseUrl by remember { mutableStateOf("") }
LaunchedEffect(Unit) {
withContext(Dispatchers.IO) {
Log.d("AGNewReleaseNotification", "Checking for new release...")
val info = getJsonResponse<ReleaseInfo>("https://api.github.com/repos/$REPO/releases/latest")
if (info != null) {
val curRelease = BuildConfig.VERSION_NAME
val newRelease = info.tag_name
val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease)
Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer")
if (isNewer) {
newReleaseVersion = newRelease
newReleaseUrl = info.html_url
}
}
}
}
AnimatedVisibility(
visible = newReleaseVersion.isNotEmpty(),
enter = fadeIn() + expandVertically()
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier
.padding(horizontal = 16.dp)
.padding(bottom = 12.dp)
.clip(
CircleShape
)
.background(MaterialTheme.colorScheme.tertiaryContainer)
.padding(4.dp)
) {
Text(
"New release $newReleaseVersion available",
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier.padding(start = 12.dp)
)
Row(
modifier = Modifier.padding(end = 12.dp),
verticalAlignment = Alignment.CenterVertically
) {
ClickableLink(
url = newReleaseUrl,
linkText = "View",
icon = Icons.AutoMirrored.Rounded.OpenInNew,
)
}
}
}
}
fun isNewerRelease(currentRelease: String, newRelease: String): Boolean {
// Split the version strings into their individual components (e.g., "0.9.0" -> ["0", "9", "0"])
val currentComponents = currentRelease.split('.').map { it.toIntOrNull() ?: 0 }
val newComponents = newRelease.split('.').map { it.toIntOrNull() ?: 0 }
// Determine the maximum number of components to iterate through
val maxComponents = max(currentComponents.size, newComponents.size)
// Iterate through the components from left to right (major, minor, patch, etc.)
for (i in 0 until maxComponents) {
val currentComponent = currentComponents.getOrElse(i) { 0 }
val newComponent = newComponents.getOrElse(i) { 0 }
if (newComponent > currentComponent) {
return true
} else if (newComponent < currentComponent) {
return false
}
}
return false
}

View file

@ -16,45 +16,266 @@
package com.google.aiedge.gallery.ui.home
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.BasicTextField
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.CheckCircle
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.MultiChoiceSegmentedButtonRow
import androidx.compose.material3.OutlinedButton
import androidx.compose.material3.SegmentedButton
import androidx.compose.material3.SegmentedButtonDefaults
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.focus.onFocusChanged
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.google.aiedge.gallery.BuildConfig
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.THEME_AUTO
import com.google.aiedge.gallery.ui.theme.THEME_DARK
import com.google.aiedge.gallery.ui.theme.THEME_LIGHT
import com.google.aiedge.gallery.ui.theme.ThemeSettings
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
import java.time.Instant
import java.time.ZoneId
import java.time.format.DateTimeFormatter
import java.util.Locale
import kotlin.math.min
private val CONFIGS: List<Config> = listOf(
SegmentedButtonConfig(
key = ConfigKey.THEME,
defaultValue = THEME_AUTO,
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK),
)
)
private val THEME_OPTIONS = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK)
@Composable
fun SettingsDialog(
curThemeOverride: String,
modelManagerViewModel: ModelManagerViewModel,
onDismissed: () -> Unit,
onOk: (Map<String, Any>) -> Unit,
) {
val initialValues = mapOf(
ConfigKey.THEME.label to curThemeOverride
)
ConfigDialog(
title = "Settings",
subtitle = "App version: ${BuildConfig.VERSION_NAME}",
okBtnLabel = "OK",
configs = CONFIGS,
initialValues = initialValues,
onDismissed = onDismissed,
onOk = { curConfigValues ->
onOk(curConfigValues)
var selectedTheme by remember { mutableStateOf(curThemeOverride) }
var hfToken by remember { mutableStateOf(modelManagerViewModel.getTokenStatusAndData().data) }
val dateFormatter = remember {
DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.systemDefault())
.withLocale(Locale.getDefault())
}
var customHfToken by remember { mutableStateOf("") }
var isFocused by remember { mutableStateOf(false) }
val focusRequester = remember { FocusRequester() }
val interactionSource = remember { MutableInteractionSource() }
// Hide config dialog.
Dialog(onDismissRequest = onDismissed) {
val focusManager = LocalFocusManager.current
Card(
modifier = Modifier
.fillMaxWidth()
.clickable(
interactionSource = interactionSource, indication = null // Disable the ripple effect
) {
focusManager.clearFocus()
}, shape = RoundedCornerShape(16.dp)
) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Dialog title and subtitle.
Column {
Text(
"Settings",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
// Subtitle.
Text(
"App version: ${BuildConfig.VERSION_NAME}",
style = labelSmallNarrow,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.offset(y = (-6).dp)
)
}
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Theme switcher.
Column(
modifier = Modifier.fillMaxWidth()
) {
Text(
"Theme",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
)
MultiChoiceSegmentedButtonRow {
THEME_OPTIONS.forEachIndexed { index, label ->
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
index = index, count = THEME_OPTIONS.size
), onCheckedChange = {
selectedTheme = label
// Update theme settings.
// This will update app's theme.
ThemeSettings.themeOverride.value = label
// Save to data store.
modelManagerViewModel.saveThemeOverride(label)
}, checked = label == selectedTheme, label = { Text(label) })
}
}
}
// HF Token management.
Column(
modifier = Modifier.fillMaxWidth(), verticalArrangement = Arrangement.spacedBy(4.dp)
) {
Text(
"HuggingFace access token",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
)
// Show the start of the token.
val curHfToken = hfToken
if (curHfToken != null) {
Text(
curHfToken.accessToken.substring(0, min(16, curHfToken.accessToken.length)) + "...",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
} else {
Text(
"Not available",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
"The token will be automatically retrieved when a gated model is downloaded",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
OutlinedButton(
onClick = {
modelManagerViewModel.clearAccessToken()
hfToken = null
}, enabled = curHfToken != null
) {
Text("Clear")
}
BasicTextField(
value = customHfToken,
singleLine = true,
modifier = Modifier
.fillMaxWidth()
.padding(top = 4.dp)
.focusRequester(focusRequester)
.onFocusChanged {
isFocused = it.isFocused
},
onValueChange = {
customHfToken = it
},
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
) { innerTextField ->
Box(
modifier = Modifier
.border(
width = if (isFocused) 2.dp else 1.dp,
color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline,
shape = CircleShape,
)
.height(40.dp), contentAlignment = Alignment.CenterStart
) {
Row(verticalAlignment = Alignment.CenterVertically) {
Box(
modifier = Modifier
.padding(start = 16.dp)
.weight(1f)
) {
if (customHfToken.isEmpty()) {
Text(
"Enter token manually",
color = MaterialTheme.colorScheme.onSurfaceVariant,
style = MaterialTheme.typography.bodySmall
)
}
innerTextField()
}
if (customHfToken.isNotEmpty()) {
IconButton(
modifier = Modifier.offset(x = 1.dp),
onClick = {
modelManagerViewModel.saveAccessToken(
accessToken = customHfToken,
refreshToken = "",
expiresAt = System.currentTimeMillis() + 1000L * 60 * 60 * 24 * 365 * 10,
)
hfToken = modelManagerViewModel.getTokenStatusAndData().data
}) {
Icon(Icons.Rounded.CheckCircle, contentDescription = "")
}
}
}
}
}
}
}
}
// Button row.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
// Close button
Button(
onClick = {
onDismissed()
},
)
) {
Text("Close")
}
}
}
}
}
}

View file

@ -34,9 +34,9 @@ object LlmChatDestination {
val route = "LlmChatRoute"
}
object LlmImageToTextDestination {
object LlmAskImageDestination {
@Serializable
val route = "LlmImageToTextRoute"
val route = "LlmAskImageRoute"
}
@Composable
@ -57,11 +57,11 @@ fun LlmChatScreen(
}
@Composable
fun LlmImageToTextScreen(
fun LlmAskImageScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmImageToTextViewModel = viewModel(
viewModel: LlmAskImageViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
@ -129,12 +129,7 @@ fun ChatViewWrapper(
})
}
},
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
if (message is ChatMessageText) {
viewModel.benchmark(
model = model, message = message
)
}
onBenchmarkClicked = { _, _, _, _ ->
},
onResetSessionClicked = { model ->
viewModel.resetSession(model = model)

View file

@ -22,7 +22,7 @@ import android.util.Log
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
@ -49,6 +49,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
setPreparing(true)
// Loading.
addMessage(
@ -88,6 +89,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken
firstRun = false
setPreparing(false)
} else {
decodeTokens++
}
@ -137,10 +139,12 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
},
cleanUpListener = {
setInProgress(false)
setPreparing(false)
})
} catch (e: Exception) {
Log.e(TAG, "Error occurred while running inference", e)
setInProgress(false)
setPreparing(false)
onError()
}
}
@ -194,98 +198,6 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
}
}
fun benchmark(model: Model, message: ChatMessageText) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
val instance = model.instance as LlmModelInstance
val prefillTokens = instance.session.sizeInTokens(message.content)
// Add the message to show benchmark results.
val benchmarkLlmResult = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(),
running = true,
latencyMs = -1f,
)
addMessage(model = model, message = benchmarkLlmResult)
// Run inference.
val result = StringBuilder()
var firstRun = true
var timeToFirstToken = 0f
var firstTokenTs = 0L
var decodeTokens = 0
var prefillSpeed = 0f
var decodeSpeed: Float
val start = System.currentTimeMillis()
var lastUpdateTime = 0L
LlmChatModelHelper.runInference(model = model,
input = message.content,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
if (firstRun) {
firstTokenTs = System.currentTimeMillis()
timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken
firstRun = false
// Update message to show prefill speed.
replaceLastMessage(
model = model,
message = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = false,
latencyMs = -1f,
),
type = ChatMessageType.BENCHMARK_LLM_RESULT,
)
} else {
decodeTokens++
}
result.append(partialResult)
if (curTs - lastUpdateTime > 500 || done) {
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) {
decodeSpeed = 0f
}
replaceLastMessage(
model = model, message = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = !done,
latencyMs = -1f,
), type = ChatMessageType.BENCHMARK_LLM_RESULT
)
lastUpdateTime = curTs
if (done) {
setInProgress(false)
}
}
},
cleanUpListener = {
setInProgress(false)
})
}
}
fun handleError(
context: Context,
model: Model,
@ -320,15 +232,8 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
)
// Re-generate the response automatically.
generateResponse(model = model, input = triggeredMessage.content, onError = {
handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = triggeredMessage
)
})
generateResponse(model = model, input = triggeredMessage.content, onError = {})
}
}
class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT)
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)

View file

@ -73,6 +73,7 @@ fun LlmSingleTurnScreen(
) {
val task = viewModel.task
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val uiState by viewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel
val scope = rememberCoroutineScope()
val context = LocalContext.current
@ -114,6 +115,8 @@ fun LlmSingleTurnScreen(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
inProgress = uiState.inProgress,
modelPreparing = uiState.preparing,
onConfigChanged = { _, _ -> },
onBackClicked = { handleNavigateUp() },
onModelSelected = { newSelectedModel ->

View file

@ -20,10 +20,9 @@ import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
import com.google.aiedge.gallery.ui.common.chat.Stat
import com.google.aiedge.gallery.ui.common.processLlmResponse
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
@ -44,9 +43,9 @@ data class LlmSingleTurnUiState(
val inProgress: Boolean = false,
/**
* Indicates whether the model is currently being initialized.
* Indicates whether the model is preparing (before outputting any result and after initializing).
*/
val initializing: Boolean = false,
val preparing: Boolean = false,
// model -> <template label -> response>
val responsesByModel: Map<String, Map<String, String>>,
@ -65,14 +64,14 @@ private val STATS = listOf(
Stat(id = "latency", label = "Latency", unit = "sec")
)
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewModel() {
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow()
fun generateResponse(model: Model, input: String) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
setInitializing(true)
setPreparing(true)
// Wait for instance to be initialized.
while (model.instance == null) {
@ -98,7 +97,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
val curTs = System.currentTimeMillis()
if (firstRun) {
setInitializing(false)
setPreparing(false)
firstTokenTs = System.currentTimeMillis()
timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken
@ -148,7 +147,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
},
singleTurn = true,
cleanUpListener = {
setInitializing(false)
setPreparing(false)
setInProgress(false)
})
}
@ -167,8 +166,8 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
}
fun setInitializing(initializing: Boolean) {
_uiState.update { _uiState.value.copy(initializing = initializing) }
fun setPreparing(preparing: Boolean) {
_uiState.update { _uiState.value.copy(preparing = preparing) }
}
fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) {

View file

@ -339,7 +339,7 @@ fun PromptTemplatesPanel(
val modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (inProgress && !modelInitializing) {
if (inProgress && !modelInitializing && !uiState.preparing) {
IconButton(
onClick = {
onStopButtonClicked(model)

View file

@ -57,7 +57,7 @@ import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
@ -76,11 +76,11 @@ fun ResponsePanel(
modelManagerViewModel: ModelManagerViewModel,
modifier: Modifier = Modifier,
) {
val task = TASK_LLM_USECASES
val task = TASK_LLM_PROMPT_LAB
val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val inProgress = uiState.inProgress
val initializing = uiState.initializing
val initializing = uiState.preparing
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val responseScrollState = rememberScrollState()
var selectedOptionIndex by remember { mutableIntStateOf(0) }

View file

@ -70,7 +70,7 @@ fun ModelList(
) {
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
// be properly updated.
val models by remember {
val models by remember(task) {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
@ -80,7 +80,7 @@ fun ModelList(
}
}
}
val importedModels by remember {
val importedModels by remember(task) {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {

View file

@ -37,15 +37,16 @@ import com.google.aiedge.gallery.data.ModelAllowlist
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.common.AuthConfig
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.common.getJsonResponse
import com.google.aiedge.gallery.ui.common.processTasks
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
@ -59,8 +60,6 @@ import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import net.openid.appauth.AuthorizationException
import net.openid.appauth.AuthorizationRequest
import net.openid.appauth.AuthorizationResponse
@ -73,7 +72,7 @@ import java.net.URL
private const val TAG = "AGModelManagerViewModel"
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
private const val MODEL_ALLOWLIST_URL =
"https://raw.githubusercontent.com/jinjingforever/kokoro-codelab-jingjin/refs/heads/main/model_allowlist.json"
"https://raw.githubusercontent.com/google-ai-edge/gallery/refs/heads/main/model_allowlist.json"
data class ModelInitializationStatus(
val status: ModelInitializationStatusType, var error: String = ""
@ -116,11 +115,6 @@ data class ModelManagerUiState(
*/
val modelInitializationStatus: Map<String, ModelInitializationStatus>,
/**
* Whether Hugging Face models from the given community are currently being loaded.
*/
val loadingHfModels: Boolean = false,
/**
* Whether the app is loading and processing the model allowlist.
*/
@ -196,8 +190,9 @@ open class ModelManagerViewModel(
)
}
fun cancelDownloadModel(model: Model) {
fun cancelDownloadModel(task: Task, model: Model) {
downloadRepository.cancelDownloadModel(model)
deleteModel(task = task, model = model)
}
fun deleteModel(task: Task, model: Model) {
@ -311,13 +306,13 @@ open class ModelManagerViewModel(
onDone = onDone,
)
TaskType.LLM_USECASES -> LlmChatModelHelper.initialize(
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
)
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.initialize(
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
@ -341,8 +336,8 @@ open class ModelManagerViewModel(
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
@ -446,14 +441,14 @@ open class ModelManagerViewModel(
// Create model.
val model = createModelFromImportedModelInfo(info = info)
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES, TASK_LLM_IMAGE_TO_TEXT)) {
for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
// Remove duplicated imported model if existed.
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex)
}
if (task == TASK_LLM_IMAGE_TO_TEXT && model.llmSupportImage || task != TASK_LLM_IMAGE_TO_TEXT) {
if (task == TASK_LLM_ASK_IMAGE && model.llmSupportImage || task != TASK_LLM_ASK_IMAGE) {
task.models.add(model)
}
task.updateTrigger.value = System.currentTimeMillis()
@ -502,7 +497,7 @@ open class ModelManagerViewModel(
// Check expiration (with 5-minute buffer).
val curTs = System.currentTimeMillis()
val expirationTs = tokenData.expiresAtSeconds - 5 * 60
val expirationTs = tokenData.expiresAtMs - 5 * 60
Log.d(
TAG,
"Checking whether token has expired or not. Current ts: $curTs, expires at: $expirationTs"
@ -562,7 +557,7 @@ open class ModelManagerViewModel(
} else {
// Token exchange successful. Store the tokens securely
Log.d(TAG, "Token exchange successful. Storing tokens...")
dataStoreRepository.saveAccessTokenData(
saveAccessToken(
accessToken = tokenResponse.accessToken!!,
refreshToken = tokenResponse.refreshToken!!,
expiresAt = tokenResponse.accessTokenExpirationTime!!
@ -606,6 +601,18 @@ open class ModelManagerViewModel(
}
}
fun saveAccessToken(accessToken: String, refreshToken: String, expiresAt: Long) {
dataStoreRepository.saveAccessTokenData(
accessToken = accessToken,
refreshToken = refreshToken,
expiresAt = expiresAt,
)
}
fun clearAccessToken() {
dataStoreRepository.clearAccessTokenData()
}
private fun processPendingDownloads() {
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
@ -673,11 +680,11 @@ open class ModelManagerViewModel(
if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) {
TASK_LLM_CHAT.models.add(model)
}
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
TASK_LLM_USECASES.models.add(model)
if (allowedModel.taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)) {
TASK_LLM_PROMPT_LAB.models.add(model)
}
if (allowedModel.taskTypes.contains(TASK_LLM_IMAGE_TO_TEXT.type.id)) {
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
TASK_LLM_ASK_IMAGE.models.add(model)
}
}
@ -732,9 +739,9 @@ open class ModelManagerViewModel(
// Add to task.
TASK_LLM_CHAT.models.add(model)
TASK_LLM_USECASES.models.add(model)
TASK_LLM_PROMPT_LAB.models.add(model)
if (model.llmSupportImage) {
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
TASK_LLM_ASK_IMAGE.models.add(model)
}
// Update status.
@ -827,38 +834,6 @@ open class ModelManagerViewModel(
)
}
@OptIn(ExperimentalSerializationApi::class)
private inline fun <reified T> getJsonResponse(url: String): T? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET"
connection.connect()
val responseCode = connection.responseCode
if (responseCode == HttpURLConnection.HTTP_OK) {
val inputStream = connection.inputStream
val response = inputStream.bufferedReader().use { it.readText() }
// Parse JSON using kotlinx.serialization
val json = Json {
// Handle potential extra fields
ignoreUnknownKeys = true
allowComments = true
allowTrailingComma = true
}
val jsonObj = json.decodeFromString<T>(response)
return jsonObj
} else {
Log.e(TAG, "HTTP error: $responseCode")
}
} catch (e: Exception) {
Log.e(TAG, "Error when getting json response: ${e.message}")
e.printStackTrace()
}
return null
}
private fun isFileInExternalFilesDir(fileName: String): Boolean {
if (externalFilesDir != null) {
val file = File(externalFilesDir, fileName)

View file

@ -46,8 +46,8 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
@ -60,8 +60,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextDestination
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextScreen
import com.google.aiedge.gallery.ui.llmchat.LlmAskImageDestination
import com.google.aiedge.gallery.ui.llmchat.LlmAskImageScreen
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
@ -233,7 +233,7 @@ fun GalleryNavHost(
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_USECASES)?.let { defaultModel ->
getModelFromNavigationParam(it, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen(
@ -245,15 +245,15 @@ fun GalleryNavHost(
// LLM image to text.
composable(
route = "${LlmImageToTextDestination.route}/{modelName}",
route = "${LlmAskImageDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_IMAGE_TO_TEXT)?.let { defaultModel ->
getModelFromNavigationParam(it, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmImageToTextScreen(
LlmAskImageScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
@ -287,8 +287,8 @@ fun navigateToTaskScreen(
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_IMAGE_TO_TEXT -> navController.navigate("${LlmImageToTextDestination.route}/${modelName}")
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
TaskType.LLM_PROMPT_LAB -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}

View file

@ -42,6 +42,9 @@ class PreviewDataStoreRepository : DataStoreRepository {
return null
}
override fun clearAccessTokenData() {
}
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
}