From 29b614355e3efb8906a0768a800f8b7290517621 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Thu, 17 Apr 2025 17:54:57 -0700 Subject: [PATCH] Add initial support for importing local model. --- .../gallery/data/DataStoreRepository.kt | 25 ++ .../google/aiedge/gallery/data/HuggingFace.kt | 3 +- .../com/google/aiedge/gallery/data/Model.kt | 13 +- .../com/google/aiedge/gallery/data/Tasks.kt | 6 +- .../gallery/ui/common/DownloadAndTryButton.kt | 7 +- .../google/aiedge/gallery/ui/common/Utils.kt | 16 +- .../gallery/ui/common/chat/ChatPanel.kt | 57 +++- .../aiedge/gallery/ui/common/chat/ChatView.kt | 8 +- .../gallery/ui/common/modelitem/ModelItem.kt | 21 +- .../common/modelitem/ModelItemActionButton.kt | 2 +- .../ImageClassificationModelHelper.kt | 4 +- .../ImageGenerationModelHelper.kt | 18 +- .../gallery/ui/llmchat/LlmChatModelHelper.kt | 8 +- .../ui/modelmanager/ModelImportDialog.kt | 260 ++++++++++++++++ .../gallery/ui/modelmanager/ModelList.kt | 286 ++++++++++++++---- .../ui/modelmanager/ModelManagerViewModel.kt | 165 +++++++--- .../ui/preview/PreviewDataStoreRepository.kt | 8 + .../preview/PreviewModelManagerViewModel.kt | 4 - .../TextClassificationModelHelper.kt | 4 +- 19 files changed, 789 insertions(+), 126 deletions(-) create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelImportDialog.kt diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt index 0a61e1d..e65173a 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt @@ -47,6 +47,8 @@ interface DataStoreRepository { fun readThemeOverride(): String fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) fun readAccessTokenData(): AccessTokenData? + fun saveLocalModels(localModels: List) + fun readLocalModels(): List } /** @@ -79,6 +81,9 @@ class DefaultDataStoreRepository( val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv") val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at") + + // Data for all imported local models. + val LOCAL_MODELS = stringPreferencesKey("local_models") } private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key" @@ -155,6 +160,26 @@ class DefaultDataStoreRepository( } } + override fun saveLocalModels(localModels: List) { + runBlocking { + dataStore.edit { preferences -> + val gson = Gson() + val jsonString = gson.toJson(localModels) + preferences[PreferencesKeys.LOCAL_MODELS] = jsonString + } + } + } + + override fun readLocalModels(): List { + return runBlocking { + val preferences = dataStore.data.first() + val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]" + val gson = Gson() + val listType = object : TypeToken>() {}.type + gson.fromJson(infosStr, listType) + } + } + private fun getTextInputHistory(preferences: Preferences): List { val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]" val gson = Gson() diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt index b1fd347..6aa08ea 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt @@ -16,6 +16,7 @@ package com.google.aiedge.gallery.data +import com.google.aiedge.gallery.ui.common.ensureValidFileName import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable @@ -103,7 +104,7 @@ data class HfModel( } else { listOf("") } - val fileName = "${id}_${(parts.lastOrNull() ?: "")}".replace(Regex("[^a-zA-Z0-9._-]"), "_") + val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}") // Generate configs based on the given default values. val configs: List = when (task) { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt index 54c8a4e..b72db89 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt @@ -32,6 +32,8 @@ enum class LlmBackend { CPU, GPU } +const val IMPORTS_DIR = "__imports" + /** A model for a task */ data class Model( /** The Hugging Face model ID (if applicable). */ @@ -85,6 +87,9 @@ data class Model( /** The prompt templates for the model (only for LLM). */ val llmPromptTemplates: List = listOf(), + /** Whether the model is imported as a local model. */ + val isLocalModel: Boolean = false, + // The following fields are managed by the app. Don't need to set manually. var taskType: TaskType? = null, var instance: Any? = null, @@ -104,10 +109,11 @@ data class Model( } fun getPath(context: Context, fileName: String = downloadFileName): String { + val baseDir = "${context.getExternalFilesDir(null)}" return if (this.isZip && this.unzipDir.isNotEmpty()) { - "${context.getExternalFilesDir(null)}/${this.unzipDir}" + "$baseDir/${this.unzipDir}" } else { - "${context.getExternalFilesDir(null)}/${fileName}" + "$baseDir/${fileName}" } } @@ -140,6 +146,9 @@ data class Model( } } +/** Data for a imported local model. */ +data class LocalModelInfo(val fileName: String, val fileSize: Long) + enum class ModelDownloadStatusType { NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED, } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt index 5b7c9b3..c1ca12e 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt @@ -19,6 +19,8 @@ package com.google.aiedge.gallery.data import androidx.annotation.StringRes import androidx.compose.material.icons.Icons import androidx.compose.material.icons.rounded.ImageSearch +import androidx.compose.runtime.MutableState +import androidx.compose.runtime.mutableStateOf import androidx.compose.ui.graphics.vector.ImageVector import com.google.aiedge.gallery.R @@ -63,7 +65,9 @@ data class Task( @StringRes val textInputPlaceHolderRes: Int = R.string.chat_textinput_placeholder, // The following fields are managed by the app. Don't need to set manually. - var index: Int = -1 + var index: Int = -1, + + val updateTrigger: MutableState = mutableStateOf(0) ) val TASK_TEXT_CLASSIFICATION = Task( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt index ef2d334..6ec84a0 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt @@ -46,6 +46,7 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.unit.dp import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.TokenRequestResultType import com.google.aiedge.gallery.ui.modelmanager.TokenStatus @@ -90,6 +91,7 @@ private const val TAG = "AGDownloadAndTryButton" @OptIn(ExperimentalMaterial3Api::class) @Composable fun DownloadAndTryButton( + task: Task, model: Model, enabled: Boolean, needToDownloadFirst: Boolean, @@ -106,17 +108,18 @@ fun DownloadAndTryButton( val permissionLauncher = rememberLauncherForActivityResult( ActivityResultContracts.RequestPermission() ) { - modelManagerViewModel.downloadModel(model) + modelManagerViewModel.downloadModel(task = task, model = model) } // Function to kick off download. val startDownload: (accessToken: String?) -> Unit = { accessToken -> model.accessToken = accessToken onClicked() - checkNotificationPermissonAndStartDownload( + checkNotificationPermissionAndStartDownload( context = context, launcher = permissionLauncher, modelManagerViewModel = modelManagerViewModel, + task = task, model = model ) checkingToken = false diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt index 5059527..d9709fb 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt @@ -416,10 +416,11 @@ fun getTaskIconColor(index: Int): Color { return MaterialTheme.customColors.taskIconColors[colorIndex] } -fun checkNotificationPermissonAndStartDownload( +fun checkNotificationPermissionAndStartDownload( context: Context, launcher: ManagedActivityResultLauncher, modelManagerViewModel: ModelManagerViewModel, + task: Task, model: Model ) { // Check permission @@ -428,7 +429,7 @@ fun checkNotificationPermissonAndStartDownload( ContextCompat.checkSelfPermission( context, Manifest.permission.POST_NOTIFICATIONS ) -> { - modelManagerViewModel.downloadModel(model) + modelManagerViewModel.downloadModel(task = task, model = model) } // Otherwise, ask for permission @@ -440,3 +441,14 @@ fun checkNotificationPermissonAndStartDownload( } } +fun ensureValidFileName(fileName: String): String { + return fileName.replace(Regex("[^a-zA-Z0-9._-]"), "_") +} + +fun cleanUpMediapipeTaskErrorMessage(message: String): String { + val index = message.indexOf("=== Source Location Trace") + if (index >= 0) { + return message.substring(0, index) + } + return message +} diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt index 1ef6f68..15a4c2d 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt @@ -41,10 +41,13 @@ import androidx.compose.foundation.layout.wrapContentHeight import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.rememberLazyListState +import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.outlined.Timer import androidx.compose.material.icons.rounded.ContentCopy import androidx.compose.material.icons.rounded.Refresh +import androidx.compose.material3.Button +import androidx.compose.material3.Card import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme @@ -83,11 +86,12 @@ import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp +import androidx.compose.ui.window.Dialog import com.google.aiedge.gallery.R import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.TaskType -import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatus +import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewChatModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel @@ -113,6 +117,7 @@ fun ChatPanel( onSendMessage: (Model, ChatMessage) -> Unit, onRunAgainClicked: (Model, ChatMessage) -> Unit, onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit, + navigateUp: () -> Unit, modifier: Modifier = Modifier, onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> }, onStreamEnd: (Int) -> Unit = {}, @@ -140,6 +145,8 @@ fun ChatPanel( var showMessageLongPressedSheet by remember { mutableStateOf(false) } val longPressedMessage: MutableState = remember { mutableStateOf(null) } + var showErrorDialog by remember { mutableStateOf(false) } + // Keep track of the last message and last message content. val lastMessage: MutableState = remember { mutableStateOf(null) } val lastMessageContent: MutableState = remember { mutableStateOf("") } @@ -201,6 +208,10 @@ fun ChatPanel( val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[selectedModel.name] + LaunchedEffect(modelInitializationStatus) { + showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR + } + Column( modifier = modifier.imePadding() ) { @@ -417,7 +428,7 @@ fun ChatPanel( // Model initialization in-progress message. this@Column.AnimatedVisibility( - visible = modelInitializationStatus == ModelInitializationStatus.INITIALIZING, + visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, enter = scaleIn() + fadeIn(), exit = scaleOut() + fadeOut(), modifier = Modifier.offset(y = 12.dp) @@ -479,6 +490,47 @@ fun ChatPanel( } } + // Error dialog. + if (showErrorDialog) { + Dialog( + onDismissRequest = { + showErrorDialog = false + navigateUp() + }, + ) { + Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { + Column( + modifier = Modifier + .padding(20.dp), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Title + Text( + "Error", + style = MaterialTheme.typography.titleLarge, + modifier = Modifier.padding(bottom = 8.dp) + ) + + // Error + Text( + modelInitializationStatus?.error ?: "", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error, + ) + + Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) { + Button(onClick = { + showErrorDialog = false + navigateUp() + }) { + Text("Close") + } + } + } + } + } + } + // Benchmark config dialog. if (showBenchmarkConfigsDialog) { BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false }, @@ -547,6 +599,7 @@ fun ChatPanelPreview() { task = task, selectedModel = TASK_TEST1.models[1], viewModel = PreviewChatModel(context = context), + navigateUp = {}, onSendMessage = { _, _ -> }, onRunAgainClicked = { _, _ -> }, onBenchmarkClicked = { _, _, _, _ -> }, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt index fe7e5e0..22d1178 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt @@ -55,7 +55,7 @@ import com.google.aiedge.gallery.data.AppBarActionType import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.Task -import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload +import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewChatModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel @@ -104,7 +104,7 @@ fun ChatView( val launcher = rememberLauncherForActivityResult( ActivityResultContracts.RequestPermission() ) { - modelManagerViewModel.downloadModel(selectedModel) + modelManagerViewModel.downloadModel(task = task, model = selectedModel) } val handleNavigateUp = { @@ -245,10 +245,11 @@ fun ChatView( exit = fadeOut() ) { ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = { - checkNotificationPermissonAndStartDownload( + checkNotificationPermissionAndStartDownload( context = context, launcher = launcher, modelManagerViewModel = modelManagerViewModel, + task = task, model = curSelectedModel ) }) @@ -261,6 +262,7 @@ fun ChatView( task = task, selectedModel = curSelectedModel, viewModel = viewModel, + navigateUp = navigateUp, onSendMessage = onSendMessage, onRunAgainClicked = onRunAgainClicked, onBenchmarkClicked = onBenchmarkClicked, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt index 96f59c5..7e3d18f 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt @@ -35,6 +35,7 @@ import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.padding import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.rounded.ChevronRight import androidx.compose.material.icons.rounded.Settings import androidx.compose.material.icons.rounded.UnfoldLess import androidx.compose.material.icons.rounded.UnfoldMore @@ -68,7 +69,7 @@ import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.DownloadAndTryButton import com.google.aiedge.gallery.ui.common.TaskIcon import com.google.aiedge.gallery.ui.common.chat.MarkdownText -import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload +import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload import com.google.aiedge.gallery.ui.common.getTaskBgColor import com.google.aiedge.gallery.ui.common.getTaskIconColor import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel @@ -113,7 +114,7 @@ fun ModelItem( val launcher = rememberLauncherForActivityResult( ActivityResultContracts.RequestPermission() ) { - modelManagerViewModel.downloadModel(model) + modelManagerViewModel.downloadModel(task = task, model = model) } var isExpanded by remember { mutableStateOf(false) } @@ -156,10 +157,11 @@ fun ModelItem( modelManagerViewModel = modelManagerViewModel, downloadStatus = downloadStatus, onDownloadClicked = { model -> - checkNotificationPermissonAndStartDownload( + checkNotificationPermissionAndStartDownload( context = context, launcher = launcher, modelManagerViewModel = modelManagerViewModel, + task = task, model = model ) }, @@ -186,7 +188,9 @@ fun ModelItem( } } else { Icon( - if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore, + // For local model, show ">" directly indicating users can just tap the model item to + // go into it without needing to expand it first. + if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore, contentDescription = "", tint = getTaskIconColor(task), ) @@ -237,6 +241,7 @@ fun ModelItem( val needToDownloadFirst = downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED DownloadAndTryButton( + task = task, model = model, enabled = isExpanded, needToDownloadFirst = needToDownloadFirst, @@ -266,7 +271,13 @@ fun ModelItem( ) boxModifier = if (canExpand) { boxModifier.clickable( - onClick = { isExpanded = !isExpanded }, + onClick = { + if (!model.isLocalModel) { + isExpanded = !isExpanded + } else { + onModelClicked(model) + } + }, interactionSource = remember { MutableInteractionSource() }, indication = ripple( bounded = true, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItemActionButton.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItemActionButton.kt index 4ec8295..0fa31bd 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItemActionButton.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItemActionButton.kt @@ -124,7 +124,7 @@ fun ModelItemActionButton( if (showConfirmDeleteDialog) { ConfirmDeleteModelDialog(model = model, onConfirm = { - modelManagerViewModel.deleteModel(model) + modelManagerViewModel.deleteModel(task = task, model = model) showConfirmDeleteDialog = false }, onDismiss = { showConfirmDeleteDialog = false diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imageclassification/ImageClassificationModelHelper.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imageclassification/ImageClassificationModelHelper.kt index b9415b2..0e05b07 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imageclassification/ImageClassificationModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imageclassification/ImageClassificationModelHelper.kt @@ -48,7 +48,7 @@ class ImageClassificationInferenceResult( //TODO: handle error. object ImageClassificationModelHelper { - fun initialize(context: Context, model: Model, onDone: () -> Unit) { + fun initialize(context: Context, model: Model, onDone: (String) -> Unit) { val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU) TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask -> val optionsBuilder = TfLiteInitializationOptions.builder() @@ -69,7 +69,7 @@ object ImageClassificationModelHelper { File(model.getPath(context = context)), interpreterOption ) model.instance = interpreter - onDone() + onDone("") } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt index dfd4d83..019930c 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt @@ -24,6 +24,7 @@ import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.ui.common.LatencyProvider +import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage import kotlin.random.Random private const val TAG = "AGImageGenerationModelHelper" @@ -33,12 +34,17 @@ class ImageGenerationInferenceResult( ) : LatencyProvider object ImageGenerationModelHelper { - fun initialize(context: Context, model: Model, onDone: () -> Unit) { - val options = ImageGenerator.ImageGeneratorOptions.builder() - .setImageGeneratorModelDirectory(model.getPath(context = context)) - .build() - model.instance = ImageGenerator.createFromOptions(context, options) - onDone() + fun initialize(context: Context, model: Model, onDone: (String) -> Unit) { + try { + val options = ImageGenerator.ImageGeneratorOptions.builder() + .setImageGeneratorModelDirectory(model.getPath(context = context)) + .build() + model.instance = ImageGenerator.createFromOptions(context, options) + } catch (e: Exception) { + onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error")) + return + } + onDone("") } fun cleanUp(model: Model) { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt index 15cfe30..aa01ac1 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -21,6 +21,7 @@ import android.util.Log import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.LlmBackend import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession @@ -40,7 +41,7 @@ object LlmChatModelHelper { private val cleanUpListeners: MutableMap = mutableMapOf() fun initialize( - context: Context, model: Model, onDone: () -> Unit + context: Context, model: Model, onDone: (String) -> Unit ) { val maxTokens = model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN) @@ -68,9 +69,10 @@ object LlmChatModelHelper { ) model.instance = LlmModelInstance(engine = llmInference, session = session) } catch (e: Exception) { - e.printStackTrace() + onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error")) + return } - onDone() + onDone("") } fun cleanUp(model: Model) { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelImportDialog.kt new file mode 100644 index 0000000..d8b26f9 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelImportDialog.kt @@ -0,0 +1,260 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.aiedge.gallery.ui.modelmanager + +import android.content.Context +import android.net.Uri +import android.provider.OpenableColumns +import android.util.Log +import androidx.compose.animation.core.Animatable +import androidx.compose.animation.core.tween +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.rounded.Error +import androidx.compose.material3.Button +import androidx.compose.material3.Card +import androidx.compose.material3.Icon +import androidx.compose.material3.LinearProgressIndicator +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableFloatStateOf +import androidx.compose.runtime.mutableLongStateOf +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.unit.dp +import androidx.compose.ui.window.Dialog +import androidx.compose.ui.window.DialogProperties +import com.google.aiedge.gallery.data.IMPORTS_DIR +import com.google.aiedge.gallery.ui.common.ensureValidFileName +import com.google.aiedge.gallery.ui.common.humanReadableSize +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import java.io.File +import java.io.FileOutputStream +import java.net.URLDecoder +import java.nio.charset.StandardCharsets + +private const val TAG = "AGModelImportDialog" + +data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "") + +@Composable +fun ModelImportDialog( + uri: Uri, onDone: (ModelImportInfo) -> Unit +) { + val context = LocalContext.current + val coroutineScope = rememberCoroutineScope() + + var fileName by remember { mutableStateOf("") } + var fileSize by remember { mutableLongStateOf(0L) } + var error by remember { mutableStateOf("") } + var progress by remember { mutableFloatStateOf(0f) } + + LaunchedEffect(Unit) { + error = "" + + // Get basic info. + val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri) + fileSize = info.first + fileName = ensureValidFileName(info.second) + + // Import. + importModel( + context = context, + coroutineScope = coroutineScope, + fileName = fileName, + fileSize = fileSize, + uri = uri, + onDone = { + onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error)) + }, + onProgress = { + progress = it + }, + onError = { + error = it + } + ) + } + + Dialog( + properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false), + onDismissRequest = {}, + ) { + Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { + Column( + modifier = Modifier + .padding(20.dp), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Title. + Text( + "Importing...", + style = MaterialTheme.typography.titleLarge, + modifier = Modifier.padding(bottom = 8.dp) + ) + + // No error. + if (error.isEmpty()) { + // Progress bar. + Column(verticalArrangement = Arrangement.spacedBy(4.dp)) { + Text( + "$fileName (${fileSize.humanReadableSize()})", + style = MaterialTheme.typography.labelSmall, + ) + val animatedProgress = remember { Animatable(0f) } + LinearProgressIndicator( + progress = { animatedProgress.value }, + modifier = Modifier + .fillMaxWidth() + .padding(bottom = 8.dp), + ) + LaunchedEffect(progress) { + animatedProgress.animateTo(progress, animationSpec = tween(150)) + } + } + } + // Has error. + else { + Row( + verticalAlignment = Alignment.Top, + horizontalArrangement = Arrangement.spacedBy(6.dp) + ) { + Icon( + Icons.Rounded.Error, + contentDescription = "", + tint = MaterialTheme.colorScheme.error + ) + Text( + error, + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.error, + modifier = Modifier.padding(top = 4.dp) + ) + } + Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) { + Button(onClick = { + onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error)) + }) { + Text("Close") + } + } + } + } + } + } +} + +private fun importModel( + context: Context, + coroutineScope: CoroutineScope, + fileName: String, + fileSize: Long, + uri: Uri, + onDone: () -> Unit, + onProgress: (Float) -> Unit, + onError: (String) -> Unit, +) { + // TODO: handle error. + coroutineScope.launch(Dispatchers.IO) { + // Get the last component of the uri path as the imported file name. + val decodedUri = URLDecoder.decode(uri.toString(), StandardCharsets.UTF_8.name()) + Log.d(TAG, "importing model from $decodedUri. File name: $fileName. File size: $fileSize") + + // Create /imports if not exist. + val importsDir = File(context.getExternalFilesDir(null), IMPORTS_DIR) + if (!importsDir.exists()) { + importsDir.mkdirs() + } + + // Import by copying the file over. + val outputFile = File(context.getExternalFilesDir(null), "$IMPORTS_DIR/$fileName") + val outputStream = FileOutputStream(outputFile) + val buffer = ByteArray(DEFAULT_BUFFER_SIZE) + var bytesRead: Int + var lastSetProgressTs: Long = 0 + var importedBytes = 0L + val inputStream = context.contentResolver.openInputStream(uri) + try { + if (inputStream != null) { + while (inputStream.read(buffer).also { bytesRead = it } != -1) { + outputStream.write(buffer, 0, bytesRead) + importedBytes += bytesRead + + // Report progress every 200 ms. + val curTs = System.currentTimeMillis() + if (curTs - lastSetProgressTs > 200) { + Log.d(TAG, "importing progress: $importedBytes, $fileSize") + lastSetProgressTs = curTs + if (fileSize != 0L) { + onProgress(importedBytes.toFloat() / fileSize.toFloat()) + } + } + } + } + } catch (e: Exception) { + e.printStackTrace() + onError(e.message ?: "Failed to import") + return@launch + } finally { + inputStream?.close() + outputStream.close() + } + Log.d(TAG, "import done") + onProgress(1f) + onDone() + } +} + +private fun getFileSizeAndDisplayNameFromUri(context: Context, uri: Uri): Pair { + val contentResolver = context.contentResolver + var fileSize = 0L + var displayName = "" + + try { + contentResolver.query( + uri, arrayOf(OpenableColumns.SIZE, OpenableColumns.DISPLAY_NAME), null, null, null + )?.use { cursor -> + if (cursor.moveToFirst()) { + val sizeIndex = cursor.getColumnIndexOrThrow(OpenableColumns.SIZE) + fileSize = cursor.getLong(sizeIndex) + + val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME) + displayName = cursor.getString(nameIndex) + } + } + } catch (e: Exception) { + e.printStackTrace() + return Pair(0L, "") + } + + return Pair(fileSize, displayName) +} diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt index abcb0ee..ba1f8f3 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt @@ -16,24 +16,46 @@ package com.google.aiedge.gallery.ui.modelmanager +import android.content.Intent +import android.net.Uri +import android.os.Build +import android.util.Log +import androidx.activity.compose.rememberLauncherForActivityResult +import androidx.activity.result.ActivityResultLauncher +import androidx.activity.result.contract.ActivityResultContracts +import androidx.annotation.RequiresApi import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.outlined.NoteAdd +import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.outlined.Code import androidx.compose.material.icons.outlined.Description +import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.ModalBottomSheet +import androidx.compose.material3.SmallFloatingActionButton import androidx.compose.material3.Text +import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.runtime.Composable +import androidx.compose.runtime.derivedStateOf +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.vector.ImageVector @@ -53,8 +75,14 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.customColors +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch + +private const val TAG = "AGModelList" /** The list of models in the model manager. */ +@RequiresApi(Build.VERSION_CODES.O) +@OptIn(ExperimentalMaterial3Api::class) @Composable fun ModelList( task: Task, @@ -63,65 +91,214 @@ fun ModelList( onModelClicked: (Model) -> Unit, modifier: Modifier = Modifier, ) { - LazyColumn( - modifier = modifier.padding(top = 8.dp), - contentPadding = contentPadding, - verticalArrangement = Arrangement.spacedBy(8.dp), - ) { - // Headline. - item(key = "headline") { - Text( - task.description, - textAlign = TextAlign.Center, - style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), - modifier = Modifier - .fillMaxWidth() - ) - } + var showAddModelSheet by remember { mutableStateOf(false) } + var showImportingDialog by remember { mutableStateOf(false) } + val curFileUri = remember { mutableStateOf(null) } + val sheetState = rememberModalBottomSheetState() + val coroutineScope = rememberCoroutineScope() - // URLs. - item(key = "urls") { - Row( - horizontalArrangement = Arrangement.Center, - modifier = Modifier - .fillMaxWidth() - .padding(top = 12.dp, bottom = 16.dp), - ) { - Column( - horizontalAlignment = Alignment.Start, - verticalArrangement = Arrangement.spacedBy(4.dp), + // This is just to update "models" list when task.updateTrigger is updated so that the UI can + // be properly updated. + val models by remember { + derivedStateOf { + val trigger = task.updateTrigger.value + if (trigger >= 0) { + task.models.toList().filter { !it.isLocalModel } + } else { + listOf() + } + } + } + val localModels by remember { + derivedStateOf { + val trigger = task.updateTrigger.value + if (trigger >= 0) { + task.models.toList().filter { it.isLocalModel } + } else { + listOf() + } + } + } + + val filePickerLauncher: ActivityResultLauncher = rememberLauncherForActivityResult( + contract = ActivityResultContracts.StartActivityForResult() + ) { result -> + if (result.resultCode == android.app.Activity.RESULT_OK) { + result.data?.data?.let { uri -> + curFileUri.value = uri + showImportingDialog = true + } ?: run { + Log.d(TAG, "No file selected or URI is null.") + } + } else { + Log.d(TAG, "File picking cancelled.") + } + } + + Box(contentAlignment = Alignment.BottomEnd) { + LazyColumn( + modifier = modifier.padding(top = 8.dp), + contentPadding = contentPadding, + verticalArrangement = Arrangement.spacedBy(8.dp), + ) { + // Headline. + item(key = "headline") { + Text( + task.description, + textAlign = TextAlign.Center, + style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), + modifier = Modifier.fillMaxWidth() + ) + } + + // URLs. + item(key = "urls") { + Row( + horizontalArrangement = Arrangement.Center, + modifier = Modifier + .fillMaxWidth() + .padding(top = 12.dp, bottom = 16.dp), ) { - if (task.docUrl.isNotEmpty()) { - ClickableLink( - url = task.docUrl, - linkText = "API Documentation", - icon = Icons.Outlined.Description - ) - } - if (task.sourceCodeUrl.isNotEmpty()) { - ClickableLink( - url = task.sourceCodeUrl, - linkText = "Example code", - icon = Icons.Outlined.Code - ) + Column( + horizontalAlignment = Alignment.Start, + verticalArrangement = Arrangement.spacedBy(4.dp), + ) { + if (task.docUrl.isNotEmpty()) { + ClickableLink( + url = task.docUrl, linkText = "API Documentation", icon = Icons.Outlined.Description + ) + } + if (task.sourceCodeUrl.isNotEmpty()) { + ClickableLink( + url = task.sourceCodeUrl, linkText = "Example code", icon = Icons.Outlined.Code + ) + } } } } + + // List of models within a task. + items(items = models) { model -> + Box { + ModelItem( + model = model, + task = task, + modelManagerViewModel = modelManagerViewModel, + onModelClicked = onModelClicked, + modifier = Modifier.padding(horizontal = 12.dp) + ) + } + } + + // Title for local models. + if (localModels.isNotEmpty()) { + item(key = "localModelsTitle") { + Text( + "Local models", + style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), + modifier = Modifier + .padding(horizontal = 16.dp) + .padding(top = 24.dp) + ) + } + } + + // List of local models within a task. + items(items = localModels) { model -> + Box { + ModelItem( + model = model, + task = task, + modelManagerViewModel = modelManagerViewModel, + onModelClicked = onModelClicked, + modifier = Modifier.padding(horizontal = 12.dp) + ) + } + } + + item(key = "bottomPadding") { + Spacer(modifier = Modifier.height(60.dp)) + } } - // List of models within a task. - items(items = task.models) { model -> - Box { - ModelItem( - model = model, - task = task, - modelManagerViewModel = modelManagerViewModel, - onModelClicked = onModelClicked, - modifier = Modifier.padding(start = 12.dp, end = 12.dp) - ) + // Add model button at the bottom right. + Box( + modifier = Modifier + .padding(end = 16.dp) + .padding(bottom = contentPadding.calculateBottomPadding()) + ) { + SmallFloatingActionButton( + onClick = { + showAddModelSheet = true + }, + containerColor = MaterialTheme.colorScheme.secondaryContainer, + contentColor = MaterialTheme.colorScheme.secondary, + ) { + Icon(Icons.Filled.Add, "") } } } + + if (showAddModelSheet) { + ModalBottomSheet( + onDismissRequest = { showAddModelSheet = false }, + sheetState = sheetState, + ) { + Text( + "Add custom model", + style = MaterialTheme.typography.titleLarge, + modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp) + ) + Box(modifier = Modifier.clickable { + coroutineScope.launch { + // Give it sometime to show the click effect. + delay(200) + showAddModelSheet = false + + // Show file picker. + val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply { + addCategory(Intent.CATEGORY_OPENABLE) + type = "*/*" + putExtra( + Intent.EXTRA_MIME_TYPES, + arrayOf("application/x-binary", "application/octet-stream") + ) + // Single select. + putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false) + } + filePickerLauncher.launch(intent) + } + }) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp), + modifier = Modifier + .fillMaxWidth() + .padding(16.dp) + ) { + Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "") + Text("Add local model") + } + } + } + } + + if (showImportingDialog) { + curFileUri.value?.let { uri -> + ModelImportDialog(uri = uri, onDone = { info -> + showImportingDialog = false + + if (info.error.isEmpty()) { + // TODO: support other model types. + modelManagerViewModel.addLocalLlmModel( + task = task, + fileName = info.fileName, + fileSize = info.fileSize + ) + } + }) + } + } } @Composable @@ -132,15 +309,11 @@ fun ClickableLink( ) { val uriHandler = LocalUriHandler.current val annotatedText = AnnotatedString( - text = linkText, - spanStyles = listOf( + text = linkText, spanStyles = listOf( AnnotatedString.Range( item = SpanStyle( - color = MaterialTheme.customColors.linkColor, - textDecoration = TextDecoration.Underline - ), - start = 0, - end = linkText.length + color = MaterialTheme.customColors.linkColor, textDecoration = TextDecoration.Underline + ), start = 0, end = linkText.length ) ) ) @@ -163,6 +336,7 @@ fun ClickableLink( } } +@RequiresApi(Build.VERSION_CODES.O) @Preview(showBackground = true) @Composable fun ModelListPreview() { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt index 6ed644c..c4d3411 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -24,16 +24,20 @@ import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.google.aiedge.gallery.data.AGWorkInfo import com.google.aiedge.gallery.data.AccessTokenData +import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.DataStoreRepository import com.google.aiedge.gallery.data.DownloadRepository import com.google.aiedge.gallery.data.EMPTY_MODEL import com.google.aiedge.gallery.data.HfModel import com.google.aiedge.gallery.data.HfModelDetails import com.google.aiedge.gallery.data.HfModelSummary +import com.google.aiedge.gallery.data.IMPORTS_DIR +import com.google.aiedge.gallery.data.LocalModelInfo import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.TASKS +import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.getModelByName @@ -41,6 +45,7 @@ import com.google.aiedge.gallery.ui.common.AuthConfig import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper +import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async @@ -66,8 +71,12 @@ private const val TAG = "AGModelManagerViewModel" private const val HG_COMMUNITY = "jinjingforevercommunity" private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50 -enum class ModelInitializationStatus { - NOT_INITIALIZED, INITIALIZING, INITIALIZED, +data class ModelInitializationStatus( + val status: ModelInitializationStatusType, var error: String = "" +) + +enum class ModelInitializationStatusType { + NOT_INITIALIZED, INITIALIZING, INITIALIZED, ERROR } enum class TokenStatus { @@ -84,8 +93,7 @@ data class TokenStatusAndData( ) data class TokenRequestResult( - val status: TokenRequestResultType, - val errorMessage: String? = null + val status: TokenRequestResultType, val errorMessage: String? = null ) data class ModelManagerUiState( @@ -94,11 +102,6 @@ data class ModelManagerUiState( */ val tasks: List, - /** - * A map that stores lists of models indexed by task name. - */ - val modelsByTaskName: Map>, - /** * A map that tracks the download status of each model, indexed by model name. */ @@ -191,14 +194,14 @@ open class ModelManagerViewModel( _uiState.update { _uiState.value.copy(selectedModel = model) } } - fun downloadModel(model: Model) { + fun downloadModel(task: Task, model: Model) { // Update status. setDownloadStatus( curModel = model, status = ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS) ) // Delete the model files first. - deleteModel(model = model) + deleteModel(task = task, model = model) // Start to send download request. downloadRepository.downloadModel( @@ -210,7 +213,7 @@ open class ModelManagerViewModel( downloadRepository.cancelDownloadModel(model) } - fun deleteModel(model: Model) { + fun deleteModel(task: Task, model: Model) { deleteFileFromExternalFilesDir(model.downloadFileName) for (file in model.extraDataFiles) { deleteFileFromExternalFilesDir(file.downloadFileName) @@ -223,6 +226,24 @@ open class ModelManagerViewModel( val curModelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap() curModelDownloadStatus[model.name] = ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED) + + // Delete model from the list if model is imported as a local model. + if (model.isLocalModel) { + val index = task.models.indexOf(model) + if (index >= 0) { + task.models.removeAt(index) + } + task.updateTrigger.value = System.currentTimeMillis() + curModelDownloadStatus.remove(model.name) + + // Update preference. + val localModels = dataStoreRepository.readLocalModels().toMutableList() + val localModelIndex = localModels.indexOfFirst { it.fileName == model.name } + if (localModelIndex >= 0) { + localModels.removeAt(localModelIndex) + } + dataStoreRepository.saveLocalModels(localModels = localModels) + } val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus) _uiState.update { newUiState } } @@ -230,7 +251,7 @@ open class ModelManagerViewModel( fun initializeModel(context: Context, model: Model, force: Boolean = false) { viewModelScope.launch(Dispatchers.Default) { // Skip if initialized already. - if (!force && uiState.value.modelInitializationStatus[model.name] == ModelInitializationStatus.INITIALIZED) { + if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) { Log.d(TAG, "Model '${model.name}' has been initialized. Skipping.") return@launch } @@ -252,20 +273,27 @@ open class ModelManagerViewModel( // been initialized or not. If so, skip. launch { delay(500) - if (model.instance == null) { + if (model.instance == null && model.initializing) { updateModelInitializationStatus( - model = model, status = ModelInitializationStatus.INITIALIZING + model = model, status = ModelInitializationStatusType.INITIALIZING ) } } - val onDone: () -> Unit = { + val onDone: (error: String) -> Unit = { error -> + model.initializing = false if (model.instance != null) { Log.d(TAG, "Model '${model.name}' initialized successfully") - model.initializing = false updateModelInitializationStatus( model = model, - status = ModelInitializationStatus.INITIALIZED, + status = ModelInitializationStatusType.INITIALIZED, + ) + } else if (error.isNotEmpty()) { + Log.d(TAG, "Model '${model.name}' failed to initialize") + updateModelInitializationStatus( + model = model, + status = ModelInitializationStatusType.ERROR, + error = error, ) } } @@ -310,7 +338,7 @@ open class ModelManagerViewModel( model.instance = null model.initializing = false updateModelInitializationStatus( - model = model, status = ModelInitializationStatus.NOT_INITIALIZED + model = model, status = ModelInitializationStatusType.NOT_INITIALIZED ) } } @@ -380,8 +408,7 @@ open class ModelManagerViewModel( val connection = url.openConnection() as HttpURLConnection if (accessToken != null) { connection.setRequestProperty( - "Authorization", - "Bearer $accessToken" + "Authorization", "Bearer $accessToken" ) } connection.connect() @@ -390,6 +417,47 @@ open class ModelManagerViewModel( return connection.responseCode } + fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) { + Log.d(TAG, "adding local model: $fileName, $fileSize") + + // Create model. + val configs: List = createLLmChatConfig(defaults = mapOf()) + val model = Model( + name = fileName, + url = "", + configs = configs, + sizeInBytes = fileSize, + downloadFileName = "$IMPORTS_DIR/$fileName", + isLocalModel = true, + ) + model.preProcess(task = task) + task.models.add(model) + + // Add initial status and states. + val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap() + val modelInstances = uiState.value.modelInitializationStatus.toMutableMap() + modelDownloadStatus[model.name] = ModelDownloadStatus( + status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize + ) + modelInstances[model.name] = + ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED) + + // Update ui state. + _uiState.update { + uiState.value.copy( + tasks = uiState.value.tasks.toList(), + modelDownloadStatus = modelDownloadStatus, + modelInitializationStatus = modelInstances + ) + } + task.updateTrigger.value = System.currentTimeMillis() + + // Add to preference storage. + val localModels = dataStoreRepository.readLocalModels().toMutableList() + localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize)) + dataStoreRepository.saveLocalModels(localModels = localModels) + } + fun getTokenStatusAndData(): TokenStatusAndData { // Try to load token data from DataStore. var tokenStatus = TokenStatus.NOT_STORED @@ -436,8 +504,7 @@ open class ModelManagerViewModel( if (dataIntent == null) { onTokenRequested( TokenRequestResult( - status = TokenRequestResultType.FAILED, - errorMessage = "Empty auth result" + status = TokenRequestResultType.FAILED, errorMessage = "Empty auth result" ) ) return @@ -481,8 +548,7 @@ open class ModelManagerViewModel( } else { onTokenRequested( TokenRequestResult( - status = TokenRequestResultType.FAILED, - errorMessage = errorMessage + status = TokenRequestResultType.FAILED, errorMessage = errorMessage ) ) } @@ -513,23 +579,49 @@ open class ModelManagerViewModel( } private fun createUiState(): ModelManagerUiState { - val modelsByTaskName: Map> = - TASKS.associate { task -> task.type.label to task.models } val modelDownloadStatus: MutableMap = mutableMapOf() val modelInstances: MutableMap = mutableMapOf() - for ((_, models) in modelsByTaskName.entries) { - for (model in models) { + for (task in TASKS) { + for (model in task.models) { modelDownloadStatus[model.name] = getModelDownloadStatus(model = model) - modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED + modelInstances[model.name] = + ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED) } } + // Load local models. + for (localModel in dataStoreRepository.readLocalModels()) { + Log.d(TAG, "stored local model: $localModel") + + // Create model. + val configs: List = createLLmChatConfig(defaults = mapOf()) + val model = Model( + name = localModel.fileName, + url = "", + configs = configs, + sizeInBytes = localModel.fileSize, + downloadFileName = "$IMPORTS_DIR/${localModel.fileName}", + isLocalModel = true, + ) + + // Add to task. + val task = TASK_LLM_CHAT + model.preProcess(task = task) + task.models.add(model) + + // Update status. + modelDownloadStatus[model.name] = ModelDownloadStatus( + status = ModelDownloadStatusType.SUCCEEDED, + receivedBytes = localModel.fileSize, + totalBytes = localModel.fileSize + ) + } + val textInputHistory = dataStoreRepository.readTextInputHistory() Log.d(TAG, "text input history: $textInputHistory") return ModelManagerUiState( tasks = TASKS, - modelsByTaskName = modelsByTaskName, modelDownloadStatus = modelDownloadStatus, modelInitializationStatus = modelInstances, textInputHistory = textInputHistory, @@ -610,7 +702,8 @@ open class ModelManagerViewModel( // Add initial status and states. modelDownloadStatus[model.name] = getModelDownloadStatus(model = model) - modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED + modelInstances[model.name] = + ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED) } } } @@ -677,9 +770,13 @@ open class ModelManagerViewModel( } } - private fun updateModelInitializationStatus(model: Model, status: ModelInitializationStatus) { + private fun updateModelInitializationStatus( + model: Model, + status: ModelInitializationStatusType, + error: String = "" + ) { val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap() - curModelInstance[model.name] = status + curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error) val newUiState = uiState.value.copy(modelInitializationStatus = curModelInstance) _uiState.update { newUiState } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewDataStoreRepository.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewDataStoreRepository.kt index 1e41c38..d6941c7 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewDataStoreRepository.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewDataStoreRepository.kt @@ -18,6 +18,7 @@ package com.google.aiedge.gallery.ui.preview import com.google.aiedge.gallery.data.AccessTokenData import com.google.aiedge.gallery.data.DataStoreRepository +import com.google.aiedge.gallery.data.LocalModelInfo class PreviewDataStoreRepository : DataStoreRepository { override fun saveTextInputHistory(history: List) { @@ -40,4 +41,11 @@ class PreviewDataStoreRepository : DataStoreRepository { override fun readAccessTokenData(): AccessTokenData? { return null } + + override fun saveLocalModels(localModels: List) { + } + + override fun readLocalModels(): List { + return listOf() + } } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewModelManagerViewModel.kt index 31db9ee..407c2b1 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewModelManagerViewModel.kt @@ -17,7 +17,6 @@ package com.google.aiedge.gallery.ui.preview import android.content.Context -import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.ui.modelmanager.ModelManagerUiState @@ -39,8 +38,6 @@ class PreviewModelManagerViewModel(context: Context) : } } - val modelsByTaskName: Map> = - ALL_PREVIEW_TASKS.associate { task -> task.type.label to task.models } val modelDownloadStatus = mapOf( MODEL_TEST1.name to ModelDownloadStatus( status = ModelDownloadStatusType.IN_PROGRESS, @@ -61,7 +58,6 @@ class PreviewModelManagerViewModel(context: Context) : ) val newUiState = ModelManagerUiState( tasks = ALL_PREVIEW_TASKS, - modelsByTaskName = modelsByTaskName, modelDownloadStatus = modelDownloadStatus, modelInitializationStatus = mapOf(), selectedModel = MODEL_TEST2, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/textclassification/TextClassificationModelHelper.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/textclassification/TextClassificationModelHelper.kt index 91d5a81..3a3ab01 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/textclassification/TextClassificationModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/textclassification/TextClassificationModelHelper.kt @@ -40,14 +40,14 @@ class TextClassificationInferenceResult( * Helper object for managing text classification models. */ object TextClassificationModelHelper { - fun initialize(context: Context, model: Model, onDone: () -> Unit) { + fun initialize(context: Context, model: Model, onDone: (String) -> Unit) { val modelByteBuffer = readFileToByteBuffer(File(model.getPath(context = context))) if (modelByteBuffer != null) { val options = TextClassifier.TextClassifierOptions.builder().setBaseOptions( BaseOptions.builder().setModelAssetBuffer(modelByteBuffer).build() ).build() model.instance = TextClassifier.createFromOptions(context, options) - onDone() + onDone("") } }