mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-12 09:22:23 -04:00
Add initial support for importing local model.
This commit is contained in:
parent
b2f35a86e7
commit
29b614355e
19 changed files with 789 additions and 126 deletions
|
@ -47,6 +47,8 @@ interface DataStoreRepository {
|
||||||
fun readThemeOverride(): String
|
fun readThemeOverride(): String
|
||||||
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
|
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
|
||||||
fun readAccessTokenData(): AccessTokenData?
|
fun readAccessTokenData(): AccessTokenData?
|
||||||
|
fun saveLocalModels(localModels: List<LocalModelInfo>)
|
||||||
|
fun readLocalModels(): List<LocalModelInfo>
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -79,6 +81,9 @@ class DefaultDataStoreRepository(
|
||||||
val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv")
|
val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv")
|
||||||
|
|
||||||
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
|
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
|
||||||
|
|
||||||
|
// Data for all imported local models.
|
||||||
|
val LOCAL_MODELS = stringPreferencesKey("local_models")
|
||||||
}
|
}
|
||||||
|
|
||||||
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
|
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
|
||||||
|
@ -155,6 +160,26 @@ class DefaultDataStoreRepository(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
|
||||||
|
runBlocking {
|
||||||
|
dataStore.edit { preferences ->
|
||||||
|
val gson = Gson()
|
||||||
|
val jsonString = gson.toJson(localModels)
|
||||||
|
preferences[PreferencesKeys.LOCAL_MODELS] = jsonString
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun readLocalModels(): List<LocalModelInfo> {
|
||||||
|
return runBlocking {
|
||||||
|
val preferences = dataStore.data.first()
|
||||||
|
val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]"
|
||||||
|
val gson = Gson()
|
||||||
|
val listType = object : TypeToken<List<LocalModelInfo>>() {}.type
|
||||||
|
gson.fromJson(infosStr, listType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private fun getTextInputHistory(preferences: Preferences): List<String> {
|
private fun getTextInputHistory(preferences: Preferences): List<String> {
|
||||||
val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]"
|
val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]"
|
||||||
val gson = Gson()
|
val gson = Gson()
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.data
|
package com.google.aiedge.gallery.data
|
||||||
|
|
||||||
|
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
||||||
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
|
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
|
||||||
import kotlinx.serialization.KSerializer
|
import kotlinx.serialization.KSerializer
|
||||||
import kotlinx.serialization.Serializable
|
import kotlinx.serialization.Serializable
|
||||||
|
@ -103,7 +104,7 @@ data class HfModel(
|
||||||
} else {
|
} else {
|
||||||
listOf("")
|
listOf("")
|
||||||
}
|
}
|
||||||
val fileName = "${id}_${(parts.lastOrNull() ?: "")}".replace(Regex("[^a-zA-Z0-9._-]"), "_")
|
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
|
||||||
|
|
||||||
// Generate configs based on the given default values.
|
// Generate configs based on the given default values.
|
||||||
val configs: List<Config> = when (task) {
|
val configs: List<Config> = when (task) {
|
||||||
|
|
|
@ -32,6 +32,8 @@ enum class LlmBackend {
|
||||||
CPU, GPU
|
CPU, GPU
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const val IMPORTS_DIR = "__imports"
|
||||||
|
|
||||||
/** A model for a task */
|
/** A model for a task */
|
||||||
data class Model(
|
data class Model(
|
||||||
/** The Hugging Face model ID (if applicable). */
|
/** The Hugging Face model ID (if applicable). */
|
||||||
|
@ -85,6 +87,9 @@ data class Model(
|
||||||
/** The prompt templates for the model (only for LLM). */
|
/** The prompt templates for the model (only for LLM). */
|
||||||
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
||||||
|
|
||||||
|
/** Whether the model is imported as a local model. */
|
||||||
|
val isLocalModel: Boolean = false,
|
||||||
|
|
||||||
// The following fields are managed by the app. Don't need to set manually.
|
// The following fields are managed by the app. Don't need to set manually.
|
||||||
var taskType: TaskType? = null,
|
var taskType: TaskType? = null,
|
||||||
var instance: Any? = null,
|
var instance: Any? = null,
|
||||||
|
@ -104,10 +109,11 @@ data class Model(
|
||||||
}
|
}
|
||||||
|
|
||||||
fun getPath(context: Context, fileName: String = downloadFileName): String {
|
fun getPath(context: Context, fileName: String = downloadFileName): String {
|
||||||
|
val baseDir = "${context.getExternalFilesDir(null)}"
|
||||||
return if (this.isZip && this.unzipDir.isNotEmpty()) {
|
return if (this.isZip && this.unzipDir.isNotEmpty()) {
|
||||||
"${context.getExternalFilesDir(null)}/${this.unzipDir}"
|
"$baseDir/${this.unzipDir}"
|
||||||
} else {
|
} else {
|
||||||
"${context.getExternalFilesDir(null)}/${fileName}"
|
"$baseDir/${fileName}"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,6 +146,9 @@ data class Model(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Data for a imported local model. */
|
||||||
|
data class LocalModelInfo(val fileName: String, val fileSize: Long)
|
||||||
|
|
||||||
enum class ModelDownloadStatusType {
|
enum class ModelDownloadStatusType {
|
||||||
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
|
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,8 @@ package com.google.aiedge.gallery.data
|
||||||
import androidx.annotation.StringRes
|
import androidx.annotation.StringRes
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.rounded.ImageSearch
|
import androidx.compose.material.icons.rounded.ImageSearch
|
||||||
|
import androidx.compose.runtime.MutableState
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
|
|
||||||
|
@ -63,7 +65,9 @@ data class Task(
|
||||||
@StringRes val textInputPlaceHolderRes: Int = R.string.chat_textinput_placeholder,
|
@StringRes val textInputPlaceHolderRes: Int = R.string.chat_textinput_placeholder,
|
||||||
|
|
||||||
// The following fields are managed by the app. Don't need to set manually.
|
// The following fields are managed by the app. Don't need to set manually.
|
||||||
var index: Int = -1
|
var index: Int = -1,
|
||||||
|
|
||||||
|
val updateTrigger: MutableState<Long> = mutableStateOf(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
val TASK_TEXT_CLASSIFICATION = Task(
|
val TASK_TEXT_CLASSIFICATION = Task(
|
||||||
|
|
|
@ -46,6 +46,7 @@ import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.TokenRequestResultType
|
import com.google.aiedge.gallery.ui.modelmanager.TokenRequestResultType
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.TokenStatus
|
import com.google.aiedge.gallery.ui.modelmanager.TokenStatus
|
||||||
|
@ -90,6 +91,7 @@ private const val TAG = "AGDownloadAndTryButton"
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun DownloadAndTryButton(
|
fun DownloadAndTryButton(
|
||||||
|
task: Task,
|
||||||
model: Model,
|
model: Model,
|
||||||
enabled: Boolean,
|
enabled: Boolean,
|
||||||
needToDownloadFirst: Boolean,
|
needToDownloadFirst: Boolean,
|
||||||
|
@ -106,17 +108,18 @@ fun DownloadAndTryButton(
|
||||||
val permissionLauncher = rememberLauncherForActivityResult(
|
val permissionLauncher = rememberLauncherForActivityResult(
|
||||||
ActivityResultContracts.RequestPermission()
|
ActivityResultContracts.RequestPermission()
|
||||||
) {
|
) {
|
||||||
modelManagerViewModel.downloadModel(model)
|
modelManagerViewModel.downloadModel(task = task, model = model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to kick off download.
|
// Function to kick off download.
|
||||||
val startDownload: (accessToken: String?) -> Unit = { accessToken ->
|
val startDownload: (accessToken: String?) -> Unit = { accessToken ->
|
||||||
model.accessToken = accessToken
|
model.accessToken = accessToken
|
||||||
onClicked()
|
onClicked()
|
||||||
checkNotificationPermissonAndStartDownload(
|
checkNotificationPermissionAndStartDownload(
|
||||||
context = context,
|
context = context,
|
||||||
launcher = permissionLauncher,
|
launcher = permissionLauncher,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
task = task,
|
||||||
model = model
|
model = model
|
||||||
)
|
)
|
||||||
checkingToken = false
|
checkingToken = false
|
||||||
|
|
|
@ -416,10 +416,11 @@ fun getTaskIconColor(index: Int): Color {
|
||||||
return MaterialTheme.customColors.taskIconColors[colorIndex]
|
return MaterialTheme.customColors.taskIconColors[colorIndex]
|
||||||
}
|
}
|
||||||
|
|
||||||
fun checkNotificationPermissonAndStartDownload(
|
fun checkNotificationPermissionAndStartDownload(
|
||||||
context: Context,
|
context: Context,
|
||||||
launcher: ManagedActivityResultLauncher<String, Boolean>,
|
launcher: ManagedActivityResultLauncher<String, Boolean>,
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
task: Task,
|
||||||
model: Model
|
model: Model
|
||||||
) {
|
) {
|
||||||
// Check permission
|
// Check permission
|
||||||
|
@ -428,7 +429,7 @@ fun checkNotificationPermissonAndStartDownload(
|
||||||
ContextCompat.checkSelfPermission(
|
ContextCompat.checkSelfPermission(
|
||||||
context, Manifest.permission.POST_NOTIFICATIONS
|
context, Manifest.permission.POST_NOTIFICATIONS
|
||||||
) -> {
|
) -> {
|
||||||
modelManagerViewModel.downloadModel(model)
|
modelManagerViewModel.downloadModel(task = task, model = model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, ask for permission
|
// Otherwise, ask for permission
|
||||||
|
@ -440,3 +441,14 @@ fun checkNotificationPermissonAndStartDownload(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun ensureValidFileName(fileName: String): String {
|
||||||
|
return fileName.replace(Regex("[^a-zA-Z0-9._-]"), "_")
|
||||||
|
}
|
||||||
|
|
||||||
|
fun cleanUpMediapipeTaskErrorMessage(message: String): String {
|
||||||
|
val index = message.indexOf("=== Source Location Trace")
|
||||||
|
if (index >= 0) {
|
||||||
|
return message.substring(0, index)
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
|
@ -41,10 +41,13 @@ import androidx.compose.foundation.layout.wrapContentHeight
|
||||||
import androidx.compose.foundation.lazy.LazyColumn
|
import androidx.compose.foundation.lazy.LazyColumn
|
||||||
import androidx.compose.foundation.lazy.items
|
import androidx.compose.foundation.lazy.items
|
||||||
import androidx.compose.foundation.lazy.rememberLazyListState
|
import androidx.compose.foundation.lazy.rememberLazyListState
|
||||||
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.outlined.Timer
|
import androidx.compose.material.icons.outlined.Timer
|
||||||
import androidx.compose.material.icons.rounded.ContentCopy
|
import androidx.compose.material.icons.rounded.ContentCopy
|
||||||
import androidx.compose.material.icons.rounded.Refresh
|
import androidx.compose.material.icons.rounded.Refresh
|
||||||
|
import androidx.compose.material3.Button
|
||||||
|
import androidx.compose.material3.Card
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
@ -83,11 +86,12 @@ import androidx.compose.ui.res.stringResource
|
||||||
import androidx.compose.ui.text.AnnotatedString
|
import androidx.compose.ui.text.AnnotatedString
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.compose.ui.window.Dialog
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatus
|
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
|
@ -113,6 +117,7 @@ fun ChatPanel(
|
||||||
onSendMessage: (Model, ChatMessage) -> Unit,
|
onSendMessage: (Model, ChatMessage) -> Unit,
|
||||||
onRunAgainClicked: (Model, ChatMessage) -> Unit,
|
onRunAgainClicked: (Model, ChatMessage) -> Unit,
|
||||||
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
|
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
|
||||||
|
navigateUp: () -> Unit,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
|
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
|
||||||
onStreamEnd: (Int) -> Unit = {},
|
onStreamEnd: (Int) -> Unit = {},
|
||||||
|
@ -140,6 +145,8 @@ fun ChatPanel(
|
||||||
var showMessageLongPressedSheet by remember { mutableStateOf(false) }
|
var showMessageLongPressedSheet by remember { mutableStateOf(false) }
|
||||||
val longPressedMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
|
val longPressedMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
|
||||||
|
|
||||||
|
var showErrorDialog by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
// Keep track of the last message and last message content.
|
// Keep track of the last message and last message content.
|
||||||
val lastMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
|
val lastMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
|
||||||
val lastMessageContent: MutableState<String> = remember { mutableStateOf("") }
|
val lastMessageContent: MutableState<String> = remember { mutableStateOf("") }
|
||||||
|
@ -201,6 +208,10 @@ fun ChatPanel(
|
||||||
val modelInitializationStatus =
|
val modelInitializationStatus =
|
||||||
modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
||||||
|
|
||||||
|
LaunchedEffect(modelInitializationStatus) {
|
||||||
|
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
|
||||||
|
}
|
||||||
|
|
||||||
Column(
|
Column(
|
||||||
modifier = modifier.imePadding()
|
modifier = modifier.imePadding()
|
||||||
) {
|
) {
|
||||||
|
@ -417,7 +428,7 @@ fun ChatPanel(
|
||||||
|
|
||||||
// Model initialization in-progress message.
|
// Model initialization in-progress message.
|
||||||
this@Column.AnimatedVisibility(
|
this@Column.AnimatedVisibility(
|
||||||
visible = modelInitializationStatus == ModelInitializationStatus.INITIALIZING,
|
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||||
enter = scaleIn() + fadeIn(),
|
enter = scaleIn() + fadeIn(),
|
||||||
exit = scaleOut() + fadeOut(),
|
exit = scaleOut() + fadeOut(),
|
||||||
modifier = Modifier.offset(y = 12.dp)
|
modifier = Modifier.offset(y = 12.dp)
|
||||||
|
@ -479,6 +490,47 @@ fun ChatPanel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Error dialog.
|
||||||
|
if (showErrorDialog) {
|
||||||
|
Dialog(
|
||||||
|
onDismissRequest = {
|
||||||
|
showErrorDialog = false
|
||||||
|
navigateUp()
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(20.dp),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
// Title
|
||||||
|
Text(
|
||||||
|
"Error",
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
modifier = Modifier.padding(bottom = 8.dp)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error
|
||||||
|
Text(
|
||||||
|
modelInitializationStatus?.error ?: "",
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.error,
|
||||||
|
)
|
||||||
|
|
||||||
|
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
|
||||||
|
Button(onClick = {
|
||||||
|
showErrorDialog = false
|
||||||
|
navigateUp()
|
||||||
|
}) {
|
||||||
|
Text("Close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Benchmark config dialog.
|
// Benchmark config dialog.
|
||||||
if (showBenchmarkConfigsDialog) {
|
if (showBenchmarkConfigsDialog) {
|
||||||
BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false },
|
BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false },
|
||||||
|
@ -547,6 +599,7 @@ fun ChatPanelPreview() {
|
||||||
task = task,
|
task = task,
|
||||||
selectedModel = TASK_TEST1.models[1],
|
selectedModel = TASK_TEST1.models[1],
|
||||||
viewModel = PreviewChatModel(context = context),
|
viewModel = PreviewChatModel(context = context),
|
||||||
|
navigateUp = {},
|
||||||
onSendMessage = { _, _ -> },
|
onSendMessage = { _, _ -> },
|
||||||
onRunAgainClicked = { _, _ -> },
|
onRunAgainClicked = { _, _ -> },
|
||||||
onBenchmarkClicked = { _, _, _, _ -> },
|
onBenchmarkClicked = { _, _, _, _ -> },
|
||||||
|
|
|
@ -55,7 +55,7 @@ import com.google.aiedge.gallery.data.AppBarActionType
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload
|
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
|
@ -104,7 +104,7 @@ fun ChatView(
|
||||||
val launcher = rememberLauncherForActivityResult(
|
val launcher = rememberLauncherForActivityResult(
|
||||||
ActivityResultContracts.RequestPermission()
|
ActivityResultContracts.RequestPermission()
|
||||||
) {
|
) {
|
||||||
modelManagerViewModel.downloadModel(selectedModel)
|
modelManagerViewModel.downloadModel(task = task, model = selectedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
val handleNavigateUp = {
|
val handleNavigateUp = {
|
||||||
|
@ -245,10 +245,11 @@ fun ChatView(
|
||||||
exit = fadeOut()
|
exit = fadeOut()
|
||||||
) {
|
) {
|
||||||
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
|
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
|
||||||
checkNotificationPermissonAndStartDownload(
|
checkNotificationPermissionAndStartDownload(
|
||||||
context = context,
|
context = context,
|
||||||
launcher = launcher,
|
launcher = launcher,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
task = task,
|
||||||
model = curSelectedModel
|
model = curSelectedModel
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
@ -261,6 +262,7 @@ fun ChatView(
|
||||||
task = task,
|
task = task,
|
||||||
selectedModel = curSelectedModel,
|
selectedModel = curSelectedModel,
|
||||||
viewModel = viewModel,
|
viewModel = viewModel,
|
||||||
|
navigateUp = navigateUp,
|
||||||
onSendMessage = onSendMessage,
|
onSendMessage = onSendMessage,
|
||||||
onRunAgainClicked = onRunAgainClicked,
|
onRunAgainClicked = onRunAgainClicked,
|
||||||
onBenchmarkClicked = onBenchmarkClicked,
|
onBenchmarkClicked = onBenchmarkClicked,
|
||||||
|
|
|
@ -35,6 +35,7 @@ import androidx.compose.foundation.layout.offset
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.rounded.ChevronRight
|
||||||
import androidx.compose.material.icons.rounded.Settings
|
import androidx.compose.material.icons.rounded.Settings
|
||||||
import androidx.compose.material.icons.rounded.UnfoldLess
|
import androidx.compose.material.icons.rounded.UnfoldLess
|
||||||
import androidx.compose.material.icons.rounded.UnfoldMore
|
import androidx.compose.material.icons.rounded.UnfoldMore
|
||||||
|
@ -68,7 +69,7 @@ import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
|
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
|
||||||
import com.google.aiedge.gallery.ui.common.TaskIcon
|
import com.google.aiedge.gallery.ui.common.TaskIcon
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
||||||
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload
|
import com.google.aiedge.gallery.ui.common.checkNotificationPermissionAndStartDownload
|
||||||
import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
||||||
import com.google.aiedge.gallery.ui.common.getTaskIconColor
|
import com.google.aiedge.gallery.ui.common.getTaskIconColor
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
@ -113,7 +114,7 @@ fun ModelItem(
|
||||||
val launcher = rememberLauncherForActivityResult(
|
val launcher = rememberLauncherForActivityResult(
|
||||||
ActivityResultContracts.RequestPermission()
|
ActivityResultContracts.RequestPermission()
|
||||||
) {
|
) {
|
||||||
modelManagerViewModel.downloadModel(model)
|
modelManagerViewModel.downloadModel(task = task, model = model)
|
||||||
}
|
}
|
||||||
|
|
||||||
var isExpanded by remember { mutableStateOf(false) }
|
var isExpanded by remember { mutableStateOf(false) }
|
||||||
|
@ -156,10 +157,11 @@ fun ModelItem(
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
downloadStatus = downloadStatus,
|
downloadStatus = downloadStatus,
|
||||||
onDownloadClicked = { model ->
|
onDownloadClicked = { model ->
|
||||||
checkNotificationPermissonAndStartDownload(
|
checkNotificationPermissionAndStartDownload(
|
||||||
context = context,
|
context = context,
|
||||||
launcher = launcher,
|
launcher = launcher,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
task = task,
|
||||||
model = model
|
model = model
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
@ -186,7 +188,9 @@ fun ModelItem(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Icon(
|
Icon(
|
||||||
if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
// For local model, show ">" directly indicating users can just tap the model item to
|
||||||
|
// go into it without needing to expand it first.
|
||||||
|
if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
||||||
contentDescription = "",
|
contentDescription = "",
|
||||||
tint = getTaskIconColor(task),
|
tint = getTaskIconColor(task),
|
||||||
)
|
)
|
||||||
|
@ -237,6 +241,7 @@ fun ModelItem(
|
||||||
val needToDownloadFirst =
|
val needToDownloadFirst =
|
||||||
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
|
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
|
||||||
DownloadAndTryButton(
|
DownloadAndTryButton(
|
||||||
|
task = task,
|
||||||
model = model,
|
model = model,
|
||||||
enabled = isExpanded,
|
enabled = isExpanded,
|
||||||
needToDownloadFirst = needToDownloadFirst,
|
needToDownloadFirst = needToDownloadFirst,
|
||||||
|
@ -266,7 +271,13 @@ fun ModelItem(
|
||||||
)
|
)
|
||||||
boxModifier = if (canExpand) {
|
boxModifier = if (canExpand) {
|
||||||
boxModifier.clickable(
|
boxModifier.clickable(
|
||||||
onClick = { isExpanded = !isExpanded },
|
onClick = {
|
||||||
|
if (!model.isLocalModel) {
|
||||||
|
isExpanded = !isExpanded
|
||||||
|
} else {
|
||||||
|
onModelClicked(model)
|
||||||
|
}
|
||||||
|
},
|
||||||
interactionSource = remember { MutableInteractionSource() },
|
interactionSource = remember { MutableInteractionSource() },
|
||||||
indication = ripple(
|
indication = ripple(
|
||||||
bounded = true,
|
bounded = true,
|
||||||
|
|
|
@ -124,7 +124,7 @@ fun ModelItemActionButton(
|
||||||
|
|
||||||
if (showConfirmDeleteDialog) {
|
if (showConfirmDeleteDialog) {
|
||||||
ConfirmDeleteModelDialog(model = model, onConfirm = {
|
ConfirmDeleteModelDialog(model = model, onConfirm = {
|
||||||
modelManagerViewModel.deleteModel(model)
|
modelManagerViewModel.deleteModel(task = task, model = model)
|
||||||
showConfirmDeleteDialog = false
|
showConfirmDeleteDialog = false
|
||||||
}, onDismiss = {
|
}, onDismiss = {
|
||||||
showConfirmDeleteDialog = false
|
showConfirmDeleteDialog = false
|
||||||
|
|
|
@ -48,7 +48,7 @@ class ImageClassificationInferenceResult(
|
||||||
//TODO: handle error.
|
//TODO: handle error.
|
||||||
|
|
||||||
object ImageClassificationModelHelper {
|
object ImageClassificationModelHelper {
|
||||||
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
|
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
|
||||||
val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU)
|
val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU)
|
||||||
TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask ->
|
TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask ->
|
||||||
val optionsBuilder = TfLiteInitializationOptions.builder()
|
val optionsBuilder = TfLiteInitializationOptions.builder()
|
||||||
|
@ -69,7 +69,7 @@ object ImageClassificationModelHelper {
|
||||||
File(model.getPath(context = context)), interpreterOption
|
File(model.getPath(context = context)), interpreterOption
|
||||||
)
|
)
|
||||||
model.instance = interpreter
|
model.instance = interpreter
|
||||||
onDone()
|
onDone("")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.ui.common.LatencyProvider
|
import com.google.aiedge.gallery.ui.common.LatencyProvider
|
||||||
|
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
|
|
||||||
private const val TAG = "AGImageGenerationModelHelper"
|
private const val TAG = "AGImageGenerationModelHelper"
|
||||||
|
@ -33,12 +34,17 @@ class ImageGenerationInferenceResult(
|
||||||
) : LatencyProvider
|
) : LatencyProvider
|
||||||
|
|
||||||
object ImageGenerationModelHelper {
|
object ImageGenerationModelHelper {
|
||||||
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
|
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
|
||||||
val options = ImageGenerator.ImageGeneratorOptions.builder()
|
try {
|
||||||
.setImageGeneratorModelDirectory(model.getPath(context = context))
|
val options = ImageGenerator.ImageGeneratorOptions.builder()
|
||||||
.build()
|
.setImageGeneratorModelDirectory(model.getPath(context = context))
|
||||||
model.instance = ImageGenerator.createFromOptions(context, options)
|
.build()
|
||||||
onDone()
|
model.instance = ImageGenerator.createFromOptions(context, options)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
onDone("")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun cleanUp(model: Model) {
|
fun cleanUp(model: Model) {
|
||||||
|
|
|
@ -21,6 +21,7 @@ import android.util.Log
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.LlmBackend
|
import com.google.aiedge.gallery.data.LlmBackend
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||||
|
|
||||||
|
@ -40,7 +41,7 @@ object LlmChatModelHelper {
|
||||||
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
|
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
|
||||||
|
|
||||||
fun initialize(
|
fun initialize(
|
||||||
context: Context, model: Model, onDone: () -> Unit
|
context: Context, model: Model, onDone: (String) -> Unit
|
||||||
) {
|
) {
|
||||||
val maxTokens =
|
val maxTokens =
|
||||||
model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN)
|
model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN)
|
||||||
|
@ -68,9 +69,10 @@ object LlmChatModelHelper {
|
||||||
)
|
)
|
||||||
model.instance = LlmModelInstance(engine = llmInference, session = session)
|
model.instance = LlmModelInstance(engine = llmInference, session = session)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
e.printStackTrace()
|
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
onDone()
|
onDone("")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun cleanUp(model: Model) {
|
fun cleanUp(model: Model) {
|
||||||
|
|
|
@ -0,0 +1,260 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.modelmanager
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.net.Uri
|
||||||
|
import android.provider.OpenableColumns
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.compose.animation.core.Animatable
|
||||||
|
import androidx.compose.animation.core.tween
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.rounded.Error
|
||||||
|
import androidx.compose.material3.Button
|
||||||
|
import androidx.compose.material3.Card
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.LinearProgressIndicator
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableFloatStateOf
|
||||||
|
import androidx.compose.runtime.mutableLongStateOf
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.compose.ui.window.Dialog
|
||||||
|
import androidx.compose.ui.window.DialogProperties
|
||||||
|
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||||
|
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
||||||
|
import com.google.aiedge.gallery.ui.common.humanReadableSize
|
||||||
|
import kotlinx.coroutines.CoroutineScope
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import java.io.File
|
||||||
|
import java.io.FileOutputStream
|
||||||
|
import java.net.URLDecoder
|
||||||
|
import java.nio.charset.StandardCharsets
|
||||||
|
|
||||||
|
private const val TAG = "AGModelImportDialog"
|
||||||
|
|
||||||
|
data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "")
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ModelImportDialog(
|
||||||
|
uri: Uri, onDone: (ModelImportInfo) -> Unit
|
||||||
|
) {
|
||||||
|
val context = LocalContext.current
|
||||||
|
val coroutineScope = rememberCoroutineScope()
|
||||||
|
|
||||||
|
var fileName by remember { mutableStateOf("") }
|
||||||
|
var fileSize by remember { mutableLongStateOf(0L) }
|
||||||
|
var error by remember { mutableStateOf("") }
|
||||||
|
var progress by remember { mutableFloatStateOf(0f) }
|
||||||
|
|
||||||
|
LaunchedEffect(Unit) {
|
||||||
|
error = ""
|
||||||
|
|
||||||
|
// Get basic info.
|
||||||
|
val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri)
|
||||||
|
fileSize = info.first
|
||||||
|
fileName = ensureValidFileName(info.second)
|
||||||
|
|
||||||
|
// Import.
|
||||||
|
importModel(
|
||||||
|
context = context,
|
||||||
|
coroutineScope = coroutineScope,
|
||||||
|
fileName = fileName,
|
||||||
|
fileSize = fileSize,
|
||||||
|
uri = uri,
|
||||||
|
onDone = {
|
||||||
|
onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error))
|
||||||
|
},
|
||||||
|
onProgress = {
|
||||||
|
progress = it
|
||||||
|
},
|
||||||
|
onError = {
|
||||||
|
error = it
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
Dialog(
|
||||||
|
properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false),
|
||||||
|
onDismissRequest = {},
|
||||||
|
) {
|
||||||
|
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(20.dp),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
// Title.
|
||||||
|
Text(
|
||||||
|
"Importing...",
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
modifier = Modifier.padding(bottom = 8.dp)
|
||||||
|
)
|
||||||
|
|
||||||
|
// No error.
|
||||||
|
if (error.isEmpty()) {
|
||||||
|
// Progress bar.
|
||||||
|
Column(verticalArrangement = Arrangement.spacedBy(4.dp)) {
|
||||||
|
Text(
|
||||||
|
"$fileName (${fileSize.humanReadableSize()})",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
)
|
||||||
|
val animatedProgress = remember { Animatable(0f) }
|
||||||
|
LinearProgressIndicator(
|
||||||
|
progress = { animatedProgress.value },
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(bottom = 8.dp),
|
||||||
|
)
|
||||||
|
LaunchedEffect(progress) {
|
||||||
|
animatedProgress.animateTo(progress, animationSpec = tween(150))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Has error.
|
||||||
|
else {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.Top,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Rounded.Error,
|
||||||
|
contentDescription = "",
|
||||||
|
tint = MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
error,
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = MaterialTheme.colorScheme.error,
|
||||||
|
modifier = Modifier.padding(top = 4.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
|
||||||
|
Button(onClick = {
|
||||||
|
onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error))
|
||||||
|
}) {
|
||||||
|
Text("Close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun importModel(
|
||||||
|
context: Context,
|
||||||
|
coroutineScope: CoroutineScope,
|
||||||
|
fileName: String,
|
||||||
|
fileSize: Long,
|
||||||
|
uri: Uri,
|
||||||
|
onDone: () -> Unit,
|
||||||
|
onProgress: (Float) -> Unit,
|
||||||
|
onError: (String) -> Unit,
|
||||||
|
) {
|
||||||
|
// TODO: handle error.
|
||||||
|
coroutineScope.launch(Dispatchers.IO) {
|
||||||
|
// Get the last component of the uri path as the imported file name.
|
||||||
|
val decodedUri = URLDecoder.decode(uri.toString(), StandardCharsets.UTF_8.name())
|
||||||
|
Log.d(TAG, "importing model from $decodedUri. File name: $fileName. File size: $fileSize")
|
||||||
|
|
||||||
|
// Create <app_external_dir>/imports if not exist.
|
||||||
|
val importsDir = File(context.getExternalFilesDir(null), IMPORTS_DIR)
|
||||||
|
if (!importsDir.exists()) {
|
||||||
|
importsDir.mkdirs()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import by copying the file over.
|
||||||
|
val outputFile = File(context.getExternalFilesDir(null), "$IMPORTS_DIR/$fileName")
|
||||||
|
val outputStream = FileOutputStream(outputFile)
|
||||||
|
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
|
||||||
|
var bytesRead: Int
|
||||||
|
var lastSetProgressTs: Long = 0
|
||||||
|
var importedBytes = 0L
|
||||||
|
val inputStream = context.contentResolver.openInputStream(uri)
|
||||||
|
try {
|
||||||
|
if (inputStream != null) {
|
||||||
|
while (inputStream.read(buffer).also { bytesRead = it } != -1) {
|
||||||
|
outputStream.write(buffer, 0, bytesRead)
|
||||||
|
importedBytes += bytesRead
|
||||||
|
|
||||||
|
// Report progress every 200 ms.
|
||||||
|
val curTs = System.currentTimeMillis()
|
||||||
|
if (curTs - lastSetProgressTs > 200) {
|
||||||
|
Log.d(TAG, "importing progress: $importedBytes, $fileSize")
|
||||||
|
lastSetProgressTs = curTs
|
||||||
|
if (fileSize != 0L) {
|
||||||
|
onProgress(importedBytes.toFloat() / fileSize.toFloat())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
e.printStackTrace()
|
||||||
|
onError(e.message ?: "Failed to import")
|
||||||
|
return@launch
|
||||||
|
} finally {
|
||||||
|
inputStream?.close()
|
||||||
|
outputStream.close()
|
||||||
|
}
|
||||||
|
Log.d(TAG, "import done")
|
||||||
|
onProgress(1f)
|
||||||
|
onDone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun getFileSizeAndDisplayNameFromUri(context: Context, uri: Uri): Pair<Long, String> {
|
||||||
|
val contentResolver = context.contentResolver
|
||||||
|
var fileSize = 0L
|
||||||
|
var displayName = ""
|
||||||
|
|
||||||
|
try {
|
||||||
|
contentResolver.query(
|
||||||
|
uri, arrayOf(OpenableColumns.SIZE, OpenableColumns.DISPLAY_NAME), null, null, null
|
||||||
|
)?.use { cursor ->
|
||||||
|
if (cursor.moveToFirst()) {
|
||||||
|
val sizeIndex = cursor.getColumnIndexOrThrow(OpenableColumns.SIZE)
|
||||||
|
fileSize = cursor.getLong(sizeIndex)
|
||||||
|
|
||||||
|
val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME)
|
||||||
|
displayName = cursor.getString(nameIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
e.printStackTrace()
|
||||||
|
return Pair(0L, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
return Pair(fileSize, displayName)
|
||||||
|
}
|
|
@ -16,24 +16,46 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.modelmanager
|
package com.google.aiedge.gallery.ui.modelmanager
|
||||||
|
|
||||||
|
import android.content.Intent
|
||||||
|
import android.net.Uri
|
||||||
|
import android.os.Build
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||||
|
import androidx.activity.result.ActivityResultLauncher
|
||||||
|
import androidx.activity.result.contract.ActivityResultContracts
|
||||||
|
import androidx.annotation.RequiresApi
|
||||||
import androidx.compose.foundation.clickable
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.PaddingValues
|
import androidx.compose.foundation.layout.PaddingValues
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.Spacer
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.height
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.foundation.lazy.LazyColumn
|
import androidx.compose.foundation.lazy.LazyColumn
|
||||||
import androidx.compose.foundation.lazy.items
|
import androidx.compose.foundation.lazy.items
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
||||||
|
import androidx.compose.material.icons.filled.Add
|
||||||
import androidx.compose.material.icons.outlined.Code
|
import androidx.compose.material.icons.outlined.Code
|
||||||
import androidx.compose.material.icons.outlined.Description
|
import androidx.compose.material.icons.outlined.Description
|
||||||
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.ModalBottomSheet
|
||||||
|
import androidx.compose.material3.SmallFloatingActionButton
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.material3.rememberModalBottomSheetState
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.derivedStateOf
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
|
@ -53,8 +75,14 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
import com.google.aiedge.gallery.ui.theme.customColors
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
|
private const val TAG = "AGModelList"
|
||||||
|
|
||||||
/** The list of models in the model manager. */
|
/** The list of models in the model manager. */
|
||||||
|
@RequiresApi(Build.VERSION_CODES.O)
|
||||||
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelList(
|
fun ModelList(
|
||||||
task: Task,
|
task: Task,
|
||||||
|
@ -63,65 +91,214 @@ fun ModelList(
|
||||||
onModelClicked: (Model) -> Unit,
|
onModelClicked: (Model) -> Unit,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
) {
|
) {
|
||||||
LazyColumn(
|
var showAddModelSheet by remember { mutableStateOf(false) }
|
||||||
modifier = modifier.padding(top = 8.dp),
|
var showImportingDialog by remember { mutableStateOf(false) }
|
||||||
contentPadding = contentPadding,
|
val curFileUri = remember { mutableStateOf<Uri?>(null) }
|
||||||
verticalArrangement = Arrangement.spacedBy(8.dp),
|
val sheetState = rememberModalBottomSheetState()
|
||||||
) {
|
val coroutineScope = rememberCoroutineScope()
|
||||||
// Headline.
|
|
||||||
item(key = "headline") {
|
|
||||||
Text(
|
|
||||||
task.description,
|
|
||||||
textAlign = TextAlign.Center,
|
|
||||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxWidth()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// URLs.
|
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
|
||||||
item(key = "urls") {
|
// be properly updated.
|
||||||
Row(
|
val models by remember {
|
||||||
horizontalArrangement = Arrangement.Center,
|
derivedStateOf {
|
||||||
modifier = Modifier
|
val trigger = task.updateTrigger.value
|
||||||
.fillMaxWidth()
|
if (trigger >= 0) {
|
||||||
.padding(top = 12.dp, bottom = 16.dp),
|
task.models.toList().filter { !it.isLocalModel }
|
||||||
) {
|
} else {
|
||||||
Column(
|
listOf()
|
||||||
horizontalAlignment = Alignment.Start,
|
}
|
||||||
verticalArrangement = Arrangement.spacedBy(4.dp),
|
}
|
||||||
|
}
|
||||||
|
val localModels by remember {
|
||||||
|
derivedStateOf {
|
||||||
|
val trigger = task.updateTrigger.value
|
||||||
|
if (trigger >= 0) {
|
||||||
|
task.models.toList().filter { it.isLocalModel }
|
||||||
|
} else {
|
||||||
|
listOf()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
||||||
|
contract = ActivityResultContracts.StartActivityForResult()
|
||||||
|
) { result ->
|
||||||
|
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
||||||
|
result.data?.data?.let { uri ->
|
||||||
|
curFileUri.value = uri
|
||||||
|
showImportingDialog = true
|
||||||
|
} ?: run {
|
||||||
|
Log.d(TAG, "No file selected or URI is null.")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Log.d(TAG, "File picking cancelled.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Box(contentAlignment = Alignment.BottomEnd) {
|
||||||
|
LazyColumn(
|
||||||
|
modifier = modifier.padding(top = 8.dp),
|
||||||
|
contentPadding = contentPadding,
|
||||||
|
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
) {
|
||||||
|
// Headline.
|
||||||
|
item(key = "headline") {
|
||||||
|
Text(
|
||||||
|
task.description,
|
||||||
|
textAlign = TextAlign.Center,
|
||||||
|
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||||
|
modifier = Modifier.fillMaxWidth()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// URLs.
|
||||||
|
item(key = "urls") {
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.Center,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(top = 12.dp, bottom = 16.dp),
|
||||||
) {
|
) {
|
||||||
if (task.docUrl.isNotEmpty()) {
|
Column(
|
||||||
ClickableLink(
|
horizontalAlignment = Alignment.Start,
|
||||||
url = task.docUrl,
|
verticalArrangement = Arrangement.spacedBy(4.dp),
|
||||||
linkText = "API Documentation",
|
) {
|
||||||
icon = Icons.Outlined.Description
|
if (task.docUrl.isNotEmpty()) {
|
||||||
)
|
ClickableLink(
|
||||||
}
|
url = task.docUrl, linkText = "API Documentation", icon = Icons.Outlined.Description
|
||||||
if (task.sourceCodeUrl.isNotEmpty()) {
|
)
|
||||||
ClickableLink(
|
}
|
||||||
url = task.sourceCodeUrl,
|
if (task.sourceCodeUrl.isNotEmpty()) {
|
||||||
linkText = "Example code",
|
ClickableLink(
|
||||||
icon = Icons.Outlined.Code
|
url = task.sourceCodeUrl, linkText = "Example code", icon = Icons.Outlined.Code
|
||||||
)
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// List of models within a task.
|
||||||
|
items(items = models) { model ->
|
||||||
|
Box {
|
||||||
|
ModelItem(
|
||||||
|
model = model,
|
||||||
|
task = task,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
onModelClicked = onModelClicked,
|
||||||
|
modifier = Modifier.padding(horizontal = 12.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Title for local models.
|
||||||
|
if (localModels.isNotEmpty()) {
|
||||||
|
item(key = "localModelsTitle") {
|
||||||
|
Text(
|
||||||
|
"Local models",
|
||||||
|
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(horizontal = 16.dp)
|
||||||
|
.padding(top = 24.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List of local models within a task.
|
||||||
|
items(items = localModels) { model ->
|
||||||
|
Box {
|
||||||
|
ModelItem(
|
||||||
|
model = model,
|
||||||
|
task = task,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
onModelClicked = onModelClicked,
|
||||||
|
modifier = Modifier.padding(horizontal = 12.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
item(key = "bottomPadding") {
|
||||||
|
Spacer(modifier = Modifier.height(60.dp))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// List of models within a task.
|
// Add model button at the bottom right.
|
||||||
items(items = task.models) { model ->
|
Box(
|
||||||
Box {
|
modifier = Modifier
|
||||||
ModelItem(
|
.padding(end = 16.dp)
|
||||||
model = model,
|
.padding(bottom = contentPadding.calculateBottomPadding())
|
||||||
task = task,
|
) {
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
SmallFloatingActionButton(
|
||||||
onModelClicked = onModelClicked,
|
onClick = {
|
||||||
modifier = Modifier.padding(start = 12.dp, end = 12.dp)
|
showAddModelSheet = true
|
||||||
)
|
},
|
||||||
|
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||||
|
contentColor = MaterialTheme.colorScheme.secondary,
|
||||||
|
) {
|
||||||
|
Icon(Icons.Filled.Add, "")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (showAddModelSheet) {
|
||||||
|
ModalBottomSheet(
|
||||||
|
onDismissRequest = { showAddModelSheet = false },
|
||||||
|
sheetState = sheetState,
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"Add custom model",
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
|
||||||
|
)
|
||||||
|
Box(modifier = Modifier.clickable {
|
||||||
|
coroutineScope.launch {
|
||||||
|
// Give it sometime to show the click effect.
|
||||||
|
delay(200)
|
||||||
|
showAddModelSheet = false
|
||||||
|
|
||||||
|
// Show file picker.
|
||||||
|
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||||
|
addCategory(Intent.CATEGORY_OPENABLE)
|
||||||
|
type = "*/*"
|
||||||
|
putExtra(
|
||||||
|
Intent.EXTRA_MIME_TYPES,
|
||||||
|
arrayOf("application/x-binary", "application/octet-stream")
|
||||||
|
)
|
||||||
|
// Single select.
|
||||||
|
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
||||||
|
}
|
||||||
|
filePickerLauncher.launch(intent)
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(16.dp)
|
||||||
|
) {
|
||||||
|
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
|
||||||
|
Text("Add local model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (showImportingDialog) {
|
||||||
|
curFileUri.value?.let { uri ->
|
||||||
|
ModelImportDialog(uri = uri, onDone = { info ->
|
||||||
|
showImportingDialog = false
|
||||||
|
|
||||||
|
if (info.error.isEmpty()) {
|
||||||
|
// TODO: support other model types.
|
||||||
|
modelManagerViewModel.addLocalLlmModel(
|
||||||
|
task = task,
|
||||||
|
fileName = info.fileName,
|
||||||
|
fileSize = info.fileSize
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
|
@ -132,15 +309,11 @@ fun ClickableLink(
|
||||||
) {
|
) {
|
||||||
val uriHandler = LocalUriHandler.current
|
val uriHandler = LocalUriHandler.current
|
||||||
val annotatedText = AnnotatedString(
|
val annotatedText = AnnotatedString(
|
||||||
text = linkText,
|
text = linkText, spanStyles = listOf(
|
||||||
spanStyles = listOf(
|
|
||||||
AnnotatedString.Range(
|
AnnotatedString.Range(
|
||||||
item = SpanStyle(
|
item = SpanStyle(
|
||||||
color = MaterialTheme.customColors.linkColor,
|
color = MaterialTheme.customColors.linkColor, textDecoration = TextDecoration.Underline
|
||||||
textDecoration = TextDecoration.Underline
|
), start = 0, end = linkText.length
|
||||||
),
|
|
||||||
start = 0,
|
|
||||||
end = linkText.length
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -163,6 +336,7 @@ fun ClickableLink(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@RequiresApi(Build.VERSION_CODES.O)
|
||||||
@Preview(showBackground = true)
|
@Preview(showBackground = true)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelListPreview() {
|
fun ModelListPreview() {
|
||||||
|
|
|
@ -24,16 +24,20 @@ import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.google.aiedge.gallery.data.AGWorkInfo
|
import com.google.aiedge.gallery.data.AGWorkInfo
|
||||||
import com.google.aiedge.gallery.data.AccessTokenData
|
import com.google.aiedge.gallery.data.AccessTokenData
|
||||||
|
import com.google.aiedge.gallery.data.Config
|
||||||
import com.google.aiedge.gallery.data.DataStoreRepository
|
import com.google.aiedge.gallery.data.DataStoreRepository
|
||||||
import com.google.aiedge.gallery.data.DownloadRepository
|
import com.google.aiedge.gallery.data.DownloadRepository
|
||||||
import com.google.aiedge.gallery.data.EMPTY_MODEL
|
import com.google.aiedge.gallery.data.EMPTY_MODEL
|
||||||
import com.google.aiedge.gallery.data.HfModel
|
import com.google.aiedge.gallery.data.HfModel
|
||||||
import com.google.aiedge.gallery.data.HfModelDetails
|
import com.google.aiedge.gallery.data.HfModelDetails
|
||||||
import com.google.aiedge.gallery.data.HfModelSummary
|
import com.google.aiedge.gallery.data.HfModelSummary
|
||||||
|
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||||
|
import com.google.aiedge.gallery.data.LocalModelInfo
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.TASKS
|
import com.google.aiedge.gallery.data.TASKS
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
import com.google.aiedge.gallery.data.getModelByName
|
import com.google.aiedge.gallery.data.getModelByName
|
||||||
|
@ -41,6 +45,7 @@ import com.google.aiedge.gallery.ui.common.AuthConfig
|
||||||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
|
||||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
|
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.async
|
import kotlinx.coroutines.async
|
||||||
|
@ -66,8 +71,12 @@ private const val TAG = "AGModelManagerViewModel"
|
||||||
private const val HG_COMMUNITY = "jinjingforevercommunity"
|
private const val HG_COMMUNITY = "jinjingforevercommunity"
|
||||||
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
|
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
|
||||||
|
|
||||||
enum class ModelInitializationStatus {
|
data class ModelInitializationStatus(
|
||||||
NOT_INITIALIZED, INITIALIZING, INITIALIZED,
|
val status: ModelInitializationStatusType, var error: String = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
enum class ModelInitializationStatusType {
|
||||||
|
NOT_INITIALIZED, INITIALIZING, INITIALIZED, ERROR
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class TokenStatus {
|
enum class TokenStatus {
|
||||||
|
@ -84,8 +93,7 @@ data class TokenStatusAndData(
|
||||||
)
|
)
|
||||||
|
|
||||||
data class TokenRequestResult(
|
data class TokenRequestResult(
|
||||||
val status: TokenRequestResultType,
|
val status: TokenRequestResultType, val errorMessage: String? = null
|
||||||
val errorMessage: String? = null
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data class ModelManagerUiState(
|
data class ModelManagerUiState(
|
||||||
|
@ -94,11 +102,6 @@ data class ModelManagerUiState(
|
||||||
*/
|
*/
|
||||||
val tasks: List<Task>,
|
val tasks: List<Task>,
|
||||||
|
|
||||||
/**
|
|
||||||
* A map that stores lists of models indexed by task name.
|
|
||||||
*/
|
|
||||||
val modelsByTaskName: Map<String, MutableList<Model>>,
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A map that tracks the download status of each model, indexed by model name.
|
* A map that tracks the download status of each model, indexed by model name.
|
||||||
*/
|
*/
|
||||||
|
@ -191,14 +194,14 @@ open class ModelManagerViewModel(
|
||||||
_uiState.update { _uiState.value.copy(selectedModel = model) }
|
_uiState.update { _uiState.value.copy(selectedModel = model) }
|
||||||
}
|
}
|
||||||
|
|
||||||
fun downloadModel(model: Model) {
|
fun downloadModel(task: Task, model: Model) {
|
||||||
// Update status.
|
// Update status.
|
||||||
setDownloadStatus(
|
setDownloadStatus(
|
||||||
curModel = model, status = ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS)
|
curModel = model, status = ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Delete the model files first.
|
// Delete the model files first.
|
||||||
deleteModel(model = model)
|
deleteModel(task = task, model = model)
|
||||||
|
|
||||||
// Start to send download request.
|
// Start to send download request.
|
||||||
downloadRepository.downloadModel(
|
downloadRepository.downloadModel(
|
||||||
|
@ -210,7 +213,7 @@ open class ModelManagerViewModel(
|
||||||
downloadRepository.cancelDownloadModel(model)
|
downloadRepository.cancelDownloadModel(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun deleteModel(model: Model) {
|
fun deleteModel(task: Task, model: Model) {
|
||||||
deleteFileFromExternalFilesDir(model.downloadFileName)
|
deleteFileFromExternalFilesDir(model.downloadFileName)
|
||||||
for (file in model.extraDataFiles) {
|
for (file in model.extraDataFiles) {
|
||||||
deleteFileFromExternalFilesDir(file.downloadFileName)
|
deleteFileFromExternalFilesDir(file.downloadFileName)
|
||||||
|
@ -223,6 +226,24 @@ open class ModelManagerViewModel(
|
||||||
val curModelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
val curModelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
||||||
curModelDownloadStatus[model.name] =
|
curModelDownloadStatus[model.name] =
|
||||||
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
|
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
|
||||||
|
|
||||||
|
// Delete model from the list if model is imported as a local model.
|
||||||
|
if (model.isLocalModel) {
|
||||||
|
val index = task.models.indexOf(model)
|
||||||
|
if (index >= 0) {
|
||||||
|
task.models.removeAt(index)
|
||||||
|
}
|
||||||
|
task.updateTrigger.value = System.currentTimeMillis()
|
||||||
|
curModelDownloadStatus.remove(model.name)
|
||||||
|
|
||||||
|
// Update preference.
|
||||||
|
val localModels = dataStoreRepository.readLocalModels().toMutableList()
|
||||||
|
val localModelIndex = localModels.indexOfFirst { it.fileName == model.name }
|
||||||
|
if (localModelIndex >= 0) {
|
||||||
|
localModels.removeAt(localModelIndex)
|
||||||
|
}
|
||||||
|
dataStoreRepository.saveLocalModels(localModels = localModels)
|
||||||
|
}
|
||||||
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
|
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
|
||||||
_uiState.update { newUiState }
|
_uiState.update { newUiState }
|
||||||
}
|
}
|
||||||
|
@ -230,7 +251,7 @@ open class ModelManagerViewModel(
|
||||||
fun initializeModel(context: Context, model: Model, force: Boolean = false) {
|
fun initializeModel(context: Context, model: Model, force: Boolean = false) {
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
// Skip if initialized already.
|
// Skip if initialized already.
|
||||||
if (!force && uiState.value.modelInitializationStatus[model.name] == ModelInitializationStatus.INITIALIZED) {
|
if (!force && uiState.value.modelInitializationStatus[model.name]?.status == ModelInitializationStatusType.INITIALIZED) {
|
||||||
Log.d(TAG, "Model '${model.name}' has been initialized. Skipping.")
|
Log.d(TAG, "Model '${model.name}' has been initialized. Skipping.")
|
||||||
return@launch
|
return@launch
|
||||||
}
|
}
|
||||||
|
@ -252,20 +273,27 @@ open class ModelManagerViewModel(
|
||||||
// been initialized or not. If so, skip.
|
// been initialized or not. If so, skip.
|
||||||
launch {
|
launch {
|
||||||
delay(500)
|
delay(500)
|
||||||
if (model.instance == null) {
|
if (model.instance == null && model.initializing) {
|
||||||
updateModelInitializationStatus(
|
updateModelInitializationStatus(
|
||||||
model = model, status = ModelInitializationStatus.INITIALIZING
|
model = model, status = ModelInitializationStatusType.INITIALIZING
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val onDone: () -> Unit = {
|
val onDone: (error: String) -> Unit = { error ->
|
||||||
|
model.initializing = false
|
||||||
if (model.instance != null) {
|
if (model.instance != null) {
|
||||||
Log.d(TAG, "Model '${model.name}' initialized successfully")
|
Log.d(TAG, "Model '${model.name}' initialized successfully")
|
||||||
model.initializing = false
|
|
||||||
updateModelInitializationStatus(
|
updateModelInitializationStatus(
|
||||||
model = model,
|
model = model,
|
||||||
status = ModelInitializationStatus.INITIALIZED,
|
status = ModelInitializationStatusType.INITIALIZED,
|
||||||
|
)
|
||||||
|
} else if (error.isNotEmpty()) {
|
||||||
|
Log.d(TAG, "Model '${model.name}' failed to initialize")
|
||||||
|
updateModelInitializationStatus(
|
||||||
|
model = model,
|
||||||
|
status = ModelInitializationStatusType.ERROR,
|
||||||
|
error = error,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -310,7 +338,7 @@ open class ModelManagerViewModel(
|
||||||
model.instance = null
|
model.instance = null
|
||||||
model.initializing = false
|
model.initializing = false
|
||||||
updateModelInitializationStatus(
|
updateModelInitializationStatus(
|
||||||
model = model, status = ModelInitializationStatus.NOT_INITIALIZED
|
model = model, status = ModelInitializationStatusType.NOT_INITIALIZED
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -380,8 +408,7 @@ open class ModelManagerViewModel(
|
||||||
val connection = url.openConnection() as HttpURLConnection
|
val connection = url.openConnection() as HttpURLConnection
|
||||||
if (accessToken != null) {
|
if (accessToken != null) {
|
||||||
connection.setRequestProperty(
|
connection.setRequestProperty(
|
||||||
"Authorization",
|
"Authorization", "Bearer $accessToken"
|
||||||
"Bearer $accessToken"
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
connection.connect()
|
connection.connect()
|
||||||
|
@ -390,6 +417,47 @@ open class ModelManagerViewModel(
|
||||||
return connection.responseCode
|
return connection.responseCode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) {
|
||||||
|
Log.d(TAG, "adding local model: $fileName, $fileSize")
|
||||||
|
|
||||||
|
// Create model.
|
||||||
|
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
|
||||||
|
val model = Model(
|
||||||
|
name = fileName,
|
||||||
|
url = "",
|
||||||
|
configs = configs,
|
||||||
|
sizeInBytes = fileSize,
|
||||||
|
downloadFileName = "$IMPORTS_DIR/$fileName",
|
||||||
|
isLocalModel = true,
|
||||||
|
)
|
||||||
|
model.preProcess(task = task)
|
||||||
|
task.models.add(model)
|
||||||
|
|
||||||
|
// Add initial status and states.
|
||||||
|
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
||||||
|
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
|
||||||
|
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||||
|
status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize
|
||||||
|
)
|
||||||
|
modelInstances[model.name] =
|
||||||
|
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
||||||
|
|
||||||
|
// Update ui state.
|
||||||
|
_uiState.update {
|
||||||
|
uiState.value.copy(
|
||||||
|
tasks = uiState.value.tasks.toList(),
|
||||||
|
modelDownloadStatus = modelDownloadStatus,
|
||||||
|
modelInitializationStatus = modelInstances
|
||||||
|
)
|
||||||
|
}
|
||||||
|
task.updateTrigger.value = System.currentTimeMillis()
|
||||||
|
|
||||||
|
// Add to preference storage.
|
||||||
|
val localModels = dataStoreRepository.readLocalModels().toMutableList()
|
||||||
|
localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize))
|
||||||
|
dataStoreRepository.saveLocalModels(localModels = localModels)
|
||||||
|
}
|
||||||
|
|
||||||
fun getTokenStatusAndData(): TokenStatusAndData {
|
fun getTokenStatusAndData(): TokenStatusAndData {
|
||||||
// Try to load token data from DataStore.
|
// Try to load token data from DataStore.
|
||||||
var tokenStatus = TokenStatus.NOT_STORED
|
var tokenStatus = TokenStatus.NOT_STORED
|
||||||
|
@ -436,8 +504,7 @@ open class ModelManagerViewModel(
|
||||||
if (dataIntent == null) {
|
if (dataIntent == null) {
|
||||||
onTokenRequested(
|
onTokenRequested(
|
||||||
TokenRequestResult(
|
TokenRequestResult(
|
||||||
status = TokenRequestResultType.FAILED,
|
status = TokenRequestResultType.FAILED, errorMessage = "Empty auth result"
|
||||||
errorMessage = "Empty auth result"
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -481,8 +548,7 @@ open class ModelManagerViewModel(
|
||||||
} else {
|
} else {
|
||||||
onTokenRequested(
|
onTokenRequested(
|
||||||
TokenRequestResult(
|
TokenRequestResult(
|
||||||
status = TokenRequestResultType.FAILED,
|
status = TokenRequestResultType.FAILED, errorMessage = errorMessage
|
||||||
errorMessage = errorMessage
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -513,23 +579,49 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun createUiState(): ModelManagerUiState {
|
private fun createUiState(): ModelManagerUiState {
|
||||||
val modelsByTaskName: Map<String, MutableList<Model>> =
|
|
||||||
TASKS.associate { task -> task.type.label to task.models }
|
|
||||||
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
|
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
|
||||||
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
|
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
|
||||||
for ((_, models) in modelsByTaskName.entries) {
|
for (task in TASKS) {
|
||||||
for (model in models) {
|
for (model in task.models) {
|
||||||
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
|
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
|
||||||
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED
|
modelInstances[model.name] =
|
||||||
|
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load local models.
|
||||||
|
for (localModel in dataStoreRepository.readLocalModels()) {
|
||||||
|
Log.d(TAG, "stored local model: $localModel")
|
||||||
|
|
||||||
|
// Create model.
|
||||||
|
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
|
||||||
|
val model = Model(
|
||||||
|
name = localModel.fileName,
|
||||||
|
url = "",
|
||||||
|
configs = configs,
|
||||||
|
sizeInBytes = localModel.fileSize,
|
||||||
|
downloadFileName = "$IMPORTS_DIR/${localModel.fileName}",
|
||||||
|
isLocalModel = true,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Add to task.
|
||||||
|
val task = TASK_LLM_CHAT
|
||||||
|
model.preProcess(task = task)
|
||||||
|
task.models.add(model)
|
||||||
|
|
||||||
|
// Update status.
|
||||||
|
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||||
|
status = ModelDownloadStatusType.SUCCEEDED,
|
||||||
|
receivedBytes = localModel.fileSize,
|
||||||
|
totalBytes = localModel.fileSize
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
val textInputHistory = dataStoreRepository.readTextInputHistory()
|
val textInputHistory = dataStoreRepository.readTextInputHistory()
|
||||||
Log.d(TAG, "text input history: $textInputHistory")
|
Log.d(TAG, "text input history: $textInputHistory")
|
||||||
|
|
||||||
return ModelManagerUiState(
|
return ModelManagerUiState(
|
||||||
tasks = TASKS,
|
tasks = TASKS,
|
||||||
modelsByTaskName = modelsByTaskName,
|
|
||||||
modelDownloadStatus = modelDownloadStatus,
|
modelDownloadStatus = modelDownloadStatus,
|
||||||
modelInitializationStatus = modelInstances,
|
modelInitializationStatus = modelInstances,
|
||||||
textInputHistory = textInputHistory,
|
textInputHistory = textInputHistory,
|
||||||
|
@ -610,7 +702,8 @@ open class ModelManagerViewModel(
|
||||||
|
|
||||||
// Add initial status and states.
|
// Add initial status and states.
|
||||||
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
|
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
|
||||||
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED
|
modelInstances[model.name] =
|
||||||
|
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -677,9 +770,13 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun updateModelInitializationStatus(model: Model, status: ModelInitializationStatus) {
|
private fun updateModelInitializationStatus(
|
||||||
|
model: Model,
|
||||||
|
status: ModelInitializationStatusType,
|
||||||
|
error: String = ""
|
||||||
|
) {
|
||||||
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
|
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
|
||||||
curModelInstance[model.name] = status
|
curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error)
|
||||||
val newUiState = uiState.value.copy(modelInitializationStatus = curModelInstance)
|
val newUiState = uiState.value.copy(modelInitializationStatus = curModelInstance)
|
||||||
_uiState.update { newUiState }
|
_uiState.update { newUiState }
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package com.google.aiedge.gallery.ui.preview
|
||||||
|
|
||||||
import com.google.aiedge.gallery.data.AccessTokenData
|
import com.google.aiedge.gallery.data.AccessTokenData
|
||||||
import com.google.aiedge.gallery.data.DataStoreRepository
|
import com.google.aiedge.gallery.data.DataStoreRepository
|
||||||
|
import com.google.aiedge.gallery.data.LocalModelInfo
|
||||||
|
|
||||||
class PreviewDataStoreRepository : DataStoreRepository {
|
class PreviewDataStoreRepository : DataStoreRepository {
|
||||||
override fun saveTextInputHistory(history: List<String>) {
|
override fun saveTextInputHistory(history: List<String>) {
|
||||||
|
@ -40,4 +41,11 @@ class PreviewDataStoreRepository : DataStoreRepository {
|
||||||
override fun readAccessTokenData(): AccessTokenData? {
|
override fun readAccessTokenData(): AccessTokenData? {
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun readLocalModels(): List<LocalModelInfo> {
|
||||||
|
return listOf()
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -17,7 +17,6 @@
|
||||||
package com.google.aiedge.gallery.ui.preview
|
package com.google.aiedge.gallery.ui.preview
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import com.google.aiedge.gallery.data.Model
|
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerUiState
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerUiState
|
||||||
|
@ -39,8 +38,6 @@ class PreviewModelManagerViewModel(context: Context) :
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val modelsByTaskName: Map<String, MutableList<Model>> =
|
|
||||||
ALL_PREVIEW_TASKS.associate { task -> task.type.label to task.models }
|
|
||||||
val modelDownloadStatus = mapOf(
|
val modelDownloadStatus = mapOf(
|
||||||
MODEL_TEST1.name to ModelDownloadStatus(
|
MODEL_TEST1.name to ModelDownloadStatus(
|
||||||
status = ModelDownloadStatusType.IN_PROGRESS,
|
status = ModelDownloadStatusType.IN_PROGRESS,
|
||||||
|
@ -61,7 +58,6 @@ class PreviewModelManagerViewModel(context: Context) :
|
||||||
)
|
)
|
||||||
val newUiState = ModelManagerUiState(
|
val newUiState = ModelManagerUiState(
|
||||||
tasks = ALL_PREVIEW_TASKS,
|
tasks = ALL_PREVIEW_TASKS,
|
||||||
modelsByTaskName = modelsByTaskName,
|
|
||||||
modelDownloadStatus = modelDownloadStatus,
|
modelDownloadStatus = modelDownloadStatus,
|
||||||
modelInitializationStatus = mapOf(),
|
modelInitializationStatus = mapOf(),
|
||||||
selectedModel = MODEL_TEST2,
|
selectedModel = MODEL_TEST2,
|
||||||
|
|
|
@ -40,14 +40,14 @@ class TextClassificationInferenceResult(
|
||||||
* Helper object for managing text classification models.
|
* Helper object for managing text classification models.
|
||||||
*/
|
*/
|
||||||
object TextClassificationModelHelper {
|
object TextClassificationModelHelper {
|
||||||
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
|
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
|
||||||
val modelByteBuffer = readFileToByteBuffer(File(model.getPath(context = context)))
|
val modelByteBuffer = readFileToByteBuffer(File(model.getPath(context = context)))
|
||||||
if (modelByteBuffer != null) {
|
if (modelByteBuffer != null) {
|
||||||
val options = TextClassifier.TextClassifierOptions.builder().setBaseOptions(
|
val options = TextClassifier.TextClassifierOptions.builder().setBaseOptions(
|
||||||
BaseOptions.builder().setModelAssetBuffer(modelByteBuffer).build()
|
BaseOptions.builder().setModelAssetBuffer(modelByteBuffer).build()
|
||||||
).build()
|
).build()
|
||||||
model.instance = TextClassifier.createFromOptions(context, options)
|
model.instance = TextClassifier.createFromOptions(context, options)
|
||||||
onDone()
|
onDone("")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue