From ef290cd7b0b8fac1266687539a458003ecd76f6d Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Thu, 15 May 2025 00:31:45 -0700 Subject: [PATCH] Add initial support for model allowlist, and stop generating response. --- Android/src/app/src/main/AndroidManifest.xml | 2 +- .../aiedge/gallery/GalleryApplication.kt | 4 - .../data/{HuggingFace.kt => ConfigValue.kt} | 68 ---- .../com/google/aiedge/gallery/data/Consts.kt | 2 + .../aiedge/gallery/data/DownloadRepository.kt | 8 +- .../com/google/aiedge/gallery/data/Model.kt | 35 +- .../aiedge/gallery/data/ModelAllowlist.kt | 116 +++++++ .../com/google/aiedge/gallery/data/Tasks.kt | 28 +- .../gallery/ui/common/DownloadAndTryButton.kt | 38 ++- .../google/aiedge/gallery/ui/common/Utils.kt | 2 + .../gallery/ui/common/chat/ChatMessage.kt | 5 + .../gallery/ui/common/chat/ChatPanel.kt | 4 + .../ui/common/chat/MessageBodyWarning.kt | 62 ++++ .../ui/common/chat/MessageInputText.kt | 29 +- .../gallery/ui/common/modelitem/ModelItem.kt | 1 - .../ui/common/modelitem/ModelNameAndStatus.kt | 2 + .../aiedge/gallery/ui/home/HomeScreen.kt | 76 ++++- .../gallery/ui/llmchat/LlmChatScreen.kt | 38 ++- .../gallery/ui/llmchat/LlmChatViewModel.kt | 140 ++++---- .../llmsingleturn/LlmSingleTurnViewModel.kt | 4 +- .../ui/llmsingleturn/PromptTemplatesPanel.kt | 1 + .../gallery/ui/llmsingleturn/ResponsePanel.kt | 4 +- .../gallery/ui/modelmanager/ModelList.kt | 4 - .../ui/modelmanager/ModelManagerViewModel.kt | 320 +++++++++--------- .../gallery/ui/navigation/GalleryNavGraph.kt | 6 +- .../aiedge/gallery/worker/DownloadWorker.kt | 32 +- Android/src/gradle/libs.versions.toml | 2 +- Android/src/settings.gradle.kts | 1 + 28 files changed, 676 insertions(+), 358 deletions(-) rename Android/src/app/src/main/java/com/google/aiedge/gallery/data/{HuggingFace.kt => ConfigValue.kt} (67%) create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt create mode 100644 Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyWarning.kt diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml index 6c366fb..4d3c832 100644 --- a/Android/src/app/src/main/AndroidManifest.xml +++ b/Android/src/app/src/main/AndroidManifest.xml @@ -32,7 +32,7 @@ android:dataExtractionRules="@xml/data_extraction_rules" android:fullBackupContent="@xml/backup_rules" android:icon="@mipmap/ic_launcher" - android:label="@string/app_name" + android:label="Edge Gallery" android:roundIcon="@mipmap/ic_launcher" android:supportsRtl="true" android:theme="@style/Theme.Gallery" diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApplication.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApplication.kt index b422851..49a7122 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApplication.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/GalleryApplication.kt @@ -23,7 +23,6 @@ import androidx.datastore.preferences.core.Preferences import androidx.datastore.preferences.preferencesDataStore import com.google.aiedge.gallery.data.AppContainer import com.google.aiedge.gallery.data.DefaultAppContainer -import com.google.aiedge.gallery.ui.common.processTasks import com.google.aiedge.gallery.ui.theme.ThemeSettings private val Context.dataStore: DataStore by preferencesDataStore(name = "app_gallery_preferences") @@ -35,9 +34,6 @@ class GalleryApplication : Application() { override fun onCreate() { super.onCreate() - // Process tasks. - processTasks() - container = DefaultAppContainer(this, dataStore) // Load theme. diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ConfigValue.kt similarity index 67% rename from Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt rename to Android/src/app/src/main/java/com/google/aiedge/gallery/data/ConfigValue.kt index cf92212..e73edb1 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ConfigValue.kt @@ -16,7 +16,6 @@ package com.google.aiedge.gallery.data -import com.google.aiedge.gallery.ui.common.ensureValidFileName import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable import kotlinx.serialization.SerializationException @@ -27,15 +26,6 @@ import kotlinx.serialization.encoding.Encoder import kotlinx.serialization.json.JsonDecoder import kotlinx.serialization.json.JsonPrimitive -@Serializable -data class HfModelSummary(val modelId: String) - -@Serializable -data class HfModelDetails(val id: String, val siblings: List) - -@Serializable -data class HfModelFile(val rfilename: String) - @Serializable(with = ConfigValueSerializer::class) sealed class ConfigValue { @Serializable @@ -85,64 +75,6 @@ object ConfigValueSerializer : KSerializer { } } -@Serializable -data class HfModel( - var id: String = "", - val task: String, - val name: String, - val url: String = "", - val file: String = "", - val sizeInBytes: Long, - val configs: Map, -) { - fun toModel(): Model { - val parts = if (url.isNotEmpty()) { - url.split('/') - } else if (file.isNotEmpty()) { - listOf(file) - } else { - listOf("") - } - val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}") - - // Generate configs based on the given default values. -// val configs: List = when (task) { -// TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs) -// // todo: add configs for other types. -// else -> listOf() -// } - // todo: fix when loading from models.json - val configs: List = listOf() - - // Construct url. - var modelUrl = url - if (modelUrl.isEmpty() && file.isNotEmpty()) { - modelUrl = "https://huggingface.co/${id}/resolve/main/${file}?download=true" - } - - // Other parameters. - val showBenchmarkButton = when (task) { - TASK_LLM_CHAT.type.label -> false - else -> true - } - val showRunAgainButton = when (task) { - TASK_LLM_CHAT.type.label -> false - else -> true - } - - return Model( - hfModelId = id, - name = name, - url = modelUrl, - sizeInBytes = sizeInBytes, - downloadFileName = fileName, - configs = configs, - showBenchmarkButton = showBenchmarkButton, - showRunAgainButton = showRunAgainButton, - ) - } -} - fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int { if (configValue == null) { return default diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Consts.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Consts.kt index ed666e7..4082be3 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Consts.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Consts.kt @@ -18,6 +18,8 @@ package com.google.aiedge.gallery.data // Keys used to send/receive data to Work. const val KEY_MODEL_URL = "KEY_MODEL_URL" +const val KEY_MODEL_VERSION = "KEY_MODEL_VERSION" +const val KEY_MODEL_DOWNLOAD_MODEL_DIR = "KEY_MODEL_DOWNLOAD_MODEL_DIR" const val KEY_MODEL_DOWNLOAD_FILE_NAME = "KEY_MODEL_DOWNLOAD_FILE_NAME" const val KEY_MODEL_TOTAL_BYTES = "KEY_MODEL_TOTAL_BYTES" const val KEY_MODEL_DOWNLOAD_RECEIVED_BYTES = "KEY_MODEL_DOWNLOAD_RECEIVED_BYTES" diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DownloadRepository.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DownloadRepository.kt index eea9306..96dc1f0 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DownloadRepository.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DownloadRepository.kt @@ -37,13 +37,13 @@ import androidx.work.OutOfQuotaPolicy import androidx.work.WorkInfo import androidx.work.WorkManager import androidx.work.WorkQuery +import com.google.aiedge.gallery.AppLifecycleProvider +import com.google.aiedge.gallery.R +import com.google.aiedge.gallery.worker.DownloadWorker import com.google.common.util.concurrent.FutureCallback import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors -import com.google.aiedge.gallery.AppLifecycleProvider -import com.google.aiedge.gallery.R -import com.google.aiedge.gallery.worker.DownloadWorker import java.util.UUID private const val TAG = "AGDownloadRepository" @@ -89,6 +89,8 @@ class DefaultDownloadRepository( val builder = Data.Builder() val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes } val inputDataBuilder = builder.putString(KEY_MODEL_URL, model.url) + .putString(KEY_MODEL_VERSION, model.version) + .putString(KEY_MODEL_DOWNLOAD_MODEL_DIR, model.normalizedName) .putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName) .putBoolean(KEY_MODEL_IS_ZIP, model.isZip).putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir) .putLong( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt index 8a8d2a1..2fe0b03 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt @@ -20,6 +20,7 @@ import android.content.Context import com.google.aiedge.gallery.ui.common.chat.PromptTemplate import com.google.aiedge.gallery.ui.common.convertValueToTargetType import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs +import java.io.File data class ModelDataFile( val name: String, @@ -33,16 +34,22 @@ enum class Accelerator(val label: String) { } const val IMPORTS_DIR = "__imports" +private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]") /** A model for a task */ data class Model( - /** The Hugging Face model ID (if applicable). */ - val hfModelId: String = "", - /** The name (for display purpose) of the model. */ val name: String, - /** The name of the downloaded model file. */ + /** The version of the model. */ + val version: String = "_", + + /** + * The name of the downloaded model file. + * + * The final file path of the downloaded model will be: + * {context.getExternalFilesDir}/{normalizedName}/{version}/{downloadFileName} + */ val downloadFileName: String, /** The URL to download the model from. */ @@ -88,6 +95,7 @@ data class Model( val imported: Boolean = false, // The following fields are managed by the app. Don't need to set manually. + var normalizedName: String = "", var instance: Any? = null, var initializing: Boolean = false, // TODO(jingjin): use a "queue" system to manage model init and cleanup. @@ -96,6 +104,10 @@ data class Model( var totalBytes: Long = 0L, var accessToken: String? = null, ) { + init { + normalizedName = NORMALIZE_NAME_REGEX.replace(name, "_") + } + fun preProcess() { val configValues: MutableMap = mutableMapOf() for (config in this.configs) { @@ -106,11 +118,22 @@ data class Model( } fun getPath(context: Context, fileName: String = downloadFileName): String { - val baseDir = "${context.getExternalFilesDir(null)}" + if (imported) { + return listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", fileName).joinToString( + File.separator + ) + } + + val baseDir = + listOf( + context.getExternalFilesDir(null)?.absolutePath ?: "", + normalizedName, + version + ).joinToString(File.separator) return if (this.isZip && this.unzipDir.isNotEmpty()) { "$baseDir/${this.unzipDir}" } else { - "$baseDir/${fileName}" + "$baseDir/$fileName" } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt new file mode 100644 index 0000000..1ddfd0f --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/ModelAllowlist.kt @@ -0,0 +1,116 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.aiedge.gallery.data + +import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS +import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TEMPERATURE +import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPK +import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPP +import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs +import kotlinx.serialization.Serializable + +/** A model in the model allowlist. */ +@Serializable +data class AllowedModel( + val name: String, + val modelId: String, + val modelFile: String, + val description: String, + val sizeInBytes: Long, + val version: String, + val defaultConfig: Map, + val taskTypes: List, + val disabled: Boolean? = null, +) { + fun toModel(): Model { + // Construct HF download url. + val downloadUrl = "https://huggingface.co/$modelId/resolve/main/$modelFile?download=true" + + // Config. + val isLlmModel = + taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id) + var configs: List = listOf() + if (isLlmModel) { + var defaultTopK: Int = DEFAULT_TOPK + var defaultTopP: Float = DEFAULT_TOPP + var defaultTemperature: Float = DEFAULT_TEMPERATURE + var defaultMaxToken: Int = 1024 + var accelerators: List = DEFAULT_ACCELERATORS + if (defaultConfig.containsKey("topK")) { + defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK) + } + if (defaultConfig.containsKey("topP")) { + defaultTopP = getFloatConfigValue(defaultConfig["topP"], defaultTopP) + } + if (defaultConfig.containsKey("temperature")) { + defaultTemperature = getFloatConfigValue(defaultConfig["temperature"], defaultTemperature) + } + if (defaultConfig.containsKey("maxTokens")) { + defaultMaxToken = getIntConfigValue(defaultConfig["maxTokens"], defaultMaxToken) + } + if (defaultConfig.containsKey("accelerators")) { + val items = getStringConfigValue(defaultConfig["accelerators"], "gpu").split(",") + accelerators = mutableListOf() + for (item in items) { + if (item == "cpu") { + accelerators.add(Accelerator.CPU) + } else if (item == "gpu") { + accelerators.add(Accelerator.GPU) + } + } + } + configs = createLlmChatConfigs( + defaultTopK = defaultTopK, + defaultTopP = defaultTopP, + defaultTemperature = defaultTemperature, + defaultMaxToken = defaultMaxToken, + accelerators = accelerators, + ) + } + + // Misc. + var showBenchmarkButton = true + val showRunAgainButton = true + if (taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)) { + showBenchmarkButton = false + } + + return Model( + name = name, + version = version, + info = description, + url = downloadUrl, + sizeInBytes = sizeInBytes, + configs = configs, + downloadFileName = modelFile, + showBenchmarkButton = showBenchmarkButton, + showRunAgainButton = showRunAgainButton, + learnMoreUrl = "https://huggingface.co/${modelId}" + ) + } + + override fun toString(): String { + return "$modelId/$modelFile" + } +} + +/** The model allowlist. */ +@Serializable +data class ModelAllowlist( + val models: List, +) + diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt index 6caac37..e50fb47 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Tasks.kt @@ -27,15 +27,15 @@ import androidx.compose.ui.graphics.vector.ImageVector import com.google.aiedge.gallery.R /** Type of task. */ -enum class TaskType(val label: String) { - TEXT_CLASSIFICATION("Text Classification"), - IMAGE_CLASSIFICATION("Image Classification"), - IMAGE_GENERATION("Image Generation"), - LLM_CHAT("LLM Chat"), - LLM_SINGLE_TURN("LLM Use Cases"), +enum class TaskType(val label: String, val id: String) { + TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"), + IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"), + IMAGE_GENERATION(label = "Image Generation", id = "image_generation"), + LLM_CHAT(label = "LLM Chat", id = "llm_chat"), + LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"), - TEST_TASK_1("Test task 1"), - TEST_TASK_2("Test task 2") + TEST_TASK_1(label = "Test task 1", id = "test_task_1"), + TEST_TASK_2(label = "Test task 2", id = "test_task_2") } /** Data class for a task listed in home screen. */ @@ -91,17 +91,19 @@ val TASK_IMAGE_CLASSIFICATION = Task( val TASK_LLM_CHAT = Task( type = TaskType.LLM_CHAT, icon = Icons.Outlined.Forum, - models = MODELS_LLM, +// models = MODELS_LLM, + models = mutableListOf(), description = "Chat with a on-device large language model", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat ) -val TASK_LLM_SINGLE_TURN = Task( - type = TaskType.LLM_SINGLE_TURN, +val TASK_LLM_USECASES = Task( + type = TaskType.LLM_USECASES, icon = Icons.Outlined.Widgets, - models = MODELS_LLM, +// models = MODELS_LLM, + models = mutableListOf(), description = "Single turn use cases with on-device large language model", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", @@ -123,7 +125,7 @@ val TASKS: List = listOf( // TASK_TEXT_CLASSIFICATION, // TASK_IMAGE_CLASSIFICATION, // TASK_IMAGE_GENERATION, - TASK_LLM_SINGLE_TURN, + TASK_LLM_USECASES, TASK_LLM_CHAT, ) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt index 6ec84a0..ba156ca 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/DownloadAndTryButton.kt @@ -28,12 +28,15 @@ import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.wrapContentHeight import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.rounded.ArrowForward +import androidx.compose.material.icons.rounded.Error +import androidx.compose.material3.AlertDialog import androidx.compose.material3.Button import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme import androidx.compose.material3.ModalBottomSheet import androidx.compose.material3.Text +import androidx.compose.material3.TextButton import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.runtime.Composable import androidx.compose.runtime.getValue @@ -102,6 +105,7 @@ fun DownloadAndTryButton( val context = LocalContext.current var checkingToken by remember { mutableStateOf(false) } var showAgreementAckSheet by remember { mutableStateOf(false) } + var showErrorDialog by remember { mutableStateOf(false) } val sheetState = rememberModalBottomSheetState() // A launcher for requesting notification permission. @@ -208,12 +212,18 @@ fun DownloadAndTryButton( TAG, "Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download" ) - if (modelManagerViewModel.getModelUrlResponse(model = model) == HttpURLConnection.HTTP_OK) { + val firstResponseCode = modelManagerViewModel.getModelUrlResponse(model = model) + if (firstResponseCode == HttpURLConnection.HTTP_OK) { Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...") withContext(Dispatchers.Main) { startDownload(null) } return@launch + } else if (firstResponseCode < 0) { + checkingToken = false + Log.e(TAG, "Unknown network error") + showErrorDialog = true + return@launch } Log.d(TAG, "Model '${model.name}' needs auth. Start token exchange process...") @@ -334,4 +344,30 @@ fun DownloadAndTryButton( } } } + + if (showErrorDialog) { + AlertDialog( + icon = { + Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error) + }, + title = { + Text("Unknown network error") + }, + text = { + Text("Please check your internet connection.") + }, + onDismissRequest = { + showErrorDialog = false + }, + confirmButton = { + TextButton( + onClick = { + showErrorDialog = false + } + ) { + Text("Close") + } + }, + ) + } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt index d614b3c..da4d98d 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt @@ -484,5 +484,7 @@ fun processLlmResponse(response: String): String { } } + newContent = newContent.replace("\\n", "\n") + return newContent } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatMessage.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatMessage.kt index 17938e1..9260435 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatMessage.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatMessage.kt @@ -24,6 +24,7 @@ import com.google.aiedge.gallery.data.Model enum class ChatMessageType { INFO, + WARNING, TEXT, IMAGE, IMAGE_WITH_HISTORY, @@ -57,6 +58,10 @@ class ChatMessageLoading : ChatMessage(type = ChatMessageType.LOADING, side = Ch class ChatMessageInfo(val content: String) : ChatMessage(type = ChatMessageType.INFO, side = ChatSide.SYSTEM) +/** Chat message for info (help). */ +class ChatMessageWarning(val content: String) : + ChatMessage(type = ChatMessageType.WARNING, side = ChatSide.SYSTEM) + /** Chat message for config values change. */ class ChatMessageConfigValuesChange( val model: Model, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt index 8e66bf2..20c967d 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt @@ -269,6 +269,9 @@ fun ChatPanel( // Info. is ChatMessageInfo -> MessageBodyInfo(message = message) + // Warning + is ChatMessageWarning -> MessageBodyWarning(message = message) + // Config values change. is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message) @@ -433,6 +436,7 @@ fun ChatPanel( modelManagerViewModel = modelManagerViewModel, curMessage = curMessage, inProgress = uiState.inProgress, + modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, textFieldPlaceHolderRes = task.textInputPlaceHolderRes, onValueChanged = { curMessage = it }, onSendMessage = { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyWarning.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyWarning.kt new file mode 100644 index 0000000..1c79ad2 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageBodyWarning.kt @@ -0,0 +1,62 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.aiedge.gallery.ui.common.chat + +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material3.MaterialTheme +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.tooling.preview.Preview +import androidx.compose.ui.unit.dp +import com.google.aiedge.gallery.ui.theme.GalleryTheme + +/** + * Composable function to display warning message content within a chat. + * + * Supports markdown. + */ +@Composable +fun MessageBodyWarning(message: ChatMessageWarning) { + Row( + modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center + ) { + Box( + modifier = Modifier + .clip(RoundedCornerShape(16.dp)) + .background(MaterialTheme.colorScheme.tertiaryContainer) + ) { + MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true) + } + } +} + +@Preview(showBackground = true) +@Composable +fun MessageBodyWarningPreview() { + GalleryTheme { + Row(modifier = Modifier.padding(16.dp)) { + MessageBodyWarning(message = ChatMessageWarning(content = "This is a warning")) + } + } +} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt index c550645..2029882 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/MessageInputText.kt @@ -75,6 +75,7 @@ fun MessageInputText( modelManagerViewModel: ModelManagerViewModel, curMessage: String, inProgress: Boolean, + modelInitializing: Boolean, @StringRes textFieldPlaceHolderRes: Int, onValueChanged: (String) -> Unit, onSendMessage: (ChatMessage) -> Unit, @@ -162,17 +163,19 @@ fun MessageInputText( Spacer(modifier = Modifier.width(8.dp)) if (inProgress && showStopButtonWhenInProgress) { - IconButton( - onClick = onStopButtonClicked, - colors = IconButtonDefaults.iconButtonColors( - containerColor = MaterialTheme.colorScheme.secondaryContainer, - ), - ) { - Icon( - Icons.Rounded.Stop, - contentDescription = "", - tint = MaterialTheme.colorScheme.primary - ) + if (!modelInitializing) { + IconButton( + onClick = onStopButtonClicked, + colors = IconButtonDefaults.iconButtonColors( + containerColor = MaterialTheme.colorScheme.secondaryContainer, + ), + ) { + Icon( + Icons.Rounded.Stop, + contentDescription = "", + tint = MaterialTheme.colorScheme.primary + ) + } } } // Send button. Only shown when text is not empty. else if (curMessage.isNotEmpty()) { @@ -230,6 +233,7 @@ fun MessageInputTextPreview() { modelManagerViewModel = PreviewModelManagerViewModel(context = context), curMessage = "hello", inProgress = false, + modelInitializing = false, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, onValueChanged = {}, onSendMessage = {}, @@ -239,6 +243,7 @@ fun MessageInputTextPreview() { modelManagerViewModel = PreviewModelManagerViewModel(context = context), curMessage = "hello", inProgress = true, + modelInitializing = false, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, onValueChanged = {}, onSendMessage = {}, @@ -247,6 +252,7 @@ fun MessageInputTextPreview() { modelManagerViewModel = PreviewModelManagerViewModel(context = context), curMessage = "", inProgress = false, + modelInitializing = false, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, onValueChanged = {}, onSendMessage = {}, @@ -255,6 +261,7 @@ fun MessageInputTextPreview() { modelManagerViewModel = PreviewModelManagerViewModel(context = context), curMessage = "", inProgress = true, + modelInitializing = false, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, onValueChanged = {}, onSendMessage = {}, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt index 00dc4fc..c7eb4b3 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt @@ -39,7 +39,6 @@ import androidx.compose.material.icons.rounded.ChevronRight import androidx.compose.material.icons.rounded.Settings import androidx.compose.material.icons.rounded.UnfoldLess import androidx.compose.material.icons.rounded.UnfoldMore -import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.IconButton import androidx.compose.material3.OutlinedButton diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt index 182d016..d60afff 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelNameAndStatus.kt @@ -75,6 +75,8 @@ fun ModelNameAndStatus( ) { Text( model.name, + maxLines = 1, + overflow = TextOverflow.Ellipsis, style = MaterialTheme.typography.titleMedium, modifier = modifier, ) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt index 1e21e1e..4438e1f 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt @@ -38,6 +38,7 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size import androidx.compose.foundation.lazy.grid.GridCells import androidx.compose.foundation.lazy.grid.GridItemSpan import androidx.compose.foundation.lazy.grid.LazyVerticalGrid @@ -46,8 +47,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.outlined.NoteAdd import androidx.compose.material.icons.filled.Add +import androidx.compose.material.icons.rounded.Error +import androidx.compose.material3.AlertDialog import androidx.compose.material3.Card import androidx.compose.material3.CardDefaults +import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme @@ -57,6 +61,7 @@ import androidx.compose.material3.SmallFloatingActionButton import androidx.compose.material3.SnackbarHost import androidx.compose.material3.SnackbarHostState import androidx.compose.material3.Text +import androidx.compose.material3.TextButton import androidx.compose.material3.TopAppBarDefaults import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.runtime.Composable @@ -180,6 +185,7 @@ fun HomeScreen( TaskList( tasks = tasks, navigateToTaskScreen = navigateToTaskScreen, + loadingModelAllowlist = uiState.loadingModelAllowlist, modifier = Modifier.fillMaxSize(), contentPadding = innerPadding, ) @@ -285,12 +291,40 @@ fun HomeScreen( } } } + + if (uiState.loadingModelAllowlistError.isNotEmpty()) { + AlertDialog( + icon = { + Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error) + }, + title = { + Text(uiState.loadingModelAllowlistError) + }, + text = { + Text("Please check your internet connection and try again later.") + }, + onDismissRequest = { + modelManagerViewModel.loadModelAllowlist() + }, + confirmButton = { + TextButton( + onClick = { + modelManagerViewModel.loadModelAllowlist() + } + ) { + Text("Retry") + } + }, + ) + + } } @Composable private fun TaskList( tasks: List, navigateToTaskScreen: (Task) -> Unit, + loadingModelAllowlist: Boolean, modifier: Modifier = Modifier, contentPadding: PaddingValues = PaddingValues(0.dp), ) { @@ -312,17 +346,37 @@ private fun TaskList( ) } - // Cards. - items(tasks) { task -> - TaskCard( - task = task, - onClick = { - navigateToTaskScreen(task) - }, - modifier = Modifier - .fillMaxWidth() - .aspectRatio(1f) - ) + if (loadingModelAllowlist) { + item(key = "loading", span = { GridItemSpan(2) }) { + Row( + horizontalArrangement = Arrangement.Center, + modifier = Modifier + .fillMaxWidth() + .padding(top = 32.dp) + ) { + CircularProgressIndicator( + trackColor = MaterialTheme.colorScheme.surfaceVariant, + strokeWidth = 3.dp, + modifier = Modifier + .padding(end = 8.dp) + .size(20.dp) + ) + Text("Loading model list...", style = MaterialTheme.typography.bodyMedium) + } + } + } else { + // Cards. + items(tasks) { task -> + TaskCard( + task = task, + onClick = { + navigateToTaskScreen(task) + }, + modifier = Modifier + .fillMaxWidth() + .aspectRatio(1f) + ) + } } // Bottom padding. diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt index dbc54e5..a3db07d 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatScreen.kt @@ -16,11 +16,15 @@ package com.google.aiedge.gallery.ui.llmchat +import android.util.Log import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier +import androidx.compose.ui.platform.LocalContext import androidx.lifecycle.viewmodel.compose.viewModel import com.google.aiedge.gallery.ui.ViewModelProvider +import com.google.aiedge.gallery.ui.common.chat.ChatMessageInfo import com.google.aiedge.gallery.ui.common.chat.ChatMessageText +import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning import com.google.aiedge.gallery.ui.common.chat.ChatView import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import kotlinx.serialization.Serializable @@ -40,6 +44,8 @@ fun LlmChatScreen( factory = ViewModelProvider.Factory ), ) { + val context = LocalContext.current + ChatView( task = viewModel.task, viewModel = viewModel, @@ -51,25 +57,43 @@ fun LlmChatScreen( ) if (message is ChatMessageText) { modelManagerViewModel.addTextInputHistory(message.content) - viewModel.generateResponse( - model = model, - input = message.content, - ) + viewModel.generateResponse(model = model, input = message.content, onError = { + viewModel.addMessage( + model = model, + message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.") + ) + + modelManagerViewModel.initializeModel( + context = context, task = viewModel.task, model = model, force = true + ) + }) } }, onRunAgainClicked = { model, message -> if (message is ChatMessageText) { - viewModel.runAgain(model = model, message = message) + viewModel.runAgain(model = model, message = message, onError = { + viewModel.addMessage( + model = model, + message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.") + ) + + modelManagerViewModel.initializeModel( + context = context, task = viewModel.task, model = model, force = true + ) + }) } }, onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations -> if (message is ChatMessageText) { viewModel.benchmark( - model = model, - message = message + model = model, message = message ) } }, + showStopButtonInInputWhenInProgress = true, + onStopButtonClicked = { model -> + viewModel.stopResponse(model = model) + }, navigateUp = navigateUp, modifier = modifier, ) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt index d2936d2..f994978 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -16,6 +16,7 @@ package com.google.aiedge.gallery.ui.llmchat +import android.util.Log import androidx.lifecycle.viewModelScope import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.TASK_LLM_CHAT @@ -26,6 +27,7 @@ import com.google.aiedge.gallery.ui.common.chat.ChatMessageType import com.google.aiedge.gallery.ui.common.chat.ChatSide import com.google.aiedge.gallery.ui.common.chat.ChatViewModel import com.google.aiedge.gallery.ui.common.chat.Stat +import kotlinx.coroutines.CoroutineExceptionHandler import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.launch @@ -39,7 +41,7 @@ private val STATS = listOf( ) class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { - fun generateResponse(model: Model, input: String) { + fun generateResponse(model: Model, input: String, onError: () -> Unit) { viewModelScope.launch(Dispatchers.Default) { setInProgress(true) @@ -65,75 +67,90 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { var prefillSpeed = 0f var decodeSpeed: Float val start = System.currentTimeMillis() - LlmChatModelHelper.runInference( - model = model, - input = input, - resultListener = { partialResult, done -> - val curTs = System.currentTimeMillis() - if (firstRun) { - firstTokenTs = System.currentTimeMillis() - timeToFirstToken = (firstTokenTs - start) / 1000f - prefillSpeed = prefillTokens / timeToFirstToken - firstRun = false - } else { - decodeTokens++ - } + try { + LlmChatModelHelper.runInference( + model = model, + input = input, + resultListener = { partialResult, done -> + val curTs = System.currentTimeMillis() - // Remove the last message if it is a "loading" message. - // This will only be done once. - val lastMessage = getLastMessage(model = model) - if (lastMessage?.type == ChatMessageType.LOADING) { - removeLastMessage(model = model) - - // Add an empty message that will receive streaming results. - addMessage( - model = model, - message = ChatMessageText(content = "", side = ChatSide.AGENT) - ) - } - - // Incrementally update the streamed partial results. - val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1 - updateLastTextMessageContentIncrementally( - model = model, - partialContent = partialResult, - latencyMs = latencyMs.toFloat() - ) - - if (done) { - setInProgress(false) - - decodeSpeed = - decodeTokens / ((curTs - firstTokenTs) / 1000f) - if (decodeSpeed.isNaN()) { - decodeSpeed = 0f + if (firstRun) { + firstTokenTs = System.currentTimeMillis() + timeToFirstToken = (firstTokenTs - start) / 1000f + prefillSpeed = prefillTokens / timeToFirstToken + firstRun = false + } else { + decodeTokens++ } - if (lastMessage is ChatMessageText) { - updateLastTextMessageLlmBenchmarkResult( - model = model, llmBenchmarkResult = - ChatMessageBenchmarkLlmResult( - orderedStats = STATS, - statValues = mutableMapOf( - "prefill_speed" to prefillSpeed, - "decode_speed" to decodeSpeed, - "time_to_first_token" to timeToFirstToken, - "latency" to (curTs - start).toFloat() / 1000f, - ), - running = false, - latencyMs = -1f, - ) + // Remove the last message if it is a "loading" message. + // This will only be done once. + val lastMessage = getLastMessage(model = model) + if (lastMessage?.type == ChatMessageType.LOADING) { + removeLastMessage(model = model) + + // Add an empty message that will receive streaming results. + addMessage( + model = model, + message = ChatMessageText(content = "", side = ChatSide.AGENT) ) } - } - }, cleanUpListener = { - setInProgress(false) - }) + + // Incrementally update the streamed partial results. + val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1 + updateLastTextMessageContentIncrementally( + model = model, + partialContent = partialResult, + latencyMs = latencyMs.toFloat() + ) + + if (done) { + setInProgress(false) + + decodeSpeed = + decodeTokens / ((curTs - firstTokenTs) / 1000f) + if (decodeSpeed.isNaN()) { + decodeSpeed = 0f + } + + if (lastMessage is ChatMessageText) { + updateLastTextMessageLlmBenchmarkResult( + model = model, llmBenchmarkResult = + ChatMessageBenchmarkLlmResult( + orderedStats = STATS, + statValues = mutableMapOf( + "prefill_speed" to prefillSpeed, + "decode_speed" to decodeSpeed, + "time_to_first_token" to timeToFirstToken, + "latency" to (curTs - start).toFloat() / 1000f, + ), + running = false, + latencyMs = -1f, + ) + ) + } + } + }, cleanUpListener = { + setInProgress(false) + }) + } catch (e: Exception) { + setInProgress(false) + onError() + } } } - fun runAgain(model: Model, message: ChatMessageText) { + fun stopResponse(model: Model) { + Log.d(TAG, "Stopping response for model ${model.name}...") + viewModelScope.launch(Dispatchers.Default) { + setInProgress(false) + val instance = model.instance as LlmModelInstance + instance.session.cancelGenerateResponseAsync() + } + } + + fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) { viewModelScope.launch(Dispatchers.Default) { // Wait for model to be initialized. while (model.instance == null) { @@ -147,6 +164,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { generateResponse( model = model, input = message.content, + onError = onError ) } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt index fe5ffed..cc045cf 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt @@ -20,7 +20,7 @@ import android.util.Log import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.google.aiedge.gallery.data.Model -import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN +import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.aiedge.gallery.ui.common.chat.Stat @@ -64,7 +64,7 @@ private val STATS = listOf( Stat(id = "latency", label = "Latency", unit = "sec") ) -open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_SINGLE_TURN) : ViewModel() { +open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewModel() { private val _uiState = MutableStateFlow(createUiState(task = task)) val uiState = _uiState.asStateFlow() diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt index 2d36fbb..661daaa 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt @@ -218,6 +218,7 @@ fun PromptTemplatesPanel( .clip(MessageBubbleShape(radius = bubbleBorderRadius)) .background(MaterialTheme.customColors.agentBubbleBgColor) .padding(16.dp) + .focusRequester(focusRequester) ) } else { TextField( diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/ResponsePanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/ResponsePanel.kt index ebc79ac..8d90ff4 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/ResponsePanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/ResponsePanel.kt @@ -57,7 +57,7 @@ import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.unit.dp import com.google.aiedge.gallery.data.Model -import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN +import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.ui.common.chat.MarkdownText import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading @@ -76,7 +76,7 @@ fun ResponsePanel( modelManagerViewModel: ModelManagerViewModel, modifier: Modifier = Modifier, ) { - val task = TASK_LLM_SINGLE_TURN + val task = TASK_LLM_USECASES val uiState by viewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val inProgress = uiState.inProgress diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt index 60dc32c..29e25d8 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt @@ -16,8 +16,6 @@ package com.google.aiedge.gallery.ui.modelmanager -import android.os.Build -import androidx.annotation.RequiresApi import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box @@ -62,7 +60,6 @@ import com.google.aiedge.gallery.ui.theme.customColors private const val TAG = "AGModelList" /** The list of models in the model manager. */ -@RequiresApi(Build.VERSION_CODES.O) @Composable fun ModelList( task: Task, @@ -213,7 +210,6 @@ fun ClickableLink( } } -@RequiresApi(Build.VERSION_CODES.O) @Preview(showBackground = true) @Composable fun ModelListPreview() { diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt index 79bf5d9..d9598cf 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -30,32 +30,28 @@ import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.DataStoreRepository import com.google.aiedge.gallery.data.DownloadRepository import com.google.aiedge.gallery.data.EMPTY_MODEL -import com.google.aiedge.gallery.data.HfModel -import com.google.aiedge.gallery.data.HfModelDetails -import com.google.aiedge.gallery.data.HfModelSummary import com.google.aiedge.gallery.data.IMPORTS_DIR import com.google.aiedge.gallery.data.ImportedModelInfo import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.data.ModelAllowlist import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.data.TASK_LLM_CHAT -import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN +import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.data.getModelByName import com.google.aiedge.gallery.ui.common.AuthConfig import com.google.aiedge.gallery.ui.common.convertValueToTargetType +import com.google.aiedge.gallery.ui.common.processTasks import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.async -import kotlinx.coroutines.awaitAll -import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.delay import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.asStateFlow @@ -73,8 +69,9 @@ import java.net.HttpURLConnection import java.net.URL private const val TAG = "AGModelManagerViewModel" -private const val HG_COMMUNITY = "jinjingforevercommunity" private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50 +private const val MODEL_ALLOWLIST_URL = + "https://raw.githubusercontent.com/jinjingforever/kokoro-codelab-jingjin/refs/heads/main/model_allowlist.json" data class ModelInitializationStatus( val status: ModelInitializationStatusType, var error: String = "" @@ -122,6 +119,14 @@ data class ModelManagerUiState( */ val loadingHfModels: Boolean = false, + /** + * Whether the app is loading and processing the model allowlist. + */ + val loadingModelAllowlist: Boolean = true, + + /** The error message when loading the model allowlist. */ + val loadingModelAllowlistError: String = "", + /** * The currently selected model. */ @@ -153,7 +158,7 @@ open class ModelManagerViewModel( private val externalFilesDir = context.getExternalFilesDir(null) private val inProgressWorkInfos: List = downloadRepository.getEnqueuedOrRunningWorkInfos() - protected val _uiState = MutableStateFlow(createUiState()) + protected val _uiState = MutableStateFlow(createEmptyUiState()) val uiState = _uiState.asStateFlow() val authService = AuthorizationService(context) @@ -162,44 +167,7 @@ open class ModelManagerViewModel( var pagerScrollState: MutableStateFlow = MutableStateFlow(PagerScrollState()) init { - Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos") - - // Iterate through the inProgressWorkInfos and retrieve the corresponding modes. - // Those models are the ones that have not finished downloading. - val models: MutableList = mutableListOf() - for (info in inProgressWorkInfos) { - getModelByName(info.modelName)?.let { model -> - models.add(model) - } - } - - // Cancel all pending downloads for the retrieved models. - downloadRepository.cancelAll(models) { - Log.d(TAG, "All pending work is cancelled") - - viewModelScope.launch(Dispatchers.IO) { - // Load models from hg community. - loadHfModels() - Log.d(TAG, "Done loading HF models") - - // Kick off downloads for these models . - withContext(Dispatchers.Main) { - val tokenStatusAndData = getTokenStatusAndData() - for (info in inProgressWorkInfos) { - val model: Model? = getModelByName(info.modelName) - if (model != null) { - if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) { - model.accessToken = tokenStatusAndData.data.accessToken - } - Log.d(TAG, "Sending a new download request for '${model.name}'") - downloadRepository.downloadModel( - model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus - ) - } - } - } - } - } + loadModelAllowlist() } override fun onCleared() { @@ -231,12 +199,10 @@ open class ModelManagerViewModel( } fun deleteModel(task: Task, model: Model) { - deleteFileFromExternalFilesDir(model.downloadFileName) - for (file in model.extraDataFiles) { - deleteFileFromExternalFilesDir(file.downloadFileName) - } - if (model.isZip && model.unzipDir.isNotEmpty()) { - deleteDirFromExternalFilesDir(model.unzipDir) + if (model.imported) { + deleteFileFromExternalFilesDir(model.downloadFileName) + } else { + deleteDirFromExternalFilesDir(model.normalizedName) } // Update model download status to NotDownloaded. @@ -340,7 +306,7 @@ open class ModelManagerViewModel( onDone = onDone, ) - TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.initialize( + TaskType.LLM_USECASES -> LlmChatModelHelper.initialize( context = context, model = model, onDone = onDone, @@ -364,7 +330,7 @@ open class ModelManagerViewModel( TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model) TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model) TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model) - TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.cleanUp(model = model) + TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model) TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model) TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_2 -> {} @@ -444,33 +410,40 @@ open class ModelManagerViewModel( } fun getModelUrlResponse(model: Model, accessToken: String? = null): Int { - val url = URL(model.url) - val connection = url.openConnection() as HttpURLConnection - if (accessToken != null) { - connection.setRequestProperty( - "Authorization", "Bearer $accessToken" - ) - } - connection.connect() + try { + val url = URL(model.url) + val connection = url.openConnection() as HttpURLConnection + if (accessToken != null) { + connection.setRequestProperty( + "Authorization", "Bearer $accessToken" + ) + } + connection.connect() - // Report the result. - return connection.responseCode + // Report the result. + return connection.responseCode + } catch (e: Exception) { + Log.e(TAG, "$e") + return -1 + } } fun addImportedLlmModel(info: ImportedModelInfo) { Log.d(TAG, "adding imported llm model: $info") - // Remove duplicated imported model if existed. - val task = TASK_LLM_CHAT - val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } - if (modelIndex >= 0) { - Log.d(TAG, "duplicated imported model found in task. Removing it first") - task.models.removeAt(modelIndex) - } - // Create model. - val model = createModelFromImportedModelInfo(info = info, task = task) - task.models.add(model) + val model = createModelFromImportedModelInfo(info = info) + + // Remove duplicated imported model if existed. + for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES)) { + val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } + if (modelIndex >= 0) { + Log.d(TAG, "duplicated imported model found in task. Removing it first") + task.models.removeAt(modelIndex) + } + task.models.add(model) + task.updateTrigger.value = System.currentTimeMillis() + } // Add initial status and states. val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap() @@ -491,10 +464,6 @@ open class ModelManagerViewModel( modelInitializationStatus = modelInstances ) } - task.updateTrigger.value = System.currentTimeMillis() - // Also need to update single turn task. - TASK_LLM_SINGLE_TURN.updateTrigger.value = System.currentTimeMillis() - // Add to preference storage. val importedModels = dataStoreRepository.readImportedModels().toMutableList() @@ -623,10 +592,110 @@ open class ModelManagerViewModel( } } + private fun processPendingDownloads() { + Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos") + + // Iterate through the inProgressWorkInfos and retrieve the corresponding modes. + // Those models are the ones that have not finished downloading. + val models: MutableList = mutableListOf() + for (info in inProgressWorkInfos) { + getModelByName(info.modelName)?.let { model -> + models.add(model) + } + } + + // Cancel all pending downloads for the retrieved models. + downloadRepository.cancelAll(models) { + Log.d(TAG, "All pending work is cancelled") + + viewModelScope.launch(Dispatchers.IO) { + // Kick off downloads for these models . + withContext(Dispatchers.Main) { + val tokenStatusAndData = getTokenStatusAndData() + for (info in inProgressWorkInfos) { + val model: Model? = getModelByName(info.modelName) + if (model != null) { + if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) { + model.accessToken = tokenStatusAndData.data.accessToken + } + Log.d(TAG, "Sending a new download request for '${model.name}'") + downloadRepository.downloadModel( + model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus + ) + } + } + } + } + } + } + + fun loadModelAllowlist() { + _uiState.update { + uiState.value.copy( + loadingModelAllowlist = true, + loadingModelAllowlistError = "" + ) + } + + viewModelScope.launch(Dispatchers.IO) { + try { + // Load model allowlist json. + val modelAllowlist: ModelAllowlist? = + getJsonResponse(url = MODEL_ALLOWLIST_URL) + + if (modelAllowlist == null) { + _uiState.update { uiState.value.copy(loadingModelAllowlistError = "Failed to load model list") } + return@launch + } + + Log.d(TAG, "Allowlist: $modelAllowlist") + + // Convert models in the allowlist. + for (allowedModel in modelAllowlist.models) { + if (allowedModel.disabled == true) { + continue + } + + val model = allowedModel.toModel() + if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) { + TASK_LLM_CHAT.models.add(model) + } + if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) { + TASK_LLM_USECASES.models.add(model) + } + } + + // Pre-process all tasks. + processTasks() + + // Update UI state. + val newUiState = createUiState() + _uiState.update { + newUiState.copy( + loadingModelAllowlist = false, + ) + } + + // Process pending downloads. + processPendingDownloads() + } catch (e: Exception) { + e.printStackTrace() + } + } + } + private fun isModelPartiallyDownloaded(model: Model): Boolean { return inProgressWorkInfos.find { it.modelName == model.name } != null } + private fun createEmptyUiState(): ModelManagerUiState { + return ModelManagerUiState( + tasks = listOf(), + modelDownloadStatus = mapOf(), + modelInitializationStatus = mapOf(), + ) + } + private fun createUiState(): ModelManagerUiState { val modelDownloadStatus: MutableMap = mutableMapOf() val modelInstances: MutableMap = mutableMapOf() @@ -643,11 +712,11 @@ open class ModelManagerViewModel( Log.d(TAG, "stored imported model: $importedModel") // Create model. - val model = createModelFromImportedModelInfo(info = importedModel, task = TASK_LLM_CHAT) + val model = createModelFromImportedModelInfo(info = importedModel) // Add to task. - val task = TASK_LLM_CHAT - task.models.add(model) + TASK_LLM_CHAT.models.add(model) + TASK_LLM_USECASES.models.add(model) // Update status. modelDownloadStatus[model.name] = ModelDownloadStatus( @@ -660,6 +729,7 @@ open class ModelManagerViewModel( val textInputHistory = dataStoreRepository.readTextInputHistory() Log.d(TAG, "text input history: $textInputHistory") + Log.d(TAG, "model download status: $modelDownloadStatus") return ModelManagerUiState( tasks = TASKS, modelDownloadStatus = modelDownloadStatus, @@ -668,7 +738,7 @@ open class ModelManagerViewModel( ) } - private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model { + private fun createModelFromImportedModelInfo(info: ImportedModelInfo): Model { val accelerators: List = (convertValueToTargetType( info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING ) as String).split(",").mapNotNull { acceleratorLabel -> @@ -733,74 +803,6 @@ open class ModelManagerViewModel( ) } - suspend fun loadHfModels() { - // Update loading state shown in ui. - _uiState.update { - uiState.value.copy( - loadingHfModels = true, - ) - } - - val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap() - val modelInstances = uiState.value.modelInitializationStatus.toMutableMap() - try { - // Load model summaries. - val modelSummaries = - getJsonResponse>(url = "https://huggingface.co/api/models?search=$HG_COMMUNITY") - Log.d(TAG, "HF model summaries: $modelSummaries") - - // Load individual models in parallel. - if (modelSummaries != null) { - coroutineScope { - val hfModels = modelSummaries.map { summary -> - async { - val details = - getJsonResponse(url = "https://huggingface.co/api/models/${summary.modelId}") - if (details != null && details.siblings.find { it.rfilename == "app.json" } != null) { - val hfModel = - getJsonResponse(url = "https://huggingface.co/${summary.modelId}/resolve/main/app.json") - if (hfModel != null) { - hfModel.id = details.id - } - return@async hfModel - } - return@async null - } - } - - // Process loaded app.json - for (hfModel in hfModels.awaitAll()) { - if (hfModel != null) { - Log.d(TAG, "HF model: $hfModel") - val task = TASKS.find { it.type.label == hfModel.task } - val model = hfModel.toModel() - if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) { - model.preProcess() - Log.d(TAG, "AG model: $model") - task.models.add(model) - - // Add initial status and states. - modelDownloadStatus[model.name] = getModelDownloadStatus(model = model) - modelInstances[model.name] = - ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED) - } - } - } - } - } - - _uiState.update { - uiState.value.copy( - loadingHfModels = false, - modelDownloadStatus = modelDownloadStatus, - modelInitializationStatus = modelInstances - ) - } - } catch (e: Exception) { - e.printStackTrace() - } - } - private inline fun getJsonResponse(url: String): T? { try { val connection = URL(url).openConnection() as HttpURLConnection @@ -817,9 +819,10 @@ open class ModelManagerViewModel( val jsonObj = json.decodeFromString(response) return jsonObj } else { - println("HTTP error: $responseCode") + Log.e(TAG, "HTTP error: $responseCode") } } catch (e: Exception) { + Log.e(TAG, "Error when getting json response: ${e.message}") e.printStackTrace() } @@ -859,11 +862,18 @@ open class ModelManagerViewModel( } private fun isModelDownloaded(model: Model): Boolean { - val downloadedFileExists = - model.downloadFileName.isNotEmpty() && isFileInExternalFilesDir(model.downloadFileName) + val downloadedFileExists = model.downloadFileName.isNotEmpty() && isFileInExternalFilesDir( + listOf( + model.normalizedName, model.version, model.downloadFileName + ).joinToString(File.separator) + ) val unzippedDirectoryExists = - model.isZip && model.unzipDir.isNotEmpty() && isFileInExternalFilesDir(model.unzipDir) + model.isZip && model.unzipDir.isNotEmpty() && isFileInExternalFilesDir( + listOf( + model.normalizedName, model.version, model.unzipDir + ).joinToString(File.separator) + ) // Will also return true if model is partially downloaded. return downloadedFileExists || unzippedDirectoryExists diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/navigation/GalleryNavGraph.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/navigation/GalleryNavGraph.kt index 7f8464e..844500e 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/navigation/GalleryNavGraph.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/navigation/GalleryNavGraph.kt @@ -46,7 +46,7 @@ import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION import com.google.aiedge.gallery.data.TASK_LLM_CHAT -import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN +import com.google.aiedge.gallery.data.TASK_LLM_USECASES import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.TaskType @@ -231,7 +231,7 @@ fun GalleryNavHost( enterTransition = { slideEnter() }, exitTransition = { slideExit() }, ) { - getModelFromNavigationParam(it, TASK_LLM_SINGLE_TURN)?.let { defaultModel -> + getModelFromNavigationParam(it, TASK_LLM_USECASES)?.let { defaultModel -> modelManagerViewModel.selectModel(defaultModel) LlmSingleTurnScreen( @@ -271,7 +271,7 @@ fun navigateToTaskScreen( TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}") TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}") TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}") - TaskType.LLM_SINGLE_TURN -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}") + TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}") TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}") TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_2 -> {} diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/worker/DownloadWorker.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/worker/DownloadWorker.kt index 1f0b5d9..69f73eb 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/worker/DownloadWorker.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/worker/DownloadWorker.kt @@ -24,6 +24,7 @@ import androidx.work.WorkerParameters import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ACCESS_TOKEN import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ERROR_MESSAGE import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_FILE_NAME +import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_MODEL_DIR import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RATE import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RECEIVED_BYTES import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_REMAINING_MS @@ -34,6 +35,7 @@ import com.google.aiedge.gallery.data.KEY_MODEL_START_UNZIPPING import com.google.aiedge.gallery.data.KEY_MODEL_TOTAL_BYTES import com.google.aiedge.gallery.data.KEY_MODEL_UNZIPPED_DIR import com.google.aiedge.gallery.data.KEY_MODEL_URL +import com.google.aiedge.gallery.data.KEY_MODEL_VERSION import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import java.io.BufferedInputStream @@ -48,7 +50,10 @@ import java.util.zip.ZipInputStream private const val TAG = "AGDownloadWorker" -data class UrlAndFileName(val url: String, val fileName: String) +data class UrlAndFileName( + val url: String, + val fileName: String, +) class DownloadWorker(context: Context, params: WorkerParameters) : CoroutineWorker(context, params) { @@ -56,7 +61,9 @@ class DownloadWorker(context: Context, params: WorkerParameters) : override suspend fun doWork(): Result { val fileUrl = inputData.getString(KEY_MODEL_URL) + val version = inputData.getString(KEY_MODEL_VERSION)!! val fileName = inputData.getString(KEY_MODEL_DOWNLOAD_FILE_NAME) + val modelDir = inputData.getString(KEY_MODEL_DOWNLOAD_MODEL_DIR)!! val isZip = inputData.getBoolean(KEY_MODEL_IS_ZIP, false) val unzippedDir = inputData.getString(KEY_MODEL_UNZIPPED_DIR) val extraDataFileUrls = inputData.getString(KEY_MODEL_EXTRA_DATA_URLS)?.split(",") ?: listOf() @@ -96,8 +103,20 @@ class DownloadWorker(context: Context, params: WorkerParameters) : connection.setRequestProperty("Authorization", "Bearer $accessToken") } + // Prepare output file's dir. + val outputDir = File( + applicationContext.getExternalFilesDir(null), + listOf(modelDir, version).joinToString(separator = File.separator) + ) + if (!outputDir.exists()) { + outputDir.mkdirs() + } + // Read the file and see if it is partially downloaded. - val outputFile = File(applicationContext.getExternalFilesDir(null), file.fileName) + val outputFile = File( + applicationContext.getExternalFilesDir(null), + listOf(modelDir, version, file.fileName).joinToString(separator = File.separator) + ) val outputFileBytes = outputFile.length() if (outputFileBytes > 0) { Log.d( @@ -192,14 +211,19 @@ class DownloadWorker(context: Context, params: WorkerParameters) : setProgress(Data.Builder().putBoolean(KEY_MODEL_START_UNZIPPING, true).build()) // Prepare target dir. - val destDir = File("${externalFilesDir}${File.separator}${unzippedDir}") + val destDir = + File( + externalFilesDir, + listOf(modelDir, version, unzippedDir).joinToString(File.separator) + ) if (!destDir.exists()) { destDir.mkdirs() } // Unzip. val unzipBuffer = ByteArray(4096) - val zipFilePath = "${externalFilesDir}${File.separator}${fileName}" + val zipFilePath = + "${externalFilesDir}${File.separator}$modelDir${File.separator}$version${File.separator}${fileName}" val zipIn = ZipInputStream(BufferedInputStream(FileInputStream(zipFilePath))) var zipEntry: ZipEntry? = zipIn.nextEntry diff --git a/Android/src/gradle/libs.versions.toml b/Android/src/gradle/libs.versions.toml index 4b056d6..66c5fc7 100644 --- a/Android/src/gradle/libs.versions.toml +++ b/Android/src/gradle/libs.versions.toml @@ -18,7 +18,7 @@ gson = "2.12.1" lifecycleProcess = "2.8.7" #noinspection GradleDependency mediapipeTasksText = "0.10.21" -mediapipeTasksGenai = "0.10.22" +mediapipeTasksGenai = "0.10.24" mediapipeTasksImageGenerator = "0.10.21" commonmark = "1.0.0-alpha02" richtext = "1.0.0-alpha02" diff --git a/Android/src/settings.gradle.kts b/Android/src/settings.gradle.kts index e574fc2..229f038 100644 --- a/Android/src/settings.gradle.kts +++ b/Android/src/settings.gradle.kts @@ -30,6 +30,7 @@ pluginManagement { dependencyResolutionManagement { repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) repositories { + mavenLocal() google() mavenCentral() }