Add initial support for importing local model.

This commit is contained in:
Jing Jin 2025-04-17 17:54:57 -07:00
parent b2f35a86e7
commit 29b614355e
19 changed files with 789 additions and 126 deletions

View file

@ -47,6 +47,8 @@ interface DataStoreRepository {
fun readThemeOverride(): String
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
fun readAccessTokenData(): AccessTokenData?
fun saveLocalModels(localModels: List<LocalModelInfo>)
fun readLocalModels(): List<LocalModelInfo>
}
/**
@ -79,6 +81,9 @@ class DefaultDataStoreRepository(
val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv")
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
// Data for all imported local models.
val LOCAL_MODELS = stringPreferencesKey("local_models")
}
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
@ -155,6 +160,26 @@ class DefaultDataStoreRepository(
}
}
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
runBlocking {
dataStore.edit { preferences ->
val gson = Gson()
val jsonString = gson.toJson(localModels)
preferences[PreferencesKeys.LOCAL_MODELS] = jsonString
}
}
}
override fun readLocalModels(): List<LocalModelInfo> {
return runBlocking {
val preferences = dataStore.data.first()
val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]"
val gson = Gson()
val listType = object : TypeToken<List<LocalModelInfo>>() {}.type
gson.fromJson(infosStr, listType)
}
}
private fun getTextInputHistory(preferences: Preferences): List<String> {
val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]"
val gson = Gson()

View file

@ -16,6 +16,7 @@
package com.google.aiedge.gallery.data
import com.google.aiedge.gallery.ui.common.ensureValidFileName
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
@ -103,7 +104,7 @@ data class HfModel(
} else {
listOf("")
}
val fileName = "${id}_${(parts.lastOrNull() ?: "")}".replace(Regex("[^a-zA-Z0-9._-]"), "_")
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
// Generate configs based on the given default values.
val configs: List<Config> = when (task) {

View file

@ -32,6 +32,8 @@ enum class LlmBackend {
CPU, GPU
}
const val IMPORTS_DIR = "__imports"
/** A model for a task */
data class Model(
/** The Hugging Face model ID (if applicable). */
@ -85,6 +87,9 @@ data class Model(
/** The prompt templates for the model (only for LLM). */
val llmPromptTemplates: List<PromptTemplate> = listOf(),
/** Whether the model is imported as a local model. */
val isLocalModel: Boolean = false,
// The following fields are managed by the app. Don't need to set manually.
var taskType: TaskType? = null,
var instance: Any? = null,
@ -104,10 +109,11 @@ data class Model(
}
fun getPath(context: Context, fileName: String = downloadFileName): String {
val baseDir = "${context.getExternalFilesDir(null)}"
return if (this.isZip && this.unzipDir.isNotEmpty()) {
"${context.getExternalFilesDir(null)}/${this.unzipDir}"
"$baseDir/${this.unzipDir}"
} else {
"${context.getExternalFilesDir(null)}/${fileName}"
"$baseDir/${fileName}"
}
}
@ -140,6 +146,9 @@ data class Model(
}
}
/** Data for a imported local model. */
data class LocalModelInfo(val fileName: String, val fileSize: Long)
enum class ModelDownloadStatusType {
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
}

View file

@ -19,6 +19,8 @@ package com.google.aiedge.gallery.data
import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ImageSearch
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import androidx.compose.ui.graphics.vector.ImageVector
import com.google.aiedge.gallery.R
@ -63,7 +65,9 @@ data class Task(
@StringRes val textInputPlaceHolderRes: Int = R.string.chat_textinput_placeholder,
// The following fields are managed by the app. Don't need to set manually.
var index: Int = -1
var index: Int = -1,
val updateTrigger: MutableState<Long> = mutableStateOf(0)
)
val TASK_TEXT_CLASSIFICATION = Task(

View file

@ -46,6 +46,7 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.modelmanager.TokenRequestResultType
import com.google.aiedge.gallery.ui.modelmanager.TokenStatus
@ -90,6 +91,7 @@ private const val TAG = "AGDownloadAndTryButton"
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun DownloadAndTryButton(
task: Task,
model: Model,
enabled: Boolean,
needToDownloadFirst: Boolean,
@ -106,17 +108,18 @@ fun DownloadAndTryButton(
val permissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(model)
modelManagerViewModel.downloadModel(task = task, model = model)
}
// Function to kick off download.
val startDownload: (accessToken: String?) -> Unit = { accessToken ->
model.accessToken = accessToken
onClicked()
checkNotificationPermissonAndStartDownload(
checkNotificationPermissionAndStartDownload(
context = context,
launcher = permissionLauncher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = model
)
checkingToken = false

View file

@ -416,10 +416,11 @@ fun getTaskIconColor(index: Int): Color {
return MaterialTheme.customColors.taskIconColors[colorIndex]
}
fun checkNotificationPermissonAndStartDownload(
fun checkNotificationPermissionAndStartDownload(
context: Context,
launcher: ManagedActivityResultLauncher<String, Boolean>,
modelManagerViewModel: ModelManagerViewModel,
task: Task,
model: Model
) {
// Check permission
@ -428,7 +429,7 @@ fun checkNotificationPermissonAndStartDownload(
ContextCompat.checkSelfPermission(
context, Manifest.permission.POST_NOTIFICATIONS
) -> {
modelManagerViewModel.downloadModel(model)
modelManagerViewModel.downloadModel(task = task, model = model)
}
// Otherwise, ask for permission
@ -440,3 +441,14 @@ fun checkNotificationPermissonAndStartDownload(
}
}
fun ensureValidFileName(fileName: String): String {
return fileName.replace(Regex("[^a-zA-Z0-9._-]"), "_")
}
fun cleanUpMediapipeTaskErrorMessage(message: String): String {
val index = message.indexOf("=== Source Location Trace")
if (index >= 0) {
return message.substring(0, index)
}
return message
}

View file

@ -41,10 +41,13 @@ import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Timer
import androidx.compose.material.icons.rounded.ContentCopy
import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
@ -83,11 +86,12 @@ import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatus
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
@ -113,6 +117,7 @@ fun ChatPanel(
onSendMessage: (Model, ChatMessage) -> Unit,
onRunAgainClicked: (Model, ChatMessage) -> Unit,
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
onStreamEnd: (Int) -> Unit = {},
@ -140,6 +145,8 @@ fun ChatPanel(
var showMessageLongPressedSheet by remember { mutableStateOf(false) }
val longPressedMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
var showErrorDialog by remember { mutableStateOf(false) }
// Keep track of the last message and last message content.
val lastMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
val lastMessageContent: MutableState<String> = remember { mutableStateOf("") }
@ -201,6 +208,10 @@ fun ChatPanel(
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[selectedModel.name]
LaunchedEffect(modelInitializationStatus) {
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
}
Column(
modifier = modifier.imePadding()
) {
@ -417,7 +428,7 @@ fun ChatPanel(
// Model initialization in-progress message.
this@Column.AnimatedVisibility(
visible = modelInitializationStatus == ModelInitializationStatus.INITIALIZING,
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
modifier = Modifier.offset(y = 12.dp)
@ -479,6 +490,47 @@ fun ChatPanel(
}
}
// Error dialog.
if (showErrorDialog) {
Dialog(
onDismissRequest = {
showErrorDialog = false
navigateUp()
},
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title
Text(
"Error",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
// Error
Text(
modelInitializationStatus?.error ?: "",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.error,
)
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button(onClick = {
showErrorDialog = false
navigateUp()
}) {
Text("Close")
}
}
}
}
}
}
// Benchmark config dialog.
if (showBenchmarkConfigsDialog) {
BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false },
@ -547,6 +599,7 @@ fun ChatPanelPreview() {
task = task,
selectedModel = TASK_TEST1.models[1],
viewModel = PreviewChatModel(context = context),
navigateUp = {},
onSendMessage = { _, _ -> },
onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> },

View file

@ -55,7 +55,7 @@ import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
@ -104,7 +104,7 @@ fun ChatView(
val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(selectedModel)
modelManagerViewModel.downloadModel(task = task, model = selectedModel)
}
val handleNavigateUp = {
@ -245,10 +245,11 @@ fun ChatView(
exit = fadeOut()
) {
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
checkNotificationPermissonAndStartDownload(
checkNotificationPermissionAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = curSelectedModel
)
})
@ -261,6 +262,7 @@ fun ChatView(
task = task,
selectedModel = curSelectedModel,
viewModel = viewModel,
navigateUp = navigateUp,
onSendMessage = onSendMessage,
onRunAgainClicked = onRunAgainClicked,
onBenchmarkClicked = onBenchmarkClicked,

View file

@ -35,6 +35,7 @@ import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ChevronRight
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material.icons.rounded.UnfoldLess
import androidx.compose.material.icons.rounded.UnfoldMore
@ -68,7 +69,7 @@ import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
import com.google.aiedge.gallery.ui.common.TaskIcon
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
import com.google.aiedge.gallery.ui.common.getTaskBgColor
import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@ -113,7 +114,7 @@ fun ModelItem(
val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(model)
modelManagerViewModel.downloadModel(task = task, model = model)
}
var isExpanded by remember { mutableStateOf(false) }
@ -156,10 +157,11 @@ fun ModelItem(
modelManagerViewModel = modelManagerViewModel,
downloadStatus = downloadStatus,
onDownloadClicked = { model ->
checkNotificationPermissonAndStartDownload(
checkNotificationPermissionAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = model
)
},
@ -186,7 +188,9 @@ fun ModelItem(
}
} else {
Icon(
if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
// For local model, show ">" directly indicating users can just tap the model item to
// go into it without needing to expand it first.
if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "",
tint = getTaskIconColor(task),
)
@ -237,6 +241,7 @@ fun ModelItem(
val needToDownloadFirst =
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
DownloadAndTryButton(
task = task,
model = model,
enabled = isExpanded,
needToDownloadFirst = needToDownloadFirst,
@ -266,7 +271,13 @@ fun ModelItem(
)
boxModifier = if (canExpand) {
boxModifier.clickable(
onClick = { isExpanded = !isExpanded },
onClick = {
if (!model.isLocalModel) {
isExpanded = !isExpanded
} else {
onModelClicked(model)
}
},
interactionSource = remember { MutableInteractionSource() },
indication = ripple(
bounded = true,

View file

@ -124,7 +124,7 @@ fun ModelItemActionButton(
if (showConfirmDeleteDialog) {
ConfirmDeleteModelDialog(model = model, onConfirm = {
modelManagerViewModel.deleteModel(model)
modelManagerViewModel.deleteModel(task = task, model = model)
showConfirmDeleteDialog = false
}, onDismiss = {
showConfirmDeleteDialog = false

View file

@ -48,7 +48,7 @@ class ImageClassificationInferenceResult(
//TODO: handle error.
object ImageClassificationModelHelper {
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU)
TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask ->
val optionsBuilder = TfLiteInitializationOptions.builder()
@ -69,7 +69,7 @@ object ImageClassificationModelHelper {
File(model.getPath(context = context)), interpreterOption
)
model.instance = interpreter
onDone()
onDone("")
}
}
}

View file

@ -24,6 +24,7 @@ import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.LatencyProvider
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import kotlin.random.Random
private const val TAG = "AGImageGenerationModelHelper"
@ -33,12 +34,17 @@ class ImageGenerationInferenceResult(
) : LatencyProvider
object ImageGenerationModelHelper {
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
val options = ImageGenerator.ImageGeneratorOptions.builder()
.setImageGeneratorModelDirectory(model.getPath(context = context))
.build()
model.instance = ImageGenerator.createFromOptions(context, options)
onDone()
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
try {
val options = ImageGenerator.ImageGeneratorOptions.builder()
.setImageGeneratorModelDirectory(model.getPath(context = context))
.build()
model.instance = ImageGenerator.createFromOptions(context, options)
} catch (e: Exception) {
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
return
}
onDone("")
}
fun cleanUp(model: Model) {

View file

@ -21,6 +21,7 @@ import android.util.Log
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.LlmBackend
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
@ -40,7 +41,7 @@ object LlmChatModelHelper {
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
fun initialize(
context: Context, model: Model, onDone: () -> Unit
context: Context, model: Model, onDone: (String) -> Unit
) {
val maxTokens =
model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN)
@ -68,9 +69,10 @@ object LlmChatModelHelper {
)
model.instance = LlmModelInstance(engine = llmInference, session = session)
} catch (e: Exception) {
e.printStackTrace()
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
return
}
onDone()
onDone("")
}
fun cleanUp(model: Model) {

View file

@ -0,0 +1,260 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.modelmanager
import android.content.Context
import android.net.Uri
import android.provider.OpenableColumns
import android.util.Log
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.Icon
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties
import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.ui.common.ensureValidFileName
import com.google.aiedge.gallery.ui.common.humanReadableSize
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import java.io.File
import java.io.FileOutputStream
import java.net.URLDecoder
import java.nio.charset.StandardCharsets
private const val TAG = "AGModelImportDialog"
data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "")
@Composable
fun ModelImportDialog(
uri: Uri, onDone: (ModelImportInfo) -> Unit
) {
val context = LocalContext.current
val coroutineScope = rememberCoroutineScope()
var fileName by remember { mutableStateOf("") }
var fileSize by remember { mutableLongStateOf(0L) }
var error by remember { mutableStateOf("") }
var progress by remember { mutableFloatStateOf(0f) }
LaunchedEffect(Unit) {
error = ""
// Get basic info.
val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri)
fileSize = info.first
fileName = ensureValidFileName(info.second)
// Import.
importModel(
context = context,
coroutineScope = coroutineScope,
fileName = fileName,
fileSize = fileSize,
uri = uri,
onDone = {
onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error))
},
onProgress = {
progress = it
},
onError = {
error = it
}
)
}
Dialog(
properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false),
onDismissRequest = {},
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title.
Text(
"Importing...",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
// No error.
if (error.isEmpty()) {
// Progress bar.
Column(verticalArrangement = Arrangement.spacedBy(4.dp)) {
Text(
"$fileName (${fileSize.humanReadableSize()})",
style = MaterialTheme.typography.labelSmall,
)
val animatedProgress = remember { Animatable(0f) }
LinearProgressIndicator(
progress = { animatedProgress.value },
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 8.dp),
)
LaunchedEffect(progress) {
animatedProgress.animateTo(progress, animationSpec = tween(150))
}
}
}
// Has error.
else {
Row(
verticalAlignment = Alignment.Top,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(
Icons.Rounded.Error,
contentDescription = "",
tint = MaterialTheme.colorScheme.error
)
Text(
error,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.error,
modifier = Modifier.padding(top = 4.dp)
)
}
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button(onClick = {
onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error))
}) {
Text("Close")
}
}
}
}
}
}
}
private fun importModel(
context: Context,
coroutineScope: CoroutineScope,
fileName: String,
fileSize: Long,
uri: Uri,
onDone: () -> Unit,
onProgress: (Float) -> Unit,
onError: (String) -> Unit,
) {
// TODO: handle error.
coroutineScope.launch(Dispatchers.IO) {
// Get the last component of the uri path as the imported file name.
val decodedUri = URLDecoder.decode(uri.toString(), StandardCharsets.UTF_8.name())
Log.d(TAG, "importing model from $decodedUri. File name: $fileName. File size: $fileSize")
// Create <app_external_dir>/imports if not exist.
val importsDir = File(context.getExternalFilesDir(null), IMPORTS_DIR)
if (!importsDir.exists()) {
importsDir.mkdirs()
}
// Import by copying the file over.
val outputFile = File(context.getExternalFilesDir(null), "$IMPORTS_DIR/$fileName")
val outputStream = FileOutputStream(outputFile)
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
var bytesRead: Int
var lastSetProgressTs: Long = 0
var importedBytes = 0L
val inputStream = context.contentResolver.openInputStream(uri)
try {
if (inputStream != null) {
while (inputStream.read(buffer).also { bytesRead = it } != -1) {
outputStream.write(buffer, 0, bytesRead)
importedBytes += bytesRead
// Report progress every 200 ms.
val curTs = System.currentTimeMillis()
if (curTs - lastSetProgressTs > 200) {
Log.d(TAG, "importing progress: $importedBytes, $fileSize")
lastSetProgressTs = curTs
if (fileSize != 0L) {
onProgress(importedBytes.toFloat() / fileSize.toFloat())
}
}
}
}
} catch (e: Exception) {
e.printStackTrace()
onError(e.message ?: "Failed to import")
return@launch
} finally {
inputStream?.close()
outputStream.close()
}
Log.d(TAG, "import done")
onProgress(1f)
onDone()
}
}
private fun getFileSizeAndDisplayNameFromUri(context: Context, uri: Uri): Pair<Long, String> {
val contentResolver = context.contentResolver
var fileSize = 0L
var displayName = ""
try {
contentResolver.query(
uri, arrayOf(OpenableColumns.SIZE, OpenableColumns.DISPLAY_NAME), null, null, null
)?.use { cursor ->
if (cursor.moveToFirst()) {
val sizeIndex = cursor.getColumnIndexOrThrow(OpenableColumns.SIZE)
fileSize = cursor.getLong(sizeIndex)
val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME)
displayName = cursor.getString(nameIndex)
}
}
} catch (e: Exception) {
e.printStackTrace()
return Pair(0L, "")
}
return Pair(fileSize, displayName)
}

View file

@ -16,24 +16,46 @@
package com.google.aiedge.gallery.ui.modelmanager
import android.content.Intent
import android.net.Uri
import android.os.Build
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher
import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.RequiresApi
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.outlined.Code
import androidx.compose.material.icons.outlined.Description
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.SmallFloatingActionButton
import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.vector.ImageVector
@ -53,8 +75,14 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private const val TAG = "AGModelList"
/** The list of models in the model manager. */
@RequiresApi(Build.VERSION_CODES.O)
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelList(
task: Task,
@ -63,65 +91,214 @@ fun ModelList(
onModelClicked: (Model) -> Unit,
modifier: Modifier = Modifier,
) {
LazyColumn(
modifier = modifier.padding(top = 8.dp),
contentPadding = contentPadding,
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// Headline.
item(key = "headline") {
Text(
task.description,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier
.fillMaxWidth()
)
}
var showAddModelSheet by remember { mutableStateOf(false) }
var showImportingDialog by remember { mutableStateOf(false) }
val curFileUri = remember { mutableStateOf<Uri?>(null) }
val sheetState = rememberModalBottomSheetState()
val coroutineScope = rememberCoroutineScope()
// URLs.
item(key = "urls") {
Row(
horizontalArrangement = Arrangement.Center,
modifier = Modifier
.fillMaxWidth()
.padding(top = 12.dp, bottom = 16.dp),
) {
Column(
horizontalAlignment = Alignment.Start,
verticalArrangement = Arrangement.spacedBy(4.dp),
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
// be properly updated.
val models by remember {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
task.models.toList().filter { !it.isLocalModel }
} else {
listOf()
}
}
}
val localModels by remember {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
task.models.toList().filter { it.isLocalModel }
} else {
listOf()
}
}
}
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
curFileUri.value = uri
showImportingDialog = true
} ?: run {
Log.d(TAG, "No file selected or URI is null.")
}
} else {
Log.d(TAG, "File picking cancelled.")
}
}
Box(contentAlignment = Alignment.BottomEnd) {
LazyColumn(
modifier = modifier.padding(top = 8.dp),
contentPadding = contentPadding,
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// Headline.
item(key = "headline") {
Text(
task.description,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.fillMaxWidth()
)
}
// URLs.
item(key = "urls") {
Row(
horizontalArrangement = Arrangement.Center,
modifier = Modifier
.fillMaxWidth()
.padding(top = 12.dp, bottom = 16.dp),
) {
if (task.docUrl.isNotEmpty()) {
ClickableLink(
url = task.docUrl,
linkText = "API Documentation",
icon = Icons.Outlined.Description
)
}
if (task.sourceCodeUrl.isNotEmpty()) {
ClickableLink(
url = task.sourceCodeUrl,
linkText = "Example code",
icon = Icons.Outlined.Code
)
Column(
horizontalAlignment = Alignment.Start,
verticalArrangement = Arrangement.spacedBy(4.dp),
) {
if (task.docUrl.isNotEmpty()) {
ClickableLink(
url = task.docUrl, linkText = "API Documentation", icon = Icons.Outlined.Description
)
}
if (task.sourceCodeUrl.isNotEmpty()) {
ClickableLink(
url = task.sourceCodeUrl, linkText = "Example code", icon = Icons.Outlined.Code
)
}
}
}
}
// List of models within a task.
items(items = models) { model ->
Box {
ModelItem(
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelClicked = onModelClicked,
modifier = Modifier.padding(horizontal = 12.dp)
)
}
}
// Title for local models.
if (localModels.isNotEmpty()) {
item(key = "localModelsTitle") {
Text(
"Local models",
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier
.padding(horizontal = 16.dp)
.padding(top = 24.dp)
)
}
}
// List of local models within a task.
items(items = localModels) { model ->
Box {
ModelItem(
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelClicked = onModelClicked,
modifier = Modifier.padding(horizontal = 12.dp)
)
}
}
item(key = "bottomPadding") {
Spacer(modifier = Modifier.height(60.dp))
}
}
// List of models within a task.
items(items = task.models) { model ->
Box {
ModelItem(
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelClicked = onModelClicked,
modifier = Modifier.padding(start = 12.dp, end = 12.dp)
)
// Add model button at the bottom right.
Box(
modifier = Modifier
.padding(end = 16.dp)
.padding(bottom = contentPadding.calculateBottomPadding())
) {
SmallFloatingActionButton(
onClick = {
showAddModelSheet = true
},
containerColor = MaterialTheme.colorScheme.secondaryContainer,
contentColor = MaterialTheme.colorScheme.secondary,
) {
Icon(Icons.Filled.Add, "")
}
}
}
if (showAddModelSheet) {
ModalBottomSheet(
onDismissRequest = { showAddModelSheet = false },
sheetState = sheetState,
) {
Text(
"Add custom model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
)
Box(modifier = Modifier.clickable {
coroutineScope.launch {
// Give it sometime to show the click effect.
delay(200)
showAddModelSheet = false
// Show file picker.
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*"
putExtra(
Intent.EXTRA_MIME_TYPES,
arrayOf("application/x-binary", "application/octet-stream")
)
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
}
filePickerLauncher.launch(intent)
}
}) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
) {
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
Text("Add local model")
}
}
}
}
if (showImportingDialog) {
curFileUri.value?.let { uri ->
ModelImportDialog(uri = uri, onDone = { info ->
showImportingDialog = false
if (info.error.isEmpty()) {
// TODO: support other model types.
modelManagerViewModel.addLocalLlmModel(
task = task,
fileName = info.fileName,
fileSize = info.fileSize
)
}
})
}
}
}
@Composable
@ -132,15 +309,11 @@ fun ClickableLink(
) {
val uriHandler = LocalUriHandler.current
val annotatedText = AnnotatedString(
text = linkText,
spanStyles = listOf(
text = linkText, spanStyles = listOf(
AnnotatedString.Range(
item = SpanStyle(
color = MaterialTheme.customColors.linkColor,
textDecoration = TextDecoration.Underline
),
start = 0,
end = linkText.length
color = MaterialTheme.customColors.linkColor, textDecoration = TextDecoration.Underline
), start = 0, end = linkText.length
)
)
)
@ -163,6 +336,7 @@ fun ClickableLink(
}
}
@RequiresApi(Build.VERSION_CODES.O)
@Preview(showBackground = true)
@Composable
fun ModelListPreview() {

View file

@ -24,16 +24,20 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.AGWorkInfo
import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.DownloadRepository
import com.google.aiedge.gallery.data.EMPTY_MODEL
import com.google.aiedge.gallery.data.HfModel
import com.google.aiedge.gallery.data.HfModelDetails
import com.google.aiedge.gallery.data.HfModelSummary
import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.data.LocalModelInfo
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.getModelByName
@ -41,6 +45,7 @@ import com.google.aiedge.gallery.ui.common.AuthConfig
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
@ -66,8 +71,12 @@ private const val TAG = "AGModelManagerViewModel"
private const val HG_COMMUNITY = "jinjingforevercommunity"
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
enum class ModelInitializationStatus {
NOT_INITIALIZED, INITIALIZING, INITIALIZED,
data class ModelInitializationStatus(
val status: ModelInitializationStatusType, var error: String = ""
)
enum class ModelInitializationStatusType {
NOT_INITIALIZED, INITIALIZING, INITIALIZED, ERROR
}
enum class TokenStatus {
@ -84,8 +93,7 @@ data class TokenStatusAndData(
)
data class TokenRequestResult(
val status: TokenRequestResultType,
val errorMessage: String? = null
val status: TokenRequestResultType, val errorMessage: String? = null
)
data class ModelManagerUiState(
@ -94,11 +102,6 @@ data class ModelManagerUiState(
*/
val tasks: List<Task>,
/**
* A map that stores lists of models indexed by task name.
*/
val modelsByTaskName: Map<String, MutableList<Model>>,
/**
* A map that tracks the download status of each model, indexed by model name.
*/
@ -191,14 +194,14 @@ open class ModelManagerViewModel(
_uiState.update { _uiState.value.copy(selectedModel = model) }
}
fun downloadModel(model: Model) {
fun downloadModel(task: Task, model: Model) {
// Update status.
setDownloadStatus(
curModel = model, status = ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS)
)
// Delete the model files first.
deleteModel(model = model)
deleteModel(task = task, model = model)
// Start to send download request.
downloadRepository.downloadModel(
@ -210,7 +213,7 @@ open class ModelManagerViewModel(
downloadRepository.cancelDownloadModel(model)
}
fun deleteModel(model: Model) {
fun deleteModel(task: Task, model: Model) {
deleteFileFromExternalFilesDir(model.downloadFileName)
for (file in model.extraDataFiles) {
deleteFileFromExternalFilesDir(file.downloadFileName)
@ -223,6 +226,24 @@ open class ModelManagerViewModel(
val curModelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
curModelDownloadStatus[model.name] =
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
// Delete model from the list if model is imported as a local model.
if (model.isLocalModel) {
val index = task.models.indexOf(model)
if (index >= 0) {
task.models.removeAt(index)
}
task.updateTrigger.value = System.currentTimeMillis()
curModelDownloadStatus.remove(model.name)
// Update preference.
val localModels = dataStoreRepository.readLocalModels().toMutableList()
val localModelIndex = localModels.indexOfFirst { it.fileName == model.name }
if (localModelIndex >= 0) {
localModels.removeAt(localModelIndex)
}
dataStoreRepository.saveLocalModels(localModels = localModels)
}
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
_uiState.update { newUiState }
}
@ -230,7 +251,7 @@ open class ModelManagerViewModel(
fun initializeModel(context: Context, model: Model, force: Boolean = false) {
viewModelScope.launch(Dispatchers.Default) {
// Skip if initialized already.
if (!force && uiState.value.modelInitializationStatus[model.name] == ModelInitializationStatus.INITIALIZED) {
if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) {
Log.d(TAG, "Model '${model.name}' has been initialized. Skipping.")
return@launch
}
@ -252,20 +273,27 @@ open class ModelManagerViewModel(
// been initialized or not. If so, skip.
launch {
delay(500)
if (model.instance == null) {
if (model.instance == null && model.initializing) {
updateModelInitializationStatus(
model = model, status = ModelInitializationStatus.INITIALIZING
model = model, status = ModelInitializationStatusType.INITIALIZING
)
}
}
val onDone: () -> Unit = {
val onDone: (error: String) -> Unit = { error ->
model.initializing = false
if (model.instance != null) {
Log.d(TAG, "Model '${model.name}' initialized successfully")
model.initializing = false
updateModelInitializationStatus(
model = model,
status = ModelInitializationStatus.INITIALIZED,
status = ModelInitializationStatusType.INITIALIZED,
)
} else if (error.isNotEmpty()) {
Log.d(TAG, "Model '${model.name}' failed to initialize")
updateModelInitializationStatus(
model = model,
status = ModelInitializationStatusType.ERROR,
error = error,
)
}
}
@ -310,7 +338,7 @@ open class ModelManagerViewModel(
model.instance = null
model.initializing = false
updateModelInitializationStatus(
model = model, status = ModelInitializationStatus.NOT_INITIALIZED
model = model, status = ModelInitializationStatusType.NOT_INITIALIZED
)
}
}
@ -380,8 +408,7 @@ open class ModelManagerViewModel(
val connection = url.openConnection() as HttpURLConnection
if (accessToken != null) {
connection.setRequestProperty(
"Authorization",
"Bearer $accessToken"
"Authorization", "Bearer $accessToken"
)
}
connection.connect()
@ -390,6 +417,47 @@ open class ModelManagerViewModel(
return connection.responseCode
}
fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) {
Log.d(TAG, "adding local model: $fileName, $fileSize")
// Create model.
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
val model = Model(
name = fileName,
url = "",
configs = configs,
sizeInBytes = fileSize,
downloadFileName = "$IMPORTS_DIR/$fileName",
isLocalModel = true,
)
model.preProcess(task = task)
task.models.add(model)
// Add initial status and states.
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
modelDownloadStatus[model.name] = ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize
)
modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
// Update ui state.
_uiState.update {
uiState.value.copy(
tasks = uiState.value.tasks.toList(),
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances
)
}
task.updateTrigger.value = System.currentTimeMillis()
// Add to preference storage.
val localModels = dataStoreRepository.readLocalModels().toMutableList()
localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize))
dataStoreRepository.saveLocalModels(localModels = localModels)
}
fun getTokenStatusAndData(): TokenStatusAndData {
// Try to load token data from DataStore.
var tokenStatus = TokenStatus.NOT_STORED
@ -436,8 +504,7 @@ open class ModelManagerViewModel(
if (dataIntent == null) {
onTokenRequested(
TokenRequestResult(
status = TokenRequestResultType.FAILED,
errorMessage = "Empty auth result"
status = TokenRequestResultType.FAILED, errorMessage = "Empty auth result"
)
)
return
@ -481,8 +548,7 @@ open class ModelManagerViewModel(
} else {
onTokenRequested(
TokenRequestResult(
status = TokenRequestResultType.FAILED,
errorMessage = errorMessage
status = TokenRequestResultType.FAILED, errorMessage = errorMessage
)
)
}
@ -513,23 +579,49 @@ open class ModelManagerViewModel(
}
private fun createUiState(): ModelManagerUiState {
val modelsByTaskName: Map<String, MutableList<Model>> =
TASKS.associate { task -> task.type.label to task.models }
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
for ((_, models) in modelsByTaskName.entries) {
for (model in models) {
for (task in TASKS) {
for (model in task.models) {
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED
modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
}
}
// Load local models.
for (localModel in dataStoreRepository.readLocalModels()) {
Log.d(TAG, "stored local model: $localModel")
// Create model.
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
val model = Model(
name = localModel.fileName,
url = "",
configs = configs,
sizeInBytes = localModel.fileSize,
downloadFileName = "$IMPORTS_DIR/${localModel.fileName}",
isLocalModel = true,
)
// Add to task.
val task = TASK_LLM_CHAT
model.preProcess(task = task)
task.models.add(model)
// Update status.
modelDownloadStatus[model.name] = ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED,
receivedBytes = localModel.fileSize,
totalBytes = localModel.fileSize
)
}
val textInputHistory = dataStoreRepository.readTextInputHistory()
Log.d(TAG, "text input history: $textInputHistory")
return ModelManagerUiState(
tasks = TASKS,
modelsByTaskName = modelsByTaskName,
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances,
textInputHistory = textInputHistory,
@ -610,7 +702,8 @@ open class ModelManagerViewModel(
// Add initial status and states.
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED
modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
}
}
}
@ -677,9 +770,13 @@ open class ModelManagerViewModel(
}
}
private fun updateModelInitializationStatus(model: Model, status: ModelInitializationStatus) {
private fun updateModelInitializationStatus(
model: Model,
status: ModelInitializationStatusType,
error: String = ""
) {
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
curModelInstance[model.name] = status
curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error)
val newUiState = uiState.value.copy(modelInitializationStatus = curModelInstance)
_uiState.update { newUiState }
}

View file

@ -18,6 +18,7 @@ package com.google.aiedge.gallery.ui.preview
import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.LocalModelInfo
class PreviewDataStoreRepository : DataStoreRepository {
override fun saveTextInputHistory(history: List<String>) {
@ -40,4 +41,11 @@ class PreviewDataStoreRepository : DataStoreRepository {
override fun readAccessTokenData(): AccessTokenData? {
return null
}
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
}
override fun readLocalModels(): List<LocalModelInfo> {
return listOf()
}
}

View file

@ -17,7 +17,6 @@
package com.google.aiedge.gallery.ui.preview
import android.content.Context
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerUiState
@ -39,8 +38,6 @@ class PreviewModelManagerViewModel(context: Context) :
}
}
val modelsByTaskName: Map<String, MutableList<Model>> =
ALL_PREVIEW_TASKS.associate { task -> task.type.label to task.models }
val modelDownloadStatus = mapOf(
MODEL_TEST1.name to ModelDownloadStatus(
status = ModelDownloadStatusType.IN_PROGRESS,
@ -61,7 +58,6 @@ class PreviewModelManagerViewModel(context: Context) :
)
val newUiState = ModelManagerUiState(
tasks = ALL_PREVIEW_TASKS,
modelsByTaskName = modelsByTaskName,
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = mapOf(),
selectedModel = MODEL_TEST2,

View file

@ -40,14 +40,14 @@ class TextClassificationInferenceResult(
* Helper object for managing text classification models.
*/
object TextClassificationModelHelper {
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
val modelByteBuffer = readFileToByteBuffer(File(model.getPath(context = context)))
if (modelByteBuffer != null) {
val options = TextClassifier.TextClassifierOptions.builder().setBaseOptions(
BaseOptions.builder().setModelAssetBuffer(modelByteBuffer).build()
).build()
model.instance = TextClassifier.createFromOptions(context, options)
onDone()
onDone("")
}
}