mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-15 10:46:38 -04:00
Various bug fixes.
This commit is contained in:
parent
0c49efc054
commit
37a58d1a41
35 changed files with 1517 additions and 995 deletions
|
@ -27,10 +27,10 @@ android {
|
|||
|
||||
defaultConfig {
|
||||
applicationId = "com.google.aiedge.gallery"
|
||||
minSdk = 24
|
||||
minSdk = 26
|
||||
targetSdk = 35
|
||||
versionCode = 1
|
||||
versionName = "20250428"
|
||||
versionName = "0.9.0"
|
||||
|
||||
// Needed for HuggingFace auth workflows.
|
||||
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
|
||||
|
|
|
@ -41,7 +41,9 @@
|
|||
android:name=".MainActivity"
|
||||
android:exported="true"
|
||||
android:theme="@style/Theme.Gallery.SplashScreen"
|
||||
android:windowSoftInputMode="adjustResize">
|
||||
android:screenOrientation="portrait"
|
||||
android:windowSoftInputMode="adjustResize"
|
||||
tools:ignore="DiscouragedApi,LockedOrientationActivity">
|
||||
<!-- This is for putting the app into launcher -->
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
|
|
@ -68,7 +68,6 @@ fun GalleryTopAppBar(
|
|||
leftAction: AppBarAction? = null,
|
||||
rightAction: AppBarAction? = null,
|
||||
scrollBehavior: TopAppBarScrollBehavior? = null,
|
||||
loadingHfModels: Boolean = false,
|
||||
subtitle: String = "",
|
||||
) {
|
||||
CenterAlignedTopAppBar(
|
||||
|
@ -151,28 +150,6 @@ fun GalleryTopAppBar(
|
|||
}
|
||||
}
|
||||
|
||||
// Click an icon to open "download manager".
|
||||
AppBarActionType.DOWNLOAD_MANAGER -> {
|
||||
if (loadingHfModels) {
|
||||
CircularProgressIndicator(
|
||||
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||
strokeWidth = 3.dp,
|
||||
modifier = Modifier
|
||||
.padding(end = 12.dp)
|
||||
.size(20.dp)
|
||||
)
|
||||
}
|
||||
// else {
|
||||
// IconButton(onClick = rightAction.actionFn) {
|
||||
// Icon(
|
||||
// imageVector = Deployed_code,
|
||||
// contentDescription = "",
|
||||
// tint = MaterialTheme.colorScheme.primary
|
||||
// )
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
AppBarActionType.MODEL_SELECTOR -> {
|
||||
Text("ms")
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ import javax.crypto.SecretKey
|
|||
data class AccessTokenData(
|
||||
val accessToken: String,
|
||||
val refreshToken: String,
|
||||
val expiresAtSeconds: Long
|
||||
val expiresAtMs: Long
|
||||
)
|
||||
|
||||
interface DataStoreRepository {
|
||||
|
@ -46,6 +46,7 @@ interface DataStoreRepository {
|
|||
fun saveThemeOverride(theme: String)
|
||||
fun readThemeOverride(): String
|
||||
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
|
||||
fun clearAccessTokenData()
|
||||
fun readAccessTokenData(): AccessTokenData?
|
||||
fun saveImportedModels(importedModels: List<ImportedModelInfo>)
|
||||
fun readImportedModels(): List<ImportedModelInfo>
|
||||
|
@ -135,6 +136,18 @@ class DefaultDataStoreRepository(
|
|||
}
|
||||
}
|
||||
|
||||
override fun clearAccessTokenData() {
|
||||
return runBlocking {
|
||||
dataStore.edit { preferences ->
|
||||
preferences.remove(PreferencesKeys.ENCRYPTED_ACCESS_TOKEN)
|
||||
preferences.remove(PreferencesKeys.ACCESS_TOKEN_IV)
|
||||
preferences.remove(PreferencesKeys.ENCRYPTED_REFRESH_TOKEN)
|
||||
preferences.remove(PreferencesKeys.REFRESH_TOKEN_IV)
|
||||
preferences.remove(PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun readAccessTokenData(): AccessTokenData? {
|
||||
return runBlocking {
|
||||
val preferences = dataStore.data.first()
|
||||
|
|
|
@ -43,7 +43,7 @@ data class AllowedModel(
|
|||
|
||||
// Config.
|
||||
val isLlmModel =
|
||||
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)
|
||||
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)
|
||||
var configs: List<Config> = listOf()
|
||||
if (isLlmModel) {
|
||||
var defaultTopK: Int = DEFAULT_TOPK
|
||||
|
|
|
@ -32,9 +32,9 @@ enum class TaskType(val label: String, val id: String) {
|
|||
TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
|
||||
IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
|
||||
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
|
||||
LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
|
||||
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
|
||||
LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"),
|
||||
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
|
||||
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
|
||||
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
|
||||
|
||||
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
||||
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
|
||||
|
@ -93,7 +93,6 @@ val TASK_IMAGE_CLASSIFICATION = Task(
|
|||
val TASK_LLM_CHAT = Task(
|
||||
type = TaskType.LLM_CHAT,
|
||||
icon = Icons.Outlined.Forum,
|
||||
// models = MODELS_LLM,
|
||||
models = mutableListOf(),
|
||||
description = "Chat with on-device large language models",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
|
@ -101,10 +100,9 @@ val TASK_LLM_CHAT = Task(
|
|||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||
)
|
||||
|
||||
val TASK_LLM_USECASES = Task(
|
||||
type = TaskType.LLM_USECASES,
|
||||
val TASK_LLM_PROMPT_LAB = Task(
|
||||
type = TaskType.LLM_PROMPT_LAB,
|
||||
icon = Icons.Outlined.Widgets,
|
||||
// models = MODELS_LLM,
|
||||
models = mutableListOf(),
|
||||
description = "Single turn use cases with on-device large language model",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
|
@ -112,10 +110,9 @@ val TASK_LLM_USECASES = Task(
|
|||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||
)
|
||||
|
||||
val TASK_LLM_IMAGE_TO_TEXT = Task(
|
||||
type = TaskType.LLM_IMAGE_TO_TEXT,
|
||||
val TASK_LLM_ASK_IMAGE = Task(
|
||||
type = TaskType.LLM_ASK_IMAGE,
|
||||
icon = Icons.Outlined.Mms,
|
||||
// models = MODELS_LLM,
|
||||
models = mutableListOf(),
|
||||
description = "Ask questions about images with on-device large language models",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
|
@ -135,12 +132,9 @@ val TASK_IMAGE_GENERATION = Task(
|
|||
|
||||
/** All tasks. */
|
||||
val TASKS: List<Task> = listOf(
|
||||
// TASK_TEXT_CLASSIFICATION,
|
||||
// TASK_IMAGE_CLASSIFICATION,
|
||||
// TASK_IMAGE_GENERATION,
|
||||
TASK_LLM_USECASES,
|
||||
TASK_LLM_ASK_IMAGE,
|
||||
TASK_LLM_PROMPT_LAB,
|
||||
TASK_LLM_CHAT,
|
||||
TASK_LLM_IMAGE_TO_TEXT
|
||||
)
|
||||
|
||||
fun getModelByName(name: String): Model? {
|
||||
|
|
|
@ -25,7 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
|
|||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
|
||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmAskImageViewModel
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
|
||||
|
@ -63,9 +63,9 @@ object ViewModelProvider {
|
|||
LlmSingleTurnViewModel()
|
||||
}
|
||||
|
||||
// Initializer for LlmImageToTextViewModel.
|
||||
// Initializer for LlmAskImageViewModel.
|
||||
initializer {
|
||||
LlmImageToTextViewModel()
|
||||
LlmAskImageViewModel()
|
||||
}
|
||||
|
||||
// Initializer for ImageGenerationViewModel.
|
||||
|
|
|
@ -26,6 +26,8 @@ import androidx.browser.customtabs.CustomTabsIntent
|
|||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.wrapContentHeight
|
||||
import androidx.compose.foundation.text.BasicText
|
||||
import androidx.compose.foundation.text.TextAutoSize
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.rounded.ArrowForward
|
||||
import androidx.compose.material.icons.rounded.Error
|
||||
|
@ -48,6 +50,7 @@ import androidx.compose.ui.Alignment
|
|||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
|
@ -291,11 +294,32 @@ fun DownloadAndTryButton(
|
|||
modifier = Modifier.padding(end = 4.dp)
|
||||
)
|
||||
|
||||
val textColor = MaterialTheme.colorScheme.onPrimary
|
||||
if (checkingToken) {
|
||||
Text("Checking access...")
|
||||
BasicText(
|
||||
text = "Checking access...",
|
||||
maxLines = 1,
|
||||
color = { textColor },
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
autoSize = TextAutoSize.StepBased(
|
||||
minFontSize = 8.sp,
|
||||
maxFontSize = 14.sp,
|
||||
stepSize = 1.sp
|
||||
)
|
||||
)
|
||||
} else {
|
||||
if (needToDownloadFirst) {
|
||||
Text("Download & Try it", maxLines = 1)
|
||||
BasicText(
|
||||
text = "Download & Try",
|
||||
maxLines = 1,
|
||||
color = { textColor },
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
autoSize = TextAutoSize.StepBased(
|
||||
minFontSize = 8.sp,
|
||||
maxFontSize = 14.sp,
|
||||
stepSize = 1.sp
|
||||
)
|
||||
)
|
||||
} else {
|
||||
Text("Try it", maxLines = 1)
|
||||
}
|
||||
|
|
|
@ -63,6 +63,8 @@ fun ModelPageAppBar(
|
|||
modelManagerViewModel: ModelManagerViewModel,
|
||||
onBackClicked: () -> Unit,
|
||||
onModelSelected: (Model) -> Unit,
|
||||
inProgress: Boolean,
|
||||
modelPreparing: Boolean,
|
||||
modifier: Modifier = Modifier,
|
||||
isResettingSession: Boolean = false,
|
||||
onResetSessionClicked: (Model) -> Unit = {},
|
||||
|
@ -129,15 +131,16 @@ fun ModelPageAppBar(
|
|||
val isModelInitializing =
|
||||
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
||||
if (showConfigButton) {
|
||||
val enableConfigButton = !isModelInitializing && !inProgress
|
||||
IconButton(
|
||||
onClick = {
|
||||
showConfigDialog = true
|
||||
},
|
||||
enabled = !isModelInitializing,
|
||||
enabled = enableConfigButton,
|
||||
modifier = Modifier
|
||||
.scale(0.75f)
|
||||
.offset(x = configButtonOffset)
|
||||
.alpha(if (isModelInitializing) 0.5f else 1f)
|
||||
.alpha(if (!enableConfigButton) 0.5f else 1f)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.Tune,
|
||||
|
@ -154,14 +157,15 @@ fun ModelPageAppBar(
|
|||
modifier = Modifier.size(16.dp)
|
||||
)
|
||||
} else {
|
||||
val enableResetButton = !isModelInitializing && !modelPreparing
|
||||
IconButton(
|
||||
onClick = {
|
||||
onResetSessionClicked(model)
|
||||
},
|
||||
enabled = !isModelInitializing,
|
||||
enabled = enableResetButton,
|
||||
modifier = Modifier
|
||||
.scale(0.75f)
|
||||
.alpha(if (isModelInitializing) 0.5f else 1f)
|
||||
.alpha(if (!enableResetButton) 0.5f else 1f)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.MapsUgc,
|
||||
|
|
|
@ -83,8 +83,11 @@ fun ModelPickerChipsPager(
|
|||
val scope = rememberCoroutineScope()
|
||||
val density = LocalDensity.current
|
||||
val windowInfo = LocalWindowInfo.current
|
||||
val screenWidthDp =
|
||||
remember { with(density) { windowInfo.containerSize.width.toDp() } }
|
||||
val screenWidthDp = remember {
|
||||
with(density) {
|
||||
windowInfo.containerSize.width.toDp()
|
||||
}
|
||||
}
|
||||
|
||||
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
|
||||
pageCount = { task.models.size })
|
||||
|
@ -147,7 +150,7 @@ fun ModelPickerChipsPager(
|
|||
}
|
||||
Text(
|
||||
model.name,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
style = MaterialTheme.typography.labelMedium,
|
||||
modifier = Modifier
|
||||
.padding(start = 4.dp)
|
||||
.widthIn(0.dp, screenWidthDp - 250.dp),
|
||||
|
|
|
@ -21,6 +21,7 @@ import android.content.Context
|
|||
import android.content.pm.PackageManager
|
||||
import android.net.Uri
|
||||
import android.os.Build
|
||||
import android.util.Log
|
||||
import androidx.activity.compose.ManagedActivityResultLauncher
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.runtime.Composable
|
||||
|
@ -39,7 +40,11 @@ import com.google.aiedge.gallery.ui.common.chat.Histogram
|
|||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
import kotlinx.serialization.ExperimentalSerializationApi
|
||||
import kotlinx.serialization.json.Json
|
||||
import java.io.File
|
||||
import java.net.HttpURLConnection
|
||||
import java.net.URL
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.max
|
||||
|
@ -488,3 +493,38 @@ fun processLlmResponse(response: String): String {
|
|||
|
||||
return newContent
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalSerializationApi::class)
|
||||
inline fun <reified T> getJsonResponse(url: String): T? {
|
||||
try {
|
||||
val connection = URL(url).openConnection() as HttpURLConnection
|
||||
connection.requestMethod = "GET"
|
||||
connection.connect()
|
||||
|
||||
val responseCode = connection.responseCode
|
||||
if (responseCode == HttpURLConnection.HTTP_OK) {
|
||||
val inputStream = connection.inputStream
|
||||
val response = inputStream.bufferedReader().use { it.readText() }
|
||||
|
||||
// Parse JSON using kotlinx.serialization
|
||||
val json = Json {
|
||||
// Handle potential extra fields
|
||||
ignoreUnknownKeys = true
|
||||
allowComments = true
|
||||
allowTrailingComma = true
|
||||
}
|
||||
val jsonObj = json.decodeFromString<T>(response)
|
||||
return jsonObj
|
||||
} else {
|
||||
Log.e("AGUtils", "HTTP error: $responseCode")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(
|
||||
"AGUtils",
|
||||
"Error when getting json response: ${e.message}"
|
||||
)
|
||||
e.printStackTrace()
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
|
|
@ -16,9 +16,14 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.common.chat
|
||||
|
||||
import androidx.compose.animation.AnimatedContent
|
||||
import androidx.compose.animation.ExperimentalSharedTransitionApi
|
||||
import androidx.compose.animation.SharedTransitionLayout
|
||||
import androidx.compose.foundation.Image
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.gestures.detectTapGestures
|
||||
import androidx.compose.foundation.gestures.detectTransformGestures
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
|
@ -28,6 +33,7 @@ import androidx.compose.foundation.layout.fillMaxSize
|
|||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.ime
|
||||
import androidx.compose.foundation.layout.imePadding
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.layout.width
|
||||
|
@ -39,12 +45,15 @@ import androidx.compose.foundation.lazy.rememberLazyListState
|
|||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.outlined.Timer
|
||||
import androidx.compose.material.icons.rounded.Close
|
||||
import androidx.compose.material.icons.rounded.ContentCopy
|
||||
import androidx.compose.material.icons.rounded.Refresh
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.IconButtonDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
import androidx.compose.material3.SnackbarHost
|
||||
|
@ -55,6 +64,7 @@ import androidx.compose.runtime.LaunchedEffect
|
|||
import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableFloatStateOf
|
||||
import androidx.compose.runtime.mutableIntStateOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
|
@ -65,12 +75,16 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.geometry.Offset
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.RectangleShape
|
||||
import androidx.compose.ui.graphics.asImageBitmap
|
||||
import androidx.compose.ui.graphics.graphicsLayer
|
||||
import androidx.compose.ui.hapticfeedback.HapticFeedbackType
|
||||
import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
|
||||
import androidx.compose.ui.input.nestedscroll.NestedScrollSource
|
||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||
import androidx.compose.ui.input.pointer.pointerInput
|
||||
import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.layout.onSizeChanged
|
||||
import androidx.compose.ui.platform.LocalClipboardManager
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalDensity
|
||||
|
@ -80,6 +94,7 @@ import androidx.compose.ui.res.dimensionResource
|
|||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.text.AnnotatedString
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.IntSize
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.window.Dialog
|
||||
import com.google.aiedge.gallery.R
|
||||
|
@ -102,7 +117,7 @@ enum class ChatInputType {
|
|||
/**
|
||||
* Composable function for the main chat panel, displaying messages and handling user input.
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@OptIn(ExperimentalMaterial3Api::class, ExperimentalSharedTransitionApi::class)
|
||||
@Composable
|
||||
fun ChatPanel(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
|
@ -127,6 +142,7 @@ fun ChatPanel(
|
|||
val snackbarHostState = remember { SnackbarHostState() }
|
||||
val scope = rememberCoroutineScope()
|
||||
val haptic = LocalHapticFeedback.current
|
||||
var selectedImageMessage by remember { mutableStateOf<ChatMessageImage?>(null) }
|
||||
|
||||
var curMessage by remember { mutableStateOf("") } // Correct state
|
||||
val focusManager = LocalFocusManager.current
|
||||
|
@ -200,13 +216,14 @@ fun ChatPanel(
|
|||
}
|
||||
}
|
||||
|
||||
val modelInitializationStatus =
|
||||
modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
||||
val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
||||
|
||||
LaunchedEffect(modelInitializationStatus) {
|
||||
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
|
||||
}
|
||||
|
||||
SharedTransitionLayout(modifier = Modifier.fillMaxSize()) {
|
||||
AnimatedContent(targetState = selectedImageMessage) { targetSelectedImageMessage ->
|
||||
Column(
|
||||
modifier = modifier.imePadding()
|
||||
) {
|
||||
|
@ -297,8 +314,7 @@ fun ChatPanel(
|
|||
)
|
||||
.background(backgroundColor)
|
||||
if (message is ChatMessageText) {
|
||||
messageBubbleModifier = messageBubbleModifier
|
||||
.pointerInput(Unit) {
|
||||
messageBubbleModifier = messageBubbleModifier.pointerInput(Unit) {
|
||||
detectTapGestures(
|
||||
onLongPress = {
|
||||
haptic.performHapticFeedback(HapticFeedbackType.LongPress)
|
||||
|
@ -316,7 +332,21 @@ fun ChatPanel(
|
|||
is ChatMessageText -> MessageBodyText(message = message)
|
||||
|
||||
// Image
|
||||
is ChatMessageImage -> MessageBodyImage(message = message)
|
||||
is ChatMessageImage -> {
|
||||
if (targetSelectedImageMessage != message) {
|
||||
MessageBodyImage(
|
||||
message = message,
|
||||
modifier = Modifier
|
||||
.clickable {
|
||||
selectedImageMessage = message
|
||||
}
|
||||
.sharedElement(
|
||||
sharedContentState = rememberSharedContentState(key = "selected_image"),
|
||||
animatedVisibilityScope = this@AnimatedContent
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Image with history (for image gen)
|
||||
is ChatMessageImageWithHistory -> MessageBodyImageWithHistory(
|
||||
|
@ -325,8 +355,9 @@ fun ChatPanel(
|
|||
|
||||
// Classification result
|
||||
is ChatMessageClassification -> MessageBodyClassification(
|
||||
message = message,
|
||||
modifier = Modifier.width(message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH)
|
||||
message = message, modifier = Modifier.width(
|
||||
message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH
|
||||
)
|
||||
)
|
||||
|
||||
// Benchmark result.
|
||||
|
@ -334,8 +365,7 @@ fun ChatPanel(
|
|||
|
||||
// Benchmark LLM result.
|
||||
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
|
||||
message = message,
|
||||
modifier = Modifier.wrapContentWidth()
|
||||
message = message, modifier = Modifier.wrapContentWidth()
|
||||
)
|
||||
|
||||
else -> {}
|
||||
|
@ -348,7 +378,7 @@ fun ChatPanel(
|
|||
) {
|
||||
LatencyText(message = message)
|
||||
// A button to show stats for the LLM message.
|
||||
if (task.type == TaskType.LLM_CHAT && message is ChatMessageText
|
||||
if ((task.type == TaskType.LLM_CHAT || task.type == TaskType.LLM_ASK_IMAGE) && message is ChatMessageText
|
||||
// This means we only want to show the action button when the message is done
|
||||
// generating, at which point the latency will be set.
|
||||
&& message.latencyMs >= 0
|
||||
|
@ -363,7 +393,10 @@ fun ChatPanel(
|
|||
viewModel.toggleShowingStats(selectedModel, message)
|
||||
|
||||
// Add the stats message after the LLM message.
|
||||
if (viewModel.isShowingStats(model = selectedModel, message = message)) {
|
||||
if (viewModel.isShowingStats(
|
||||
model = selectedModel, message = message
|
||||
)
|
||||
) {
|
||||
val llmBenchmarkResult = message.llmBenchmarkResult
|
||||
if (llmBenchmarkResult != null) {
|
||||
viewModel.insertMessageAfter(
|
||||
|
@ -376,10 +409,12 @@ fun ChatPanel(
|
|||
// Remove the stats message.
|
||||
else {
|
||||
val curMessageIndex =
|
||||
viewModel.getMessageIndex(model = selectedModel, message = message)
|
||||
viewModel.removeMessageAt(
|
||||
viewModel.getMessageIndex(
|
||||
model = selectedModel,
|
||||
index = curMessageIndex + 1
|
||||
message = message
|
||||
)
|
||||
viewModel.removeMessageAt(
|
||||
model = selectedModel, index = curMessageIndex + 1
|
||||
)
|
||||
}
|
||||
},
|
||||
|
@ -425,8 +460,23 @@ fun ChatPanel(
|
|||
}
|
||||
|
||||
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
|
||||
}
|
||||
|
||||
// Show an info message for ask image task to get users started.
|
||||
if (task.type == TaskType.LLM_ASK_IMAGE && messages.isEmpty()) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(horizontal = 16.dp)
|
||||
.fillMaxSize(),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
MessageBodyInfo(
|
||||
ChatMessageInfo(content = "To get started, click + below to add an image and type a prompt to ask a question about it."),
|
||||
smallFontSize = false
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Chat input
|
||||
when (chatInputType) {
|
||||
|
@ -439,6 +489,7 @@ fun ChatPanel(
|
|||
curMessage = curMessage,
|
||||
inProgress = uiState.inProgress,
|
||||
isResettingSession = uiState.isResettingSession,
|
||||
modelPreparing = uiState.preparing,
|
||||
hasImageMessage = hasImageMessage,
|
||||
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
||||
|
@ -459,7 +510,7 @@ fun ChatPanel(
|
|||
onStopButtonClicked = onStopButtonClicked,
|
||||
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
||||
showPromptTemplatesInMenu = false,
|
||||
showImagePickerInMenu = selectedModel.llmSupportImage == true,
|
||||
showImagePickerInMenu = selectedModel.llmSupportImage,
|
||||
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
|
||||
)
|
||||
}
|
||||
|
@ -488,6 +539,58 @@ fun ChatPanel(
|
|||
}
|
||||
}
|
||||
|
||||
// A full-screen image viewer.
|
||||
if (targetSelectedImageMessage != null) {
|
||||
ZoomableBox(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.background(Color.Black.copy(alpha = 0.9f))
|
||||
.sharedElement(
|
||||
rememberSharedContentState(key = "bounds"),
|
||||
animatedVisibilityScope = this,
|
||||
)
|
||||
.skipToLookaheadSize(),
|
||||
) {
|
||||
// Image.
|
||||
Image(
|
||||
bitmap = targetSelectedImageMessage.imageBitMap,
|
||||
contentDescription = "",
|
||||
modifier = modifier
|
||||
.fillMaxSize()
|
||||
.graphicsLayer(
|
||||
scaleX = scale,
|
||||
scaleY = scale,
|
||||
translationX = offsetX,
|
||||
translationY = offsetY
|
||||
)
|
||||
.sharedElement(
|
||||
sharedContentState = rememberSharedContentState(key = "selected_image"),
|
||||
animatedVisibilityScope = this@AnimatedContent
|
||||
),
|
||||
contentScale = ContentScale.Fit,
|
||||
)
|
||||
|
||||
// Close button.
|
||||
IconButton(
|
||||
onClick = {
|
||||
selectedImageMessage = null
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.surfaceVariant,
|
||||
),
|
||||
modifier = Modifier.offset(x = (-8).dp, y = 8.dp)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.Close,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Error dialog.
|
||||
if (showErrorDialog) {
|
||||
Dialog(
|
||||
|
@ -498,9 +601,7 @@ fun ChatPanel(
|
|||
) {
|
||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(20.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Title
|
||||
Text(
|
||||
|
@ -568,13 +669,10 @@ fun ChatPanel(
|
|||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||
modifier = Modifier
|
||||
.padding(vertical = 8.dp, horizontal = 16.dp)
|
||||
modifier = Modifier.padding(vertical = 8.dp, horizontal = 16.dp)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.ContentCopy,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(18.dp)
|
||||
Icons.Rounded.ContentCopy, contentDescription = "", modifier = Modifier.size(18.dp)
|
||||
)
|
||||
Text("Copy text")
|
||||
}
|
||||
|
@ -586,6 +684,51 @@ fun ChatPanel(
|
|||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ZoomableBox(
|
||||
modifier: Modifier = Modifier,
|
||||
minScale: Float = 1f,
|
||||
maxScale: Float = 5f,
|
||||
content: @Composable ZoomableBoxScope.() -> Unit
|
||||
) {
|
||||
var scale by remember { mutableFloatStateOf(1f) }
|
||||
var offsetX by remember { mutableFloatStateOf(0f) }
|
||||
var offsetY by remember { mutableFloatStateOf(0f) }
|
||||
var size by remember { mutableStateOf(IntSize.Zero) }
|
||||
Box(
|
||||
modifier = modifier
|
||||
.clip(RectangleShape)
|
||||
.onSizeChanged { size = it }
|
||||
.pointerInput(Unit) {
|
||||
detectTransformGestures { _, pan, zoom, _ ->
|
||||
scale = maxOf(minScale, minOf(scale * zoom, maxScale))
|
||||
val maxX = (size.width * (scale - 1)) / 2
|
||||
val minX = -maxX
|
||||
offsetX = maxOf(minX, minOf(maxX, offsetX + pan.x))
|
||||
val maxY = (size.height * (scale - 1)) / 2
|
||||
val minY = -maxY
|
||||
offsetY = maxOf(minY, minOf(maxY, offsetY + pan.y))
|
||||
}
|
||||
},
|
||||
contentAlignment = Alignment.TopEnd
|
||||
) {
|
||||
val scope = ZoomableBoxScopeImpl(scale, offsetX, offsetY)
|
||||
scope.content()
|
||||
}
|
||||
}
|
||||
|
||||
interface ZoomableBoxScope {
|
||||
val scale: Float
|
||||
val offsetX: Float
|
||||
val offsetY: Float
|
||||
}
|
||||
|
||||
private data class ZoomableBoxScopeImpl(
|
||||
override val scale: Float,
|
||||
override val offsetX: Float,
|
||||
override val offsetY: Float
|
||||
) : ZoomableBoxScope
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ChatPanelPreview() {
|
||||
|
|
|
@ -19,6 +19,7 @@ package com.google.aiedge.gallery.ui.common.chat
|
|||
import android.util.Log
|
||||
import androidx.activity.compose.BackHandler
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
|
@ -27,6 +28,7 @@ import androidx.compose.foundation.pager.HorizontalPager
|
|||
import androidx.compose.foundation.pager.rememberPagerState
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
|
@ -36,10 +38,13 @@ import androidx.compose.runtime.remember
|
|||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.snapshotFlow
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.graphicsLayer
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
|
@ -81,7 +86,7 @@ fun ChatView(
|
|||
chatInputType: ChatInputType = ChatInputType.TEXT,
|
||||
showStopButtonInInputWhenInProgress: Boolean = false,
|
||||
) {
|
||||
val uiStat by viewModel.uiState.collectAsState()
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val selectedModel = modelManagerUiState.selectedModel
|
||||
|
||||
|
@ -158,7 +163,9 @@ fun ChatView(
|
|||
model = selectedModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
showResetSessionButton = true,
|
||||
isResettingSession = uiStat.isResettingSession,
|
||||
isResettingSession = uiState.isResettingSession,
|
||||
inProgress = uiState.inProgress,
|
||||
modelPreparing = uiState.preparing,
|
||||
onResetSessionClicked = onResetSessionClicked,
|
||||
onConfigChanged = { old, new ->
|
||||
viewModel.addConfigChangedMessage(
|
||||
|
|
|
@ -38,6 +38,11 @@ data class ChatUiState(
|
|||
*/
|
||||
val isResettingSession: Boolean = false,
|
||||
|
||||
/**
|
||||
* Indicates whether the model is preparing (before outputting any result and after initializing).
|
||||
*/
|
||||
val preparing: Boolean = false,
|
||||
|
||||
/**
|
||||
* A map of model names to lists of chat messages.
|
||||
*/
|
||||
|
@ -204,6 +209,10 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
|||
_uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) }
|
||||
}
|
||||
|
||||
fun setPreparing(preparing: Boolean) {
|
||||
_uiState.update { _uiState.value.copy(preparing = preparing) }
|
||||
}
|
||||
|
||||
fun addConfigChangedMessage(
|
||||
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
|
||||
) {
|
||||
|
|
|
@ -18,6 +18,8 @@ package com.google.aiedge.gallery.ui.common.chat
|
|||
|
||||
import android.util.Log
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
|
@ -28,9 +30,11 @@ import androidx.compose.foundation.layout.height
|
|||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.foundation.text.BasicTextField
|
||||
import androidx.compose.foundation.text.KeyboardOptions
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
|
@ -53,6 +57,9 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.focus.FocusRequester
|
||||
import androidx.compose.ui.focus.focusRequester
|
||||
import androidx.compose.ui.focus.onFocusChanged
|
||||
import androidx.compose.ui.graphics.SolidColor
|
||||
import androidx.compose.ui.platform.LocalFocusManager
|
||||
import androidx.compose.ui.text.TextStyle
|
||||
import androidx.compose.ui.text.input.KeyboardType
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
|
@ -89,9 +96,20 @@ fun ConfigDialog(
|
|||
putAll(initialValues)
|
||||
}
|
||||
}
|
||||
val interactionSource = remember { MutableInteractionSource() }
|
||||
|
||||
Dialog(onDismissRequest = onDismissed) {
|
||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||
val focusManager = LocalFocusManager.current
|
||||
Card(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.clickable(
|
||||
interactionSource = interactionSource, indication = null // Disable the ripple effect
|
||||
) {
|
||||
focusManager.clearFocus()
|
||||
},
|
||||
shape = RoundedCornerShape(16.dp)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
|
@ -114,7 +132,14 @@ fun ConfigDialog(
|
|||
}
|
||||
|
||||
// List of config rows.
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.verticalScroll(rememberScrollState())
|
||||
.weight(1f, fill = false),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
ConfigEditorsPanel(configs = configs, values = values)
|
||||
}
|
||||
|
||||
// Button row.
|
||||
Row(
|
||||
|
@ -264,6 +289,8 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
|
|||
values[config.key.label] = NaN
|
||||
}
|
||||
},
|
||||
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
|
||||
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
|
||||
) { innerTextField ->
|
||||
Box(
|
||||
modifier = Modifier.border(
|
||||
|
|
|
@ -25,17 +25,18 @@ import androidx.compose.ui.layout.ContentScale
|
|||
import androidx.compose.ui.unit.dp
|
||||
|
||||
@Composable
|
||||
fun MessageBodyImage(message: ChatMessageImage) {
|
||||
fun MessageBodyImage(message: ChatMessageImage, modifier: Modifier = Modifier) {
|
||||
val bitmapWidth = message.bitmap.width
|
||||
val bitmapHeight = message.bitmap.height
|
||||
val imageWidth =
|
||||
if (bitmapWidth >= bitmapHeight) 200 else (200f / bitmapHeight * bitmapWidth).toInt()
|
||||
val imageHeight =
|
||||
if (bitmapHeight >= bitmapWidth) 200 else (200f / bitmapWidth * bitmapHeight).toInt()
|
||||
|
||||
Image(
|
||||
bitmap = message.imageBitMap,
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
modifier = modifier
|
||||
.height(imageHeight.dp)
|
||||
.width(imageWidth.dp),
|
||||
contentScale = ContentScale.Fit,
|
||||
|
|
|
@ -38,7 +38,7 @@ import com.google.aiedge.gallery.ui.theme.customColors
|
|||
* Supports markdown.
|
||||
*/
|
||||
@Composable
|
||||
fun MessageBodyInfo(message: ChatMessageInfo) {
|
||||
fun MessageBodyInfo(message: ChatMessageInfo, smallFontSize: Boolean = true) {
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
|
||||
) {
|
||||
|
@ -47,7 +47,11 @@ fun MessageBodyInfo(message: ChatMessageInfo) {
|
|||
.clip(RoundedCornerShape(16.dp))
|
||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||
) {
|
||||
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
|
||||
MarkdownText(
|
||||
text = message.content,
|
||||
modifier = Modifier.padding(12.dp),
|
||||
smallFontSize = smallFontSize
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -103,6 +103,7 @@ fun MessageInputText(
|
|||
@StringRes textFieldPlaceHolderRes: Int,
|
||||
onValueChanged: (String) -> Unit,
|
||||
onSendMessage: (List<ChatMessage>) -> Unit,
|
||||
modelPreparing: Boolean = false,
|
||||
onOpenPromptTemplatesClicked: () -> Unit = {},
|
||||
onStopButtonClicked: () -> Unit = {},
|
||||
showPromptTemplatesInMenu: Boolean = false,
|
||||
|
@ -173,7 +174,7 @@ fun MessageInputText(
|
|||
.height(80.dp)
|
||||
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)),
|
||||
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
|
||||
)
|
||||
Box(modifier = Modifier
|
||||
.offset(x = 10.dp, y = (-10).dp)
|
||||
|
@ -219,7 +220,7 @@ fun MessageInputText(
|
|||
expanded = showAddContentMenu,
|
||||
onDismissRequest = { showAddContentMenu = false }) {
|
||||
if (showImagePickerInMenu) {
|
||||
// Take a photo.
|
||||
// Take a picture.
|
||||
DropdownMenuItem(
|
||||
text = {
|
||||
Row(
|
||||
|
@ -227,7 +228,7 @@ fun MessageInputText(
|
|||
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||
) {
|
||||
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
|
||||
Text("Take a photo")
|
||||
Text("Take a picture")
|
||||
}
|
||||
},
|
||||
enabled = pickedImages.isEmpty() && !hasImageMessage,
|
||||
|
@ -321,7 +322,7 @@ fun MessageInputText(
|
|||
Spacer(modifier = Modifier.width(8.dp))
|
||||
|
||||
if (inProgress && showStopButtonWhenInProgress) {
|
||||
if (!modelInitializing) {
|
||||
if (!modelInitializing && !modelPreparing) {
|
||||
IconButton(
|
||||
onClick = onStopButtonClicked,
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
|
|
|
@ -20,8 +20,9 @@ import android.content.Intent
|
|||
import android.net.Uri
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.compose.animation.core.animateFloatAsState
|
||||
import androidx.compose.animation.core.tween
|
||||
import androidx.compose.animation.AnimatedContent
|
||||
import androidx.compose.animation.ExperimentalSharedTransitionApi
|
||||
import androidx.compose.animation.SharedTransitionLayout
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||
|
@ -30,17 +31,13 @@ import androidx.compose.foundation.layout.Box
|
|||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.heightIn
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.ChevronRight
|
||||
import androidx.compose.material.icons.rounded.Settings
|
||||
import androidx.compose.material.icons.rounded.UnfoldLess
|
||||
import androidx.compose.material.icons.rounded.UnfoldMore
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.OutlinedButton
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.ripple
|
||||
|
@ -48,16 +45,12 @@ import androidx.compose.runtime.Composable
|
|||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.movableContentOf
|
||||
import androidx.compose.runtime.movableContentWithReceiverOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.layout.LookaheadScope
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.Dp
|
||||
|
@ -91,6 +84,7 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
|
|||
* model description and buttons for learning more (opening a URL) and downloading/trying
|
||||
* the model.
|
||||
*/
|
||||
@OptIn(ExperimentalSharedTransitionApi::class)
|
||||
@Composable
|
||||
fun ModelItem(
|
||||
model: Model,
|
||||
|
@ -117,37 +111,57 @@ fun ModelItem(
|
|||
|
||||
var isExpanded by remember { mutableStateOf(false) }
|
||||
|
||||
// Animate alpha for model description and button rows when switching between layouts.
|
||||
val alphaAnimation by animateFloatAsState(
|
||||
targetValue = if (isExpanded) 1f else 0f,
|
||||
animationSpec = tween(durationMillis = LAYOUT_ANIMATION_DURATION - 50)
|
||||
var boxModifier = modifier
|
||||
.fillMaxWidth()
|
||||
.clip(RoundedCornerShape(size = 42.dp))
|
||||
.background(
|
||||
getTaskBgColor(task)
|
||||
)
|
||||
boxModifier = if (canExpand) {
|
||||
boxModifier.clickable(onClick = {
|
||||
if (!model.imported) {
|
||||
isExpanded = !isExpanded
|
||||
} else {
|
||||
onModelClicked(model)
|
||||
}
|
||||
}, interactionSource = remember { MutableInteractionSource() }, indication = ripple(
|
||||
bounded = true,
|
||||
radius = 1000.dp,
|
||||
)
|
||||
)
|
||||
} else {
|
||||
boxModifier
|
||||
}
|
||||
|
||||
LookaheadScope {
|
||||
// Task icon.
|
||||
val taskIcon = remember {
|
||||
movableContentOf {
|
||||
Box(
|
||||
modifier = boxModifier,
|
||||
contentAlignment = Alignment.Center,
|
||||
) {
|
||||
SharedTransitionLayout {
|
||||
AnimatedContent(
|
||||
isExpanded, label = "item_layout_transition",
|
||||
) { targetState ->
|
||||
val taskIcon = @Composable {
|
||||
TaskIcon(
|
||||
task = task, modifier = Modifier.animateLayout()
|
||||
task = task, modifier = Modifier.sharedElement(
|
||||
sharedContentState = rememberSharedContentState(key = "task_icon"),
|
||||
animatedVisibilityScope = this@AnimatedContent,
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Model name and status.
|
||||
val modelNameAndStatus = remember {
|
||||
movableContentOf {
|
||||
val modelNameAndStatus = @Composable {
|
||||
ModelNameAndStatus(
|
||||
model = model,
|
||||
task = task,
|
||||
downloadStatus = downloadStatus,
|
||||
isExpanded = isExpanded,
|
||||
modifier = Modifier.animateLayout()
|
||||
animatedVisibilityScope = this@AnimatedContent,
|
||||
sharedTransitionScope = this@SharedTransitionLayout
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
val actionButton = remember {
|
||||
movableContentOf {
|
||||
val actionButton = @Composable {
|
||||
ModelItemActionButton(
|
||||
context = context,
|
||||
model = model,
|
||||
|
@ -165,61 +179,48 @@ fun ModelItem(
|
|||
},
|
||||
showDeleteButton = showDeleteButton,
|
||||
showDownloadButton = false,
|
||||
modifier = Modifier.sharedElement(
|
||||
sharedContentState = rememberSharedContentState(key = "action_button"),
|
||||
animatedVisibilityScope = this@AnimatedContent,
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Expand/collapse icon, or the config icon.
|
||||
val expandButton = remember {
|
||||
movableContentOf {
|
||||
if (showConfigButtonIfExisted) {
|
||||
if (downloadStatus?.status === ModelDownloadStatusType.SUCCEEDED) {
|
||||
if (model.configs.isNotEmpty()) {
|
||||
IconButton(onClick = onConfigClicked) {
|
||||
Icon(
|
||||
Icons.Rounded.Settings,
|
||||
contentDescription = "",
|
||||
tint = getTaskIconColor(task)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
val expandButton = @Composable {
|
||||
Icon(
|
||||
// For imported model, show ">" directly indicating users can just tap the model item to
|
||||
// go into it without needing to expand it first.
|
||||
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
||||
contentDescription = "",
|
||||
tint = getTaskIconColor(task),
|
||||
modifier = Modifier.sharedElement(
|
||||
sharedContentState = rememberSharedContentState(key = "expand_button"),
|
||||
animatedVisibilityScope = this@AnimatedContent,
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Model description shown in expanded layout.
|
||||
val modelDescription = remember {
|
||||
movableContentOf { m: Modifier ->
|
||||
val description = @Composable {
|
||||
if (model.info.isNotEmpty()) {
|
||||
MarkdownText(
|
||||
model.info,
|
||||
modifier = Modifier
|
||||
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp)
|
||||
.animateLayout()
|
||||
.then(m)
|
||||
model.info, modifier = Modifier
|
||||
.sharedElement(
|
||||
sharedContentState = rememberSharedContentState(key = "description"),
|
||||
animatedVisibilityScope = this@AnimatedContent,
|
||||
)
|
||||
.skipToLookaheadSize()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Button rows shown in expanded layout.
|
||||
val buttonRows = remember {
|
||||
movableContentOf { m: Modifier ->
|
||||
val buttonsRow = @Composable {
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp)
|
||||
.animateLayout()
|
||||
.then(m),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp), modifier = Modifier
|
||||
.sharedElement(
|
||||
sharedContentState = rememberSharedContentState(key = "buttons_row"),
|
||||
animatedVisibilityScope = this@AnimatedContent,
|
||||
)
|
||||
.skipToLookaheadSize()
|
||||
) {
|
||||
// The "learn more" button. Click to show related urls in a bottom sheet.
|
||||
if (model.learnMoreUrl.isNotEmpty()) {
|
||||
|
@ -238,60 +239,42 @@ fun ModelItem(
|
|||
// Button to start the download and start the chat session with the model.
|
||||
val needToDownloadFirst =
|
||||
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
|
||||
DownloadAndTryButton(
|
||||
task = task,
|
||||
DownloadAndTryButton(task = task,
|
||||
model = model,
|
||||
enabled = isExpanded,
|
||||
needToDownloadFirst = needToDownloadFirst,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onClicked = { onModelClicked(model) }
|
||||
)
|
||||
}
|
||||
onClicked = { onModelClicked(model) })
|
||||
}
|
||||
}
|
||||
|
||||
val container = remember {
|
||||
movableContentWithReceiverOf<LookaheadScope, @Composable () -> Unit> { content ->
|
||||
Box(
|
||||
modifier = Modifier.animateLayout(),
|
||||
contentAlignment = Alignment.TopEnd,
|
||||
// Collapsed state.
|
||||
if (!targetState) {
|
||||
Column(
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
) {
|
||||
content()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var boxModifier = modifier
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.clip(RoundedCornerShape(size = 42.dp))
|
||||
.background(
|
||||
getTaskBgColor(task)
|
||||
)
|
||||
boxModifier = if (canExpand) {
|
||||
boxModifier.clickable(
|
||||
onClick = {
|
||||
if (!model.imported) {
|
||||
isExpanded = !isExpanded
|
||||
} else {
|
||||
onModelClicked(model)
|
||||
}
|
||||
},
|
||||
interactionSource = remember { MutableInteractionSource() },
|
||||
indication = ripple(
|
||||
bounded = true,
|
||||
radius = 500.dp,
|
||||
)
|
||||
)
|
||||
} else {
|
||||
boxModifier
|
||||
}
|
||||
Box(
|
||||
modifier = boxModifier,
|
||||
contentAlignment = Alignment.Center
|
||||
.padding(start = 18.dp, end = 18.dp)
|
||||
.padding(vertical = verticalSpacing)
|
||||
) {
|
||||
if (isExpanded) {
|
||||
container {
|
||||
// The main part (icon, model name, status, etc)
|
||||
// Icon at the left.
|
||||
taskIcon()
|
||||
// Model name and status at the center.
|
||||
Row(modifier = Modifier.weight(1f)) {
|
||||
modelNameAndStatus()
|
||||
}
|
||||
// Action button and expand/collapse button at the right.
|
||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||
actionButton()
|
||||
expandButton()
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Column(
|
||||
verticalArrangement = Arrangement.spacedBy(14.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
|
@ -300,7 +283,9 @@ fun ModelItem(
|
|||
.padding(vertical = verticalSpacing, horizontal = 18.dp)
|
||||
) {
|
||||
Box(contentAlignment = Alignment.Center) {
|
||||
// Icon at the top-center.
|
||||
taskIcon()
|
||||
// Action button and expand/collapse button at the right.
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
|
@ -310,39 +295,12 @@ fun ModelItem(
|
|||
expandButton()
|
||||
}
|
||||
}
|
||||
// Name and status below the icon.
|
||||
modelNameAndStatus()
|
||||
modelDescription(Modifier.alpha(alphaAnimation))
|
||||
buttonRows(Modifier.alpha(alphaAnimation)) // Apply alpha here
|
||||
}
|
||||
}
|
||||
} else {
|
||||
container {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
// The main part (icon, model name, status, etc)
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(start = 18.dp, end = 18.dp)
|
||||
.padding(vertical = verticalSpacing)
|
||||
) {
|
||||
taskIcon()
|
||||
Row(modifier = Modifier.weight(1f)) {
|
||||
modelNameAndStatus()
|
||||
}
|
||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||
actionButton()
|
||||
expandButton()
|
||||
}
|
||||
}
|
||||
Column(
|
||||
modifier = Modifier.offset(y = 30.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
modelDescription(Modifier.alpha(alphaAnimation))
|
||||
buttonRows(Modifier.alpha(alphaAnimation))
|
||||
}
|
||||
// Description.
|
||||
description()
|
||||
// Buttons
|
||||
buttonsRow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -108,7 +108,8 @@ fun ModelItemActionButton(
|
|||
// Button to cancel the download when it is in progress.
|
||||
ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = {
|
||||
modelManagerViewModel.cancelDownloadModel(
|
||||
model
|
||||
task = task,
|
||||
model = model
|
||||
)
|
||||
}) {
|
||||
Icon(
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.common.modelitem
|
||||
|
||||
import androidx.compose.animation.AnimatedVisibilityScope
|
||||
import androidx.compose.animation.ExperimentalSharedTransitionApi
|
||||
import androidx.compose.animation.SharedTransitionScope
|
||||
import androidx.compose.animation.core.Animatable
|
||||
import androidx.compose.animation.core.tween
|
||||
import androidx.compose.foundation.layout.Column
|
||||
|
@ -30,6 +33,7 @@ import androidx.compose.runtime.LaunchedEffect
|
|||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.focus.focusModifier
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
|
@ -54,18 +58,22 @@ import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
|
|||
* - "Unzipping..." status for unzipping processes.
|
||||
* - Model size for successful downloads.
|
||||
*/
|
||||
@OptIn(ExperimentalSharedTransitionApi::class)
|
||||
@Composable
|
||||
fun ModelNameAndStatus(
|
||||
model: Model,
|
||||
task: Task,
|
||||
downloadStatus: ModelDownloadStatus?,
|
||||
isExpanded: Boolean,
|
||||
sharedTransitionScope: SharedTransitionScope,
|
||||
animatedVisibilityScope: AnimatedVisibilityScope,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
|
||||
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
|
||||
var curDownloadProgress = 0f
|
||||
|
||||
with(sharedTransitionScope) {
|
||||
Column(
|
||||
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
|
||||
) {
|
||||
|
@ -78,7 +86,10 @@ fun ModelNameAndStatus(
|
|||
maxLines = 1,
|
||||
overflow = TextOverflow.MiddleEllipsis,
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
modifier = modifier,
|
||||
modifier = Modifier.sharedElement(
|
||||
rememberSharedContentState(key = "model_name"),
|
||||
animatedVisibilityScope = animatedVisibilityScope
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -87,7 +98,12 @@ fun ModelNameAndStatus(
|
|||
if (!inProgress && !isPartiallyDownloaded) {
|
||||
StatusIcon(
|
||||
downloadStatus = downloadStatus,
|
||||
modifier = modifier.padding(end = 4.dp)
|
||||
modifier = modifier
|
||||
.padding(end = 4.dp)
|
||||
.sharedElement(
|
||||
rememberSharedContentState(key = "download_status_icon"),
|
||||
animatedVisibilityScope = animatedVisibilityScope
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -99,7 +115,10 @@ fun ModelNameAndStatus(
|
|||
color = MaterialTheme.colorScheme.error,
|
||||
style = labelSmallNarrow,
|
||||
overflow = TextOverflow.Ellipsis,
|
||||
modifier = modifier,
|
||||
modifier = Modifier.sharedElement(
|
||||
rememberSharedContentState(key = "failure_messsage"),
|
||||
animatedVisibilityScope = animatedVisibilityScope
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -154,7 +173,12 @@ fun ModelNameAndStatus(
|
|||
style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp),
|
||||
textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start,
|
||||
overflow = TextOverflow.Visible,
|
||||
modifier = modifier.offset(y = if (index == 0) 0.dp else (-1).dp)
|
||||
modifier = Modifier
|
||||
.offset(y = if (index == 0) 0.dp else (-1).dp)
|
||||
.sharedElement(
|
||||
rememberSharedContentState(key = "status_label_${index}"),
|
||||
animatedVisibilityScope = animatedVisibilityScope
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -168,7 +192,12 @@ fun ModelNameAndStatus(
|
|||
progress = { animatedProgress.value },
|
||||
color = getTaskIconColor(task = task),
|
||||
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||
modifier = modifier.padding(top = 2.dp)
|
||||
modifier = Modifier
|
||||
.padding(top = 2.dp)
|
||||
.sharedElement(
|
||||
rememberSharedContentState(key = "download_progress_bar"),
|
||||
animatedVisibilityScope = animatedVisibilityScope
|
||||
)
|
||||
)
|
||||
LaunchedEffect(curDownloadProgress) {
|
||||
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
|
||||
|
@ -180,8 +209,13 @@ fun ModelNameAndStatus(
|
|||
color = getTaskIconColor(task = task),
|
||||
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||
modifier = Modifier
|
||||
.padding(top = 2.dp),
|
||||
.padding(top = 2.dp)
|
||||
.sharedElement(
|
||||
rememberSharedContentState(key = "unzip_progress_bar"),
|
||||
animatedVisibilityScope = animatedVisibilityScope
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,8 +16,11 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.home
|
||||
|
||||
import android.app.Activity
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import android.net.Uri
|
||||
import android.provider.OpenableColumns
|
||||
import android.util.Log
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.ActivityResultLauncher
|
||||
|
@ -49,6 +52,7 @@ import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
|||
import androidx.compose.material.icons.filled.Add
|
||||
import androidx.compose.material.icons.rounded.Error
|
||||
import androidx.compose.material3.AlertDialog
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.CardDefaults
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
|
@ -80,12 +84,18 @@ import androidx.compose.ui.draw.clip
|
|||
import androidx.compose.ui.draw.scale
|
||||
import androidx.compose.ui.graphics.Brush
|
||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||
import androidx.compose.ui.layout.layout
|
||||
import androidx.compose.ui.platform.LocalConfiguration
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalDensity
|
||||
import androidx.compose.ui.platform.LocalWindowInfo
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.text.LinkAnnotation
|
||||
import androidx.compose.ui.text.SpanStyle
|
||||
import androidx.compose.ui.text.TextLinkStyles
|
||||
import androidx.compose.ui.text.buildAnnotatedString
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.text.style.TextDecoration
|
||||
import androidx.compose.ui.text.withLink
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
|
@ -93,7 +103,6 @@ import com.google.aiedge.gallery.GalleryTopAppBar
|
|||
import com.google.aiedge.gallery.R
|
||||
import com.google.aiedge.gallery.data.AppBarAction
|
||||
import com.google.aiedge.gallery.data.AppBarActionType
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.TaskIcon
|
||||
|
@ -101,7 +110,6 @@ import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
|||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
|
||||
import kotlinx.coroutines.delay
|
||||
|
@ -109,6 +117,12 @@ import kotlinx.coroutines.launch
|
|||
|
||||
private const val TAG = "AGHomeScreen"
|
||||
private const val TASK_COUNT_ANIMATION_DURATION = 250
|
||||
private const val MAX_TASK_CARD_PADDING = 24
|
||||
private const val MIN_TASK_CARD_PADDING = 18
|
||||
private const val MAX_TASK_CARD_RADIUS = 43.5
|
||||
private const val MIN_TASK_CARD_RADIUS = 30
|
||||
private const val MAX_TASK_CARD_ICON_SIZE = 56
|
||||
private const val MIN_TASK_CARD_ICON_SIZE = 50
|
||||
|
||||
/** Navigation destination data */
|
||||
object HomeScreenDestination {
|
||||
|
@ -127,6 +141,7 @@ fun HomeScreen(
|
|||
val uiState by modelManagerViewModel.uiState.collectAsState()
|
||||
var showSettingsDialog by remember { mutableStateOf(false) }
|
||||
var showImportModelSheet by remember { mutableStateOf(false) }
|
||||
var showUnsupportedFileTypeDialog by remember { mutableStateOf(false) }
|
||||
val sheetState = rememberModalBottomSheetState()
|
||||
var showImportDialog by remember { mutableStateOf(false) }
|
||||
var showImportingDialog by remember { mutableStateOf(false) }
|
||||
|
@ -135,17 +150,21 @@ fun HomeScreen(
|
|||
val coroutineScope = rememberCoroutineScope()
|
||||
val snackbarHostState = remember { SnackbarHostState() }
|
||||
val scope = rememberCoroutineScope()
|
||||
|
||||
val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 }
|
||||
val loadingHfModels = uiState.loadingHfModels
|
||||
val context = LocalContext.current
|
||||
|
||||
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
||||
contract = ActivityResultContracts.StartActivityForResult()
|
||||
) { result ->
|
||||
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
||||
result.data?.data?.let { uri ->
|
||||
val fileName = getFileName(context = context, uri = uri)
|
||||
Log.d(TAG, "Selected file: $fileName")
|
||||
if (fileName != null && !fileName.endsWith(".task")) {
|
||||
showUnsupportedFileTypeDialog = true
|
||||
} else {
|
||||
selectedLocalModelFileUri.value = uri
|
||||
showImportDialog = true
|
||||
}
|
||||
} ?: run {
|
||||
Log.d(TAG, "No file selected or URI is null.")
|
||||
}
|
||||
|
@ -154,21 +173,15 @@ fun HomeScreen(
|
|||
}
|
||||
}
|
||||
|
||||
Scaffold(
|
||||
modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
|
||||
topBar = {
|
||||
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = {
|
||||
GalleryTopAppBar(
|
||||
title = stringResource(HomeScreenDestination.titleRes),
|
||||
rightAction = AppBarAction(
|
||||
actionType = AppBarActionType.APP_SETTING, actionFn = {
|
||||
rightAction = AppBarAction(actionType = AppBarActionType.APP_SETTING, actionFn = {
|
||||
showSettingsDialog = true
|
||||
}
|
||||
),
|
||||
loadingHfModels = loadingHfModels,
|
||||
}),
|
||||
scrollBehavior = scrollBehavior,
|
||||
)
|
||||
},
|
||||
floatingActionButton = {
|
||||
}, floatingActionButton = {
|
||||
// A floating action button to show "import model" bottom sheet.
|
||||
SmallFloatingActionButton(
|
||||
onClick = {
|
||||
|
@ -179,11 +192,10 @@ fun HomeScreen(
|
|||
) {
|
||||
Icon(Icons.Filled.Add, "")
|
||||
}
|
||||
}
|
||||
) { innerPadding ->
|
||||
}) { innerPadding ->
|
||||
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
|
||||
TaskList(
|
||||
tasks = nonEmptyTasks,
|
||||
tasks = uiState.tasks,
|
||||
navigateToTaskScreen = navigateToTaskScreen,
|
||||
loadingModelAllowlist = uiState.loadingModelAllowlist,
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
|
@ -198,16 +210,8 @@ fun HomeScreen(
|
|||
if (showSettingsDialog) {
|
||||
SettingsDialog(
|
||||
curThemeOverride = modelManagerViewModel.readThemeOverride(),
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onDismissed = { showSettingsDialog = false },
|
||||
onOk = { curConfigValues ->
|
||||
// Update theme settings.
|
||||
// This will update app's theme.
|
||||
val themeOverride = curConfigValues[ConfigKey.THEME.label] as String
|
||||
ThemeSettings.themeOverride.value = themeOverride
|
||||
|
||||
// Save to data store.
|
||||
modelManagerViewModel.saveThemeOverride(themeOverride)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -232,10 +236,6 @@ fun HomeScreen(
|
|||
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||
addCategory(Intent.CATEGORY_OPENABLE)
|
||||
type = "*/*"
|
||||
putExtra(
|
||||
Intent.EXTRA_MIME_TYPES,
|
||||
arrayOf("application/x-binary", "application/octet-stream")
|
||||
)
|
||||
// Single select.
|
||||
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
||||
}
|
||||
|
@ -259,9 +259,7 @@ fun HomeScreen(
|
|||
// Import dialog
|
||||
if (showImportDialog) {
|
||||
selectedLocalModelFileUri.value?.let { uri ->
|
||||
ModelImportDialog(uri = uri,
|
||||
onDismiss = { showImportDialog = false },
|
||||
onDone = { info ->
|
||||
ModelImportDialog(uri = uri, onDismiss = { showImportDialog = false }, onDone = { info ->
|
||||
selectedImportedModelInfo.value = info
|
||||
showImportDialog = false
|
||||
showImportingDialog = true
|
||||
|
@ -273,8 +271,7 @@ fun HomeScreen(
|
|||
if (showImportingDialog) {
|
||||
selectedLocalModelFileUri.value?.let { uri ->
|
||||
selectedImportedModelInfo.value?.let { info ->
|
||||
ModelImportingDialog(
|
||||
uri = uri,
|
||||
ModelImportingDialog(uri = uri,
|
||||
info = info,
|
||||
onDismiss = { showImportingDialog = false },
|
||||
onDone = {
|
||||
|
@ -292,6 +289,22 @@ fun HomeScreen(
|
|||
}
|
||||
}
|
||||
|
||||
// Alert dialog for unsupported file type.
|
||||
if (showUnsupportedFileTypeDialog) {
|
||||
AlertDialog(
|
||||
onDismissRequest = { showUnsupportedFileTypeDialog = false },
|
||||
title = { Text("Unsupported file type") },
|
||||
text = {
|
||||
Text("Only \".task\" file type is supported.")
|
||||
},
|
||||
confirmButton = {
|
||||
Button(onClick = { showUnsupportedFileTypeDialog = false }) {
|
||||
Text(stringResource(R.string.ok))
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if (uiState.loadingModelAllowlistError.isNotEmpty()) {
|
||||
AlertDialog(
|
||||
icon = {
|
||||
|
@ -307,11 +320,9 @@ fun HomeScreen(
|
|||
modelManagerViewModel.loadModelAllowlist()
|
||||
},
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
TextButton(onClick = {
|
||||
modelManagerViewModel.loadModelAllowlist()
|
||||
}
|
||||
) {
|
||||
}) {
|
||||
Text("Retry")
|
||||
}
|
||||
},
|
||||
|
@ -327,6 +338,38 @@ private fun TaskList(
|
|||
modifier: Modifier = Modifier,
|
||||
contentPadding: PaddingValues = PaddingValues(0.dp),
|
||||
) {
|
||||
val density = LocalDensity.current
|
||||
val windowInfo = LocalWindowInfo.current
|
||||
val screenWidthDp = remember {
|
||||
with(density) {
|
||||
windowInfo.containerSize.width.toDp()
|
||||
}
|
||||
}
|
||||
val screenHeightDp = remember {
|
||||
with(density) {
|
||||
windowInfo.containerSize.height.toDp()
|
||||
}
|
||||
}
|
||||
val sizeFraction = remember { ((screenWidthDp - 360.dp) / (410.dp - 360.dp)).coerceIn(0f, 1f) }
|
||||
val linkColor = MaterialTheme.customColors.linkColor
|
||||
|
||||
val introText = buildAnnotatedString {
|
||||
append("Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from ")
|
||||
withLink(
|
||||
link = LinkAnnotation.Url(
|
||||
url = "https://huggingface.co/litert-community", // Replace with the actual URL
|
||||
styles = TextLinkStyles(
|
||||
style = SpanStyle(
|
||||
color = linkColor,
|
||||
textDecoration = TextDecoration.Underline,
|
||||
)
|
||||
)
|
||||
)
|
||||
) {
|
||||
append("LiteRT community")
|
||||
}
|
||||
}
|
||||
|
||||
Box(modifier = modifier.fillMaxSize()) {
|
||||
LazyVerticalGrid(
|
||||
columns = GridCells.Fixed(count = 2),
|
||||
|
@ -335,10 +378,15 @@ private fun TaskList(
|
|||
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||
) {
|
||||
// New rel
|
||||
item(key = "newReleaseNotification", span = { GridItemSpan(2) }) {
|
||||
NewReleaseNotification()
|
||||
}
|
||||
|
||||
// Headline.
|
||||
item(key = "headline", span = { GridItemSpan(2) }) {
|
||||
Text(
|
||||
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community",
|
||||
introText,
|
||||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||
modifier = Modifier.padding(bottom = 20.dp)
|
||||
|
@ -364,14 +412,21 @@ private fun TaskList(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
// Cards.
|
||||
// LLM Cards.
|
||||
item(key = "llmCardsHeader", span = { GridItemSpan(2) }) {
|
||||
Text(
|
||||
"Example LLM Use Cases",
|
||||
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||
modifier = Modifier.padding(bottom = 4.dp)
|
||||
)
|
||||
}
|
||||
|
||||
items(tasks) { task ->
|
||||
TaskCard(
|
||||
task = task,
|
||||
onClick = {
|
||||
sizeFraction = sizeFraction, task = task, onClick = {
|
||||
navigateToTaskScreen(task)
|
||||
},
|
||||
modifier = Modifier
|
||||
}, modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.aspectRatio(1f)
|
||||
)
|
||||
|
@ -388,7 +443,7 @@ private fun TaskList(
|
|||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(LocalConfiguration.current.screenHeightDp.dp * 0.25f)
|
||||
.height(screenHeightDp * 0.25f)
|
||||
.background(
|
||||
Brush.verticalGradient(
|
||||
colors = MaterialTheme.customColors.homeBottomGradient,
|
||||
|
@ -400,7 +455,15 @@ private fun TaskList(
|
|||
}
|
||||
|
||||
@Composable
|
||||
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) {
|
||||
private fun TaskCard(
|
||||
task: Task, onClick: () -> Unit, sizeFraction: Float, modifier: Modifier = Modifier
|
||||
) {
|
||||
val padding =
|
||||
(MAX_TASK_CARD_PADDING - MIN_TASK_CARD_PADDING) * sizeFraction + MIN_TASK_CARD_PADDING
|
||||
val radius = (MAX_TASK_CARD_RADIUS - MIN_TASK_CARD_RADIUS) * sizeFraction + MIN_TASK_CARD_RADIUS
|
||||
val iconSize =
|
||||
(MAX_TASK_CARD_ICON_SIZE - MIN_TASK_CARD_ICON_SIZE) * sizeFraction + MIN_TASK_CARD_ICON_SIZE
|
||||
|
||||
// Observes the model count and updates the model count label with a fade-in/fade-out animation
|
||||
// whenever the count changes.
|
||||
val modelCount by remember {
|
||||
|
@ -445,7 +508,7 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
|
|||
|
||||
Card(
|
||||
modifier = modifier
|
||||
.clip(RoundedCornerShape(43.5.dp))
|
||||
.clip(RoundedCornerShape(radius.dp))
|
||||
.clickable(
|
||||
onClick = onClick,
|
||||
),
|
||||
|
@ -456,39 +519,24 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
|
|||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(24.dp),
|
||||
.padding(padding.dp),
|
||||
) {
|
||||
// Icon.
|
||||
TaskIcon(task = task)
|
||||
TaskIcon(task = task, width = iconSize.dp)
|
||||
|
||||
Spacer(modifier = Modifier.weight(1f))
|
||||
Spacer(modifier = Modifier.weight(2f))
|
||||
|
||||
// Title.
|
||||
val pair = task.type.label.splitByFirstSpace()
|
||||
Text(
|
||||
pair.first,
|
||||
task.type.label,
|
||||
color = MaterialTheme.colorScheme.primary,
|
||||
style = titleMediumNarrow.copy(
|
||||
fontSize = 20.sp,
|
||||
fontWeight = FontWeight.Bold,
|
||||
),
|
||||
)
|
||||
if (pair.second.isNotEmpty()) {
|
||||
Text(
|
||||
pair.second,
|
||||
color = MaterialTheme.colorScheme.primary,
|
||||
style = titleMediumNarrow.copy(
|
||||
fontSize = 18.sp,
|
||||
fontWeight = FontWeight.Bold,
|
||||
),
|
||||
modifier = Modifier.layout { measurable, constraints ->
|
||||
val placeable = measurable.measure(constraints)
|
||||
layout(placeable.width, placeable.height) {
|
||||
placeable.placeRelative(0, -4.dp.roundToPx())
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.weight(1f))
|
||||
|
||||
// Model count.
|
||||
Text(
|
||||
|
@ -503,12 +551,21 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
|
|||
}
|
||||
}
|
||||
|
||||
private fun String.splitByFirstSpace(): Pair<String, String> {
|
||||
val spaceIndex = this.indexOf(' ')
|
||||
if (spaceIndex == -1) {
|
||||
return Pair(this, "")
|
||||
// Helper function to get the file name from a URI
|
||||
fun getFileName(context: Context, uri: Uri): String? {
|
||||
if (uri.scheme == "content") {
|
||||
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||
if (cursor.moveToFirst()) {
|
||||
val nameIndex = cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME)
|
||||
if (nameIndex != -1) {
|
||||
return cursor.getString(nameIndex)
|
||||
}
|
||||
return Pair(this.substring(0, spaceIndex), this.substring(spaceIndex + 1))
|
||||
}
|
||||
}
|
||||
} else if (uri.scheme == "file") {
|
||||
return uri.lastPathSegment
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
@Preview
|
||||
|
|
|
@ -22,12 +22,16 @@ import android.provider.OpenableColumns
|
|||
import android.util.Log
|
||||
import androidx.compose.animation.core.Animatable
|
||||
import androidx.compose.animation.core.tween
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.Error
|
||||
import androidx.compose.material3.Button
|
||||
|
@ -51,6 +55,7 @@ import androidx.compose.runtime.snapshots.SnapshotStateMap
|
|||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalFocusManager
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.window.Dialog
|
||||
import androidx.compose.ui.window.DialogProperties
|
||||
|
@ -82,41 +87,34 @@ import java.nio.charset.StandardCharsets
|
|||
private const val TAG = "AGModelImportDialog"
|
||||
|
||||
private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
|
||||
LabelConfig(key = ConfigKey.NAME),
|
||||
LabelConfig(key = ConfigKey.MODEL_TYPE),
|
||||
NumberSliderConfig(
|
||||
LabelConfig(key = ConfigKey.NAME), LabelConfig(key = ConfigKey.MODEL_TYPE), NumberSliderConfig(
|
||||
key = ConfigKey.DEFAULT_MAX_TOKENS,
|
||||
sliderMin = 100f,
|
||||
sliderMax = 1024f,
|
||||
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
|
||||
valueType = ValueType.INT
|
||||
),
|
||||
NumberSliderConfig(
|
||||
), NumberSliderConfig(
|
||||
key = ConfigKey.DEFAULT_TOPK,
|
||||
sliderMin = 5f,
|
||||
sliderMax = 40f,
|
||||
defaultValue = DEFAULT_TOPK.toFloat(),
|
||||
valueType = ValueType.INT
|
||||
),
|
||||
NumberSliderConfig(
|
||||
), NumberSliderConfig(
|
||||
key = ConfigKey.DEFAULT_TOPP,
|
||||
sliderMin = 0.0f,
|
||||
sliderMax = 1.0f,
|
||||
defaultValue = DEFAULT_TOPP,
|
||||
valueType = ValueType.FLOAT
|
||||
),
|
||||
NumberSliderConfig(
|
||||
), NumberSliderConfig(
|
||||
key = ConfigKey.DEFAULT_TEMPERATURE,
|
||||
sliderMin = 0.0f,
|
||||
sliderMax = 2.0f,
|
||||
defaultValue = DEFAULT_TEMPERATURE,
|
||||
valueType = ValueType.FLOAT
|
||||
),
|
||||
BooleanSwitchConfig(
|
||||
), BooleanSwitchConfig(
|
||||
key = ConfigKey.SUPPORT_IMAGE,
|
||||
defaultValue = false,
|
||||
),
|
||||
SegmentedButtonConfig(
|
||||
), SegmentedButtonConfig(
|
||||
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
||||
defaultValue = Accelerator.CPU.label,
|
||||
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
|
||||
|
@ -126,9 +124,7 @@ private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
|
|||
|
||||
@Composable
|
||||
fun ModelImportDialog(
|
||||
uri: Uri,
|
||||
onDismiss: () -> Unit,
|
||||
onDone: (ImportedModelInfo) -> Unit
|
||||
uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
|
||||
) {
|
||||
val context = LocalContext.current
|
||||
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
|
||||
|
@ -150,15 +146,23 @@ fun ModelImportDialog(
|
|||
putAll(initialValues)
|
||||
}
|
||||
}
|
||||
val interactionSource = remember { MutableInteractionSource() }
|
||||
|
||||
Dialog(
|
||||
onDismissRequest = onDismiss,
|
||||
) {
|
||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||
Column(
|
||||
val focusManager = LocalFocusManager.current
|
||||
Card(
|
||||
modifier = Modifier
|
||||
.padding(20.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
.fillMaxWidth()
|
||||
.clickable(
|
||||
interactionSource = interactionSource, indication = null // Disable the ripple effect
|
||||
) {
|
||||
focusManager.clearFocus()
|
||||
}, shape = RoundedCornerShape(16.dp)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Title.
|
||||
Text(
|
||||
|
@ -167,11 +171,18 @@ fun ModelImportDialog(
|
|||
modifier = Modifier.padding(bottom = 8.dp)
|
||||
)
|
||||
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.verticalScroll(rememberScrollState())
|
||||
.weight(1f, fill = false),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Default configs for users to set.
|
||||
ConfigEditorsPanel(
|
||||
configs = IMPORT_CONFIGS_LLM,
|
||||
values = values,
|
||||
)
|
||||
}
|
||||
|
||||
// Button row.
|
||||
Row(
|
||||
|
@ -210,10 +221,7 @@ fun ModelImportDialog(
|
|||
|
||||
@Composable
|
||||
fun ModelImportingDialog(
|
||||
uri: Uri,
|
||||
info: ImportedModelInfo,
|
||||
onDismiss: () -> Unit,
|
||||
onDone: (ImportedModelInfo) -> Unit
|
||||
uri: Uri, info: ImportedModelInfo, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
|
||||
) {
|
||||
var error by remember { mutableStateOf("") }
|
||||
val context = LocalContext.current
|
||||
|
@ -222,8 +230,7 @@ fun ModelImportingDialog(
|
|||
|
||||
LaunchedEffect(Unit) {
|
||||
// Import.
|
||||
importModel(
|
||||
context = context,
|
||||
importModel(context = context,
|
||||
coroutineScope = coroutineScope,
|
||||
fileName = info.fileName,
|
||||
fileSize = info.fileSize,
|
||||
|
@ -236,8 +243,7 @@ fun ModelImportingDialog(
|
|||
},
|
||||
onError = {
|
||||
error = it
|
||||
}
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
Dialog(
|
||||
|
@ -246,9 +252,7 @@ fun ModelImportingDialog(
|
|||
) {
|
||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(20.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Title.
|
||||
Text(
|
||||
|
@ -280,13 +284,10 @@ fun ModelImportingDialog(
|
|||
// Has error.
|
||||
else {
|
||||
Row(
|
||||
verticalAlignment = Alignment.Top,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||
verticalAlignment = Alignment.Top, horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.Error,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.error
|
||||
Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error
|
||||
)
|
||||
Text(
|
||||
error,
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
package com.google.aiedge.gallery.ui.home
|
||||
|
||||
import android.util.Log
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import androidx.compose.animation.expandVertically
|
||||
import androidx.compose.animation.fadeIn
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.rounded.OpenInNew
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.BuildConfig
|
||||
import com.google.aiedge.gallery.ui.common.getJsonResponse
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ClickableLink
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlin.math.max
|
||||
|
||||
private const val TAG = "AGNewReleaseNotification"
|
||||
private const val REPO = "google-ai-edge/gallery"
|
||||
|
||||
@Serializable
|
||||
data class ReleaseInfo(
|
||||
val html_url: String,
|
||||
val tag_name: String,
|
||||
)
|
||||
|
||||
@Composable
|
||||
fun NewReleaseNotification() {
|
||||
var newReleaseVersion by remember { mutableStateOf("") }
|
||||
var newReleaseUrl by remember { mutableStateOf("") }
|
||||
|
||||
LaunchedEffect(Unit) {
|
||||
withContext(Dispatchers.IO) {
|
||||
Log.d("AGNewReleaseNotification", "Checking for new release...")
|
||||
val info = getJsonResponse<ReleaseInfo>("https://api.github.com/repos/$REPO/releases/latest")
|
||||
if (info != null) {
|
||||
val curRelease = BuildConfig.VERSION_NAME
|
||||
val newRelease = info.tag_name
|
||||
val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease)
|
||||
Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer")
|
||||
if (isNewer) {
|
||||
newReleaseVersion = newRelease
|
||||
newReleaseUrl = info.html_url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnimatedVisibility(
|
||||
visible = newReleaseVersion.isNotEmpty(),
|
||||
enter = fadeIn() + expandVertically()
|
||||
) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.SpaceBetween,
|
||||
modifier = Modifier
|
||||
.padding(horizontal = 16.dp)
|
||||
.padding(bottom = 12.dp)
|
||||
.clip(
|
||||
CircleShape
|
||||
)
|
||||
.background(MaterialTheme.colorScheme.tertiaryContainer)
|
||||
.padding(4.dp)
|
||||
) {
|
||||
Text(
|
||||
"New release $newReleaseVersion available",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
modifier = Modifier.padding(start = 12.dp)
|
||||
)
|
||||
Row(
|
||||
modifier = Modifier.padding(end = 12.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
ClickableLink(
|
||||
url = newReleaseUrl,
|
||||
linkText = "View",
|
||||
icon = Icons.AutoMirrored.Rounded.OpenInNew,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun isNewerRelease(currentRelease: String, newRelease: String): Boolean {
|
||||
// Split the version strings into their individual components (e.g., "0.9.0" -> ["0", "9", "0"])
|
||||
val currentComponents = currentRelease.split('.').map { it.toIntOrNull() ?: 0 }
|
||||
val newComponents = newRelease.split('.').map { it.toIntOrNull() ?: 0 }
|
||||
|
||||
// Determine the maximum number of components to iterate through
|
||||
val maxComponents = max(currentComponents.size, newComponents.size)
|
||||
|
||||
// Iterate through the components from left to right (major, minor, patch, etc.)
|
||||
for (i in 0 until maxComponents) {
|
||||
val currentComponent = currentComponents.getOrElse(i) { 0 }
|
||||
val newComponent = newComponents.getOrElse(i) { 0 }
|
||||
|
||||
if (newComponent > currentComponent) {
|
||||
return true
|
||||
} else if (newComponent < currentComponent) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
|
@ -16,45 +16,266 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.home
|
||||
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.wrapContentHeight
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.foundation.text.BasicTextField
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.CheckCircle
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.MultiChoiceSegmentedButtonRow
|
||||
import androidx.compose.material3.OutlinedButton
|
||||
import androidx.compose.material3.SegmentedButton
|
||||
import androidx.compose.material3.SegmentedButtonDefaults
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.focus.FocusRequester
|
||||
import androidx.compose.ui.focus.focusRequester
|
||||
import androidx.compose.ui.focus.onFocusChanged
|
||||
import androidx.compose.ui.graphics.SolidColor
|
||||
import androidx.compose.ui.platform.LocalFocusManager
|
||||
import androidx.compose.ui.text.TextStyle
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.window.Dialog
|
||||
import com.google.aiedge.gallery.BuildConfig
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.theme.THEME_AUTO
|
||||
import com.google.aiedge.gallery.ui.theme.THEME_DARK
|
||||
import com.google.aiedge.gallery.ui.theme.THEME_LIGHT
|
||||
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
||||
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
|
||||
import java.time.Instant
|
||||
import java.time.ZoneId
|
||||
import java.time.format.DateTimeFormatter
|
||||
import java.util.Locale
|
||||
import kotlin.math.min
|
||||
|
||||
private val CONFIGS: List<Config> = listOf(
|
||||
SegmentedButtonConfig(
|
||||
key = ConfigKey.THEME,
|
||||
defaultValue = THEME_AUTO,
|
||||
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK),
|
||||
)
|
||||
)
|
||||
private val THEME_OPTIONS = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK)
|
||||
|
||||
@Composable
|
||||
fun SettingsDialog(
|
||||
curThemeOverride: String,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
onDismissed: () -> Unit,
|
||||
onOk: (Map<String, Any>) -> Unit,
|
||||
) {
|
||||
val initialValues = mapOf(
|
||||
ConfigKey.THEME.label to curThemeOverride
|
||||
)
|
||||
ConfigDialog(
|
||||
title = "Settings",
|
||||
subtitle = "App version: ${BuildConfig.VERSION_NAME}",
|
||||
okBtnLabel = "OK",
|
||||
configs = CONFIGS,
|
||||
initialValues = initialValues,
|
||||
onDismissed = onDismissed,
|
||||
onOk = { curConfigValues ->
|
||||
onOk(curConfigValues)
|
||||
var selectedTheme by remember { mutableStateOf(curThemeOverride) }
|
||||
var hfToken by remember { mutableStateOf(modelManagerViewModel.getTokenStatusAndData().data) }
|
||||
val dateFormatter = remember {
|
||||
DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.systemDefault())
|
||||
.withLocale(Locale.getDefault())
|
||||
}
|
||||
var customHfToken by remember { mutableStateOf("") }
|
||||
var isFocused by remember { mutableStateOf(false) }
|
||||
val focusRequester = remember { FocusRequester() }
|
||||
val interactionSource = remember { MutableInteractionSource() }
|
||||
|
||||
// Hide config dialog.
|
||||
Dialog(onDismissRequest = onDismissed) {
|
||||
val focusManager = LocalFocusManager.current
|
||||
Card(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.clickable(
|
||||
interactionSource = interactionSource, indication = null // Disable the ripple effect
|
||||
) {
|
||||
focusManager.clearFocus()
|
||||
}, shape = RoundedCornerShape(16.dp)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Dialog title and subtitle.
|
||||
Column {
|
||||
Text(
|
||||
"Settings",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
modifier = Modifier.padding(bottom = 8.dp)
|
||||
)
|
||||
// Subtitle.
|
||||
Text(
|
||||
"App version: ${BuildConfig.VERSION_NAME}",
|
||||
style = labelSmallNarrow,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||
modifier = Modifier.offset(y = (-6).dp)
|
||||
)
|
||||
}
|
||||
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.verticalScroll(rememberScrollState())
|
||||
.weight(1f, fill = false),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Theme switcher.
|
||||
Column(
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Text(
|
||||
"Theme",
|
||||
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
|
||||
)
|
||||
MultiChoiceSegmentedButtonRow {
|
||||
THEME_OPTIONS.forEachIndexed { index, label ->
|
||||
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
|
||||
index = index, count = THEME_OPTIONS.size
|
||||
), onCheckedChange = {
|
||||
selectedTheme = label
|
||||
|
||||
// Update theme settings.
|
||||
// This will update app's theme.
|
||||
ThemeSettings.themeOverride.value = label
|
||||
|
||||
// Save to data store.
|
||||
modelManagerViewModel.saveThemeOverride(label)
|
||||
}, checked = label == selectedTheme, label = { Text(label) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HF Token management.
|
||||
Column(
|
||||
modifier = Modifier.fillMaxWidth(), verticalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
Text(
|
||||
"HuggingFace access token",
|
||||
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
|
||||
)
|
||||
// Show the start of the token.
|
||||
val curHfToken = hfToken
|
||||
if (curHfToken != null) {
|
||||
Text(
|
||||
curHfToken.accessToken.substring(0, min(16, curHfToken.accessToken.length)) + "...",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Text(
|
||||
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
} else {
|
||||
Text(
|
||||
"Not available",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Text(
|
||||
"The token will be automatically retrieved when a gated model is downloaded",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
|
||||
OutlinedButton(
|
||||
onClick = {
|
||||
modelManagerViewModel.clearAccessToken()
|
||||
hfToken = null
|
||||
}, enabled = curHfToken != null
|
||||
) {
|
||||
Text("Clear")
|
||||
}
|
||||
BasicTextField(
|
||||
value = customHfToken,
|
||||
singleLine = true,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(top = 4.dp)
|
||||
.focusRequester(focusRequester)
|
||||
.onFocusChanged {
|
||||
isFocused = it.isFocused
|
||||
},
|
||||
onValueChange = {
|
||||
customHfToken = it
|
||||
},
|
||||
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
|
||||
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
|
||||
) { innerTextField ->
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.border(
|
||||
width = if (isFocused) 2.dp else 1.dp,
|
||||
color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline,
|
||||
shape = CircleShape,
|
||||
)
|
||||
.height(40.dp), contentAlignment = Alignment.CenterStart
|
||||
) {
|
||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.padding(start = 16.dp)
|
||||
.weight(1f)
|
||||
) {
|
||||
if (customHfToken.isEmpty()) {
|
||||
Text(
|
||||
"Enter token manually",
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||
style = MaterialTheme.typography.bodySmall
|
||||
)
|
||||
}
|
||||
innerTextField()
|
||||
}
|
||||
if (customHfToken.isNotEmpty()) {
|
||||
IconButton(
|
||||
modifier = Modifier.offset(x = 1.dp),
|
||||
onClick = {
|
||||
modelManagerViewModel.saveAccessToken(
|
||||
accessToken = customHfToken,
|
||||
refreshToken = "",
|
||||
expiresAt = System.currentTimeMillis() + 1000L * 60 * 60 * 24 * 365 * 10,
|
||||
)
|
||||
hfToken = modelManagerViewModel.getTokenStatusAndData().data
|
||||
}) {
|
||||
Icon(Icons.Rounded.CheckCircle, contentDescription = "")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Button row.
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(top = 8.dp),
|
||||
horizontalArrangement = Arrangement.End,
|
||||
) {
|
||||
// Close button
|
||||
Button(
|
||||
onClick = {
|
||||
onDismissed()
|
||||
},
|
||||
)
|
||||
) {
|
||||
Text("Close")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,9 +34,9 @@ object LlmChatDestination {
|
|||
val route = "LlmChatRoute"
|
||||
}
|
||||
|
||||
object LlmImageToTextDestination {
|
||||
object LlmAskImageDestination {
|
||||
@Serializable
|
||||
val route = "LlmImageToTextRoute"
|
||||
val route = "LlmAskImageRoute"
|
||||
}
|
||||
|
||||
@Composable
|
||||
|
@ -57,11 +57,11 @@ fun LlmChatScreen(
|
|||
}
|
||||
|
||||
@Composable
|
||||
fun LlmImageToTextScreen(
|
||||
fun LlmAskImageScreen(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
navigateUp: () -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
viewModel: LlmImageToTextViewModel = viewModel(
|
||||
viewModel: LlmAskImageViewModel = viewModel(
|
||||
factory = ViewModelProvider.Factory
|
||||
),
|
||||
) {
|
||||
|
@ -129,12 +129,7 @@ fun ChatViewWrapper(
|
|||
})
|
||||
}
|
||||
},
|
||||
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
|
||||
if (message is ChatMessageText) {
|
||||
viewModel.benchmark(
|
||||
model = model, message = message
|
||||
)
|
||||
}
|
||||
onBenchmarkClicked = { _, _, _, _ ->
|
||||
},
|
||||
onResetSessionClicked = { model ->
|
||||
viewModel.resetSession(model = model)
|
||||
|
|
|
@ -22,7 +22,7 @@ import android.util.Log
|
|||
import androidx.lifecycle.viewModelScope
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
||||
|
@ -49,6 +49,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
setInProgress(true)
|
||||
setPreparing(true)
|
||||
|
||||
// Loading.
|
||||
addMessage(
|
||||
|
@ -88,6 +89,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||
prefillSpeed = prefillTokens / timeToFirstToken
|
||||
firstRun = false
|
||||
setPreparing(false)
|
||||
} else {
|
||||
decodeTokens++
|
||||
}
|
||||
|
@ -137,10 +139,12 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
},
|
||||
cleanUpListener = {
|
||||
setInProgress(false)
|
||||
setPreparing(false)
|
||||
})
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error occurred while running inference", e)
|
||||
setInProgress(false)
|
||||
setPreparing(false)
|
||||
onError()
|
||||
}
|
||||
}
|
||||
|
@ -194,98 +198,6 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
}
|
||||
}
|
||||
|
||||
fun benchmark(model: Model, message: ChatMessageText) {
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
setInProgress(true)
|
||||
|
||||
// Wait for model to be initialized.
|
||||
while (model.instance == null) {
|
||||
delay(100)
|
||||
}
|
||||
val instance = model.instance as LlmModelInstance
|
||||
val prefillTokens = instance.session.sizeInTokens(message.content)
|
||||
|
||||
// Add the message to show benchmark results.
|
||||
val benchmarkLlmResult = ChatMessageBenchmarkLlmResult(
|
||||
orderedStats = STATS,
|
||||
statValues = mutableMapOf(),
|
||||
running = true,
|
||||
latencyMs = -1f,
|
||||
)
|
||||
addMessage(model = model, message = benchmarkLlmResult)
|
||||
|
||||
// Run inference.
|
||||
val result = StringBuilder()
|
||||
var firstRun = true
|
||||
var timeToFirstToken = 0f
|
||||
var firstTokenTs = 0L
|
||||
var decodeTokens = 0
|
||||
var prefillSpeed = 0f
|
||||
var decodeSpeed: Float
|
||||
val start = System.currentTimeMillis()
|
||||
var lastUpdateTime = 0L
|
||||
LlmChatModelHelper.runInference(model = model,
|
||||
input = message.content,
|
||||
resultListener = { partialResult, done ->
|
||||
val curTs = System.currentTimeMillis()
|
||||
|
||||
if (firstRun) {
|
||||
firstTokenTs = System.currentTimeMillis()
|
||||
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||
prefillSpeed = prefillTokens / timeToFirstToken
|
||||
firstRun = false
|
||||
|
||||
// Update message to show prefill speed.
|
||||
replaceLastMessage(
|
||||
model = model,
|
||||
message = ChatMessageBenchmarkLlmResult(
|
||||
orderedStats = STATS,
|
||||
statValues = mutableMapOf(
|
||||
"prefill_speed" to prefillSpeed,
|
||||
"time_to_first_token" to timeToFirstToken,
|
||||
"latency" to (curTs - start).toFloat() / 1000f,
|
||||
),
|
||||
running = false,
|
||||
latencyMs = -1f,
|
||||
),
|
||||
type = ChatMessageType.BENCHMARK_LLM_RESULT,
|
||||
)
|
||||
} else {
|
||||
decodeTokens++
|
||||
}
|
||||
result.append(partialResult)
|
||||
|
||||
if (curTs - lastUpdateTime > 500 || done) {
|
||||
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||
if (decodeSpeed.isNaN()) {
|
||||
decodeSpeed = 0f
|
||||
}
|
||||
replaceLastMessage(
|
||||
model = model, message = ChatMessageBenchmarkLlmResult(
|
||||
orderedStats = STATS,
|
||||
statValues = mutableMapOf(
|
||||
"prefill_speed" to prefillSpeed,
|
||||
"decode_speed" to decodeSpeed,
|
||||
"time_to_first_token" to timeToFirstToken,
|
||||
"latency" to (curTs - start).toFloat() / 1000f,
|
||||
),
|
||||
running = !done,
|
||||
latencyMs = -1f,
|
||||
), type = ChatMessageType.BENCHMARK_LLM_RESULT
|
||||
)
|
||||
lastUpdateTime = curTs
|
||||
|
||||
if (done) {
|
||||
setInProgress(false)
|
||||
}
|
||||
}
|
||||
},
|
||||
cleanUpListener = {
|
||||
setInProgress(false)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fun handleError(
|
||||
context: Context,
|
||||
model: Model,
|
||||
|
@ -320,15 +232,8 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
)
|
||||
|
||||
// Re-generate the response automatically.
|
||||
generateResponse(model = model, input = triggeredMessage.content, onError = {
|
||||
handleError(
|
||||
context = context,
|
||||
model = model,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
triggeredMessage = triggeredMessage
|
||||
)
|
||||
})
|
||||
generateResponse(model = model, input = triggeredMessage.content, onError = {})
|
||||
}
|
||||
}
|
||||
|
||||
class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT)
|
||||
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
|
|
@ -73,6 +73,7 @@ fun LlmSingleTurnScreen(
|
|||
) {
|
||||
val task = viewModel.task
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val selectedModel = modelManagerUiState.selectedModel
|
||||
val scope = rememberCoroutineScope()
|
||||
val context = LocalContext.current
|
||||
|
@ -114,6 +115,8 @@ fun LlmSingleTurnScreen(
|
|||
task = task,
|
||||
model = selectedModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
inProgress = uiState.inProgress,
|
||||
modelPreparing = uiState.preparing,
|
||||
onConfigChanged = { _, _ -> },
|
||||
onBackClicked = { handleNavigateUp() },
|
||||
onModelSelected = { newSelectedModel ->
|
||||
|
|
|
@ -20,10 +20,9 @@ import android.util.Log
|
|||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||
|
@ -44,9 +43,9 @@ data class LlmSingleTurnUiState(
|
|||
val inProgress: Boolean = false,
|
||||
|
||||
/**
|
||||
* Indicates whether the model is currently being initialized.
|
||||
* Indicates whether the model is preparing (before outputting any result and after initializing).
|
||||
*/
|
||||
val initializing: Boolean = false,
|
||||
val preparing: Boolean = false,
|
||||
|
||||
// model -> <template label -> response>
|
||||
val responsesByModel: Map<String, Map<String, String>>,
|
||||
|
@ -65,14 +64,14 @@ private val STATS = listOf(
|
|||
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||
)
|
||||
|
||||
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewModel() {
|
||||
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() {
|
||||
private val _uiState = MutableStateFlow(createUiState(task = task))
|
||||
val uiState = _uiState.asStateFlow()
|
||||
|
||||
fun generateResponse(model: Model, input: String) {
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
setInProgress(true)
|
||||
setInitializing(true)
|
||||
setPreparing(true)
|
||||
|
||||
// Wait for instance to be initialized.
|
||||
while (model.instance == null) {
|
||||
|
@ -98,7 +97,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
|
|||
val curTs = System.currentTimeMillis()
|
||||
|
||||
if (firstRun) {
|
||||
setInitializing(false)
|
||||
setPreparing(false)
|
||||
firstTokenTs = System.currentTimeMillis()
|
||||
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||
prefillSpeed = prefillTokens / timeToFirstToken
|
||||
|
@ -148,7 +147,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
|
|||
},
|
||||
singleTurn = true,
|
||||
cleanUpListener = {
|
||||
setInitializing(false)
|
||||
setPreparing(false)
|
||||
setInProgress(false)
|
||||
})
|
||||
}
|
||||
|
@ -167,8 +166,8 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
|
|||
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
||||
}
|
||||
|
||||
fun setInitializing(initializing: Boolean) {
|
||||
_uiState.update { _uiState.value.copy(initializing = initializing) }
|
||||
fun setPreparing(preparing: Boolean) {
|
||||
_uiState.update { _uiState.value.copy(preparing = preparing) }
|
||||
}
|
||||
|
||||
fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) {
|
||||
|
|
|
@ -339,7 +339,7 @@ fun PromptTemplatesPanel(
|
|||
|
||||
val modelInitializing =
|
||||
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
||||
if (inProgress && !modelInitializing) {
|
||||
if (inProgress && !modelInitializing && !uiState.preparing) {
|
||||
IconButton(
|
||||
onClick = {
|
||||
onStopButtonClicked(model)
|
||||
|
|
|
@ -57,7 +57,7 @@ import androidx.compose.ui.platform.LocalClipboardManager
|
|||
import androidx.compose.ui.text.AnnotatedString
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
||||
|
@ -76,11 +76,11 @@ fun ResponsePanel(
|
|||
modelManagerViewModel: ModelManagerViewModel,
|
||||
modifier: Modifier = Modifier,
|
||||
) {
|
||||
val task = TASK_LLM_USECASES
|
||||
val task = TASK_LLM_PROMPT_LAB
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val inProgress = uiState.inProgress
|
||||
val initializing = uiState.initializing
|
||||
val initializing = uiState.preparing
|
||||
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||
val responseScrollState = rememberScrollState()
|
||||
var selectedOptionIndex by remember { mutableIntStateOf(0) }
|
||||
|
|
|
@ -70,7 +70,7 @@ fun ModelList(
|
|||
) {
|
||||
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
|
||||
// be properly updated.
|
||||
val models by remember {
|
||||
val models by remember(task) {
|
||||
derivedStateOf {
|
||||
val trigger = task.updateTrigger.value
|
||||
if (trigger >= 0) {
|
||||
|
@ -80,7 +80,7 @@ fun ModelList(
|
|||
}
|
||||
}
|
||||
}
|
||||
val importedModels by remember {
|
||||
val importedModels by remember(task) {
|
||||
derivedStateOf {
|
||||
val trigger = task.updateTrigger.value
|
||||
if (trigger >= 0) {
|
||||
|
|
|
@ -37,15 +37,16 @@ import com.google.aiedge.gallery.data.ModelAllowlist
|
|||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.TASKS
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.TaskType
|
||||
import com.google.aiedge.gallery.data.ValueType
|
||||
import com.google.aiedge.gallery.data.getModelByName
|
||||
import com.google.aiedge.gallery.ui.common.AuthConfig
|
||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||
import com.google.aiedge.gallery.ui.common.getJsonResponse
|
||||
import com.google.aiedge.gallery.ui.common.processTasks
|
||||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
||||
|
@ -59,8 +60,6 @@ import kotlinx.coroutines.flow.asStateFlow
|
|||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.ExperimentalSerializationApi
|
||||
import kotlinx.serialization.json.Json
|
||||
import net.openid.appauth.AuthorizationException
|
||||
import net.openid.appauth.AuthorizationRequest
|
||||
import net.openid.appauth.AuthorizationResponse
|
||||
|
@ -73,7 +72,7 @@ import java.net.URL
|
|||
private const val TAG = "AGModelManagerViewModel"
|
||||
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
|
||||
private const val MODEL_ALLOWLIST_URL =
|
||||
"https://raw.githubusercontent.com/jinjingforever/kokoro-codelab-jingjin/refs/heads/main/model_allowlist.json"
|
||||
"https://raw.githubusercontent.com/google-ai-edge/gallery/refs/heads/main/model_allowlist.json"
|
||||
|
||||
data class ModelInitializationStatus(
|
||||
val status: ModelInitializationStatusType, var error: String = ""
|
||||
|
@ -116,11 +115,6 @@ data class ModelManagerUiState(
|
|||
*/
|
||||
val modelInitializationStatus: Map<String, ModelInitializationStatus>,
|
||||
|
||||
/**
|
||||
* Whether Hugging Face models from the given community are currently being loaded.
|
||||
*/
|
||||
val loadingHfModels: Boolean = false,
|
||||
|
||||
/**
|
||||
* Whether the app is loading and processing the model allowlist.
|
||||
*/
|
||||
|
@ -196,8 +190,9 @@ open class ModelManagerViewModel(
|
|||
)
|
||||
}
|
||||
|
||||
fun cancelDownloadModel(model: Model) {
|
||||
fun cancelDownloadModel(task: Task, model: Model) {
|
||||
downloadRepository.cancelDownloadModel(model)
|
||||
deleteModel(task = task, model = model)
|
||||
}
|
||||
|
||||
fun deleteModel(task: Task, model: Model) {
|
||||
|
@ -311,13 +306,13 @@ open class ModelManagerViewModel(
|
|||
onDone = onDone,
|
||||
)
|
||||
|
||||
TaskType.LLM_USECASES -> LlmChatModelHelper.initialize(
|
||||
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.initialize(
|
||||
context = context,
|
||||
model = model,
|
||||
onDone = onDone,
|
||||
)
|
||||
|
||||
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.initialize(
|
||||
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.initialize(
|
||||
context = context,
|
||||
model = model,
|
||||
onDone = onDone,
|
||||
|
@ -341,8 +336,8 @@ open class ModelManagerViewModel(
|
|||
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
||||
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
TaskType.TEST_TASK_2 -> {}
|
||||
|
@ -446,14 +441,14 @@ open class ModelManagerViewModel(
|
|||
// Create model.
|
||||
val model = createModelFromImportedModelInfo(info = info)
|
||||
|
||||
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES, TASK_LLM_IMAGE_TO_TEXT)) {
|
||||
for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
|
||||
// Remove duplicated imported model if existed.
|
||||
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
||||
if (modelIndex >= 0) {
|
||||
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
||||
task.models.removeAt(modelIndex)
|
||||
}
|
||||
if (task == TASK_LLM_IMAGE_TO_TEXT && model.llmSupportImage || task != TASK_LLM_IMAGE_TO_TEXT) {
|
||||
if (task == TASK_LLM_ASK_IMAGE && model.llmSupportImage || task != TASK_LLM_ASK_IMAGE) {
|
||||
task.models.add(model)
|
||||
}
|
||||
task.updateTrigger.value = System.currentTimeMillis()
|
||||
|
@ -502,7 +497,7 @@ open class ModelManagerViewModel(
|
|||
|
||||
// Check expiration (with 5-minute buffer).
|
||||
val curTs = System.currentTimeMillis()
|
||||
val expirationTs = tokenData.expiresAtSeconds - 5 * 60
|
||||
val expirationTs = tokenData.expiresAtMs - 5 * 60
|
||||
Log.d(
|
||||
TAG,
|
||||
"Checking whether token has expired or not. Current ts: $curTs, expires at: $expirationTs"
|
||||
|
@ -562,7 +557,7 @@ open class ModelManagerViewModel(
|
|||
} else {
|
||||
// Token exchange successful. Store the tokens securely
|
||||
Log.d(TAG, "Token exchange successful. Storing tokens...")
|
||||
dataStoreRepository.saveAccessTokenData(
|
||||
saveAccessToken(
|
||||
accessToken = tokenResponse.accessToken!!,
|
||||
refreshToken = tokenResponse.refreshToken!!,
|
||||
expiresAt = tokenResponse.accessTokenExpirationTime!!
|
||||
|
@ -606,6 +601,18 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
}
|
||||
|
||||
fun saveAccessToken(accessToken: String, refreshToken: String, expiresAt: Long) {
|
||||
dataStoreRepository.saveAccessTokenData(
|
||||
accessToken = accessToken,
|
||||
refreshToken = refreshToken,
|
||||
expiresAt = expiresAt,
|
||||
)
|
||||
}
|
||||
|
||||
fun clearAccessToken() {
|
||||
dataStoreRepository.clearAccessTokenData()
|
||||
}
|
||||
|
||||
private fun processPendingDownloads() {
|
||||
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
|
||||
|
||||
|
@ -673,11 +680,11 @@ open class ModelManagerViewModel(
|
|||
if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) {
|
||||
TASK_LLM_CHAT.models.add(model)
|
||||
}
|
||||
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
|
||||
TASK_LLM_USECASES.models.add(model)
|
||||
if (allowedModel.taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)) {
|
||||
TASK_LLM_PROMPT_LAB.models.add(model)
|
||||
}
|
||||
if (allowedModel.taskTypes.contains(TASK_LLM_IMAGE_TO_TEXT.type.id)) {
|
||||
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
|
||||
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
|
||||
TASK_LLM_ASK_IMAGE.models.add(model)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -732,9 +739,9 @@ open class ModelManagerViewModel(
|
|||
|
||||
// Add to task.
|
||||
TASK_LLM_CHAT.models.add(model)
|
||||
TASK_LLM_USECASES.models.add(model)
|
||||
TASK_LLM_PROMPT_LAB.models.add(model)
|
||||
if (model.llmSupportImage) {
|
||||
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
|
||||
TASK_LLM_ASK_IMAGE.models.add(model)
|
||||
}
|
||||
|
||||
// Update status.
|
||||
|
@ -827,38 +834,6 @@ open class ModelManagerViewModel(
|
|||
)
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalSerializationApi::class)
|
||||
private inline fun <reified T> getJsonResponse(url: String): T? {
|
||||
try {
|
||||
val connection = URL(url).openConnection() as HttpURLConnection
|
||||
connection.requestMethod = "GET"
|
||||
connection.connect()
|
||||
|
||||
val responseCode = connection.responseCode
|
||||
if (responseCode == HttpURLConnection.HTTP_OK) {
|
||||
val inputStream = connection.inputStream
|
||||
val response = inputStream.bufferedReader().use { it.readText() }
|
||||
|
||||
// Parse JSON using kotlinx.serialization
|
||||
val json = Json {
|
||||
// Handle potential extra fields
|
||||
ignoreUnknownKeys = true
|
||||
allowComments = true
|
||||
allowTrailingComma = true
|
||||
}
|
||||
val jsonObj = json.decodeFromString<T>(response)
|
||||
return jsonObj
|
||||
} else {
|
||||
Log.e(TAG, "HTTP error: $responseCode")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error when getting json response: ${e.message}")
|
||||
e.printStackTrace()
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
private fun isFileInExternalFilesDir(fileName: String): Boolean {
|
||||
if (externalFilesDir != null) {
|
||||
val file = File(externalFilesDir, fileName)
|
||||
|
|
|
@ -46,8 +46,8 @@ import com.google.aiedge.gallery.data.Model
|
|||
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
||||
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.TaskType
|
||||
|
@ -60,8 +60,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
|
|||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextDestination
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextScreen
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmAskImageDestination
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmAskImageScreen
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
|
||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
|
||||
|
@ -233,7 +233,7 @@ fun GalleryNavHost(
|
|||
enterTransition = { slideEnter() },
|
||||
exitTransition = { slideExit() },
|
||||
) {
|
||||
getModelFromNavigationParam(it, TASK_LLM_USECASES)?.let { defaultModel ->
|
||||
getModelFromNavigationParam(it, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
|
||||
modelManagerViewModel.selectModel(defaultModel)
|
||||
|
||||
LlmSingleTurnScreen(
|
||||
|
@ -245,15 +245,15 @@ fun GalleryNavHost(
|
|||
|
||||
// LLM image to text.
|
||||
composable(
|
||||
route = "${LlmImageToTextDestination.route}/{modelName}",
|
||||
route = "${LlmAskImageDestination.route}/{modelName}",
|
||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||
enterTransition = { slideEnter() },
|
||||
exitTransition = { slideExit() },
|
||||
) {
|
||||
getModelFromNavigationParam(it, TASK_LLM_IMAGE_TO_TEXT)?.let { defaultModel ->
|
||||
getModelFromNavigationParam(it, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
|
||||
modelManagerViewModel.selectModel(defaultModel)
|
||||
|
||||
LlmImageToTextScreen(
|
||||
LlmAskImageScreen(
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
navigateUp = { navController.navigateUp() },
|
||||
)
|
||||
|
@ -287,8 +287,8 @@ fun navigateToTaskScreen(
|
|||
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
||||
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
||||
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
||||
TaskType.LLM_IMAGE_TO_TEXT -> navController.navigate("${LlmImageToTextDestination.route}/${modelName}")
|
||||
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
|
||||
TaskType.LLM_PROMPT_LAB -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
TaskType.TEST_TASK_2 -> {}
|
||||
|
|
|
@ -42,6 +42,9 @@ class PreviewDataStoreRepository : DataStoreRepository {
|
|||
return null
|
||||
}
|
||||
|
||||
override fun clearAccessTokenData() {
|
||||
}
|
||||
|
||||
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue