Make importing model functionality better.

- Allow users to specify default parameters before importing.
This commit is contained in:
Jing Jin 2025-04-18 23:10:55 -07:00
parent 29b614355e
commit 604972fe23
15 changed files with 635 additions and 329 deletions

View file

@ -23,6 +23,7 @@ package com.google.aiedge.gallery.data
* Each type corresponds to a specific editor widget, such as a slider or a switch. * Each type corresponds to a specific editor widget, such as a slider or a switch.
*/ */
enum class ConfigEditorType { enum class ConfigEditorType {
LABEL,
NUMBER_SLIDER, NUMBER_SLIDER,
BOOLEAN_SWITCH, BOOLEAN_SWITCH,
DROPDOWN, DROPDOWN,
@ -57,6 +58,19 @@ open class Config(
open val needReinitialization: Boolean = true, open val needReinitialization: Boolean = true,
) )
/**
* Configuration setting for a label.
*/
class LabelConfig(
override val key: ConfigKey,
override val defaultValue: String = "",
) : Config(
type = ConfigEditorType.LABEL,
key = key,
defaultValue = defaultValue,
valueType = ValueType.STRING
)
/** /**
* Configuration setting for a number slider. * Configuration setting for a number slider.
* *
@ -99,9 +113,11 @@ class SegmentedButtonConfig(
override val key: ConfigKey, override val key: ConfigKey,
override val defaultValue: String, override val defaultValue: String,
val options: List<String>, val options: List<String>,
val allowMultiple: Boolean = false,
) : Config( ) : Config(
type = ConfigEditorType.DROPDOWN, type = ConfigEditorType.DROPDOWN,
key = key, key = key,
defaultValue = defaultValue, defaultValue = defaultValue,
// The emitted value will be comma-separated labels when allowMultiple=true.
valueType = ValueType.STRING, valueType = ValueType.STRING,
) )

View file

@ -47,8 +47,8 @@ interface DataStoreRepository {
fun readThemeOverride(): String fun readThemeOverride(): String
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
fun readAccessTokenData(): AccessTokenData? fun readAccessTokenData(): AccessTokenData?
fun saveLocalModels(localModels: List<LocalModelInfo>) fun saveImportedModels(importedModels: List<ImportedModelInfo>)
fun readLocalModels(): List<LocalModelInfo> fun readImportedModels(): List<ImportedModelInfo>
} }
/** /**
@ -82,8 +82,8 @@ class DefaultDataStoreRepository(
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at") val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
// Data for all imported local models. // Data for all imported models.
val LOCAL_MODELS = stringPreferencesKey("local_models") val IMPORTED_MODELS = stringPreferencesKey("imported_models")
} }
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key" private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
@ -160,22 +160,22 @@ class DefaultDataStoreRepository(
} }
} }
override fun saveLocalModels(localModels: List<LocalModelInfo>) { override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
runBlocking { runBlocking {
dataStore.edit { preferences -> dataStore.edit { preferences ->
val gson = Gson() val gson = Gson()
val jsonString = gson.toJson(localModels) val jsonString = gson.toJson(importedModels)
preferences[PreferencesKeys.LOCAL_MODELS] = jsonString preferences[PreferencesKeys.IMPORTED_MODELS] = jsonString
} }
} }
} }
override fun readLocalModels(): List<LocalModelInfo> { override fun readImportedModels(): List<ImportedModelInfo> {
return runBlocking { return runBlocking {
val preferences = dataStore.data.first() val preferences = dataStore.data.first()
val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]" val infosStr = preferences[PreferencesKeys.IMPORTED_MODELS] ?: "[]"
val gson = Gson() val gson = Gson()
val listType = object : TypeToken<List<LocalModelInfo>>() {}.type val listType = object : TypeToken<List<ImportedModelInfo>>() {}.type
gson.fromJson(infosStr, listType) gson.fromJson(infosStr, listType)
} }
} }

View file

@ -17,7 +17,6 @@
package com.google.aiedge.gallery.data package com.google.aiedge.gallery.data
import com.google.aiedge.gallery.ui.common.ensureValidFileName import com.google.aiedge.gallery.ui.common.ensureValidFileName
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
import kotlinx.serialization.KSerializer import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException import kotlinx.serialization.SerializationException
@ -107,11 +106,13 @@ data class HfModel(
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}") val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
// Generate configs based on the given default values. // Generate configs based on the given default values.
val configs: List<Config> = when (task) { // val configs: List<Config> = when (task) {
TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs) // TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
// todo: add configs for other types. // // todo: add configs for other types.
else -> listOf() // else -> listOf()
} // }
// todo: fix when loading from models.json
val configs: List<Config> = listOf()
// Construct url. // Construct url.
var modelUrl = url var modelUrl = url

View file

@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data
import android.content.Context import android.content.Context
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
import com.google.aiedge.gallery.ui.common.convertValueToTargetType import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
data class ModelDataFile( data class ModelDataFile(
@ -28,8 +29,8 @@ data class ModelDataFile(
val sizeInBytes: Long, val sizeInBytes: Long,
) )
enum class LlmBackend { enum class Accelerator(val label: String) {
CPU, GPU CPU(label = "CPU"), GPU(label = "GPU")
} }
const val IMPORTS_DIR = "__imports" const val IMPORTS_DIR = "__imports"
@ -81,14 +82,14 @@ data class Model(
/** The name of the directory to unzip the model to (if it's a zip file). */ /** The name of the directory to unzip the model to (if it's a zip file). */
val unzipDir: String = "", val unzipDir: String = "",
/** The preferred backend of the model (only for LLM). */ /** The accelerators the the model can run with. */
val llmBackend: LlmBackend = LlmBackend.GPU, val accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
/** The prompt templates for the model (only for LLM). */ /** The prompt templates for the model (only for LLM). */
val llmPromptTemplates: List<PromptTemplate> = listOf(), val llmPromptTemplates: List<PromptTemplate> = listOf(),
/** Whether the model is imported as a local model. */ /** Whether the model is imported or not. */
val isLocalModel: Boolean = false, val imported: Boolean = false,
// The following fields are managed by the app. Don't need to set manually. // The following fields are managed by the app. Don't need to set manually.
var taskType: TaskType? = null, var taskType: TaskType? = null,
@ -135,6 +136,12 @@ data class Model(
) as Boolean ) as Boolean
} }
fun getStringConfigValue(key: ConfigKey, defaultValue: String = ""): String {
return getTypedConfigValue(
key = key, valueType = ValueType.STRING, defaultValue = defaultValue
) as String
}
fun getExtraDataFile(name: String): ModelDataFile? { fun getExtraDataFile(name: String): ModelDataFile? {
return extraDataFiles.find { it.name == name } return extraDataFiles.find { it.name == name }
} }
@ -147,7 +154,11 @@ data class Model(
} }
/** Data for a imported local model. */ /** Data for a imported local model. */
data class LocalModelInfo(val fileName: String, val fileSize: Long) data class ImportedModelInfo(
val fileName: String,
val fileSize: Long,
val defaultValues: Map<String, Any>
)
enum class ModelDownloadStatusType { enum class ModelDownloadStatusType {
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED, NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
@ -165,29 +176,25 @@ data class ModelDownloadStatus(
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// Configs. // Configs.
enum class ConfigKey(val label: String, val id: String) { enum class ConfigKey(val label: String) {
MAX_TOKENS("Max tokens", id = "max_token"), MAX_TOKENS("Max tokens"),
TOPK("TopK", id = "topk"), TOPK("TopK"),
TOPP( TOPP("TopP"),
"TopP", TEMPERATURE("Temperature"),
id = "topp" DEFAULT_MAX_TOKENS("Default max tokens"),
), DEFAULT_TOPK("Default TopK"),
TEMPERATURE("Temperature", id = "temperature"), DEFAULT_TOPP("Default TopP"),
MAX_RESULT_COUNT( DEFAULT_TEMPERATURE("Default temperature"),
"Max result count", MAX_RESULT_COUNT("Max result count"),
id = "max_result_count" USE_GPU("Use GPU"),
), ACCELERATOR("Accelerator"),
USE_GPU("Use GPU", id = "use_gpu"), COMPATIBLE_ACCELERATORS("Compatible accelerators"),
WARM_UP_ITERATIONS( WARM_UP_ITERATIONS("Warm up iterations"),
"Warm up iterations", BENCHMARK_ITERATIONS("Benchmark iterations"),
id = "warm_up_iterations" ITERATIONS("Iterations"),
), THEME("Theme"),
BENCHMARK_ITERATIONS( NAME("Name"),
"Benchmark iterations", MODEL_TYPE("Model type")
id = "benchmark_iterations"
),
ITERATIONS("Iterations", id = "iterations"),
THEME("Theme", id = "theme"),
} }
val MOBILENET_CONFIGS: List<Config> = listOf( val MOBILENET_CONFIGS: List<Config> = listOf(
@ -258,7 +265,12 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
downloadFileName = "gemma3-1b-it-int4.task", downloadFileName = "gemma3-1b-it-int4.task",
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true", url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
sizeInBytes = 554661243L, sizeInBytes = 554661243L,
configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f), accelerators = listOf(Accelerator.CPU, Accelerator.GPU),
configs = createLlmChatConfigs(
defaultTopK = 64,
defaultTopP = 0.95f,
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
),
info = LLM_CHAT_INFO, info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT", learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
llmPromptTemplates = listOf( llmPromptTemplates = listOf(
@ -280,8 +292,13 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
downloadFileName = "deepseek.task", downloadFileName = "deepseek.task",
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true", url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
sizeInBytes = 1860686856L, sizeInBytes = 1860686856L,
llmBackend = LlmBackend.CPU, accelerators = listOf(Accelerator.CPU),
configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f), configs = createLlmChatConfigs(
defaultTemperature = 0.6f,
defaultTopK = 40,
defaultTopP = 0.7f,
accelerators = listOf(Accelerator.CPU)
),
info = LLM_CHAT_INFO, info = LLM_CHAT_INFO,
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B", learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
) )

View file

@ -34,16 +34,15 @@ import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.MultiChoiceSegmentedButtonRow
import androidx.compose.material3.SegmentedButton import androidx.compose.material3.SegmentedButton
import androidx.compose.material3.SegmentedButtonDefaults import androidx.compose.material3.SegmentedButtonDefaults
import androidx.compose.material3.SingleChoiceSegmentedButtonRow
import androidx.compose.material3.Slider import androidx.compose.material3.Slider
import androidx.compose.material3.Switch import androidx.compose.material3.Switch
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextButton import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateMapOf import androidx.compose.runtime.mutableStateMapOf
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
@ -60,6 +59,7 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.Dialog
import com.google.aiedge.gallery.data.BooleanSwitchConfig import com.google.aiedge.gallery.data.BooleanSwitchConfig
import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.LabelConfig
import com.google.aiedge.gallery.data.NumberSliderConfig import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.SegmentedButtonConfig import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.data.ValueType
@ -113,27 +113,10 @@ fun ConfigDialog(
} }
} }
// List of config rows. // List of config rows.
for (config in configs) { ConfigEditorsPanel(configs = configs, values = values)
when (config) {
// Number slider.
is NumberSliderConfig -> {
NumberSliderRow(config = config, values = values)
}
// Boolean switch. // Button row.
is BooleanSwitchConfig -> {
BooleanSwitchRow(config = config, values = values)
}
is SegmentedButtonConfig -> {
SegmentedButtonRow(config = config, values = values)
}
else -> {}
}
}
Row( Row(
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
@ -164,6 +147,53 @@ fun ConfigDialog(
} }
} }
/**
* Composable function to display a list of config editor rows.
*/
@Composable
fun ConfigEditorsPanel(configs: List<Config>, values: SnapshotStateMap<String, Any>) {
for (config in configs) {
when (config) {
// Label.
is LabelConfig -> {
LabelRow(config = config, values = values)
}
// Number slider.
is NumberSliderConfig -> {
NumberSliderRow(config = config, values = values)
}
// Boolean switch.
is BooleanSwitchConfig -> {
BooleanSwitchRow(config = config, values = values)
}
// Segmented button.
is SegmentedButtonConfig -> {
SegmentedButtonRow(config = config, values = values)
}
else -> {}
}
}
}
@Composable
fun LabelRow(config: LabelConfig, values: SnapshotStateMap<String, Any>) {
Column(modifier = Modifier.fillMaxWidth()) {
// Field label.
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
// Content label.
val label = try {
values[config.key.label] as String
} catch (e: Exception) {
""
}
Text(label, style = MaterialTheme.typography.bodyMedium)
}
}
/** /**
* Composable function to display a number slider with an associated text input field. * Composable function to display a number slider with an associated text input field.
* *
@ -272,18 +302,41 @@ fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<Strin
@Composable @Composable
fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) { fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) {
var selectedIndex by remember { mutableIntStateOf(config.options.indexOf(values[config.key.label])) } val selectedOptions: List<String> = remember { (values[config.key.label] as String).split(",") }
var selectionStates: List<Boolean> by remember {
mutableStateOf(List(config.options.size) { index ->
selectedOptions.contains(config.options[index])
})
}
Column(modifier = Modifier.fillMaxWidth()) { Column(modifier = Modifier.fillMaxWidth()) {
Text(config.key.label, style = MaterialTheme.typography.titleSmall) Text(config.key.label, style = MaterialTheme.typography.titleSmall)
SingleChoiceSegmentedButtonRow { MultiChoiceSegmentedButtonRow {
config.options.forEachIndexed { index, label -> config.options.forEachIndexed { index, label ->
SegmentedButton(shape = SegmentedButtonDefaults.itemShape( SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
index = index, count = config.options.size index = index, count = config.options.size
), onClick = { ), onCheckedChange = {
selectedIndex = index var newSelectionStates = selectionStates.toMutableList()
values[config.key.label] = label val selectedCount = newSelectionStates.count { it }
}, selected = index == selectedIndex, label = { Text(label) })
// Single select.
if (!config.allowMultiple) {
if (!newSelectionStates[index]) {
newSelectionStates = MutableList(config.options.size) { it == index }
}
}
// Multiple select.
else {
if (!(selectedCount == 1 && newSelectionStates[index])) {
newSelectionStates[index] = !newSelectionStates[index]
}
}
selectionStates = newSelectionStates
values[config.key.label] =
config.options.filterIndexed { index, option -> selectionStates[index] }
.joinToString(",")
}, checked = selectionStates[index], label = { Text(label) })
} }
} }

View file

@ -92,7 +92,6 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
* model description and buttons for learning more (opening a URL) and downloading/trying * model description and buttons for learning more (opening a URL) and downloading/trying
* the model. * the model.
*/ */
@OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun ModelItem( fun ModelItem(
model: Model, model: Model,
@ -188,9 +187,9 @@ fun ModelItem(
} }
} else { } else {
Icon( Icon(
// For local model, show ">" directly indicating users can just tap the model item to // For imported model, show ">" directly indicating users can just tap the model item to
// go into it without needing to expand it first. // go into it without needing to expand it first.
if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore, if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "", contentDescription = "",
tint = getTaskIconColor(task), tint = getTaskIconColor(task),
) )
@ -272,7 +271,7 @@ fun ModelItem(
boxModifier = if (canExpand) { boxModifier = if (canExpand) {
boxModifier.clickable( boxModifier.clickable(
onClick = { onClick = {
if (!model.isLocalModel) { if (!model.imported) {
isExpanded = !isExpanded isExpanded = !isExpanded
} else { } else {
onModelClicked(model) onModelClicked(model)

View file

@ -16,13 +16,22 @@
package com.google.aiedge.gallery.ui.home package com.google.aiedge.gallery.ui.home
import android.content.Intent
import android.net.Uri
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher
import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.StringRes import androidx.annotation.StringRes
import androidx.compose.animation.core.animateFloatAsState
import androidx.compose.animation.core.tween
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.aspectRatio import androidx.compose.foundation.layout.aspectRatio
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
@ -34,22 +43,34 @@ import androidx.compose.foundation.lazy.grid.GridItemSpan
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.items import androidx.compose.foundation.lazy.grid.items
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults import androidx.compose.material3.CardDefaults
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.SmallFloatingActionButton
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBarDefaults import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.scale
import androidx.compose.ui.graphics.Brush import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.input.nestedscroll.nestedScroll import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.layout.layout import androidx.compose.ui.layout.layout
@ -66,6 +87,8 @@ import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.AppBarAction import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.TaskIcon import com.google.aiedge.gallery.ui.common.TaskIcon
import com.google.aiedge.gallery.ui.common.getTaskBgColor import com.google.aiedge.gallery.ui.common.getTaskBgColor
@ -75,6 +98,11 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.ThemeSettings import com.google.aiedge.gallery.ui.theme.ThemeSettings
import com.google.aiedge.gallery.ui.theme.customColors import com.google.aiedge.gallery.ui.theme.customColors
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private const val TAG = "AGHomeScreen"
private const val TASK_COUNT_ANIMATION_DURATION = 250
/** Navigation destination data */ /** Navigation destination data */
object HomeScreenDestination { object HomeScreenDestination {
@ -92,11 +120,35 @@ fun HomeScreen(
val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior() val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior()
val uiState by modelManagerViewModel.uiState.collectAsState() val uiState by modelManagerViewModel.uiState.collectAsState()
var showSettingsDialog by remember { mutableStateOf(false) } var showSettingsDialog by remember { mutableStateOf(false) }
var showImportModelSheet by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState()
var showImportDialog by remember { mutableStateOf(false) }
var showImportingDialog by remember { mutableStateOf(false) }
val selectedLocalModelFileUri = remember { mutableStateOf<Uri?>(null) }
val selectedImportedModelInfo = remember { mutableStateOf<ImportedModelInfo?>(null) }
val coroutineScope = rememberCoroutineScope()
val tasks = uiState.tasks val tasks = uiState.tasks
val loadingHfModels = uiState.loadingHfModels val loadingHfModels = uiState.loadingHfModels
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = { val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
selectedLocalModelFileUri.value = uri
showImportDialog = true
} ?: run {
Log.d(TAG, "No file selected or URI is null.")
}
} else {
Log.d(TAG, "File picking cancelled.")
}
}
Scaffold(
modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
topBar = {
GalleryTopAppBar( GalleryTopAppBar(
title = stringResource(HomeScreenDestination.titleRes), title = stringResource(HomeScreenDestination.titleRes),
rightAction = AppBarAction( rightAction = AppBarAction(
@ -107,7 +159,20 @@ fun HomeScreen(
loadingHfModels = loadingHfModels, loadingHfModels = loadingHfModels,
scrollBehavior = scrollBehavior, scrollBehavior = scrollBehavior,
) )
}) { innerPadding -> },
floatingActionButton = {
// A floating action button to show "import model" bottom sheet.
SmallFloatingActionButton(
onClick = {
showImportModelSheet = true
},
containerColor = MaterialTheme.colorScheme.secondaryContainer,
contentColor = MaterialTheme.colorScheme.secondary,
) {
Icon(Icons.Filled.Add, "")
}
}
) { innerPadding ->
TaskList( TaskList(
tasks = tasks, tasks = tasks,
navigateToTaskScreen = navigateToTaskScreen, navigateToTaskScreen = navigateToTaskScreen,
@ -132,6 +197,83 @@ fun HomeScreen(
}, },
) )
} }
// Import model bottom sheet.
if (showImportModelSheet) {
ModalBottomSheet(
onDismissRequest = { showImportModelSheet = false },
sheetState = sheetState,
) {
Text(
"Import model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
)
Box(modifier = Modifier.clickable {
coroutineScope.launch {
// Give it sometime to show the click effect.
delay(200)
showImportModelSheet = false
// Show file picker.
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*"
putExtra(
Intent.EXTRA_MIME_TYPES,
arrayOf("application/x-binary", "application/octet-stream")
)
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
}
filePickerLauncher.launch(intent)
}
}) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
) {
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
Text("From local model file")
}
}
}
}
// Import dialog
if (showImportDialog) {
selectedLocalModelFileUri.value?.let { uri ->
ModelImportDialog(uri = uri,
onDismiss = { showImportDialog = false },
onDone = { info ->
selectedImportedModelInfo.value = info
showImportDialog = false
showImportingDialog = true
})
}
}
// Importing in progress dialog.
if (showImportingDialog) {
selectedLocalModelFileUri.value?.let { uri ->
selectedImportedModelInfo.value?.let { info ->
ModelImportingDialog(
uri = uri,
info = info,
onDismiss = { showImportingDialog = false },
onDone = {
modelManagerViewModel.addImportedLlmModel(
task = TASK_LLM_CHAT,
info = it,
)
showImportingDialog = false
})
}
}
}
} }
@Composable @Composable
@ -150,7 +292,7 @@ private fun TaskList(
verticalArrangement = Arrangement.spacedBy(8.dp), verticalArrangement = Arrangement.spacedBy(8.dp),
) { ) {
// Headline. // Headline.
item(span = { GridItemSpan(2) }) { item(key = "headline", span = { GridItemSpan(2) }) {
Text( Text(
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community", "Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community",
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
@ -171,6 +313,11 @@ private fun TaskList(
.aspectRatio(1f) .aspectRatio(1f)
) )
} }
// Bottom padding.
item(key = "bottomPadding", span = { GridItemSpan(2) }) {
Spacer(modifier = Modifier.height(60.dp))
}
} }
// Gradient overlay at the bottom. // Gradient overlay at the bottom.
@ -190,6 +337,48 @@ private fun TaskList(
@Composable @Composable
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) { private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) {
// Observes the model count and updates the model count label with a fade-in/fade-out animation
// whenever the count changes.
val modelCount by remember {
derivedStateOf {
val trigger = task.updateTrigger.value
if (trigger >= 0) {
task.models.size
} else {
0
}
}
}
val modelCountLabel by remember {
derivedStateOf {
when (modelCount) {
1 -> "1 Model"
else -> "%d Models".format(modelCount)
}
}
}
var curModelCountLabel by remember { mutableStateOf("") }
var modelCountLabelVisible by remember { mutableStateOf(true) }
val modelCountAlpha: Float by animateFloatAsState(
targetValue = if (modelCountLabelVisible) 1f else 0f,
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
)
val modelCountScale: Float by animateFloatAsState(
targetValue = if (modelCountLabelVisible) 1f else 0.7f,
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
)
LaunchedEffect(modelCountLabel) {
if (curModelCountLabel.isEmpty()) {
curModelCountLabel = modelCountLabel
} else {
modelCountLabelVisible = false
delay(TASK_COUNT_ANIMATION_DURATION.toLong())
curModelCountLabel = modelCountLabel
modelCountLabelVisible = true
}
}
Card( Card(
modifier = modifier modifier = modifier
.clip(RoundedCornerShape(43.5.dp)) .clip(RoundedCornerShape(43.5.dp))
@ -238,14 +427,13 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
} }
// Model count. // Model count.
val modelCountLabel = when (task.models.size) {
1 -> "1 Model"
else -> "%d Models".format(task.models.size)
}
Text( Text(
modelCountLabel, curModelCountLabel,
color = MaterialTheme.colorScheme.secondary, color = MaterialTheme.colorScheme.secondary,
style = MaterialTheme.typography.bodyMedium style = MaterialTheme.typography.bodyMedium,
modifier = Modifier
.alpha(modelCountAlpha)
.scale(modelCountScale),
) )
} }
} }

View file

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package com.google.aiedge.gallery.ui.modelmanager package com.google.aiedge.gallery.ui.home
import android.content.Context import android.content.Context
import android.net.Uri import android.net.Uri
@ -36,24 +36,40 @@ import androidx.compose.material3.Icon
import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableLongStateOf import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.runtime.mutableStateMapOf
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshots.SnapshotStateMap
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties import androidx.compose.ui.window.DialogProperties
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.IMPORTS_DIR import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.data.LabelConfig
import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.ui.common.chat.ConfigEditorsPanel
import com.google.aiedge.gallery.ui.common.ensureValidFileName import com.google.aiedge.gallery.ui.common.ensureValidFileName
import com.google.aiedge.gallery.ui.common.humanReadableSize import com.google.aiedge.gallery.ui.common.humanReadableSize
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_MAX_TOKEN
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPK
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPP
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -64,37 +80,151 @@ import java.nio.charset.StandardCharsets
private const val TAG = "AGModelImportDialog" private const val TAG = "AGModelImportDialog"
data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "") private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
LabelConfig(key = ConfigKey.NAME),
LabelConfig(key = ConfigKey.MODEL_TYPE),
NumberSliderConfig(
key = ConfigKey.DEFAULT_MAX_TOKENS,
sliderMin = 100f,
sliderMax = 1024f,
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = DEFAULT_TOPK.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = DEFAULT_TOPP,
valueType = ValueType.FLOAT
),
NumberSliderConfig(
key = ConfigKey.DEFAULT_TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = DEFAULT_TEMPERATURE,
valueType = ValueType.FLOAT
),
SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
allowMultiple = true,
)
)
@Composable @Composable
fun ModelImportDialog( fun ModelImportDialog(
uri: Uri, onDone: (ModelImportInfo) -> Unit uri: Uri,
onDismiss: () -> Unit,
onDone: (ImportedModelInfo) -> Unit
) { ) {
val context = LocalContext.current val context = LocalContext.current
val coroutineScope = rememberCoroutineScope() val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
val fileSize by remember { mutableLongStateOf(info.first) }
val fileName by remember { mutableStateOf(ensureValidFileName(info.second)) }
var fileName by remember { mutableStateOf("") } val initialValues: Map<String, Any> = remember {
var fileSize by remember { mutableLongStateOf(0L) } mutableMapOf<String, Any>().apply {
for (config in IMPORT_CONFIGS_LLM) {
put(config.key.label, config.defaultValue)
}
put(ConfigKey.NAME.label, fileName)
// TODO: support other types.
put(ConfigKey.MODEL_TYPE.label, "LLM")
}
}
val values: SnapshotStateMap<String, Any> = remember {
mutableStateMapOf<String, Any>().apply {
putAll(initialValues)
}
}
Dialog(
onDismissRequest = onDismiss,
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title.
Text(
"Import Model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
// Default configs for users to set.
ConfigEditorsPanel(
configs = IMPORT_CONFIGS_LLM,
values = values,
)
// Button row.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
// Cancel button.
TextButton(
onClick = { onDismiss() },
) {
Text("Cancel")
}
// Import button
Button(
onClick = {
onDone(
ImportedModelInfo(
fileName = fileName,
fileSize = fileSize,
defaultValues = values,
)
)
},
) {
Text("Import")
}
}
}
}
}
}
@Composable
fun ModelImportingDialog(
uri: Uri,
info: ImportedModelInfo,
onDismiss: () -> Unit,
onDone: (ImportedModelInfo) -> Unit
) {
var error by remember { mutableStateOf("") } var error by remember { mutableStateOf("") }
val context = LocalContext.current
val coroutineScope = rememberCoroutineScope()
var progress by remember { mutableFloatStateOf(0f) } var progress by remember { mutableFloatStateOf(0f) }
LaunchedEffect(Unit) { LaunchedEffect(Unit) {
error = ""
// Get basic info.
val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri)
fileSize = info.first
fileName = ensureValidFileName(info.second)
// Import. // Import.
importModel( importModel(
context = context, context = context,
coroutineScope = coroutineScope, coroutineScope = coroutineScope,
fileName = fileName, fileName = info.fileName,
fileSize = fileSize, fileSize = info.fileSize,
uri = uri, uri = uri,
onDone = { onDone = {
onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error)) onDone(info)
}, },
onProgress = { onProgress = {
progress = it progress = it
@ -107,7 +237,7 @@ fun ModelImportDialog(
Dialog( Dialog(
properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false), properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false),
onDismissRequest = {}, onDismissRequest = onDismiss,
) { ) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column( Column(
@ -117,7 +247,7 @@ fun ModelImportDialog(
) { ) {
// Title. // Title.
Text( Text(
"Importing...", "Import Model",
style = MaterialTheme.typography.titleLarge, style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp) modifier = Modifier.padding(bottom = 8.dp)
) )
@ -127,7 +257,7 @@ fun ModelImportDialog(
// Progress bar. // Progress bar.
Column(verticalArrangement = Arrangement.spacedBy(4.dp)) { Column(verticalArrangement = Arrangement.spacedBy(4.dp)) {
Text( Text(
"$fileName (${fileSize.humanReadableSize()})", "${info.fileName} (${info.fileSize.humanReadableSize()})",
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelSmall,
) )
val animatedProgress = remember { Animatable(0f) } val animatedProgress = remember { Animatable(0f) }
@ -162,7 +292,7 @@ fun ModelImportDialog(
} }
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) { Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button(onClick = { Button(onClick = {
onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error)) onDismiss()
}) { }) {
Text("Close") Text("Close")
} }

View file

@ -30,7 +30,7 @@ private val CONFIGS: List<Config> = listOf(
SegmentedButtonConfig( SegmentedButtonConfig(
key = ConfigKey.THEME, key = ConfigKey.THEME,
defaultValue = THEME_AUTO, defaultValue = THEME_AUTO,
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK) options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK),
) )
) )

View file

@ -16,24 +16,25 @@
package com.google.aiedge.gallery.ui.llmchat package com.google.aiedge.gallery.ui.llmchat
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.ConfigValue
import com.google.aiedge.gallery.data.NumberSliderConfig import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.data.getFloatConfigValue
import com.google.aiedge.gallery.data.getIntConfigValue
private const val DEFAULT_MAX_TOKEN = 1024 const val DEFAULT_MAX_TOKEN = 1024
private const val DEFAULT_TOPK = 40 const val DEFAULT_TOPK = 40
private const val DEFAULT_TOPP = 0.9f const val DEFAULT_TOPP = 0.9f
private const val DEFAULT_TEMPERATURE = 1.0f const val DEFAULT_TEMPERATURE = 1.0f
val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
fun createLlmChatConfigs( fun createLlmChatConfigs(
defaultMaxToken: Int = DEFAULT_MAX_TOKEN, defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
defaultTopK: Int = DEFAULT_TOPK, defaultTopK: Int = DEFAULT_TOPK,
defaultTopP: Float = DEFAULT_TOPP, defaultTopP: Float = DEFAULT_TOPP,
defaultTemperature: Float = DEFAULT_TEMPERATURE defaultTemperature: Float = DEFAULT_TEMPERATURE,
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
): List<Config> { ): List<Config> {
return listOf( return listOf(
NumberSliderConfig( NumberSliderConfig(
@ -64,21 +65,10 @@ fun createLlmChatConfigs(
defaultValue = defaultTemperature, defaultValue = defaultTemperature,
valueType = ValueType.FLOAT valueType = ValueType.FLOAT
), ),
) SegmentedButtonConfig(
} key = ConfigKey.ACCELERATOR,
defaultValue = if (accelerators.contains(Accelerator.GPU)) Accelerator.GPU.label else accelerators[0].label,
fun createLLmChatConfig(defaults: Map<String, ConfigValue>): List<Config> { options = accelerators.map { it.label }
val defaultMaxToken = )
getIntConfigValue(defaults[ConfigKey.MAX_TOKENS.id], default = DEFAULT_MAX_TOKEN)
val defaultTopK = getIntConfigValue(defaults[ConfigKey.TOPK.id], default = DEFAULT_TOPK)
val defaultTopP = getFloatConfigValue(defaults[ConfigKey.TOPP.id], default = DEFAULT_TOPP)
val defaultTemperature =
getFloatConfigValue(defaults[ConfigKey.TEMPERATURE.id], default = DEFAULT_TEMPERATURE)
return createLlmChatConfigs(
defaultMaxToken = defaultMaxToken,
defaultTopK = defaultTopK,
defaultTopP = defaultTopP,
defaultTemperature = defaultTemperature
) )
} }

View file

@ -18,18 +18,14 @@ package com.google.aiedge.gallery.ui.llmchat
import android.content.Context import android.content.Context
import android.util.Log import android.util.Log
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.LlmBackend
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
private const val TAG = "AGLlmChatModelHelper" private const val TAG = "AGLlmChatModelHelper"
private const val DEFAULT_MAX_TOKEN = 1024
private const val DEFAULT_TOPK = 40
private const val DEFAULT_TOPP = 0.9f
private const val DEFAULT_TEMPERATURE = 1.0f
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
typealias CleanUpListener = () -> Unit typealias CleanUpListener = () -> Unit
@ -49,10 +45,13 @@ object LlmChatModelHelper {
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP) val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature = val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE) model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
val accelerator =
model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = Accelerator.GPU.label)
Log.d(TAG, "Initializing...") Log.d(TAG, "Initializing...")
val preferredBackend = when (model.llmBackend) { val preferredBackend = when (accelerator) {
LlmBackend.CPU -> LlmInference.Backend.CPU Accelerator.CPU.label -> LlmInference.Backend.CPU
LlmBackend.GPU -> LlmInference.Backend.GPU Accelerator.GPU.label -> LlmInference.Backend.GPU
else -> LlmInference.Backend.GPU
} }
val options = val options =
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context)) LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))

View file

@ -16,13 +16,7 @@
package com.google.aiedge.gallery.ui.modelmanager package com.google.aiedge.gallery.ui.modelmanager
import android.content.Intent
import android.net.Uri
import android.os.Build import android.os.Build
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher
import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.RequiresApi import androidx.annotation.RequiresApi
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
@ -30,32 +24,21 @@ import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.items
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.outlined.Code import androidx.compose.material.icons.outlined.Code
import androidx.compose.material.icons.outlined.Description import androidx.compose.material.icons.outlined.Description
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.SmallFloatingActionButton
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.derivedStateOf import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
@ -75,14 +58,11 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private const val TAG = "AGModelList" private const val TAG = "AGModelList"
/** The list of models in the model manager. */ /** The list of models in the model manager. */
@RequiresApi(Build.VERSION_CODES.O) @RequiresApi(Build.VERSION_CODES.O)
@OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun ModelList( fun ModelList(
task: Task, task: Task,
@ -91,50 +71,29 @@ fun ModelList(
onModelClicked: (Model) -> Unit, onModelClicked: (Model) -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
) { ) {
var showAddModelSheet by remember { mutableStateOf(false) }
var showImportingDialog by remember { mutableStateOf(false) }
val curFileUri = remember { mutableStateOf<Uri?>(null) }
val sheetState = rememberModalBottomSheetState()
val coroutineScope = rememberCoroutineScope()
// This is just to update "models" list when task.updateTrigger is updated so that the UI can // This is just to update "models" list when task.updateTrigger is updated so that the UI can
// be properly updated. // be properly updated.
val models by remember { val models by remember {
derivedStateOf { derivedStateOf {
val trigger = task.updateTrigger.value val trigger = task.updateTrigger.value
if (trigger >= 0) { if (trigger >= 0) {
task.models.toList().filter { !it.isLocalModel } task.models.toList().filter { !it.imported }
} else { } else {
listOf() listOf()
} }
} }
} }
val localModels by remember { val importedModels by remember {
derivedStateOf { derivedStateOf {
val trigger = task.updateTrigger.value val trigger = task.updateTrigger.value
if (trigger >= 0) { if (trigger >= 0) {
task.models.toList().filter { it.isLocalModel } task.models.toList().filter { it.imported }
} else { } else {
listOf() listOf()
} }
} }
} }
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
curFileUri.value = uri
showImportingDialog = true
} ?: run {
Log.d(TAG, "No file selected or URI is null.")
}
} else {
Log.d(TAG, "File picking cancelled.")
}
}
Box(contentAlignment = Alignment.BottomEnd) { Box(contentAlignment = Alignment.BottomEnd) {
LazyColumn( LazyColumn(
modifier = modifier.padding(top = 8.dp), modifier = modifier.padding(top = 8.dp),
@ -190,11 +149,11 @@ fun ModelList(
} }
} }
// Title for local models. // Title for imported models.
if (localModels.isNotEmpty()) { if (importedModels.isNotEmpty()) {
item(key = "localModelsTitle") { item(key = "importedModelsTitle") {
Text( Text(
"Local models", "Imported models",
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier modifier = Modifier
.padding(horizontal = 16.dp) .padding(horizontal = 16.dp)
@ -203,8 +162,8 @@ fun ModelList(
} }
} }
// List of local models within a task. // List of imported models within a task.
items(items = localModels) { model -> items(items = importedModels) { model ->
Box { Box {
ModelItem( ModelItem(
model = model, model = model,
@ -215,88 +174,6 @@ fun ModelList(
) )
} }
} }
item(key = "bottomPadding") {
Spacer(modifier = Modifier.height(60.dp))
}
}
// Add model button at the bottom right.
Box(
modifier = Modifier
.padding(end = 16.dp)
.padding(bottom = contentPadding.calculateBottomPadding())
) {
SmallFloatingActionButton(
onClick = {
showAddModelSheet = true
},
containerColor = MaterialTheme.colorScheme.secondaryContainer,
contentColor = MaterialTheme.colorScheme.secondary,
) {
Icon(Icons.Filled.Add, "")
}
}
}
if (showAddModelSheet) {
ModalBottomSheet(
onDismissRequest = { showAddModelSheet = false },
sheetState = sheetState,
) {
Text(
"Add custom model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
)
Box(modifier = Modifier.clickable {
coroutineScope.launch {
// Give it sometime to show the click effect.
delay(200)
showAddModelSheet = false
// Show file picker.
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*"
putExtra(
Intent.EXTRA_MIME_TYPES,
arrayOf("application/x-binary", "application/octet-stream")
)
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
}
filePickerLauncher.launch(intent)
}
}) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
) {
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
Text("Add local model")
}
}
}
}
if (showImportingDialog) {
curFileUri.value?.let { uri ->
ModelImportDialog(uri = uri, onDone = { info ->
showImportingDialog = false
if (info.error.isEmpty()) {
// TODO: support other model types.
modelManagerViewModel.addLocalLlmModel(
task = task,
fileName = info.fileName,
fileSize = info.fileSize
)
}
})
} }
} }
} }

View file

@ -23,8 +23,10 @@ import androidx.activity.result.ActivityResult
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.AGWorkInfo import com.google.aiedge.gallery.data.AGWorkInfo
import com.google.aiedge.gallery.data.Accelerator
import com.google.aiedge.gallery.data.AccessTokenData import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.DataStoreRepository import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.DownloadRepository import com.google.aiedge.gallery.data.DownloadRepository
import com.google.aiedge.gallery.data.EMPTY_MODEL import com.google.aiedge.gallery.data.EMPTY_MODEL
@ -32,7 +34,7 @@ import com.google.aiedge.gallery.data.HfModel
import com.google.aiedge.gallery.data.HfModelDetails import com.google.aiedge.gallery.data.HfModelDetails
import com.google.aiedge.gallery.data.HfModelSummary import com.google.aiedge.gallery.data.HfModelSummary
import com.google.aiedge.gallery.data.IMPORTS_DIR import com.google.aiedge.gallery.data.IMPORTS_DIR
import com.google.aiedge.gallery.data.LocalModelInfo import com.google.aiedge.gallery.data.ImportedModelInfo
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
@ -40,12 +42,14 @@ import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.data.getModelByName import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.common.AuthConfig import com.google.aiedge.gallery.ui.common.AuthConfig
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async import kotlinx.coroutines.async
@ -228,7 +232,7 @@ open class ModelManagerViewModel(
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED) ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
// Delete model from the list if model is imported as a local model. // Delete model from the list if model is imported as a local model.
if (model.isLocalModel) { if (model.imported) {
val index = task.models.indexOf(model) val index = task.models.indexOf(model)
if (index >= 0) { if (index >= 0) {
task.models.removeAt(index) task.models.removeAt(index)
@ -237,12 +241,12 @@ open class ModelManagerViewModel(
curModelDownloadStatus.remove(model.name) curModelDownloadStatus.remove(model.name)
// Update preference. // Update preference.
val localModels = dataStoreRepository.readLocalModels().toMutableList() val importedModels = dataStoreRepository.readImportedModels().toMutableList()
val localModelIndex = localModels.indexOfFirst { it.fileName == model.name } val importedModelIndex = importedModels.indexOfFirst { it.fileName == model.name }
if (localModelIndex >= 0) { if (importedModelIndex >= 0) {
localModels.removeAt(localModelIndex) importedModels.removeAt(importedModelIndex)
} }
dataStoreRepository.saveLocalModels(localModels = localModels) dataStoreRepository.saveImportedModels(importedModels = importedModels)
} }
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus) val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
_uiState.update { newUiState } _uiState.update { newUiState }
@ -417,27 +421,20 @@ open class ModelManagerViewModel(
return connection.responseCode return connection.responseCode
} }
fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) { fun addImportedLlmModel(task: Task, info: ImportedModelInfo) {
Log.d(TAG, "adding local model: $fileName, $fileSize") Log.d(TAG, "adding imported llm model: $info")
// Create model. // Create model.
val configs: List<Config> = createLLmChatConfig(defaults = mapOf()) val model = createModelFromImportedModelInfo(info = info, task = task)
val model = Model(
name = fileName,
url = "",
configs = configs,
sizeInBytes = fileSize,
downloadFileName = "$IMPORTS_DIR/$fileName",
isLocalModel = true,
)
model.preProcess(task = task)
task.models.add(model) task.models.add(model)
// Add initial status and states. // Add initial status and states.
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap() val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap() val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
modelDownloadStatus[model.name] = ModelDownloadStatus( modelDownloadStatus[model.name] = ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize status = ModelDownloadStatusType.SUCCEEDED,
receivedBytes = info.fileSize,
totalBytes = info.fileSize
) )
modelInstances[model.name] = modelInstances[model.name] =
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED) ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
@ -453,9 +450,9 @@ open class ModelManagerViewModel(
task.updateTrigger.value = System.currentTimeMillis() task.updateTrigger.value = System.currentTimeMillis()
// Add to preference storage. // Add to preference storage.
val localModels = dataStoreRepository.readLocalModels().toMutableList() val importedModels = dataStoreRepository.readImportedModels().toMutableList()
localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize)) importedModels.add(info)
dataStoreRepository.saveLocalModels(localModels = localModels) dataStoreRepository.saveImportedModels(importedModels = importedModels)
} }
fun getTokenStatusAndData(): TokenStatusAndData { fun getTokenStatusAndData(): TokenStatusAndData {
@ -589,31 +586,22 @@ open class ModelManagerViewModel(
} }
} }
// Load local models. // Load imported models.
for (localModel in dataStoreRepository.readLocalModels()) { for (importedModel in dataStoreRepository.readImportedModels()) {
Log.d(TAG, "stored local model: $localModel") Log.d(TAG, "stored imported model: $importedModel")
// Create model. // Create model.
val configs: List<Config> = createLLmChatConfig(defaults = mapOf()) val model = createModelFromImportedModelInfo(info = importedModel, task = TASK_LLM_CHAT)
val model = Model(
name = localModel.fileName,
url = "",
configs = configs,
sizeInBytes = localModel.fileSize,
downloadFileName = "$IMPORTS_DIR/${localModel.fileName}",
isLocalModel = true,
)
// Add to task. // Add to task.
val task = TASK_LLM_CHAT val task = TASK_LLM_CHAT
model.preProcess(task = task)
task.models.add(model) task.models.add(model)
// Update status. // Update status.
modelDownloadStatus[model.name] = ModelDownloadStatus( modelDownloadStatus[model.name] = ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED, status = ModelDownloadStatusType.SUCCEEDED,
receivedBytes = localModel.fileSize, receivedBytes = importedModel.fileSize,
totalBytes = localModel.fileSize totalBytes = importedModel.fileSize
) )
} }
@ -628,6 +616,51 @@ open class ModelManagerViewModel(
) )
} }
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
val accelerators: List<Accelerator> = (convertValueToTargetType(
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!,
ValueType.STRING
) as String)
.split(",")
.mapNotNull { acceleratorLabel ->
when (acceleratorLabel.trim()) {
Accelerator.GPU.label -> Accelerator.GPU
Accelerator.CPU.label -> Accelerator.CPU
else -> null // Ignore unknown accelerator labels
}
}
val configs: List<Config> = createLlmChatConfigs(
defaultMaxToken = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!,
ValueType.INT
) as Int,
defaultTopK = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!,
ValueType.INT
) as Int,
defaultTopP = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!,
ValueType.FLOAT
) as Float,
defaultTemperature = convertValueToTargetType(
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!,
ValueType.FLOAT
) as Float,
accelerators = accelerators,
)
val model = Model(
name = info.fileName,
url = "",
configs = configs,
sizeInBytes = info.fileSize,
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
imported = true,
)
model.preProcess(task = task)
return model
}
/** /**
* Retrieves the download status of a model. * Retrieves the download status of a model.
* *
@ -771,9 +804,7 @@ open class ModelManagerViewModel(
} }
private fun updateModelInitializationStatus( private fun updateModelInitializationStatus(
model: Model, model: Model, status: ModelInitializationStatusType, error: String = ""
status: ModelInitializationStatusType,
error: String = ""
) { ) {
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap() val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error) curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error)

View file

@ -18,7 +18,7 @@ package com.google.aiedge.gallery.ui.preview
import com.google.aiedge.gallery.data.AccessTokenData import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.DataStoreRepository import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.LocalModelInfo import com.google.aiedge.gallery.data.ImportedModelInfo
class PreviewDataStoreRepository : DataStoreRepository { class PreviewDataStoreRepository : DataStoreRepository {
override fun saveTextInputHistory(history: List<String>) { override fun saveTextInputHistory(history: List<String>) {
@ -42,10 +42,10 @@ class PreviewDataStoreRepository : DataStoreRepository {
return null return null
} }
override fun saveLocalModels(localModels: List<LocalModelInfo>) { override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
} }
override fun readLocalModels(): List<LocalModelInfo> { override fun readImportedModels(): List<ImportedModelInfo> {
return listOf() return listOf()
} }
} }

View file

@ -22,6 +22,7 @@ import androidx.compose.material.icons.rounded.AutoAwesome
import com.google.aiedge.gallery.data.BooleanSwitchConfig import com.google.aiedge.gallery.data.BooleanSwitchConfig
import com.google.aiedge.gallery.data.Config import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.LabelConfig
import com.google.aiedge.gallery.data.SegmentedButtonConfig import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.NumberSliderConfig import com.google.aiedge.gallery.data.NumberSliderConfig
@ -30,6 +31,10 @@ import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType import com.google.aiedge.gallery.data.ValueType
val TEST_CONFIGS1: List<Config> = listOf( val TEST_CONFIGS1: List<Config> = listOf(
LabelConfig(
key = ConfigKey.NAME,
defaultValue = "Test name",
),
NumberSliderConfig( NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT, key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f, sliderMin = 1f,