Add initial support for importing local model.

This commit is contained in:
Jing Jin 2025-04-17 17:54:57 -07:00
parent b2f35a86e7
commit 29b614355e
19 changed files with 789 additions and 126 deletions

View file

@ -47,6 +47,8 @@ interface DataStoreRepository {
fun readThemeOverride(): String fun readThemeOverride(): String
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
fun readAccessTokenData(): AccessTokenData? fun readAccessTokenData(): AccessTokenData?
fun saveLocalModels(localModels: List<LocalModelInfo>)
fun readLocalModels(): List<LocalModelInfo>
} }
/** /**
@ -79,6 +81,9 @@ class DefaultDataStoreRepository(
val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv") val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv")
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at") val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
// Data for all imported local models.
val LOCAL_MODELS = stringPreferencesKey("local_models")
} }
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key" private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
@ -155,6 +160,26 @@ class DefaultDataStoreRepository(
} }
} }
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
runBlocking {
dataStore.edit { preferences ->
val gson = Gson()
val jsonString = gson.toJson(localModels)
preferences[PreferencesKeys.LOCAL_MODELS] = jsonString
}
}
}
override fun readLocalModels(): List<LocalModelInfo> {
return runBlocking {
val preferences = dataStore.data.first()
val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]"
val gson = Gson()
val listType = object : TypeToken<List<LocalModelInfo>>() {}.type
gson.fromJson(infosStr, listType)
}
}
private fun getTextInputHistory(preferences: Preferences): List<String> { private fun getTextInputHistory(preferences: Preferences): List<String> {
val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]" val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]"
val gson = Gson() val gson = Gson()

View file

@ -16,6 +16,7 @@
package com.google.aiedge.gallery.data package com.google.aiedge.gallery.data
import com.google.aiedge.gallery.ui.common.ensureValidFileName
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
import kotlinx.serialization.KSerializer import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
@ -103,7 +104,7 @@ data class HfModel(
} else { } else {
listOf("") listOf("")
} }
val fileName = "${id}_${(parts.lastOrNull() ?: "")}".replace(Regex("[^a-zA-Z0-9._-]"), "_") val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
// Generate configs based on the given default values. // Generate configs based on the given default values.
val configs: List<Config> = when (task) { val configs: List<Config> = when (task) {

View file

@ -32,6 +32,8 @@ enum class LlmBackend {
CPU, GPU CPU, GPU
} }
const val IMPORTS_DIR = "__imports"
/** A model for a task */ /** A model for a task */
data class Model( data class Model(
/** The Hugging Face model ID (if applicable). */ /** The Hugging Face model ID (if applicable). */
@ -85,6 +87,9 @@ data class Model(
/** The prompt templates for the model (only for LLM). */ /** The prompt templates for the model (only for LLM). */
val llmPromptTemplates: List<PromptTemplate> = listOf(), val llmPromptTemplates: List<PromptTemplate> = listOf(),
/** Whether the model is imported as a local model. */
val isLocalModel: Boolean = false,
// The following fields are managed by the app. Don't need to set manually. // The following fields are managed by the app. Don't need to set manually.
var taskType: TaskType? = null, var taskType: TaskType? = null,
var instance: Any? = null, var instance: Any? = null,
@ -104,10 +109,11 @@ data class Model(
} }
fun getPath(context: Context, fileName: String = downloadFileName): String { fun getPath(context: Context, fileName: String = downloadFileName): String {
val baseDir = "${context.getExternalFilesDir(null)}"
return if (this.isZip && this.unzipDir.isNotEmpty()) { return if (this.isZip && this.unzipDir.isNotEmpty()) {
"${context.getExternalFilesDir(null)}/${this.unzipDir}" "$baseDir/${this.unzipDir}"
} else { } else {
"${context.getExternalFilesDir(null)}/${fileName}" "$baseDir/${fileName}"
} }
} }
@ -140,6 +146,9 @@ data class Model(
} }
} }
/** Data for a imported local model. */
data class LocalModelInfo(val fileName: String, val fileSize: Long)
enum class ModelDownloadStatusType { enum class ModelDownloadStatusType {
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED, NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
} }

View file

