mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-12 17:32:30 -04:00
Make importing model functionality better.
- Allow users to specify default parameters before importing.
This commit is contained in:
parent
29b614355e
commit
604972fe23
15 changed files with 635 additions and 329 deletions
|
@ -23,6 +23,7 @@ package com.google.aiedge.gallery.data
|
|||
* Each type corresponds to a specific editor widget, such as a slider or a switch.
|
||||
*/
|
||||
enum class ConfigEditorType {
|
||||
LABEL,
|
||||
NUMBER_SLIDER,
|
||||
BOOLEAN_SWITCH,
|
||||
DROPDOWN,
|
||||
|
@ -57,6 +58,19 @@ open class Config(
|
|||
open val needReinitialization: Boolean = true,
|
||||
)
|
||||
|
||||
/**
|
||||
* Configuration setting for a label.
|
||||
*/
|
||||
class LabelConfig(
|
||||
override val key: ConfigKey,
|
||||
override val defaultValue: String = "",
|
||||
) : Config(
|
||||
type = ConfigEditorType.LABEL,
|
||||
key = key,
|
||||
defaultValue = defaultValue,
|
||||
valueType = ValueType.STRING
|
||||
)
|
||||
|
||||
/**
|
||||
* Configuration setting for a number slider.
|
||||
*
|
||||
|
@ -99,9 +113,11 @@ class SegmentedButtonConfig(
|
|||
override val key: ConfigKey,
|
||||
override val defaultValue: String,
|
||||
val options: List<String>,
|
||||
val allowMultiple: Boolean = false,
|
||||
) : Config(
|
||||
type = ConfigEditorType.DROPDOWN,
|
||||
key = key,
|
||||
defaultValue = defaultValue,
|
||||
// The emitted value will be comma-separated labels when allowMultiple=true.
|
||||
valueType = ValueType.STRING,
|
||||
)
|
|
@ -47,8 +47,8 @@ interface DataStoreRepository {
|
|||
fun readThemeOverride(): String
|
||||
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
|
||||
fun readAccessTokenData(): AccessTokenData?
|
||||
fun saveLocalModels(localModels: List<LocalModelInfo>)
|
||||
fun readLocalModels(): List<LocalModelInfo>
|
||||
fun saveImportedModels(importedModels: List<ImportedModelInfo>)
|
||||
fun readImportedModels(): List<ImportedModelInfo>
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -82,8 +82,8 @@ class DefaultDataStoreRepository(
|
|||
|
||||
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
|
||||
|
||||
// Data for all imported local models.
|
||||
val LOCAL_MODELS = stringPreferencesKey("local_models")
|
||||
// Data for all imported models.
|
||||
val IMPORTED_MODELS = stringPreferencesKey("imported_models")
|
||||
}
|
||||
|
||||
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
|
||||
|
@ -160,22 +160,22 @@ class DefaultDataStoreRepository(
|
|||
}
|
||||
}
|
||||
|
||||
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
|
||||
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
|
||||
runBlocking {
|
||||
dataStore.edit { preferences ->
|
||||
val gson = Gson()
|
||||
val jsonString = gson.toJson(localModels)
|
||||
preferences[PreferencesKeys.LOCAL_MODELS] = jsonString
|
||||
val jsonString = gson.toJson(importedModels)
|
||||
preferences[PreferencesKeys.IMPORTED_MODELS] = jsonString
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun readLocalModels(): List<LocalModelInfo> {
|
||||
override fun readImportedModels(): List<ImportedModelInfo> {
|
||||
return runBlocking {
|
||||
val preferences = dataStore.data.first()
|
||||
val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]"
|
||||
val infosStr = preferences[PreferencesKeys.IMPORTED_MODELS] ?: "[]"
|
||||
val gson = Gson()
|
||||
val listType = object : TypeToken<List<LocalModelInfo>>() {}.type
|
||||
val listType = object : TypeToken<List<ImportedModelInfo>>() {}.type
|
||||
gson.fromJson(infosStr, listType)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package com.google.aiedge.gallery.data
|
||||
|
||||
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
||||
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
|
||||
import kotlinx.serialization.KSerializer
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlinx.serialization.SerializationException
|
||||
|
@ -107,11 +106,13 @@ data class HfModel(
|
|||
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
|
||||
|
||||
// Generate configs based on the given default values.
|
||||
val configs: List<Config> = when (task) {
|
||||
TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
|
||||
// todo: add configs for other types.
|
||||
else -> listOf()
|
||||
}
|
||||
// val configs: List<Config> = when (task) {
|
||||
// TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
|
||||
// // todo: add configs for other types.
|
||||
// else -> listOf()
|
||||
// }
|
||||
// todo: fix when loading from models.json
|
||||
val configs: List<Config> = listOf()
|
||||
|
||||
// Construct url.
|
||||
var modelUrl = url
|
||||
|
|
|
@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data
|
|||
import android.content.Context
|
||||
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
|
||||
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||
|
||||
data class ModelDataFile(
|
||||
|
@ -28,8 +29,8 @@ data class ModelDataFile(
|
|||
val sizeInBytes: Long,
|
||||
)
|
||||
|
||||
enum class LlmBackend {
|
||||
CPU, GPU
|
||||
enum class Accelerator(val label: String) {
|
||||
CPU(label = "CPU"), GPU(label = "GPU")
|
||||
}
|
||||
|
||||
const val IMPORTS_DIR = "__imports"
|
||||
|
@ -81,14 +82,14 @@ data class Model(
|
|||
/** The name of the directory to unzip the model to (if it's a zip file). */
|
||||
val unzipDir: String = "",
|
||||
|
||||
/** The preferred backend of the model (only for LLM). */
|
||||
val llmBackend: LlmBackend = LlmBackend.GPU,
|
||||
/** The accelerators the the model can run with. */
|
||||
val accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
|
||||
|
||||
/** The prompt templates for the model (only for LLM). */
|
||||
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
||||
|
||||
/** Whether the model is imported as a local model. */
|
||||
val isLocalModel: Boolean = false,
|
||||
/** Whether the model is imported or not. */
|
||||
val imported: Boolean = false,
|
||||
|
||||
// The following fields are managed by the app. Don't need to set manually.
|
||||
var taskType: TaskType? = null,
|
||||
|
@ -135,6 +136,12 @@ data class Model(
|
|||
) as Boolean
|
||||
}
|
||||
|
||||
fun getStringConfigValue(key: ConfigKey, defaultValue: String = ""): String {
|
||||
return getTypedConfigValue(
|
||||
key = key, valueType = ValueType.STRING, defaultValue = defaultValue
|
||||
) as String
|
||||
}
|
||||
|
||||
fun getExtraDataFile(name: String): ModelDataFile? {
|
||||
return extraDataFiles.find { it.name == name }
|
||||
}
|
||||
|
@ -147,7 +154,11 @@ data class Model(
|
|||
}
|
||||
|
||||
/** Data for a imported local model. */
|
||||
data class LocalModelInfo(val fileName: String, val fileSize: Long)
|
||||
data class ImportedModelInfo(
|
||||
val fileName: String,
|
||||
val fileSize: Long,
|
||||
val defaultValues: Map<String, Any>
|
||||
)
|
||||
|
||||
enum class ModelDownloadStatusType {
|
||||
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
|
||||
|
@ -165,29 +176,25 @@ data class ModelDownloadStatus(
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Configs.
|
||||
|
||||
enum class ConfigKey(val label: String, val id: String) {
|
||||
MAX_TOKENS("Max tokens", id = "max_token"),
|
||||
TOPK("TopK", id = "topk"),
|
||||
TOPP(
|
||||
"TopP",
|
||||
id = "topp"
|
||||
),
|
||||
TEMPERATURE("Temperature", id = "temperature"),
|
||||
MAX_RESULT_COUNT(
|
||||
"Max result count",
|
||||
id = "max_result_count"
|
||||
),
|
||||
USE_GPU("Use GPU", id = "use_gpu"),
|
||||
WARM_UP_ITERATIONS(
|
||||
"Warm up iterations",
|
||||
id = "warm_up_iterations"
|
||||
),
|
||||
BENCHMARK_ITERATIONS(
|
||||
"Benchmark iterations",
|
||||
id = "benchmark_iterations"
|
||||
),
|
||||
ITERATIONS("Iterations", id = "iterations"),
|
||||
THEME("Theme", id = "theme"),
|
||||
enum class ConfigKey(val label: String) {
|
||||
MAX_TOKENS("Max tokens"),
|
||||
TOPK("TopK"),
|
||||
TOPP("TopP"),
|
||||
TEMPERATURE("Temperature"),
|
||||
DEFAULT_MAX_TOKENS("Default max tokens"),
|
||||
DEFAULT_TOPK("Default TopK"),
|
||||
DEFAULT_TOPP("Default TopP"),
|
||||
DEFAULT_TEMPERATURE("Default temperature"),
|
||||
MAX_RESULT_COUNT("Max result count"),
|
||||
USE_GPU("Use GPU"),
|
||||
ACCELERATOR("Accelerator"),
|
||||
COMPATIBLE_ACCELERATORS("Compatible accelerators"),
|
||||
WARM_UP_ITERATIONS("Warm up iterations"),
|
||||
BENCHMARK_ITERATIONS("Benchmark iterations"),
|
||||
ITERATIONS("Iterations"),
|
||||
THEME("Theme"),
|
||||
NAME("Name"),
|
||||
MODEL_TYPE("Model type")
|
||||
}
|
||||
|
||||
val MOBILENET_CONFIGS: List<Config> = listOf(
|
||||
|
@ -258,7 +265,12 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
|||
downloadFileName = "gemma3-1b-it-int4.task",
|
||||
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
|
||||
sizeInBytes = 554661243L,
|
||||
configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f),
|
||||
accelerators = listOf(Accelerator.CPU, Accelerator.GPU),
|
||||
configs = createLlmChatConfigs(
|
||||
defaultTopK = 64,
|
||||
defaultTopP = 0.95f,
|
||||
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
|
||||
),
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
||||
llmPromptTemplates = listOf(
|
||||
|
@ -280,8 +292,13 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
|
|||
downloadFileName = "deepseek.task",
|
||||
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
|
||||
sizeInBytes = 1860686856L,
|
||||
llmBackend = LlmBackend.CPU,
|
||||
configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f),
|
||||
accelerators = listOf(Accelerator.CPU),
|
||||
configs = createLlmChatConfigs(
|
||||
defaultTemperature = 0.6f,
|
||||
defaultTopK = 40,
|
||||
defaultTopP = 0.7f,
|
||||
accelerators = listOf(Accelerator.CPU)
|
||||
),
|
||||
info = LLM_CHAT_INFO,
|
||||
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
)
|
||||
|
|
|
@ -34,16 +34,15 @@ import androidx.compose.foundation.text.KeyboardOptions
|
|||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.MultiChoiceSegmentedButtonRow
|
||||
import androidx.compose.material3.SegmentedButton
|
||||
import androidx.compose.material3.SegmentedButtonDefaults
|
||||
import androidx.compose.material3.SingleChoiceSegmentedButtonRow
|
||||
import androidx.compose.material3.Slider
|
||||
import androidx.compose.material3.Switch
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextButton
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableIntStateOf
|
||||
import androidx.compose.runtime.mutableStateMapOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
|
@ -60,6 +59,7 @@ import androidx.compose.ui.unit.dp
|
|||
import androidx.compose.ui.window.Dialog
|
||||
import com.google.aiedge.gallery.data.BooleanSwitchConfig
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.LabelConfig
|
||||
import com.google.aiedge.gallery.data.NumberSliderConfig
|
||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||
import com.google.aiedge.gallery.data.ValueType
|
||||
|
@ -113,27 +113,10 @@ fun ConfigDialog(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// List of config rows.
|
||||
for (config in configs) {
|
||||
when (config) {
|
||||
// Number slider.
|
||||
is NumberSliderConfig -> {
|
||||
NumberSliderRow(config = config, values = values)
|
||||
}
|
||||
ConfigEditorsPanel(configs = configs, values = values)
|
||||
|
||||
// Boolean switch.
|
||||
is BooleanSwitchConfig -> {
|
||||
BooleanSwitchRow(config = config, values = values)
|
||||
}
|
||||
|
||||
is SegmentedButtonConfig -> {
|
||||
SegmentedButtonRow(config = config, values = values)
|
||||
}
|
||||
|
||||
else -> {}
|
||||
}
|
||||
}
|
||||
// Button row.
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
|
@ -164,6 +147,53 @@ fun ConfigDialog(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Composable function to display a list of config editor rows.
|
||||
*/
|
||||
@Composable
|
||||
fun ConfigEditorsPanel(configs: List<Config>, values: SnapshotStateMap<String, Any>) {
|
||||
for (config in configs) {
|
||||
when (config) {
|
||||
// Label.
|
||||
is LabelConfig -> {
|
||||
LabelRow(config = config, values = values)
|
||||
}
|
||||
|
||||
// Number slider.
|
||||
is NumberSliderConfig -> {
|
||||
NumberSliderRow(config = config, values = values)
|
||||
}
|
||||
|
||||
// Boolean switch.
|
||||
is BooleanSwitchConfig -> {
|
||||
BooleanSwitchRow(config = config, values = values)
|
||||
}
|
||||
|
||||
// Segmented button.
|
||||
is SegmentedButtonConfig -> {
|
||||
SegmentedButtonRow(config = config, values = values)
|
||||
}
|
||||
|
||||
else -> {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun LabelRow(config: LabelConfig, values: SnapshotStateMap<String, Any>) {
|
||||
Column(modifier = Modifier.fillMaxWidth()) {
|
||||
// Field label.
|
||||
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
|
||||
// Content label.
|
||||
val label = try {
|
||||
values[config.key.label] as String
|
||||
} catch (e: Exception) {
|
||||
""
|
||||
}
|
||||
Text(label, style = MaterialTheme.typography.bodyMedium)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Composable function to display a number slider with an associated text input field.
|
||||
*
|
||||
|
@ -272,18 +302,41 @@ fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<Strin
|
|||
|
||||
@Composable
|
||||
fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) {
|
||||
var selectedIndex by remember { mutableIntStateOf(config.options.indexOf(values[config.key.label])) }
|
||||
val selectedOptions: List<String> = remember { (values[config.key.label] as String).split(",") }
|
||||
var selectionStates: List<Boolean> by remember {
|
||||
mutableStateOf(List(config.options.size) { index ->
|
||||
selectedOptions.contains(config.options[index])
|
||||
})
|
||||
}
|
||||
|
||||
Column(modifier = Modifier.fillMaxWidth()) {
|
||||
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
|
||||
SingleChoiceSegmentedButtonRow {
|
||||
MultiChoiceSegmentedButtonRow {
|
||||
config.options.forEachIndexed { index, label ->
|
||||
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
|
||||
index = index, count = config.options.size
|
||||
), onClick = {
|
||||
selectedIndex = index
|
||||
values[config.key.label] = label
|
||||
}, selected = index == selectedIndex, label = { Text(label) })
|
||||
), onCheckedChange = {
|
||||
var newSelectionStates = selectionStates.toMutableList()
|
||||
val selectedCount = newSelectionStates.count { it }
|
||||
|
||||
// Single select.
|
||||
if (!config.allowMultiple) {
|
||||
if (!newSelectionStates[index]) {
|
||||
newSelectionStates = MutableList(config.options.size) { it == index }
|
||||
}
|
||||
}
|
||||
// Multiple select.
|
||||
else {
|
||||
if (!(selectedCount == 1 && newSelectionStates[index])) {
|
||||
newSelectionStates[index] = !newSelectionStates[index]
|
||||
}
|
||||
}
|
||||
selectionStates = newSelectionStates
|
||||
|
||||
values[config.key.label] =
|
||||
config.options.filterIndexed { index, option -> selectionStates[index] }
|
||||
.joinToString(",")
|
||||
}, checked = selectionStates[index], label = { Text(label) })
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -92,7 +92,6 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
|
|||
* model description and buttons for learning more (opening a URL) and downloading/trying
|
||||
* the model.
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ModelItem(
|
||||
model: Model,
|
||||
|
@ -188,9 +187,9 @@ fun ModelItem(
|
|||
}
|
||||
} else {
|
||||
Icon(
|
||||
// For local model, show ">" directly indicating users can just tap the model item to
|
||||
// For imported model, show ">" directly indicating users can just tap the model item to
|
||||
// go into it without needing to expand it first.
|
||||
if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
||||
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
||||
contentDescription = "",
|
||||
tint = getTaskIconColor(task),
|
||||
)
|
||||
|
@ -272,7 +271,7 @@ fun ModelItem(
|
|||
boxModifier = if (canExpand) {
|
||||
boxModifier.clickable(
|
||||
onClick = {
|
||||
if (!model.isLocalModel) {
|
||||
if (!model.imported) {
|
||||
isExpanded = !isExpanded
|
||||
} else {
|
||||
onModelClicked(model)
|
||||
|
|
|
@ -16,13 +16,22 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.home
|
||||
|
||||
import android.content.Intent
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.ActivityResultLauncher
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.annotation.StringRes
|
||||
import androidx.compose.animation.core.animateFloatAsState
|
||||
import androidx.compose.animation.core.tween
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.PaddingValues
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.aspectRatio
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
|
@ -34,22 +43,34 @@ import androidx.compose.foundation.lazy.grid.GridItemSpan
|
|||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||
import androidx.compose.foundation.lazy.grid.items
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
||||
import androidx.compose.material.icons.filled.Add
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.CardDefaults
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.SmallFloatingActionButton
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TopAppBarDefaults
|
||||
import androidx.compose.material3.rememberModalBottomSheetState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.draw.scale
|
||||
import androidx.compose.ui.graphics.Brush
|
||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||
import androidx.compose.ui.layout.layout
|
||||
|
@ -66,6 +87,8 @@ import com.google.aiedge.gallery.R
|
|||
import com.google.aiedge.gallery.data.AppBarAction
|
||||
import com.google.aiedge.gallery.data.AppBarActionType
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.TaskIcon
|
||||
import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
||||
|
@ -75,6 +98,11 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
|||
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
private const val TAG = "AGHomeScreen"
|
||||
private const val TASK_COUNT_ANIMATION_DURATION = 250
|
||||
|
||||
/** Navigation destination data */
|
||||
object HomeScreenDestination {
|
||||
|
@ -92,11 +120,35 @@ fun HomeScreen(
|
|||
val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior()
|
||||
val uiState by modelManagerViewModel.uiState.collectAsState()
|
||||
var showSettingsDialog by remember { mutableStateOf(false) }
|
||||
var showImportModelSheet by remember { mutableStateOf(false) }
|
||||
val sheetState = rememberModalBottomSheetState()
|
||||
var showImportDialog by remember { mutableStateOf(false) }
|
||||
var showImportingDialog by remember { mutableStateOf(false) }
|
||||
val selectedLocalModelFileUri = remember { mutableStateOf<Uri?>(null) }
|
||||
val selectedImportedModelInfo = remember { mutableStateOf<ImportedModelInfo?>(null) }
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
val tasks = uiState.tasks
|
||||
val loadingHfModels = uiState.loadingHfModels
|
||||
|
||||
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = {
|
||||
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
||||
contract = ActivityResultContracts.StartActivityForResult()
|
||||
) { result ->
|
||||
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
||||
result.data?.data?.let { uri ->
|
||||
selectedLocalModelFileUri.value = uri
|
||||
showImportDialog = true
|
||||
} ?: run {
|
||||
Log.d(TAG, "No file selected or URI is null.")
|
||||
}
|
||||
} else {
|
||||
Log.d(TAG, "File picking cancelled.")
|
||||
}
|
||||
}
|
||||
|
||||
Scaffold(
|
||||
modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
|
||||
topBar = {
|
||||
GalleryTopAppBar(
|
||||
title = stringResource(HomeScreenDestination.titleRes),
|
||||
rightAction = AppBarAction(
|
||||
|
@ -107,7 +159,20 @@ fun HomeScreen(
|
|||
loadingHfModels = loadingHfModels,
|
||||
scrollBehavior = scrollBehavior,
|
||||
)
|
||||
}) { innerPadding ->
|
||||
},
|
||||
floatingActionButton = {
|
||||
// A floating action button to show "import model" bottom sheet.
|
||||
SmallFloatingActionButton(
|
||||
onClick = {
|
||||
showImportModelSheet = true
|
||||
},
|
||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||
contentColor = MaterialTheme.colorScheme.secondary,
|
||||
) {
|
||||
Icon(Icons.Filled.Add, "")
|
||||
}
|
||||
}
|
||||
) { innerPadding ->
|
||||
TaskList(
|
||||
tasks = tasks,
|
||||
navigateToTaskScreen = navigateToTaskScreen,
|
||||
|
@ -132,6 +197,83 @@ fun HomeScreen(
|
|||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Import model bottom sheet.
|
||||
if (showImportModelSheet) {
|
||||
ModalBottomSheet(
|
||||
onDismissRequest = { showImportModelSheet = false },
|
||||
sheetState = sheetState,
|
||||
) {
|
||||
Text(
|
||||
"Import model",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
|
||||
)
|
||||
Box(modifier = Modifier.clickable {
|
||||
coroutineScope.launch {
|
||||
// Give it sometime to show the click effect.
|
||||
delay(200)
|
||||
showImportModelSheet = false
|
||||
|
||||
// Show file picker.
|
||||
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||
addCategory(Intent.CATEGORY_OPENABLE)
|
||||
type = "*/*"
|
||||
putExtra(
|
||||
Intent.EXTRA_MIME_TYPES,
|
||||
arrayOf("application/x-binary", "application/octet-stream")
|
||||
)
|
||||
// Single select.
|
||||
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
||||
}
|
||||
filePickerLauncher.launch(intent)
|
||||
}
|
||||
}) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp)
|
||||
) {
|
||||
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
|
||||
Text("From local model file")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Import dialog
|
||||
if (showImportDialog) {
|
||||
selectedLocalModelFileUri.value?.let { uri ->
|
||||
ModelImportDialog(uri = uri,
|
||||
onDismiss = { showImportDialog = false },
|
||||
onDone = { info ->
|
||||
selectedImportedModelInfo.value = info
|
||||
showImportDialog = false
|
||||
showImportingDialog = true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Importing in progress dialog.
|
||||
if (showImportingDialog) {
|
||||
selectedLocalModelFileUri.value?.let { uri ->
|
||||
selectedImportedModelInfo.value?.let { info ->
|
||||
ModelImportingDialog(
|
||||
uri = uri,
|
||||
info = info,
|
||||
onDismiss = { showImportingDialog = false },
|
||||
onDone = {
|
||||
modelManagerViewModel.addImportedLlmModel(
|
||||
task = TASK_LLM_CHAT,
|
||||
info = it,
|
||||
)
|
||||
showImportingDialog = false
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
|
@ -150,7 +292,7 @@ private fun TaskList(
|
|||
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||
) {
|
||||
// Headline.
|
||||
item(span = { GridItemSpan(2) }) {
|
||||
item(key = "headline", span = { GridItemSpan(2) }) {
|
||||
Text(
|
||||
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community",
|
||||
textAlign = TextAlign.Center,
|
||||
|
@ -171,6 +313,11 @@ private fun TaskList(
|
|||
.aspectRatio(1f)
|
||||
)
|
||||
}
|
||||
|
||||
// Bottom padding.
|
||||
item(key = "bottomPadding", span = { GridItemSpan(2) }) {
|
||||
Spacer(modifier = Modifier.height(60.dp))
|
||||
}
|
||||
}
|
||||
|
||||
// Gradient overlay at the bottom.
|
||||
|
@ -190,6 +337,48 @@ private fun TaskList(
|
|||
|
||||
@Composable
|
||||
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) {
|
||||
// Observes the model count and updates the model count label with a fade-in/fade-out animation
|
||||
// whenever the count changes.
|
||||
val modelCount by remember {
|
||||
derivedStateOf {
|
||||
val trigger = task.updateTrigger.value
|
||||
if (trigger >= 0) {
|
||||
task.models.size
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
val modelCountLabel by remember {
|
||||
derivedStateOf {
|
||||
when (modelCount) {
|
||||
1 -> "1 Model"
|
||||
else -> "%d Models".format(modelCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
var curModelCountLabel by remember { mutableStateOf("") }
|
||||
var modelCountLabelVisible by remember { mutableStateOf(true) }
|
||||
val modelCountAlpha: Float by animateFloatAsState(
|
||||
targetValue = if (modelCountLabelVisible) 1f else 0f,
|
||||
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
|
||||
)
|
||||
val modelCountScale: Float by animateFloatAsState(
|
||||
targetValue = if (modelCountLabelVisible) 1f else 0.7f,
|
||||
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
|
||||
)
|
||||
|
||||
LaunchedEffect(modelCountLabel) {
|
||||
if (curModelCountLabel.isEmpty()) {
|
||||
curModelCountLabel = modelCountLabel
|
||||
} else {
|
||||
modelCountLabelVisible = false
|
||||
delay(TASK_COUNT_ANIMATION_DURATION.toLong())
|
||||
curModelCountLabel = modelCountLabel
|
||||
modelCountLabelVisible = true
|
||||
}
|
||||
}
|
||||
|
||||
Card(
|
||||
modifier = modifier
|
||||
.clip(RoundedCornerShape(43.5.dp))
|
||||
|
@ -238,14 +427,13 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
|
|||
}
|
||||
|
||||
// Model count.
|
||||
val modelCountLabel = when (task.models.size) {
|
||||
1 -> "1 Model"
|
||||
else -> "%d Models".format(task.models.size)
|
||||
}
|
||||
Text(
|
||||
modelCountLabel,
|
||||
curModelCountLabel,
|
||||
color = MaterialTheme.colorScheme.secondary,
|
||||
style = MaterialTheme.typography.bodyMedium
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
modifier = Modifier
|
||||
.alpha(modelCountAlpha)
|
||||
.scale(modelCountScale),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.modelmanager
|
||||
package com.google.aiedge.gallery.ui.home
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
|
@ -36,24 +36,40 @@ import androidx.compose.material3.Icon
|
|||
import androidx.compose.material3.LinearProgressIndicator
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextButton
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableFloatStateOf
|
||||
import androidx.compose.runtime.mutableLongStateOf
|
||||
import androidx.compose.runtime.mutableStateMapOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.snapshots.SnapshotStateMap
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.window.Dialog
|
||||
import androidx.compose.ui.window.DialogProperties
|
||||
import com.google.aiedge.gallery.data.Accelerator
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||
import com.google.aiedge.gallery.data.LabelConfig
|
||||
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||
import com.google.aiedge.gallery.data.NumberSliderConfig
|
||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||
import com.google.aiedge.gallery.data.ValueType
|
||||
import com.google.aiedge.gallery.ui.common.chat.ConfigEditorsPanel
|
||||
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
||||
import com.google.aiedge.gallery.ui.common.humanReadableSize
|
||||
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_MAX_TOKEN
|
||||
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
|
||||
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPK
|
||||
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPP
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
|
@ -64,37 +80,151 @@ import java.nio.charset.StandardCharsets
|
|||
|
||||
private const val TAG = "AGModelImportDialog"
|
||||
|
||||
data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "")
|
||||
private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
|
||||
LabelConfig(key = ConfigKey.NAME),
|
||||
LabelConfig(key = ConfigKey.MODEL_TYPE),
|
||||
NumberSliderConfig(
|
||||
key = ConfigKey.DEFAULT_MAX_TOKENS,
|
||||
sliderMin = 100f,
|
||||
sliderMax = 1024f,
|
||||
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
|
||||
valueType = ValueType.INT
|
||||
),
|
||||
NumberSliderConfig(
|
||||
key = ConfigKey.DEFAULT_TOPK,
|
||||
sliderMin = 5f,
|
||||
sliderMax = 40f,
|
||||
defaultValue = DEFAULT_TOPK.toFloat(),
|
||||
valueType = ValueType.INT
|
||||
),
|
||||
NumberSliderConfig(
|
||||
key = ConfigKey.DEFAULT_TOPP,
|
||||
sliderMin = 0.0f,
|
||||
sliderMax = 1.0f,
|
||||
defaultValue = DEFAULT_TOPP,
|
||||
valueType = ValueType.FLOAT
|
||||
),
|
||||
NumberSliderConfig(
|
||||
key = ConfigKey.DEFAULT_TEMPERATURE,
|
||||
sliderMin = 0.0f,
|
||||
sliderMax = 2.0f,
|
||||
defaultValue = DEFAULT_TEMPERATURE,
|
||||
valueType = ValueType.FLOAT
|
||||
),
|
||||
SegmentedButtonConfig(
|
||||
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
||||
defaultValue = Accelerator.CPU.label,
|
||||
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
|
||||
allowMultiple = true,
|
||||
)
|
||||
)
|
||||
|
||||
@Composable
|
||||
fun ModelImportDialog(
|
||||
uri: Uri, onDone: (ModelImportInfo) -> Unit
|
||||
uri: Uri,
|
||||
onDismiss: () -> Unit,
|
||||
onDone: (ImportedModelInfo) -> Unit
|
||||
) {
|
||||
val context = LocalContext.current
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
|
||||
val fileSize by remember { mutableLongStateOf(info.first) }
|
||||
val fileName by remember { mutableStateOf(ensureValidFileName(info.second)) }
|
||||
|
||||
var fileName by remember { mutableStateOf("") }
|
||||
var fileSize by remember { mutableLongStateOf(0L) }
|
||||
val initialValues: Map<String, Any> = remember {
|
||||
mutableMapOf<String, Any>().apply {
|
||||
for (config in IMPORT_CONFIGS_LLM) {
|
||||
put(config.key.label, config.defaultValue)
|
||||
}
|
||||
put(ConfigKey.NAME.label, fileName)
|
||||
// TODO: support other types.
|
||||
put(ConfigKey.MODEL_TYPE.label, "LLM")
|
||||
}
|
||||
}
|
||||
val values: SnapshotStateMap<String, Any> = remember {
|
||||
mutableStateMapOf<String, Any>().apply {
|
||||
putAll(initialValues)
|
||||
}
|
||||
}
|
||||
|
||||
Dialog(
|
||||
onDismissRequest = onDismiss,
|
||||
) {
|
||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(20.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Title.
|
||||
Text(
|
||||
"Import Model",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
modifier = Modifier.padding(bottom = 8.dp)
|
||||
)
|
||||
|
||||
// Default configs for users to set.
|
||||
ConfigEditorsPanel(
|
||||
configs = IMPORT_CONFIGS_LLM,
|
||||
values = values,
|
||||
)
|
||||
|
||||
// Button row.
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(top = 8.dp),
|
||||
horizontalArrangement = Arrangement.End,
|
||||
) {
|
||||
// Cancel button.
|
||||
TextButton(
|
||||
onClick = { onDismiss() },
|
||||
) {
|
||||
Text("Cancel")
|
||||
}
|
||||
|
||||
// Import button
|
||||
Button(
|
||||
onClick = {
|
||||
onDone(
|
||||
ImportedModelInfo(
|
||||
fileName = fileName,
|
||||
fileSize = fileSize,
|
||||
defaultValues = values,
|
||||
)
|
||||
)
|
||||
},
|
||||
) {
|
||||
Text("Import")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ModelImportingDialog(
|
||||
uri: Uri,
|
||||
info: ImportedModelInfo,
|
||||
onDismiss: () -> Unit,
|
||||
onDone: (ImportedModelInfo) -> Unit
|
||||
) {
|
||||
var error by remember { mutableStateOf("") }
|
||||
val context = LocalContext.current
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
var progress by remember { mutableFloatStateOf(0f) }
|
||||
|
||||
LaunchedEffect(Unit) {
|
||||
error = ""
|
||||
|
||||
// Get basic info.
|
||||
val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri)
|
||||
fileSize = info.first
|
||||
fileName = ensureValidFileName(info.second)
|
||||
|
||||
// Import.
|
||||
importModel(
|
||||
context = context,
|
||||
coroutineScope = coroutineScope,
|
||||
fileName = fileName,
|
||||
fileSize = fileSize,
|
||||
fileName = info.fileName,
|
||||
fileSize = info.fileSize,
|
||||
uri = uri,
|
||||
onDone = {
|
||||
onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error))
|
||||
onDone(info)
|
||||
},
|
||||
onProgress = {
|
||||
progress = it
|
||||
|
@ -107,7 +237,7 @@ fun ModelImportDialog(
|
|||
|
||||
Dialog(
|
||||
properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false),
|
||||
onDismissRequest = {},
|
||||
onDismissRequest = onDismiss,
|
||||
) {
|
||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||
Column(
|
||||
|
@ -117,7 +247,7 @@ fun ModelImportDialog(
|
|||
) {
|
||||
// Title.
|
||||
Text(
|
||||
"Importing...",
|
||||
"Import Model",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
modifier = Modifier.padding(bottom = 8.dp)
|
||||
)
|
||||
|
@ -127,7 +257,7 @@ fun ModelImportDialog(
|
|||
// Progress bar.
|
||||
Column(verticalArrangement = Arrangement.spacedBy(4.dp)) {
|
||||
Text(
|
||||
"$fileName (${fileSize.humanReadableSize()})",
|
||||
"${info.fileName} (${info.fileSize.humanReadableSize()})",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
)
|
||||
val animatedProgress = remember { Animatable(0f) }
|
||||
|
@ -162,7 +292,7 @@ fun ModelImportDialog(
|
|||
}
|
||||
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
|
||||
Button(onClick = {
|
||||
onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error))
|
||||
onDismiss()
|
||||
}) {
|
||||
Text("Close")
|
||||
}
|
|
@ -30,7 +30,7 @@ private val CONFIGS: List<Config> = listOf(
|
|||
SegmentedButtonConfig(
|
||||
key = ConfigKey.THEME,
|
||||
defaultValue = THEME_AUTO,
|
||||
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK)
|
||||
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -16,24 +16,25 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.llmchat
|
||||
|
||||
import com.google.aiedge.gallery.data.Accelerator
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.ConfigValue
|
||||
import com.google.aiedge.gallery.data.NumberSliderConfig
|
||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||
import com.google.aiedge.gallery.data.ValueType
|
||||
import com.google.aiedge.gallery.data.getFloatConfigValue
|
||||
import com.google.aiedge.gallery.data.getIntConfigValue
|
||||
|
||||
private const val DEFAULT_MAX_TOKEN = 1024
|
||||
private const val DEFAULT_TOPK = 40
|
||||
private const val DEFAULT_TOPP = 0.9f
|
||||
private const val DEFAULT_TEMPERATURE = 1.0f
|
||||
const val DEFAULT_MAX_TOKEN = 1024
|
||||
const val DEFAULT_TOPK = 40
|
||||
const val DEFAULT_TOPP = 0.9f
|
||||
const val DEFAULT_TEMPERATURE = 1.0f
|
||||
val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
|
||||
|
||||
fun createLlmChatConfigs(
|
||||
defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
|
||||
defaultTopK: Int = DEFAULT_TOPK,
|
||||
defaultTopP: Float = DEFAULT_TOPP,
|
||||
defaultTemperature: Float = DEFAULT_TEMPERATURE
|
||||
defaultTemperature: Float = DEFAULT_TEMPERATURE,
|
||||
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
|
||||
): List<Config> {
|
||||
return listOf(
|
||||
NumberSliderConfig(
|
||||
|
@ -64,21 +65,10 @@ fun createLlmChatConfigs(
|
|||
defaultValue = defaultTemperature,
|
||||
valueType = ValueType.FLOAT
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fun createLLmChatConfig(defaults: Map<String, ConfigValue>): List<Config> {
|
||||
val defaultMaxToken =
|
||||
getIntConfigValue(defaults[ConfigKey.MAX_TOKENS.id], default = DEFAULT_MAX_TOKEN)
|
||||
val defaultTopK = getIntConfigValue(defaults[ConfigKey.TOPK.id], default = DEFAULT_TOPK)
|
||||
val defaultTopP = getFloatConfigValue(defaults[ConfigKey.TOPP.id], default = DEFAULT_TOPP)
|
||||
val defaultTemperature =
|
||||
getFloatConfigValue(defaults[ConfigKey.TEMPERATURE.id], default = DEFAULT_TEMPERATURE)
|
||||
|
||||
return createLlmChatConfigs(
|
||||
defaultMaxToken = defaultMaxToken,
|
||||
defaultTopK = defaultTopK,
|
||||
defaultTopP = defaultTopP,
|
||||
defaultTemperature = defaultTemperature
|
||||
SegmentedButtonConfig(
|
||||
key = ConfigKey.ACCELERATOR,
|
||||
defaultValue = if (accelerators.contains(Accelerator.GPU)) Accelerator.GPU.label else accelerators[0].label,
|
||||
options = accelerators.map { it.label }
|
||||
)
|
||||
)
|
||||
}
|
||||
|
|
|
@ -18,18 +18,14 @@ package com.google.aiedge.gallery.ui.llmchat
|
|||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import com.google.aiedge.gallery.data.Accelerator
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.LlmBackend
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||
|
||||
private const val TAG = "AGLlmChatModelHelper"
|
||||
private const val DEFAULT_MAX_TOKEN = 1024
|
||||
private const val DEFAULT_TOPK = 40
|
||||
private const val DEFAULT_TOPP = 0.9f
|
||||
private const val DEFAULT_TEMPERATURE = 1.0f
|
||||
|
||||
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
|
||||
typealias CleanUpListener = () -> Unit
|
||||
|
@ -49,10 +45,13 @@ object LlmChatModelHelper {
|
|||
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
|
||||
val temperature =
|
||||
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
|
||||
val accelerator =
|
||||
model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = Accelerator.GPU.label)
|
||||
Log.d(TAG, "Initializing...")
|
||||
val preferredBackend = when (model.llmBackend) {
|
||||
LlmBackend.CPU -> LlmInference.Backend.CPU
|
||||
LlmBackend.GPU -> LlmInference.Backend.GPU
|
||||
val preferredBackend = when (accelerator) {
|
||||
Accelerator.CPU.label -> LlmInference.Backend.CPU
|
||||
Accelerator.GPU.label -> LlmInference.Backend.GPU
|
||||
else -> LlmInference.Backend.GPU
|
||||
}
|
||||
val options =
|
||||
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
|
||||
|
|
|
@ -16,13 +16,7 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.modelmanager
|
||||
|
||||
import android.content.Intent
|
||||
import android.net.Uri
|
||||
import android.os.Build
|
||||
import android.util.Log
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.ActivityResultLauncher
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.annotation.RequiresApi
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
|
@ -30,32 +24,21 @@ import androidx.compose.foundation.layout.Box
|
|||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.PaddingValues
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
||||
import androidx.compose.material.icons.filled.Add
|
||||
import androidx.compose.material.icons.outlined.Code
|
||||
import androidx.compose.material.icons.outlined.Description
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
import androidx.compose.material3.SmallFloatingActionButton
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.rememberModalBottomSheetState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
|
@ -75,14 +58,11 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
|||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.theme.customColors
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
private const val TAG = "AGModelList"
|
||||
|
||||
/** The list of models in the model manager. */
|
||||
@RequiresApi(Build.VERSION_CODES.O)
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ModelList(
|
||||
task: Task,
|
||||
|
@ -91,50 +71,29 @@ fun ModelList(
|
|||
onModelClicked: (Model) -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
) {
|
||||
var showAddModelSheet by remember { mutableStateOf(false) }
|
||||
var showImportingDialog by remember { mutableStateOf(false) }
|
||||
val curFileUri = remember { mutableStateOf<Uri?>(null) }
|
||||
val sheetState = rememberModalBottomSheetState()
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
|
||||
// be properly updated.
|
||||
val models by remember {
|
||||
derivedStateOf {
|
||||
val trigger = task.updateTrigger.value
|
||||
if (trigger >= 0) {
|
||||
task.models.toList().filter { !it.isLocalModel }
|
||||
task.models.toList().filter { !it.imported }
|
||||
} else {
|
||||
listOf()
|
||||
}
|
||||
}
|
||||
}
|
||||
val localModels by remember {
|
||||
val importedModels by remember {
|
||||
derivedStateOf {
|
||||
val trigger = task.updateTrigger.value
|
||||
if (trigger >= 0) {
|
||||
task.models.toList().filter { it.isLocalModel }
|
||||
task.models.toList().filter { it.imported }
|
||||
} else {
|
||||
listOf()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
||||
contract = ActivityResultContracts.StartActivityForResult()
|
||||
) { result ->
|
||||
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
||||
result.data?.data?.let { uri ->
|
||||
curFileUri.value = uri
|
||||
showImportingDialog = true
|
||||
} ?: run {
|
||||
Log.d(TAG, "No file selected or URI is null.")
|
||||
}
|
||||
} else {
|
||||
Log.d(TAG, "File picking cancelled.")
|
||||
}
|
||||
}
|
||||
|
||||
Box(contentAlignment = Alignment.BottomEnd) {
|
||||
LazyColumn(
|
||||
modifier = modifier.padding(top = 8.dp),
|
||||
|
@ -190,11 +149,11 @@ fun ModelList(
|
|||
}
|
||||
}
|
||||
|
||||
// Title for local models.
|
||||
if (localModels.isNotEmpty()) {
|
||||
item(key = "localModelsTitle") {
|
||||
// Title for imported models.
|
||||
if (importedModels.isNotEmpty()) {
|
||||
item(key = "importedModelsTitle") {
|
||||
Text(
|
||||
"Local models",
|
||||
"Imported models",
|
||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||
modifier = Modifier
|
||||
.padding(horizontal = 16.dp)
|
||||
|
@ -203,8 +162,8 @@ fun ModelList(
|
|||
}
|
||||
}
|
||||
|
||||
// List of local models within a task.
|
||||
items(items = localModels) { model ->
|
||||
// List of imported models within a task.
|
||||
items(items = importedModels) { model ->
|
||||
Box {
|
||||
ModelItem(
|
||||
model = model,
|
||||
|
@ -215,88 +174,6 @@ fun ModelList(
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
item(key = "bottomPadding") {
|
||||
Spacer(modifier = Modifier.height(60.dp))
|
||||
}
|
||||
}
|
||||
|
||||
// Add model button at the bottom right.
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.padding(end = 16.dp)
|
||||
.padding(bottom = contentPadding.calculateBottomPadding())
|
||||
) {
|
||||
SmallFloatingActionButton(
|
||||
onClick = {
|
||||
showAddModelSheet = true
|
||||
},
|
||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||
contentColor = MaterialTheme.colorScheme.secondary,
|
||||
) {
|
||||
Icon(Icons.Filled.Add, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (showAddModelSheet) {
|
||||
ModalBottomSheet(
|
||||
onDismissRequest = { showAddModelSheet = false },
|
||||
sheetState = sheetState,
|
||||
) {
|
||||
Text(
|
||||
"Add custom model",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
|
||||
)
|
||||
Box(modifier = Modifier.clickable {
|
||||
coroutineScope.launch {
|
||||
// Give it sometime to show the click effect.
|
||||
delay(200)
|
||||
showAddModelSheet = false
|
||||
|
||||
// Show file picker.
|
||||
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||
addCategory(Intent.CATEGORY_OPENABLE)
|
||||
type = "*/*"
|
||||
putExtra(
|
||||
Intent.EXTRA_MIME_TYPES,
|
||||
arrayOf("application/x-binary", "application/octet-stream")
|
||||
)
|
||||
// Single select.
|
||||
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
||||
}
|
||||
filePickerLauncher.launch(intent)
|
||||
}
|
||||
}) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp)
|
||||
) {
|
||||
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
|
||||
Text("Add local model")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (showImportingDialog) {
|
||||
curFileUri.value?.let { uri ->
|
||||
ModelImportDialog(uri = uri, onDone = { info ->
|
||||
showImportingDialog = false
|
||||
|
||||
if (info.error.isEmpty()) {
|
||||
// TODO: support other model types.
|
||||
modelManagerViewModel.addLocalLlmModel(
|
||||
task = task,
|
||||
fileName = info.fileName,
|
||||
fileSize = info.fileSize
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,8 +23,10 @@ import androidx.activity.result.ActivityResult
|
|||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.google.aiedge.gallery.data.AGWorkInfo
|
||||
import com.google.aiedge.gallery.data.Accelerator
|
||||
import com.google.aiedge.gallery.data.AccessTokenData
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.DataStoreRepository
|
||||
import com.google.aiedge.gallery.data.DownloadRepository
|
||||
import com.google.aiedge.gallery.data.EMPTY_MODEL
|
||||
|
@ -32,7 +34,7 @@ import com.google.aiedge.gallery.data.HfModel
|
|||
import com.google.aiedge.gallery.data.HfModelDetails
|
||||
import com.google.aiedge.gallery.data.HfModelSummary
|
||||
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||
import com.google.aiedge.gallery.data.LocalModelInfo
|
||||
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
|
@ -40,12 +42,14 @@ import com.google.aiedge.gallery.data.TASKS
|
|||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.data.TaskType
|
||||
import com.google.aiedge.gallery.data.ValueType
|
||||
import com.google.aiedge.gallery.data.getModelByName
|
||||
import com.google.aiedge.gallery.ui.common.AuthConfig
|
||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
|
||||
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
|
@ -228,7 +232,7 @@ open class ModelManagerViewModel(
|
|||
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
|
||||
|
||||
// Delete model from the list if model is imported as a local model.
|
||||
if (model.isLocalModel) {
|
||||
if (model.imported) {
|
||||
val index = task.models.indexOf(model)
|
||||
if (index >= 0) {
|
||||
task.models.removeAt(index)
|
||||
|
@ -237,12 +241,12 @@ open class ModelManagerViewModel(
|
|||
curModelDownloadStatus.remove(model.name)
|
||||
|
||||
// Update preference.
|
||||
val localModels = dataStoreRepository.readLocalModels().toMutableList()
|
||||
val localModelIndex = localModels.indexOfFirst { it.fileName == model.name }
|
||||
if (localModelIndex >= 0) {
|
||||
localModels.removeAt(localModelIndex)
|
||||
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
|
||||
val importedModelIndex = importedModels.indexOfFirst { it.fileName == model.name }
|
||||
if (importedModelIndex >= 0) {
|
||||
importedModels.removeAt(importedModelIndex)
|
||||
}
|
||||
dataStoreRepository.saveLocalModels(localModels = localModels)
|
||||
dataStoreRepository.saveImportedModels(importedModels = importedModels)
|
||||
}
|
||||
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
|
||||
_uiState.update { newUiState }
|
||||
|
@ -417,27 +421,20 @@ open class ModelManagerViewModel(
|
|||
return connection.responseCode
|
||||
}
|
||||
|
||||
fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) {
|
||||
Log.d(TAG, "adding local model: $fileName, $fileSize")
|
||||
fun addImportedLlmModel(task: Task, info: ImportedModelInfo) {
|
||||
Log.d(TAG, "adding imported llm model: $info")
|
||||
|
||||
// Create model.
|
||||
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
|
||||
val model = Model(
|
||||
name = fileName,
|
||||
url = "",
|
||||
configs = configs,
|
||||
sizeInBytes = fileSize,
|
||||
downloadFileName = "$IMPORTS_DIR/$fileName",
|
||||
isLocalModel = true,
|
||||
)
|
||||
model.preProcess(task = task)
|
||||
val model = createModelFromImportedModelInfo(info = info, task = task)
|
||||
task.models.add(model)
|
||||
|
||||
// Add initial status and states.
|
||||
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
||||
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
|
||||
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||
status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize
|
||||
status = ModelDownloadStatusType.SUCCEEDED,
|
||||
receivedBytes = info.fileSize,
|
||||
totalBytes = info.fileSize
|
||||
)
|
||||
modelInstances[model.name] =
|
||||
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
||||
|
@ -453,9 +450,9 @@ open class ModelManagerViewModel(
|
|||
task.updateTrigger.value = System.currentTimeMillis()
|
||||
|
||||
// Add to preference storage.
|
||||
val localModels = dataStoreRepository.readLocalModels().toMutableList()
|
||||
localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize))
|
||||
dataStoreRepository.saveLocalModels(localModels = localModels)
|
||||
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
|
||||
importedModels.add(info)
|
||||
dataStoreRepository.saveImportedModels(importedModels = importedModels)
|
||||
}
|
||||
|
||||
fun getTokenStatusAndData(): TokenStatusAndData {
|
||||
|
@ -589,31 +586,22 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
}
|
||||
|
||||
// Load local models.
|
||||
for (localModel in dataStoreRepository.readLocalModels()) {
|
||||
Log.d(TAG, "stored local model: $localModel")
|
||||
// Load imported models.
|
||||
for (importedModel in dataStoreRepository.readImportedModels()) {
|
||||
Log.d(TAG, "stored imported model: $importedModel")
|
||||
|
||||
// Create model.
|
||||
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
|
||||
val model = Model(
|
||||
name = localModel.fileName,
|
||||
url = "",
|
||||
configs = configs,
|
||||
sizeInBytes = localModel.fileSize,
|
||||
downloadFileName = "$IMPORTS_DIR/${localModel.fileName}",
|
||||
isLocalModel = true,
|
||||
)
|
||||
val model = createModelFromImportedModelInfo(info = importedModel, task = TASK_LLM_CHAT)
|
||||
|
||||
// Add to task.
|
||||
val task = TASK_LLM_CHAT
|
||||
model.preProcess(task = task)
|
||||
task.models.add(model)
|
||||
|
||||
// Update status.
|
||||
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||
status = ModelDownloadStatusType.SUCCEEDED,
|
||||
receivedBytes = localModel.fileSize,
|
||||
totalBytes = localModel.fileSize
|
||||
receivedBytes = importedModel.fileSize,
|
||||
totalBytes = importedModel.fileSize
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -628,6 +616,51 @@ open class ModelManagerViewModel(
|
|||
)
|
||||
}
|
||||
|
||||
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
|
||||
val accelerators: List<Accelerator> = (convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!,
|
||||
ValueType.STRING
|
||||
) as String)
|
||||
.split(",")
|
||||
.mapNotNull { acceleratorLabel ->
|
||||
when (acceleratorLabel.trim()) {
|
||||
Accelerator.GPU.label -> Accelerator.GPU
|
||||
Accelerator.CPU.label -> Accelerator.CPU
|
||||
else -> null // Ignore unknown accelerator labels
|
||||
}
|
||||
}
|
||||
val configs: List<Config> = createLlmChatConfigs(
|
||||
defaultMaxToken = convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!,
|
||||
ValueType.INT
|
||||
) as Int,
|
||||
defaultTopK = convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!,
|
||||
ValueType.INT
|
||||
) as Int,
|
||||
defaultTopP = convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!,
|
||||
ValueType.FLOAT
|
||||
) as Float,
|
||||
defaultTemperature = convertValueToTargetType(
|
||||
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!,
|
||||
ValueType.FLOAT
|
||||
) as Float,
|
||||
accelerators = accelerators,
|
||||
)
|
||||
val model = Model(
|
||||
name = info.fileName,
|
||||
url = "",
|
||||
configs = configs,
|
||||
sizeInBytes = info.fileSize,
|
||||
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
|
||||
imported = true,
|
||||
)
|
||||
model.preProcess(task = task)
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the download status of a model.
|
||||
*
|
||||
|
@ -771,9 +804,7 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
|
||||
private fun updateModelInitializationStatus(
|
||||
model: Model,
|
||||
status: ModelInitializationStatusType,
|
||||
error: String = ""
|
||||
model: Model, status: ModelInitializationStatusType, error: String = ""
|
||||
) {
|
||||
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
|
||||
curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error)
|
||||
|
|
|
@ -18,7 +18,7 @@ package com.google.aiedge.gallery.ui.preview
|
|||
|
||||
import com.google.aiedge.gallery.data.AccessTokenData
|
||||
import com.google.aiedge.gallery.data.DataStoreRepository
|
||||
import com.google.aiedge.gallery.data.LocalModelInfo
|
||||
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||
|
||||
class PreviewDataStoreRepository : DataStoreRepository {
|
||||
override fun saveTextInputHistory(history: List<String>) {
|
||||
|
@ -42,10 +42,10 @@ class PreviewDataStoreRepository : DataStoreRepository {
|
|||
return null
|
||||
}
|
||||
|
||||
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
|
||||
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
|
||||
}
|
||||
|
||||
override fun readLocalModels(): List<LocalModelInfo> {
|
||||
override fun readImportedModels(): List<ImportedModelInfo> {
|
||||
return listOf()
|
||||
}
|
||||
}
|
|
@ -22,6 +22,7 @@ import androidx.compose.material.icons.rounded.AutoAwesome
|
|||
import com.google.aiedge.gallery.data.BooleanSwitchConfig
|
||||
import com.google.aiedge.gallery.data.Config
|
||||
import com.google.aiedge.gallery.data.ConfigKey
|
||||
import com.google.aiedge.gallery.data.LabelConfig
|
||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.NumberSliderConfig
|
||||
|
@ -30,6 +31,10 @@ import com.google.aiedge.gallery.data.TaskType
|
|||
import com.google.aiedge.gallery.data.ValueType
|
||||
|
||||
val TEST_CONFIGS1: List<Config> = listOf(
|
||||
LabelConfig(
|
||||
key = ConfigKey.NAME,
|
||||
defaultValue = "Test name",
|
||||
),
|
||||
NumberSliderConfig(
|
||||
key = ConfigKey.MAX_RESULT_COUNT,
|
||||
sliderMin = 1f,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue