From 604972fe232778765f5ead07f891a715ad6c1c44 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:10:55 -0700 Subject: [PATCH] Make importing model functionality better. - Allow users to specify default parameters before importing. --- .../com/google/aiedge/gallery/data/Config.kt | 16 ++ .../gallery/data/DataStoreRepository.kt | 20 +- .../google/aiedge/gallery/data/HuggingFace.kt | 13 +- .../com/google/aiedge/gallery/data/Model.kt | 83 ++++--- .../gallery/ui/common/chat/ConfigDialog.kt | 107 ++++++--- .../gallery/ui/common/modelitem/ModelItem.kt | 7 +- .../aiedge/gallery/ui/home/HomeScreen.kt | 226 ++++++++++++++++-- .../ModelImportDialog.kt | 170 +++++++++++-- .../aiedge/gallery/ui/home/SettingsDialog.kt | 2 +- .../gallery/ui/llmchat/LlmChatConfigs.kt | 38 ++- .../gallery/ui/llmchat/LlmChatModelHelper.kt | 15 +- .../gallery/ui/modelmanager/ModelList.kt | 141 +---------- .../ui/modelmanager/ModelManagerViewModel.kt | 115 +++++---- .../ui/preview/PreviewDataStoreRepository.kt | 6 +- .../aiedge/gallery/ui/preview/PreviewTasks.kt | 5 + 15 files changed, 635 insertions(+), 329 deletions(-) rename Android/src/app/src/main/java/com/google/aiedge/gallery/ui/{modelmanager => home}/ModelImportDialog.kt (63%) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Config.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Config.kt index 73814e2..4ebde26 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Config.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Config.kt @@ -23,6 +23,7 @@ package com.google.aiedge.gallery.data * Each type corresponds to a specific editor widget, such as a slider or a switch. */ enum class ConfigEditorType { + LABEL, NUMBER_SLIDER, BOOLEAN_SWITCH, DROPDOWN, @@ -57,6 +58,19 @@ open class Config( open val needReinitialization: Boolean = true, ) +/** + * Configuration setting for a label. + */ +class LabelConfig( + override val key: ConfigKey, + override val defaultValue: String = "", +) : Config( + type = ConfigEditorType.LABEL, + key = key, + defaultValue = defaultValue, + valueType = ValueType.STRING +) + /** * Configuration setting for a number slider. * @@ -99,9 +113,11 @@ class SegmentedButtonConfig( override val key: ConfigKey, override val defaultValue: String, val options: List, + val allowMultiple: Boolean = false, ) : Config( type = ConfigEditorType.DROPDOWN, key = key, defaultValue = defaultValue, + // The emitted value will be comma-separated labels when allowMultiple=true. valueType = ValueType.STRING, ) \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt index e65173a..d1cc28d 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/DataStoreRepository.kt @@ -47,8 +47,8 @@ interface DataStoreRepository { fun readThemeOverride(): String fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) fun readAccessTokenData(): AccessTokenData? - fun saveLocalModels(localModels: List) - fun readLocalModels(): List + fun saveImportedModels(importedModels: List) + fun readImportedModels(): List } /** @@ -82,8 +82,8 @@ class DefaultDataStoreRepository( val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at") - // Data for all imported local models. - val LOCAL_MODELS = stringPreferencesKey("local_models") + // Data for all imported models. + val IMPORTED_MODELS = stringPreferencesKey("imported_models") } private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key" @@ -160,22 +160,22 @@ class DefaultDataStoreRepository( } } - override fun saveLocalModels(localModels: List) { + override fun saveImportedModels(importedModels: List) { runBlocking { dataStore.edit { preferences -> val gson = Gson() - val jsonString = gson.toJson(localModels) - preferences[PreferencesKeys.LOCAL_MODELS] = jsonString + val jsonString = gson.toJson(importedModels) + preferences[PreferencesKeys.IMPORTED_MODELS] = jsonString } } } - override fun readLocalModels(): List { + override fun readImportedModels(): List { return runBlocking { val preferences = dataStore.data.first() - val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]" + val infosStr = preferences[PreferencesKeys.IMPORTED_MODELS] ?: "[]" val gson = Gson() - val listType = object : TypeToken>() {}.type + val listType = object : TypeToken>() {}.type gson.fromJson(infosStr, listType) } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt index 6aa08ea..cf92212 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/HuggingFace.kt @@ -17,7 +17,6 @@ package com.google.aiedge.gallery.data import com.google.aiedge.gallery.ui.common.ensureValidFileName -import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable import kotlinx.serialization.SerializationException @@ -107,11 +106,13 @@ data class HfModel( val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}") // Generate configs based on the given default values. - val configs: List = when (task) { - TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs) - // todo: add configs for other types. - else -> listOf() - } +// val configs: List = when (task) { +// TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs) +// // todo: add configs for other types. +// else -> listOf() +// } + // todo: fix when loading from models.json + val configs: List = listOf() // Construct url. var modelUrl = url diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt index b72db89..b617ea1 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt @@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data import android.content.Context import com.google.aiedge.gallery.ui.common.chat.PromptTemplate import com.google.aiedge.gallery.ui.common.convertValueToTargetType +import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs data class ModelDataFile( @@ -28,8 +29,8 @@ data class ModelDataFile( val sizeInBytes: Long, ) -enum class LlmBackend { - CPU, GPU +enum class Accelerator(val label: String) { + CPU(label = "CPU"), GPU(label = "GPU") } const val IMPORTS_DIR = "__imports" @@ -81,14 +82,14 @@ data class Model( /** The name of the directory to unzip the model to (if it's a zip file). */ val unzipDir: String = "", - /** The preferred backend of the model (only for LLM). */ - val llmBackend: LlmBackend = LlmBackend.GPU, + /** The accelerators the the model can run with. */ + val accelerators: List = DEFAULT_ACCELERATORS, /** The prompt templates for the model (only for LLM). */ val llmPromptTemplates: List = listOf(), - /** Whether the model is imported as a local model. */ - val isLocalModel: Boolean = false, + /** Whether the model is imported or not. */ + val imported: Boolean = false, // The following fields are managed by the app. Don't need to set manually. var taskType: TaskType? = null, @@ -135,6 +136,12 @@ data class Model( ) as Boolean } + fun getStringConfigValue(key: ConfigKey, defaultValue: String = ""): String { + return getTypedConfigValue( + key = key, valueType = ValueType.STRING, defaultValue = defaultValue + ) as String + } + fun getExtraDataFile(name: String): ModelDataFile? { return extraDataFiles.find { it.name == name } } @@ -147,7 +154,11 @@ data class Model( } /** Data for a imported local model. */ -data class LocalModelInfo(val fileName: String, val fileSize: Long) +data class ImportedModelInfo( + val fileName: String, + val fileSize: Long, + val defaultValues: Map +) enum class ModelDownloadStatusType { NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED, @@ -165,29 +176,25 @@ data class ModelDownloadStatus( //////////////////////////////////////////////////////////////////////////////////////////////////// // Configs. -enum class ConfigKey(val label: String, val id: String) { - MAX_TOKENS("Max tokens", id = "max_token"), - TOPK("TopK", id = "topk"), - TOPP( - "TopP", - id = "topp" - ), - TEMPERATURE("Temperature", id = "temperature"), - MAX_RESULT_COUNT( - "Max result count", - id = "max_result_count" - ), - USE_GPU("Use GPU", id = "use_gpu"), - WARM_UP_ITERATIONS( - "Warm up iterations", - id = "warm_up_iterations" - ), - BENCHMARK_ITERATIONS( - "Benchmark iterations", - id = "benchmark_iterations" - ), - ITERATIONS("Iterations", id = "iterations"), - THEME("Theme", id = "theme"), +enum class ConfigKey(val label: String) { + MAX_TOKENS("Max tokens"), + TOPK("TopK"), + TOPP("TopP"), + TEMPERATURE("Temperature"), + DEFAULT_MAX_TOKENS("Default max tokens"), + DEFAULT_TOPK("Default TopK"), + DEFAULT_TOPP("Default TopP"), + DEFAULT_TEMPERATURE("Default temperature"), + MAX_RESULT_COUNT("Max result count"), + USE_GPU("Use GPU"), + ACCELERATOR("Accelerator"), + COMPATIBLE_ACCELERATORS("Compatible accelerators"), + WARM_UP_ITERATIONS("Warm up iterations"), + BENCHMARK_ITERATIONS("Benchmark iterations"), + ITERATIONS("Iterations"), + THEME("Theme"), + NAME("Name"), + MODEL_TYPE("Model type") } val MOBILENET_CONFIGS: List = listOf( @@ -258,7 +265,12 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model( downloadFileName = "gemma3-1b-it-int4.task", url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true", sizeInBytes = 554661243L, - configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f), + accelerators = listOf(Accelerator.CPU, Accelerator.GPU), + configs = createLlmChatConfigs( + defaultTopK = 64, + defaultTopP = 0.95f, + accelerators = listOf(Accelerator.CPU, Accelerator.GPU) + ), info = LLM_CHAT_INFO, learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT", llmPromptTemplates = listOf( @@ -280,8 +292,13 @@ val MODEL_LLM_DEEPSEEK: Model = Model( downloadFileName = "deepseek.task", url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true", sizeInBytes = 1860686856L, - llmBackend = LlmBackend.CPU, - configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f), + accelerators = listOf(Accelerator.CPU), + configs = createLlmChatConfigs( + defaultTemperature = 0.6f, + defaultTopK = 40, + defaultTopP = 0.7f, + accelerators = listOf(Accelerator.CPU) + ), info = LLM_CHAT_INFO, learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B", ) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ConfigDialog.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ConfigDialog.kt index 84990b3..ed592f7 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ConfigDialog.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ConfigDialog.kt @@ -34,16 +34,15 @@ import androidx.compose.foundation.text.KeyboardOptions import androidx.compose.material3.Button import androidx.compose.material3.Card import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.MultiChoiceSegmentedButtonRow import androidx.compose.material3.SegmentedButton import androidx.compose.material3.SegmentedButtonDefaults -import androidx.compose.material3.SingleChoiceSegmentedButtonRow import androidx.compose.material3.Slider import androidx.compose.material3.Switch import androidx.compose.material3.Text import androidx.compose.material3.TextButton import androidx.compose.runtime.Composable import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.mutableStateMapOf import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember @@ -60,6 +59,7 @@ import androidx.compose.ui.unit.dp import androidx.compose.ui.window.Dialog import com.google.aiedge.gallery.data.BooleanSwitchConfig import com.google.aiedge.gallery.data.Config +import com.google.aiedge.gallery.data.LabelConfig import com.google.aiedge.gallery.data.NumberSliderConfig import com.google.aiedge.gallery.data.SegmentedButtonConfig import com.google.aiedge.gallery.data.ValueType @@ -113,27 +113,10 @@ fun ConfigDialog( } } - // List of config rows. - for (config in configs) { - when (config) { - // Number slider. - is NumberSliderConfig -> { - NumberSliderRow(config = config, values = values) - } + ConfigEditorsPanel(configs = configs, values = values) - // Boolean switch. - is BooleanSwitchConfig -> { - BooleanSwitchRow(config = config, values = values) - } - - is SegmentedButtonConfig -> { - SegmentedButtonRow(config = config, values = values) - } - - else -> {} - } - } + // Button row. Row( modifier = Modifier .fillMaxWidth() @@ -164,6 +147,53 @@ fun ConfigDialog( } } +/** + * Composable function to display a list of config editor rows. + */ +@Composable +fun ConfigEditorsPanel(configs: List, values: SnapshotStateMap) { + for (config in configs) { + when (config) { + // Label. + is LabelConfig -> { + LabelRow(config = config, values = values) + } + + // Number slider. + is NumberSliderConfig -> { + NumberSliderRow(config = config, values = values) + } + + // Boolean switch. + is BooleanSwitchConfig -> { + BooleanSwitchRow(config = config, values = values) + } + + // Segmented button. + is SegmentedButtonConfig -> { + SegmentedButtonRow(config = config, values = values) + } + + else -> {} + } + } +} + +@Composable +fun LabelRow(config: LabelConfig, values: SnapshotStateMap) { + Column(modifier = Modifier.fillMaxWidth()) { + // Field label. + Text(config.key.label, style = MaterialTheme.typography.titleSmall) + // Content label. + val label = try { + values[config.key.label] as String + } catch (e: Exception) { + "" + } + Text(label, style = MaterialTheme.typography.bodyMedium) + } +} + /** * Composable function to display a number slider with an associated text input field. * @@ -272,18 +302,41 @@ fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap) { - var selectedIndex by remember { mutableIntStateOf(config.options.indexOf(values[config.key.label])) } + val selectedOptions: List = remember { (values[config.key.label] as String).split(",") } + var selectionStates: List by remember { + mutableStateOf(List(config.options.size) { index -> + selectedOptions.contains(config.options[index]) + }) + } Column(modifier = Modifier.fillMaxWidth()) { Text(config.key.label, style = MaterialTheme.typography.titleSmall) - SingleChoiceSegmentedButtonRow { + MultiChoiceSegmentedButtonRow { config.options.forEachIndexed { index, label -> SegmentedButton(shape = SegmentedButtonDefaults.itemShape( index = index, count = config.options.size - ), onClick = { - selectedIndex = index - values[config.key.label] = label - }, selected = index == selectedIndex, label = { Text(label) }) + ), onCheckedChange = { + var newSelectionStates = selectionStates.toMutableList() + val selectedCount = newSelectionStates.count { it } + + // Single select. + if (!config.allowMultiple) { + if (!newSelectionStates[index]) { + newSelectionStates = MutableList(config.options.size) { it == index } + } + } + // Multiple select. + else { + if (!(selectedCount == 1 && newSelectionStates[index])) { + newSelectionStates[index] = !newSelectionStates[index] + } + } + selectionStates = newSelectionStates + + values[config.key.label] = + config.options.filterIndexed { index, option -> selectionStates[index] } + .joinToString(",") + }, checked = selectionStates[index], label = { Text(label) }) } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt index 7e3d18f..00dc4fc 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/modelitem/ModelItem.kt @@ -92,7 +92,6 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp * model description and buttons for learning more (opening a URL) and downloading/trying * the model. */ -@OptIn(ExperimentalMaterial3Api::class) @Composable fun ModelItem( model: Model, @@ -188,9 +187,9 @@ fun ModelItem( } } else { Icon( - // For local model, show ">" directly indicating users can just tap the model item to + // For imported model, show ">" directly indicating users can just tap the model item to // go into it without needing to expand it first. - if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore, + if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore, contentDescription = "", tint = getTaskIconColor(task), ) @@ -272,7 +271,7 @@ fun ModelItem( boxModifier = if (canExpand) { boxModifier.clickable( onClick = { - if (!model.isLocalModel) { + if (!model.imported) { isExpanded = !isExpanded } else { onModelClicked(model) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt index 78bf841..23921fd 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/HomeScreen.kt @@ -16,13 +16,22 @@ package com.google.aiedge.gallery.ui.home +import android.content.Intent +import android.net.Uri +import android.util.Log +import androidx.activity.compose.rememberLauncherForActivityResult +import androidx.activity.result.ActivityResultLauncher +import androidx.activity.result.contract.ActivityResultContracts import androidx.annotation.StringRes +import androidx.compose.animation.core.animateFloatAsState +import androidx.compose.animation.core.tween import androidx.compose.foundation.background import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.PaddingValues +import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.aspectRatio import androidx.compose.foundation.layout.fillMaxSize @@ -34,22 +43,34 @@ import androidx.compose.foundation.lazy.grid.GridItemSpan import androidx.compose.foundation.lazy.grid.LazyVerticalGrid import androidx.compose.foundation.lazy.grid.items import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.outlined.NoteAdd +import androidx.compose.material.icons.filled.Add import androidx.compose.material3.Card import androidx.compose.material3.CardDefaults import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.ModalBottomSheet import androidx.compose.material3.Scaffold +import androidx.compose.material3.SmallFloatingActionButton import androidx.compose.material3.Text import androidx.compose.material3.TopAppBarDefaults +import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.derivedStateOf import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.clip +import androidx.compose.ui.draw.scale import androidx.compose.ui.graphics.Brush import androidx.compose.ui.input.nestedscroll.nestedScroll import androidx.compose.ui.layout.layout @@ -66,6 +87,8 @@ import com.google.aiedge.gallery.R import com.google.aiedge.gallery.data.AppBarAction import com.google.aiedge.gallery.data.AppBarActionType import com.google.aiedge.gallery.data.ConfigKey +import com.google.aiedge.gallery.data.ImportedModelInfo +import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.TaskIcon import com.google.aiedge.gallery.ui.common.getTaskBgColor @@ -75,6 +98,11 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.ThemeSettings import com.google.aiedge.gallery.ui.theme.customColors import com.google.aiedge.gallery.ui.theme.titleMediumNarrow +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch + +private const val TAG = "AGHomeScreen" +private const val TASK_COUNT_ANIMATION_DURATION = 250 /** Navigation destination data */ object HomeScreenDestination { @@ -92,22 +120,59 @@ fun HomeScreen( val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior() val uiState by modelManagerViewModel.uiState.collectAsState() var showSettingsDialog by remember { mutableStateOf(false) } + var showImportModelSheet by remember { mutableStateOf(false) } + val sheetState = rememberModalBottomSheetState() + var showImportDialog by remember { mutableStateOf(false) } + var showImportingDialog by remember { mutableStateOf(false) } + val selectedLocalModelFileUri = remember { mutableStateOf(null) } + val selectedImportedModelInfo = remember { mutableStateOf(null) } + val coroutineScope = rememberCoroutineScope() val tasks = uiState.tasks val loadingHfModels = uiState.loadingHfModels - Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = { - GalleryTopAppBar( - title = stringResource(HomeScreenDestination.titleRes), - rightAction = AppBarAction( - actionType = AppBarActionType.APP_SETTING, actionFn = { - showSettingsDialog = true - } - ), - loadingHfModels = loadingHfModels, - scrollBehavior = scrollBehavior, - ) - }) { innerPadding -> + val filePickerLauncher: ActivityResultLauncher = rememberLauncherForActivityResult( + contract = ActivityResultContracts.StartActivityForResult() + ) { result -> + if (result.resultCode == android.app.Activity.RESULT_OK) { + result.data?.data?.let { uri -> + selectedLocalModelFileUri.value = uri + showImportDialog = true + } ?: run { + Log.d(TAG, "No file selected or URI is null.") + } + } else { + Log.d(TAG, "File picking cancelled.") + } + } + + Scaffold( + modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), + topBar = { + GalleryTopAppBar( + title = stringResource(HomeScreenDestination.titleRes), + rightAction = AppBarAction( + actionType = AppBarActionType.APP_SETTING, actionFn = { + showSettingsDialog = true + } + ), + loadingHfModels = loadingHfModels, + scrollBehavior = scrollBehavior, + ) + }, + floatingActionButton = { + // A floating action button to show "import model" bottom sheet. + SmallFloatingActionButton( + onClick = { + showImportModelSheet = true + }, + containerColor = MaterialTheme.colorScheme.secondaryContainer, + contentColor = MaterialTheme.colorScheme.secondary, + ) { + Icon(Icons.Filled.Add, "") + } + } + ) { innerPadding -> TaskList( tasks = tasks, navigateToTaskScreen = navigateToTaskScreen, @@ -132,6 +197,83 @@ fun HomeScreen( }, ) } + + // Import model bottom sheet. + if (showImportModelSheet) { + ModalBottomSheet( + onDismissRequest = { showImportModelSheet = false }, + sheetState = sheetState, + ) { + Text( + "Import model", + style = MaterialTheme.typography.titleLarge, + modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp) + ) + Box(modifier = Modifier.clickable { + coroutineScope.launch { + // Give it sometime to show the click effect. + delay(200) + showImportModelSheet = false + + // Show file picker. + val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply { + addCategory(Intent.CATEGORY_OPENABLE) + type = "*/*" + putExtra( + Intent.EXTRA_MIME_TYPES, + arrayOf("application/x-binary", "application/octet-stream") + ) + // Single select. + putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false) + } + filePickerLauncher.launch(intent) + } + }) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp), + modifier = Modifier + .fillMaxWidth() + .padding(16.dp) + ) { + Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "") + Text("From local model file") + } + } + } + } + + // Import dialog + if (showImportDialog) { + selectedLocalModelFileUri.value?.let { uri -> + ModelImportDialog(uri = uri, + onDismiss = { showImportDialog = false }, + onDone = { info -> + selectedImportedModelInfo.value = info + showImportDialog = false + showImportingDialog = true + }) + } + } + + // Importing in progress dialog. + if (showImportingDialog) { + selectedLocalModelFileUri.value?.let { uri -> + selectedImportedModelInfo.value?.let { info -> + ModelImportingDialog( + uri = uri, + info = info, + onDismiss = { showImportingDialog = false }, + onDone = { + modelManagerViewModel.addImportedLlmModel( + task = TASK_LLM_CHAT, + info = it, + ) + showImportingDialog = false + }) + } + } + } } @Composable @@ -150,7 +292,7 @@ private fun TaskList( verticalArrangement = Arrangement.spacedBy(8.dp), ) { // Headline. - item(span = { GridItemSpan(2) }) { + item(key = "headline", span = { GridItemSpan(2) }) { Text( "Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community", textAlign = TextAlign.Center, @@ -171,6 +313,11 @@ private fun TaskList( .aspectRatio(1f) ) } + + // Bottom padding. + item(key = "bottomPadding", span = { GridItemSpan(2) }) { + Spacer(modifier = Modifier.height(60.dp)) + } } // Gradient overlay at the bottom. @@ -190,6 +337,48 @@ private fun TaskList( @Composable private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) { + // Observes the model count and updates the model count label with a fade-in/fade-out animation + // whenever the count changes. + val modelCount by remember { + derivedStateOf { + val trigger = task.updateTrigger.value + if (trigger >= 0) { + task.models.size + } else { + 0 + } + } + } + val modelCountLabel by remember { + derivedStateOf { + when (modelCount) { + 1 -> "1 Model" + else -> "%d Models".format(modelCount) + } + } + } + var curModelCountLabel by remember { mutableStateOf("") } + var modelCountLabelVisible by remember { mutableStateOf(true) } + val modelCountAlpha: Float by animateFloatAsState( + targetValue = if (modelCountLabelVisible) 1f else 0f, + animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION) + ) + val modelCountScale: Float by animateFloatAsState( + targetValue = if (modelCountLabelVisible) 1f else 0.7f, + animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION) + ) + + LaunchedEffect(modelCountLabel) { + if (curModelCountLabel.isEmpty()) { + curModelCountLabel = modelCountLabel + } else { + modelCountLabelVisible = false + delay(TASK_COUNT_ANIMATION_DURATION.toLong()) + curModelCountLabel = modelCountLabel + modelCountLabelVisible = true + } + } + Card( modifier = modifier .clip(RoundedCornerShape(43.5.dp)) @@ -238,14 +427,13 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif } // Model count. - val modelCountLabel = when (task.models.size) { - 1 -> "1 Model" - else -> "%d Models".format(task.models.size) - } Text( - modelCountLabel, + curModelCountLabel, color = MaterialTheme.colorScheme.secondary, - style = MaterialTheme.typography.bodyMedium + style = MaterialTheme.typography.bodyMedium, + modifier = Modifier + .alpha(modelCountAlpha) + .scale(modelCountScale), ) } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt similarity index 63% rename from Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelImportDialog.kt rename to Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt index d8b26f9..aefbbe9 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelImportDialog.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/ModelImportDialog.kt @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.aiedge.gallery.ui.modelmanager +package com.google.aiedge.gallery.ui.home import android.content.Context import android.net.Uri @@ -36,24 +36,40 @@ import androidx.compose.material3.Icon import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text +import androidx.compose.material3.TextButton import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableFloatStateOf import androidx.compose.runtime.mutableLongStateOf +import androidx.compose.runtime.mutableStateMapOf import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.setValue +import androidx.compose.runtime.snapshots.SnapshotStateMap import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.unit.dp import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.DialogProperties +import com.google.aiedge.gallery.data.Accelerator +import com.google.aiedge.gallery.data.Config +import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.IMPORTS_DIR +import com.google.aiedge.gallery.data.LabelConfig +import com.google.aiedge.gallery.data.ImportedModelInfo +import com.google.aiedge.gallery.data.NumberSliderConfig +import com.google.aiedge.gallery.data.SegmentedButtonConfig +import com.google.aiedge.gallery.data.ValueType +import com.google.aiedge.gallery.ui.common.chat.ConfigEditorsPanel import com.google.aiedge.gallery.ui.common.ensureValidFileName import com.google.aiedge.gallery.ui.common.humanReadableSize +import com.google.aiedge.gallery.ui.llmchat.DEFAULT_MAX_TOKEN +import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TEMPERATURE +import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPK +import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPP import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch @@ -64,37 +80,151 @@ import java.nio.charset.StandardCharsets private const val TAG = "AGModelImportDialog" -data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "") +private val IMPORT_CONFIGS_LLM: List = listOf( + LabelConfig(key = ConfigKey.NAME), + LabelConfig(key = ConfigKey.MODEL_TYPE), + NumberSliderConfig( + key = ConfigKey.DEFAULT_MAX_TOKENS, + sliderMin = 100f, + sliderMax = 1024f, + defaultValue = DEFAULT_MAX_TOKEN.toFloat(), + valueType = ValueType.INT + ), + NumberSliderConfig( + key = ConfigKey.DEFAULT_TOPK, + sliderMin = 5f, + sliderMax = 40f, + defaultValue = DEFAULT_TOPK.toFloat(), + valueType = ValueType.INT + ), + NumberSliderConfig( + key = ConfigKey.DEFAULT_TOPP, + sliderMin = 0.0f, + sliderMax = 1.0f, + defaultValue = DEFAULT_TOPP, + valueType = ValueType.FLOAT + ), + NumberSliderConfig( + key = ConfigKey.DEFAULT_TEMPERATURE, + sliderMin = 0.0f, + sliderMax = 2.0f, + defaultValue = DEFAULT_TEMPERATURE, + valueType = ValueType.FLOAT + ), + SegmentedButtonConfig( + key = ConfigKey.COMPATIBLE_ACCELERATORS, + defaultValue = Accelerator.CPU.label, + options = listOf(Accelerator.CPU.label, Accelerator.GPU.label), + allowMultiple = true, + ) +) @Composable fun ModelImportDialog( - uri: Uri, onDone: (ModelImportInfo) -> Unit + uri: Uri, + onDismiss: () -> Unit, + onDone: (ImportedModelInfo) -> Unit ) { val context = LocalContext.current - val coroutineScope = rememberCoroutineScope() + val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) } + val fileSize by remember { mutableLongStateOf(info.first) } + val fileName by remember { mutableStateOf(ensureValidFileName(info.second)) } - var fileName by remember { mutableStateOf("") } - var fileSize by remember { mutableLongStateOf(0L) } + val initialValues: Map = remember { + mutableMapOf().apply { + for (config in IMPORT_CONFIGS_LLM) { + put(config.key.label, config.defaultValue) + } + put(ConfigKey.NAME.label, fileName) + // TODO: support other types. + put(ConfigKey.MODEL_TYPE.label, "LLM") + } + } + val values: SnapshotStateMap = remember { + mutableStateMapOf().apply { + putAll(initialValues) + } + } + + Dialog( + onDismissRequest = onDismiss, + ) { + Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { + Column( + modifier = Modifier + .padding(20.dp), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Title. + Text( + "Import Model", + style = MaterialTheme.typography.titleLarge, + modifier = Modifier.padding(bottom = 8.dp) + ) + + // Default configs for users to set. + ConfigEditorsPanel( + configs = IMPORT_CONFIGS_LLM, + values = values, + ) + + // Button row. + Row( + modifier = Modifier + .fillMaxWidth() + .padding(top = 8.dp), + horizontalArrangement = Arrangement.End, + ) { + // Cancel button. + TextButton( + onClick = { onDismiss() }, + ) { + Text("Cancel") + } + + // Import button + Button( + onClick = { + onDone( + ImportedModelInfo( + fileName = fileName, + fileSize = fileSize, + defaultValues = values, + ) + ) + }, + ) { + Text("Import") + } + } + + } + } + } +} + +@Composable +fun ModelImportingDialog( + uri: Uri, + info: ImportedModelInfo, + onDismiss: () -> Unit, + onDone: (ImportedModelInfo) -> Unit +) { var error by remember { mutableStateOf("") } + val context = LocalContext.current + val coroutineScope = rememberCoroutineScope() var progress by remember { mutableFloatStateOf(0f) } LaunchedEffect(Unit) { - error = "" - - // Get basic info. - val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri) - fileSize = info.first - fileName = ensureValidFileName(info.second) - // Import. importModel( context = context, coroutineScope = coroutineScope, - fileName = fileName, - fileSize = fileSize, + fileName = info.fileName, + fileSize = info.fileSize, uri = uri, onDone = { - onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error)) + onDone(info) }, onProgress = { progress = it @@ -107,7 +237,7 @@ fun ModelImportDialog( Dialog( properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false), - onDismissRequest = {}, + onDismissRequest = onDismiss, ) { Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { Column( @@ -117,7 +247,7 @@ fun ModelImportDialog( ) { // Title. Text( - "Importing...", + "Import Model", style = MaterialTheme.typography.titleLarge, modifier = Modifier.padding(bottom = 8.dp) ) @@ -127,7 +257,7 @@ fun ModelImportDialog( // Progress bar. Column(verticalArrangement = Arrangement.spacedBy(4.dp)) { Text( - "$fileName (${fileSize.humanReadableSize()})", + "${info.fileName} (${info.fileSize.humanReadableSize()})", style = MaterialTheme.typography.labelSmall, ) val animatedProgress = remember { Animatable(0f) } @@ -162,7 +292,7 @@ fun ModelImportDialog( } Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) { Button(onClick = { - onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error)) + onDismiss() }) { Text("Close") } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt index 5d111fc..f8b4607 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/SettingsDialog.kt @@ -30,7 +30,7 @@ private val CONFIGS: List = listOf( SegmentedButtonConfig( key = ConfigKey.THEME, defaultValue = THEME_AUTO, - options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK) + options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK), ) ) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatConfigs.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatConfigs.kt index 788a5c6..58759fe 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatConfigs.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatConfigs.kt @@ -16,24 +16,25 @@ package com.google.aiedge.gallery.ui.llmchat +import com.google.aiedge.gallery.data.Accelerator import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.ConfigKey -import com.google.aiedge.gallery.data.ConfigValue import com.google.aiedge.gallery.data.NumberSliderConfig +import com.google.aiedge.gallery.data.SegmentedButtonConfig import com.google.aiedge.gallery.data.ValueType -import com.google.aiedge.gallery.data.getFloatConfigValue -import com.google.aiedge.gallery.data.getIntConfigValue -private const val DEFAULT_MAX_TOKEN = 1024 -private const val DEFAULT_TOPK = 40 -private const val DEFAULT_TOPP = 0.9f -private const val DEFAULT_TEMPERATURE = 1.0f +const val DEFAULT_MAX_TOKEN = 1024 +const val DEFAULT_TOPK = 40 +const val DEFAULT_TOPP = 0.9f +const val DEFAULT_TEMPERATURE = 1.0f +val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU) fun createLlmChatConfigs( defaultMaxToken: Int = DEFAULT_MAX_TOKEN, defaultTopK: Int = DEFAULT_TOPK, defaultTopP: Float = DEFAULT_TOPP, - defaultTemperature: Float = DEFAULT_TEMPERATURE + defaultTemperature: Float = DEFAULT_TEMPERATURE, + accelerators: List = DEFAULT_ACCELERATORS, ): List { return listOf( NumberSliderConfig( @@ -64,21 +65,10 @@ fun createLlmChatConfigs( defaultValue = defaultTemperature, valueType = ValueType.FLOAT ), - ) -} - -fun createLLmChatConfig(defaults: Map): List { - val defaultMaxToken = - getIntConfigValue(defaults[ConfigKey.MAX_TOKENS.id], default = DEFAULT_MAX_TOKEN) - val defaultTopK = getIntConfigValue(defaults[ConfigKey.TOPK.id], default = DEFAULT_TOPK) - val defaultTopP = getFloatConfigValue(defaults[ConfigKey.TOPP.id], default = DEFAULT_TOPP) - val defaultTemperature = - getFloatConfigValue(defaults[ConfigKey.TEMPERATURE.id], default = DEFAULT_TEMPERATURE) - - return createLlmChatConfigs( - defaultMaxToken = defaultMaxToken, - defaultTopK = defaultTopK, - defaultTopP = defaultTopP, - defaultTemperature = defaultTemperature + SegmentedButtonConfig( + key = ConfigKey.ACCELERATOR, + defaultValue = if (accelerators.contains(Accelerator.GPU)) Accelerator.GPU.label else accelerators[0].label, + options = accelerators.map { it.label } + ) ) } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt index aa01ac1..fe2be1b 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -18,18 +18,14 @@ package com.google.aiedge.gallery.ui.llmchat import android.content.Context import android.util.Log +import com.google.aiedge.gallery.data.Accelerator import com.google.aiedge.gallery.data.ConfigKey -import com.google.aiedge.gallery.data.LlmBackend import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession private const val TAG = "AGLlmChatModelHelper" -private const val DEFAULT_MAX_TOKEN = 1024 -private const val DEFAULT_TOPK = 40 -private const val DEFAULT_TOPP = 0.9f -private const val DEFAULT_TEMPERATURE = 1.0f typealias ResultListener = (partialResult: String, done: Boolean) -> Unit typealias CleanUpListener = () -> Unit @@ -49,10 +45,13 @@ object LlmChatModelHelper { val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP) val temperature = model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE) + val accelerator = + model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = Accelerator.GPU.label) Log.d(TAG, "Initializing...") - val preferredBackend = when (model.llmBackend) { - LlmBackend.CPU -> LlmInference.Backend.CPU - LlmBackend.GPU -> LlmInference.Backend.GPU + val preferredBackend = when (accelerator) { + Accelerator.CPU.label -> LlmInference.Backend.CPU + Accelerator.GPU.label -> LlmInference.Backend.GPU + else -> LlmInference.Backend.GPU } val options = LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context)) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt index ba1f8f3..60dc32c 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelList.kt @@ -16,13 +16,7 @@ package com.google.aiedge.gallery.ui.modelmanager -import android.content.Intent -import android.net.Uri import android.os.Build -import android.util.Log -import androidx.activity.compose.rememberLauncherForActivityResult -import androidx.activity.result.ActivityResultLauncher -import androidx.activity.result.contract.ActivityResultContracts import androidx.annotation.RequiresApi import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement @@ -30,32 +24,21 @@ import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.Row -import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxWidth -import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.material.icons.Icons -import androidx.compose.material.icons.automirrored.outlined.NoteAdd -import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.outlined.Code import androidx.compose.material.icons.outlined.Description -import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.ModalBottomSheet -import androidx.compose.material3.SmallFloatingActionButton import androidx.compose.material3.Text -import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.runtime.Composable import androidx.compose.runtime.derivedStateOf import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember -import androidx.compose.runtime.rememberCoroutineScope -import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.vector.ImageVector @@ -75,14 +58,11 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.customColors -import kotlinx.coroutines.delay -import kotlinx.coroutines.launch private const val TAG = "AGModelList" /** The list of models in the model manager. */ @RequiresApi(Build.VERSION_CODES.O) -@OptIn(ExperimentalMaterial3Api::class) @Composable fun ModelList( task: Task, @@ -91,50 +71,29 @@ fun ModelList( onModelClicked: (Model) -> Unit, modifier: Modifier = Modifier, ) { - var showAddModelSheet by remember { mutableStateOf(false) } - var showImportingDialog by remember { mutableStateOf(false) } - val curFileUri = remember { mutableStateOf(null) } - val sheetState = rememberModalBottomSheetState() - val coroutineScope = rememberCoroutineScope() - // This is just to update "models" list when task.updateTrigger is updated so that the UI can // be properly updated. val models by remember { derivedStateOf { val trigger = task.updateTrigger.value if (trigger >= 0) { - task.models.toList().filter { !it.isLocalModel } + task.models.toList().filter { !it.imported } } else { listOf() } } } - val localModels by remember { + val importedModels by remember { derivedStateOf { val trigger = task.updateTrigger.value if (trigger >= 0) { - task.models.toList().filter { it.isLocalModel } + task.models.toList().filter { it.imported } } else { listOf() } } } - val filePickerLauncher: ActivityResultLauncher = rememberLauncherForActivityResult( - contract = ActivityResultContracts.StartActivityForResult() - ) { result -> - if (result.resultCode == android.app.Activity.RESULT_OK) { - result.data?.data?.let { uri -> - curFileUri.value = uri - showImportingDialog = true - } ?: run { - Log.d(TAG, "No file selected or URI is null.") - } - } else { - Log.d(TAG, "File picking cancelled.") - } - } - Box(contentAlignment = Alignment.BottomEnd) { LazyColumn( modifier = modifier.padding(top = 8.dp), @@ -190,11 +149,11 @@ fun ModelList( } } - // Title for local models. - if (localModels.isNotEmpty()) { - item(key = "localModelsTitle") { + // Title for imported models. + if (importedModels.isNotEmpty()) { + item(key = "importedModelsTitle") { Text( - "Local models", + "Imported models", style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), modifier = Modifier .padding(horizontal = 16.dp) @@ -203,8 +162,8 @@ fun ModelList( } } - // List of local models within a task. - items(items = localModels) { model -> + // List of imported models within a task. + items(items = importedModels) { model -> Box { ModelItem( model = model, @@ -215,88 +174,6 @@ fun ModelList( ) } } - - item(key = "bottomPadding") { - Spacer(modifier = Modifier.height(60.dp)) - } - } - - // Add model button at the bottom right. - Box( - modifier = Modifier - .padding(end = 16.dp) - .padding(bottom = contentPadding.calculateBottomPadding()) - ) { - SmallFloatingActionButton( - onClick = { - showAddModelSheet = true - }, - containerColor = MaterialTheme.colorScheme.secondaryContainer, - contentColor = MaterialTheme.colorScheme.secondary, - ) { - Icon(Icons.Filled.Add, "") - } - } - } - - if (showAddModelSheet) { - ModalBottomSheet( - onDismissRequest = { showAddModelSheet = false }, - sheetState = sheetState, - ) { - Text( - "Add custom model", - style = MaterialTheme.typography.titleLarge, - modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp) - ) - Box(modifier = Modifier.clickable { - coroutineScope.launch { - // Give it sometime to show the click effect. - delay(200) - showAddModelSheet = false - - // Show file picker. - val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply { - addCategory(Intent.CATEGORY_OPENABLE) - type = "*/*" - putExtra( - Intent.EXTRA_MIME_TYPES, - arrayOf("application/x-binary", "application/octet-stream") - ) - // Single select. - putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false) - } - filePickerLauncher.launch(intent) - } - }) { - Row( - verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.spacedBy(6.dp), - modifier = Modifier - .fillMaxWidth() - .padding(16.dp) - ) { - Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "") - Text("Add local model") - } - } - } - } - - if (showImportingDialog) { - curFileUri.value?.let { uri -> - ModelImportDialog(uri = uri, onDone = { info -> - showImportingDialog = false - - if (info.error.isEmpty()) { - // TODO: support other model types. - modelManagerViewModel.addLocalLlmModel( - task = task, - fileName = info.fileName, - fileSize = info.fileSize - ) - } - }) } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt index c4d3411..d82add5 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -23,8 +23,10 @@ import androidx.activity.result.ActivityResult import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.google.aiedge.gallery.data.AGWorkInfo +import com.google.aiedge.gallery.data.Accelerator import com.google.aiedge.gallery.data.AccessTokenData import com.google.aiedge.gallery.data.Config +import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.DataStoreRepository import com.google.aiedge.gallery.data.DownloadRepository import com.google.aiedge.gallery.data.EMPTY_MODEL @@ -32,7 +34,7 @@ import com.google.aiedge.gallery.data.HfModel import com.google.aiedge.gallery.data.HfModelDetails import com.google.aiedge.gallery.data.HfModelSummary import com.google.aiedge.gallery.data.IMPORTS_DIR -import com.google.aiedge.gallery.data.LocalModelInfo +import com.google.aiedge.gallery.data.ImportedModelInfo import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatusType @@ -40,12 +42,14 @@ import com.google.aiedge.gallery.data.TASKS import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.TaskType +import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.data.getModelByName import com.google.aiedge.gallery.ui.common.AuthConfig +import com.google.aiedge.gallery.ui.common.convertValueToTargetType import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper -import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig +import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async @@ -228,7 +232,7 @@ open class ModelManagerViewModel( ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED) // Delete model from the list if model is imported as a local model. - if (model.isLocalModel) { + if (model.imported) { val index = task.models.indexOf(model) if (index >= 0) { task.models.removeAt(index) @@ -237,12 +241,12 @@ open class ModelManagerViewModel( curModelDownloadStatus.remove(model.name) // Update preference. - val localModels = dataStoreRepository.readLocalModels().toMutableList() - val localModelIndex = localModels.indexOfFirst { it.fileName == model.name } - if (localModelIndex >= 0) { - localModels.removeAt(localModelIndex) + val importedModels = dataStoreRepository.readImportedModels().toMutableList() + val importedModelIndex = importedModels.indexOfFirst { it.fileName == model.name } + if (importedModelIndex >= 0) { + importedModels.removeAt(importedModelIndex) } - dataStoreRepository.saveLocalModels(localModels = localModels) + dataStoreRepository.saveImportedModels(importedModels = importedModels) } val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus) _uiState.update { newUiState } @@ -417,27 +421,20 @@ open class ModelManagerViewModel( return connection.responseCode } - fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) { - Log.d(TAG, "adding local model: $fileName, $fileSize") + fun addImportedLlmModel(task: Task, info: ImportedModelInfo) { + Log.d(TAG, "adding imported llm model: $info") // Create model. - val configs: List = createLLmChatConfig(defaults = mapOf()) - val model = Model( - name = fileName, - url = "", - configs = configs, - sizeInBytes = fileSize, - downloadFileName = "$IMPORTS_DIR/$fileName", - isLocalModel = true, - ) - model.preProcess(task = task) + val model = createModelFromImportedModelInfo(info = info, task = task) task.models.add(model) // Add initial status and states. val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap() val modelInstances = uiState.value.modelInitializationStatus.toMutableMap() modelDownloadStatus[model.name] = ModelDownloadStatus( - status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize + status = ModelDownloadStatusType.SUCCEEDED, + receivedBytes = info.fileSize, + totalBytes = info.fileSize ) modelInstances[model.name] = ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED) @@ -453,9 +450,9 @@ open class ModelManagerViewModel( task.updateTrigger.value = System.currentTimeMillis() // Add to preference storage. - val localModels = dataStoreRepository.readLocalModels().toMutableList() - localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize)) - dataStoreRepository.saveLocalModels(localModels = localModels) + val importedModels = dataStoreRepository.readImportedModels().toMutableList() + importedModels.add(info) + dataStoreRepository.saveImportedModels(importedModels = importedModels) } fun getTokenStatusAndData(): TokenStatusAndData { @@ -589,31 +586,22 @@ open class ModelManagerViewModel( } } - // Load local models. - for (localModel in dataStoreRepository.readLocalModels()) { - Log.d(TAG, "stored local model: $localModel") + // Load imported models. + for (importedModel in dataStoreRepository.readImportedModels()) { + Log.d(TAG, "stored imported model: $importedModel") // Create model. - val configs: List = createLLmChatConfig(defaults = mapOf()) - val model = Model( - name = localModel.fileName, - url = "", - configs = configs, - sizeInBytes = localModel.fileSize, - downloadFileName = "$IMPORTS_DIR/${localModel.fileName}", - isLocalModel = true, - ) + val model = createModelFromImportedModelInfo(info = importedModel, task = TASK_LLM_CHAT) // Add to task. val task = TASK_LLM_CHAT - model.preProcess(task = task) task.models.add(model) // Update status. modelDownloadStatus[model.name] = ModelDownloadStatus( status = ModelDownloadStatusType.SUCCEEDED, - receivedBytes = localModel.fileSize, - totalBytes = localModel.fileSize + receivedBytes = importedModel.fileSize, + totalBytes = importedModel.fileSize ) } @@ -628,6 +616,51 @@ open class ModelManagerViewModel( ) } + private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model { + val accelerators: List = (convertValueToTargetType( + info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, + ValueType.STRING + ) as String) + .split(",") + .mapNotNull { acceleratorLabel -> + when (acceleratorLabel.trim()) { + Accelerator.GPU.label -> Accelerator.GPU + Accelerator.CPU.label -> Accelerator.CPU + else -> null // Ignore unknown accelerator labels + } + } + val configs: List = createLlmChatConfigs( + defaultMaxToken = convertValueToTargetType( + info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!, + ValueType.INT + ) as Int, + defaultTopK = convertValueToTargetType( + info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!, + ValueType.INT + ) as Int, + defaultTopP = convertValueToTargetType( + info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!, + ValueType.FLOAT + ) as Float, + defaultTemperature = convertValueToTargetType( + info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!, + ValueType.FLOAT + ) as Float, + accelerators = accelerators, + ) + val model = Model( + name = info.fileName, + url = "", + configs = configs, + sizeInBytes = info.fileSize, + downloadFileName = "$IMPORTS_DIR/${info.fileName}", + imported = true, + ) + model.preProcess(task = task) + + return model + } + /** * Retrieves the download status of a model. * @@ -771,9 +804,7 @@ open class ModelManagerViewModel( } private fun updateModelInitializationStatus( - model: Model, - status: ModelInitializationStatusType, - error: String = "" + model: Model, status: ModelInitializationStatusType, error: String = "" ) { val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap() curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewDataStoreRepository.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewDataStoreRepository.kt index d6941c7..606132a 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewDataStoreRepository.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewDataStoreRepository.kt @@ -18,7 +18,7 @@ package com.google.aiedge.gallery.ui.preview import com.google.aiedge.gallery.data.AccessTokenData import com.google.aiedge.gallery.data.DataStoreRepository -import com.google.aiedge.gallery.data.LocalModelInfo +import com.google.aiedge.gallery.data.ImportedModelInfo class PreviewDataStoreRepository : DataStoreRepository { override fun saveTextInputHistory(history: List) { @@ -42,10 +42,10 @@ class PreviewDataStoreRepository : DataStoreRepository { return null } - override fun saveLocalModels(localModels: List) { + override fun saveImportedModels(importedModels: List) { } - override fun readLocalModels(): List { + override fun readImportedModels(): List { return listOf() } } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewTasks.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewTasks.kt index eec2aa2..80e6010 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewTasks.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/preview/PreviewTasks.kt @@ -22,6 +22,7 @@ import androidx.compose.material.icons.rounded.AutoAwesome import com.google.aiedge.gallery.data.BooleanSwitchConfig import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.ConfigKey +import com.google.aiedge.gallery.data.LabelConfig import com.google.aiedge.gallery.data.SegmentedButtonConfig import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.NumberSliderConfig @@ -30,6 +31,10 @@ import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.ValueType val TEST_CONFIGS1: List = listOf( + LabelConfig( + key = ConfigKey.NAME, + defaultValue = "Test name", + ), NumberSliderConfig( key = ConfigKey.MAX_RESULT_COUNT, sliderMin = 1f,