mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-12 17:32:30 -04:00
Make importing model functionality better.
- Allow users to specify default parameters before importing.
This commit is contained in:
parent
29b614355e
commit
604972fe23
15 changed files with 635 additions and 329 deletions
|
@ -23,6 +23,7 @@ package com.google.aiedge.gallery.data
|
||||||
* Each type corresponds to a specific editor widget, such as a slider or a switch.
|
* Each type corresponds to a specific editor widget, such as a slider or a switch.
|
||||||
*/
|
*/
|
||||||
enum class ConfigEditorType {
|
enum class ConfigEditorType {
|
||||||
|
LABEL,
|
||||||
NUMBER_SLIDER,
|
NUMBER_SLIDER,
|
||||||
BOOLEAN_SWITCH,
|
BOOLEAN_SWITCH,
|
||||||
DROPDOWN,
|
DROPDOWN,
|
||||||
|
@ -57,6 +58,19 @@ open class Config(
|
||||||
open val needReinitialization: Boolean = true,
|
open val needReinitialization: Boolean = true,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration setting for a label.
|
||||||
|
*/
|
||||||
|
class LabelConfig(
|
||||||
|
override val key: ConfigKey,
|
||||||
|
override val defaultValue: String = "",
|
||||||
|
) : Config(
|
||||||
|
type = ConfigEditorType.LABEL,
|
||||||
|
key = key,
|
||||||
|
defaultValue = defaultValue,
|
||||||
|
valueType = ValueType.STRING
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configuration setting for a number slider.
|
* Configuration setting for a number slider.
|
||||||
*
|
*
|
||||||
|
@ -99,9 +113,11 @@ class SegmentedButtonConfig(
|
||||||
override val key: ConfigKey,
|
override val key: ConfigKey,
|
||||||
override val defaultValue: String,
|
override val defaultValue: String,
|
||||||
val options: List<String>,
|
val options: List<String>,
|
||||||
|
val allowMultiple: Boolean = false,
|
||||||
) : Config(
|
) : Config(
|
||||||
type = ConfigEditorType.DROPDOWN,
|
type = ConfigEditorType.DROPDOWN,
|
||||||
key = key,
|
key = key,
|
||||||
defaultValue = defaultValue,
|
defaultValue = defaultValue,
|
||||||
|
// The emitted value will be comma-separated labels when allowMultiple=true.
|
||||||
valueType = ValueType.STRING,
|
valueType = ValueType.STRING,
|
||||||
)
|
)
|
|
@ -47,8 +47,8 @@ interface DataStoreRepository {
|
||||||
fun readThemeOverride(): String
|
fun readThemeOverride(): String
|
||||||
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
|
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
|
||||||
fun readAccessTokenData(): AccessTokenData?
|
fun readAccessTokenData(): AccessTokenData?
|
||||||
fun saveLocalModels(localModels: List<LocalModelInfo>)
|
fun saveImportedModels(importedModels: List<ImportedModelInfo>)
|
||||||
fun readLocalModels(): List<LocalModelInfo>
|
fun readImportedModels(): List<ImportedModelInfo>
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -82,8 +82,8 @@ class DefaultDataStoreRepository(
|
||||||
|
|
||||||
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
|
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
|
||||||
|
|
||||||
// Data for all imported local models.
|
// Data for all imported models.
|
||||||
val LOCAL_MODELS = stringPreferencesKey("local_models")
|
val IMPORTED_MODELS = stringPreferencesKey("imported_models")
|
||||||
}
|
}
|
||||||
|
|
||||||
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
|
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
|
||||||
|
@ -160,22 +160,22 @@ class DefaultDataStoreRepository(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
|
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
|
||||||
runBlocking {
|
runBlocking {
|
||||||
dataStore.edit { preferences ->
|
dataStore.edit { preferences ->
|
||||||
val gson = Gson()
|
val gson = Gson()
|
||||||
val jsonString = gson.toJson(localModels)
|
val jsonString = gson.toJson(importedModels)
|
||||||
preferences[PreferencesKeys.LOCAL_MODELS] = jsonString
|
preferences[PreferencesKeys.IMPORTED_MODELS] = jsonString
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun readLocalModels(): List<LocalModelInfo> {
|
override fun readImportedModels(): List<ImportedModelInfo> {
|
||||||
return runBlocking {
|
return runBlocking {
|
||||||
val preferences = dataStore.data.first()
|
val preferences = dataStore.data.first()
|
||||||
val infosStr = preferences[PreferencesKeys.LOCAL_MODELS] ?: "[]"
|
val infosStr = preferences[PreferencesKeys.IMPORTED_MODELS] ?: "[]"
|
||||||
val gson = Gson()
|
val gson = Gson()
|
||||||
val listType = object : TypeToken<List<LocalModelInfo>>() {}.type
|
val listType = object : TypeToken<List<ImportedModelInfo>>() {}.type
|
||||||
gson.fromJson(infosStr, listType)
|
gson.fromJson(infosStr, listType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package com.google.aiedge.gallery.data
|
package com.google.aiedge.gallery.data
|
||||||
|
|
||||||
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
||||||
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
|
|
||||||
import kotlinx.serialization.KSerializer
|
import kotlinx.serialization.KSerializer
|
||||||
import kotlinx.serialization.Serializable
|
import kotlinx.serialization.Serializable
|
||||||
import kotlinx.serialization.SerializationException
|
import kotlinx.serialization.SerializationException
|
||||||
|
@ -107,11 +106,13 @@ data class HfModel(
|
||||||
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
|
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
|
||||||
|
|
||||||
// Generate configs based on the given default values.
|
// Generate configs based on the given default values.
|
||||||
val configs: List<Config> = when (task) {
|
// val configs: List<Config> = when (task) {
|
||||||
TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
|
// TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
|
||||||
// todo: add configs for other types.
|
// // todo: add configs for other types.
|
||||||
else -> listOf()
|
// else -> listOf()
|
||||||
}
|
// }
|
||||||
|
// todo: fix when loading from models.json
|
||||||
|
val configs: List<Config> = listOf()
|
||||||
|
|
||||||
// Construct url.
|
// Construct url.
|
||||||
var modelUrl = url
|
var modelUrl = url
|
||||||
|
|
|
@ -19,6 +19,7 @@ package com.google.aiedge.gallery.data
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
||||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
|
||||||
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||||
|
|
||||||
data class ModelDataFile(
|
data class ModelDataFile(
|
||||||
|
@ -28,8 +29,8 @@ data class ModelDataFile(
|
||||||
val sizeInBytes: Long,
|
val sizeInBytes: Long,
|
||||||
)
|
)
|
||||||
|
|
||||||
enum class LlmBackend {
|
enum class Accelerator(val label: String) {
|
||||||
CPU, GPU
|
CPU(label = "CPU"), GPU(label = "GPU")
|
||||||
}
|
}
|
||||||
|
|
||||||
const val IMPORTS_DIR = "__imports"
|
const val IMPORTS_DIR = "__imports"
|
||||||
|
@ -81,14 +82,14 @@ data class Model(
|
||||||
/** The name of the directory to unzip the model to (if it's a zip file). */
|
/** The name of the directory to unzip the model to (if it's a zip file). */
|
||||||
val unzipDir: String = "",
|
val unzipDir: String = "",
|
||||||
|
|
||||||
/** The preferred backend of the model (only for LLM). */
|
/** The accelerators the the model can run with. */
|
||||||
val llmBackend: LlmBackend = LlmBackend.GPU,
|
val accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
|
||||||
|
|
||||||
/** The prompt templates for the model (only for LLM). */
|
/** The prompt templates for the model (only for LLM). */
|
||||||
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
val llmPromptTemplates: List<PromptTemplate> = listOf(),
|
||||||
|
|
||||||
/** Whether the model is imported as a local model. */
|
/** Whether the model is imported or not. */
|
||||||
val isLocalModel: Boolean = false,
|
val imported: Boolean = false,
|
||||||
|
|
||||||
// The following fields are managed by the app. Don't need to set manually.
|
// The following fields are managed by the app. Don't need to set manually.
|
||||||
var taskType: TaskType? = null,
|
var taskType: TaskType? = null,
|
||||||
|
@ -135,6 +136,12 @@ data class Model(
|
||||||
) as Boolean
|
) as Boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun getStringConfigValue(key: ConfigKey, defaultValue: String = ""): String {
|
||||||
|
return getTypedConfigValue(
|
||||||
|
key = key, valueType = ValueType.STRING, defaultValue = defaultValue
|
||||||
|
) as String
|
||||||
|
}
|
||||||
|
|
||||||
fun getExtraDataFile(name: String): ModelDataFile? {
|
fun getExtraDataFile(name: String): ModelDataFile? {
|
||||||
return extraDataFiles.find { it.name == name }
|
return extraDataFiles.find { it.name == name }
|
||||||
}
|
}
|
||||||
|
@ -147,7 +154,11 @@ data class Model(
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Data for a imported local model. */
|
/** Data for a imported local model. */
|
||||||
data class LocalModelInfo(val fileName: String, val fileSize: Long)
|
data class ImportedModelInfo(
|
||||||
|
val fileName: String,
|
||||||
|
val fileSize: Long,
|
||||||
|
val defaultValues: Map<String, Any>
|
||||||
|
)
|
||||||
|
|
||||||
enum class ModelDownloadStatusType {
|
enum class ModelDownloadStatusType {
|
||||||
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
|
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
|
||||||
|
@ -165,29 +176,25 @@ data class ModelDownloadStatus(
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// Configs.
|
// Configs.
|
||||||
|
|
||||||
enum class ConfigKey(val label: String, val id: String) {
|
enum class ConfigKey(val label: String) {
|
||||||
MAX_TOKENS("Max tokens", id = "max_token"),
|
MAX_TOKENS("Max tokens"),
|
||||||
TOPK("TopK", id = "topk"),
|
TOPK("TopK"),
|
||||||
TOPP(
|
TOPP("TopP"),
|
||||||
"TopP",
|
TEMPERATURE("Temperature"),
|
||||||
id = "topp"
|
DEFAULT_MAX_TOKENS("Default max tokens"),
|
||||||
),
|
DEFAULT_TOPK("Default TopK"),
|
||||||
TEMPERATURE("Temperature", id = "temperature"),
|
DEFAULT_TOPP("Default TopP"),
|
||||||
MAX_RESULT_COUNT(
|
DEFAULT_TEMPERATURE("Default temperature"),
|
||||||
"Max result count",
|
MAX_RESULT_COUNT("Max result count"),
|
||||||
id = "max_result_count"
|
USE_GPU("Use GPU"),
|
||||||
),
|
ACCELERATOR("Accelerator"),
|
||||||
USE_GPU("Use GPU", id = "use_gpu"),
|
COMPATIBLE_ACCELERATORS("Compatible accelerators"),
|
||||||
WARM_UP_ITERATIONS(
|
WARM_UP_ITERATIONS("Warm up iterations"),
|
||||||
"Warm up iterations",
|
BENCHMARK_ITERATIONS("Benchmark iterations"),
|
||||||
id = "warm_up_iterations"
|
ITERATIONS("Iterations"),
|
||||||
),
|
THEME("Theme"),
|
||||||
BENCHMARK_ITERATIONS(
|
NAME("Name"),
|
||||||
"Benchmark iterations",
|
MODEL_TYPE("Model type")
|
||||||
id = "benchmark_iterations"
|
|
||||||
),
|
|
||||||
ITERATIONS("Iterations", id = "iterations"),
|
|
||||||
THEME("Theme", id = "theme"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val MOBILENET_CONFIGS: List<Config> = listOf(
|
val MOBILENET_CONFIGS: List<Config> = listOf(
|
||||||
|
@ -258,7 +265,12 @@ val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
|
||||||
downloadFileName = "gemma3-1b-it-int4.task",
|
downloadFileName = "gemma3-1b-it-int4.task",
|
||||||
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
|
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
|
||||||
sizeInBytes = 554661243L,
|
sizeInBytes = 554661243L,
|
||||||
configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f),
|
accelerators = listOf(Accelerator.CPU, Accelerator.GPU),
|
||||||
|
configs = createLlmChatConfigs(
|
||||||
|
defaultTopK = 64,
|
||||||
|
defaultTopP = 0.95f,
|
||||||
|
accelerators = listOf(Accelerator.CPU, Accelerator.GPU)
|
||||||
|
),
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
learnMoreUrl = "https://huggingface.co/litert-community/Gemma3-1B-IT",
|
||||||
llmPromptTemplates = listOf(
|
llmPromptTemplates = listOf(
|
||||||
|
@ -280,8 +292,13 @@ val MODEL_LLM_DEEPSEEK: Model = Model(
|
||||||
downloadFileName = "deepseek.task",
|
downloadFileName = "deepseek.task",
|
||||||
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
|
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
|
||||||
sizeInBytes = 1860686856L,
|
sizeInBytes = 1860686856L,
|
||||||
llmBackend = LlmBackend.CPU,
|
accelerators = listOf(Accelerator.CPU),
|
||||||
configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f),
|
configs = createLlmChatConfigs(
|
||||||
|
defaultTemperature = 0.6f,
|
||||||
|
defaultTopK = 40,
|
||||||
|
defaultTopP = 0.7f,
|
||||||
|
accelerators = listOf(Accelerator.CPU)
|
||||||
|
),
|
||||||
info = LLM_CHAT_INFO,
|
info = LLM_CHAT_INFO,
|
||||||
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
learnMoreUrl = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,16 +34,15 @@ import androidx.compose.foundation.text.KeyboardOptions
|
||||||
import androidx.compose.material3.Button
|
import androidx.compose.material3.Button
|
||||||
import androidx.compose.material3.Card
|
import androidx.compose.material3.Card
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.MultiChoiceSegmentedButtonRow
|
||||||
import androidx.compose.material3.SegmentedButton
|
import androidx.compose.material3.SegmentedButton
|
||||||
import androidx.compose.material3.SegmentedButtonDefaults
|
import androidx.compose.material3.SegmentedButtonDefaults
|
||||||
import androidx.compose.material3.SingleChoiceSegmentedButtonRow
|
|
||||||
import androidx.compose.material3.Slider
|
import androidx.compose.material3.Slider
|
||||||
import androidx.compose.material3.Switch
|
import androidx.compose.material3.Switch
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.material3.TextButton
|
import androidx.compose.material3.TextButton
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableIntStateOf
|
|
||||||
import androidx.compose.runtime.mutableStateMapOf
|
import androidx.compose.runtime.mutableStateMapOf
|
||||||
import androidx.compose.runtime.mutableStateOf
|
import androidx.compose.runtime.mutableStateOf
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
|
@ -60,6 +59,7 @@ import androidx.compose.ui.unit.dp
|
||||||
import androidx.compose.ui.window.Dialog
|
import androidx.compose.ui.window.Dialog
|
||||||
import com.google.aiedge.gallery.data.BooleanSwitchConfig
|
import com.google.aiedge.gallery.data.BooleanSwitchConfig
|
||||||
import com.google.aiedge.gallery.data.Config
|
import com.google.aiedge.gallery.data.Config
|
||||||
|
import com.google.aiedge.gallery.data.LabelConfig
|
||||||
import com.google.aiedge.gallery.data.NumberSliderConfig
|
import com.google.aiedge.gallery.data.NumberSliderConfig
|
||||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||||
import com.google.aiedge.gallery.data.ValueType
|
import com.google.aiedge.gallery.data.ValueType
|
||||||
|
@ -113,27 +113,10 @@ fun ConfigDialog(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// List of config rows.
|
// List of config rows.
|
||||||
for (config in configs) {
|
ConfigEditorsPanel(configs = configs, values = values)
|
||||||
when (config) {
|
|
||||||
// Number slider.
|
|
||||||
is NumberSliderConfig -> {
|
|
||||||
NumberSliderRow(config = config, values = values)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Boolean switch.
|
// Button row.
|
||||||
is BooleanSwitchConfig -> {
|
|
||||||
BooleanSwitchRow(config = config, values = values)
|
|
||||||
}
|
|
||||||
|
|
||||||
is SegmentedButtonConfig -> {
|
|
||||||
SegmentedButtonRow(config = config, values = values)
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Row(
|
Row(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
|
@ -164,6 +147,53 @@ fun ConfigDialog(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Composable function to display a list of config editor rows.
|
||||||
|
*/
|
||||||
|
@Composable
|
||||||
|
fun ConfigEditorsPanel(configs: List<Config>, values: SnapshotStateMap<String, Any>) {
|
||||||
|
for (config in configs) {
|
||||||
|
when (config) {
|
||||||
|
// Label.
|
||||||
|
is LabelConfig -> {
|
||||||
|
LabelRow(config = config, values = values)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number slider.
|
||||||
|
is NumberSliderConfig -> {
|
||||||
|
NumberSliderRow(config = config, values = values)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boolean switch.
|
||||||
|
is BooleanSwitchConfig -> {
|
||||||
|
BooleanSwitchRow(config = config, values = values)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Segmented button.
|
||||||
|
is SegmentedButtonConfig -> {
|
||||||
|
SegmentedButtonRow(config = config, values = values)
|
||||||
|
}
|
||||||
|
|
||||||
|
else -> {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun LabelRow(config: LabelConfig, values: SnapshotStateMap<String, Any>) {
|
||||||
|
Column(modifier = Modifier.fillMaxWidth()) {
|
||||||
|
// Field label.
|
||||||
|
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
|
||||||
|
// Content label.
|
||||||
|
val label = try {
|
||||||
|
values[config.key.label] as String
|
||||||
|
} catch (e: Exception) {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
Text(label, style = MaterialTheme.typography.bodyMedium)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Composable function to display a number slider with an associated text input field.
|
* Composable function to display a number slider with an associated text input field.
|
||||||
*
|
*
|
||||||
|
@ -272,18 +302,41 @@ fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<Strin
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) {
|
fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) {
|
||||||
var selectedIndex by remember { mutableIntStateOf(config.options.indexOf(values[config.key.label])) }
|
val selectedOptions: List<String> = remember { (values[config.key.label] as String).split(",") }
|
||||||
|
var selectionStates: List<Boolean> by remember {
|
||||||
|
mutableStateOf(List(config.options.size) { index ->
|
||||||
|
selectedOptions.contains(config.options[index])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
Column(modifier = Modifier.fillMaxWidth()) {
|
Column(modifier = Modifier.fillMaxWidth()) {
|
||||||
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
|
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
|
||||||
SingleChoiceSegmentedButtonRow {
|
MultiChoiceSegmentedButtonRow {
|
||||||
config.options.forEachIndexed { index, label ->
|
config.options.forEachIndexed { index, label ->
|
||||||
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
|
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
|
||||||
index = index, count = config.options.size
|
index = index, count = config.options.size
|
||||||
), onClick = {
|
), onCheckedChange = {
|
||||||
selectedIndex = index
|
var newSelectionStates = selectionStates.toMutableList()
|
||||||
values[config.key.label] = label
|
val selectedCount = newSelectionStates.count { it }
|
||||||
}, selected = index == selectedIndex, label = { Text(label) })
|
|
||||||
|
// Single select.
|
||||||
|
if (!config.allowMultiple) {
|
||||||
|
if (!newSelectionStates[index]) {
|
||||||
|
newSelectionStates = MutableList(config.options.size) { it == index }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Multiple select.
|
||||||
|
else {
|
||||||
|
if (!(selectedCount == 1 && newSelectionStates[index])) {
|
||||||
|
newSelectionStates[index] = !newSelectionStates[index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
selectionStates = newSelectionStates
|
||||||
|
|
||||||
|
values[config.key.label] =
|
||||||
|
config.options.filterIndexed { index, option -> selectionStates[index] }
|
||||||
|
.joinToString(",")
|
||||||
|
}, checked = selectionStates[index], label = { Text(label) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -92,7 +92,6 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
|
||||||
* model description and buttons for learning more (opening a URL) and downloading/trying
|
* model description and buttons for learning more (opening a URL) and downloading/trying
|
||||||
* the model.
|
* the model.
|
||||||
*/
|
*/
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelItem(
|
fun ModelItem(
|
||||||
model: Model,
|
model: Model,
|
||||||
|
@ -188,9 +187,9 @@ fun ModelItem(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Icon(
|
Icon(
|
||||||
// For local model, show ">" directly indicating users can just tap the model item to
|
// For imported model, show ">" directly indicating users can just tap the model item to
|
||||||
// go into it without needing to expand it first.
|
// go into it without needing to expand it first.
|
||||||
if (model.isLocalModel) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
||||||
contentDescription = "",
|
contentDescription = "",
|
||||||
tint = getTaskIconColor(task),
|
tint = getTaskIconColor(task),
|
||||||
)
|
)
|
||||||
|
@ -272,7 +271,7 @@ fun ModelItem(
|
||||||
boxModifier = if (canExpand) {
|
boxModifier = if (canExpand) {
|
||||||
boxModifier.clickable(
|
boxModifier.clickable(
|
||||||
onClick = {
|
onClick = {
|
||||||
if (!model.isLocalModel) {
|
if (!model.imported) {
|
||||||
isExpanded = !isExpanded
|
isExpanded = !isExpanded
|
||||||
} else {
|
} else {
|
||||||
onModelClicked(model)
|
onModelClicked(model)
|
||||||
|
|
|
@ -16,13 +16,22 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.home
|
package com.google.aiedge.gallery.ui.home
|
||||||
|
|
||||||
|
import android.content.Intent
|
||||||
|
import android.net.Uri
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||||
|
import androidx.activity.result.ActivityResultLauncher
|
||||||
|
import androidx.activity.result.contract.ActivityResultContracts
|
||||||
import androidx.annotation.StringRes
|
import androidx.annotation.StringRes
|
||||||
|
import androidx.compose.animation.core.animateFloatAsState
|
||||||
|
import androidx.compose.animation.core.tween
|
||||||
import androidx.compose.foundation.background
|
import androidx.compose.foundation.background
|
||||||
import androidx.compose.foundation.clickable
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.PaddingValues
|
import androidx.compose.foundation.layout.PaddingValues
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.Spacer
|
import androidx.compose.foundation.layout.Spacer
|
||||||
import androidx.compose.foundation.layout.aspectRatio
|
import androidx.compose.foundation.layout.aspectRatio
|
||||||
import androidx.compose.foundation.layout.fillMaxSize
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
|
@ -34,22 +43,34 @@ import androidx.compose.foundation.lazy.grid.GridItemSpan
|
||||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||||
import androidx.compose.foundation.lazy.grid.items
|
import androidx.compose.foundation.lazy.grid.items
|
||||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
||||||
|
import androidx.compose.material.icons.filled.Add
|
||||||
import androidx.compose.material3.Card
|
import androidx.compose.material3.Card
|
||||||
import androidx.compose.material3.CardDefaults
|
import androidx.compose.material3.CardDefaults
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.ModalBottomSheet
|
||||||
import androidx.compose.material3.Scaffold
|
import androidx.compose.material3.Scaffold
|
||||||
|
import androidx.compose.material3.SmallFloatingActionButton
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.material3.TopAppBarDefaults
|
import androidx.compose.material3.TopAppBarDefaults
|
||||||
|
import androidx.compose.material3.rememberModalBottomSheetState
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
import androidx.compose.runtime.collectAsState
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.derivedStateOf
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableStateOf
|
import androidx.compose.runtime.mutableStateOf
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.alpha
|
||||||
import androidx.compose.ui.draw.clip
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.draw.scale
|
||||||
import androidx.compose.ui.graphics.Brush
|
import androidx.compose.ui.graphics.Brush
|
||||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||||
import androidx.compose.ui.layout.layout
|
import androidx.compose.ui.layout.layout
|
||||||
|
@ -66,6 +87,8 @@ import com.google.aiedge.gallery.R
|
||||||
import com.google.aiedge.gallery.data.AppBarAction
|
import com.google.aiedge.gallery.data.AppBarAction
|
||||||
import com.google.aiedge.gallery.data.AppBarActionType
|
import com.google.aiedge.gallery.data.AppBarActionType
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
|
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.TaskIcon
|
import com.google.aiedge.gallery.ui.common.TaskIcon
|
||||||
import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
||||||
|
@ -75,6 +98,11 @@ import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
||||||
import com.google.aiedge.gallery.ui.theme.customColors
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
|
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
|
private const val TAG = "AGHomeScreen"
|
||||||
|
private const val TASK_COUNT_ANIMATION_DURATION = 250
|
||||||
|
|
||||||
/** Navigation destination data */
|
/** Navigation destination data */
|
||||||
object HomeScreenDestination {
|
object HomeScreenDestination {
|
||||||
|
@ -92,22 +120,59 @@ fun HomeScreen(
|
||||||
val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior()
|
val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior()
|
||||||
val uiState by modelManagerViewModel.uiState.collectAsState()
|
val uiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
var showSettingsDialog by remember { mutableStateOf(false) }
|
var showSettingsDialog by remember { mutableStateOf(false) }
|
||||||
|
var showImportModelSheet by remember { mutableStateOf(false) }
|
||||||
|
val sheetState = rememberModalBottomSheetState()
|
||||||
|
var showImportDialog by remember { mutableStateOf(false) }
|
||||||
|
var showImportingDialog by remember { mutableStateOf(false) }
|
||||||
|
val selectedLocalModelFileUri = remember { mutableStateOf<Uri?>(null) }
|
||||||
|
val selectedImportedModelInfo = remember { mutableStateOf<ImportedModelInfo?>(null) }
|
||||||
|
val coroutineScope = rememberCoroutineScope()
|
||||||
|
|
||||||
val tasks = uiState.tasks
|
val tasks = uiState.tasks
|
||||||
val loadingHfModels = uiState.loadingHfModels
|
val loadingHfModels = uiState.loadingHfModels
|
||||||
|
|
||||||
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = {
|
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
||||||
GalleryTopAppBar(
|
contract = ActivityResultContracts.StartActivityForResult()
|
||||||
title = stringResource(HomeScreenDestination.titleRes),
|
) { result ->
|
||||||
rightAction = AppBarAction(
|
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
||||||
actionType = AppBarActionType.APP_SETTING, actionFn = {
|
result.data?.data?.let { uri ->
|
||||||
showSettingsDialog = true
|
selectedLocalModelFileUri.value = uri
|
||||||
}
|
showImportDialog = true
|
||||||
),
|
} ?: run {
|
||||||
loadingHfModels = loadingHfModels,
|
Log.d(TAG, "No file selected or URI is null.")
|
||||||
scrollBehavior = scrollBehavior,
|
}
|
||||||
)
|
} else {
|
||||||
}) { innerPadding ->
|
Log.d(TAG, "File picking cancelled.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Scaffold(
|
||||||
|
modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
|
||||||
|
topBar = {
|
||||||
|
GalleryTopAppBar(
|
||||||
|
title = stringResource(HomeScreenDestination.titleRes),
|
||||||
|
rightAction = AppBarAction(
|
||||||
|
actionType = AppBarActionType.APP_SETTING, actionFn = {
|
||||||
|
showSettingsDialog = true
|
||||||
|
}
|
||||||
|
),
|
||||||
|
loadingHfModels = loadingHfModels,
|
||||||
|
scrollBehavior = scrollBehavior,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
floatingActionButton = {
|
||||||
|
// A floating action button to show "import model" bottom sheet.
|
||||||
|
SmallFloatingActionButton(
|
||||||
|
onClick = {
|
||||||
|
showImportModelSheet = true
|
||||||
|
},
|
||||||
|
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||||
|
contentColor = MaterialTheme.colorScheme.secondary,
|
||||||
|
) {
|
||||||
|
Icon(Icons.Filled.Add, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
) { innerPadding ->
|
||||||
TaskList(
|
TaskList(
|
||||||
tasks = tasks,
|
tasks = tasks,
|
||||||
navigateToTaskScreen = navigateToTaskScreen,
|
navigateToTaskScreen = navigateToTaskScreen,
|
||||||
|
@ -132,6 +197,83 @@ fun HomeScreen(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Import model bottom sheet.
|
||||||
|
if (showImportModelSheet) {
|
||||||
|
ModalBottomSheet(
|
||||||
|
onDismissRequest = { showImportModelSheet = false },
|
||||||
|
sheetState = sheetState,
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"Import model",
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
|
||||||
|
)
|
||||||
|
Box(modifier = Modifier.clickable {
|
||||||
|
coroutineScope.launch {
|
||||||
|
// Give it sometime to show the click effect.
|
||||||
|
delay(200)
|
||||||
|
showImportModelSheet = false
|
||||||
|
|
||||||
|
// Show file picker.
|
||||||
|
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||||
|
addCategory(Intent.CATEGORY_OPENABLE)
|
||||||
|
type = "*/*"
|
||||||
|
putExtra(
|
||||||
|
Intent.EXTRA_MIME_TYPES,
|
||||||
|
arrayOf("application/x-binary", "application/octet-stream")
|
||||||
|
)
|
||||||
|
// Single select.
|
||||||
|
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
||||||
|
}
|
||||||
|
filePickerLauncher.launch(intent)
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(16.dp)
|
||||||
|
) {
|
||||||
|
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
|
||||||
|
Text("From local model file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import dialog
|
||||||
|
if (showImportDialog) {
|
||||||
|
selectedLocalModelFileUri.value?.let { uri ->
|
||||||
|
ModelImportDialog(uri = uri,
|
||||||
|
onDismiss = { showImportDialog = false },
|
||||||
|
onDone = { info ->
|
||||||
|
selectedImportedModelInfo.value = info
|
||||||
|
showImportDialog = false
|
||||||
|
showImportingDialog = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Importing in progress dialog.
|
||||||
|
if (showImportingDialog) {
|
||||||
|
selectedLocalModelFileUri.value?.let { uri ->
|
||||||
|
selectedImportedModelInfo.value?.let { info ->
|
||||||
|
ModelImportingDialog(
|
||||||
|
uri = uri,
|
||||||
|
info = info,
|
||||||
|
onDismiss = { showImportingDialog = false },
|
||||||
|
onDone = {
|
||||||
|
modelManagerViewModel.addImportedLlmModel(
|
||||||
|
task = TASK_LLM_CHAT,
|
||||||
|
info = it,
|
||||||
|
)
|
||||||
|
showImportingDialog = false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
|
@ -150,7 +292,7 @@ private fun TaskList(
|
||||||
verticalArrangement = Arrangement.spacedBy(8.dp),
|
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
) {
|
) {
|
||||||
// Headline.
|
// Headline.
|
||||||
item(span = { GridItemSpan(2) }) {
|
item(key = "headline", span = { GridItemSpan(2) }) {
|
||||||
Text(
|
Text(
|
||||||
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community",
|
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community",
|
||||||
textAlign = TextAlign.Center,
|
textAlign = TextAlign.Center,
|
||||||
|
@ -171,6 +313,11 @@ private fun TaskList(
|
||||||
.aspectRatio(1f)
|
.aspectRatio(1f)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bottom padding.
|
||||||
|
item(key = "bottomPadding", span = { GridItemSpan(2) }) {
|
||||||
|
Spacer(modifier = Modifier.height(60.dp))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gradient overlay at the bottom.
|
// Gradient overlay at the bottom.
|
||||||
|
@ -190,6 +337,48 @@ private fun TaskList(
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) {
|
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) {
|
||||||
|
// Observes the model count and updates the model count label with a fade-in/fade-out animation
|
||||||
|
// whenever the count changes.
|
||||||
|
val modelCount by remember {
|
||||||
|
derivedStateOf {
|
||||||
|
val trigger = task.updateTrigger.value
|
||||||
|
if (trigger >= 0) {
|
||||||
|
task.models.size
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val modelCountLabel by remember {
|
||||||
|
derivedStateOf {
|
||||||
|
when (modelCount) {
|
||||||
|
1 -> "1 Model"
|
||||||
|
else -> "%d Models".format(modelCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var curModelCountLabel by remember { mutableStateOf("") }
|
||||||
|
var modelCountLabelVisible by remember { mutableStateOf(true) }
|
||||||
|
val modelCountAlpha: Float by animateFloatAsState(
|
||||||
|
targetValue = if (modelCountLabelVisible) 1f else 0f,
|
||||||
|
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
|
||||||
|
)
|
||||||
|
val modelCountScale: Float by animateFloatAsState(
|
||||||
|
targetValue = if (modelCountLabelVisible) 1f else 0.7f,
|
||||||
|
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
|
||||||
|
)
|
||||||
|
|
||||||
|
LaunchedEffect(modelCountLabel) {
|
||||||
|
if (curModelCountLabel.isEmpty()) {
|
||||||
|
curModelCountLabel = modelCountLabel
|
||||||
|
} else {
|
||||||
|
modelCountLabelVisible = false
|
||||||
|
delay(TASK_COUNT_ANIMATION_DURATION.toLong())
|
||||||
|
curModelCountLabel = modelCountLabel
|
||||||
|
modelCountLabelVisible = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Card(
|
Card(
|
||||||
modifier = modifier
|
modifier = modifier
|
||||||
.clip(RoundedCornerShape(43.5.dp))
|
.clip(RoundedCornerShape(43.5.dp))
|
||||||
|
@ -238,14 +427,13 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model count.
|
// Model count.
|
||||||
val modelCountLabel = when (task.models.size) {
|
|
||||||
1 -> "1 Model"
|
|
||||||
else -> "%d Models".format(task.models.size)
|
|
||||||
}
|
|
||||||
Text(
|
Text(
|
||||||
modelCountLabel,
|
curModelCountLabel,
|
||||||
color = MaterialTheme.colorScheme.secondary,
|
color = MaterialTheme.colorScheme.secondary,
|
||||||
style = MaterialTheme.typography.bodyMedium
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
modifier = Modifier
|
||||||
|
.alpha(modelCountAlpha)
|
||||||
|
.scale(modelCountScale),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.modelmanager
|
package com.google.aiedge.gallery.ui.home
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
|
@ -36,24 +36,40 @@ import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.LinearProgressIndicator
|
import androidx.compose.material3.LinearProgressIndicator
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.material3.TextButton
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.LaunchedEffect
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableFloatStateOf
|
import androidx.compose.runtime.mutableFloatStateOf
|
||||||
import androidx.compose.runtime.mutableLongStateOf
|
import androidx.compose.runtime.mutableLongStateOf
|
||||||
|
import androidx.compose.runtime.mutableStateMapOf
|
||||||
import androidx.compose.runtime.mutableStateOf
|
import androidx.compose.runtime.mutableStateOf
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.runtime.snapshots.SnapshotStateMap
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import androidx.compose.ui.window.Dialog
|
import androidx.compose.ui.window.Dialog
|
||||||
import androidx.compose.ui.window.DialogProperties
|
import androidx.compose.ui.window.DialogProperties
|
||||||
|
import com.google.aiedge.gallery.data.Accelerator
|
||||||
|
import com.google.aiedge.gallery.data.Config
|
||||||
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||||
|
import com.google.aiedge.gallery.data.LabelConfig
|
||||||
|
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||||
|
import com.google.aiedge.gallery.data.NumberSliderConfig
|
||||||
|
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||||
|
import com.google.aiedge.gallery.data.ValueType
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.ConfigEditorsPanel
|
||||||
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
||||||
import com.google.aiedge.gallery.ui.common.humanReadableSize
|
import com.google.aiedge.gallery.ui.common.humanReadableSize
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_MAX_TOKEN
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPK
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPP
|
||||||
import kotlinx.coroutines.CoroutineScope
|
import kotlinx.coroutines.CoroutineScope
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
@ -64,37 +80,151 @@ import java.nio.charset.StandardCharsets
|
||||||
|
|
||||||
private const val TAG = "AGModelImportDialog"
|
private const val TAG = "AGModelImportDialog"
|
||||||
|
|
||||||
data class ModelImportInfo(val fileName: String, val fileSize: Long, val error: String = "")
|
private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
|
||||||
|
LabelConfig(key = ConfigKey.NAME),
|
||||||
|
LabelConfig(key = ConfigKey.MODEL_TYPE),
|
||||||
|
NumberSliderConfig(
|
||||||
|
key = ConfigKey.DEFAULT_MAX_TOKENS,
|
||||||
|
sliderMin = 100f,
|
||||||
|
sliderMax = 1024f,
|
||||||
|
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
|
||||||
|
valueType = ValueType.INT
|
||||||
|
),
|
||||||
|
NumberSliderConfig(
|
||||||
|
key = ConfigKey.DEFAULT_TOPK,
|
||||||
|
sliderMin = 5f,
|
||||||
|
sliderMax = 40f,
|
||||||
|
defaultValue = DEFAULT_TOPK.toFloat(),
|
||||||
|
valueType = ValueType.INT
|
||||||
|
),
|
||||||
|
NumberSliderConfig(
|
||||||
|
key = ConfigKey.DEFAULT_TOPP,
|
||||||
|
sliderMin = 0.0f,
|
||||||
|
sliderMax = 1.0f,
|
||||||
|
defaultValue = DEFAULT_TOPP,
|
||||||
|
valueType = ValueType.FLOAT
|
||||||
|
),
|
||||||
|
NumberSliderConfig(
|
||||||
|
key = ConfigKey.DEFAULT_TEMPERATURE,
|
||||||
|
sliderMin = 0.0f,
|
||||||
|
sliderMax = 2.0f,
|
||||||
|
defaultValue = DEFAULT_TEMPERATURE,
|
||||||
|
valueType = ValueType.FLOAT
|
||||||
|
),
|
||||||
|
SegmentedButtonConfig(
|
||||||
|
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
||||||
|
defaultValue = Accelerator.CPU.label,
|
||||||
|
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
|
||||||
|
allowMultiple = true,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelImportDialog(
|
fun ModelImportDialog(
|
||||||
uri: Uri, onDone: (ModelImportInfo) -> Unit
|
uri: Uri,
|
||||||
|
onDismiss: () -> Unit,
|
||||||
|
onDone: (ImportedModelInfo) -> Unit
|
||||||
) {
|
) {
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
val coroutineScope = rememberCoroutineScope()
|
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
|
||||||
|
val fileSize by remember { mutableLongStateOf(info.first) }
|
||||||
|
val fileName by remember { mutableStateOf(ensureValidFileName(info.second)) }
|
||||||
|
|
||||||
var fileName by remember { mutableStateOf("") }
|
val initialValues: Map<String, Any> = remember {
|
||||||
var fileSize by remember { mutableLongStateOf(0L) }
|
mutableMapOf<String, Any>().apply {
|
||||||
|
for (config in IMPORT_CONFIGS_LLM) {
|
||||||
|
put(config.key.label, config.defaultValue)
|
||||||
|
}
|
||||||
|
put(ConfigKey.NAME.label, fileName)
|
||||||
|
// TODO: support other types.
|
||||||
|
put(ConfigKey.MODEL_TYPE.label, "LLM")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val values: SnapshotStateMap<String, Any> = remember {
|
||||||
|
mutableStateMapOf<String, Any>().apply {
|
||||||
|
putAll(initialValues)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Dialog(
|
||||||
|
onDismissRequest = onDismiss,
|
||||||
|
) {
|
||||||
|
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(20.dp),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
// Title.
|
||||||
|
Text(
|
||||||
|
"Import Model",
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
modifier = Modifier.padding(bottom = 8.dp)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default configs for users to set.
|
||||||
|
ConfigEditorsPanel(
|
||||||
|
configs = IMPORT_CONFIGS_LLM,
|
||||||
|
values = values,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Button row.
|
||||||
|
Row(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(top = 8.dp),
|
||||||
|
horizontalArrangement = Arrangement.End,
|
||||||
|
) {
|
||||||
|
// Cancel button.
|
||||||
|
TextButton(
|
||||||
|
onClick = { onDismiss() },
|
||||||
|
) {
|
||||||
|
Text("Cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import button
|
||||||
|
Button(
|
||||||
|
onClick = {
|
||||||
|
onDone(
|
||||||
|
ImportedModelInfo(
|
||||||
|
fileName = fileName,
|
||||||
|
fileSize = fileSize,
|
||||||
|
defaultValues = values,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
Text("Import")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ModelImportingDialog(
|
||||||
|
uri: Uri,
|
||||||
|
info: ImportedModelInfo,
|
||||||
|
onDismiss: () -> Unit,
|
||||||
|
onDone: (ImportedModelInfo) -> Unit
|
||||||
|
) {
|
||||||
var error by remember { mutableStateOf("") }
|
var error by remember { mutableStateOf("") }
|
||||||
|
val context = LocalContext.current
|
||||||
|
val coroutineScope = rememberCoroutineScope()
|
||||||
var progress by remember { mutableFloatStateOf(0f) }
|
var progress by remember { mutableFloatStateOf(0f) }
|
||||||
|
|
||||||
LaunchedEffect(Unit) {
|
LaunchedEffect(Unit) {
|
||||||
error = ""
|
|
||||||
|
|
||||||
// Get basic info.
|
|
||||||
val info = getFileSizeAndDisplayNameFromUri(context = context, uri = uri)
|
|
||||||
fileSize = info.first
|
|
||||||
fileName = ensureValidFileName(info.second)
|
|
||||||
|
|
||||||
// Import.
|
// Import.
|
||||||
importModel(
|
importModel(
|
||||||
context = context,
|
context = context,
|
||||||
coroutineScope = coroutineScope,
|
coroutineScope = coroutineScope,
|
||||||
fileName = fileName,
|
fileName = info.fileName,
|
||||||
fileSize = fileSize,
|
fileSize = info.fileSize,
|
||||||
uri = uri,
|
uri = uri,
|
||||||
onDone = {
|
onDone = {
|
||||||
onDone(ModelImportInfo(fileName = fileName, fileSize = fileSize, error = error))
|
onDone(info)
|
||||||
},
|
},
|
||||||
onProgress = {
|
onProgress = {
|
||||||
progress = it
|
progress = it
|
||||||
|
@ -107,7 +237,7 @@ fun ModelImportDialog(
|
||||||
|
|
||||||
Dialog(
|
Dialog(
|
||||||
properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false),
|
properties = DialogProperties(dismissOnBackPress = false, dismissOnClickOutside = false),
|
||||||
onDismissRequest = {},
|
onDismissRequest = onDismiss,
|
||||||
) {
|
) {
|
||||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||||
Column(
|
Column(
|
||||||
|
@ -117,7 +247,7 @@ fun ModelImportDialog(
|
||||||
) {
|
) {
|
||||||
// Title.
|
// Title.
|
||||||
Text(
|
Text(
|
||||||
"Importing...",
|
"Import Model",
|
||||||
style = MaterialTheme.typography.titleLarge,
|
style = MaterialTheme.typography.titleLarge,
|
||||||
modifier = Modifier.padding(bottom = 8.dp)
|
modifier = Modifier.padding(bottom = 8.dp)
|
||||||
)
|
)
|
||||||
|
@ -127,7 +257,7 @@ fun ModelImportDialog(
|
||||||
// Progress bar.
|
// Progress bar.
|
||||||
Column(verticalArrangement = Arrangement.spacedBy(4.dp)) {
|
Column(verticalArrangement = Arrangement.spacedBy(4.dp)) {
|
||||||
Text(
|
Text(
|
||||||
"$fileName (${fileSize.humanReadableSize()})",
|
"${info.fileName} (${info.fileSize.humanReadableSize()})",
|
||||||
style = MaterialTheme.typography.labelSmall,
|
style = MaterialTheme.typography.labelSmall,
|
||||||
)
|
)
|
||||||
val animatedProgress = remember { Animatable(0f) }
|
val animatedProgress = remember { Animatable(0f) }
|
||||||
|
@ -162,7 +292,7 @@ fun ModelImportDialog(
|
||||||
}
|
}
|
||||||
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
|
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
|
||||||
Button(onClick = {
|
Button(onClick = {
|
||||||
onDone(ModelImportInfo(fileName = "", fileSize = 0L, error = error))
|
onDismiss()
|
||||||
}) {
|
}) {
|
||||||
Text("Close")
|
Text("Close")
|
||||||
}
|
}
|
|
@ -30,7 +30,7 @@ private val CONFIGS: List<Config> = listOf(
|
||||||
SegmentedButtonConfig(
|
SegmentedButtonConfig(
|
||||||
key = ConfigKey.THEME,
|
key = ConfigKey.THEME,
|
||||||
defaultValue = THEME_AUTO,
|
defaultValue = THEME_AUTO,
|
||||||
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK)
|
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,24 +16,25 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.llmchat
|
package com.google.aiedge.gallery.ui.llmchat
|
||||||
|
|
||||||
|
import com.google.aiedge.gallery.data.Accelerator
|
||||||
import com.google.aiedge.gallery.data.Config
|
import com.google.aiedge.gallery.data.Config
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.ConfigValue
|
|
||||||
import com.google.aiedge.gallery.data.NumberSliderConfig
|
import com.google.aiedge.gallery.data.NumberSliderConfig
|
||||||
|
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||||
import com.google.aiedge.gallery.data.ValueType
|
import com.google.aiedge.gallery.data.ValueType
|
||||||
import com.google.aiedge.gallery.data.getFloatConfigValue
|
|
||||||
import com.google.aiedge.gallery.data.getIntConfigValue
|
|
||||||
|
|
||||||
private const val DEFAULT_MAX_TOKEN = 1024
|
const val DEFAULT_MAX_TOKEN = 1024
|
||||||
private const val DEFAULT_TOPK = 40
|
const val DEFAULT_TOPK = 40
|
||||||
private const val DEFAULT_TOPP = 0.9f
|
const val DEFAULT_TOPP = 0.9f
|
||||||
private const val DEFAULT_TEMPERATURE = 1.0f
|
const val DEFAULT_TEMPERATURE = 1.0f
|
||||||
|
val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
|
||||||
|
|
||||||
fun createLlmChatConfigs(
|
fun createLlmChatConfigs(
|
||||||
defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
|
defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
|
||||||
defaultTopK: Int = DEFAULT_TOPK,
|
defaultTopK: Int = DEFAULT_TOPK,
|
||||||
defaultTopP: Float = DEFAULT_TOPP,
|
defaultTopP: Float = DEFAULT_TOPP,
|
||||||
defaultTemperature: Float = DEFAULT_TEMPERATURE
|
defaultTemperature: Float = DEFAULT_TEMPERATURE,
|
||||||
|
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
|
||||||
): List<Config> {
|
): List<Config> {
|
||||||
return listOf(
|
return listOf(
|
||||||
NumberSliderConfig(
|
NumberSliderConfig(
|
||||||
|
@ -64,21 +65,10 @@ fun createLlmChatConfigs(
|
||||||
defaultValue = defaultTemperature,
|
defaultValue = defaultTemperature,
|
||||||
valueType = ValueType.FLOAT
|
valueType = ValueType.FLOAT
|
||||||
),
|
),
|
||||||
)
|
SegmentedButtonConfig(
|
||||||
}
|
key = ConfigKey.ACCELERATOR,
|
||||||
|
defaultValue = if (accelerators.contains(Accelerator.GPU)) Accelerator.GPU.label else accelerators[0].label,
|
||||||
fun createLLmChatConfig(defaults: Map<String, ConfigValue>): List<Config> {
|
options = accelerators.map { it.label }
|
||||||
val defaultMaxToken =
|
)
|
||||||
getIntConfigValue(defaults[ConfigKey.MAX_TOKENS.id], default = DEFAULT_MAX_TOKEN)
|
|
||||||
val defaultTopK = getIntConfigValue(defaults[ConfigKey.TOPK.id], default = DEFAULT_TOPK)
|
|
||||||
val defaultTopP = getFloatConfigValue(defaults[ConfigKey.TOPP.id], default = DEFAULT_TOPP)
|
|
||||||
val defaultTemperature =
|
|
||||||
getFloatConfigValue(defaults[ConfigKey.TEMPERATURE.id], default = DEFAULT_TEMPERATURE)
|
|
||||||
|
|
||||||
return createLlmChatConfigs(
|
|
||||||
defaultMaxToken = defaultMaxToken,
|
|
||||||
defaultTopK = defaultTopK,
|
|
||||||
defaultTopP = defaultTopP,
|
|
||||||
defaultTemperature = defaultTemperature
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,18 +18,14 @@ package com.google.aiedge.gallery.ui.llmchat
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
|
import com.google.aiedge.gallery.data.Accelerator
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.LlmBackend
|
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
import com.google.aiedge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||||
|
|
||||||
private const val TAG = "AGLlmChatModelHelper"
|
private const val TAG = "AGLlmChatModelHelper"
|
||||||
private const val DEFAULT_MAX_TOKEN = 1024
|
|
||||||
private const val DEFAULT_TOPK = 40
|
|
||||||
private const val DEFAULT_TOPP = 0.9f
|
|
||||||
private const val DEFAULT_TEMPERATURE = 1.0f
|
|
||||||
|
|
||||||
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
|
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
|
||||||
typealias CleanUpListener = () -> Unit
|
typealias CleanUpListener = () -> Unit
|
||||||
|
@ -49,10 +45,13 @@ object LlmChatModelHelper {
|
||||||
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
|
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
|
||||||
val temperature =
|
val temperature =
|
||||||
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
|
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
|
||||||
|
val accelerator =
|
||||||
|
model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = Accelerator.GPU.label)
|
||||||
Log.d(TAG, "Initializing...")
|
Log.d(TAG, "Initializing...")
|
||||||
val preferredBackend = when (model.llmBackend) {
|
val preferredBackend = when (accelerator) {
|
||||||
LlmBackend.CPU -> LlmInference.Backend.CPU
|
Accelerator.CPU.label -> LlmInference.Backend.CPU
|
||||||
LlmBackend.GPU -> LlmInference.Backend.GPU
|
Accelerator.GPU.label -> LlmInference.Backend.GPU
|
||||||
|
else -> LlmInference.Backend.GPU
|
||||||
}
|
}
|
||||||
val options =
|
val options =
|
||||||
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
|
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
|
||||||
|
|
|
@ -16,13 +16,7 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.modelmanager
|
package com.google.aiedge.gallery.ui.modelmanager
|
||||||
|
|
||||||
import android.content.Intent
|
|
||||||
import android.net.Uri
|
|
||||||
import android.os.Build
|
import android.os.Build
|
||||||
import android.util.Log
|
|
||||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
|
||||||
import androidx.activity.result.ActivityResultLauncher
|
|
||||||
import androidx.activity.result.contract.ActivityResultContracts
|
|
||||||
import androidx.annotation.RequiresApi
|
import androidx.annotation.RequiresApi
|
||||||
import androidx.compose.foundation.clickable
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
@ -30,32 +24,21 @@ import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.PaddingValues
|
import androidx.compose.foundation.layout.PaddingValues
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.Spacer
|
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.height
|
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.foundation.lazy.LazyColumn
|
import androidx.compose.foundation.lazy.LazyColumn
|
||||||
import androidx.compose.foundation.lazy.items
|
import androidx.compose.foundation.lazy.items
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
|
||||||
import androidx.compose.material.icons.filled.Add
|
|
||||||
import androidx.compose.material.icons.outlined.Code
|
import androidx.compose.material.icons.outlined.Code
|
||||||
import androidx.compose.material.icons.outlined.Description
|
import androidx.compose.material.icons.outlined.Description
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.ModalBottomSheet
|
|
||||||
import androidx.compose.material3.SmallFloatingActionButton
|
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.material3.rememberModalBottomSheetState
|
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.derivedStateOf
|
import androidx.compose.runtime.derivedStateOf
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableStateOf
|
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
|
||||||
import androidx.compose.runtime.setValue
|
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
|
@ -75,14 +58,11 @@ import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
import com.google.aiedge.gallery.ui.theme.customColors
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
import kotlinx.coroutines.delay
|
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
|
|
||||||
private const val TAG = "AGModelList"
|
private const val TAG = "AGModelList"
|
||||||
|
|
||||||
/** The list of models in the model manager. */
|
/** The list of models in the model manager. */
|
||||||
@RequiresApi(Build.VERSION_CODES.O)
|
@RequiresApi(Build.VERSION_CODES.O)
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelList(
|
fun ModelList(
|
||||||
task: Task,
|
task: Task,
|
||||||
|
@ -91,50 +71,29 @@ fun ModelList(
|
||||||
onModelClicked: (Model) -> Unit,
|
onModelClicked: (Model) -> Unit,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
) {
|
) {
|
||||||
var showAddModelSheet by remember { mutableStateOf(false) }
|
|
||||||
var showImportingDialog by remember { mutableStateOf(false) }
|
|
||||||
val curFileUri = remember { mutableStateOf<Uri?>(null) }
|
|
||||||
val sheetState = rememberModalBottomSheetState()
|
|
||||||
val coroutineScope = rememberCoroutineScope()
|
|
||||||
|
|
||||||
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
|
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
|
||||||
// be properly updated.
|
// be properly updated.
|
||||||
val models by remember {
|
val models by remember {
|
||||||
derivedStateOf {
|
derivedStateOf {
|
||||||
val trigger = task.updateTrigger.value
|
val trigger = task.updateTrigger.value
|
||||||
if (trigger >= 0) {
|
if (trigger >= 0) {
|
||||||
task.models.toList().filter { !it.isLocalModel }
|
task.models.toList().filter { !it.imported }
|
||||||
} else {
|
} else {
|
||||||
listOf()
|
listOf()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val localModels by remember {
|
val importedModels by remember {
|
||||||
derivedStateOf {
|
derivedStateOf {
|
||||||
val trigger = task.updateTrigger.value
|
val trigger = task.updateTrigger.value
|
||||||
if (trigger >= 0) {
|
if (trigger >= 0) {
|
||||||
task.models.toList().filter { it.isLocalModel }
|
task.models.toList().filter { it.imported }
|
||||||
} else {
|
} else {
|
||||||
listOf()
|
listOf()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
|
||||||
contract = ActivityResultContracts.StartActivityForResult()
|
|
||||||
) { result ->
|
|
||||||
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
|
||||||
result.data?.data?.let { uri ->
|
|
||||||
curFileUri.value = uri
|
|
||||||
showImportingDialog = true
|
|
||||||
} ?: run {
|
|
||||||
Log.d(TAG, "No file selected or URI is null.")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Log.d(TAG, "File picking cancelled.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Box(contentAlignment = Alignment.BottomEnd) {
|
Box(contentAlignment = Alignment.BottomEnd) {
|
||||||
LazyColumn(
|
LazyColumn(
|
||||||
modifier = modifier.padding(top = 8.dp),
|
modifier = modifier.padding(top = 8.dp),
|
||||||
|
@ -190,11 +149,11 @@ fun ModelList(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Title for local models.
|
// Title for imported models.
|
||||||
if (localModels.isNotEmpty()) {
|
if (importedModels.isNotEmpty()) {
|
||||||
item(key = "localModelsTitle") {
|
item(key = "importedModelsTitle") {
|
||||||
Text(
|
Text(
|
||||||
"Local models",
|
"Imported models",
|
||||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.padding(horizontal = 16.dp)
|
.padding(horizontal = 16.dp)
|
||||||
|
@ -203,8 +162,8 @@ fun ModelList(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// List of local models within a task.
|
// List of imported models within a task.
|
||||||
items(items = localModels) { model ->
|
items(items = importedModels) { model ->
|
||||||
Box {
|
Box {
|
||||||
ModelItem(
|
ModelItem(
|
||||||
model = model,
|
model = model,
|
||||||
|
@ -215,88 +174,6 @@ fun ModelList(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
item(key = "bottomPadding") {
|
|
||||||
Spacer(modifier = Modifier.height(60.dp))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add model button at the bottom right.
|
|
||||||
Box(
|
|
||||||
modifier = Modifier
|
|
||||||
.padding(end = 16.dp)
|
|
||||||
.padding(bottom = contentPadding.calculateBottomPadding())
|
|
||||||
) {
|
|
||||||
SmallFloatingActionButton(
|
|
||||||
onClick = {
|
|
||||||
showAddModelSheet = true
|
|
||||||
},
|
|
||||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
|
||||||
contentColor = MaterialTheme.colorScheme.secondary,
|
|
||||||
) {
|
|
||||||
Icon(Icons.Filled.Add, "")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (showAddModelSheet) {
|
|
||||||
ModalBottomSheet(
|
|
||||||
onDismissRequest = { showAddModelSheet = false },
|
|
||||||
sheetState = sheetState,
|
|
||||||
) {
|
|
||||||
Text(
|
|
||||||
"Add custom model",
|
|
||||||
style = MaterialTheme.typography.titleLarge,
|
|
||||||
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
|
|
||||||
)
|
|
||||||
Box(modifier = Modifier.clickable {
|
|
||||||
coroutineScope.launch {
|
|
||||||
// Give it sometime to show the click effect.
|
|
||||||
delay(200)
|
|
||||||
showAddModelSheet = false
|
|
||||||
|
|
||||||
// Show file picker.
|
|
||||||
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
|
||||||
addCategory(Intent.CATEGORY_OPENABLE)
|
|
||||||
type = "*/*"
|
|
||||||
putExtra(
|
|
||||||
Intent.EXTRA_MIME_TYPES,
|
|
||||||
arrayOf("application/x-binary", "application/octet-stream")
|
|
||||||
)
|
|
||||||
// Single select.
|
|
||||||
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
|
||||||
}
|
|
||||||
filePickerLauncher.launch(intent)
|
|
||||||
}
|
|
||||||
}) {
|
|
||||||
Row(
|
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
|
||||||
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxWidth()
|
|
||||||
.padding(16.dp)
|
|
||||||
) {
|
|
||||||
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
|
|
||||||
Text("Add local model")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (showImportingDialog) {
|
|
||||||
curFileUri.value?.let { uri ->
|
|
||||||
ModelImportDialog(uri = uri, onDone = { info ->
|
|
||||||
showImportingDialog = false
|
|
||||||
|
|
||||||
if (info.error.isEmpty()) {
|
|
||||||
// TODO: support other model types.
|
|
||||||
modelManagerViewModel.addLocalLlmModel(
|
|
||||||
task = task,
|
|
||||||
fileName = info.fileName,
|
|
||||||
fileSize = info.fileSize
|
|
||||||
)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,8 +23,10 @@ import androidx.activity.result.ActivityResult
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.google.aiedge.gallery.data.AGWorkInfo
|
import com.google.aiedge.gallery.data.AGWorkInfo
|
||||||
|
import com.google.aiedge.gallery.data.Accelerator
|
||||||
import com.google.aiedge.gallery.data.AccessTokenData
|
import com.google.aiedge.gallery.data.AccessTokenData
|
||||||
import com.google.aiedge.gallery.data.Config
|
import com.google.aiedge.gallery.data.Config
|
||||||
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.DataStoreRepository
|
import com.google.aiedge.gallery.data.DataStoreRepository
|
||||||
import com.google.aiedge.gallery.data.DownloadRepository
|
import com.google.aiedge.gallery.data.DownloadRepository
|
||||||
import com.google.aiedge.gallery.data.EMPTY_MODEL
|
import com.google.aiedge.gallery.data.EMPTY_MODEL
|
||||||
|
@ -32,7 +34,7 @@ import com.google.aiedge.gallery.data.HfModel
|
||||||
import com.google.aiedge.gallery.data.HfModelDetails
|
import com.google.aiedge.gallery.data.HfModelDetails
|
||||||
import com.google.aiedge.gallery.data.HfModelSummary
|
import com.google.aiedge.gallery.data.HfModelSummary
|
||||||
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||||
import com.google.aiedge.gallery.data.LocalModelInfo
|
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
|
@ -40,12 +42,14 @@ import com.google.aiedge.gallery.data.TASKS
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
|
import com.google.aiedge.gallery.data.ValueType
|
||||||
import com.google.aiedge.gallery.data.getModelByName
|
import com.google.aiedge.gallery.data.getModelByName
|
||||||
import com.google.aiedge.gallery.ui.common.AuthConfig
|
import com.google.aiedge.gallery.ui.common.AuthConfig
|
||||||
|
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||||
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
|
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
|
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.async
|
import kotlinx.coroutines.async
|
||||||
|
@ -228,7 +232,7 @@ open class ModelManagerViewModel(
|
||||||
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
|
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
|
||||||
|
|
||||||
// Delete model from the list if model is imported as a local model.
|
// Delete model from the list if model is imported as a local model.
|
||||||
if (model.isLocalModel) {
|
if (model.imported) {
|
||||||
val index = task.models.indexOf(model)
|
val index = task.models.indexOf(model)
|
||||||
if (index >= 0) {
|
if (index >= 0) {
|
||||||
task.models.removeAt(index)
|
task.models.removeAt(index)
|
||||||
|
@ -237,12 +241,12 @@ open class ModelManagerViewModel(
|
||||||
curModelDownloadStatus.remove(model.name)
|
curModelDownloadStatus.remove(model.name)
|
||||||
|
|
||||||
// Update preference.
|
// Update preference.
|
||||||
val localModels = dataStoreRepository.readLocalModels().toMutableList()
|
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
|
||||||
val localModelIndex = localModels.indexOfFirst { it.fileName == model.name }
|
val importedModelIndex = importedModels.indexOfFirst { it.fileName == model.name }
|
||||||
if (localModelIndex >= 0) {
|
if (importedModelIndex >= 0) {
|
||||||
localModels.removeAt(localModelIndex)
|
importedModels.removeAt(importedModelIndex)
|
||||||
}
|
}
|
||||||
dataStoreRepository.saveLocalModels(localModels = localModels)
|
dataStoreRepository.saveImportedModels(importedModels = importedModels)
|
||||||
}
|
}
|
||||||
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
|
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
|
||||||
_uiState.update { newUiState }
|
_uiState.update { newUiState }
|
||||||
|
@ -417,27 +421,20 @@ open class ModelManagerViewModel(
|
||||||
return connection.responseCode
|
return connection.responseCode
|
||||||
}
|
}
|
||||||
|
|
||||||
fun addLocalLlmModel(task: Task, fileName: String, fileSize: Long) {
|
fun addImportedLlmModel(task: Task, info: ImportedModelInfo) {
|
||||||
Log.d(TAG, "adding local model: $fileName, $fileSize")
|
Log.d(TAG, "adding imported llm model: $info")
|
||||||
|
|
||||||
// Create model.
|
// Create model.
|
||||||
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
|
val model = createModelFromImportedModelInfo(info = info, task = task)
|
||||||
val model = Model(
|
|
||||||
name = fileName,
|
|
||||||
url = "",
|
|
||||||
configs = configs,
|
|
||||||
sizeInBytes = fileSize,
|
|
||||||
downloadFileName = "$IMPORTS_DIR/$fileName",
|
|
||||||
isLocalModel = true,
|
|
||||||
)
|
|
||||||
model.preProcess(task = task)
|
|
||||||
task.models.add(model)
|
task.models.add(model)
|
||||||
|
|
||||||
// Add initial status and states.
|
// Add initial status and states.
|
||||||
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
||||||
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
|
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
|
||||||
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||||
status = ModelDownloadStatusType.SUCCEEDED, receivedBytes = fileSize, totalBytes = fileSize
|
status = ModelDownloadStatusType.SUCCEEDED,
|
||||||
|
receivedBytes = info.fileSize,
|
||||||
|
totalBytes = info.fileSize
|
||||||
)
|
)
|
||||||
modelInstances[model.name] =
|
modelInstances[model.name] =
|
||||||
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
||||||
|
@ -453,9 +450,9 @@ open class ModelManagerViewModel(
|
||||||
task.updateTrigger.value = System.currentTimeMillis()
|
task.updateTrigger.value = System.currentTimeMillis()
|
||||||
|
|
||||||
// Add to preference storage.
|
// Add to preference storage.
|
||||||
val localModels = dataStoreRepository.readLocalModels().toMutableList()
|
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
|
||||||
localModels.add(LocalModelInfo(fileName = fileName, fileSize = fileSize))
|
importedModels.add(info)
|
||||||
dataStoreRepository.saveLocalModels(localModels = localModels)
|
dataStoreRepository.saveImportedModels(importedModels = importedModels)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun getTokenStatusAndData(): TokenStatusAndData {
|
fun getTokenStatusAndData(): TokenStatusAndData {
|
||||||
|
@ -589,31 +586,22 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load local models.
|
// Load imported models.
|
||||||
for (localModel in dataStoreRepository.readLocalModels()) {
|
for (importedModel in dataStoreRepository.readImportedModels()) {
|
||||||
Log.d(TAG, "stored local model: $localModel")
|
Log.d(TAG, "stored imported model: $importedModel")
|
||||||
|
|
||||||
// Create model.
|
// Create model.
|
||||||
val configs: List<Config> = createLLmChatConfig(defaults = mapOf())
|
val model = createModelFromImportedModelInfo(info = importedModel, task = TASK_LLM_CHAT)
|
||||||
val model = Model(
|
|
||||||
name = localModel.fileName,
|
|
||||||
url = "",
|
|
||||||
configs = configs,
|
|
||||||
sizeInBytes = localModel.fileSize,
|
|
||||||
downloadFileName = "$IMPORTS_DIR/${localModel.fileName}",
|
|
||||||
isLocalModel = true,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Add to task.
|
// Add to task.
|
||||||
val task = TASK_LLM_CHAT
|
val task = TASK_LLM_CHAT
|
||||||
model.preProcess(task = task)
|
|
||||||
task.models.add(model)
|
task.models.add(model)
|
||||||
|
|
||||||
// Update status.
|
// Update status.
|
||||||
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||||
status = ModelDownloadStatusType.SUCCEEDED,
|
status = ModelDownloadStatusType.SUCCEEDED,
|
||||||
receivedBytes = localModel.fileSize,
|
receivedBytes = importedModel.fileSize,
|
||||||
totalBytes = localModel.fileSize
|
totalBytes = importedModel.fileSize
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -628,6 +616,51 @@ open class ModelManagerViewModel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
|
||||||
|
val accelerators: List<Accelerator> = (convertValueToTargetType(
|
||||||
|
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!,
|
||||||
|
ValueType.STRING
|
||||||
|
) as String)
|
||||||
|
.split(",")
|
||||||
|
.mapNotNull { acceleratorLabel ->
|
||||||
|
when (acceleratorLabel.trim()) {
|
||||||
|
Accelerator.GPU.label -> Accelerator.GPU
|
||||||
|
Accelerator.CPU.label -> Accelerator.CPU
|
||||||
|
else -> null // Ignore unknown accelerator labels
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val configs: List<Config> = createLlmChatConfigs(
|
||||||
|
defaultMaxToken = convertValueToTargetType(
|
||||||
|
info.defaultValues[ConfigKey.DEFAULT_MAX_TOKENS.label]!!,
|
||||||
|
ValueType.INT
|
||||||
|
) as Int,
|
||||||
|
defaultTopK = convertValueToTargetType(
|
||||||
|
info.defaultValues[ConfigKey.DEFAULT_TOPK.label]!!,
|
||||||
|
ValueType.INT
|
||||||
|
) as Int,
|
||||||
|
defaultTopP = convertValueToTargetType(
|
||||||
|
info.defaultValues[ConfigKey.DEFAULT_TOPP.label]!!,
|
||||||
|
ValueType.FLOAT
|
||||||
|
) as Float,
|
||||||
|
defaultTemperature = convertValueToTargetType(
|
||||||
|
info.defaultValues[ConfigKey.DEFAULT_TEMPERATURE.label]!!,
|
||||||
|
ValueType.FLOAT
|
||||||
|
) as Float,
|
||||||
|
accelerators = accelerators,
|
||||||
|
)
|
||||||
|
val model = Model(
|
||||||
|
name = info.fileName,
|
||||||
|
url = "",
|
||||||
|
configs = configs,
|
||||||
|
sizeInBytes = info.fileSize,
|
||||||
|
downloadFileName = "$IMPORTS_DIR/${info.fileName}",
|
||||||
|
imported = true,
|
||||||
|
)
|
||||||
|
model.preProcess(task = task)
|
||||||
|
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrieves the download status of a model.
|
* Retrieves the download status of a model.
|
||||||
*
|
*
|
||||||
|
@ -771,9 +804,7 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun updateModelInitializationStatus(
|
private fun updateModelInitializationStatus(
|
||||||
model: Model,
|
model: Model, status: ModelInitializationStatusType, error: String = ""
|
||||||
status: ModelInitializationStatusType,
|
|
||||||
error: String = ""
|
|
||||||
) {
|
) {
|
||||||
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
|
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
|
||||||
curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error)
|
curModelInstance[model.name] = ModelInitializationStatus(status = status, error = error)
|
||||||
|
|
|
@ -18,7 +18,7 @@ package com.google.aiedge.gallery.ui.preview
|
||||||
|
|
||||||
import com.google.aiedge.gallery.data.AccessTokenData
|
import com.google.aiedge.gallery.data.AccessTokenData
|
||||||
import com.google.aiedge.gallery.data.DataStoreRepository
|
import com.google.aiedge.gallery.data.DataStoreRepository
|
||||||
import com.google.aiedge.gallery.data.LocalModelInfo
|
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||||
|
|
||||||
class PreviewDataStoreRepository : DataStoreRepository {
|
class PreviewDataStoreRepository : DataStoreRepository {
|
||||||
override fun saveTextInputHistory(history: List<String>) {
|
override fun saveTextInputHistory(history: List<String>) {
|
||||||
|
@ -42,10 +42,10 @@ class PreviewDataStoreRepository : DataStoreRepository {
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun saveLocalModels(localModels: List<LocalModelInfo>) {
|
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun readLocalModels(): List<LocalModelInfo> {
|
override fun readImportedModels(): List<ImportedModelInfo> {
|
||||||
return listOf()
|
return listOf()
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -22,6 +22,7 @@ import androidx.compose.material.icons.rounded.AutoAwesome
|
||||||
import com.google.aiedge.gallery.data.BooleanSwitchConfig
|
import com.google.aiedge.gallery.data.BooleanSwitchConfig
|
||||||
import com.google.aiedge.gallery.data.Config
|
import com.google.aiedge.gallery.data.Config
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
import com.google.aiedge.gallery.data.ConfigKey
|
||||||
|
import com.google.aiedge.gallery.data.LabelConfig
|
||||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.NumberSliderConfig
|
import com.google.aiedge.gallery.data.NumberSliderConfig
|
||||||
|
@ -30,6 +31,10 @@ import com.google.aiedge.gallery.data.TaskType
|
||||||
import com.google.aiedge.gallery.data.ValueType
|
import com.google.aiedge.gallery.data.ValueType
|
||||||
|
|
||||||
val TEST_CONFIGS1: List<Config> = listOf(
|
val TEST_CONFIGS1: List<Config> = listOf(
|
||||||
|
LabelConfig(
|
||||||
|
key = ConfigKey.NAME,
|
||||||
|
defaultValue = "Test name",
|
||||||
|
),
|
||||||
NumberSliderConfig(
|
NumberSliderConfig(
|
||||||
key = ConfigKey.MAX_RESULT_COUNT,
|
key = ConfigKey.MAX_RESULT_COUNT,
|
||||||
sliderMin = 1f,
|
sliderMin = 1f,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue