Make importing model functionality better.

- Allow users to specify default parameters before importing.
This commit is contained in:
Jing Jin 2025-04-18 23:10:55 -07:00
parent 29b614355e
commit 604972fe23
15 changed files with 635 additions and 329 deletions

View file

@ -23,6 +23,7 @@ package com.google.aiedge.gallery.data
* Each type corresponds to a specific editor widget, such as a slider or a switch.
*/
enum class ConfigEditorType {
LABEL,
NUMBER_SLIDER,
BOOLEAN_SWITCH,
DROPDOWN,
@ -57,6 +58,19 @@ open class Config(
open val needReinitialization: Boolean = true,
)
/**
* Configuration setting for a label.
*/
class LabelConfig(
override val key: ConfigKey,
override val defaultValue: String = "",
) : Config(
type = ConfigEditorType.LABEL,
key = key,
defaultValue = defaultValue,
valueType = ValueType.STRING
)
/**
* Configuration setting for a number slider.
*
@ -99,9 +113,11 @@ class SegmentedButtonConfig(
override val key: ConfigKey,
override val defaultValue: String,
val options: List<String>,
val allowMultiple: Boolean = false,
) : Config(
type = ConfigEditorType.DROPDOWN,
key = key,
defaultValue = defaultValue,
// The emitted value will be comma-separated labels when allowMultiple=true.
valueType = ValueType.STRING,
)

View file

@ -47,8 +47,8 @@ interface DataStoreRepository {
fun readThemeOverride(): String
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
fun readAccessTokenData(): AccessTokenData?
fun saveLocalModels(localModels: List<LocalModelInfo>)
fun readLocalModels(): List<LocalModelInfo>
fun saveImportedModels(importedModels: List<ImportedModelInfo>)
fun readImportedModels(): List<ImportedModelInfo>
}
/**
@ -82,8 +82,8 @@ class DefaultDataStoreRepository(
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
// Data for all imported local models.
val LOCAL_MODELS = stringPreferencesKey("local_models")
// Data for all imported models.
val IMPORTED_MODELS = stringPreferencesKey("imported_models")
}
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
@ -160,22 +160,22 @@ class DefaultDataStoreRepository(
}
}
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
runBlocking {
dataStore.edit { preferences ->
val gson = Gson()
val jsonString = gson.toJson(localModels)
preferences[PreferencesKeys.LOCAL_MODELS] = jsonString
val jsonString = gson.toJson(importedModels)
preferences[PreferencesKeys.IMPORTED_MODELS] = jsonString
}
}
}
override fun readLocalModels(): List<LocalModelInfo> {
override fun readImportedModels(): List<ImportedModelInfo> {
return runBlocking {
val preferences = dataStore.data.first()
val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]"
val infosStr = preferences[PreferencesKeys.IMPORTED_MODELS] ?: "[]"
val gson = Gson()
val listType = object : TypeToken<List<LocalModelInfo>>() {}.type
val listType = object : TypeToken<List<ImportedModelInfo>>() {}.type
gson.fromJson(infosStr, listType)
}
}

View file

@ -17,7 +17,6 @@
package com.google.aiedge.gallery.data
import com.google.aiedge.gallery.ui.common.ensureValidFileName
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
@ -107,11 +106,13 @@ data class HfModel(
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
// Generate configs based on the given default values.
val configs: List<Config> = when (task) {
TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
// todo: add configs for other types.
else -> listOf()
}
// val configs: List<Config> = when (task) {
// TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
// // todo: add configs for other types.
// else -> listOf()
// }
// todo: fix when loading from models.json
val configs: List<Config> = listOf()
// Construct url.
var modelUrl = url

View file

@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data
import android.content.Context
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
data class ModelDataFile(
@ -28,8 +29,8 @@ data class ModelDataFile(
val sizeInBytes: Long,
)
enum class LlmBackend {
CPU, GPU
enum class Accelerator(val label: String) {
CPU(label = "CPU"), GPU(label = "GPU")
}
const val IMPORTS_DIR = "__imports"
@ -81,14 +82,14 @@ data class Model(
/** The name of the directory to unzip the model to (if it's a zip file). */
val unzipDir: String = "",
/** The preferred backend of the model (only for LLM). */
val llmBackend: LlmBackend = LlmBackend.GPU,
/** The accelerators the the model can run with. */
val accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
/** The prompt templates for the model (only for LLM). */
val llmPromptTemplates: List<PromptTemplate> = listOf(),
/** Whether the model is imported as a local model. */
val isLocalModel: Boolean = false,
/** Whether the model is imported or not. */
val imported: Boolean = false,
// The following fields are managed by the app. Don't need to set manually.
var taskType: TaskType? = null,
@ -135,6 +136,12 @@ data class Model(
) as Boolean
}
fun getStringConfigValue(key: ConfigKey, defaultValue: String = ""): String {
return getTypedConfigValue(
key = key, valueType = ValueType.STRING, defaultValue = defaultValue
) as String
}
fun getExtraDataFile(name: String): ModelDataFile? {
return extraDataFiles.find { it.name == name }
}
@ -147,7 +154,11 @@ data class Model(
}
/** Data for a imported local model. */
data class LocalModelInfo(val fileName: String, val fileSize: Long)
data class ImportedModelInfo(
val fileName: String,
val fileSize: Long,
val defaultValues: Map<String, Any>
)
enum class ModelDownloadStatusType {
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
@ -165,29 +176,25 @@ data class ModelDownloadStatus(
////////////////////////////////////////////////////////////////////////////////////////////////////
// Configs.
enum class ConfigKey(val label: String, val id: String) {
MAX_TOKENS("Max tokens", id = "max_token"),
TOPK("TopK", id = "topk"),
TOPP(
"TopP",
id = "topp"
),
TEMPERATURE("Temperature", id = "temperature"),
MAX_RESULT_COUNT(
"Max result count",
id = "max_result_count"
),
USE_GPU("Use GPU", id = "use_gpu"),
WARM_UP_ITERATIONS(
"Warm up iterations",
id = "warm_up_iterations"
),
BENCHMARK_ITERATIONS(
"Benchmark iterations",
id = "benchmark_iterations"
),
ITERATIONS("Iterations", id = "iterations"),
THEME("Theme", id = "theme"),
enum class ConfigKey(val label: String) {
MAX_TOKENS("Max tokens"),
TOPK("TopK"),
TOPP("TopP"),
TEMPERATURE("Temperature"),
DEFAULT_MAX_TOKENS("Default max tokens"),
DEFAULT_TOPK("Default TopK"),
DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Accelerator"),
COMPATIBLE_ACCELERATORS("Compatible accelerators"),
WARM_UP_ITERATIONS("Warm up iterations"),
BENCHMARK_ITERATIONS("Benchmark iterations"),
ITERATIONS("Iterations"),
THEME("Theme"),
NAME("Name"),
MODEL_TYPE("Model type")
}
val MOBILENET_CONFIGS: List<Config> = listOf(
@ -258,7 +265,12 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
downloadFileName = "gemma3-1b-it-int4.task",
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
sizeInBytes = 554661243L,
configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f),
accelerators = listOf(Accelerator.CPU, Accelerator.GPU),
configs = createLlmChatConfigs(
defaultTopK = 64,
defaultTopP = 0.95f,
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
),
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
llmPromptTemplates = listOf(
@ -280,8 +292,13 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
downloadFileName = "deepseek.task",
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
sizeInBytes = 1860686856L,
llmBackend = LlmBackend.CPU,
configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f),
accelerators = listOf(Accelerator.CPU),
configs = createLlmChatConfigs(
defaultTemperature = 0.6f,
defaultTopK = 40,
defaultTopP = 0.7f,
accelerators = listOf(Accelerator.CPU)
),
info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
)

View file

@ -34,16 +34,15 @@ import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.MultiChoiceSegmentedButtonRow
import androidx.compose.material3.SegmentedButton
import androidx.compose.material3.SegmentedButtonDefaults
import androidx.compose.material3.SingleChoiceSegmentedButtonRow
import androidx.compose.material3.Slider
import androidx.compose.material3.Switch
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateMapOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
@ -60,6 +59,7 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.google.aiedge.gallery.data.BooleanSwitchConfig
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.LabelConfig
import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.ValueType
@ -113,27 +113,10 @@ fun ConfigDialog(
}
}
// List of config rows.
for (config in configs) {
when (config) {
// Number slider.
is NumberSliderConfig -> {
NumberSliderRow(config = config, values = values)
}
ConfigEditorsPanel(configs = configs, values = values)
// Boolean switch.
is BooleanSwitchConfig -> {
BooleanSwitchRow(config = config, values = values)
}
is SegmentedButtonConfig -> {
SegmentedButtonRow(config = config, values = values)
}
else -> {}
}
}
// Button row.
Row(
modifier = Modifier
.fillMaxWidth()
@ -164,6 +147,53 @@ fun ConfigDialog(
}
}
/**
* Composable function to display a list of config editor rows.
*/
@Composable
fun ConfigEditorsPanel(configs: List<Config>, values: SnapshotStateMap<String, Any>) {
for (config in configs) {
when (config) {
// Label.
is LabelConfig -> {
LabelRow(config = config, values = values)
}
// Number slider.
is NumberSliderConfig -> {
NumberSliderRow(config = config, values = values)
}
// Boolean switch.
is BooleanSwitchConfig -> {
BooleanSwitchRow(config = config, values = values)
}
// Segmented button.
is SegmentedButtonConfig -> {
SegmentedButtonRow(config = config, values = values)
}
else -> {}
}
}
}
@Composable
fun LabelRow(config: LabelConfig, values: SnapshotStateMap<String, Any>) {
Column(modifier = Modifier.fillMaxWidth()) {
// Field label.
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
// Content label.
val label = try {
values[config.key.label] as String
} catch (e: Exception) {
""
}
Text(label, style = MaterialTheme.typography.bodyMedium)
}
}
/**
* Composable function to display a number slider with an associated text input field.
*
@ -272,18 +302,41 @@ fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<Strin
@Composable
fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) {
var selectedIndex by remember { mutableIntStateOf(config.options.indexOf(values[config.key.label])) }
val selectedOptions: List<String> = remember { (values[config.key.label] as String).split(",") }
var selectionStates: List<Boolean> by remember {
mutableStateOf(List(config.options.size) { index ->
selectedOptions.contains(config.options[index])
})
}
Column(modifier = Modifier.fillMaxWidth()) {
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
SingleChoiceSegmentedButtonRow {
MultiChoiceSegmentedButtonRow {
config.options.forEachIndexed { index, label ->
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
index = index, count = config.options.size
), onClick = {
selectedIndex = index
values[config.key.label] = label
}, selected = index == selectedIndex, label = { Text(label) })
), onCheckedChange = {
var newSelectionStates = selectionStates.toMutableList()
val selectedCount = newSelectionStates.count { it }
// Single select.
if (!config.allowMultiple) {
if (!newSelectionStates[index]) {
newSelectionStates = MutableList(config.options.size) { it == index }
}
}
// Multiple select.
else {
if (!(selectedCount == 1 && newSelectionStates[index])) {
newSelectionStates[index] = !newSelectionStates[index]
}
}
selectionStates = newSelectionStates
values[config.key.label] =
config.options.filterIndexed { index, option -> selectionStates[index] }
.joinToString(",")
}, checked = selectionStates[index], label = { Text(label) })
}
}

View file

@ -92,7 +92,6 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
* model description and buttons for learning more (opening a URL) and downloading/trying
* the model.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelItem(
model: Model,
@ -188,9 +187,9 @@ fun ModelItem(
}
} else {
Icon(
// For local model, show ">" directly indicating users can just tap the model item to
// For imported model, show ">" directly indicating users can just tap the model item to
// go into it without needing to expand it first.
if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "",
tint = getTaskIconColor(task),
)
@ -272,7 +271,7 @@ fun ModelItem(
boxModifier = if (canExpand) {
boxModifier.clickable(
onClick = {
if (!model.isLocalModel) {
if (!model.imported) {
isExpanded = !isExpanded
} else {
onModelClicked(model)

View file

@ -16,13 +16,22 @@
package com.google.aiedge.gallery.ui.home
import android.content.Intent
import android.net.Uri
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher
import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.StringRes
import androidx.compose.animation.core.animateFloatAsState
import androidx.compose.animation.core.tween
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.aspectRatio
import androidx.compose.foundation.layout.fillMaxSize
@ -34,22 +43,34 @@ import androidx.compose.foundation.lazy.grid.GridItemSpan
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.items
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add
import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Scaffold
import androidx.compose.material3.SmallFloatingActionButton
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.scale
import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.layout.layout
@ -66,6 +87,8 @@ import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.TaskIcon
import com.google.aiedge.gallery.ui.common.getTaskBgColor
@ -75,6 +98,11 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.ThemeSettings
import com.google.aiedge.gallery.ui.theme.customColors
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private const val TAG = "AGHomeScreen"
private const val TASK_COUNT_ANIMATION_DURATION = 250
/** Navigation destination data */
object HomeScreenDestination {
@ -92,22 +120,59 @@ fun HomeScreen(
val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior()
val uiState by modelManagerViewModel.uiState.collectAsState()
var showSettingsDialog by remember { mutableStateOf(false) }
var showImportModelSheet by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState()
var showImportDialog by remember { mutableStateOf(false) }
var showImportingDialog by remember { mutableStateOf(false) }
val selectedLocalModelFileUri = remember { mutableStateOf<Uri?>(null) }
val selectedImportedModelInfo = remember { mutableStateOf<ImportedModelInfo?>(null) }
val coroutineScope = rememberCoroutineScope()
val tasks = uiState.tasks
val loadingHfModels = uiState.loadingHfModels
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = {
GalleryTopAppBar(
title = stringResource(HomeScreenDestination.titleRes),
rightAction = AppBarAction(
actionType = AppBarActionType.APP_SETTING, actionFn = {
showSettingsDialog = true
}
),
loadingHfModels = loadingHfModels,
scrollBehavior = scrollBehavior,
)
}) { innerPadding ->
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
selectedLocalModelFileUri.value = uri
showImportDialog = true
} ?: run {
Log.d(TAG, "No file selected or URI is null.")
}
} else {
Log.d(TAG, "File picking cancelled.")
}
}
Scaffold(
modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
topBar = {
GalleryTopAppBar(
title = stringResource(HomeScreenDestination.titleRes),
rightAction = AppBarAction(
actionType = AppBarActionType.APP_SETTING, actionFn = {
showSettingsDialog = true
}
),
loadingHfModels = loadingHfModels,
scrollBehavior = scrollBehavior,
)
},
floatingActionButton = {
// A floating action button to show "import model" bottom sheet.
SmallFloatingActionButton(
onClick = {
showImportModelSheet = true
},
containerColor = MaterialTheme.colorScheme.secondaryContainer,
contentColor = MaterialTheme.colorScheme.secondary,
) {
Icon(Icons.Filled.Add, "")
}
}
) { innerPadding ->
TaskList(
tasks = tasks,
navigateToTaskScreen = navigateToTaskScreen,
@ -132,6 +197,83 @@ fun HomeScreen(
},
)
}
// Import model bottom sheet.
if (showImportModelSheet) {
ModalBottomSheet(
onDismissRequest = { showImportModelSheet = false },
sheetState = sheetState,
) {
Text(
"Import model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
)
Box(modifier = Modifier.clickable {
coroutineScope.launch {
// Give it sometime to show the click effect.
delay(200)
showImportModelSheet = false
// Show file picker.
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*"
putExtra(
Intent.EXTRA_MIME_TYPES,
arrayOf("application/x-binary", "application/octet-stream")
)
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
}
filePickerLauncher.launch(intent)
}
}) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
) {
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
Text("From local model file")
}
}
}
}
// Import dialog
if (showImportDialog) {
selectedLocalModelFileUri.value?.let { uri ->
ModelImportDialog(uri = uri,
onDismiss = { showImportDialog = false },
onDone = { info ->
selectedImportedModelInfo.value = info
showImportDialog = false
showImportingDialog = true
})
}
}
// Importing in progress dialog.
if (showImportingDialog) {
selectedLocalModelFileUri.value?.let { uri ->
selectedImportedModelInfo.value?.let { info ->
ModelImportingDialog(
uri = uri,
info = info,
onDismiss = { showImportingDialog = false },
onDone = {
modelManagerViewModel.addImportedLlmModel(
task = TASK_LLM_CHAT,
info = it,
)
showImportingDialog = false
})
}
}
}
}
@Composable
@ -150,7 +292,7 @@ private fun TaskList(
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// Headline.
item(span = { GridItemSpan(2) }) {
item(key = "headline", span = { GridItemSpan(2) }) {
Text(
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community",
textAlign = TextAlign.Center,
@ -171,6 +313,11 @@ private fun TaskList(
.aspectRatio(1f)
)
}
// Bottom padding.
item(key = "bottomPadding", span = { GridItemSpan(2) }) {
Spacer(modifier = Modifier.height(60.dp))
}
}
// Gradient overlay at the bottom.
@ -190,6 +337,48 @@ private fun TaskList(
@Composable
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) {
// Observes the model count and updates the model count label with a fade-in/fade-out animation
// whenever the count changes.
val modelCount by remember {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
task.models.size
} else {
0
}
}
}
val modelCountLabel by remember {
derivedStateOf {
when (modelCount) {
1 -> "1 Model"
else -> "%d Models".format(modelCount)
}
}
}
var curModelCountLabel by remember { mutableStateOf("") }
var modelCountLabelVisible by remember { mutableStateOf(true) }
val modelCountAlpha: Float by animateFloatAsState(
targetValue = if (modelCountLabelVisible) 1f else 0f,
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
)
val modelCountScale: Float by animateFloatAsState(
targetValue = if (modelCountLabelVisible) 1f else 0.7f,
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
)
LaunchedEffect(modelCountLabel) {
if (curModelCountLabel.isEmpty()) {
curModelCountLabel = modelCountLabel
} else {
modelCountLabelVisible = false
delay(TASK_COUNT_ANIMATION_DURATION.toLong())
curModelCountLabel = modelCountLabel
modelCountLabelVisible = true
}
}
Card(
modifier = modifier
.clip(RoundedCornerShape(43.5.dp))
@ -238,14 +427,13 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
}
// Model count.
val modelCountLabel = when (task.models.size) {
1 -> "1 Model"
else -> "%d Models".format(task.models.size)
}
Text(
modelCountLabel,
curModelCountLabel,
color = MaterialTheme.colorScheme.secondary,
style = MaterialTheme.typography.bodyMedium
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier
.alpha(modelCountAlpha)
.scale(modelCountScale),
)
}
}

View file

@ -14,7 +14,7 @@
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.modelmanager
package com.google.aiedge.gallery.ui.home
import android.content.Context
import android.net.Uri
@ -36,24 +36,40 @@ import androidx.compose.material3.Icon
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.runtime.mutableStateMapOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshots.SnapshotStateMap
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.data.LabelConfig
import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.ui.common.chat.ConfigEditorsPanel
import com.google.aiedge.gallery.ui.common.ensureValidFileName
import com.google.aiedge.gallery.ui.common.humanReadableSize
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_MAX_TOKEN
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPK
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPP
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
@ -64,37 +80,151 @@ import java.nio.charset.StandardCharsets
private const val TAG = "AGModelImportDialog"
data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "")
private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
LabelConfig(key = ConfigKey.NAME),
LabelConfig(key = ConfigKey.MODEL_TYPE),
NumberSliderConfig(
key = ConfigKey.DEFAULT_MAX_TOKENS,
sliderMin = 100f,
sliderMax = 1024f,
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = DEFAULT_TOPK.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = DEFAULT_TOPP,
valueType = ValueType.FLOAT
),
NumberSliderConfig(
key = ConfigKey.DEFAULT_TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = DEFAULT_TEMPERATURE,
valueType = ValueType.FLOAT
),
SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
allowMultiple = true,
)
)
@Composable
fun ModelImportDialog(
uri: Uri, onDone: (ModelImportInfo) -> Unit
uri: Uri,
onDismiss: () -> Unit,
onDone: (ImportedModelInfo) -> Unit
) {
val context = LocalContext.current
val coroutineScope = rememberCoroutineScope()
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
val fileSize by remember { mutableLongStateOf(info.first) }
val fileName by remember { mutableStateOf(ensureValidFileName(info.second)) }
var fileName by remember { mutableStateOf("") }
var fileSize by remember { mutableLongStateOf(0L) }
val initialValues: Map<String, Any> = remember {
mutableMapOf<String, Any>().apply {
for (config in IMPORT_CONFIGS_LLM) {
put(config.key.label, config.defaultValue)
}
put(ConfigKey.NAME.label, fileName)
// TODO: support other types.
put(ConfigKey.MODEL_TYPE.label, "LLM")
}
}
val values: SnapshotStateMap<String, Any> = remember {
mutableStateMapOf<String, Any>().apply {
putAll(initialValues)
}
}
Dialog(
onDismissRequest = onDismiss,
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title.
Text(
"Import Model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
// Default configs for users to set.
ConfigEditorsPanel(
configs = IMPORT_CONFIGS_LLM,
values = values,
)
// Button row.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
// Cancel button.
TextButton(
onClick = { onDismiss() },
) {
Text("Cancel")
}
// Import button
Button(
onClick = {
onDone(
ImportedModelInfo(
fileName = fileName,
fileSize = fileSize,
defaultValues = values,
)
)
},
) {
Text("Import")
}
}
}
}
}
}
@Composable
fun ModelImportingDialog(
uri: Uri,
info: ImportedModelInfo,
onDismiss: () -> Unit,
onDone: (ImportedModelInfo) -> Unit
) {
var error by remember { mutableStateOf("") }
val context = LocalContext.current
val coroutineScope = rememberCoroutineScope()
var progress by remember { mutableFloatStateOf(0f) }
LaunchedEffect(Unit) {
error = ""
// Get basic info.
val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri)
fileSize = info.first
fileName = ensureValidFileName(info.second)
// Import.
importModel(
context = context,
coroutineScope = coroutineScope,
fileName = fileName,
fileSize = fileSize,
fileName = info.fileName,
fileSize = info.fileSize,
uri = uri,
onDone = {
onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error))
onDone(info)
},
onProgress = {
progress = it
@ -107,7 +237,7 @@ fun ModelImportDialog(
Dialog(
properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false),
onDismissRequest = {},
onDismissRequest = onDismiss,
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
@ -117,7 +247,7 @@ fun ModelImportDialog(
) {
// Title.
Text(
"Importing...",
"Import Model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
@ -127,7 +257,7 @@ fun ModelImportDialog(
// Progress bar.
Column(verticalArrangement = Arrangement.spacedBy(4.dp)) {
Text(
"$fileName (${fileSize.humanReadableSize()})",
"${info.fileName} (${info.fileSize.humanReadableSize()})",
style = MaterialTheme.typography.labelSmall,
)
val animatedProgress = remember { Animatable(0f) }
@ -162,7 +292,7 @@ fun ModelImportDialog(
}
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button(onClick = {
onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error))
onDismiss()
}) {
Text("Close")
}

View file

@ -30,7 +30,7 @@ private val CONFIGS: List<Config> = listOf(
SegmentedButtonConfig(
key = ConfigKey.THEME,
defaultValue = THEME_AUTO,
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK)
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK),
)
)

View file

@ -16,24 +16,25 @@
package com.google.aiedge.gallery.ui.llmchat
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.ConfigValue
import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.data.getFloatConfigValue
import com.google.aiedge.gallery.data.getIntConfigValue
private const val DEFAULT_MAX_TOKEN = 1024
private const val DEFAULT_TOPK = 40
private const val DEFAULT_TOPP = 0.9f
private const val DEFAULT_TEMPERATURE = 1.0f
const val DEFAULT_MAX_TOKEN = 1024
const val DEFAULT_TOPK = 40
const val DEFAULT_TOPP = 0.9f
const val DEFAULT_TEMPERATURE = 1.0f
val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
fun createLlmChatConfigs(
defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
defaultTopK: Int = DEFAULT_TOPK,
defaultTopP: Float = DEFAULT_TOPP,
defaultTemperature: Float = DEFAULT_TEMPERATURE
defaultTemperature: Float = DEFAULT_TEMPERATURE,
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
): List<Config> {
return listOf(
NumberSliderConfig(
@ -64,21 +65,10 @@ fun createLlmChatConfigs(
defaultValue = defaultTemperature,
valueType = ValueType.FLOAT
),
)
}
fun createLLmChatConfig(defaults: Map<String, ConfigValue>): List<Config> {
val defaultMaxToken =
getIntConfigValue(defaults[ConfigKey.MAX_TOKENS.id], default = DEFAULT_MAX_TOKEN)
val defaultTopK = getIntConfigValue(defaults[ConfigKey.TOPK.id], default = DEFAULT_TOPK)
val defaultTopP = getFloatConfigValue(defaults[ConfigKey.TOPP.id], default = DEFAULT_TOPP)
val defaultTemperature =
getFloatConfigValue(defaults[ConfigKey.TEMPERATURE.id], default = DEFAULT_TEMPERATURE)
return createLlmChatConfigs(
defaultMaxToken = defaultMaxToken,
defaultTopK = defaultTopK,
defaultTopP = defaultTopP,
defaultTemperature = defaultTemperature
SegmentedButtonConfig(
key = ConfigKey.ACCELERATOR,
defaultValue = if (accelerators.contains(Accelerator.GPU)) Accelerator.GPU.label else accelerators[0].label,
options = accelerators.map { it.label }
)
)
}

View file

@ -18,18 +18,14 @@ package com.google.aiedge.gallery.ui.llmchat
import android.content.Context
import android.util.Log
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.LlmBackend
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
private const val TAG = "AGLlmChatModelHelper"
private const val DEFAULT_MAX_TOKEN = 1024
private const val DEFAULT_TOPK = 40
private const val DEFAULT_TOPP = 0.9f
private const val DEFAULT_TEMPERATURE = 1.0f
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
typealias CleanUpListener = () -> Unit
@ -49,10 +45,13 @@ object LlmChatModelHelper {
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
val accelerator =
model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = Accelerator.GPU.label)
Log.d(TAG, "Initializing...")
val preferredBackend = when (model.llmBackend) {
LlmBackend.CPU -> LlmInference.Backend.CPU
LlmBackend.GPU -> LlmInference.Backend.GPU
val preferredBackend = when (accelerator) {
Accelerator.CPU.label -> LlmInference.Backend.CPU
Accelerator.GPU.label -> LlmInference.Backend.GPU
else -> LlmInference.Backend.GPU
}
val options =
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))

View file

@ -16,13 +16,7 @@
package com.google.aiedge.gallery.ui.modelmanager
import android.content.Intent
import android.net.Uri
import android.os.Build
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher
import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.RequiresApi
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
@ -30,32 +24,21 @@ import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.outlined.Code
import androidx.compose.material.icons.outlined.Description
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.SmallFloatingActionButton
import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.vector.ImageVector
@ -75,14 +58,11 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private const val TAG = "AGModelList"
/** The list of models in the model manager. */
@RequiresApi(Build.VERSION_CODES.O)
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelList(
task: Task,
@ -91,50 +71,29 @@ fun ModelList(
onModelClicked: (Model) -> Unit,
modifier: Modifier = Modifier,
) {
var showAddModelSheet by remember { mutableStateOf(false) }
var showImportingDialog by remember { mutableStateOf(false) }
val curFileUri = remember { mutableStateOf<Uri?>(null) }
val sheetState = rememberModalBottomSheetState()
val coroutineScope = rememberCoroutineScope()
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
// be properly updated.
val models by remember {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
task.models.toList().filter { !it.isLocalModel }
task.models.toList().filter { !it.imported }
} else {
listOf()
}
}
}
val localModels by remember {
val importedModels by remember {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
task.models.toList().filter { it.isLocalModel }
task.models.toList().filter { it.imported }
} else {
listOf()
}
}
}
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
curFileUri.value = uri
showImportingDialog = true
} ?: run {
Log.d(TAG, "No file selected or URI is null.")
}
} else {
Log.d(TAG, "File picking cancelled.")
}
}
Box(contentAlignment = Alignment.BottomEnd) {
LazyColumn(
modifier = modifier.padding(top = 8.dp),
@ -190,11 +149,11 @@ fun ModelList(
}
}
// Title for local models.
if (localModels.isNotEmpty()) {
item(key = "localModelsTitle") {
// Title for imported models.
if (importedModels.isNotEmpty()) {
item(key = "importedModelsTitle") {
Text(
"Local models",
"Imported models",
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier
.padding(horizontal = 16.dp)
@ -203,8 +162,8 @@ fun ModelList(
}
}
// List of local models within a task.
items(items = localModels) { model ->
// List of imported models within a task.
items(items = importedModels) { model ->
Box {
ModelItem(
model = model,
@ -215,88 +174,6 @@ fun ModelList(
)
}
}
item(key = "bottomPadding") {
Spacer(modifier = Modifier.height(60.dp))
}
}
// Add model button at the bottom right.
Box(
modifier = Modifier
.padding(end = 16.dp)
.padding(bottom = contentPadding.calculateBottomPadding())
) {
SmallFloatingActionButton(
onClick = {
showAddModelSheet = true
},
containerColor = MaterialTheme.colorScheme.secondaryContainer,
contentColor = MaterialTheme.colorScheme.secondary,
) {
Icon(Icons.Filled.Add, "")
}
}
}
if (showAddModelSheet) {
ModalBottomSheet(
onDismissRequest = { showAddModelSheet = false },
sheetState = sheetState,
) {
Text(
"Add custom model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
)
Box(modifier = Modifier.clickable {
coroutineScope.launch {
// Give it sometime to show the click effect.
delay(200)
showAddModelSheet = false
// Show file picker.
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*"
putExtra(
Intent.EXTRA_MIME_TYPES,
arrayOf("application/x-binary", "application/octet-stream")
)
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
}
filePickerLauncher.launch(intent)
}
}) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
) {
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
Text("Add local model")
}
}
}
}
if (showImportingDialog) {
curFileUri.value?.let { uri ->
ModelImportDialog(uri = uri, onDone = { info ->
showImportingDialog = false
if (info.error.isEmpty()) {
// TODO: support other model types.
modelManagerViewModel.addLocalLlmModel(
task = task,
fileName = info.fileName,
fileSize = info.fileSize
)
}
})
}
}
}

View file

@ -23,8 +23,10 @@ import androidx.activity.result.ActivityResult
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.AGWorkInfo
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.DownloadRepository
import com.google.aiedge.gallery.data.EMPTY_MODEL
@ -32,7 +34,7 @@ import com.google.aiedge.gallery.data.HfModel
import com.google.aiedge.gallery.data.HfModelDetails
import com.google.aiedge.gallery.data.HfModelSummary
import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.data.LocalModelInfo
import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
@ -40,12 +42,14 @@ import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.common.AuthConfig
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
@ -228,7 +232,7 @@ open class ModelManagerViewModel(
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
// Delete model from the list if model is imported as a local model.
if (model.isLocalModel) {
if (model.imported) {
val index = task.models.indexOf(model)
if (index >= 0) {
task.models.removeAt(index)
@ -237,12 +241,12 @@ open class ModelManagerViewModel(
curModelDownloadStatus.remove(model.name)
// Update preference.
val localModels = dataStoreRepository.readLocalModels().toMutableList()
val localModelIndex = localModels.indexOfFirst { it.fileName == model.name }
if (localModelIndex >= 0) {
localModels.removeAt(localModelIndex)
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
val importedModelIndex = importedModels.indexOfFirst { it.fileName == model.name }
if (importedModelIndex >= 0) {
importedModels.removeAt(importedModelIndex)
}
dataStoreRepository.saveLocalModels(localModels = localModels)
dataStoreRepository.saveImportedModels(importedModels = importedModels)
}
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
_uiState.update { newUiState }
@ -417,27 +421,20 @@ open class ModelManagerViewModel(
return connection.responseCode
}
fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) {
Log.d(TAG, "adding local model: $fileName, $fileSize")
fun addImportedLlmModel(task: Task, info: ImportedModelInfo) {
Log.d(TAG, "adding imported llm model: $info")
// Create model.
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
val model = Model(
name = fileName,
url = "",
configs = configs,
sizeInBytes = fileSize,
downloadFileName = "$IMPORTS_DIR/$fileName",
isLocalModel = true,
)
model.preProcess(task = task)
val model = createModelFromImportedModelInfo(info = info, task = task)
task.models.add(model)
// Add initial status and states.
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
modelDownloadStatus[model.name] = ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize
status = ModelDownloadStatusType.SUCCEEDED,
receivedBytes = info.fileSize,
totalBytes = info.fileSize
)
modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
@ -453,9 +450,9 @@ open class ModelManagerViewModel(
task.updateTrigger.value = System.currentTimeMillis()
// Add to preference storage.
val localModels = dataStoreRepository.readLocalModels().toMutableList()
localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize))
dataStoreRepository.saveLocalModels(localModels = localModels)
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
importedModels.add(info)
dataStoreRepository.saveImportedModels(importedModels = importedModels)
}
fun getTokenStatusAndData(): TokenStatusAndData {
@ -589,31 +586,22 @@ open class ModelManagerViewModel(
}
}
// Load local models.
for (localModel in dataStoreRepository.readLocalModels()) {
Log.d(TAG, "stored local model: $localModel")
// Load imported models.
for (importedModel in dataStoreRepository.readImportedModels()) {
Log.d(TAG, "stored imported model: $importedModel")
// Create model.
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
val model = Model(
name = localModel.fileName,
url = "",
configs = configs,
sizeInBytes = localModel.fileSize,
downloadFileName = "$IMPORTS_DIR/${localModel.fileName}",
isLocalModel = true,
)
val model = createModelFromImportedModelInfo(info = importedModel, task = TASK_LLM_CHAT)
// Add to task.
val task = TASK_LLM_CHAT
model.preProcess(task = task)
task.models.add(model)
// Update status.
modelDownloadStatus[model.name] = ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED,
receivedBytes = localModel.fileSize,
totalBytes = localModel.fileSize
receivedBytes = importedModel.fileSize,
totalBytes = importedModel.fileSize
)
}
@ -628,6 +616,51 @@ open class ModelManagerViewModel(
)
}
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
val accelerators: List<Accelerator> = (convertValueToTargetType(
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!,
ValueType.STRING
) as String)
.split(",")
.mapNotNull { acceleratorLabel ->
when (acceleratorLabel.trim()) {
Accelerator.GPU.label -> Accelerator.GPU
Accelerator.CPU.label -> Accelerator.CPU
else -> null // Ignore unknown accelerator labels
}
}
val configs: List<Config> = createLlmChatConfigs(
defaultMaxToken = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!,
ValueType.INT
) as Int,
defaultTopK = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!,
ValueType.INT
) as Int,
defaultTopP = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!,
ValueType.FLOAT
) as Float,
defaultTemperature = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!,
ValueType.FLOAT
) as Float,
accelerators = accelerators,
)
val model = Model(
name = info.fileName,
url = "",
configs = configs,
sizeInBytes = info.fileSize,
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
imported = true,
)
model.preProcess(task = task)
return model
}
/**
* Retrieves the download status of a model.
*
@ -771,9 +804,7 @@ open class ModelManagerViewModel(
}
private fun updateModelInitializationStatus(
model: Model,
status: ModelInitializationStatusType,
error: String = ""
model: Model, status: ModelInitializationStatusType, error: String = ""
) {
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error)

View file

@ -18,7 +18,7 @@ package com.google.aiedge.gallery.ui.preview
import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.LocalModelInfo
import com.google.aiedge.gallery.data.ImportedModelInfo
class PreviewDataStoreRepository : DataStoreRepository {
override fun saveTextInputHistory(history: List<String>) {
@ -42,10 +42,10 @@ class PreviewDataStoreRepository : DataStoreRepository {
return null
}
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
}
override fun readLocalModels(): List<LocalModelInfo> {
override fun readImportedModels(): List<ImportedModelInfo> {
return listOf()
}
}

View file

@ -22,6 +22,7 @@ import androidx.compose.material.icons.rounded.AutoAwesome
import com.google.aiedge.gallery.data.BooleanSwitchConfig
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.LabelConfig
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.NumberSliderConfig
@ -30,6 +31,10 @@ import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType
val TEST_CONFIGS1: List<Config> = listOf(
LabelConfig(
key = ConfigKey.NAME,
defaultValue = "Test name",
),
NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f,