mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-12 01:12:20 -04:00
Add initial support for importing local model.
This commit is contained in:
parent
b2f35a86e7
commit
29b614355e
19 changed files with 789 additions and 126 deletions
|
@ -47,6 +47,8 @@ interface DataStoreRepository {
|
|||
fun readThemeOverride(): String
|
||||
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
|
||||
fun readAccessTokenData(): AccessTokenData?
|
||||
fun saveLocalModels(localModels: List<LocalModelInfo>)
|
||||
fun readLocalModels(): List<LocalModelInfo>
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -79,6 +81,9 @@ class DefaultDataStoreRepository(
|
|||
val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv")
|
||||
|
||||
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
|
||||
|
||||
// Data for all imported local models.
|
||||
val LOCAL_MODELS = stringPreferencesKey("local_models")
|
||||
}
|
||||
|
||||
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
|
||||
|
@ -155,6 +160,26 @@ class DefaultDataStoreRepository(
|
|||
}
|
||||
}
|
||||
|
||||
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
|
||||
runBlocking {
|
||||
dataStore.edit { preferences ->
|
||||
val gson = Gson()
|
||||
val jsonString = gson.toJson(localModels)
|
||||
preferences[PreferencesKeys.LOCAL_MODELS] = jsonString
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun readLocalModels(): List<LocalModelInfo> {
|
||||
return runBlocking {
|
||||
val preferences = dataStore.data.first()
|
||||
val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]"
|
||||
val gson = Gson()
|
||||
val listType = object : TypeToken<List<LocalModelInfo>>() {}.type
|
||||
gson.fromJson(infosStr, listType)
|
||||
}
|
||||
}
|
||||
|
||||
private fun getTextInputHistory(preferences: Preferences): List<String> {
|
||||
val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]"
|
||||
val gson = Gson()
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package com.google.aiedge.gallery.data
|
||||
|
||||
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
||||
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
|
||||
import kotlinx.serialization.KSerializer
|
||||
import kotlinx.serialization.Serializable
|
||||
|
@ -103,7 +104,7 @@ data class HfModel(
|
|||
} else {
|
||||
listOf("")
|
||||
}
|
||||
val fileName = "${id}_${(parts.lastOrNull() ?: "")}".replace(Regex("[^a-zA-Z0-9._-]"), "_")
|
||||
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
|
||||
|
||||
// Generate configs based on the given default values.
|
||||
val configs: List<Config> = when (task) {
|
||||
|
|
|
@ -32,6 +32,8 @@ enum class LlmBackend {
|
|||
CPU, GPU
|
||||
}
|
||||
|
||||
const val IMPORTS_DIR = "__imports"
|
||||
|
||||
/** A model for a task */
|
||||
data class Model(
|
||||
/** The Hugging Face model ID (if applicable). */
|
||||
|
@ -85,6 +87,9 @@ data class Model(
|
|||
/** The prompt templates for the model (only for LLM). */
|
||||
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
||||
|
||||
/** Whether the model is imported as a local model. */
|
||||
val isLocalModel: Boolean = false,
|
||||
|
||||
// The following fields are managed by the app. Don't need to set manually.
|
||||
var taskType: TaskType? = null,
|
||||
var instance: Any? = null,
|
||||
|
@ -104,10 +109,11 @@ data class Model(
|
|||
}
|
||||
|
||||
fun getPath(context: Context, fileName: String = downloadFileName): String {
|
||||
val baseDir = "${context.getExternalFilesDir(null)}"
|
||||
return if (this.isZip && this.unzipDir.isNotEmpty()) {
|
||||
"${context.getExternalFilesDir(null)}/${this.unzipDir}"
|
||||
"$baseDir/${this.unzipDir}"
|
||||
} else {
|
||||
"${context.getExternalFilesDir(null)}/${fileName}"
|
||||
"$baseDir/${fileName}"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,6 +146,9 @@ data class Model(
|
|||
}
|
||||
}
|
||||
|
||||
/** Data for a imported local model. */
|
||||
data class LocalModelInfo(val fileName: String, val fileSize: Long)
|
||||
|
||||
enum class ModelDownloadStatusType {
|
||||
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ package com.google.aiedge.gallery.data
|
|||
import androidx.annotation.StringRes
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.ImageSearch
|
||||
import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import com.google.aiedge.gallery.R
|
||||
|
||||
|
@ -63,7 +65,9 @@ data class Task(
|
|||
@StringRes val textInputPlaceHolderRes: Int = R.string.chat_textinput_placeholder,
|
||||
|
||||
// The following fields are managed by the app. Don't need to set manually.
|
||||
var index: Int = -1
|
||||
var index: Int = -1,
|
||||
|
||||
val updateTrigger: MutableState<Long> = mutableStateOf(0)
|
||||
)
|
||||
|
||||
val TASK_TEXT_CLASSIFICATION = Task(
|
||||
|
|
|
@ -46,6 +46,7 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.modelmanager.TokenRequestResultType
|
||||
import com.google.aiedge.gallery.ui.modelmanager.TokenStatus
|
||||
|
@ -90,6 +91,7 @@ private const val TAG = "AGDownloadAndTryButton"
|
|||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun DownloadAndTryButton(
|
||||
task: Task,
|
||||
model: Model,
|
||||
enabled: Boolean,
|
||||
needToDownloadFirst: Boolean,
|
||||
|
@ -106,17 +108,18 @@ fun DownloadAndTryButton(
|
|||
val permissionLauncher = rememberLauncherForActivityResult(
|
||||
ActivityResultContracts.RequestPermission()
|
||||
) {
|
||||
modelManagerViewModel.downloadModel(model)
|
||||
modelManagerViewModel.downloadModel(task = task, model = model)
|
||||
}
|
||||
|
||||
// Function to kick off download.
|
||||
val startDownload: (accessToken: String?) -> Unit = { accessToken ->
|
||||
model.accessToken = accessToken
|
||||
onClicked()
|
||||
checkNotificationPermissonAndStartDownload(
|
||||
checkNotificationPermissionAndStartDownload(
|
||||
context = context,
|
||||
launcher = permissionLauncher,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
task = task,
|
||||
model = model
|
||||
)
|
||||
checkingToken = false
|
||||
|
|
|
@ -416,10 +416,11 @@ fun getTaskIconColor(index: Int): Color {
|
|||
return MaterialTheme.customColors.taskIconColors[colorIndex]
|
||||
}
|
||||
|
||||
fun checkNotificationPermissonAndStartDownload(
|
||||
fun checkNotificationPermissionAndStartDownload(
|
||||
context: Context,
|
||||
launcher: ManagedActivityResultLauncher<String, Boolean>,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
task: Task,
|
||||
model: Model
|
||||
) {
|
||||
// Check permission
|
||||
|
@ -428,7 +429,7 @@ fun checkNotificationPermissonAndStartDownload(
|
|||
ContextCompat.checkSelfPermission(
|
||||
context, Manifest.permission.POST_NOTIFICATIONS
|
||||
) -> {
|
||||
modelManagerViewModel.downloadModel(model)
|
||||
modelManagerViewModel.downloadModel(task = task, model = model)
|
||||
}
|
||||
|
||||
// Otherwise, ask for permission
|
||||
|
@ -440,3 +441,14 @@ fun checkNotificationPermissonAndStartDownload(
|
|||
}
|
||||
}
|
||||
|
||||
fun ensureValidFileName(fileName: String): String {
|
||||
return fileName.replace(Regex("[^a-zA-Z0-9._-]"), "_")
|
||||
}
|
||||
|
||||
fun cleanUpMediapipeTaskErrorMessage(message: String): String {
|
||||
val index = message.indexOf("=== Source Location Trace")
|
||||
if (index >= 0) {
|
||||
return message.substring(0, index)
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
|
|
@ -41,10 +41,13 @@ import androidx.compose.foundation.layout.wrapContentHeight
|
|||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.foundation.lazy.rememberLazyListState
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.outlined.Timer
|
||||
import androidx.compose.material.icons.rounded.ContentCopy
|
||||
import androidx.compose.material.icons.rounded.Refresh
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
|
@ -83,11 +86,12 @@ import androidx.compose.ui.res.stringResource
|
|||
import androidx.compose.ui.text.AnnotatedString
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.window.Dialog
|
||||
import com.google.aiedge.gallery.R
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.TaskType
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatus
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
|
@ -113,6 +117,7 @@ fun ChatPanel(
|
|||
onSendMessage: (Model, ChatMessage) -> Unit,
|
||||
onRunAgainClicked: (Model, ChatMessage) -> Unit,
|
||||
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
|
||||
navigateUp: () -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
|
||||
onStreamEnd: (Int) -> Unit = {},
|
||||
|
@ -140,6 +145,8 @@ fun ChatPanel(
|
|||
var showMessageLongPressedSheet by remember { mutableStateOf(false) }
|
||||
val longPressedMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
|
||||
|
||||
var showErrorDialog by remember { mutableStateOf(false) }
|
||||
|
||||
// Keep track of the last message and last message content.
|
||||
val lastMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
|
||||
val lastMessageContent: MutableState<String> = remember { mutableStateOf("") }
|
||||
|
@ -201,6 +208,10 @@ fun ChatPanel(
|
|||
val modelInitializationStatus =
|
||||
modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
||||
|
||||
LaunchedEffect(modelInitializationStatus) {
|
||||
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
|
||||
}
|
||||
|
||||
Column(
|
||||
modifier = modifier.imePadding()
|
||||
) {
|
||||
|
@ -417,7 +428,7 @@ fun ChatPanel(
|
|||
|
||||
// Model initialization in-progress message.
|
||||
this@Column.AnimatedVisibility(
|
||||
visible = modelInitializationStatus == ModelInitializationStatus.INITIALIZING,
|
||||
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||
enter = scaleIn() + fadeIn(),
|
||||
exit = scaleOut() + fadeOut(),
|
||||
modifier = Modifier.offset(y = 12.dp)
|
||||
|
@ -479,6 +490,47 @@ fun ChatPanel(
|
|||
}
|
||||
}
|
||||
|
||||
// Error dialog.
|
||||
if (showErrorDialog) {
|
||||
Dialog(
|
||||
onDismissRequest = {
|
||||
showErrorDialog = false
|
||||
navigateUp()
|
||||
},
|
||||
) {
|
||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(20.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Title
|
||||
Text(
|
||||
"Error",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
modifier = Modifier.padding(bottom = 8.dp)
|
||||
)
|
||||
|
||||
// Error
|
||||
Text(
|
||||
modelInitializationStatus?.error ?: "",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.error,
|
||||
)
|
||||
|
||||
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
|
||||
Button(onClick = {
|
||||
showErrorDialog = false
|
||||
navigateUp()
|
||||
}) {
|
||||
Text("Close")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark config dialog.
|
||||
if (showBenchmarkConfigsDialog) {
|
||||
BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false },
|
||||
|
@ -547,6 +599,7 @@ fun ChatPanelPreview() {
|
|||
task = task,
|
||||
selectedModel = TASK_TEST1.models[1],
|
||||
viewModel = PreviewChatModel(context = context),
|
||||
navigateUp = {},
|
||||
onSendMessage = { _, _ -> },
|
||||
onRunAgainClicked = { _, _ -> },
|
||||
onBenchmarkClicked = { _, _, _, _ -> },
|
||||
|
|
|
@ -55,7 +55,7 @@ import com.google.aiedge.gallery.data.AppBarActionType
|
|||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload
|
||||
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
|
@ -104,7 +104,7 @@ fun ChatView(
|
|||
val launcher = rememberLauncherForActivityResult(
|
||||
ActivityResultContracts.RequestPermission()
|
||||
) {
|
||||
modelManagerViewModel.downloadModel(selectedModel)
|
||||
modelManagerViewModel.downloadModel(task = task, model = selectedModel)
|
||||
}
|
||||
|
||||
val handleNavigateUp = {
|
||||
|
@ -245,10 +245,11 @@ fun ChatView(
|
|||
exit = fadeOut()
|
||||
) {
|
||||
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
|
||||
checkNotificationPermissonAndStartDownload(
|
||||
checkNotificationPermissionAndStartDownload(
|
||||
context = context,
|
||||
launcher = launcher,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
task = task,
|
||||
model = curSelectedModel
|
||||
)
|
||||
})
|
||||
|
@ -261,6 +262,7 @@ fun ChatView(
|
|||
task = task,
|
||||
selectedModel = curSelectedModel,
|
||||
viewModel = viewModel,
|
||||
navigateUp = navigateUp,
|
||||
onSendMessage = onSendMessage,
|
||||
onRunAgainClicked = onRunAgainClicked,
|
||||
onBenchmarkClicked = onBenchmarkClicked,
|
||||
|
|
|
@ -35,6 +35,7 @@ import androidx.compose.foundation.layout.offset
|
|||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.ChevronRight
|
||||
import androidx.compose.material.icons.rounded.Settings
|
||||
import androidx.compose.material.icons.rounded.UnfoldLess
|
||||
import androidx.compose.material.icons.rounded.UnfoldMore
|
||||
|
@ -68,7 +69,7 @@ import com.google.aiedge.gallery.data.Task
|
|||
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
|
||||
import com.google.aiedge.gallery.ui.common.TaskIcon
|
||||
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
||||
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload
|
||||
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
|
||||
import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
||||
import com.google.aiedge.gallery.ui.common.getTaskIconColor
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
|
@ -113,7 +114,7 @@ fun ModelItem(
|
|||
val launcher = rememberLauncherForActivityResult(
|
||||
ActivityResultContracts.RequestPermission()
|
||||
) {
|
||||
modelManagerViewModel.downloadModel(model)
|
||||
modelManagerViewModel.downloadModel(task = task, model = model)
|
||||
}
|
||||
|
||||
var isExpanded by remember { mutableStateOf(false) }
|
||||
|
@ -156,10 +157,11 @@ fun ModelItem(
|
|||
modelManagerViewModel = modelManagerViewModel,
|
||||
downloadStatus = downloadStatus,
|
||||
onDownloadClicked = { model ->
|
||||
checkNotificationPermissonAndStartDownload(
|
||||
checkNotificationPermissionAndStartDownload(
|
||||
context = context,
|
||||
launcher = launcher,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
task = task,
|
||||
model = model
|
||||
)
|
||||
},
|
||||
|
@ -186,7 +188,9 @@ fun ModelItem(
|
|||
}
|
||||
} else {
|
||||
Icon(
|
||||
if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
||||
// For local model, show ">" directly indicating users can just tap the model item to
|
||||
// go into it without needing to expand it first.
|
||||
if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
||||
contentDescription = "",
|
||||
tint = getTaskIconColor(task),
|
||||
)
|
||||
|
@ -237,6 +241,7 @@ fun ModelItem(
|
|||
val needToDownloadFirst =
|
||||
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
|
||||
DownloadAndTryButton(
|
||||
task = task,
|
||||
model = model,
|
||||
enabled = isExpanded,
|
||||
needToDownloadFirst = needToDownloadFirst,
|
||||
|
@ -266,7 +271,13 @@ fun ModelItem(
|
|||
)
|
||||
boxModifier = if (canExpand) {
|
||||
boxModifier.clickable(
|
||||
onClick = { isExpanded = !isExpanded },
|
||||
onClick = {
|
||||
if (!model.isLocalModel) {
|
||||
isExpanded = !isExpanded
|
||||
} else {
|
||||
onModelClicked(model)
|
||||
}
|
||||
},
|
||||
interactionSource = remember { MutableInteractionSource() },
|
||||
indication = ripple(
|
||||
bounded = true,
|
||||
|
|
|
@ -124,7 +124,7 @@ fun ModelItemActionButton(
|
|||
|
||||
if (showConfirmDeleteDialog) {
|
||||
ConfirmDeleteModelDialog(model = model, onConfirm = {
|
||||
modelManagerViewModel.deleteModel(model)
|
||||
modelManagerViewModel.deleteModel(task = task, model = model)
|
||||
showConfirmDeleteDialog = false
|
||||
}, onDismiss = {
|
||||
showConfirmDeleteDialog = false
|
||||
|
|
|
@ -48,7 +48,7 @@ class ImageClassificationInferenceResult(
|
|||
//TODO: handle error.
|
||||
|
||||
object ImageClassificationModelHelper {
|
||||
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
|
||||
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
|
||||
val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU)
|
||||
TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask ->
|
||||
val optionsBuilder = TfLiteInitializationOptions.builder()
|
||||
|
@ -69,7 +69,7 @@ object ImageClassificationModelHelper {
|
|||
File(model.getPath(context = context)), interpreterOption
|
||||
)
|
||||
model.instance = interpreter
|
||||
onDone()
|
||||
onDone("")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator
|
|||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.ui.common.LatencyProvider
|
||||
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
||||
import kotlin.random.Random
|
||||
|
||||
private const val TAG = "AGImageGenerationModelHelper"
|
||||
|
@ -33,12 +34,17 @@ class ImageGenerationInferenceResult(
|
|||
) : LatencyProvider
|
||||
|
||||
object ImageGenerationModelHelper {
|
||||
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
|
||||
val options = ImageGenerator.ImageGeneratorOptions.builder()
|
||||
.setImageGeneratorModelDirectory(model.getPath(context = context))
|
||||
.build()
|
||||
model.instance = ImageGenerator.createFromOptions(context, options)
|
||||
onDone()
|
||||
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
|
||||
try {
|
||||
val options = ImageGenerator.ImageGeneratorOptions.builder()
|
||||
.setImageGeneratorModelDirectory(model.getPath(context = context))
|
||||
.build()
|
||||
model.instance = ImageGenerator.createFromOptions(context, options)
|
||||
} catch (e: Exception) {
|
||||
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
|
||||
return
|
||||
}
|
||||
onDone("")
|
||||
}
|
||||
|
||||
fun cleanUp(model: Model) {
|
||||
|
|
|
@ -21,6 +21,7 @@ import android.util.Log
|
|||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.LlmBackend
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||
|
||||
|
@ -40,7 +41,7 @@ object LlmChatModelHelper {
|
|||
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
|
||||
|
||||
fun initialize(
|
||||
context: Context, model: Model, onDone: () -> Unit
|
||||
context: Context, model: Model, onDone: (String) -> Unit
|
||||
) {
|
||||
val maxTokens =
|
||||
model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN)
|
||||
|
@ -68,9 +69,10 @@ object LlmChatModelHelper {
|
|||
)
|
||||
model.instance = LlmModelInstance(engine = llmInference, session = session)
|
||||
} catch (e: Exception) {
|
||||
e.printStackTrace()
|
||||
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
|
||||
return
|
||||
}
|
||||
onDone()
|
||||
onDone("")
|
||||
}
|
||||
|
||||
fun cleanUp(model: Model) {
|
||||
|
|
|
@ -0,0 +1,260 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.modelmanager
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import android.provider.OpenableColumns
|
||||
import android.util.Log
|
||||
import androidx.compose.animation.core.Animatable
|
||||
import androidx.compose.animation.core.tween
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.Error
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.LinearProgressIndicator
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableFloatStateOf
|
||||
import androidx.compose.runtime.mutableLongStateOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.window.Dialog
|
||||
import androidx.compose.ui.window.DialogProperties
|
||||
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
||||
import com.google.aiedge.gallery.ui.common.humanReadableSize
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.net.URLDecoder
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
private const val TAG = "AGModelImportDialog"
|
||||
|
||||
data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "")
|
||||
|
||||
@Composable
|
||||
fun ModelImportDialog(
|
||||
uri: Uri, onDone: (ModelImportInfo) -> Unit
|
||||
) {
|
||||
val context = LocalContext.current
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
var fileName by remember { mutableStateOf("") }
|
||||
var fileSize by remember { mutableLongStateOf(0L) }
|
||||
var error by remember { mutableStateOf("") }
|
||||
var progress by remember { mutableFloatStateOf(0f) }
|
||||
|
||||
LaunchedEffect(Unit) {
|
||||
error = ""
|
||||
|
||||
// Get basic info.
|
||||
val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri)
|
||||
fileSize = info.first
|
||||
fileName = ensureValidFileName(info.second)
|
||||
|
||||
// Import.
|
||||
importModel(
|
||||
context = context,
|
||||
coroutineScope = coroutineScope,
|
||||
fileName = fileName,
|
||||
fileSize = fileSize,
|
||||
uri = uri,
|
||||
onDone = {
|
||||
onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error))
|
||||
},
|
||||
onProgress = {
|
||||
progress = it
|
||||
},
|
||||
onError = {
|
||||
error = it
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
Dialog(
|
||||
properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false),
|
||||
onDismissRequest = {},
|
||||
) {
|
||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(20.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Title.
|
||||
Text(
|
||||
"Importing...",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
modifier = Modifier.padding(bottom = 8.dp)
|
||||
)
|
||||
|
||||
// No error.
|
||||
if (error.isEmpty()) {
|
||||
// Progress bar.
|
||||
Column(verticalArrangement = Arrangement.spacedBy(4.dp)) {
|
||||
Text(
|
||||
"$fileName (${fileSize.humanReadableSize()})",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
)
|
||||
val animatedProgress = remember { Animatable(0f) }
|
||||
LinearProgressIndicator(
|
||||
progress = { animatedProgress.value },
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(bottom = 8.dp),
|
||||
)
|
||||
LaunchedEffect(progress) {
|
||||
animatedProgress.animateTo(progress, animationSpec = tween(150))
|
||||
}
|
||||
}
|
||||
}
|
||||
// Has error.
|
||||
else {
|
||||
Row(
|
||||
verticalAlignment = Alignment.Top,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.Error,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.error
|
||||
)
|
||||
Text(
|
||||
error,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.error,
|
||||
modifier = Modifier.padding(top = 4.dp)
|
||||
)
|
||||
}
|
||||
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
|
||||
Button(onClick = {
|
||||
onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error))
|
||||
}) {
|
||||
Text("Close")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun importModel(
|
||||
context: Context,
|
||||
coroutineScope: CoroutineScope,
|
||||
fileName: String,
|
||||
fileSize: Long,
|
||||
uri: Uri,
|
||||
onDone: () -> Unit,
|
||||
onProgress: (Float) -> Unit,
|
||||
onError: (String) -> Unit,
|
||||
) {
|
||||
// TODO: handle error.
|
||||
coroutineScope.launch(Dispatchers.IO) {
|
||||
// Get the last component of the uri path as the imported file name.
|
||||
val decodedUri = URLDecoder.decode(uri.toString(), StandardCharsets.UTF_8.name())
|
||||
Log.d(TAG, "importing model from $decodedUri. File name: $fileName. File size: $fileSize")
|
||||
|
||||
// Create <app_external_dir>/imports if not exist.
|
||||
val importsDir = File(context.getExternalFilesDir(null), IMPORTS_DIR)
|
||||
if (!importsDir.exists()) {
|
||||
importsDir.mkdirs()
|
||||
}
|
||||
|
||||
// Import by copying the file over.
|
||||
val outputFile = File(context.getExternalFilesDir(null), "$IMPORTS_DIR/$fileName")
|
||||
val outputStream = FileOutputStream(outputFile)
|
||||
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
|
||||
var bytesRead: Int
|
||||
var lastSetProgressTs: Long = 0
|
||||
var importedBytes = 0L
|
||||
val inputStream = context.contentResolver.openInputStream(uri)
|
||||
try {
|
||||
if (inputStream != null) {
|
||||
while (inputStream.read(buffer).also { bytesRead = it } != -1) {
|
||||
outputStream.write(buffer, 0, bytesRead)
|
||||
importedBytes += bytesRead
|
||||
|
||||
// Report progress every 200 ms.
|
||||
val curTs = System.currentTimeMillis()
|
||||
if (curTs - lastSetProgressTs > 200) {
|
||||
Log.d(TAG, "importing progress: $importedBytes, $fileSize")
|
||||
lastSetProgressTs = curTs
|
||||
if (fileSize != 0L) {
|
||||
onProgress(importedBytes.toFloat() / fileSize.toFloat())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
e.printStackTrace()
|
||||
onError(e.message ?: "Failed to import")
|
||||
return@launch
|
||||
} finally {
|
||||
inputStream?.close()
|
||||
outputStream.close()
|
||||
}
|
||||
Log.d(TAG, "import done")
|
||||
onProgress(1f)
|
||||
onDone()
|
||||
}
|
||||
}
|
||||
|
||||
private fun getFileSizeAndDisplayNameFromUri(context: Context, uri: Uri): Pair<Long, String> {
|
||||
val contentResolver = context.contentResolver
|
||||
var fileSize = 0L
|
||||
var displayName = ""
|
||||
|
||||
try {
|
||||
contentResolver.query(
|
||||
uri, arrayOf(OpenableColumns.SIZE, OpenableColumns.DISPLAY_NAME), null, null, null
|
||||
)?.use { cursor ->
|
||||
if (cursor.moveToFirst()) {
|
||||
val sizeIndex = cursor.getColumnIndexOrThrow(OpenableColumns.SIZE)
|
||||
fileSize = cursor.getLong(sizeIndex)
|
||||
|
||||
val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME)
|
||||
displayName = cursor.getString(nameIndex)
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
e.printStackTrace()
|
||||
return Pair(0L, "")
|
||||
}
|
||||
|
||||
return Pair(fileSize, displayName)
|
||||
}
|
|
@ -16,24 +16,46 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.modelmanager
|
||||
|
||||
import android.content.Intent
|
||||
import android.net.Uri
|
||||
import android.os.Build
|
||||
import android.util.Log
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.ActivityResultLauncher
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.annotation.RequiresApi
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.PaddingValues
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
||||
import androidx.compose.material.icons.filled.Add
|
||||
import androidx.compose.material.icons.outlined.Code
|
||||
import androidx.compose.material.icons.outlined.Description
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
import androidx.compose.material3.SmallFloatingActionButton
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.rememberModalBottomSheetState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
|
@ -53,8 +75,14 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
|||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
private const val TAG = "AGModelList"
|
||||
|
||||
/** The list of models in the model manager. */
|
||||
@RequiresApi(Build.VERSION_CODES.O)
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ModelList(
|
||||
task: Task,
|
||||
|
@ -63,65 +91,214 @@ fun ModelList(
|
|||
onModelClicked: (Model) -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
) {
|
||||
LazyColumn(
|
||||
modifier = modifier.padding(top = 8.dp),
|
||||
contentPadding = contentPadding,
|
||||
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||
) {
|
||||
// Headline.
|
||||
item(key = "headline") {
|
||||
Text(
|
||||
task.description,
|
||||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
var showAddModelSheet by remember { mutableStateOf(false) }
|
||||
var showImportingDialog by remember { mutableStateOf(false) }
|
||||
val curFileUri = remember { mutableStateOf<Uri?>(null) }
|
||||
val sheetState = rememberModalBottomSheetState()
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
// URLs.
|
||||
item(key = "urls") {
|
||||
Row(
|
||||
horizontalArrangement = Arrangement.Center,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(top = 12.dp, bottom = 16.dp),
|
||||
) {
|
||||
Column(
|
||||
horizontalAlignment = Alignment.Start,
|
||||
verticalArrangement = Arrangement.spacedBy(4.dp),
|
||||
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
|
||||
// be properly updated.
|
||||
val models by remember {
|
||||
derivedStateOf {
|
||||
val trigger = task.updateTrigger.value
|
||||
if (trigger >= 0) {
|
||||
task.models.toList().filter { !it.isLocalModel }
|
||||
} else {
|
||||
listOf()
|
||||
}
|
||||
}
|
||||
}
|
||||
val localModels by remember {
|
||||
derivedStateOf {
|
||||
val trigger = task.updateTrigger.value
|
||||
if (trigger >= 0) {
|
||||
task.models.toList().filter { it.isLocalModel }
|
||||
} else {
|
||||
listOf()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
||||
contract = ActivityResultContracts.StartActivityForResult()
|
||||
) { result ->
|
||||
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
||||
result.data?.data?.let { uri ->
|
||||
curFileUri.value = uri
|
||||
showImportingDialog = true
|
||||
} ?: run {
|
||||
Log.d(TAG, "No file selected or URI is null.")
|
||||
}
|
||||
} else {
|
||||
Log.d(TAG, "File picking cancelled.")
|
||||
}
|
||||
}
|
||||
|
||||
Box(contentAlignment = Alignment.BottomEnd) {
|
||||
LazyColumn(
|
||||
modifier = modifier.padding(top = 8.dp),
|
||||
contentPadding = contentPadding,
|
||||
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||
) {
|
||||
// Headline.
|
||||
item(key = "headline") {
|
||||
Text(
|
||||
task.description,
|
||||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
|
||||
// URLs.
|
||||
item(key = "urls") {
|
||||
Row(
|
||||
horizontalArrangement = Arrangement.Center,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(top = 12.dp, bottom = 16.dp),
|
||||
) {
|
||||
if (task.docUrl.isNotEmpty()) {
|
||||
ClickableLink(
|
||||
url = task.docUrl,
|
||||
linkText = "API Documentation",
|
||||
icon = Icons.Outlined.Description
|
||||
)
|
||||
}
|
||||
if (task.sourceCodeUrl.isNotEmpty()) {
|
||||
ClickableLink(
|
||||
url = task.sourceCodeUrl,
|
||||
linkText = "Example code",
|
||||
icon = Icons.Outlined.Code
|
||||
)
|
||||
Column(
|
||||
horizontalAlignment = Alignment.Start,
|
||||
verticalArrangement = Arrangement.spacedBy(4.dp),
|
||||
) {
|
||||
if (task.docUrl.isNotEmpty()) {
|
||||
ClickableLink(
|
||||
url = task.docUrl, linkText = "API Documentation", icon = Icons.Outlined.Description
|
||||
)
|
||||
}
|
||||
if (task.sourceCodeUrl.isNotEmpty()) {
|
||||
ClickableLink(
|
||||
url = task.sourceCodeUrl, linkText = "Example code", icon = Icons.Outlined.Code
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// List of models within a task.
|
||||
items(items = models) { model ->
|
||||
Box {
|
||||
ModelItem(
|
||||
model = model,
|
||||
task = task,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onModelClicked = onModelClicked,
|
||||
modifier = Modifier.padding(horizontal = 12.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Title for local models.
|
||||
if (localModels.isNotEmpty()) {
|
||||
item(key = "localModelsTitle") {
|
||||
Text(
|
||||
"Local models",
|
||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||
modifier = Modifier
|
||||
.padding(horizontal = 16.dp)
|
||||
.padding(top = 24.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// List of local models within a task.
|
||||
items(items = localModels) { model ->
|
||||
Box {
|
||||
ModelItem(
|
||||
model = model,
|
||||
task = task,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onModelClicked = onModelClicked,
|
||||
modifier = Modifier.padding(horizontal = 12.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
item(key = "bottomPadding") {
|
||||
Spacer(modifier = Modifier.height(60.dp))
|
||||
}
|
||||
}
|
||||
|
||||
// List of models within a task.
|
||||
items(items = task.models) { model ->
|
||||
Box {
|
||||
ModelItem(
|
||||
model = model,
|
||||
task = task,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onModelClicked = onModelClicked,
|
||||
modifier = Modifier.padding(start = 12.dp, end = 12.dp)
|
||||
)
|
||||
// Add model button at the bottom right.
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.padding(end = 16.dp)
|
||||
.padding(bottom = contentPadding.calculateBottomPadding())
|
||||
) {
|
||||
SmallFloatingActionButton(
|
||||
onClick = {
|
||||
showAddModelSheet = true
|
||||
},
|
||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||
contentColor = MaterialTheme.colorScheme.secondary,
|
||||
) {
|
||||
Icon(Icons.Filled.Add, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (showAddModelSheet) {
|
||||
ModalBottomSheet(
|
||||
onDismissRequest = { showAddModelSheet = false },
|
||||
sheetState = sheetState,
|
||||
) {
|
||||
Text(
|
||||
"Add custom model",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
|
||||
)
|
||||
Box(modifier = Modifier.clickable {
|
||||
coroutineScope.launch {
|
||||
// Give it sometime to show the click effect.
|
||||
delay(200)
|
||||
showAddModelSheet = false
|
||||
|
||||
// Show file picker.
|
||||
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||
addCategory(Intent.CATEGORY_OPENABLE)
|
||||
type = "*/*"
|
||||
putExtra(
|
||||
Intent.EXTRA_MIME_TYPES,
|
||||
arrayOf("application/x-binary", "application/octet-stream")
|
||||
)
|
||||
// Single select.
|
||||
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
||||
}
|
||||
filePickerLauncher.launch(intent)
|
||||
}
|
||||
}) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp)
|
||||
) {
|
||||
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
|
||||
Text("Add local model")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (showImportingDialog) {
|
||||
curFileUri.value?.let { uri ->
|
||||
ModelImportDialog(uri = uri, onDone = { info ->
|
||||
showImportingDialog = false
|
||||
|
||||
if (info.error.isEmpty()) {
|
||||
// TODO: support other model types.
|
||||
modelManagerViewModel.addLocalLlmModel(
|
||||
task = task,
|
||||
fileName = info.fileName,
|
||||
fileSize = info.fileSize
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
|
@ -132,15 +309,11 @@ fun ClickableLink(
|
|||
) {
|
||||
val uriHandler = LocalUriHandler.current
|
||||
val annotatedText = AnnotatedString(
|
||||
text = linkText,
|
||||
spanStyles = listOf(
|
||||
text = linkText, spanStyles = listOf(
|
||||
AnnotatedString.Range(
|
||||
item = SpanStyle(
|
||||
color = MaterialTheme.customColors.linkColor,
|
||||
textDecoration = TextDecoration.Underline
|
||||
),
|
||||
start = 0,
|
||||
end = linkText.length
|
||||
color = MaterialTheme.customColors.linkColor, textDecoration = TextDecoration.Underline
|
||||
), start = 0, end = linkText.length
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -163,6 +336,7 @@ fun ClickableLink(
|
|||
}
|
||||
}
|
||||
|
||||
@RequiresApi(Build.VERSION_CODES.O)
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ModelListPreview() {
|
||||
|
|
|
@ -24,16 +24,20 @@ import androidx.lifecycle.ViewModel
|
|||
import androidx.lifecycle.viewModelScope
|
||||
import com.google.aiedge.gallery.data.AGWorkInfo
|
||||
import com.google.aiedge.gallery.data.AccessTokenData
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.DataStoreRepository
|
||||
import com.google.aiedge.gallery.data.DownloadRepository
|
||||
import com.google.aiedge.gallery.data.EMPTY_MODEL
|
||||
import com.google.aiedge.gallery.data.HfModel
|
||||
import com.google.aiedge.gallery.data.HfModelDetails
|
||||
import com.google.aiedge.gallery.data.HfModelSummary
|
||||
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||
import com.google.aiedge.gallery.data.LocalModelInfo
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.TASKS
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.TaskType
|
||||
import com.google.aiedge.gallery.data.getModelByName
|
||||
|
@ -41,6 +45,7 @@ import com.google.aiedge.gallery.ui.common.AuthConfig
|
|||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
|
||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
|
@ -66,8 +71,12 @@ private const val TAG = "AGModelManagerViewModel"
|
|||
private const val HG_COMMUNITY = "jinjingforevercommunity"
|
||||
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
|
||||
|
||||
enum class ModelInitializationStatus {
|
||||
NOT_INITIALIZED, INITIALIZING, INITIALIZED,
|
||||
data class ModelInitializationStatus(
|
||||
val status: ModelInitializationStatusType, var error: String = ""
|
||||
)
|
||||
|
||||
enum class ModelInitializationStatusType {
|
||||
NOT_INITIALIZED, INITIALIZING, INITIALIZED, ERROR
|
||||
}
|
||||
|
||||
enum class TokenStatus {
|
||||
|
@ -84,8 +93,7 @@ data class TokenStatusAndData(
|
|||
)
|
||||
|
||||
data class TokenRequestResult(
|
||||
val status: TokenRequestResultType,
|
||||
val errorMessage: String? = null
|
||||
val status: TokenRequestResultType, val errorMessage: String? = null
|
||||
)
|
||||
|
||||
data class ModelManagerUiState(
|
||||
|
@ -94,11 +102,6 @@ data class ModelManagerUiState(
|
|||
*/
|
||||
val tasks: List<Task>,
|
||||
|
||||
/**
|
||||
* A map that stores lists of models indexed by task name.
|
||||
*/
|
||||
val modelsByTaskName: Map<String, MutableList<Model>>,
|
||||
|
||||
/**
|
||||
* A map that tracks the download status of each model, indexed by model name.
|
||||
*/
|
||||
|
@ -191,14 +194,14 @@ open class ModelManagerViewModel(
|
|||
_uiState.update { _uiState.value.copy(selectedModel = model) }
|
||||
}
|
||||
|
||||
fun downloadModel(model: Model) {
|
||||
fun downloadModel(task: Task, model: Model) {
|
||||
// Update status.
|
||||
setDownloadStatus(
|
||||
curModel = model, status = ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS)
|
||||
)
|
||||
|
||||
// Delete the model files first.
|
||||
deleteModel(model = model)
|
||||
deleteModel(task = task, model = model)
|
||||
|
||||
// Start to send download request.
|
||||
downloadRepository.downloadModel(
|
||||
|
@ -210,7 +213,7 @@ open class ModelManagerViewModel(
|
|||
downloadRepository.cancelDownloadModel(model)
|
||||
}
|
||||
|
||||
fun deleteModel(model: Model) {
|
||||
fun deleteModel(task: Task, model: Model) {
|
||||
deleteFileFromExternalFilesDir(model.downloadFileName)
|
||||
for (file in model.extraDataFiles) {
|
||||
deleteFileFromExternalFilesDir(file.downloadFileName)
|
||||
|
@ -223,6 +226,24 @@ open class ModelManagerViewModel(
|
|||
val curModelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
||||
curModelDownloadStatus[model.name] =
|
||||
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
|
||||
|
||||
// Delete model from the list if model is imported as a local model.
|
||||
if (model.isLocalModel) {
|
||||
val index = task.models.indexOf(model)
|
||||
if (index >= 0) {
|
||||
task.models.removeAt(index)
|
||||
}
|
||||
task.updateTrigger.value = System.currentTimeMillis()
|
||||
curModelDownloadStatus.remove(model.name)
|
||||
|
||||
// Update preference.
|
||||
val localModels = dataStoreRepository.readLocalModels().toMutableList()
|
||||
val localModelIndex = localModels.indexOfFirst { it.fileName == model.name }
|
||||
if (localModelIndex >= 0) {
|
||||
localModels.removeAt(localModelIndex)
|
||||
}
|
||||
dataStoreRepository.saveLocalModels(localModels = localModels)
|
||||
}
|
||||
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
|
||||
_uiState.update { newUiState }
|
||||
}
|
||||
|
@ -230,7 +251,7 @@ open class ModelManagerViewModel(
|
|||
fun initializeModel(context: Context, model: Model, force: Boolean = false) {
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
// Skip if initialized already.
|
||||
if (!force && uiState.value.modelInitializationStatus[model.name] == ModelInitializationStatus.INITIALIZED) {
|
||||
if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) {
|
||||
Log.d(TAG, "Model '${model.name}' has been initialized. Skipping.")
|
||||
return@launch
|
||||
}
|
||||
|
@ -252,20 +273,27 @@ open class ModelManagerViewModel(
|
|||
// been initialized or not. If so, skip.
|
||||
launch {
|
||||
delay(500)
|
||||
if (model.instance == null) {
|
||||
if (model.instance == null && model.initializing) {
|
||||
updateModelInitializationStatus(
|
||||
model = model, status = ModelInitializationStatus.INITIALIZING
|
||||
model = model, status = ModelInitializationStatusType.INITIALIZING
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
val onDone: () -> Unit = {
|
||||
val onDone: (error: String) -> Unit = { error ->
|
||||
model.initializing = false
|
||||
if (model.instance != null) {
|
||||
Log.d(TAG, "Model '${model.name}' initialized successfully")
|
||||
model.initializing = false
|
||||
updateModelInitializationStatus(
|
||||
model = model,
|
||||
status = ModelInitializationStatus.INITIALIZED,
|
||||
status = ModelInitializationStatusType.INITIALIZED,
|
||||
)
|
||||
} else if (error.isNotEmpty()) {
|
||||
Log.d(TAG, "Model '${model.name}' failed to initialize")
|
||||
updateModelInitializationStatus(
|
||||
model = model,
|
||||
status = ModelInitializationStatusType.ERROR,
|
||||
error = error,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -310,7 +338,7 @@ open class ModelManagerViewModel(
|
|||
model.instance = null
|
||||
model.initializing = false
|
||||
updateModelInitializationStatus(
|
||||
model = model, status = ModelInitializationStatus.NOT_INITIALIZED
|
||||
model = model, status = ModelInitializationStatusType.NOT_INITIALIZED
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -380,8 +408,7 @@ open class ModelManagerViewModel(
|
|||
val connection = url.openConnection() as HttpURLConnection
|
||||
if (accessToken != null) {
|
||||
connection.setRequestProperty(
|
||||
"Authorization",
|
||||
"Bearer $accessToken"
|
||||
"Authorization", "Bearer $accessToken"
|
||||
)
|
||||
}
|
||||
connection.connect()
|
||||
|
@ -390,6 +417,47 @@ open class ModelManagerViewModel(
|
|||
return connection.responseCode
|
||||
}
|
||||
|
||||
fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) {
|
||||
Log.d(TAG, "adding local model: $fileName, $fileSize")
|
||||
|
||||
// Create model.
|
||||
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
|
||||
val model = Model(
|
||||
name = fileName,
|
||||
url = "",
|
||||
configs = configs,
|
||||
sizeInBytes = fileSize,
|
||||
downloadFileName = "$IMPORTS_DIR/$fileName",
|
||||
isLocalModel = true,
|
||||
)
|
||||
model.preProcess(task = task)
|
||||
task.models.add(model)
|
||||
|
||||
// Add initial status and states.
|
||||
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
||||
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
|
||||
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||
status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize
|
||||
)
|
||||
modelInstances[model.name] =
|
||||
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
||||
|
||||
// Update ui state.
|
||||
_uiState.update {
|
||||
uiState.value.copy(
|
||||
tasks = uiState.value.tasks.toList(),
|
||||
modelDownloadStatus = modelDownloadStatus,
|
||||
modelInitializationStatus = modelInstances
|
||||
)
|
||||
}
|
||||
task.updateTrigger.value = System.currentTimeMillis()
|
||||
|
||||
// Add to preference storage.
|
||||
val localModels = dataStoreRepository.readLocalModels().toMutableList()
|
||||
localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize))
|
||||
dataStoreRepository.saveLocalModels(localModels = localModels)
|
||||
}
|
||||
|
||||
fun getTokenStatusAndData(): TokenStatusAndData {
|
||||
// Try to load token data from DataStore.
|
||||
var tokenStatus = TokenStatus.NOT_STORED
|
||||
|
@ -436,8 +504,7 @@ open class ModelManagerViewModel(
|
|||
if (dataIntent == null) {
|
||||
onTokenRequested(
|
||||
TokenRequestResult(
|
||||
status = TokenRequestResultType.FAILED,
|
||||
errorMessage = "Empty auth result"
|
||||
status = TokenRequestResultType.FAILED, errorMessage = "Empty auth result"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
@ -481,8 +548,7 @@ open class ModelManagerViewModel(
|
|||
} else {
|
||||
onTokenRequested(
|
||||
TokenRequestResult(
|
||||
status = TokenRequestResultType.FAILED,
|
||||
errorMessage = errorMessage
|
||||
status = TokenRequestResultType.FAILED, errorMessage = errorMessage
|
||||
)
|
||||
)
|
||||
}
|
||||
|
@ -513,23 +579,49 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
|
||||
private fun createUiState(): ModelManagerUiState {
|
||||
val modelsByTaskName: Map<String, MutableList<Model>> =
|
||||
TASKS.associate { task -> task.type.label to task.models }
|
||||
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
|
||||
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
|
||||
for ((_, models) in modelsByTaskName.entries) {
|
||||
for (model in models) {
|
||||
for (task in TASKS) {
|
||||
for (model in task.models) {
|
||||
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
|
||||
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED
|
||||
modelInstances[model.name] =
|
||||
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
||||
}
|
||||
}
|
||||
|
||||
// Load local models.
|
||||
for (localModel in dataStoreRepository.readLocalModels()) {
|
||||
Log.d(TAG, "stored local model: $localModel")
|
||||
|
||||
// Create model.
|
||||
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
|
||||
val model = Model(
|
||||
name = localModel.fileName,
|
||||
url = "",
|
||||
configs = configs,
|
||||
sizeInBytes = localModel.fileSize,
|
||||
downloadFileName = "$IMPORTS_DIR/${localModel.fileName}",
|
||||
isLocalModel = true,
|
||||
)
|
||||
|
||||
// Add to task.
|
||||
val task = TASK_LLM_CHAT
|
||||
model.preProcess(task = task)
|
||||
task.models.add(model)
|
||||
|
||||
// Update status.
|
||||
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||
status = ModelDownloadStatusType.SUCCEEDED,
|
||||
receivedBytes = localModel.fileSize,
|
||||
totalBytes = localModel.fileSize
|
||||
)
|
||||
}
|
||||
|
||||
val textInputHistory = dataStoreRepository.readTextInputHistory()
|
||||
Log.d(TAG, "text input history: $textInputHistory")
|
||||
|
||||
return ModelManagerUiState(
|
||||
tasks = TASKS,
|
||||
modelsByTaskName = modelsByTaskName,
|
||||
modelDownloadStatus = modelDownloadStatus,
|
||||
modelInitializationStatus = modelInstances,
|
||||
textInputHistory = textInputHistory,
|
||||
|
@ -610,7 +702,8 @@ open class ModelManagerViewModel(
|
|||
|
||||
// Add initial status and states.
|
||||
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
|
||||
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED
|
||||
modelInstances[model.name] =
|
||||
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -677,9 +770,13 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
}
|
||||
|
||||
private fun updateModelInitializationStatus(model: Model, status: ModelInitializationStatus) {
|
||||
private fun updateModelInitializationStatus(
|
||||
model: Model,
|
||||
status: ModelInitializationStatusType,
|
||||
error: String = ""
|
||||
) {
|
||||
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
|
||||
curModelInstance[model.name] = status
|
||||
curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error)
|
||||
val newUiState = uiState.value.copy(modelInitializationStatus = curModelInstance)
|
||||
_uiState.update { newUiState }
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package com.google.aiedge.gallery.ui.preview
|
|||
|
||||
import com.google.aiedge.gallery.data.AccessTokenData
|
||||
import com.google.aiedge.gallery.data.DataStoreRepository
|
||||
import com.google.aiedge.gallery.data.LocalModelInfo
|
||||
|
||||
class PreviewDataStoreRepository : DataStoreRepository {
|
||||
override fun saveTextInputHistory(history: List<String>) {
|
||||
|
@ -40,4 +41,11 @@ class PreviewDataStoreRepository : DataStoreRepository {
|
|||
override fun readAccessTokenData(): AccessTokenData? {
|
||||
return null
|
||||
}
|
||||
|
||||
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
|
||||
}
|
||||
|
||||
override fun readLocalModels(): List<LocalModelInfo> {
|
||||
return listOf()
|
||||
}
|
||||
}
|
|
@ -17,7 +17,6 @@
|
|||
package com.google.aiedge.gallery.ui.preview
|
||||
|
||||
import android.content.Context
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerUiState
|
||||
|
@ -39,8 +38,6 @@ class PreviewModelManagerViewModel(context: Context) :
|
|||
}
|
||||
}
|
||||
|
||||
val modelsByTaskName: Map<String, MutableList<Model>> =
|
||||
ALL_PREVIEW_TASKS.associate { task -> task.type.label to task.models }
|
||||
val modelDownloadStatus = mapOf(
|
||||
MODEL_TEST1.name to ModelDownloadStatus(
|
||||
status = ModelDownloadStatusType.IN_PROGRESS,
|
||||
|
@ -61,7 +58,6 @@ class PreviewModelManagerViewModel(context: Context) :
|
|||
)
|
||||
val newUiState = ModelManagerUiState(
|
||||
tasks = ALL_PREVIEW_TASKS,
|
||||
modelsByTaskName = modelsByTaskName,
|
||||
modelDownloadStatus = modelDownloadStatus,
|
||||
modelInitializationStatus = mapOf(),
|
||||
selectedModel = MODEL_TEST2,
|
||||
|
|
|
@ -40,14 +40,14 @@ class TextClassificationInferenceResult(
|
|||
* Helper object for managing text classification models.
|
||||
*/
|
||||
object TextClassificationModelHelper {
|
||||
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
|
||||
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
|
||||
val modelByteBuffer = readFileToByteBuffer(File(model.getPath(context = context)))
|
||||
if (modelByteBuffer != null) {
|
||||
val options = TextClassifier.TextClassifierOptions.builder().setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetBuffer(modelByteBuffer).build()
|
||||
).build()
|
||||
model.instance = TextClassifier.createFromOptions(context, options)
|
||||
onDone()
|
||||
onDone("")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue