Add initial support for model allowlist, and stop generating response.

This commit is contained in:
Jing Jin 2025-05-15 00:31:45 -07:00
parent f9f1d71b38
commit ef290cd7b0
28 changed files with 676 additions and 358 deletions

View file

@ -32,7 +32,7 @@
android:dataExtractionRules="@xml/data_extraction_rules" android:dataExtractionRules="@xml/data_extraction_rules"
android:fullBackupContent="@xml/backup_rules" android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher" android:icon="@mipmap/ic_launcher"
android:label="@string/app_name" android:label="Edge Gallery"
android:roundIcon="@mipmap/ic_launcher" android:roundIcon="@mipmap/ic_launcher"
android:supportsRtl="true" android:supportsRtl="true"
android:theme="@style/Theme.Gallery" android:theme="@style/Theme.Gallery"

View file

@ -23,7 +23,6 @@ import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.preferencesDataStore import androidx.datastore.preferences.preferencesDataStore
import com.google.aiedge.gallery.data.AppContainer import com.google.aiedge.gallery.data.AppContainer
import com.google.aiedge.gallery.data.DefaultAppContainer import com.google.aiedge.gallery.data.DefaultAppContainer
import com.google.aiedge.gallery.ui.common.processTasks
import com.google.aiedge.gallery.ui.theme.ThemeSettings import com.google.aiedge.gallery.ui.theme.ThemeSettings
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences") private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
@ -35,9 +34,6 @@ class GalleryApplication : Application() {
override fun onCreate() { override fun onCreate() {
super.onCreate() super.onCreate()
// Process tasks.
processTasks()
container = DefaultAppContainer(this, dataStore) container = DefaultAppContainer(this, dataStore)
// Load theme. // Load theme.

View file

@ -16,7 +16,6 @@
package com.google.aiedge.gallery.data package com.google.aiedge.gallery.data
import com.google.aiedge.gallery.ui.common.ensureValidFileName
import kotlinx.serialization.KSerializer import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException import kotlinx.serialization.SerializationException
@ -27,15 +26,6 @@ import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonDecoder import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.JsonPrimitive
@Serializable
data class HfModelSummary(val modelId: String)
@Serializable
data class HfModelDetails(val id: String, val siblings: List<HfModelFile>)
@Serializable
data class HfModelFile(val rfilename: String)
@Serializable(with = ConfigValueSerializer::class) @Serializable(with = ConfigValueSerializer::class)
sealed class ConfigValue { sealed class ConfigValue {
@Serializable @Serializable
@ -85,64 +75,6 @@ object ConfigValueSerializer : KSerializer<ConfigValue> {
} }
} }
@Serializable
data class HfModel(
var id: String = "",
val task: String,
val name: String,
val url: String = "",
val file: String = "",
val sizeInBytes: Long,
val configs: Map<String, ConfigValue>,
) {
fun toModel(): Model {
val parts = if (url.isNotEmpty()) {
url.split('/')
} else if (file.isNotEmpty()) {
listOf(file)
} else {
listOf("")
}
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
// Generate configs based on the given default values.
// val configs: List<Config> = when (task) {
// TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
// // todo: add configs for other types.
// else -> listOf()
// }
// todo: fix when loading from models.json
val configs: List<Config> = listOf()
// Construct url.
var modelUrl = url
if (modelUrl.isEmpty() && file.isNotEmpty()) {
modelUrl = "https://huggingface.co/${id}/resolve/main/${file}?download=true"
}
// Other parameters.
val showBenchmarkButton = when (task) {
TASK_LLM_CHAT.type.label -> false
else -> true
}
val showRunAgainButton = when (task) {
TASK_LLM_CHAT.type.label -> false
else -> true
}
return Model(
hfModelId = id,
name = name,
url = modelUrl,
sizeInBytes = sizeInBytes,
downloadFileName = fileName,
configs = configs,
showBenchmarkButton = showBenchmarkButton,
showRunAgainButton = showRunAgainButton,
)
}
}
fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int { fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int {
if (configValue == null) { if (configValue == null) {
return default return default

View file

@ -18,6 +18,8 @@ package com.google.aiedge.gallery.data
// Keys used to send/receive data to Work. // Keys used to send/receive data to Work.
const val KEY_MODEL_URL = "KEY_MODEL_URL" const val KEY_MODEL_URL = "KEY_MODEL_URL"
const val KEY_MODEL_VERSION = "KEY_MODEL_VERSION"
const val KEY_MODEL_DOWNLOAD_MODEL_DIR = "KEY_MODEL_DOWNLOAD_MODEL_DIR"
const val KEY_MODEL_DOWNLOAD_FILE_NAME = "KEY_MODEL_DOWNLOAD_FILE_NAME" const val KEY_MODEL_DOWNLOAD_FILE_NAME = "KEY_MODEL_DOWNLOAD_FILE_NAME"
const val KEY_MODEL_TOTAL_BYTES = "KEY_MODEL_TOTAL_BYTES" const val KEY_MODEL_TOTAL_BYTES = "KEY_MODEL_TOTAL_BYTES"
const val KEY_MODEL_DOWNLOAD_RECEIVED_BYTES = "KEY_MODEL_DOWNLOAD_RECEIVED_BYTES" const val KEY_MODEL_DOWNLOAD_RECEIVED_BYTES = "KEY_MODEL_DOWNLOAD_RECEIVED_BYTES"

View file

@ -37,13 +37,13 @@ import androidx.work.OutOfQuotaPolicy
import androidx.work.WorkInfo import androidx.work.WorkInfo
import androidx.work.WorkManager import androidx.work.WorkManager
import androidx.work.WorkQuery import androidx.work.WorkQuery
import com.google.aiedge.gallery.AppLifecycleProvider
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.worker.DownloadWorker
import com.google.common.util.concurrent.FutureCallback import com.google.common.util.concurrent.FutureCallback
import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.MoreExecutors
import com.google.aiedge.gallery.AppLifecycleProvider
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.worker.DownloadWorker
import java.util.UUID import java.util.UUID
private const val TAG = "AGDownloadRepository" private const val TAG = "AGDownloadRepository"
@ -89,6 +89,8 @@ class DefaultDownloadRepository(
val builder = Data.Builder() val builder = Data.Builder()
val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes } val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes }
val inputDataBuilder = builder.putString(KEY_MODEL_URL, model.url) val inputDataBuilder = builder.putString(KEY_MODEL_URL, model.url)
.putString(KEY_MODEL_VERSION, model.version)
.putString(KEY_MODEL_DOWNLOAD_MODEL_DIR, model.normalizedName)
.putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName) .putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName)
.putBoolean(KEY_MODEL_IS_ZIP, model.isZip).putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir) .putBoolean(KEY_MODEL_IS_ZIP, model.isZip).putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir)
.putLong( .putLong(

View file

@ -20,6 +20,7 @@ import android.content.Context
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
import com.google.aiedge.gallery.ui.common.convertValueToTargetType import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import java.io.File
data class ModelDataFile( data class ModelDataFile(
val name: String, val name: String,
@ -33,16 +34,22 @@ enum class Accelerator(val label: String) {
} }
const val IMPORTS_DIR = "__imports" const val IMPORTS_DIR = "__imports"
private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]")
/** A model for a task */ /** A model for a task */
data class Model( data class Model(
/** The Hugging Face model ID (if applicable). */
val hfModelId: String = "",
/** The name (for display purpose) of the model. */ /** The name (for display purpose) of the model. */
val name: String, val name: String,
/** The name of the downloaded model file. */ /** The version of the model. */
val version: String = "_",
/**
* The name of the downloaded model file.
*
* The final file path of the downloaded model will be:
* {context.getExternalFilesDir}/{normalizedName}/{version}/{downloadFileName}
*/
val downloadFileName: String, val downloadFileName: String,
/** The URL to download the model from. */ /** The URL to download the model from. */
@ -88,6 +95,7 @@ data class Model(
val imported: Boolean = false, val imported: Boolean = false,
// The following fields are managed by the app. Don't need to set manually. // The following fields are managed by the app. Don't need to set manually.
var normalizedName: String = "",
var instance: Any? = null, var instance: Any? = null,
var initializing: Boolean = false, var initializing: Boolean = false,
// TODO(jingjin): use a "queue" system to manage model init and cleanup. // TODO(jingjin): use a "queue" system to manage model init and cleanup.
@ -96,6 +104,10 @@ data class Model(
var totalBytes: Long = 0L, var totalBytes: Long = 0L,
var accessToken: String? = null, var accessToken: String? = null,
) { ) {
init {
normalizedName = NORMALIZE_NAME_REGEX.replace(name, "_")
}
fun preProcess() { fun preProcess() {
val configValues: MutableMap<String, Any> = mutableMapOf() val configValues: MutableMap<String, Any> = mutableMapOf()
for (config in this.configs) { for (config in this.configs) {
@ -106,11 +118,22 @@ data class Model(
} }
fun getPath(context: Context, fileName: String = downloadFileName): String { fun getPath(context: Context, fileName: String = downloadFileName): String {
val baseDir = "${context.getExternalFilesDir(null)}" if (imported) {
return listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", fileName).joinToString(
File.separator
)
}
val baseDir =
listOf(
context.getExternalFilesDir(null)?.absolutePath ?: "",
normalizedName,
version
).joinToString(File.separator)
return if (this.isZip && this.unzipDir.isNotEmpty()) { return if (this.isZip && this.unzipDir.isNotEmpty()) {
"$baseDir/${this.unzipDir}" "$baseDir/${this.unzipDir}"
} else { } else {
"$baseDir/${fileName}" "$baseDir/$fileName"
} }
} }

View file

@ -0,0 +1,116 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPK
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPP
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import kotlinx.serialization.Serializable
/** A model in the model allowlist. */
@Serializable
data class AllowedModel(
val name: String,
val modelId: String,
val modelFile: String,
val description: String,
val sizeInBytes: Long,
val version: String,
val defaultConfig: Map<String, ConfigValue>,
val taskTypes: List<String>,
val disabled: Boolean? = null,
) {
fun toModel(): Model {
// Construct HF download url.
val downloadUrl = "https://huggingface.co/$modelId/resolve/main/$modelFile?download=true"
// Config.
val isLlmModel =
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)
var configs: List<Config> = listOf()
if (isLlmModel) {
var defaultTopK: Int = DEFAULT_TOPK
var defaultTopP: Float = DEFAULT_TOPP
var defaultTemperature: Float = DEFAULT_TEMPERATURE
var defaultMaxToken: Int = 1024
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
if (defaultConfig.containsKey("topK")) {
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK)
}
if (defaultConfig.containsKey("topP")) {
defaultTopP = getFloatConfigValue(defaultConfig["topP"], defaultTopP)
}
if (defaultConfig.containsKey("temperature")) {
defaultTemperature = getFloatConfigValue(defaultConfig["temperature"], defaultTemperature)
}
if (defaultConfig.containsKey("maxTokens")) {
defaultMaxToken = getIntConfigValue(defaultConfig["maxTokens"], defaultMaxToken)
}
if (defaultConfig.containsKey("accelerators")) {
val items = getStringConfigValue(defaultConfig["accelerators"], "gpu").split(",")
accelerators = mutableListOf()
for (item in items) {
if (item == "cpu") {
accelerators.add(Accelerator.CPU)
} else if (item == "gpu") {
accelerators.add(Accelerator.GPU)
}
}
}
configs = createLlmChatConfigs(
defaultTopK = defaultTopK,
defaultTopP = defaultTopP,
defaultTemperature = defaultTemperature,
defaultMaxToken = defaultMaxToken,
accelerators = accelerators,
)
}
// Misc.
var showBenchmarkButton = true
val showRunAgainButton = true
if (taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)) {
showBenchmarkButton = false
}
return Model(
name = name,
version = version,
info = description,
url = downloadUrl,
sizeInBytes = sizeInBytes,
configs = configs,
downloadFileName = modelFile,
showBenchmarkButton = showBenchmarkButton,
showRunAgainButton = showRunAgainButton,
learnMoreUrl = "https://huggingface.co/${modelId}"
)
}
override fun toString(): String {
return "$modelId/$modelFile"
}
}
/** The model allowlist. */
@Serializable
data class ModelAllowlist(
val models: List<AllowedModel>,
)

View file

@ -27,15 +27,15 @@ import androidx.compose.ui.graphics.vector.ImageVector
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
/** Type of task. */ /** Type of task. */
enum class TaskType(val label: String) { enum class TaskType(val label: String, val id: String) {
TEXT_CLASSIFICATION("Text Classification"), TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
IMAGE_CLASSIFICATION("Image Classification"), IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
IMAGE_GENERATION("Image Generation"), IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
LLM_CHAT("LLM Chat"), LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
LLM_SINGLE_TURN("LLM Use Cases"), LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
TEST_TASK_1("Test task 1"), TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2("Test task 2") TEST_TASK_2(label = "Test task 2", id = "test_task_2")
} }
/** Data class for a task listed in home screen. */ /** Data class for a task listed in home screen. */
@ -91,17 +91,19 @@ val TASK_IMAGE_CLASSIFICATION = Task(
val TASK_LLM_CHAT = Task( val TASK_LLM_CHAT = Task(
type = TaskType.LLM_CHAT, type = TaskType.LLM_CHAT,
icon = Icons.Outlined.Forum, icon = Icons.Outlined.Forum,
models = MODELS_LLM, // models = MODELS_LLM,
models = mutableListOf(),
description = "Chat with a on-device large language model", description = "Chat with a on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
) )
val TASK_LLM_SINGLE_TURN = Task( val TASK_LLM_USECASES = Task(
type = TaskType.LLM_SINGLE_TURN, type = TaskType.LLM_USECASES,
icon = Icons.Outlined.Widgets, icon = Icons.Outlined.Widgets,
models = MODELS_LLM, // models = MODELS_LLM,
models = mutableListOf(),
description = "Single turn use cases with on-device large language model", description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt", sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
@ -123,7 +125,7 @@ val TASKS: List<Task> = listOf(
// TASK_TEXT_CLASSIFICATION, // TASK_TEXT_CLASSIFICATION,
// TASK_IMAGE_CLASSIFICATION, // TASK_IMAGE_CLASSIFICATION,
// TASK_IMAGE_GENERATION, // TASK_IMAGE_GENERATION,
TASK_LLM_SINGLE_TURN, TASK_LLM_USECASES,
TASK_LLM_CHAT, TASK_LLM_CHAT,
) )

View file

@ -28,12 +28,15 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentHeight import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowForward import androidx.compose.material.icons.automirrored.rounded.ArrowForward
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
@ -102,6 +105,7 @@ fun DownloadAndTryButton(
val context = LocalContext.current val context = LocalContext.current
var checkingToken by remember { mutableStateOf(false) } var checkingToken by remember { mutableStateOf(false) }
var showAgreementAckSheet by remember { mutableStateOf(false) } var showAgreementAckSheet by remember { mutableStateOf(false) }
var showErrorDialog by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState() val sheetState = rememberModalBottomSheetState()
// A launcher for requesting notification permission. // A launcher for requesting notification permission.
@ -208,12 +212,18 @@ fun DownloadAndTryButton(
TAG, TAG,
"Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download" "Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download"
) )
if (modelManagerViewModel.getModelUrlResponse(model = model) == HttpURLConnection.HTTP_OK) { val firstResponseCode = modelManagerViewModel.getModelUrlResponse(model = model)
if (firstResponseCode == HttpURLConnection.HTTP_OK) {
Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...") Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...")
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) {
startDownload(null) startDownload(null)
} }
return@launch return@launch
} else if (firstResponseCode < 0) {
checkingToken = false
Log.e(TAG, "Unknown network error")
showErrorDialog = true
return@launch
} }
Log.d(TAG, "Model '${model.name}' needs auth. Start token exchange process...") Log.d(TAG, "Model '${model.name}' needs auth. Start token exchange process...")
@ -334,4 +344,30 @@ fun DownloadAndTryButton(
} }
} }
} }
if (showErrorDialog) {
AlertDialog(
icon = {
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
},
title = {
Text("Unknown network error")
},
text = {
Text("Please check your internet connection.")
},
onDismissRequest = {
showErrorDialog = false
},
confirmButton = {
TextButton(
onClick = {
showErrorDialog = false
}
) {
Text("Close")
}
},
)
}
} }

View file

@ -484,5 +484,7 @@ fun processLlmResponse(response: String): String {
} }
} }
newContent = newContent.replace("\\n", "\n")
return newContent return newContent
} }

View file

@ -24,6 +24,7 @@ import com.google.aiedge.gallery.data.Model
enum class ChatMessageType { enum class ChatMessageType {
INFO, INFO,
WARNING,
TEXT, TEXT,
IMAGE, IMAGE,
IMAGE_WITH_HISTORY, IMAGE_WITH_HISTORY,
@ -57,6 +58,10 @@ class ChatMessageLoading : ChatMessage(type = ChatMessageType.LOADING, side = Ch
class ChatMessageInfo(val content: String) : class ChatMessageInfo(val content: String) :
ChatMessage(type = ChatMessageType.INFO, side = ChatSide.SYSTEM) ChatMessage(type = ChatMessageType.INFO, side = ChatSide.SYSTEM)
/** Chat message for info (help). */
class ChatMessageWarning(val content: String) :
ChatMessage(type = ChatMessageType.WARNING, side = ChatSide.SYSTEM)
/** Chat message for config values change. */ /** Chat message for config values change. */
class ChatMessageConfigValuesChange( class ChatMessageConfigValuesChange(
val model: Model, val model: Model,

View file

@ -269,6 +269,9 @@ fun ChatPanel(
// Info. // Info.
is ChatMessageInfo -> MessageBodyInfo(message = message) is ChatMessageInfo -> MessageBodyInfo(message = message)
// Warning
is ChatMessageWarning -> MessageBodyWarning(message = message)
// Config values change. // Config values change.
is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message) is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message)
@ -433,6 +436,7 @@ fun ChatPanel(
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
curMessage = curMessage, curMessage = curMessage,
inProgress = uiState.inProgress, inProgress = uiState.inProgress,
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes, textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
onValueChanged = { curMessage = it }, onValueChanged = { curMessage = it },
onSendMessage = { onSendMessage = {

View file

@ -0,0 +1,62 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display warning message content within a chat.
*
* Supports markdown.
*/
@Composable
fun MessageBodyWarning(message: ChatMessageWarning) {
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
) {
Box(
modifier = Modifier
.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer)
) {
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyWarningPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyWarning(message = ChatMessageWarning(content = "This is a warning"))
}
}
}

View file

@ -75,6 +75,7 @@ fun MessageInputText(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
curMessage: String, curMessage: String,
inProgress: Boolean, inProgress: Boolean,
modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int, @StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit, onValueChanged: (String) -> Unit,
onSendMessage: (ChatMessage) -> Unit, onSendMessage: (ChatMessage) -> Unit,
@ -162,6 +163,7 @@ fun MessageInputText(
Spacer(modifier = Modifier.width(8.dp)) Spacer(modifier = Modifier.width(8.dp))
if (inProgress && showStopButtonWhenInProgress) { if (inProgress && showStopButtonWhenInProgress) {
if (!modelInitializing) {
IconButton( IconButton(
onClick = onStopButtonClicked, onClick = onStopButtonClicked,
colors = IconButtonDefaults.iconButtonColors( colors = IconButtonDefaults.iconButtonColors(
@ -174,6 +176,7 @@ fun MessageInputText(
tint = MaterialTheme.colorScheme.primary tint = MaterialTheme.colorScheme.primary
) )
} }
}
} // Send button. Only shown when text is not empty. } // Send button. Only shown when text is not empty.
else if (curMessage.isNotEmpty()) { else if (curMessage.isNotEmpty()) {
IconButton( IconButton(
@ -230,6 +233,7 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context), modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello", curMessage = "hello",
inProgress = false, inProgress = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {}, onValueChanged = {},
onSendMessage = {}, onSendMessage = {},
@ -239,6 +243,7 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context), modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello", curMessage = "hello",
inProgress = true, inProgress = true,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {}, onValueChanged = {},
onSendMessage = {}, onSendMessage = {},
@ -247,6 +252,7 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context), modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "", curMessage = "",
inProgress = false, inProgress = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {}, onValueChanged = {},
onSendMessage = {}, onSendMessage = {},
@ -255,6 +261,7 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context), modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "", curMessage = "",
inProgress = true, inProgress = true,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder, textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {}, onValueChanged = {},
onSendMessage = {}, onSendMessage = {},

View file

@ -39,7 +39,6 @@ import androidx.compose.material.icons.rounded.ChevronRight
import androidx.compose.material.icons.rounded.Settings import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material.icons.rounded.UnfoldLess import androidx.compose.material.icons.rounded.UnfoldLess
import androidx.compose.material.icons.rounded.UnfoldMore import androidx.compose.material.icons.rounded.UnfoldMore
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
import androidx.compose.material3.OutlinedButton import androidx.compose.material3.OutlinedButton

View file

@ -75,6 +75,8 @@ fun ModelNameAndStatus(
) { ) {
Text( Text(
model.name, model.name,
maxLines = 1,
overflow = TextOverflow.Ellipsis,
style = MaterialTheme.typography.titleMedium, style = MaterialTheme.typography.titleMedium,
modifier = modifier, modifier = modifier,
) )

View file

@ -38,6 +38,7 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.lazy.grid.GridCells import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.GridItemSpan import androidx.compose.foundation.lazy.grid.GridItemSpan
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
@ -46,8 +47,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.NoteAdd import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults import androidx.compose.material3.CardDefaults
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
@ -57,6 +61,7 @@ import androidx.compose.material3.SmallFloatingActionButton
import androidx.compose.material3.SnackbarHost import androidx.compose.material3.SnackbarHost
import androidx.compose.material3.SnackbarHostState import androidx.compose.material3.SnackbarHostState
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBarDefaults import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
@ -180,6 +185,7 @@ fun HomeScreen(
TaskList( TaskList(
tasks = tasks, tasks = tasks,
navigateToTaskScreen = navigateToTaskScreen, navigateToTaskScreen = navigateToTaskScreen,
loadingModelAllowlist = uiState.loadingModelAllowlist,
modifier = Modifier.fillMaxSize(), modifier = Modifier.fillMaxSize(),
contentPadding = innerPadding, contentPadding = innerPadding,
) )
@ -285,12 +291,40 @@ fun HomeScreen(
} }
} }
} }
if (uiState.loadingModelAllowlistError.isNotEmpty()) {
AlertDialog(
icon = {
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
},
title = {
Text(uiState.loadingModelAllowlistError)
},
text = {
Text("Please check your internet connection and try again later.")
},
onDismissRequest = {
modelManagerViewModel.loadModelAllowlist()
},
confirmButton = {
TextButton(
onClick = {
modelManagerViewModel.loadModelAllowlist()
}
) {
Text("Retry")
}
},
)
}
} }
@Composable @Composable
private fun TaskList( private fun TaskList(
tasks: List<Task>, tasks: List<Task>,
navigateToTaskScreen: (Task) -> Unit, navigateToTaskScreen: (Task) -> Unit,
loadingModelAllowlist: Boolean,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
contentPadding: PaddingValues = PaddingValues(0.dp), contentPadding: PaddingValues = PaddingValues(0.dp),
) { ) {
@ -312,6 +346,25 @@ private fun TaskList(
) )
} }
if (loadingModelAllowlist) {
item(key = "loading", span = { GridItemSpan(2) }) {
Row(
horizontalArrangement = Arrangement.Center,
modifier = Modifier
.fillMaxWidth()
.padding(top = 32.dp)
) {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceVariant,
strokeWidth = 3.dp,
modifier = Modifier
.padding(end = 8.dp)
.size(20.dp)
)
Text("Loading model list...", style = MaterialTheme.typography.bodyMedium)
}
}
} else {
// Cards. // Cards.
items(tasks) { task -> items(tasks) { task ->
TaskCard( TaskCard(
@ -324,6 +377,7 @@ private fun TaskList(
.aspectRatio(1f) .aspectRatio(1f)
) )
} }
}
// Bottom padding. // Bottom padding.
item(key = "bottomPadding", span = { GridItemSpan(2) }) { item(key = "bottomPadding", span = { GridItemSpan(2) }) {

View file

@ -16,11 +16,15 @@
package com.google.aiedge.gallery.ui.llmchat package com.google.aiedge.gallery.ui.llmchat
import android.util.Log
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.ui.ViewModelProvider import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.chat.ChatMessageInfo
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
import com.google.aiedge.gallery.ui.common.chat.ChatView import com.google.aiedge.gallery.ui.common.chat.ChatView
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
@ -40,6 +44,8 @@ fun LlmChatScreen(
factory = ViewModelProvider.Factory factory = ViewModelProvider.Factory
), ),
) { ) {
val context = LocalContext.current
ChatView( ChatView(
task = viewModel.task, task = viewModel.task,
viewModel = viewModel, viewModel = viewModel,
@ -51,25 +57,43 @@ fun LlmChatScreen(
) )
if (message is ChatMessageText) { if (message is ChatMessageText) {
modelManagerViewModel.addTextInputHistory(message.content) modelManagerViewModel.addTextInputHistory(message.content)
viewModel.generateResponse( viewModel.generateResponse(model = model, input = message.content, onError = {
viewModel.addMessage(
model = model, model = model,
input = message.content, message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
) )
modelManagerViewModel.initializeModel(
context = context, task = viewModel.task, model = model, force = true
)
})
} }
}, },
onRunAgainClicked = { model, message -> onRunAgainClicked = { model, message ->
if (message is ChatMessageText) { if (message is ChatMessageText) {
viewModel.runAgain(model = model, message = message) viewModel.runAgain(model = model, message = message, onError = {
viewModel.addMessage(
model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
)
modelManagerViewModel.initializeModel(
context = context, task = viewModel.task, model = model, force = true
)
})
} }
}, },
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations -> onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
if (message is ChatMessageText) { if (message is ChatMessageText) {
viewModel.benchmark( viewModel.benchmark(
model = model, model = model, message = message
message = message
) )
} }
}, },
showStopButtonInInputWhenInProgress = true,
onStopButtonClicked = { model ->
viewModel.stopResponse(model = model)
},
navigateUp = navigateUp, navigateUp = navigateUp,
modifier = modifier, modifier = modifier,
) )

View file

@ -16,6 +16,7 @@
package com.google.aiedge.gallery.ui.llmchat package com.google.aiedge.gallery.ui.llmchat
import android.util.Log
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
@ -26,6 +27,7 @@ import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
import com.google.aiedge.gallery.ui.common.chat.ChatSide import com.google.aiedge.gallery.ui.common.chat.ChatSide
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import com.google.aiedge.gallery.ui.common.chat.Stat import com.google.aiedge.gallery.ui.common.chat.Stat
import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -39,7 +41,7 @@ private val STATS = listOf(
) )
class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) { class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
fun generateResponse(model: Model, input: String) { fun generateResponse(model: Model, input: String, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
setInProgress(true) setInProgress(true)
@ -65,6 +67,8 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
var prefillSpeed = 0f var prefillSpeed = 0f
var decodeSpeed: Float var decodeSpeed: Float
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
try {
LlmChatModelHelper.runInference( LlmChatModelHelper.runInference(
model = model, model = model,
input = input, input = input,
@ -130,10 +134,23 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
}, cleanUpListener = { }, cleanUpListener = {
setInProgress(false) setInProgress(false)
}) })
} catch (e: Exception) {
setInProgress(false)
onError()
}
} }
} }
fun runAgain(model: Model, message: ChatMessageText) { fun stopResponse(model: Model) {
Log.d(TAG, "Stopping response for model ${model.name}...")
viewModelScope.launch(Dispatchers.Default) {
setInProgress(false)
val instance = model.instance as LlmModelInstance
instance.session.cancelGenerateResponseAsync()
}
}
fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized. // Wait for model to be initialized.
while (model.instance == null) { while (model.instance == null) {
@ -147,6 +164,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
generateResponse( generateResponse(
model = model, model = model,
input = message.content, input = message.content,
onError = onError
) )
} }
} }

