Add initial support for model allowlist, and stop generating response.

This commit is contained in:
Jing Jin 2025-05-15 00:31:45 -07:00
parent f9f1d71b38
commit ef290cd7b0
28 changed files with 676 additions and 358 deletions

View file

@ -32,7 +32,7 @@
android:dataExtractionRules="@xml/data_extraction_rules"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:label="Edge Gallery"
android:roundIcon="@mipmap/ic_launcher"
android:supportsRtl="true"
android:theme="@style/Theme.Gallery"

View file

@ -23,7 +23,6 @@ import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.preferencesDataStore
import com.google.aiedge.gallery.data.AppContainer
import com.google.aiedge.gallery.data.DefaultAppContainer
import com.google.aiedge.gallery.ui.common.processTasks
import com.google.aiedge.gallery.ui.theme.ThemeSettings
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
@ -35,9 +34,6 @@ class GalleryApplication : Application() {
override fun onCreate() {
super.onCreate()
// Process tasks.
processTasks()
container = DefaultAppContainer(this, dataStore)
// Load theme.

View file

@ -16,7 +16,6 @@
package com.google.aiedge.gallery.data
import com.google.aiedge.gallery.ui.common.ensureValidFileName
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
@ -27,15 +26,6 @@ import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonPrimitive
@Serializable
data class HfModelSummary(val modelId: String)
@Serializable
data class HfModelDetails(val id: String, val siblings: List<HfModelFile>)
@Serializable
data class HfModelFile(val rfilename: String)
@Serializable(with = ConfigValueSerializer::class)
sealed class ConfigValue {
@Serializable
@ -85,64 +75,6 @@ object ConfigValueSerializer : KSerializer<ConfigValue> {
}
}
@Serializable
data class HfModel(
var id: String = "",
val task: String,
val name: String,
val url: String = "",
val file: String = "",
val sizeInBytes: Long,
val configs: Map<String, ConfigValue>,
) {
fun toModel(): Model {
val parts = if (url.isNotEmpty()) {
url.split('/')
} else if (file.isNotEmpty()) {
listOf(file)
} else {
listOf("")
}
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
// Generate configs based on the given default values.
// val configs: List<Config> = when (task) {
// TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
// // todo: add configs for other types.
// else -> listOf()
// }
// todo: fix when loading from models.json
val configs: List<Config> = listOf()
// Construct url.
var modelUrl = url
if (modelUrl.isEmpty() && file.isNotEmpty()) {
modelUrl = "https://huggingface.co/${id}/resolve/main/${file}?download=true"
}
// Other parameters.
val showBenchmarkButton = when (task) {
TASK_LLM_CHAT.type.label -> false
else -> true
}
val showRunAgainButton = when (task) {
TASK_LLM_CHAT.type.label -> false
else -> true
}
return Model(
hfModelId = id,
name = name,
url = modelUrl,
sizeInBytes = sizeInBytes,
downloadFileName = fileName,
configs = configs,
showBenchmarkButton = showBenchmarkButton,
showRunAgainButton = showRunAgainButton,
)
}
}
fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int {
if (configValue == null) {
return default

View file

@ -18,6 +18,8 @@ package com.google.aiedge.gallery.data
// Keys used to send/receive data to Work.
const val KEY_MODEL_URL = "KEY_MODEL_URL"
const val KEY_MODEL_VERSION = "KEY_MODEL_VERSION"
const val KEY_MODEL_DOWNLOAD_MODEL_DIR = "KEY_MODEL_DOWNLOAD_MODEL_DIR"
const val KEY_MODEL_DOWNLOAD_FILE_NAME = "KEY_MODEL_DOWNLOAD_FILE_NAME"
const val KEY_MODEL_TOTAL_BYTES = "KEY_MODEL_TOTAL_BYTES"
const val KEY_MODEL_DOWNLOAD_RECEIVED_BYTES = "KEY_MODEL_DOWNLOAD_RECEIVED_BYTES"

View file

@ -37,13 +37,13 @@ import androidx.work.OutOfQuotaPolicy
import androidx.work.WorkInfo
import androidx.work.WorkManager
import androidx.work.WorkQuery
import com.google.aiedge.gallery.AppLifecycleProvider
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.worker.DownloadWorker
import com.google.common.util.concurrent.FutureCallback
import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.aiedge.gallery.AppLifecycleProvider
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.worker.DownloadWorker
import java.util.UUID
private const val TAG = "AGDownloadRepository"
@ -89,6 +89,8 @@ class DefaultDownloadRepository(
val builder = Data.Builder()
val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes }
val inputDataBuilder = builder.putString(KEY_MODEL_URL, model.url)
.putString(KEY_MODEL_VERSION, model.version)
.putString(KEY_MODEL_DOWNLOAD_MODEL_DIR, model.normalizedName)
.putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName)
.putBoolean(KEY_MODEL_IS_ZIP, model.isZip).putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir)
.putLong(

View file

@ -20,6 +20,7 @@ import android.content.Context
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import java.io.File
data class ModelDataFile(
val name: String,
@ -33,16 +34,22 @@ enum class Accelerator(val label: String) {
}
const val IMPORTS_DIR = "__imports"
private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]")
/** A model for a task */
data class Model(
/** The Hugging Face model ID (if applicable). */
val hfModelId: String = "",
/** The name (for display purpose) of the model. */
val name: String,
/** The name of the downloaded model file. */
/** The version of the model. */
val version: String = "_",
/**
* The name of the downloaded model file.
*
* The final file path of the downloaded model will be:
* {context.getExternalFilesDir}/{normalizedName}/{version}/{downloadFileName}
*/
val downloadFileName: String,
/** The URL to download the model from. */
@ -88,6 +95,7 @@ data class Model(
val imported: Boolean = false,
// The following fields are managed by the app. Don't need to set manually.
var normalizedName: String = "",
var instance: Any? = null,
var initializing: Boolean = false,
// TODO(jingjin): use a "queue" system to manage model init and cleanup.
@ -96,6 +104,10 @@ data class Model(
var totalBytes: Long = 0L,
var accessToken: String? = null,
) {
init {
normalizedName = NORMALIZE_NAME_REGEX.replace(name, "_")
}
fun preProcess() {
val configValues: MutableMap<String, Any> = mutableMapOf()
for (config in this.configs) {
@ -106,11 +118,22 @@ data class Model(
}
fun getPath(context: Context, fileName: String = downloadFileName): String {
val baseDir = "${context.getExternalFilesDir(null)}"
if (imported) {
return listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", fileName).joinToString(
File.separator
)
}
val baseDir =
listOf(
context.getExternalFilesDir(null)?.absolutePath ?: "",
normalizedName,
version
).joinToString(File.separator)
return if (this.isZip && this.unzipDir.isNotEmpty()) {
"$baseDir/${this.unzipDir}"
} else {
"$baseDir/${fileName}"
"$baseDir/$fileName"
}
}

View file

@ -0,0 +1,116 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPK
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPP
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import kotlinx.serialization.Serializable
/** A model in the model allowlist. */
@Serializable
data class AllowedModel(
val name: String,
val modelId: String,
val modelFile: String,
val description: String,
val sizeInBytes: Long,
val version: String,
val defaultConfig: Map<String, ConfigValue>,
val taskTypes: List<String>,
val disabled: Boolean? = null,
) {
fun toModel(): Model {
// Construct HF download url.
val downloadUrl = "https://huggingface.co/$modelId/resolve/main/$modelFile?download=true"
// Config.
val isLlmModel =
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)
var configs: List<Config> = listOf()
if (isLlmModel) {
var defaultTopK: Int = DEFAULT_TOPK
var defaultTopP: Float = DEFAULT_TOPP
var defaultTemperature: Float = DEFAULT_TEMPERATURE
var defaultMaxToken: Int = 1024
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
if (defaultConfig.containsKey("topK")) {
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK)
}
if (defaultConfig.containsKey("topP")) {
defaultTopP = getFloatConfigValue(defaultConfig["topP"], defaultTopP)
}
if (defaultConfig.containsKey("temperature")) {
defaultTemperature = getFloatConfigValue(defaultConfig["temperature"], defaultTemperature)
}
if (defaultConfig.containsKey("maxTokens")) {
defaultMaxToken = getIntConfigValue(defaultConfig["maxTokens"], defaultMaxToken)
}
if (defaultConfig.containsKey("accelerators")) {
val items = getStringConfigValue(defaultConfig["accelerators"], "gpu").split(",")
accelerators = mutableListOf()
for (item in items) {
if (item == "cpu") {
accelerators.add(Accelerator.CPU)
} else if (item == "gpu") {
accelerators.add(Accelerator.GPU)
}
}
}
configs = createLlmChatConfigs(
defaultTopK = defaultTopK,
defaultTopP = defaultTopP,
defaultTemperature = defaultTemperature,
defaultMaxToken = defaultMaxToken,
accelerators = accelerators,
)
}
// Misc.
var showBenchmarkButton = true
val showRunAgainButton = true
if (taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)) {
showBenchmarkButton = false
}
return Model(
name = name,
version = version,
info = description,
url = downloadUrl,
sizeInBytes = sizeInBytes,
configs = configs,
downloadFileName = modelFile,
showBenchmarkButton = showBenchmarkButton,
showRunAgainButton = showRunAgainButton,
learnMoreUrl = "https://huggingface.co/${modelId}"
)
}
override fun toString(): String {
return "$modelId/$modelFile"
}
}
/** The model allowlist. */
@Serializable
data class ModelAllowlist(
val models: List<AllowedModel>,
)

View file

@ -27,15 +27,15 @@ import androidx.compose.ui.graphics.vector.ImageVector
import com.google.aiedge.gallery.R
/** Type of task. */
enum class TaskType(val label: String) {
TEXT_CLASSIFICATION("Text Classification"),
IMAGE_CLASSIFICATION("Image Classification"),
IMAGE_GENERATION("Image Generation"),
LLM_CHAT("LLM Chat"),
LLM_SINGLE_TURN("LLM Use Cases"),
enum class TaskType(val label: String, val id: String) {
TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
TEST_TASK_1("Test task 1"),
TEST_TASK_2("Test task 2")
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
}
/** Data class for a task listed in home screen. */
@ -91,17 +91,19 @@ val TASK_IMAGE_CLASSIFICATION = Task(
val TASK_LLM_CHAT = Task(
type = TaskType.LLM_CHAT,
icon = Icons.Outlined.Forum,
models = MODELS_LLM,
// models = MODELS_LLM,
models = mutableListOf(),
description = "Chat with a on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_LLM_SINGLE_TURN = Task(
type = TaskType.LLM_SINGLE_TURN,
val TASK_LLM_USECASES = Task(
type = TaskType.LLM_USECASES,
icon = Icons.Outlined.Widgets,
models = MODELS_LLM,
// models = MODELS_LLM,
models = mutableListOf(),
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
@ -123,7 +125,7 @@ val TASKS: List<Task> = listOf(
// TASK_TEXT_CLASSIFICATION,
// TASK_IMAGE_CLASSIFICATION,
// TASK_IMAGE_GENERATION,
TASK_LLM_SINGLE_TURN,
TASK_LLM_USECASES,
TASK_LLM_CHAT,
)

View file

@ -28,12 +28,15 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowForward
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
@ -102,6 +105,7 @@ fun DownloadAndTryButton(
val context = LocalContext.current
var checkingToken by remember { mutableStateOf(false) }
var showAgreementAckSheet by remember { mutableStateOf(false) }
var showErrorDialog by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState()
// A launcher for requesting notification permission.
@ -208,12 +212,18 @@ fun DownloadAndTryButton(
TAG,
"Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download"
)
if (modelManagerViewModel.getModelUrlResponse(model = model) == HttpURLConnection.HTTP_OK) {
val firstResponseCode = modelManagerViewModel.getModelUrlResponse(model = model)
if (firstResponseCode == HttpURLConnection.HTTP_OK) {
Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...")
withContext(Dispatchers.Main) {
startDownload(null)
}
return@launch
} else if (firstResponseCode < 0) {
checkingToken = false
Log.e(TAG, "Unknown network error")
showErrorDialog = true
return@launch
}
Log.d(TAG, "Model '${model.name}' needs auth. Start token exchange process...")
@ -334,4 +344,30 @@ fun DownloadAndTryButton(
}
}
}
if (showErrorDialog) {
AlertDialog(
icon = {
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
},
title = {
Text("Unknown network error")
},
text = {
Text("Please check your internet connection.")
},
onDismissRequest = {
showErrorDialog = false
},
confirmButton = {
TextButton(
onClick = {
showErrorDialog = false
}
) {
Text("Close")
}
},
)
}
}

View file

@ -484,5 +484,7 @@ fun processLlmResponse(response: String): String {
}
}
newContent = newContent.replace("\\n", "\n")
return newContent
}

View file

@ -24,6 +24,7 @@ import com.google.aiedge.gallery.data.Model
enum class ChatMessageType {
INFO,
WARNING,
TEXT,
IMAGE,
IMAGE_WITH_HISTORY,
@ -57,6 +58,10 @@ class ChatMessageLoading : ChatMessage(type = ChatMessageType.LOADING, side = Ch
class ChatMessageInfo(val content: String) :
ChatMessage(type = ChatMessageType.INFO, side = ChatSide.SYSTEM)
/** Chat message for info (help). */
class ChatMessageWarning(val content: String) :
ChatMessage(type = ChatMessageType.WARNING, side = ChatSide.SYSTEM)
/** Chat message for config values change. */
class ChatMessageConfigValuesChange(
val model: Model,

View file

@ -269,6 +269,9 @@ fun ChatPanel(
// Info.
is ChatMessageInfo -> MessageBodyInfo(message = message)
// Warning
is ChatMessageWarning -> MessageBodyWarning(message = message)
// Config values change.
is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message)
@ -433,6 +436,7 @@ fun ChatPanel(
modelManagerViewModel = modelManagerViewModel,
curMessage = curMessage,
inProgress = uiState.inProgress,
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
onValueChanged = { curMessage = it },
onSendMessage = {

View file

@ -0,0 +1,62 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display warning message content within a chat.
*
* Supports markdown.
*/
@Composable
fun MessageBodyWarning(message: ChatMessageWarning) {
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
) {
Box(
modifier = Modifier
.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer)
) {
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyWarningPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyWarning(message = ChatMessageWarning(content = "This is a warning"))
}
}
}

View file

@ -75,6 +75,7 @@ fun MessageInputText(
modelManagerViewModel: ModelManagerViewModel,
curMessage: String,
inProgress: Boolean,
modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit,
onSendMessage: (ChatMessage) -> Unit,
@ -162,6 +163,7 @@ fun MessageInputText(
Spacer(modifier = Modifier.width(8.dp))
if (inProgress && showStopButtonWhenInProgress) {
if (!modelInitializing) {
IconButton(
onClick = onStopButtonClicked,
colors = IconButtonDefaults.iconButtonColors(
@ -174,6 +176,7 @@ fun MessageInputText(
tint = MaterialTheme.colorScheme.primary
)
}
}
} // Send button. Only shown when text is not empty.
else if (curMessage.isNotEmpty()) {
IconButton(
@ -230,6 +233,7 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
@ -239,6 +243,7 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = true,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
@ -247,6 +252,7 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
@ -255,6 +261,7 @@ fun MessageInputTextPreview() {
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = true,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},

View file

@ -39,7 +39,6 @@ import androidx.compose.material.icons.rounded.ChevronRight
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material.icons.rounded.UnfoldLess
import androidx.compose.material.icons.rounded.UnfoldMore
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.OutlinedButton

View file

@ -75,6 +75,8 @@ fun ModelNameAndStatus(
) {
Text(
model.name,
maxLines = 1,
overflow = TextOverflow.Ellipsis,
style = MaterialTheme.typography.titleMedium,
modifier = modifier,
)

View file

@ -38,6 +38,7 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.GridItemSpan
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
@ -46,8 +47,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
@ -57,6 +61,7 @@ import androidx.compose.material3.SmallFloatingActionButton
import androidx.compose.material3.SnackbarHost
import androidx.compose.material3.SnackbarHostState
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
@ -180,6 +185,7 @@ fun HomeScreen(
TaskList(
tasks = tasks,
navigateToTaskScreen = navigateToTaskScreen,
loadingModelAllowlist = uiState.loadingModelAllowlist,
modifier = Modifier.fillMaxSize(),
contentPadding = innerPadding,
)
@ -285,12 +291,40 @@ fun HomeScreen(
}
}
}
if (uiState.loadingModelAllowlistError.isNotEmpty()) {
AlertDialog(
icon = {
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
},
title = {
Text(uiState.loadingModelAllowlistError)
},
text = {
Text("Please check your internet connection and try again later.")
},
onDismissRequest = {
modelManagerViewModel.loadModelAllowlist()
},
confirmButton = {
TextButton(
onClick = {
modelManagerViewModel.loadModelAllowlist()
}
) {
Text("Retry")
}
},
)
}
}
@Composable
private fun TaskList(
tasks: List<Task>,
navigateToTaskScreen: (Task) -> Unit,
loadingModelAllowlist: Boolean,
modifier: Modifier = Modifier,
contentPadding: PaddingValues = PaddingValues(0.dp),
) {
@ -312,6 +346,25 @@ private fun TaskList(
)
}
if (loadingModelAllowlist) {
item(key = "loading", span = { GridItemSpan(2) }) {
Row(
horizontalArrangement = Arrangement.Center,
modifier = Modifier
.fillMaxWidth()
.padding(top = 32.dp)
) {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceVariant,
strokeWidth = 3.dp,
modifier = Modifier
.padding(end = 8.dp)
.size(20.dp)
)
Text("Loading model list...", style = MaterialTheme.typography.bodyMedium)
}
}
} else {
// Cards.
items(tasks) { task ->
TaskCard(
@ -324,6 +377,7 @@ private fun TaskList(
.aspectRatio(1f)
)
}
}
// Bottom padding.
item(key = "bottomPadding", span = { GridItemSpan(2) }) {

View file

@ -16,11 +16,15 @@
package com.google.aiedge.gallery.ui.llmchat
import android.util.Log
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.chat.ChatMessageInfo
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
import com.google.aiedge.gallery.ui.common.chat.ChatView
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
@ -40,6 +44,8 @@ fun LlmChatScreen(
factory = ViewModelProvider.Factory
),
) {
val context = LocalContext.current
ChatView(
task = viewModel.task,
viewModel = viewModel,
@ -51,25 +57,43 @@ fun LlmChatScreen(
)
if (message is ChatMessageText) {
modelManagerViewModel.addTextInputHistory(message.content)
viewModel.generateResponse(
viewModel.generateResponse(model = model, input = message.content, onError = {
viewModel.addMessage(
model = model,
input = message.content,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
)
modelManagerViewModel.initializeModel(
context = context, task = viewModel.task, model = model, force = true
)
})
}
},
onRunAgainClicked = { model, message ->
if (message is ChatMessageText) {
viewModel.runAgain(model = model, message = message)
viewModel.runAgain(model = model, message = message, onError = {
viewModel.addMessage(
model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
)
modelManagerViewModel.initializeModel(
context = context, task = viewModel.task, model = model, force = true
)
})
}
},
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
if (message is ChatMessageText) {
viewModel.benchmark(
model = model,
message = message
model = model, message = message
)
}
},
showStopButtonInInputWhenInProgress = true,
onStopButtonClicked = { model ->
viewModel.stopResponse(model = model)
},
navigateUp = navigateUp,
modifier = modifier,
)

View file

@ -16,6 +16,7 @@
package com.google.aiedge.gallery.ui.llmchat
import android.util.Log
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
@ -26,6 +27,7 @@ import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
import com.google.aiedge.gallery.ui.common.chat.ChatSide
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import com.google.aiedge.gallery.ui.common.chat.Stat
import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
@ -39,7 +41,7 @@ private val STATS = listOf(
)
class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
fun generateResponse(model: Model, input: String) {
fun generateResponse(model: Model, input: String, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
@ -65,6 +67,8 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
var prefillSpeed = 0f
var decodeSpeed: Float
val start = System.currentTimeMillis()
try {
LlmChatModelHelper.runInference(
model = model,
input = input,
@ -130,10 +134,23 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
}, cleanUpListener = {
setInProgress(false)
})
} catch (e: Exception) {
setInProgress(false)
onError()
}
}
}
fun runAgain(model: Model, message: ChatMessageText) {
fun stopResponse(model: Model) {
Log.d(TAG, "Stopping response for model ${model.name}...")
viewModelScope.launch(Dispatchers.Default) {
setInProgress(false)
val instance = model.instance as LlmModelInstance
instance.session.cancelGenerateResponseAsync()
}
}
fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
@ -147,6 +164,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
generateResponse(
model = model,
input = message.content,
onError = onError
)
}
}

View file

@ -20,7 +20,7 @@ import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.Stat
@ -64,7 +64,7 @@ private val STATS = listOf(
Stat(id = "latency", label = "Latency", unit = "sec")
)
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_SINGLE_TURN) : ViewModel() {
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow()

View file

@ -218,6 +218,7 @@ fun PromptTemplatesPanel(
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
.background(MaterialTheme.customColors.agentBubbleBgColor)
.padding(16.dp)
.focusRequester(focusRequester)
)
} else {
TextField(

View file

@ -57,7 +57,7 @@ import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
@ -76,7 +76,7 @@ fun ResponsePanel(
modelManagerViewModel: ModelManagerViewModel,
modifier: Modifier = Modifier,
) {
val task = TASK_LLM_SINGLE_TURN
val task = TASK_LLM_USECASES
val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val inProgress = uiState.inProgress

View file

@ -16,8 +16,6 @@
package com.google.aiedge.gallery.ui.modelmanager
import android.os.Build
import androidx.annotation.RequiresApi
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
@ -62,7 +60,6 @@ import com.google.aiedge.gallery.ui.theme.customColors
private const val TAG = "AGModelList"
/** The list of models in the model manager. */
@RequiresApi(Build.VERSION_CODES.O)
@Composable
fun ModelList(
task: Task,
@ -213,7 +210,6 @@ fun ClickableLink(
}
}
@RequiresApi(Build.VERSION_CODES.O)
@Preview(showBackground = true)
@Composable
fun ModelListPreview() {

View file

@ -30,32 +30,28 @@ import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.DownloadRepository
import com.google.aiedge.gallery.data.EMPTY_MODEL
import com.google.aiedge.gallery.data.HfModel
import com.google.aiedge.gallery.data.HfModelDetails
import com.google.aiedge.gallery.data.HfModelSummary
import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelAllowlist
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.common.AuthConfig
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.common.processTasks
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
@ -73,8 +69,9 @@ import java.net.HttpURLConnection
import java.net.URL
private const val TAG = "AGModelManagerViewModel"
private const val HG_COMMUNITY = "jinjingforevercommunity"
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
private const val MODEL_ALLOWLIST_URL =
"https://raw.githubusercontent.com/jinjingforever/kokoro-codelab-jingjin/refs/heads/main/model_allowlist.json"
data class ModelInitializationStatus(
val status: ModelInitializationStatusType, var error: String = ""
@ -122,6 +119,14 @@ data class ModelManagerUiState(
*/
val loadingHfModels: Boolean = false,
/**
* Whether the app is loading and processing the model allowlist.
*/
val loadingModelAllowlist: Boolean = true,
/** The error message when loading the model allowlist. */
val loadingModelAllowlistError: String = "",
/**
* The currently selected model.
*/
@ -153,7 +158,7 @@ open class ModelManagerViewModel(
private val externalFilesDir = context.getExternalFilesDir(null)
private val inProgressWorkInfos: List<AGWorkInfo> =
downloadRepository.getEnqueuedOrRunningWorkInfos()
protected val _uiState = MutableStateFlow(createUiState())
protected val _uiState = MutableStateFlow(createEmptyUiState())
val uiState = _uiState.asStateFlow()
val authService = AuthorizationService(context)
@ -162,44 +167,7 @@ open class ModelManagerViewModel(
var pagerScrollState: MutableStateFlow<PagerScrollState> = MutableStateFlow(PagerScrollState())
init {
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
// Iterate through the inProgressWorkInfos and retrieve the corresponding modes.
// Those models are the ones that have not finished downloading.
val models: MutableList<Model> = mutableListOf()
for (info in inProgressWorkInfos) {
getModelByName(info.modelName)?.let { model ->
models.add(model)
}
}
// Cancel all pending downloads for the retrieved models.
downloadRepository.cancelAll(models) {
Log.d(TAG, "All pending work is cancelled")
viewModelScope.launch(Dispatchers.IO) {
// Load models from hg community.
loadHfModels()
Log.d(TAG, "Done loading HF models")
// Kick off downloads for these models .
withContext(Dispatchers.Main) {
val tokenStatusAndData = getTokenStatusAndData()
for (info in inProgressWorkInfos) {
val model: Model? = getModelByName(info.modelName)
if (model != null) {
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
model.accessToken = tokenStatusAndData.data.accessToken
}
Log.d(TAG, "Sending a new download request for '${model.name}'")
downloadRepository.downloadModel(
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
)
}
}
}
}
}
loadModelAllowlist()
}
override fun onCleared() {
@ -231,12 +199,10 @@ open class ModelManagerViewModel(
}
fun deleteModel(task: Task, model: Model) {
if (model.imported) {
deleteFileFromExternalFilesDir(model.downloadFileName)
for (file in model.extraDataFiles) {
deleteFileFromExternalFilesDir(file.downloadFileName)
}
if (model.isZip && model.unzipDir.isNotEmpty()) {
deleteDirFromExternalFilesDir(model.unzipDir)
} else {
deleteDirFromExternalFilesDir(model.normalizedName)
}
// Update model download status to NotDownloaded.
@ -340,7 +306,7 @@ open class ModelManagerViewModel(
onDone = onDone,
)
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.initialize(
TaskType.LLM_USECASES -> LlmChatModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
@ -364,7 +330,7 @@ open class ModelManagerViewModel(
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
@ -444,6 +410,7 @@ open class ModelManagerViewModel(
}
fun getModelUrlResponse(model: Model, accessToken: String? = null): Int {
try {
val url = URL(model.url)
val connection = url.openConnection() as HttpURLConnection
if (accessToken != null) {
@ -455,22 +422,28 @@ open class ModelManagerViewModel(
// Report the result.
return connection.responseCode
} catch (e: Exception) {
Log.e(TAG, "$e")
return -1
}
}
fun addImportedLlmModel(info: ImportedModelInfo) {
Log.d(TAG, "adding imported llm model: $info")
// Create model.
val model = createModelFromImportedModelInfo(info = info)
// Remove duplicated imported model if existed.
val task = TASK_LLM_CHAT
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES)) {
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex)
}
// Create model.
val model = createModelFromImportedModelInfo(info = info, task = task)
task.models.add(model)
task.updateTrigger.value = System.currentTimeMillis()
}
// Add initial status and states.
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
@ -491,10 +464,6 @@ open class ModelManagerViewModel(
modelInitializationStatus = modelInstances
)
}
task.updateTrigger.value = System.currentTimeMillis()
// Also need to update single turn task.
TASK_LLM_SINGLE_TURN.updateTrigger.value = System.currentTimeMillis()
// Add to preference storage.
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
@ -623,10 +592,110 @@ open class ModelManagerViewModel(
}
}
private fun processPendingDownloads() {
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
// Iterate through the inProgressWorkInfos and retrieve the corresponding modes.
// Those models are the ones that have not finished downloading.
val models: MutableList<Model> = mutableListOf()
for (info in inProgressWorkInfos) {
getModelByName(info.modelName)?.let { model ->
models.add(model)
}
}
// Cancel all pending downloads for the retrieved models.
downloadRepository.cancelAll(models) {
Log.d(TAG, "All pending work is cancelled")
viewModelScope.launch(Dispatchers.IO) {
// Kick off downloads for these models .
withContext(Dispatchers.Main) {
val tokenStatusAndData = getTokenStatusAndData()
for (info in inProgressWorkInfos) {
val model: Model? = getModelByName(info.modelName)
if (model != null) {
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
model.accessToken = tokenStatusAndData.data.accessToken
}
Log.d(TAG, "Sending a new download request for '${model.name}'")
downloadRepository.downloadModel(
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
)
}
}
}
}
}
}
fun loadModelAllowlist() {
_uiState.update {
uiState.value.copy(
loadingModelAllowlist = true,
loadingModelAllowlistError = ""
)
}
viewModelScope.launch(Dispatchers.IO) {
try {
// Load model allowlist json.
val modelAllowlist: ModelAllowlist? =
getJsonResponse<ModelAllowlist>(url = MODEL_ALLOWLIST_URL)
if (modelAllowlist == null) {
_uiState.update { uiState.value.copy(loadingModelAllowlistError = "Failed to load model list") }
return@launch
}
Log.d(TAG, "Allowlist: $modelAllowlist")
// Convert models in the allowlist.
for (allowedModel in modelAllowlist.models) {
if (allowedModel.disabled == true) {
continue
}
val model = allowedModel.toModel()
if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) {
TASK_LLM_CHAT.models.add(model)
}
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
TASK_LLM_USECASES.models.add(model)
}
}
// Pre-process all tasks.
processTasks()
// Update UI state.
val newUiState = createUiState()
_uiState.update {
newUiState.copy(
loadingModelAllowlist = false,
)
}
// Process pending downloads.
processPendingDownloads()
} catch (e: Exception) {
e.printStackTrace()
}
}
}
private fun isModelPartiallyDownloaded(model: Model): Boolean {
return inProgressWorkInfos.find { it.modelName == model.name } != null
}
private fun createEmptyUiState(): ModelManagerUiState {
return ModelManagerUiState(
tasks = listOf(),
modelDownloadStatus = mapOf(),
modelInitializationStatus = mapOf(),
)
}
private fun createUiState(): ModelManagerUiState {
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
@ -643,11 +712,11 @@ open class ModelManagerViewModel(
Log.d(TAG, "stored imported model: $importedModel")
// Create model.
val model = createModelFromImportedModelInfo(info = importedModel, task = TASK_LLM_CHAT)
val model = createModelFromImportedModelInfo(info = importedModel)
// Add to task.
val task = TASK_LLM_CHAT
task.models.add(model)
TASK_LLM_CHAT.models.add(model)
TASK_LLM_USECASES.models.add(model)
// Update status.
modelDownloadStatus[model.name] = ModelDownloadStatus(
@ -660,6 +729,7 @@ open class ModelManagerViewModel(
val textInputHistory = dataStoreRepository.readTextInputHistory()
Log.d(TAG, "text input history: $textInputHistory")
Log.d(TAG, "model download status: $modelDownloadStatus")
return ModelManagerUiState(
tasks = TASKS,
modelDownloadStatus = modelDownloadStatus,
@ -668,7 +738,7 @@ open class ModelManagerViewModel(
)
}
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
private fun createModelFromImportedModelInfo(info: ImportedModelInfo): Model {
val accelerators: List<Accelerator> = (convertValueToTargetType(
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING
) as String).split(",").mapNotNull { acceleratorLabel ->
@ -733,74 +803,6 @@ open class ModelManagerViewModel(
)
}
suspend fun loadHfModels() {
// Update loading state shown in ui.
_uiState.update {
uiState.value.copy(
loadingHfModels = true,
)
}
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
try {
// Load model summaries.
val modelSummaries =
getJsonResponse<List<HfModelSummary>>(url = "https://huggingface.co/api/models?search=$HG_COMMUNITY")
Log.d(TAG, "HF model summaries: $modelSummaries")
// Load individual models in parallel.
if (modelSummaries != null) {
coroutineScope {
val hfModels = modelSummaries.map { summary ->
async {
val details =
getJsonResponse<HfModelDetails>(url = "https://huggingface.co/api/models/${summary.modelId}")
if (details != null && details.siblings.find { it.rfilename == "app.json" } != null) {
val hfModel =
getJsonResponse<HfModel>(url = "https://huggingface.co/${summary.modelId}/resolve/main/app.json")
if (hfModel != null) {
hfModel.id = details.id
}
return@async hfModel
}
return@async null
}
}
// Process loaded app.json
for (hfModel in hfModels.awaitAll()) {
if (hfModel != null) {
Log.d(TAG, "HF model: $hfModel")
val task = TASKS.find { it.type.label == hfModel.task }
val model = hfModel.toModel()
if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) {
model.preProcess()
Log.d(TAG, "AG model: $model")
task.models.add(model)
// Add initial status and states.
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
}
}
}
}
}
_uiState.update {
uiState.value.copy(
loadingHfModels = false,
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances
)
}
} catch (e: Exception) {
e.printStackTrace()
}
}
private inline fun <reified T> getJsonResponse(url: String): T? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
@ -817,9 +819,10 @@ open class ModelManagerViewModel(
val jsonObj = json.decodeFromString<T>(response)
return jsonObj
} else {
println("HTTP error: $responseCode")
Log.e(TAG, "HTTP error: $responseCode")
}
} catch (e: Exception) {
Log.e(TAG, "Error when getting json response: ${e.message}")
e.printStackTrace()
}
@ -859,11 +862,18 @@ open class ModelManagerViewModel(
}
private fun isModelDownloaded(model: Model): Boolean {
val downloadedFileExists =
model.downloadFileName.isNotEmpty() && isFileInExternalFilesDir(model.downloadFileName)
val downloadedFileExists = model.downloadFileName.isNotEmpty() && isFileInExternalFilesDir(
listOf(
model.normalizedName, model.version, model.downloadFileName
).joinToString(File.separator)
)
val unzippedDirectoryExists =
model.isZip && model.unzipDir.isNotEmpty() && isFileInExternalFilesDir(model.unzipDir)
model.isZip && model.unzipDir.isNotEmpty() && isFileInExternalFilesDir(
listOf(
model.normalizedName, model.version, model.unzipDir
).joinToString(File.separator)
)
// Will also return true if model is partially downloaded.
return downloadedFileExists || unzippedDirectoryExists

View file

@ -46,7 +46,7 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
@ -231,7 +231,7 @@ fun GalleryNavHost(
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_SINGLE_TURN)?.let { defaultModel ->
getModelFromNavigationParam(it, TASK_LLM_USECASES)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen(
@ -271,7 +271,7 @@ fun navigateToTaskScreen(
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_SINGLE_TURN -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}

View file

@ -24,6 +24,7 @@ import androidx.work.WorkerParameters
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ACCESS_TOKEN
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ERROR_MESSAGE
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_FILE_NAME
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_MODEL_DIR
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RATE
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RECEIVED_BYTES
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_REMAINING_MS
@ -34,6 +35,7 @@ import com.google.aiedge.gallery.data.KEY_MODEL_START_UNZIPPING
import com.google.aiedge.gallery.data.KEY_MODEL_TOTAL_BYTES
import com.google.aiedge.gallery.data.KEY_MODEL_UNZIPPED_DIR
import com.google.aiedge.gallery.data.KEY_MODEL_URL
import com.google.aiedge.gallery.data.KEY_MODEL_VERSION
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.BufferedInputStream
@ -48,7 +50,10 @@ import java.util.zip.ZipInputStream
private const val TAG = "AGDownloadWorker"
data class UrlAndFileName(val url: String, val fileName: String)
data class UrlAndFileName(
val url: String,
val fileName: String,
)
class DownloadWorker(context: Context, params: WorkerParameters) :
CoroutineWorker(context, params) {
@ -56,7 +61,9 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
override suspend fun doWork(): Result {
val fileUrl = inputData.getString(KEY_MODEL_URL)
val version = inputData.getString(KEY_MODEL_VERSION)!!
val fileName = inputData.getString(KEY_MODEL_DOWNLOAD_FILE_NAME)
val modelDir = inputData.getString(KEY_MODEL_DOWNLOAD_MODEL_DIR)!!
val isZip = inputData.getBoolean(KEY_MODEL_IS_ZIP, false)
val unzippedDir = inputData.getString(KEY_MODEL_UNZIPPED_DIR)
val extraDataFileUrls = inputData.getString(KEY_MODEL_EXTRA_DATA_URLS)?.split(",") ?: listOf()
@ -96,8 +103,20 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
connection.setRequestProperty("Authorization", "Bearer $accessToken")
}
// Prepare output file's dir.
val outputDir = File(
applicationContext.getExternalFilesDir(null),
listOf(modelDir, version).joinToString(separator = File.separator)
)
if (!outputDir.exists()) {
outputDir.mkdirs()
}
// Read the file and see if it is partially downloaded.
val outputFile = File(applicationContext.getExternalFilesDir(null), file.fileName)
val outputFile = File(
applicationContext.getExternalFilesDir(null),
listOf(modelDir, version, file.fileName).joinToString(separator = File.separator)
)
val outputFileBytes = outputFile.length()
if (outputFileBytes > 0) {
Log.d(
@ -192,14 +211,19 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
setProgress(Data.Builder().putBoolean(KEY_MODEL_START_UNZIPPING, true).build())
// Prepare target dir.
val destDir = File("${externalFilesDir}${File.separator}${unzippedDir}")
val destDir =
File(
externalFilesDir,
listOf(modelDir, version, unzippedDir).joinToString(File.separator)
)
if (!destDir.exists()) {
destDir.mkdirs()
}
// Unzip.
val unzipBuffer = ByteArray(4096)
val zipFilePath = "${externalFilesDir}${File.separator}${fileName}"
val zipFilePath =
"${externalFilesDir}${File.separator}$modelDir${File.separator}$version${File.separator}${fileName}"
val zipIn = ZipInputStream(BufferedInputStream(FileInputStream(zipFilePath)))
var zipEntry: ZipEntry? = zipIn.nextEntry

View file

@ -18,7 +18,7 @@ gson = "2.12.1"
lifecycleProcess = "2.8.7"
#noinspection GradleDependency
mediapipeTasksText = "0.10.21"
mediapipeTasksGenai = "0.10.22"
mediapipeTasksGenai = "0.10.24"
mediapipeTasksImageGenerator = "0.10.21"
commonmark = "1.0.0-alpha02"
richtext = "1.0.0-alpha02"

View file

@ -30,6 +30,7 @@ pluginManagement {
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
mavenLocal()
google()
mavenCentral()
}