@ -19,6 +19,8 @@ package com.google.aiedge.gallery.data
import androidx.annotation.StringRes import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ImageSearch import androidx.compose.material.icons.rounded.ImageSearch
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
@ -63,7 +65,9 @@ data class Task(
@StringRes val textInputPlaceHolderRes: Int = R.string.chat_textinput_placeholder, @StringRes val textInputPlaceHolderRes: Int = R.string.chat_textinput_placeholder,
// The following fields are managed by the app. Don't need to set manually. // The following fields are managed by the app. Don't need to set manually.
var index: Int = -1 var index: Int = -1,
val updateTrigger: MutableState<Long> = mutableStateOf(0)
) )
val TASK_TEXT_CLASSIFICATION = Task( val TASK_TEXT_CLASSIFICATION = Task(

View file

@ -46,6 +46,7 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.modelmanager.TokenRequestResultType import com.google.aiedge.gallery.ui.modelmanager.TokenRequestResultType
import com.google.aiedge.gallery.ui.modelmanager.TokenStatus import com.google.aiedge.gallery.ui.modelmanager.TokenStatus
@ -90,6 +91,7 @@ private const val TAG = "AGDownloadAndTryButton"
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun DownloadAndTryButton( fun DownloadAndTryButton(
task: Task,
model: Model, model: Model,
enabled: Boolean, enabled: Boolean,
needToDownloadFirst: Boolean, needToDownloadFirst: Boolean,
@ -106,17 +108,18 @@ fun DownloadAndTryButton(
val permissionLauncher = rememberLauncherForActivityResult( val permissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission() ActivityResultContracts.RequestPermission()
) { ) {
modelManagerViewModel.downloadModel(model) modelManagerViewModel.downloadModel(task = task, model = model)
} }
// Function to kick off download. // Function to kick off download.
val startDownload: (accessToken: String?) -> Unit = { accessToken -> val startDownload: (accessToken: String?) -> Unit = { accessToken ->
model.accessToken = accessToken model.accessToken = accessToken
onClicked() onClicked()
checkNotificationPermissonAndStartDownload( checkNotificationPermissionAndStartDownload(
context = context, context = context,
launcher = permissionLauncher, launcher = permissionLauncher,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
task = task,
model = model model = model
) )
checkingToken = false checkingToken = false

View file

@ -416,10 +416,11 @@ fun getTaskIconColor(index: Int): Color {
return MaterialTheme.customColors.taskIconColors[colorIndex] return MaterialTheme.customColors.taskIconColors[colorIndex]
} }
fun checkNotificationPermissonAndStartDownload( fun checkNotificationPermissionAndStartDownload(
context: Context, context: Context,
launcher: ManagedActivityResultLauncher<String, Boolean>, launcher: ManagedActivityResultLauncher<String, Boolean>,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
task: Task,
model: Model model: Model
) { ) {
// Check permission // Check permission
@ -428,7 +429,7 @@ fun checkNotificationPermissonAndStartDownload(
ContextCompat.checkSelfPermission( ContextCompat.checkSelfPermission(
context, Manifest.permission.POST_NOTIFICATIONS context, Manifest.permission.POST_NOTIFICATIONS
) -> { ) -> {
modelManagerViewModel.downloadModel(model) modelManagerViewModel.downloadModel(task = task, model = model)
} }
// Otherwise, ask for permission // Otherwise, ask for permission
@ -440,3 +441,14 @@ fun checkNotificationPermissonAndStartDownload(
} }
} }
fun ensureValidFileName(fileName: String): String {
return fileName.replace(Regex("[^a-zA-Z0-9._-]"), "_")
}
fun cleanUpMediapipeTaskErrorMessage(message: String): String {
val index = message.indexOf("=== Source Location Trace")
if (index >= 0) {
return message.substring(0, index)
}
return message
}

View file

@ -41,10 +41,13 @@ import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Timer import androidx.compose.material.icons.outlined.Timer
import androidx.compose.material.icons.rounded.ContentCopy import androidx.compose.material.icons.rounded.ContentCopy
import androidx.compose.material.icons.rounded.Refresh import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
@ -83,11 +86,12 @@ import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.google.aiedge.gallery.R import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatus import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewChatModel import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
@ -113,6 +117,7 @@ fun ChatPanel(
onSendMessage: (Model, ChatMessage) -> Unit, onSendMessage: (Model, ChatMessage) -> Unit,
onRunAgainClicked: (Model, ChatMessage) -> Unit, onRunAgainClicked: (Model, ChatMessage) -> Unit,
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit, onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> }, onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
onStreamEnd: (Int) -> Unit = {}, onStreamEnd: (Int) -> Unit = {},
@ -140,6 +145,8 @@ fun ChatPanel(
var showMessageLongPressedSheet by remember { mutableStateOf(false) } var showMessageLongPressedSheet by remember { mutableStateOf(false) }
val longPressedMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) } val longPressedMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
var showErrorDialog by remember { mutableStateOf(false) }
// Keep track of the last message and last message content. // Keep track of the last message and last message content.
val lastMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) } val lastMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
val lastMessageContent: MutableState<String> = remember { mutableStateOf("") } val lastMessageContent: MutableState<String> = remember { mutableStateOf("") }
@ -201,6 +208,10 @@ fun ChatPanel(
val modelInitializationStatus = val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[selectedModel.name] modelManagerUiState.modelInitializationStatus[selectedModel.name]
LaunchedEffect(modelInitializationStatus) {
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
}
Column( Column(
modifier = modifier.imePadding() modifier = modifier.imePadding()
) { ) {
@ -417,7 +428,7 @@ fun ChatPanel(
// Model initialization in-progress message. // Model initialization in-progress message.
this@Column.AnimatedVisibility( this@Column.AnimatedVisibility(
visible = modelInitializationStatus == ModelInitializationStatus.INITIALIZING, visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(), enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(), exit = scaleOut() + fadeOut(),
modifier = Modifier.offset(y = 12.dp) modifier = Modifier.offset(y = 12.dp)
@ -479,6 +490,47 @@ fun ChatPanel(
} }
} }
// Error dialog.
if (showErrorDialog) {
Dialog(
onDismissRequest = {
showErrorDialog = false
navigateUp()
},
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title
Text(
"Error",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
// Error
Text(
modelInitializationStatus?.error ?: "",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.error,
)
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button(onClick = {
showErrorDialog = false
navigateUp()
}) {
Text("Close")
}
}
}
}
}
}
// Benchmark config dialog. // Benchmark config dialog.
if (showBenchmarkConfigsDialog) { if (showBenchmarkConfigsDialog) {
BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false }, BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false },
@ -547,6 +599,7 @@ fun ChatPanelPreview() {
task = task, task = task,
selectedModel = TASK_TEST1.models[1], selectedModel = TASK_TEST1.models[1],
viewModel = PreviewChatModel(context = context), viewModel = PreviewChatModel(context = context),
navigateUp = {},
onSendMessage = { _, _ -> }, onSendMessage = { _, _ -> },
onRunAgainClicked = { _, _ -> }, onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> }, onBenchmarkClicked = { _, _, _, _ -> },

View file

@ -55,7 +55,7 @@ import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewChatModel import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
@ -104,7 +104,7 @@ fun ChatView(
val launcher = rememberLauncherForActivityResult( val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission() ActivityResultContracts.RequestPermission()
) { ) {
modelManagerViewModel.downloadModel(selectedModel) modelManagerViewModel.downloadModel(task = task, model = selectedModel)
} }
val handleNavigateUp = { val handleNavigateUp = {
@ -245,10 +245,11 @@ fun ChatView(
exit = fadeOut() exit = fadeOut()
) { ) {
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = { ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
checkNotificationPermissonAndStartDownload( checkNotificationPermissionAndStartDownload(
context = context, context = context,
launcher = launcher, launcher = launcher,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
task = task,
model = curSelectedModel model = curSelectedModel
) )
}) })
@ -261,6 +262,7 @@ fun ChatView(
task = task, task = task,
selectedModel = curSelectedModel, selectedModel = curSelectedModel,
viewModel = viewModel, viewModel = viewModel,
navigateUp = navigateUp,
onSendMessage = onSendMessage, onSendMessage = onSendMessage,
onRunAgainClicked = onRunAgainClicked, onRunAgainClicked = onRunAgainClicked,
onBenchmarkClicked = onBenchmarkClicked, onBenchmarkClicked = onBenchmarkClicked,

View file

@ -35,6 +35,7 @@ import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ChevronRight
import androidx.compose.material.icons.rounded.Settings import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material.icons.rounded.UnfoldLess import androidx.compose.material.icons.rounded.UnfoldLess
import androidx.compose.material.icons.rounded.UnfoldMore import androidx.compose.material.icons.rounded.UnfoldMore
@ -68,7 +69,7 @@ import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
import com.google.aiedge.gallery.ui.common.TaskIcon import com.google.aiedge.gallery.ui.common.TaskIcon
import com.google.aiedge.gallery.ui.common.chat.MarkdownText import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
import com.google.aiedge.gallery.ui.common.getTaskBgColor import com.google.aiedge.gallery.ui.common.getTaskBgColor
import com.google.aiedge.gallery.ui.common.getTaskIconColor import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@ -113,7 +114,7 @@ fun ModelItem(
val launcher = rememberLauncherForActivityResult( val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission() ActivityResultContracts.RequestPermission()
) { ) {
modelManagerViewModel.downloadModel(model) modelManagerViewModel.downloadModel(task = task, model = model)
} }
var isExpanded by remember { mutableStateOf(false) } var isExpanded by remember { mutableStateOf(false) }
@ -156,10 +157,11 @@ fun ModelItem(
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
downloadStatus = downloadStatus, downloadStatus = downloadStatus,
onDownloadClicked = { model -> onDownloadClicked = { model ->
checkNotificationPermissonAndStartDownload( checkNotificationPermissionAndStartDownload(
context = context, context = context,
launcher = launcher, launcher = launcher,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
task = task,
model = model model = model
) )
}, },
@ -186,7 +188,9 @@ fun ModelItem(
} }
} else { } else {
Icon( Icon(
if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore, // For local model, show ">" directly indicating users can just tap the model item to
// go into it without needing to expand it first.
if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "", contentDescription = "",
tint = getTaskIconColor(task), tint = getTaskIconColor(task),
) )
@ -237,6 +241,7 @@ fun ModelItem(
val needToDownloadFirst = val needToDownloadFirst =
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
DownloadAndTryButton( DownloadAndTryButton(
task = task,
model = model, model = model,
enabled = isExpanded, enabled = isExpanded,
needToDownloadFirst = needToDownloadFirst, needToDownloadFirst = needToDownloadFirst,
@ -266,7 +271,13 @@ fun ModelItem(
) )
boxModifier = if (canExpand) { boxModifier = if (canExpand) {
boxModifier.clickable( boxModifier.clickable(
onClick = { isExpanded = !isExpanded }, onClick = {
if (!model.isLocalModel) {
isExpanded = !isExpanded
} else {
onModelClicked(model)
}
},
interactionSource = remember { MutableInteractionSource() }, interactionSource = remember { MutableInteractionSource() },
indication = ripple( indication = ripple(
bounded = true, bounded = true,

View file

@ -124,7 +124,7 @@ fun ModelItemActionButton(
if (showConfirmDeleteDialog) { if (showConfirmDeleteDialog) {
ConfirmDeleteModelDialog(model = model, onConfirm = { ConfirmDeleteModelDialog(model = model, onConfirm = {
modelManagerViewModel.deleteModel(model) modelManagerViewModel.deleteModel(task = task, model = model)
showConfirmDeleteDialog = false showConfirmDeleteDialog = false
}, onDismiss = { }, onDismiss = {
showConfirmDeleteDialog = false showConfirmDeleteDialog = false

View file

@ -48,7 +48,7 @@ class ImageClassificationInferenceResult(
//TODO: handle error. //TODO: handle error.
object ImageClassificationModelHelper { object ImageClassificationModelHelper {
fun initialize(context: Context, model: Model, onDone: () -> Unit) { fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU) val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU)
TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask -> TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask ->
val optionsBuilder = TfLiteInitializationOptions.builder() val optionsBuilder = TfLiteInitializationOptions.builder()
@ -69,7 +69,7 @@ object ImageClassificationModelHelper {
File(model.getPath(context = context)), interpreterOption File(model.getPath(context = context)), interpreterOption
) )
model.instance = interpreter model.instance = interpreter
onDone() onDone("")
} }
} }
} }

View file

@ -24,6 +24,7 @@ import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator
import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.LatencyProvider import com.google.aiedge.gallery.ui.common.LatencyProvider
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import kotlin.random.Random import kotlin.random.Random
private const val TAG = "AGImageGenerationModelHelper" private const val TAG = "AGImageGenerationModelHelper"
@ -33,12 +34,17 @@ class ImageGenerationInferenceResult(
) : LatencyProvider ) : LatencyProvider
object ImageGenerationModelHelper { object ImageGenerationModelHelper {
fun initialize(context: Context, model: Model, onDone: () -> Unit) { fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
val options = ImageGenerator.ImageGeneratorOptions.builder() try {
.setImageGeneratorModelDirectory(model.getPath(context = context)) val options = ImageGenerator.ImageGeneratorOptions.builder()
.build() .setImageGeneratorModelDirectory(model.getPath(context = context))
model.instance = ImageGenerator.createFromOptions(context, options) .build()
onDone() model.instance = ImageGenerator.createFromOptions(context, options)
} catch (e: Exception) {
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
return
}
onDone("")
} }
fun cleanUp(model: Model) { fun cleanUp(model: Model) {

View file

@ -21,6 +21,7 @@ import android.util.Log
import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.LlmBackend import com.google.aiedge.gallery.data.LlmBackend
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
@ -40,7 +41,7 @@ object LlmChatModelHelper {
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf() private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
fun initialize( fun initialize(
context: Context, model: Model, onDone: () -> Unit context: Context, model: Model, onDone: (String) -> Unit
) { ) {
val maxTokens = val maxTokens =
model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN) model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN)
@ -68,9 +69,10 @@ object LlmChatModelHelper {
) )
model.instance = LlmModelInstance(engine = llmInference, session = session) model.instance = LlmModelInstance(engine = llmInference, session = session)
} catch (e: Exception) { } catch (e: Exception) {
e.printStackTrace() onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
return
} }
onDone() onDone("")
} }
fun cleanUp(model: Model) { fun cleanUp(model: Model) {

View file

@ -0,0 +1,260 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.modelmanager
import android.content.Context
import android.net.Uri
import android.provider.OpenableColumns
import android.util.Log
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.Icon
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties
import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.ui.common.ensureValidFileName
import com.google.aiedge.gallery.ui.common.humanReadableSize
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import java.io.File
import java.io.FileOutputStream
import java.net.URLDecoder
import java.nio.charset.StandardCharsets
private const val TAG = "AGModelImportDialog"
data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "")
@Composable
fun ModelImportDialog(
uri: Uri, onDone: (ModelImportInfo) -> Unit
) {
val context = LocalContext.current
val coroutineScope = rememberCoroutineScope()
var fileName by remember { mutableStateOf("") }
var fileSize by remember { mutableLongStateOf(0L) }
var error by remember { mutableStateOf("") }
var progress by remember { mutableFloatStateOf(0f) }
LaunchedEffect(Unit) {
error = ""
// Get basic info.
val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri)
fileSize = info.first
fileName = ensureValidFileName(info.second)
// Import.
importModel(
context = context,
coroutineScope = coroutineScope,
fileName = fileName,
fileSize = fileSize,
uri = uri,
onDone = {
onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error))
},
onProgress = {
progress = it
},
onError = {
error = it
}
)
}
Dialog(
properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false),
onDismissRequest = {},
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title.
Text(
"Importing...",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
// No error.
if (error.isEmpty()) {
// Progress bar.
Column(verticalArrangement = Arrangement.spacedBy(4.dp)) {
Text(
"$fileName (${fileSize.humanReadableSize()})",
style = MaterialTheme.typography.labelSmall,
)
val animatedProgress = remember { Animatable(0f) }
LinearProgressIndicator(
progress = { animatedProgress.value },
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 8.dp),
)
LaunchedEffect(progress) {
animatedProgress.animateTo(progress, animationSpec = tween(150))
}
}
}
// Has error.
else {
Row(
verticalAlignment = Alignment.Top,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(
Icons.Rounded.Error,
contentDescription = "",
tint = MaterialTheme.colorScheme.error
)
Text(
error,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.error,
modifier = Modifier.padding(top = 4.dp)
)
}
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button(onClick = {
onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error))
}) {
Text("Close")
}
}
}
}
}
}
}
private fun importModel(
context: Context,
coroutineScope: CoroutineScope,
fileName: String,
fileSize: Long,
uri: Uri,
onDone: () -> Unit,
onProgress: (Float) -> Unit,
onError: (String) -> Unit,
) {
// TODO: handle error.
coroutineScope.launch(Dispatchers.IO) {
// Get the last component of the uri path as the imported file name.
val decodedUri = URLDecoder.decode(uri.toString(), StandardCharsets.UTF_8.name())
Log.d(TAG, "importing model from $decodedUri. File name: $fileName. File size: $fileSize")
// Create <app_external_dir>/imports if not exist.
val importsDir = File(context.getExternalFilesDir(null), IMPORTS_DIR)
if (!importsDir.exists()) {
importsDir.mkdirs()
}
// Import by copying the file over.
val outputFile = File(context.getExternalFilesDir(null), "$IMPORTS_DIR/$fileName")
val outputStream = FileOutputStream(outputFile)
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
var bytesRead: Int
var lastSetProgressTs: Long = 0
var importedBytes = 0L
val inputStream = context.contentResolver.openInputStream(uri)
try {
if (inputStream != null) {
while (inputStream.read(buffer).also { bytesRead = it } != -1) {
outputStream.write(buffer, 0, bytesRead)
importedBytes += bytesRead
// Report progress every 200 ms.
val curTs = System.currentTimeMillis()
if (curTs - lastSetProgressTs > 200) {
Log.d(TAG, "importing progress: $importedBytes, $fileSize")
lastSetProgressTs = curTs
if (fileSize != 0L) {
onProgress(importedBytes.toFloat() / fileSize.toFloat())
}
}
}
}
} catch (e: Exception) {
e.printStackTrace()
onError(e.message ?: "Failed to import")
return@launch
} finally {
inputStream?.close()
outputStream.close()
}
Log.d(TAG, "import done")
onProgress(1f)
onDone()
}
}
private fun getFileSizeAndDisplayNameFromUri(context: Context, uri: Uri): Pair<Long, String> {
val contentResolver = context.contentResolver
var fileSize = 0L
var displayName = ""
try {
contentResolver.query(
uri, arrayOf(OpenableColumns.SIZE, OpenableColumns.DISPLAY_NAME), null, null, null
)?.use { cursor ->
if (cursor.moveToFirst()) {
val sizeIndex = cursor.getColumnIndexOrThrow(OpenableColumns.SIZE)
fileSize = cursor.getLong(sizeIndex)
val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME)
displayName = cursor.getString(nameIndex)
}
}
} catch (e: Exception) {
e.printStackTrace()
return Pair(0L, "")
}
return Pair(fileSize, displayName)
}

View file

@ -16,24 +16,46 @@
package com.google.aiedge.gallery.ui.modelmanager package com.google.aiedge.gallery.ui.modelmanager
import android.content.Intent
import android.net.Uri
import android.os.Build
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher
import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.RequiresApi
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.items
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.outlined.Code import androidx.compose.material.icons.outlined.Code
import androidx.compose.material.icons.outlined.Description import androidx.compose.material.icons.outlined.Description
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.SmallFloatingActionButton
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
@ -53,8 +75,14 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private const val TAG = "AGModelList"
/** The list of models in the model manager. */ /** The list of models in the model manager. */
@RequiresApi(Build.VERSION_CODES.O)
@OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun ModelList( fun ModelList(
task: Task, task: Task,
@ -63,65 +91,214 @@ fun ModelList(
onModelClicked: (Model) -> Unit, onModelClicked: (Model) -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
) { ) {
LazyColumn( var showAddModelSheet by remember { mutableStateOf(false) }
modifier = modifier.padding(top = 8.dp), var showImportingDialog by remember { mutableStateOf(false) }
contentPadding = contentPadding, val curFileUri = remember { mutableStateOf<Uri?>(null) }
verticalArrangement = Arrangement.spacedBy(8.dp), val sheetState = rememberModalBottomSheetState()
) { val coroutineScope = rememberCoroutineScope()
// Headline.
item(key = "headline") {
Text(
task.description,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier
.fillMaxWidth()
)
}
// URLs. // This is just to update "models" list when task.updateTrigger is updated so that the UI can
item(key = "urls") { // be properly updated.
Row( val models by remember {
horizontalArrangement = Arrangement.Center, derivedStateOf {
modifier = Modifier val trigger = task.updateTrigger.value
.fillMaxWidth() if (trigger >= 0) {
.padding(top = 12.dp, bottom = 16.dp), task.models.toList().filter { !it.isLocalModel }
) { } else {
Column( listOf()
horizontalAlignment = Alignment.Start, }
verticalArrangement = Arrangement.spacedBy(4.dp), }
}
val localModels by remember {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
task.models.toList().filter { it.isLocalModel }
} else {
listOf()
}
}
}
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
curFileUri.value = uri
showImportingDialog = true
} ?: run {
Log.d(TAG, "No file selected or URI is null.")
}
} else {
Log.d(TAG, "File picking cancelled.")
}
}
Box(contentAlignment = Alignment.BottomEnd) {
LazyColumn(
modifier = modifier.padding(top = 8.dp),
contentPadding = contentPadding,
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// Headline.
item(key = "headline") {
Text(
task.description,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.fillMaxWidth()
)
}
// URLs.
item(key = "urls") {
Row(
horizontalArrangement = Arrangement.Center,
modifier = Modifier
.fillMaxWidth()
.padding(top = 12.dp, bottom = 16.dp),
) { ) {
if (task.docUrl.isNotEmpty()) { Column(
ClickableLink( horizontalAlignment = Alignment.Start,
url = task.docUrl, verticalArrangement = Arrangement.spacedBy(4.dp),
linkText = "API Documentation", ) {
icon = Icons.Outlined.Description if (task.docUrl.isNotEmpty()) {
) ClickableLink(
} url = task.docUrl, linkText = "API Documentation", icon = Icons.Outlined.Description
if (task.sourceCodeUrl.isNotEmpty()) { )
ClickableLink( }
url = task.sourceCodeUrl, if (task.sourceCodeUrl.isNotEmpty()) {
linkText = "Example code", ClickableLink(
icon = Icons.Outlined.Code url = task.sourceCodeUrl, linkText = "Example code", icon = Icons.Outlined.Code
) )
}
} }
} }
} }
// List of models within a task.
items(items = models) { model ->
Box {
ModelItem(
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelClicked = onModelClicked,
modifier = Modifier.padding(horizontal = 12.dp)
)
}
}
// Title for local models.
if (localModels.isNotEmpty()) {
item(key = "localModelsTitle") {
Text(
"Local models",
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier
.padding(horizontal = 16.dp)
.padding(top = 24.dp)
)
}
}
// List of local models within a task.
items(items = localModels) { model ->
Box {
ModelItem(
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelClicked = onModelClicked,
modifier = Modifier.padding(horizontal = 12.dp)
)
}
}
item(key = "bottomPadding") {
Spacer(modifier = Modifier.height(60.dp))
}
} }
// List of models within a task. // Add model button at the bottom right.
items(items = task.models) { model -> Box(
Box { modifier = Modifier
ModelItem( .padding(end = 16.dp)
model = model, .padding(bottom = contentPadding.calculateBottomPadding())
task = task, ) {
modelManagerViewModel = modelManagerViewModel, SmallFloatingActionButton(
onModelClicked = onModelClicked, onClick = {
modifier = Modifier.padding(start = 12.dp, end = 12.dp) showAddModelSheet = true
) },
containerColor = MaterialTheme.colorScheme.secondaryContainer,
contentColor = MaterialTheme.colorScheme.secondary,
) {
Icon(Icons.Filled.Add, "")
} }
} }
} }
if (showAddModelSheet) {
ModalBottomSheet(
onDismissRequest = { showAddModelSheet = false },
sheetState = sheetState,
) {
Text(
"Add custom model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
)
Box(modifier = Modifier.clickable {
coroutineScope.launch {
// Give it sometime to show the click effect.
delay(200)
showAddModelSheet = false
// Show file picker.
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*"
putExtra(
Intent.EXTRA_MIME_TYPES,
arrayOf("application/x-binary", "application/octet-stream")
)
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
}
filePickerLauncher.launch(intent)
}
}) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
) {
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
Text("Add local model")
}
}
}
}
if (showImportingDialog) {
curFileUri.value?.let { uri ->
ModelImportDialog(uri = uri, onDone = { info ->
showImportingDialog = false
if (info.error.isEmpty()) {
// TODO: support other model types.
modelManagerViewModel.addLocalLlmModel(
task = task,
fileName = info.fileName,
fileSize = info.fileSize
)
}
})
}
}
} }
@Composable @Composable
@ -132,15 +309,11 @@ fun ClickableLink(
) { ) {
val uriHandler = LocalUriHandler.current val uriHandler = LocalUriHandler.current
val annotatedText = AnnotatedString( val annotatedText = AnnotatedString(
text = linkText, text = linkText, spanStyles = listOf(
spanStyles = listOf(
AnnotatedString.Range( AnnotatedString.Range(
item = SpanStyle( item = SpanStyle(
color = MaterialTheme.customColors.linkColor, color = MaterialTheme.customColors.linkColor, textDecoration = TextDecoration.Underline
textDecoration = TextDecoration.Underline ), start = 0, end = linkText.length
),
start = 0,
end = linkText.length
) )
) )
) )
@ -163,6 +336,7 @@ fun ClickableLink(
} }
} }
@RequiresApi(Build.VERSION_CODES.O)
@Preview(showBackground = true) @Preview(showBackground = true)
@Composable @Composable
fun ModelListPreview() { fun ModelListPreview() {

View file

@ -24,16 +24,20 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.AGWorkInfo import com.google.aiedge.gallery.data.AGWorkInfo
import com.google.aiedge.gallery.data.AccessTokenData import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.DataStoreRepository import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.DownloadRepository import com.google.aiedge.gallery.data.DownloadRepository
import com.google.aiedge.gallery.data.EMPTY_MODEL import com.google.aiedge.gallery.data.EMPTY_MODEL
import com.google.aiedge.gallery.data.HfModel import com.google.aiedge.gallery.data.HfModel
import com.google.aiedge.gallery.data.HfModelDetails import com.google.aiedge.gallery.data.HfModelDetails
import com.google.aiedge.gallery.data.HfModelSummary import com.google.aiedge.gallery.data.HfModelSummary
import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.data.LocalModelInfo
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.getModelByName import com.google.aiedge.gallery.data.getModelByName
@ -41,6 +45,7 @@ import com.google.aiedge.gallery.ui.common.AuthConfig
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async import kotlinx.coroutines.async
@ -66,8 +71,12 @@ private const val TAG = "AGModelManagerViewModel"
private const val HG_COMMUNITY = "jinjingforevercommunity" private const val HG_COMMUNITY = "jinjingforevercommunity"
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50 private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
enum class ModelInitializationStatus { data class ModelInitializationStatus(
NOT_INITIALIZED, INITIALIZING, INITIALIZED, val status: ModelInitializationStatusType, var error: String = ""
)
enum class ModelInitializationStatusType {
NOT_INITIALIZED, INITIALIZING, INITIALIZED, ERROR
} }
enum class TokenStatus { enum class TokenStatus {
@ -84,8 +93,7 @@ data class TokenStatusAndData(
) )
data class TokenRequestResult( data class TokenRequestResult(
val status: TokenRequestResultType, val status: TokenRequestResultType, val errorMessage: String? = null
val errorMessage: String? = null
) )
data class ModelManagerUiState( data class ModelManagerUiState(
@ -94,11 +102,6 @@ data class ModelManagerUiState(
*/ */
val tasks: List<Task>, val tasks: List<Task>,
/**
* A map that stores lists of models indexed by task name.
*/
val modelsByTaskName: Map<String, MutableList<Model>>,
/** /**
* A map that tracks the download status of each model, indexed by model name. * A map that tracks the download status of each model, indexed by model name.
*/ */
@ -191,14 +194,14 @@ open class ModelManagerViewModel(
_uiState.update { _uiState.value.copy(selectedModel = model) } _uiState.update { _uiState.value.copy(selectedModel = model) }
} }
fun downloadModel(model: Model) { fun downloadModel(task: Task, model: Model) {
// Update status. // Update status.
setDownloadStatus( setDownloadStatus(
curModel = model, status = ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS) curModel = model, status = ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS)
) )
// Delete the model files first. // Delete the model files first.
deleteModel(model = model) deleteModel(task = task, model = model)
// Start to send download request. // Start to send download request.
downloadRepository.downloadModel( downloadRepository.downloadModel(
@ -210,7 +213,7 @@ open class ModelManagerViewModel(
downloadRepository.cancelDownloadModel(model) downloadRepository.cancelDownloadModel(model)
} }
fun deleteModel(model: Model) { fun deleteModel(task: Task, model: Model) {
deleteFileFromExternalFilesDir(model.downloadFileName) deleteFileFromExternalFilesDir(model.downloadFileName)
for (file in model.extraDataFiles) { for (file in model.extraDataFiles) {
deleteFileFromExternalFilesDir(file.downloadFileName) deleteFileFromExternalFilesDir(file.downloadFileName)
@ -223,6 +226,24 @@ open class ModelManagerViewModel(
val curModelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap() val curModelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
curModelDownloadStatus[model.name] = curModelDownloadStatus[model.name] =
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED) ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
// Delete model from the list if model is imported as a local model.
if (model.isLocalModel) {
val index = task.models.indexOf(model)
if (index >= 0) {
task.models.removeAt(index)
}
task.updateTrigger.value = System.currentTimeMillis()
curModelDownloadStatus.remove(model.name)
// Update preference.
val localModels = dataStoreRepository.readLocalModels().toMutableList()
val localModelIndex = localModels.indexOfFirst { it.fileName == model.name }
if (localModelIndex >= 0) {
localModels.removeAt(localModelIndex)
}
dataStoreRepository.saveLocalModels(localModels = localModels)
}
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus) val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
_uiState.update { newUiState } _uiState.update { newUiState }
} }
@ -230,7 +251,7 @@ open class ModelManagerViewModel(
fun initializeModel(context: Context, model: Model, force: Boolean = false) { fun initializeModel(context: Context, model: Model, force: Boolean = false) {
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
// Skip if initialized already. // Skip if initialized already.
if (!force && uiState.value.modelInitializationStatus[model.name] == ModelInitializationStatus.INITIALIZED) { if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) {
Log.d(TAG, "Model '${model.name}' has been initialized. Skipping.") Log.d(TAG, "Model '${model.name}' has been initialized. Skipping.")
return@launch return@launch
} }
@ -252,20 +273,27 @@ open class ModelManagerViewModel(
// been initialized or not. If so, skip. // been initialized or not. If so, skip.
launch { launch {
delay(500) delay(500)
if (model.instance == null) { if (model.instance == null && model.initializing) {
updateModelInitializationStatus( updateModelInitializationStatus(
model = model, status = ModelInitializationStatus.INITIALIZING model = model, status = ModelInitializationStatusType.INITIALIZING
) )
} }
} }
val onDone: () -> Unit = { val onDone: (error: String) -> Unit = { error ->
model.initializing = false
if (model.instance != null) { if (model.instance != null) {
Log.d(TAG, "Model '${model.name}' initialized successfully") Log.d(TAG, "Model '${model.name}' initialized successfully")
model.initializing = false
updateModelInitializationStatus( updateModelInitializationStatus(
model = model, model = model,
status = ModelInitializationStatus.INITIALIZED, status = ModelInitializationStatusType.INITIALIZED,
)
} else if (error.isNotEmpty()) {
Log.d(TAG, "Model '${model.name}' failed to initialize")
updateModelInitializationStatus(
model = model,
status = ModelInitializationStatusType.ERROR,
error = error,
) )
} }
} }
@ -310,7 +338,7 @@ open class ModelManagerViewModel(
model.instance = null model.instance = null
model.initializing = false model.initializing = false
updateModelInitializationStatus( updateModelInitializationStatus(
model = model, status = ModelInitializationStatus.NOT_INITIALIZED model = model, status = ModelInitializationStatusType.NOT_INITIALIZED
) )
} }
} }
@ -380,8 +408,7 @@ open class ModelManagerViewModel(
val connection = url.openConnection() as HttpURLConnection val connection = url.openConnection() as HttpURLConnection
if (accessToken != null) { if (accessToken != null) {
connection.setRequestProperty( connection.setRequestProperty(
"Authorization", "Authorization", "Bearer $accessToken"
"Bearer $accessToken"
) )
} }
connection.connect() connection.connect()
@ -390,6 +417,47 @@ open class ModelManagerViewModel(
return connection.responseCode return connection.responseCode
} }
fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) {
Log.d(TAG, "adding local model: $fileName, $fileSize")
// Create model.
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
val model = Model(
name = fileName,
url = "",
configs = configs,
sizeInBytes = fileSize,
downloadFileName = "$IMPORTS_DIR/$fileName",
isLocalModel = true,
)
model.preProcess(task = task)
task.models.add(model)
// Add initial status and states.
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
modelDownloadStatus[model.name] = ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize
)
modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
// Update ui state.
_uiState.update {
uiState.value.copy(
tasks = uiState.value.tasks.toList(),
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances
)
}
task.updateTrigger.value = System.currentTimeMillis()
// Add to preference storage.
val localModels = dataStoreRepository.readLocalModels().toMutableList()
localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize))
dataStoreRepository.saveLocalModels(localModels = localModels)
}
fun getTokenStatusAndData(): TokenStatusAndData { fun getTokenStatusAndData(): TokenStatusAndData {
// Try to load token data from DataStore. // Try to load token data from DataStore.
var tokenStatus = TokenStatus.NOT_STORED var tokenStatus = TokenStatus.NOT_STORED
@ -436,8 +504,7 @@ open class ModelManagerViewModel(
if (dataIntent == null) { if (dataIntent == null) {
onTokenRequested( onTokenRequested(
TokenRequestResult( TokenRequestResult(
status = TokenRequestResultType.FAILED, status = TokenRequestResultType.FAILED, errorMessage = "Empty auth result"
errorMessage = "Empty auth result"
) )
) )
return return
@ -481,8 +548,7 @@ open class ModelManagerViewModel(
} else { } else {
onTokenRequested( onTokenRequested(
TokenRequestResult( TokenRequestResult(
status = TokenRequestResultType.FAILED, status = TokenRequestResultType.FAILED, errorMessage = errorMessage
errorMessage = errorMessage
) )
) )
} }
@ -513,23 +579,49 @@ open class ModelManagerViewModel(
} }
private fun createUiState(): ModelManagerUiState { private fun createUiState(): ModelManagerUiState {
val modelsByTaskName: Map<String, MutableList<Model>> =
TASKS.associate { task -> task.type.label to task.models }
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf() val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf() val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
for ((_, models) in modelsByTaskName.entries) { for (task in TASKS) {
for (model in models) { for (model in task.models) {
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model) modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
} }
} }
// Load local models.
for (localModel in dataStoreRepository.readLocalModels()) {
Log.d(TAG, "stored local model: $localModel")
// Create model.
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
val model = Model(
name = localModel.fileName,
url = "",
configs = configs,
sizeInBytes = localModel.fileSize,
downloadFileName = "$IMPORTS_DIR/${localModel.fileName}",
isLocalModel = true,
)
// Add to task.
val task = TASK_LLM_CHAT
model.preProcess(task = task)
task.models.add(model)
// Update status.
modelDownloadStatus[model.name] = ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED,
receivedBytes = localModel.fileSize,
totalBytes = localModel.fileSize
)
}
val textInputHistory = dataStoreRepository.readTextInputHistory() val textInputHistory = dataStoreRepository.readTextInputHistory()
Log.d(TAG, "text input history: $textInputHistory") Log.d(TAG, "text input history: $textInputHistory")
return ModelManagerUiState( return ModelManagerUiState(
tasks = TASKS, tasks = TASKS,
modelsByTaskName = modelsByTaskName,
modelDownloadStatus = modelDownloadStatus, modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances, modelInitializationStatus = modelInstances,
textInputHistory = textInputHistory, textInputHistory = textInputHistory,
@ -610,7 +702,8 @@ open class ModelManagerViewModel(
// Add initial status and states. // Add initial status and states.
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model) modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
} }
} }
} }
@ -677,9 +770,13 @@ open class ModelManagerViewModel(
} }
} }
private fun updateModelInitializationStatus(model: Model, status: ModelInitializationStatus) { private fun updateModelInitializationStatus(
model: Model,
status: ModelInitializationStatusType,
error: String = ""
) {
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap() val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
curModelInstance[model.name] = status curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error)
val newUiState = uiState.value.copy(modelInitializationStatus = curModelInstance) val newUiState = uiState.value.copy(modelInitializationStatus = curModelInstance)
_uiState.update { newUiState } _uiState.update { newUiState }
} }