View file

@ -20,7 +20,7 @@ import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.Stat import com.google.aiedge.gallery.ui.common.chat.Stat
@ -64,7 +64,7 @@ private val STATS = listOf(
Stat(id = "latency", label = "Latency", unit = "sec") Stat(id = "latency", label = "Latency", unit = "sec")
) )
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_SINGLE_TURN) : ViewModel() { open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task)) private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow() val uiState = _uiState.asStateFlow()

View file

@ -218,6 +218,7 @@ fun PromptTemplatesPanel(
.clip(MessageBubbleShape(radius = bubbleBorderRadius)) .clip(MessageBubbleShape(radius = bubbleBorderRadius))
.background(MaterialTheme.customColors.agentBubbleBgColor) .background(MaterialTheme.customColors.agentBubbleBgColor)
.padding(16.dp) .padding(16.dp)
.focusRequester(focusRequester)
) )
} else { } else {
TextField( TextField(

View file

@ -57,7 +57,7 @@ import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.ui.common.chat.MarkdownText import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
@ -76,7 +76,7 @@ fun ResponsePanel(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
) { ) {
val task = TASK_LLM_SINGLE_TURN val task = TASK_LLM_USECASES
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val inProgress = uiState.inProgress val inProgress = uiState.inProgress

View file

@ -16,8 +16,6 @@
package com.google.aiedge.gallery.ui.modelmanager package com.google.aiedge.gallery.ui.modelmanager
import android.os.Build
import androidx.annotation.RequiresApi
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
@ -62,7 +60,6 @@ import com.google.aiedge.gallery.ui.theme.customColors
private const val TAG = "AGModelList" private const val TAG = "AGModelList"
/** The list of models in the model manager. */ /** The list of models in the model manager. */
@RequiresApi(Build.VERSION_CODES.O)
@Composable @Composable
fun ModelList( fun ModelList(
task: Task, task: Task,
@ -213,7 +210,6 @@ fun ClickableLink(
} }
} }
@RequiresApi(Build.VERSION_CODES.O)
@Preview(showBackground = true) @Preview(showBackground = true)
@Composable @Composable
fun ModelListPreview() { fun ModelListPreview() {

View file

@ -30,32 +30,28 @@ import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.DataStoreRepository import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.DownloadRepository import com.google.aiedge.gallery.data.DownloadRepository
import com.google.aiedge.gallery.data.EMPTY_MODEL import com.google.aiedge.gallery.data.EMPTY_MODEL
import com.google.aiedge.gallery.data.HfModel
import com.google.aiedge.gallery.data.HfModelDetails
import com.google.aiedge.gallery.data.HfModelSummary
import com.google.aiedge.gallery.data.IMPORTS_DIR import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.data.ImportedModelInfo import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelAllowlist
import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.data.getModelByName import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.common.AuthConfig import com.google.aiedge.gallery.ui.common.AuthConfig
import com.google.aiedge.gallery.ui.common.convertValueToTargetType import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.common.processTasks
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
@ -73,8 +69,9 @@ import java.net.HttpURLConnection
import java.net.URL import java.net.URL
private const val TAG = "AGModelManagerViewModel" private const val TAG = "AGModelManagerViewModel"
private const val HG_COMMUNITY = "jinjingforevercommunity"
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50 private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
private const val MODEL_ALLOWLIST_URL =
"https://raw.githubusercontent.com/jinjingforever/kokoro-codelab-jingjin/refs/heads/main/model_allowlist.json"
data class ModelInitializationStatus( data class ModelInitializationStatus(
val status: ModelInitializationStatusType, var error: String = "" val status: ModelInitializationStatusType, var error: String = ""
@ -122,6 +119,14 @@ data class ModelManagerUiState(
*/ */
val loadingHfModels: Boolean = false, val loadingHfModels: Boolean = false,
/**
* Whether the app is loading and processing the model allowlist.
*/
val loadingModelAllowlist: Boolean = true,
/** The error message when loading the model allowlist. */
val loadingModelAllowlistError: String = "",
/** /**
* The currently selected model. * The currently selected model.
*/ */
@ -153,7 +158,7 @@ open class ModelManagerViewModel(
private val externalFilesDir = context.getExternalFilesDir(null) private val externalFilesDir = context.getExternalFilesDir(null)
private val inProgressWorkInfos: List<AGWorkInfo> = private val inProgressWorkInfos: List<AGWorkInfo> =
downloadRepository.getEnqueuedOrRunningWorkInfos() downloadRepository.getEnqueuedOrRunningWorkInfos()
protected val _uiState = MutableStateFlow(createUiState()) protected val _uiState = MutableStateFlow(createEmptyUiState())
val uiState = _uiState.asStateFlow() val uiState = _uiState.asStateFlow()
val authService = AuthorizationService(context) val authService = AuthorizationService(context)
@ -162,44 +167,7 @@ open class ModelManagerViewModel(
var pagerScrollState: MutableStateFlow<PagerScrollState> = MutableStateFlow(PagerScrollState()) var pagerScrollState: MutableStateFlow<PagerScrollState> = MutableStateFlow(PagerScrollState())
init { init {
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos") loadModelAllowlist()
// Iterate through the inProgressWorkInfos and retrieve the corresponding modes.
// Those models are the ones that have not finished downloading.
val models: MutableList<Model> = mutableListOf()
for (info in inProgressWorkInfos) {
getModelByName(info.modelName)?.let { model ->
models.add(model)
}
}
// Cancel all pending downloads for the retrieved models.
downloadRepository.cancelAll(models) {
Log.d(TAG, "All pending work is cancelled")
viewModelScope.launch(Dispatchers.IO) {
// Load models from hg community.
loadHfModels()
Log.d(TAG, "Done loading HF models")
// Kick off downloads for these models .
withContext(Dispatchers.Main) {
val tokenStatusAndData = getTokenStatusAndData()
for (info in inProgressWorkInfos) {
val model: Model? = getModelByName(info.modelName)
if (model != null) {
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
model.accessToken = tokenStatusAndData.data.accessToken
}
Log.d(TAG, "Sending a new download request for '${model.name}'")
downloadRepository.downloadModel(
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
)
}
}
}
}
}
} }
override fun onCleared() { override fun onCleared() {
@ -231,12 +199,10 @@ open class ModelManagerViewModel(
} }
fun deleteModel(task: Task, model: Model) { fun deleteModel(task: Task, model: Model) {
if (model.imported) {
deleteFileFromExternalFilesDir(model.downloadFileName) deleteFileFromExternalFilesDir(model.downloadFileName)
for (file in model.extraDataFiles) { } else {
deleteFileFromExternalFilesDir(file.downloadFileName) deleteDirFromExternalFilesDir(model.normalizedName)
}
if (model.isZip && model.unzipDir.isNotEmpty()) {
deleteDirFromExternalFilesDir(model.unzipDir)
} }
// Update model download status to NotDownloaded. // Update model download status to NotDownloaded.
@ -340,7 +306,7 @@ open class ModelManagerViewModel(
onDone = onDone, onDone = onDone,
) )
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.initialize( TaskType.LLM_USECASES -> LlmChatModelHelper.initialize(
context = context, context = context,
model = model, model = model,
onDone = onDone, onDone = onDone,
@ -364,7 +330,7 @@ open class ModelManagerViewModel(
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model) TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model) TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model) TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {} TaskType.TEST_TASK_2 -> {}
@ -444,6 +410,7 @@ open class ModelManagerViewModel(
} }
fun getModelUrlResponse(model: Model, accessToken: String? = null): Int { fun getModelUrlResponse(model: Model, accessToken: String? = null): Int {
try {
val url = URL(model.url) val url = URL(model.url)
val connection = url.openConnection() as HttpURLConnection val connection = url.openConnection() as HttpURLConnection
if (accessToken != null) { if (accessToken != null) {
@ -455,22 +422,28 @@ open class ModelManagerViewModel(
// Report the result. // Report the result.
return connection.responseCode return connection.responseCode
} catch (e: Exception) {
Log.e(TAG, "$e")
return -1
}
} }
fun addImportedLlmModel(info: ImportedModelInfo) { fun addImportedLlmModel(info: ImportedModelInfo) {
Log.d(TAG, "adding imported llm model: $info") Log.d(TAG, "adding imported llm model: $info")
// Create model.
val model = createModelFromImportedModelInfo(info = info)
// Remove duplicated imported model if existed. // Remove duplicated imported model if existed.
val task = TASK_LLM_CHAT for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES)) {
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) { if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first") Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex) task.models.removeAt(modelIndex)
} }
// Create model.
val model = createModelFromImportedModelInfo(info = info, task = task)
task.models.add(model) task.models.add(model)
task.updateTrigger.value = System.currentTimeMillis()
}
// Add initial status and states. // Add initial status and states.
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap() val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
@ -491,10 +464,6 @@ open class ModelManagerViewModel(
modelInitializationStatus = modelInstances modelInitializationStatus = modelInstances
) )
} }
task.updateTrigger.value = System.currentTimeMillis()
// Also need to update single turn task.
TASK_LLM_SINGLE_TURN.updateTrigger.value = System.currentTimeMillis()
// Add to preference storage. // Add to preference storage.
val importedModels = dataStoreRepository.readImportedModels().toMutableList() val importedModels = dataStoreRepository.readImportedModels().toMutableList()
@ -623,10 +592,110 @@ open class ModelManagerViewModel(
} }
} }
private fun processPendingDownloads() {
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
// Iterate through the inProgressWorkInfos and retrieve the corresponding modes.
// Those models are the ones that have not finished downloading.
val models: MutableList<Model> = mutableListOf()
for (info in inProgressWorkInfos) {
getModelByName(info.modelName)?.let { model ->
models.add(model)
}
}
// Cancel all pending downloads for the retrieved models.
downloadRepository.cancelAll(models) {
Log.d(TAG, "All pending work is cancelled")
viewModelScope.launch(Dispatchers.IO) {
// Kick off downloads for these models .
withContext(Dispatchers.Main) {
val tokenStatusAndData = getTokenStatusAndData()
for (info in inProgressWorkInfos) {
val model: Model? = getModelByName(info.modelName)
if (model != null) {
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
model.accessToken = tokenStatusAndData.data.accessToken
}
Log.d(TAG, "Sending a new download request for '${model.name}'")
downloadRepository.downloadModel(
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
)
}
}
}
}
}
}
fun loadModelAllowlist() {
_uiState.update {
uiState.value.copy(
loadingModelAllowlist = true,
loadingModelAllowlistError = ""
)
}
viewModelScope.launch(Dispatchers.IO) {
try {
// Load model allowlist json.
val modelAllowlist: ModelAllowlist? =
getJsonResponse<ModelAllowlist>(url = MODEL_ALLOWLIST_URL)
if (modelAllowlist == null) {
_uiState.update { uiState.value.copy(loadingModelAllowlistError = "Failed to load model list") }
return@launch
}
Log.d(TAG, "Allowlist: $modelAllowlist")
// Convert models in the allowlist.
for (allowedModel in modelAllowlist.models) {
if (allowedModel.disabled == true) {
continue
}
val model = allowedModel.toModel()
if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) {
TASK_LLM_CHAT.models.add(model)
}
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
TASK_LLM_USECASES.models.add(model)
}
}
// Pre-process all tasks.
processTasks()
// Update UI state.
val newUiState = createUiState()
_uiState.update {
newUiState.copy(
loadingModelAllowlist = false,
)
}
// Process pending downloads.
processPendingDownloads()
} catch (e: Exception) {
e.printStackTrace()
}
}
}
private fun isModelPartiallyDownloaded(model: Model): Boolean { private fun isModelPartiallyDownloaded(model: Model): Boolean {
return inProgressWorkInfos.find { it.modelName == model.name } != null return inProgressWorkInfos.find { it.modelName == model.name } != null
} }
private fun createEmptyUiState(): ModelManagerUiState {
return ModelManagerUiState(
tasks = listOf(),
modelDownloadStatus = mapOf(),
modelInitializationStatus = mapOf(),
)
}
private fun createUiState(): ModelManagerUiState { private fun createUiState(): ModelManagerUiState {
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf() val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf() val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
@ -643,11 +712,11 @@ open class ModelManagerViewModel(
Log.d(TAG, "stored imported model: $importedModel") Log.d(TAG, "stored imported model: $importedModel")
// Create model. // Create model.
val model = createModelFromImportedModelInfo(info = importedModel, task = TASK_LLM_CHAT) val model = createModelFromImportedModelInfo(info = importedModel)
// Add to task. // Add to task.
val task = TASK_LLM_CHAT TASK_LLM_CHAT.models.add(model)
task.models.add(model) TASK_LLM_USECASES.models.add(model)
// Update status. // Update status.
modelDownloadStatus[model.name] = ModelDownloadStatus( modelDownloadStatus[model.name] = ModelDownloadStatus(
@ -660,6 +729,7 @@ open class ModelManagerViewModel(
val textInputHistory = dataStoreRepository.readTextInputHistory() val textInputHistory = dataStoreRepository.readTextInputHistory()
Log.d(TAG, "text input history: $textInputHistory") Log.d(TAG, "text input history: $textInputHistory")
Log.d(TAG, "model download status: $modelDownloadStatus")
return ModelManagerUiState( return ModelManagerUiState(
tasks = TASKS, tasks = TASKS,
modelDownloadStatus = modelDownloadStatus, modelDownloadStatus = modelDownloadStatus,
@ -668,7 +738,7 @@ open class ModelManagerViewModel(
) )
} }
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model { private fun createModelFromImportedModelInfo(info: ImportedModelInfo): Model {
val accelerators: List<Accelerator> = (convertValueToTargetType( val accelerators: List<Accelerator> = (convertValueToTargetType(
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING
) as String).split(",").mapNotNull { acceleratorLabel -> ) as String).split(",").mapNotNull { acceleratorLabel ->
@ -733,74 +803,6 @@ open class ModelManagerViewModel(
) )
} }
suspend fun loadHfModels() {
// Update loading state shown in ui.
_uiState.update {
uiState.value.copy(
loadingHfModels = true,
)
}
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
try {
// Load model summaries.
val modelSummaries =
getJsonResponse<List<HfModelSummary>>(url = "https://huggingface.co/api/models?search=$HG_COMMUNITY")
Log.d(TAG, "HF model summaries: $modelSummaries")
// Load individual models in parallel.
if (modelSummaries != null) {
coroutineScope {
val hfModels = modelSummaries.map { summary ->
async {
val details =
getJsonResponse<HfModelDetails>(url = "https://huggingface.co/api/models/${summary.modelId}")
if (details != null && details.siblings.find { it.rfilename == "app.json" } != null) {
val hfModel =
getJsonResponse<HfModel>(url = "https://huggingface.co/${summary.modelId}/resolve/main/app.json")
if (hfModel != null) {
hfModel.id = details.id
}
return@async hfModel
}
return@async null
}
}
// Process loaded app.json
for (hfModel in hfModels.awaitAll()) {
if (hfModel != null) {
Log.d(TAG, "HF model: $hfModel")
val task = TASKS.find { it.type.label == hfModel.task }
val model = hfModel.toModel()
if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) {
model.preProcess()
Log.d(TAG, "AG model: $model")
task.models.add(model)
// Add initial status and states.
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
}
}
}
}
}
_uiState.update {
uiState.value.copy(
loadingHfModels = false,
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances
)
}
} catch (e: Exception) {
e.printStackTrace()
}
}
private inline fun <reified T> getJsonResponse(url: String): T? { private inline fun <reified T> getJsonResponse(url: String): T? {
try { try {
val connection = URL(url).openConnection() as HttpURLConnection val connection = URL(url).openConnection() as HttpURLConnection
@ -817,9 +819,10 @@ open class ModelManagerViewModel(
val jsonObj = json.decodeFromString<T>(response) val jsonObj = json.decodeFromString<T>(response)
return jsonObj return jsonObj
} else { } else {
println("HTTP error: $responseCode") Log.e(TAG, "HTTP error: $responseCode")
} }
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error when getting json response: ${e.message}")
e.printStackTrace() e.printStackTrace()
} }
@ -859,11 +862,18 @@ open class ModelManagerViewModel(
} }
private fun isModelDownloaded(model: Model): Boolean { private fun isModelDownloaded(model: Model): Boolean {
val downloadedFileExists = val downloadedFileExists = model.downloadFileName.isNotEmpty() && isFileInExternalFilesDir(
model.downloadFileName.isNotEmpty() && isFileInExternalFilesDir(model.downloadFileName) listOf(
model.normalizedName, model.version, model.downloadFileName
).joinToString(File.separator)
)
val unzippedDirectoryExists = val unzippedDirectoryExists =
model.isZip && model.unzipDir.isNotEmpty() && isFileInExternalFilesDir(model.unzipDir) model.isZip && model.unzipDir.isNotEmpty() && isFileInExternalFilesDir(
listOf(
model.normalizedName, model.version, model.unzipDir
).joinToString(File.separator)
)
// Will also return true if model is partially downloaded. // Will also return true if model is partially downloaded.
return downloadedFileExists || unzippedDirectoryExists return downloadedFileExists || unzippedDirectoryExists

View file

@ -46,7 +46,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
@ -231,7 +231,7 @@ fun GalleryNavHost(
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) {
getModelFromNavigationParam(it, TASK_LLM_SINGLE_TURN)?.let { defaultModel -> getModelFromNavigationParam(it, TASK_LLM_USECASES)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen( LlmSingleTurnScreen(
@ -271,7 +271,7 @@ fun navigateToTaskScreen(
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}") TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}") TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}") TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_SINGLE_TURN -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}") TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}") TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {} TaskType.TEST_TASK_2 -> {}

View file

@ -24,6 +24,7 @@ import androidx.work.WorkerParameters
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ACCESS_TOKEN import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ACCESS_TOKEN
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ERROR_MESSAGE import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ERROR_MESSAGE
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_FILE_NAME import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_FILE_NAME
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_MODEL_DIR
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RATE import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RATE
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RECEIVED_BYTES import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RECEIVED_BYTES
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_REMAINING_MS import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_REMAINING_MS
@ -34,6 +35,7 @@ import com.google.aiedge.gallery.data.KEY_MODEL_START_UNZIPPING
import com.google.aiedge.gallery.data.KEY_MODEL_TOTAL_BYTES import com.google.aiedge.gallery.data.KEY_MODEL_TOTAL_BYTES
import com.google.aiedge.gallery.data.KEY_MODEL_UNZIPPED_DIR import com.google.aiedge.gallery.data.KEY_MODEL_UNZIPPED_DIR
import com.google.aiedge.gallery.data.KEY_MODEL_URL import com.google.aiedge.gallery.data.KEY_MODEL_URL
import com.google.aiedge.gallery.data.KEY_MODEL_VERSION
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import java.io.BufferedInputStream import java.io.BufferedInputStream
@ -48,7 +50,10 @@ import java.util.zip.ZipInputStream
private const val TAG = "AGDownloadWorker" private const val TAG = "AGDownloadWorker"
data class UrlAndFileName(val url: String, val fileName: String) data class UrlAndFileName(
val url: String,
val fileName: String,
)
class DownloadWorker(context: Context, params: WorkerParameters) : class DownloadWorker(context: Context, params: WorkerParameters) :
CoroutineWorker(context, params) { CoroutineWorker(context, params) {
@ -56,7 +61,9 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
override suspend fun doWork(): Result { override suspend fun doWork(): Result {
val fileUrl = inputData.getString(KEY_MODEL_URL) val fileUrl = inputData.getString(KEY_MODEL_URL)
val version = inputData.getString(KEY_MODEL_VERSION)!!
val fileName = inputData.getString(KEY_MODEL_DOWNLOAD_FILE_NAME) val fileName = inputData.getString(KEY_MODEL_DOWNLOAD_FILE_NAME)
val modelDir = inputData.getString(KEY_MODEL_DOWNLOAD_MODEL_DIR)!!
val isZip = inputData.getBoolean(KEY_MODEL_IS_ZIP, false) val isZip = inputData.getBoolean(KEY_MODEL_IS_ZIP, false)
val unzippedDir = inputData.getString(KEY_MODEL_UNZIPPED_DIR) val unzippedDir = inputData.getString(KEY_MODEL_UNZIPPED_DIR)
val extraDataFileUrls = inputData.getString(KEY_MODEL_EXTRA_DATA_URLS)?.split(",") ?: listOf() val extraDataFileUrls = inputData.getString(KEY_MODEL_EXTRA_DATA_URLS)?.split(",") ?: listOf()
@ -96,8 +103,20 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
connection.setRequestProperty("Authorization", "Bearer $accessToken") connection.setRequestProperty("Authorization", "Bearer $accessToken")
} }
// Prepare output file's dir.
val outputDir = File(
applicationContext.getExternalFilesDir(null),
listOf(modelDir, version).joinToString(separator = File.separator)
)
if (!outputDir.exists()) {
outputDir.mkdirs()
}
// Read the file and see if it is partially downloaded. // Read the file and see if it is partially downloaded.
val outputFile = File(applicationContext.getExternalFilesDir(null), file.fileName) val outputFile = File(
applicationContext.getExternalFilesDir(null),
listOf(modelDir, version, file.fileName).joinToString(separator = File.separator)
)
val outputFileBytes = outputFile.length() val outputFileBytes = outputFile.length()
if (outputFileBytes > 0) { if (outputFileBytes > 0) {
Log.d( Log.d(
@ -192,14 +211,19 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
setProgress(Data.Builder().putBoolean(KEY_MODEL_START_UNZIPPING, true).build()) setProgress(Data.Builder().putBoolean(KEY_MODEL_START_UNZIPPING, true).build())
// Prepare target dir. // Prepare target dir.
val destDir = File("${externalFilesDir}${File.separator}${unzippedDir}") val destDir =
File(
externalFilesDir,
listOf(modelDir, version, unzippedDir).joinToString(File.separator)
)
if (!destDir.exists()) { if (!destDir.exists()) {
destDir.mkdirs() destDir.mkdirs()
} }
// Unzip. // Unzip.
val unzipBuffer = ByteArray(4096) val unzipBuffer = ByteArray(4096)
val zipFilePath = "${externalFilesDir}${File.separator}${fileName}" val zipFilePath =
"${externalFilesDir}${File.separator}$modelDir${File.separator}$version${File.separator}${fileName}"
val zipIn = ZipInputStream(BufferedInputStream(FileInputStream(zipFilePath))) val zipIn = ZipInputStream(BufferedInputStream(FileInputStream(zipFilePath)))
var zipEntry: ZipEntry? = zipIn.nextEntry var zipEntry: ZipEntry? = zipIn.nextEntry

View file

@ -18,7 +18,7 @@ gson = "2.12.1"
lifecycleProcess = "2.8.7" lifecycleProcess = "2.8.7"
#noinspection GradleDependency #noinspection GradleDependency
mediapipeTasksText = "0.10.21" mediapipeTasksText = "0.10.21"
mediapipeTasksGenai = "0.10.22" mediapipeTasksGenai = "0.10.24"
mediapipeTasksImageGenerator = "0.10.21" mediapipeTasksImageGenerator = "0.10.21"
commonmark = "1.0.0-alpha02" commonmark = "1.0.0-alpha02"
richtext = "1.0.0-alpha02" richtext = "1.0.0-alpha02"

View file

@ -30,6 +30,7 @@ pluginManagement {
dependencyResolutionManagement { dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories { repositories {
mavenLocal()
google() google()
mavenCentral() mavenCentral()
} }