View file

@ -18,6 +18,7 @@ package com.google.aiedge.gallery.ui.preview
import com.google.aiedge.gallery.data.AccessTokenData import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.DataStoreRepository import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.LocalModelInfo
class PreviewDataStoreRepository : DataStoreRepository { class PreviewDataStoreRepository : DataStoreRepository {
override fun saveTextInputHistory(history: List<String>) { override fun saveTextInputHistory(history: List<String>) {
@ -40,4 +41,11 @@ class PreviewDataStoreRepository : DataStoreRepository {
override fun readAccessTokenData(): AccessTokenData? { override fun readAccessTokenData(): AccessTokenData? {
return null return null
} }
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
}
override fun readLocalModels(): List<LocalModelInfo> {
return listOf()
}
} }

View file

@ -17,7 +17,6 @@
package com.google.aiedge.gallery.ui.preview package com.google.aiedge.gallery.ui.preview
import android.content.Context import android.content.Context
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerUiState import com.google.aiedge.gallery.ui.modelmanager.ModelManagerUiState
@ -39,8 +38,6 @@ class PreviewModelManagerViewModel(context: Context) :
} }
} }
val modelsByTaskName: Map<String, MutableList<Model>> =
ALL_PREVIEW_TASKS.associate { task -> task.type.label to task.models }
val modelDownloadStatus = mapOf( val modelDownloadStatus = mapOf(
MODEL_TEST1.name to ModelDownloadStatus( MODEL_TEST1.name to ModelDownloadStatus(
status = ModelDownloadStatusType.IN_PROGRESS, status = ModelDownloadStatusType.IN_PROGRESS,
@ -61,7 +58,6 @@ class PreviewModelManagerViewModel(context: Context) :
) )
val newUiState = ModelManagerUiState( val newUiState = ModelManagerUiState(
tasks = ALL_PREVIEW_TASKS, tasks = ALL_PREVIEW_TASKS,
modelsByTaskName = modelsByTaskName,
modelDownloadStatus = modelDownloadStatus, modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = mapOf(), modelInitializationStatus = mapOf(),
selectedModel = MODEL_TEST2, selectedModel = MODEL_TEST2,

View file

@ -40,14 +40,14 @@ class TextClassificationInferenceResult(
* Helper object for managing text classification models. * Helper object for managing text classification models.
*/ */
object TextClassificationModelHelper { object TextClassificationModelHelper {
fun initialize(context: Context, model: Model, onDone: () -> Unit) { fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
val modelByteBuffer = readFileToByteBuffer(File(model.getPath(context = context))) val modelByteBuffer = readFileToByteBuffer(File(model.getPath(context = context)))
if (modelByteBuffer != null) { if (modelByteBuffer != null) {
val options = TextClassifier.TextClassifierOptions.builder().setBaseOptions( val options = TextClassifier.TextClassifierOptions.builder().setBaseOptions(
BaseOptions.builder().setModelAssetBuffer(modelByteBuffer).build() BaseOptions.builder().setModelAssetBuffer(modelByteBuffer).build()
).build() ).build()
model.instance = TextClassifier.createFromOptions(context, options) model.instance = TextClassifier.createFromOptions(context, options)
onDone() onDone("")
} }
} }