mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-06 06:30:30 -04:00
Add initial support for model allowlist, and stop generating response.
This commit is contained in:
parent
f9f1d71b38
commit
ef290cd7b0
28 changed files with 676 additions and 358 deletions
|
@ -32,7 +32,7 @@
|
||||||
android:dataExtractionRules="@xml/data_extraction_rules"
|
android:dataExtractionRules="@xml/data_extraction_rules"
|
||||||
android:fullBackupContent="@xml/backup_rules"
|
android:fullBackupContent="@xml/backup_rules"
|
||||||
android:icon="@mipmap/ic_launcher"
|
android:icon="@mipmap/ic_launcher"
|
||||||
android:label="@string/app_name"
|
android:label="Edge Gallery"
|
||||||
android:roundIcon="@mipmap/ic_launcher"
|
android:roundIcon="@mipmap/ic_launcher"
|
||||||
android:supportsRtl="true"
|
android:supportsRtl="true"
|
||||||
android:theme="@style/Theme.Gallery"
|
android:theme="@style/Theme.Gallery"
|
||||||
|
|
|
@ -23,7 +23,6 @@ import androidx.datastore.preferences.core.Preferences
|
||||||
import androidx.datastore.preferences.preferencesDataStore
|
import androidx.datastore.preferences.preferencesDataStore
|
||||||
import com.google.aiedge.gallery.data.AppContainer
|
import com.google.aiedge.gallery.data.AppContainer
|
||||||
import com.google.aiedge.gallery.data.DefaultAppContainer
|
import com.google.aiedge.gallery.data.DefaultAppContainer
|
||||||
import com.google.aiedge.gallery.ui.common.processTasks
|
|
||||||
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
||||||
|
|
||||||
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
|
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
|
||||||
|
@ -35,9 +34,6 @@ class GalleryApplication : Application() {
|
||||||
override fun onCreate() {
|
override fun onCreate() {
|
||||||
super.onCreate()
|
super.onCreate()
|
||||||
|
|
||||||
// Process tasks.
|
|
||||||
processTasks()
|
|
||||||
|
|
||||||
container = DefaultAppContainer(this, dataStore)
|
container = DefaultAppContainer(this, dataStore)
|
||||||
|
|
||||||
// Load theme.
|
// Load theme.
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.data
|
package com.google.aiedge.gallery.data
|
||||||
|
|
||||||
import com.google.aiedge.gallery.ui.common.ensureValidFileName
|
|
||||||
import kotlinx.serialization.KSerializer
|
import kotlinx.serialization.KSerializer
|
||||||
import kotlinx.serialization.Serializable
|
import kotlinx.serialization.Serializable
|
||||||
import kotlinx.serialization.SerializationException
|
import kotlinx.serialization.SerializationException
|
||||||
|
@ -27,15 +26,6 @@ import kotlinx.serialization.encoding.Encoder
|
||||||
import kotlinx.serialization.json.JsonDecoder
|
import kotlinx.serialization.json.JsonDecoder
|
||||||
import kotlinx.serialization.json.JsonPrimitive
|
import kotlinx.serialization.json.JsonPrimitive
|
||||||
|
|
||||||
@Serializable
|
|
||||||
data class HfModelSummary(val modelId: String)
|
|
||||||
|
|
||||||
@Serializable
|
|
||||||
data class HfModelDetails(val id: String, val siblings: List<HfModelFile>)
|
|
||||||
|
|
||||||
@Serializable
|
|
||||||
data class HfModelFile(val rfilename: String)
|
|
||||||
|
|
||||||
@Serializable(with = ConfigValueSerializer::class)
|
@Serializable(with = ConfigValueSerializer::class)
|
||||||
sealed class ConfigValue {
|
sealed class ConfigValue {
|
||||||
@Serializable
|
@Serializable
|
||||||
|
@ -85,64 +75,6 @@ object ConfigValueSerializer : KSerializer<ConfigValue> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Serializable
|
|
||||||
data class HfModel(
|
|
||||||
var id: String = "",
|
|
||||||
val task: String,
|
|
||||||
val name: String,
|
|
||||||
val url: String = "",
|
|
||||||
val file: String = "",
|
|
||||||
val sizeInBytes: Long,
|
|
||||||
val configs: Map<String, ConfigValue>,
|
|
||||||
) {
|
|
||||||
fun toModel(): Model {
|
|
||||||
val parts = if (url.isNotEmpty()) {
|
|
||||||
url.split('/')
|
|
||||||
} else if (file.isNotEmpty()) {
|
|
||||||
listOf(file)
|
|
||||||
} else {
|
|
||||||
listOf("")
|
|
||||||
}
|
|
||||||
val fileName = ensureValidFileName("${id}_${(parts.lastOrNull() ?: "")}")
|
|
||||||
|
|
||||||
// Generate configs based on the given default values.
|
|
||||||
// val configs: List<Config> = when (task) {
|
|
||||||
// TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
|
|
||||||
// // todo: add configs for other types.
|
|
||||||
// else -> listOf()
|
|
||||||
// }
|
|
||||||
// todo: fix when loading from models.json
|
|
||||||
val configs: List<Config> = listOf()
|
|
||||||
|
|
||||||
// Construct url.
|
|
||||||
var modelUrl = url
|
|
||||||
if (modelUrl.isEmpty() && file.isNotEmpty()) {
|
|
||||||
modelUrl = "https://huggingface.co/${id}/resolve/main/${file}?download=true"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Other parameters.
|
|
||||||
val showBenchmarkButton = when (task) {
|
|
||||||
TASK_LLM_CHAT.type.label -> false
|
|
||||||
else -> true
|
|
||||||
}
|
|
||||||
val showRunAgainButton = when (task) {
|
|
||||||
TASK_LLM_CHAT.type.label -> false
|
|
||||||
else -> true
|
|
||||||
}
|
|
||||||
|
|
||||||
return Model(
|
|
||||||
hfModelId = id,
|
|
||||||
name = name,
|
|
||||||
url = modelUrl,
|
|
||||||
sizeInBytes = sizeInBytes,
|
|
||||||
downloadFileName = fileName,
|
|
||||||
configs = configs,
|
|
||||||
showBenchmarkButton = showBenchmarkButton,
|
|
||||||
showRunAgainButton = showRunAgainButton,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int {
|
fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int {
|
||||||
if (configValue == null) {
|
if (configValue == null) {
|
||||||
return default
|
return default
|
|
@ -18,6 +18,8 @@ package com.google.aiedge.gallery.data
|
||||||
|
|
||||||
// Keys used to send/receive data to Work.
|
// Keys used to send/receive data to Work.
|
||||||
const val KEY_MODEL_URL = "KEY_MODEL_URL"
|
const val KEY_MODEL_URL = "KEY_MODEL_URL"
|
||||||
|
const val KEY_MODEL_VERSION = "KEY_MODEL_VERSION"
|
||||||
|
const val KEY_MODEL_DOWNLOAD_MODEL_DIR = "KEY_MODEL_DOWNLOAD_MODEL_DIR"
|
||||||
const val KEY_MODEL_DOWNLOAD_FILE_NAME = "KEY_MODEL_DOWNLOAD_FILE_NAME"
|
const val KEY_MODEL_DOWNLOAD_FILE_NAME = "KEY_MODEL_DOWNLOAD_FILE_NAME"
|
||||||
const val KEY_MODEL_TOTAL_BYTES = "KEY_MODEL_TOTAL_BYTES"
|
const val KEY_MODEL_TOTAL_BYTES = "KEY_MODEL_TOTAL_BYTES"
|
||||||
const val KEY_MODEL_DOWNLOAD_RECEIVED_BYTES = "KEY_MODEL_DOWNLOAD_RECEIVED_BYTES"
|
const val KEY_MODEL_DOWNLOAD_RECEIVED_BYTES = "KEY_MODEL_DOWNLOAD_RECEIVED_BYTES"
|
||||||
|
|
|
@ -37,13 +37,13 @@ import androidx.work.OutOfQuotaPolicy
|
||||||
import androidx.work.WorkInfo
|
import androidx.work.WorkInfo
|
||||||
import androidx.work.WorkManager
|
import androidx.work.WorkManager
|
||||||
import androidx.work.WorkQuery
|
import androidx.work.WorkQuery
|
||||||
|
import com.google.aiedge.gallery.AppLifecycleProvider
|
||||||
|
import com.google.aiedge.gallery.R
|
||||||
|
import com.google.aiedge.gallery.worker.DownloadWorker
|
||||||
import com.google.common.util.concurrent.FutureCallback
|
import com.google.common.util.concurrent.FutureCallback
|
||||||
import com.google.common.util.concurrent.Futures
|
import com.google.common.util.concurrent.Futures
|
||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import com.google.common.util.concurrent.MoreExecutors
|
import com.google.common.util.concurrent.MoreExecutors
|
||||||
import com.google.aiedge.gallery.AppLifecycleProvider
|
|
||||||
import com.google.aiedge.gallery.R
|
|
||||||
import com.google.aiedge.gallery.worker.DownloadWorker
|
|
||||||
import java.util.UUID
|
import java.util.UUID
|
||||||
|
|
||||||
private const val TAG = "AGDownloadRepository"
|
private const val TAG = "AGDownloadRepository"
|
||||||
|
@ -89,6 +89,8 @@ class DefaultDownloadRepository(
|
||||||
val builder = Data.Builder()
|
val builder = Data.Builder()
|
||||||
val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes }
|
val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes }
|
||||||
val inputDataBuilder = builder.putString(KEY_MODEL_URL, model.url)
|
val inputDataBuilder = builder.putString(KEY_MODEL_URL, model.url)
|
||||||
|
.putString(KEY_MODEL_VERSION, model.version)
|
||||||
|
.putString(KEY_MODEL_DOWNLOAD_MODEL_DIR, model.normalizedName)
|
||||||
.putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName)
|
.putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName)
|
||||||
.putBoolean(KEY_MODEL_IS_ZIP, model.isZip).putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir)
|
.putBoolean(KEY_MODEL_IS_ZIP, model.isZip).putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir)
|
||||||
.putLong(
|
.putLong(
|
||||||
|
|
|
@ -20,6 +20,7 @@ import android.content.Context
|
||||||
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
|
||||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||||
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
data class ModelDataFile(
|
data class ModelDataFile(
|
||||||
val name: String,
|
val name: String,
|
||||||
|
@ -33,16 +34,22 @@ enum class Accelerator(val label: String) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const val IMPORTS_DIR = "__imports"
|
const val IMPORTS_DIR = "__imports"
|
||||||
|
private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]")
|
||||||
|
|
||||||
/** A model for a task */
|
/** A model for a task */
|
||||||
data class Model(
|
data class Model(
|
||||||
/** The Hugging Face model ID (if applicable). */
|
|
||||||
val hfModelId: String = "",
|
|
||||||
|
|
||||||
/** The name (for display purpose) of the model. */
|
/** The name (for display purpose) of the model. */
|
||||||
val name: String,
|
val name: String,
|
||||||
|
|
||||||
/** The name of the downloaded model file. */
|
/** The version of the model. */
|
||||||
|
val version: String = "_",
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The name of the downloaded model file.
|
||||||
|
*
|
||||||
|
* The final file path of the downloaded model will be:
|
||||||
|
* {context.getExternalFilesDir}/{normalizedName}/{version}/{downloadFileName}
|
||||||
|
*/
|
||||||
val downloadFileName: String,
|
val downloadFileName: String,
|
||||||
|
|
||||||
/** The URL to download the model from. */
|
/** The URL to download the model from. */
|
||||||
|
@ -88,6 +95,7 @@ data class Model(
|
||||||
val imported: Boolean = false,
|
val imported: Boolean = false,
|
||||||
|
|
||||||
// The following fields are managed by the app. Don't need to set manually.
|
// The following fields are managed by the app. Don't need to set manually.
|
||||||
|
var normalizedName: String = "",
|
||||||
var instance: Any? = null,
|
var instance: Any? = null,
|
||||||
var initializing: Boolean = false,
|
var initializing: Boolean = false,
|
||||||
// TODO(jingjin): use a "queue" system to manage model init and cleanup.
|
// TODO(jingjin): use a "queue" system to manage model init and cleanup.
|
||||||
|
@ -96,6 +104,10 @@ data class Model(
|
||||||
var totalBytes: Long = 0L,
|
var totalBytes: Long = 0L,
|
||||||
var accessToken: String? = null,
|
var accessToken: String? = null,
|
||||||
) {
|
) {
|
||||||
|
init {
|
||||||
|
normalizedName = NORMALIZE_NAME_REGEX.replace(name, "_")
|
||||||
|
}
|
||||||
|
|
||||||
fun preProcess() {
|
fun preProcess() {
|
||||||
val configValues: MutableMap<String, Any> = mutableMapOf()
|
val configValues: MutableMap<String, Any> = mutableMapOf()
|
||||||
for (config in this.configs) {
|
for (config in this.configs) {
|
||||||
|
@ -106,11 +118,22 @@ data class Model(
|
||||||
}
|
}
|
||||||
|
|
||||||
fun getPath(context: Context, fileName: String = downloadFileName): String {
|
fun getPath(context: Context, fileName: String = downloadFileName): String {
|
||||||
val baseDir = "${context.getExternalFilesDir(null)}"
|
if (imported) {
|
||||||
|
return listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", fileName).joinToString(
|
||||||
|
File.separator
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
val baseDir =
|
||||||
|
listOf(
|
||||||
|
context.getExternalFilesDir(null)?.absolutePath ?: "",
|
||||||
|
normalizedName,
|
||||||
|
version
|
||||||
|
).joinToString(File.separator)
|
||||||
return if (this.isZip && this.unzipDir.isNotEmpty()) {
|
return if (this.isZip && this.unzipDir.isNotEmpty()) {
|
||||||
"$baseDir/${this.unzipDir}"
|
"$baseDir/${this.unzipDir}"
|
||||||
} else {
|
} else {
|
||||||
"$baseDir/${fileName}"
|
"$baseDir/$fileName"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.data
|
||||||
|
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPK
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.DEFAULT_TOPP
|
||||||
|
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||||
|
import kotlinx.serialization.Serializable
|
||||||
|
|
||||||
|
/** A model in the model allowlist. */
|
||||||
|
@Serializable
|
||||||
|
data class AllowedModel(
|
||||||
|
val name: String,
|
||||||
|
val modelId: String,
|
||||||
|
val modelFile: String,
|
||||||
|
val description: String,
|
||||||
|
val sizeInBytes: Long,
|
||||||
|
val version: String,
|
||||||
|
val defaultConfig: Map<String, ConfigValue>,
|
||||||
|
val taskTypes: List<String>,
|
||||||
|
val disabled: Boolean? = null,
|
||||||
|
) {
|
||||||
|
fun toModel(): Model {
|
||||||
|
// Construct HF download url.
|
||||||
|
val downloadUrl = "https://huggingface.co/$modelId/resolve/main/$modelFile?download=true"
|
||||||
|
|
||||||
|
// Config.
|
||||||
|
val isLlmModel =
|
||||||
|
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)
|
||||||
|
var configs: List<Config> = listOf()
|
||||||
|
if (isLlmModel) {
|
||||||
|
var defaultTopK: Int = DEFAULT_TOPK
|
||||||
|
var defaultTopP: Float = DEFAULT_TOPP
|
||||||
|
var defaultTemperature: Float = DEFAULT_TEMPERATURE
|
||||||
|
var defaultMaxToken: Int = 1024
|
||||||
|
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
|
||||||
|
if (defaultConfig.containsKey("topK")) {
|
||||||
|
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK)
|
||||||
|
}
|
||||||
|
if (defaultConfig.containsKey("topP")) {
|
||||||
|
defaultTopP = getFloatConfigValue(defaultConfig["topP"], defaultTopP)
|
||||||
|
}
|
||||||
|
if (defaultConfig.containsKey("temperature")) {
|
||||||
|
defaultTemperature = getFloatConfigValue(defaultConfig["temperature"], defaultTemperature)
|
||||||
|
}
|
||||||
|
if (defaultConfig.containsKey("maxTokens")) {
|
||||||
|
defaultMaxToken = getIntConfigValue(defaultConfig["maxTokens"], defaultMaxToken)
|
||||||
|
}
|
||||||
|
if (defaultConfig.containsKey("accelerators")) {
|
||||||
|
val items = getStringConfigValue(defaultConfig["accelerators"], "gpu").split(",")
|
||||||
|
accelerators = mutableListOf()
|
||||||
|
for (item in items) {
|
||||||
|
if (item == "cpu") {
|
||||||
|
accelerators.add(Accelerator.CPU)
|
||||||
|
} else if (item == "gpu") {
|
||||||
|
accelerators.add(Accelerator.GPU)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
configs = createLlmChatConfigs(
|
||||||
|
defaultTopK = defaultTopK,
|
||||||
|
defaultTopP = defaultTopP,
|
||||||
|
defaultTemperature = defaultTemperature,
|
||||||
|
defaultMaxToken = defaultMaxToken,
|
||||||
|
accelerators = accelerators,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Misc.
|
||||||
|
var showBenchmarkButton = true
|
||||||
|
val showRunAgainButton = true
|
||||||
|
if (taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)) {
|
||||||
|
showBenchmarkButton = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return Model(
|
||||||
|
name = name,
|
||||||
|
version = version,
|
||||||
|
info = description,
|
||||||
|
url = downloadUrl,
|
||||||
|
sizeInBytes = sizeInBytes,
|
||||||
|
configs = configs,
|
||||||
|
downloadFileName = modelFile,
|
||||||
|
showBenchmarkButton = showBenchmarkButton,
|
||||||
|
showRunAgainButton = showRunAgainButton,
|
||||||
|
learnMoreUrl = "https://huggingface.co/${modelId}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
return "$modelId/$modelFile"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The model allowlist. */
|
||||||
|
@Serializable
|
||||||
|
data class ModelAllowlist(
|
||||||
|
val models: List<AllowedModel>,
|
||||||
|
)
|
||||||
|
|
|
@ -27,15 +27,15 @@ import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
|
|
||||||
/** Type of task. */
|
/** Type of task. */
|
||||||
enum class TaskType(val label: String) {
|
enum class TaskType(val label: String, val id: String) {
|
||||||
TEXT_CLASSIFICATION("Text Classification"),
|
TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
|
||||||
IMAGE_CLASSIFICATION("Image Classification"),
|
IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
|
||||||
IMAGE_GENERATION("Image Generation"),
|
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
|
||||||
LLM_CHAT("LLM Chat"),
|
LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
|
||||||
LLM_SINGLE_TURN("LLM Use Cases"),
|
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
|
||||||
|
|
||||||
TEST_TASK_1("Test task 1"),
|
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
||||||
TEST_TASK_2("Test task 2")
|
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Data class for a task listed in home screen. */
|
/** Data class for a task listed in home screen. */
|
||||||
|
@ -91,17 +91,19 @@ val TASK_IMAGE_CLASSIFICATION = Task(
|
||||||
val TASK_LLM_CHAT = Task(
|
val TASK_LLM_CHAT = Task(
|
||||||
type = TaskType.LLM_CHAT,
|
type = TaskType.LLM_CHAT,
|
||||||
icon = Icons.Outlined.Forum,
|
icon = Icons.Outlined.Forum,
|
||||||
models = MODELS_LLM,
|
// models = MODELS_LLM,
|
||||||
|
models = mutableListOf(),
|
||||||
description = "Chat with a on-device large language model",
|
description = "Chat with a on-device large language model",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||||
)
|
)
|
||||||
|
|
||||||
val TASK_LLM_SINGLE_TURN = Task(
|
val TASK_LLM_USECASES = Task(
|
||||||
type = TaskType.LLM_SINGLE_TURN,
|
type = TaskType.LLM_USECASES,
|
||||||
icon = Icons.Outlined.Widgets,
|
icon = Icons.Outlined.Widgets,
|
||||||
models = MODELS_LLM,
|
// models = MODELS_LLM,
|
||||||
|
models = mutableListOf(),
|
||||||
description = "Single turn use cases with on-device large language model",
|
description = "Single turn use cases with on-device large language model",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||||
|
@ -123,7 +125,7 @@ val TASKS: List<Task> = listOf(
|
||||||
// TASK_TEXT_CLASSIFICATION,
|
// TASK_TEXT_CLASSIFICATION,
|
||||||
// TASK_IMAGE_CLASSIFICATION,
|
// TASK_IMAGE_CLASSIFICATION,
|
||||||
// TASK_IMAGE_GENERATION,
|
// TASK_IMAGE_GENERATION,
|
||||||
TASK_LLM_SINGLE_TURN,
|
TASK_LLM_USECASES,
|
||||||
TASK_LLM_CHAT,
|
TASK_LLM_CHAT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -28,12 +28,15 @@ import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.wrapContentHeight
|
import androidx.compose.foundation.layout.wrapContentHeight
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.automirrored.rounded.ArrowForward
|
import androidx.compose.material.icons.automirrored.rounded.ArrowForward
|
||||||
|
import androidx.compose.material.icons.rounded.Error
|
||||||
|
import androidx.compose.material3.AlertDialog
|
||||||
import androidx.compose.material3.Button
|
import androidx.compose.material3.Button
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.ModalBottomSheet
|
import androidx.compose.material3.ModalBottomSheet
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.material3.TextButton
|
||||||
import androidx.compose.material3.rememberModalBottomSheetState
|
import androidx.compose.material3.rememberModalBottomSheetState
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
|
@ -102,6 +105,7 @@ fun DownloadAndTryButton(
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
var checkingToken by remember { mutableStateOf(false) }
|
var checkingToken by remember { mutableStateOf(false) }
|
||||||
var showAgreementAckSheet by remember { mutableStateOf(false) }
|
var showAgreementAckSheet by remember { mutableStateOf(false) }
|
||||||
|
var showErrorDialog by remember { mutableStateOf(false) }
|
||||||
val sheetState = rememberModalBottomSheetState()
|
val sheetState = rememberModalBottomSheetState()
|
||||||
|
|
||||||
// A launcher for requesting notification permission.
|
// A launcher for requesting notification permission.
|
||||||
|
@ -208,12 +212,18 @@ fun DownloadAndTryButton(
|
||||||
TAG,
|
TAG,
|
||||||
"Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download"
|
"Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download"
|
||||||
)
|
)
|
||||||
if (modelManagerViewModel.getModelUrlResponse(model = model) == HttpURLConnection.HTTP_OK) {
|
val firstResponseCode = modelManagerViewModel.getModelUrlResponse(model = model)
|
||||||
|
if (firstResponseCode == HttpURLConnection.HTTP_OK) {
|
||||||
Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...")
|
Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...")
|
||||||
withContext(Dispatchers.Main) {
|
withContext(Dispatchers.Main) {
|
||||||
startDownload(null)
|
startDownload(null)
|
||||||
}
|
}
|
||||||
return@launch
|
return@launch
|
||||||
|
} else if (firstResponseCode < 0) {
|
||||||
|
checkingToken = false
|
||||||
|
Log.e(TAG, "Unknown network error")
|
||||||
|
showErrorDialog = true
|
||||||
|
return@launch
|
||||||
}
|
}
|
||||||
Log.d(TAG, "Model '${model.name}' needs auth. Start token exchange process...")
|
Log.d(TAG, "Model '${model.name}' needs auth. Start token exchange process...")
|
||||||
|
|
||||||
|
@ -334,4 +344,30 @@ fun DownloadAndTryButton(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (showErrorDialog) {
|
||||||
|
AlertDialog(
|
||||||
|
icon = {
|
||||||
|
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
|
||||||
|
},
|
||||||
|
title = {
|
||||||
|
Text("Unknown network error")
|
||||||
|
},
|
||||||
|
text = {
|
||||||
|
Text("Please check your internet connection.")
|
||||||
|
},
|
||||||
|
onDismissRequest = {
|
||||||
|
showErrorDialog = false
|
||||||
|
},
|
||||||
|
confirmButton = {
|
||||||
|
TextButton(
|
||||||
|
onClick = {
|
||||||
|
showErrorDialog = false
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
Text("Close")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -484,5 +484,7 @@ fun processLlmResponse(response: String): String {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newContent = newContent.replace("\\n", "\n")
|
||||||
|
|
||||||
return newContent
|
return newContent
|
||||||
}
|
}
|
|
@ -24,6 +24,7 @@ import com.google.aiedge.gallery.data.Model
|
||||||
|
|
||||||
enum class ChatMessageType {
|
enum class ChatMessageType {
|
||||||
INFO,
|
INFO,
|
||||||
|
WARNING,
|
||||||
TEXT,
|
TEXT,
|
||||||
IMAGE,
|
IMAGE,
|
||||||
IMAGE_WITH_HISTORY,
|
IMAGE_WITH_HISTORY,
|
||||||
|
@ -57,6 +58,10 @@ class ChatMessageLoading : ChatMessage(type = ChatMessageType.LOADING, side = Ch
|
||||||
class ChatMessageInfo(val content: String) :
|
class ChatMessageInfo(val content: String) :
|
||||||
ChatMessage(type = ChatMessageType.INFO, side = ChatSide.SYSTEM)
|
ChatMessage(type = ChatMessageType.INFO, side = ChatSide.SYSTEM)
|
||||||
|
|
||||||
|
/** Chat message for info (help). */
|
||||||
|
class ChatMessageWarning(val content: String) :
|
||||||
|
ChatMessage(type = ChatMessageType.WARNING, side = ChatSide.SYSTEM)
|
||||||
|
|
||||||
/** Chat message for config values change. */
|
/** Chat message for config values change. */
|
||||||
class ChatMessageConfigValuesChange(
|
class ChatMessageConfigValuesChange(
|
||||||
val model: Model,
|
val model: Model,
|
||||||
|
|
|
@ -269,6 +269,9 @@ fun ChatPanel(
|
||||||
// Info.
|
// Info.
|
||||||
is ChatMessageInfo -> MessageBodyInfo(message = message)
|
is ChatMessageInfo -> MessageBodyInfo(message = message)
|
||||||
|
|
||||||
|
// Warning
|
||||||
|
is ChatMessageWarning -> MessageBodyWarning(message = message)
|
||||||
|
|
||||||
// Config values change.
|
// Config values change.
|
||||||
is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message)
|
is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message)
|
||||||
|
|
||||||
|
@ -433,6 +436,7 @@ fun ChatPanel(
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
curMessage = curMessage,
|
curMessage = curMessage,
|
||||||
inProgress = uiState.inProgress,
|
inProgress = uiState.inProgress,
|
||||||
|
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||||
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
||||||
onValueChanged = { curMessage = it },
|
onValueChanged = { curMessage = it },
|
||||||
onSendMessage = {
|
onSendMessage = {
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.common.chat
|
||||||
|
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Composable function to display warning message content within a chat.
|
||||||
|
*
|
||||||
|
* Supports markdown.
|
||||||
|
*/
|
||||||
|
@Composable
|
||||||
|
fun MessageBodyWarning(message: ChatMessageWarning) {
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
|
||||||
|
) {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.clip(RoundedCornerShape(16.dp))
|
||||||
|
.background(MaterialTheme.colorScheme.tertiaryContainer)
|
||||||
|
) {
|
||||||
|
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Preview(showBackground = true)
|
||||||
|
@Composable
|
||||||
|
fun MessageBodyWarningPreview() {
|
||||||
|
GalleryTheme {
|
||||||
|
Row(modifier = Modifier.padding(16.dp)) {
|
||||||
|
MessageBodyWarning(message = ChatMessageWarning(content = "This is a warning"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -75,6 +75,7 @@ fun MessageInputText(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
curMessage: String,
|
curMessage: String,
|
||||||
inProgress: Boolean,
|
inProgress: Boolean,
|
||||||
|
modelInitializing: Boolean,
|
||||||
@StringRes textFieldPlaceHolderRes: Int,
|
@StringRes textFieldPlaceHolderRes: Int,
|
||||||
onValueChanged: (String) -> Unit,
|
onValueChanged: (String) -> Unit,
|
||||||
onSendMessage: (ChatMessage) -> Unit,
|
onSendMessage: (ChatMessage) -> Unit,
|
||||||
|
@ -162,17 +163,19 @@ fun MessageInputText(
|
||||||
Spacer(modifier = Modifier.width(8.dp))
|
Spacer(modifier = Modifier.width(8.dp))
|
||||||
|
|
||||||
if (inProgress && showStopButtonWhenInProgress) {
|
if (inProgress && showStopButtonWhenInProgress) {
|
||||||
IconButton(
|
if (!modelInitializing) {
|
||||||
onClick = onStopButtonClicked,
|
IconButton(
|
||||||
colors = IconButtonDefaults.iconButtonColors(
|
onClick = onStopButtonClicked,
|
||||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
),
|
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||||
) {
|
),
|
||||||
Icon(
|
) {
|
||||||
Icons.Rounded.Stop,
|
Icon(
|
||||||
contentDescription = "",
|
Icons.Rounded.Stop,
|
||||||
tint = MaterialTheme.colorScheme.primary
|
contentDescription = "",
|
||||||
)
|
tint = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // Send button. Only shown when text is not empty.
|
} // Send button. Only shown when text is not empty.
|
||||||
else if (curMessage.isNotEmpty()) {
|
else if (curMessage.isNotEmpty()) {
|
||||||
|
@ -230,6 +233,7 @@ fun MessageInputTextPreview() {
|
||||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
curMessage = "hello",
|
curMessage = "hello",
|
||||||
inProgress = false,
|
inProgress = false,
|
||||||
|
modelInitializing = false,
|
||||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||||
onValueChanged = {},
|
onValueChanged = {},
|
||||||
onSendMessage = {},
|
onSendMessage = {},
|
||||||
|
@ -239,6 +243,7 @@ fun MessageInputTextPreview() {
|
||||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
curMessage = "hello",
|
curMessage = "hello",
|
||||||
inProgress = true,
|
inProgress = true,
|
||||||
|
modelInitializing = false,
|
||||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||||
onValueChanged = {},
|
onValueChanged = {},
|
||||||
onSendMessage = {},
|
onSendMessage = {},
|
||||||
|
@ -247,6 +252,7 @@ fun MessageInputTextPreview() {
|
||||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
curMessage = "",
|
curMessage = "",
|
||||||
inProgress = false,
|
inProgress = false,
|
||||||
|
modelInitializing = false,
|
||||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||||
onValueChanged = {},
|
onValueChanged = {},
|
||||||
onSendMessage = {},
|
onSendMessage = {},
|
||||||
|
@ -255,6 +261,7 @@ fun MessageInputTextPreview() {
|
||||||
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
|
||||||
curMessage = "",
|
curMessage = "",
|
||||||
inProgress = true,
|
inProgress = true,
|
||||||
|
modelInitializing = false,
|
||||||
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
|
||||||
onValueChanged = {},
|
onValueChanged = {},
|
||||||
onSendMessage = {},
|
onSendMessage = {},
|
||||||
|
|
|
@ -39,7 +39,6 @@ import androidx.compose.material.icons.rounded.ChevronRight
|
||||||
import androidx.compose.material.icons.rounded.Settings
|
import androidx.compose.material.icons.rounded.Settings
|
||||||
import androidx.compose.material.icons.rounded.UnfoldLess
|
import androidx.compose.material.icons.rounded.UnfoldLess
|
||||||
import androidx.compose.material.icons.rounded.UnfoldMore
|
import androidx.compose.material.icons.rounded.UnfoldMore
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.IconButton
|
import androidx.compose.material3.IconButton
|
||||||
import androidx.compose.material3.OutlinedButton
|
import androidx.compose.material3.OutlinedButton
|
||||||
|
|
|
@ -75,6 +75,8 @@ fun ModelNameAndStatus(
|
||||||
) {
|
) {
|
||||||
Text(
|
Text(
|
||||||
model.name,
|
model.name,
|
||||||
|
maxLines = 1,
|
||||||
|
overflow = TextOverflow.Ellipsis,
|
||||||
style = MaterialTheme.typography.titleMedium,
|
style = MaterialTheme.typography.titleMedium,
|
||||||
modifier = modifier,
|
modifier = modifier,
|
||||||
)
|
)
|
||||||
|
|
|
@ -38,6 +38,7 @@ import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.height
|
import androidx.compose.foundation.layout.height
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.foundation.lazy.grid.GridCells
|
import androidx.compose.foundation.lazy.grid.GridCells
|
||||||
import androidx.compose.foundation.lazy.grid.GridItemSpan
|
import androidx.compose.foundation.lazy.grid.GridItemSpan
|
||||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||||
|
@ -46,8 +47,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
||||||
import androidx.compose.material.icons.filled.Add
|
import androidx.compose.material.icons.filled.Add
|
||||||
|
import androidx.compose.material.icons.rounded.Error
|
||||||
|
import androidx.compose.material3.AlertDialog
|
||||||
import androidx.compose.material3.Card
|
import androidx.compose.material3.Card
|
||||||
import androidx.compose.material3.CardDefaults
|
import androidx.compose.material3.CardDefaults
|
||||||
|
import androidx.compose.material3.CircularProgressIndicator
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
@ -57,6 +61,7 @@ import androidx.compose.material3.SmallFloatingActionButton
|
||||||
import androidx.compose.material3.SnackbarHost
|
import androidx.compose.material3.SnackbarHost
|
||||||
import androidx.compose.material3.SnackbarHostState
|
import androidx.compose.material3.SnackbarHostState
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.material3.TextButton
|
||||||
import androidx.compose.material3.TopAppBarDefaults
|
import androidx.compose.material3.TopAppBarDefaults
|
||||||
import androidx.compose.material3.rememberModalBottomSheetState
|
import androidx.compose.material3.rememberModalBottomSheetState
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
|
@ -180,6 +185,7 @@ fun HomeScreen(
|
||||||
TaskList(
|
TaskList(
|
||||||
tasks = tasks,
|
tasks = tasks,
|
||||||
navigateToTaskScreen = navigateToTaskScreen,
|
navigateToTaskScreen = navigateToTaskScreen,
|
||||||
|
loadingModelAllowlist = uiState.loadingModelAllowlist,
|
||||||
modifier = Modifier.fillMaxSize(),
|
modifier = Modifier.fillMaxSize(),
|
||||||
contentPadding = innerPadding,
|
contentPadding = innerPadding,
|
||||||
)
|
)
|
||||||
|
@ -285,12 +291,40 @@ fun HomeScreen(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (uiState.loadingModelAllowlistError.isNotEmpty()) {
|
||||||
|
AlertDialog(
|
||||||
|
icon = {
|
||||||
|
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
|
||||||
|
},
|
||||||
|
title = {
|
||||||
|
Text(uiState.loadingModelAllowlistError)
|
||||||
|
},
|
||||||
|
text = {
|
||||||
|
Text("Please check your internet connection and try again later.")
|
||||||
|
},
|
||||||
|
onDismissRequest = {
|
||||||
|
modelManagerViewModel.loadModelAllowlist()
|
||||||
|
},
|
||||||
|
confirmButton = {
|
||||||
|
TextButton(
|
||||||
|
onClick = {
|
||||||
|
modelManagerViewModel.loadModelAllowlist()
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
Text("Retry")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun TaskList(
|
private fun TaskList(
|
||||||
tasks: List<Task>,
|
tasks: List<Task>,
|
||||||
navigateToTaskScreen: (Task) -> Unit,
|
navigateToTaskScreen: (Task) -> Unit,
|
||||||
|
loadingModelAllowlist: Boolean,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
contentPadding: PaddingValues = PaddingValues(0.dp),
|
contentPadding: PaddingValues = PaddingValues(0.dp),
|
||||||
) {
|
) {
|
||||||
|
@ -312,17 +346,37 @@ private fun TaskList(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cards.
|
if (loadingModelAllowlist) {
|
||||||
items(tasks) { task ->
|
item(key = "loading", span = { GridItemSpan(2) }) {
|
||||||
TaskCard(
|
Row(
|
||||||
task = task,
|
horizontalArrangement = Arrangement.Center,
|
||||||
onClick = {
|
modifier = Modifier
|
||||||
navigateToTaskScreen(task)
|
.fillMaxWidth()
|
||||||
},
|
.padding(top = 32.dp)
|
||||||
modifier = Modifier
|
) {
|
||||||
.fillMaxWidth()
|
CircularProgressIndicator(
|
||||||
.aspectRatio(1f)
|
trackColor = MaterialTheme.colorScheme.surfaceVariant,
|
||||||
)
|
strokeWidth = 3.dp,
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(end = 8.dp)
|
||||||
|
.size(20.dp)
|
||||||
|
)
|
||||||
|
Text("Loading model list...", style = MaterialTheme.typography.bodyMedium)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Cards.
|
||||||
|
items(tasks) { task ->
|
||||||
|
TaskCard(
|
||||||
|
task = task,
|
||||||
|
onClick = {
|
||||||
|
navigateToTaskScreen(task)
|
||||||
|
},
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.aspectRatio(1f)
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bottom padding.
|
// Bottom padding.
|
||||||
|
|
|
@ -16,11 +16,15 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.llmchat
|
package com.google.aiedge.gallery.ui.llmchat
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||||
import com.google.aiedge.gallery.ui.ViewModelProvider
|
import com.google.aiedge.gallery.ui.ViewModelProvider
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageInfo
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
|
||||||
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageWarning
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatView
|
import com.google.aiedge.gallery.ui.common.chat.ChatView
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import kotlinx.serialization.Serializable
|
import kotlinx.serialization.Serializable
|
||||||
|
@ -40,6 +44,8 @@ fun LlmChatScreen(
|
||||||
factory = ViewModelProvider.Factory
|
factory = ViewModelProvider.Factory
|
||||||
),
|
),
|
||||||
) {
|
) {
|
||||||
|
val context = LocalContext.current
|
||||||
|
|
||||||
ChatView(
|
ChatView(
|
||||||
task = viewModel.task,
|
task = viewModel.task,
|
||||||
viewModel = viewModel,
|
viewModel = viewModel,
|
||||||
|
@ -51,25 +57,43 @@ fun LlmChatScreen(
|
||||||
)
|
)
|
||||||
if (message is ChatMessageText) {
|
if (message is ChatMessageText) {
|
||||||
modelManagerViewModel.addTextInputHistory(message.content)
|
modelManagerViewModel.addTextInputHistory(message.content)
|
||||||
viewModel.generateResponse(
|
viewModel.generateResponse(model = model, input = message.content, onError = {
|
||||||
model = model,
|
viewModel.addMessage(
|
||||||
input = message.content,
|
model = model,
|
||||||
)
|
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
|
||||||
|
)
|
||||||
|
|
||||||
|
modelManagerViewModel.initializeModel(
|
||||||
|
context = context, task = viewModel.task, model = model, force = true
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onRunAgainClicked = { model, message ->
|
onRunAgainClicked = { model, message ->
|
||||||
if (message is ChatMessageText) {
|
if (message is ChatMessageText) {
|
||||||
viewModel.runAgain(model = model, message = message)
|
viewModel.runAgain(model = model, message = message, onError = {
|
||||||
|
viewModel.addMessage(
|
||||||
|
model = model,
|
||||||
|
message = ChatMessageWarning(content = "Error occurred. Re-initializing the engine.")
|
||||||
|
)
|
||||||
|
|
||||||
|
modelManagerViewModel.initializeModel(
|
||||||
|
context = context, task = viewModel.task, model = model, force = true
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
|
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
|
||||||
if (message is ChatMessageText) {
|
if (message is ChatMessageText) {
|
||||||
viewModel.benchmark(
|
viewModel.benchmark(
|
||||||
model = model,
|
model = model, message = message
|
||||||
message = message
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
showStopButtonInInputWhenInProgress = true,
|
||||||
|
onStopButtonClicked = { model ->
|
||||||
|
viewModel.stopResponse(model = model)
|
||||||
|
},
|
||||||
navigateUp = navigateUp,
|
navigateUp = navigateUp,
|
||||||
modifier = modifier,
|
modifier = modifier,
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.llmchat
|
package com.google.aiedge.gallery.ui.llmchat
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
|
@ -26,6 +27,7 @@ import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatSide
|
import com.google.aiedge.gallery.ui.common.chat.ChatSide
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
|
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
|
||||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||||
|
import kotlinx.coroutines.CoroutineExceptionHandler
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
@ -39,7 +41,7 @@ private val STATS = listOf(
|
||||||
)
|
)
|
||||||
|
|
||||||
class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
fun generateResponse(model: Model, input: String) {
|
fun generateResponse(model: Model, input: String, onError: () -> Unit) {
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
setInProgress(true)
|
setInProgress(true)
|
||||||
|
|
||||||
|
@ -65,75 +67,90 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
var prefillSpeed = 0f
|
var prefillSpeed = 0f
|
||||||
var decodeSpeed: Float
|
var decodeSpeed: Float
|
||||||
val start = System.currentTimeMillis()
|
val start = System.currentTimeMillis()
|
||||||
LlmChatModelHelper.runInference(
|
|
||||||
model = model,
|
|
||||||
input = input,
|
|
||||||
resultListener = { partialResult, done ->
|
|
||||||
val curTs = System.currentTimeMillis()
|
|
||||||
|
|
||||||
if (firstRun) {
|
try {
|
||||||
firstTokenTs = System.currentTimeMillis()
|
LlmChatModelHelper.runInference(
|
||||||
timeToFirstToken = (firstTokenTs - start) / 1000f
|
model = model,
|
||||||
prefillSpeed = prefillTokens / timeToFirstToken
|
input = input,
|
||||||
firstRun = false
|
resultListener = { partialResult, done ->
|
||||||
} else {
|
val curTs = System.currentTimeMillis()
|
||||||
decodeTokens++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the last message if it is a "loading" message.
|
if (firstRun) {
|
||||||
// This will only be done once.
|
firstTokenTs = System.currentTimeMillis()
|
||||||
val lastMessage = getLastMessage(model = model)
|
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||||
if (lastMessage?.type == ChatMessageType.LOADING) {
|
prefillSpeed = prefillTokens / timeToFirstToken
|
||||||
removeLastMessage(model = model)
|
firstRun = false
|
||||||
|
} else {
|
||||||
// Add an empty message that will receive streaming results.
|
decodeTokens++
|
||||||
addMessage(
|
|
||||||
model = model,
|
|
||||||
message = ChatMessageText(content = "", side = ChatSide.AGENT)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Incrementally update the streamed partial results.
|
|
||||||
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
|
|
||||||
updateLastTextMessageContentIncrementally(
|
|
||||||
model = model,
|
|
||||||
partialContent = partialResult,
|
|
||||||
latencyMs = latencyMs.toFloat()
|
|
||||||
)
|
|
||||||
|
|
||||||
if (done) {
|
|
||||||
setInProgress(false)
|
|
||||||
|
|
||||||
decodeSpeed =
|
|
||||||
decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
|
||||||
if (decodeSpeed.isNaN()) {
|
|
||||||
decodeSpeed = 0f
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lastMessage is ChatMessageText) {
|
// Remove the last message if it is a "loading" message.
|
||||||
updateLastTextMessageLlmBenchmarkResult(
|
// This will only be done once.
|
||||||
model = model, llmBenchmarkResult =
|
val lastMessage = getLastMessage(model = model)
|
||||||
ChatMessageBenchmarkLlmResult(
|
if (lastMessage?.type == ChatMessageType.LOADING) {
|
||||||
orderedStats = STATS,
|
removeLastMessage(model = model)
|
||||||
statValues = mutableMapOf(
|
|
||||||
"prefill_speed" to prefillSpeed,
|
// Add an empty message that will receive streaming results.
|
||||||
"decode_speed" to decodeSpeed,
|
addMessage(
|
||||||
"time_to_first_token" to timeToFirstToken,
|
model = model,
|
||||||
"latency" to (curTs - start).toFloat() / 1000f,
|
message = ChatMessageText(content = "", side = ChatSide.AGENT)
|
||||||
),
|
|
||||||
running = false,
|
|
||||||
latencyMs = -1f,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}, cleanUpListener = {
|
// Incrementally update the streamed partial results.
|
||||||
setInProgress(false)
|
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
|
||||||
})
|
updateLastTextMessageContentIncrementally(
|
||||||
|
model = model,
|
||||||
|
partialContent = partialResult,
|
||||||
|
latencyMs = latencyMs.toFloat()
|
||||||
|
)
|
||||||
|
|
||||||
|
if (done) {
|
||||||
|
setInProgress(false)
|
||||||
|
|
||||||
|
decodeSpeed =
|
||||||
|
decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
||||||
|
if (decodeSpeed.isNaN()) {
|
||||||
|
decodeSpeed = 0f
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lastMessage is ChatMessageText) {
|
||||||
|
updateLastTextMessageLlmBenchmarkResult(
|
||||||
|
model = model, llmBenchmarkResult =
|
||||||
|
ChatMessageBenchmarkLlmResult(
|
||||||
|
orderedStats = STATS,
|
||||||
|
statValues = mutableMapOf(
|
||||||
|
"prefill_speed" to prefillSpeed,
|
||||||
|
"decode_speed" to decodeSpeed,
|
||||||
|
"time_to_first_token" to timeToFirstToken,
|
||||||
|
"latency" to (curTs - start).toFloat() / 1000f,
|
||||||
|
),
|
||||||
|
running = false,
|
||||||
|
latencyMs = -1f,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, cleanUpListener = {
|
||||||
|
setInProgress(false)
|
||||||
|
})
|
||||||
|
} catch (e: Exception) {
|
||||||
|
setInProgress(false)
|
||||||
|
onError()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun runAgain(model: Model, message: ChatMessageText) {
|
fun stopResponse(model: Model) {
|
||||||
|
Log.d(TAG, "Stopping response for model ${model.name}...")
|
||||||
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
|
setInProgress(false)
|
||||||
|
val instance = model.instance as LlmModelInstance
|
||||||
|
instance.session.cancelGenerateResponseAsync()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) {
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
// Wait for model to be initialized.
|
// Wait for model to be initialized.
|
||||||
while (model.instance == null) {
|
while (model.instance == null) {
|
||||||
|
@ -147,6 +164,7 @@ class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
|
||||||
generateResponse(
|
generateResponse(
|
||||||
model = model,
|
model = model,
|
||||||
input = message.content,
|
input = message.content,
|
||||||
|
onError = onError
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ import android.util.Log
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||||
|
@ -64,7 +64,7 @@ private val STATS = listOf(
|
||||||
Stat(id = "latency", label = "Latency", unit = "sec")
|
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||||
)
|
)
|
||||||
|
|
||||||
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_SINGLE_TURN) : ViewModel() {
|
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewModel() {
|
||||||
private val _uiState = MutableStateFlow(createUiState(task = task))
|
private val _uiState = MutableStateFlow(createUiState(task = task))
|
||||||
val uiState = _uiState.asStateFlow()
|
val uiState = _uiState.asStateFlow()
|
||||||
|
|
||||||
|
|
|
@ -218,6 +218,7 @@ fun PromptTemplatesPanel(
|
||||||
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
|
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
|
||||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||||
.padding(16.dp)
|
.padding(16.dp)
|
||||||
|
.focusRequester(focusRequester)
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
TextField(
|
TextField(
|
||||||
|
|
|
@ -57,7 +57,7 @@ import androidx.compose.ui.platform.LocalClipboardManager
|
||||||
import androidx.compose.ui.text.AnnotatedString
|
import androidx.compose.ui.text.AnnotatedString
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
||||||
|
@ -76,7 +76,7 @@ fun ResponsePanel(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
) {
|
) {
|
||||||
val task = TASK_LLM_SINGLE_TURN
|
val task = TASK_LLM_USECASES
|
||||||
val uiState by viewModel.uiState.collectAsState()
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val inProgress = uiState.inProgress
|
val inProgress = uiState.inProgress
|
||||||
|
|
|
@ -16,8 +16,6 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.modelmanager
|
package com.google.aiedge.gallery.ui.modelmanager
|
||||||
|
|
||||||
import android.os.Build
|
|
||||||
import androidx.annotation.RequiresApi
|
|
||||||
import androidx.compose.foundation.clickable
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
|
@ -62,7 +60,6 @@ import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
private const val TAG = "AGModelList"
|
private const val TAG = "AGModelList"
|
||||||
|
|
||||||
/** The list of models in the model manager. */
|
/** The list of models in the model manager. */
|
||||||
@RequiresApi(Build.VERSION_CODES.O)
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelList(
|
fun ModelList(
|
||||||
task: Task,
|
task: Task,
|
||||||
|
@ -213,7 +210,6 @@ fun ClickableLink(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@RequiresApi(Build.VERSION_CODES.O)
|
|
||||||
@Preview(showBackground = true)
|
@Preview(showBackground = true)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelListPreview() {
|
fun ModelListPreview() {
|
||||||
|
|
|
@ -30,32 +30,28 @@ import com.google.aiedge.gallery.data.ConfigKey
|
||||||
import com.google.aiedge.gallery.data.DataStoreRepository
|
import com.google.aiedge.gallery.data.DataStoreRepository
|
||||||
import com.google.aiedge.gallery.data.DownloadRepository
|
import com.google.aiedge.gallery.data.DownloadRepository
|
||||||
import com.google.aiedge.gallery.data.EMPTY_MODEL
|
import com.google.aiedge.gallery.data.EMPTY_MODEL
|
||||||
import com.google.aiedge.gallery.data.HfModel
|
|
||||||
import com.google.aiedge.gallery.data.HfModelDetails
|
|
||||||
import com.google.aiedge.gallery.data.HfModelSummary
|
|
||||||
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
import com.google.aiedge.gallery.data.IMPORTS_DIR
|
||||||
import com.google.aiedge.gallery.data.ImportedModelInfo
|
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.ModelAllowlist
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.TASKS
|
import com.google.aiedge.gallery.data.TASKS
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
import com.google.aiedge.gallery.data.ValueType
|
import com.google.aiedge.gallery.data.ValueType
|
||||||
import com.google.aiedge.gallery.data.getModelByName
|
import com.google.aiedge.gallery.data.getModelByName
|
||||||
import com.google.aiedge.gallery.ui.common.AuthConfig
|
import com.google.aiedge.gallery.ui.common.AuthConfig
|
||||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||||
|
import com.google.aiedge.gallery.ui.common.processTasks
|
||||||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||||
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
|
||||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
|
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.async
|
|
||||||
import kotlinx.coroutines.awaitAll
|
|
||||||
import kotlinx.coroutines.coroutineScope
|
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
|
@ -73,8 +69,9 @@ import java.net.HttpURLConnection
|
||||||
import java.net.URL
|
import java.net.URL
|
||||||
|
|
||||||
private const val TAG = "AGModelManagerViewModel"
|
private const val TAG = "AGModelManagerViewModel"
|
||||||
private const val HG_COMMUNITY = "jinjingforevercommunity"
|
|
||||||
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
|
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
|
||||||
|
private const val MODEL_ALLOWLIST_URL =
|
||||||
|
"https://raw.githubusercontent.com/jinjingforever/kokoro-codelab-jingjin/refs/heads/main/model_allowlist.json"
|
||||||
|
|
||||||
data class ModelInitializationStatus(
|
data class ModelInitializationStatus(
|
||||||
val status: ModelInitializationStatusType, var error: String = ""
|
val status: ModelInitializationStatusType, var error: String = ""
|
||||||
|
@ -122,6 +119,14 @@ data class ModelManagerUiState(
|
||||||
*/
|
*/
|
||||||
val loadingHfModels: Boolean = false,
|
val loadingHfModels: Boolean = false,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether the app is loading and processing the model allowlist.
|
||||||
|
*/
|
||||||
|
val loadingModelAllowlist: Boolean = true,
|
||||||
|
|
||||||
|
/** The error message when loading the model allowlist. */
|
||||||
|
val loadingModelAllowlistError: String = "",
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The currently selected model.
|
* The currently selected model.
|
||||||
*/
|
*/
|
||||||
|
@ -153,7 +158,7 @@ open class ModelManagerViewModel(
|
||||||
private val externalFilesDir = context.getExternalFilesDir(null)
|
private val externalFilesDir = context.getExternalFilesDir(null)
|
||||||
private val inProgressWorkInfos: List<AGWorkInfo> =
|
private val inProgressWorkInfos: List<AGWorkInfo> =
|
||||||
downloadRepository.getEnqueuedOrRunningWorkInfos()
|
downloadRepository.getEnqueuedOrRunningWorkInfos()
|
||||||
protected val _uiState = MutableStateFlow(createUiState())
|
protected val _uiState = MutableStateFlow(createEmptyUiState())
|
||||||
val uiState = _uiState.asStateFlow()
|
val uiState = _uiState.asStateFlow()
|
||||||
|
|
||||||
val authService = AuthorizationService(context)
|
val authService = AuthorizationService(context)
|
||||||
|
@ -162,44 +167,7 @@ open class ModelManagerViewModel(
|
||||||
var pagerScrollState: MutableStateFlow<PagerScrollState> = MutableStateFlow(PagerScrollState())
|
var pagerScrollState: MutableStateFlow<PagerScrollState> = MutableStateFlow(PagerScrollState())
|
||||||
|
|
||||||
init {
|
init {
|
||||||
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
|
loadModelAllowlist()
|
||||||
|
|
||||||
// Iterate through the inProgressWorkInfos and retrieve the corresponding modes.
|
|
||||||
// Those models are the ones that have not finished downloading.
|
|
||||||
val models: MutableList<Model> = mutableListOf()
|
|
||||||
for (info in inProgressWorkInfos) {
|
|
||||||
getModelByName(info.modelName)?.let { model ->
|
|
||||||
models.add(model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cancel all pending downloads for the retrieved models.
|
|
||||||
downloadRepository.cancelAll(models) {
|
|
||||||
Log.d(TAG, "All pending work is cancelled")
|
|
||||||
|
|
||||||
viewModelScope.launch(Dispatchers.IO) {
|
|
||||||
// Load models from hg community.
|
|
||||||
loadHfModels()
|
|
||||||
Log.d(TAG, "Done loading HF models")
|
|
||||||
|
|
||||||
// Kick off downloads for these models .
|
|
||||||
withContext(Dispatchers.Main) {
|
|
||||||
val tokenStatusAndData = getTokenStatusAndData()
|
|
||||||
for (info in inProgressWorkInfos) {
|
|
||||||
val model: Model? = getModelByName(info.modelName)
|
|
||||||
if (model != null) {
|
|
||||||
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
|
|
||||||
model.accessToken = tokenStatusAndData.data.accessToken
|
|
||||||
}
|
|
||||||
Log.d(TAG, "Sending a new download request for '${model.name}'")
|
|
||||||
downloadRepository.downloadModel(
|
|
||||||
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun onCleared() {
|
override fun onCleared() {
|
||||||
|
@ -231,12 +199,10 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
fun deleteModel(task: Task, model: Model) {
|
fun deleteModel(task: Task, model: Model) {
|
||||||
deleteFileFromExternalFilesDir(model.downloadFileName)
|
if (model.imported) {
|
||||||
for (file in model.extraDataFiles) {
|
deleteFileFromExternalFilesDir(model.downloadFileName)
|
||||||
deleteFileFromExternalFilesDir(file.downloadFileName)
|
} else {
|
||||||
}
|
deleteDirFromExternalFilesDir(model.normalizedName)
|
||||||
if (model.isZip && model.unzipDir.isNotEmpty()) {
|
|
||||||
deleteDirFromExternalFilesDir(model.unzipDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update model download status to NotDownloaded.
|
// Update model download status to NotDownloaded.
|
||||||
|
@ -340,7 +306,7 @@ open class ModelManagerViewModel(
|
||||||
onDone = onDone,
|
onDone = onDone,
|
||||||
)
|
)
|
||||||
|
|
||||||
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.initialize(
|
TaskType.LLM_USECASES -> LlmChatModelHelper.initialize(
|
||||||
context = context,
|
context = context,
|
||||||
model = model,
|
model = model,
|
||||||
onDone = onDone,
|
onDone = onDone,
|
||||||
|
@ -364,7 +330,7 @@ open class ModelManagerViewModel(
|
||||||
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
||||||
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
||||||
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
TaskType.LLM_SINGLE_TURN -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
|
@ -444,33 +410,40 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
fun getModelUrlResponse(model: Model, accessToken: String? = null): Int {
|
fun getModelUrlResponse(model: Model, accessToken: String? = null): Int {
|
||||||
val url = URL(model.url)
|
try {
|
||||||
val connection = url.openConnection() as HttpURLConnection
|
val url = URL(model.url)
|
||||||
if (accessToken != null) {
|
val connection = url.openConnection() as HttpURLConnection
|
||||||
connection.setRequestProperty(
|
if (accessToken != null) {
|
||||||
"Authorization", "Bearer $accessToken"
|
connection.setRequestProperty(
|
||||||
)
|
"Authorization", "Bearer $accessToken"
|
||||||
}
|
)
|
||||||
connection.connect()
|
}
|
||||||
|
connection.connect()
|
||||||
|
|
||||||
// Report the result.
|
// Report the result.
|
||||||
return connection.responseCode
|
return connection.responseCode
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "$e")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun addImportedLlmModel(info: ImportedModelInfo) {
|
fun addImportedLlmModel(info: ImportedModelInfo) {
|
||||||
Log.d(TAG, "adding imported llm model: $info")
|
Log.d(TAG, "adding imported llm model: $info")
|
||||||
|
|
||||||
// Remove duplicated imported model if existed.
|
|
||||||
val task = TASK_LLM_CHAT
|
|
||||||
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
|
||||||
if (modelIndex >= 0) {
|
|
||||||
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
|
||||||
task.models.removeAt(modelIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create model.
|
// Create model.
|
||||||
val model = createModelFromImportedModelInfo(info = info, task = task)
|
val model = createModelFromImportedModelInfo(info = info)
|
||||||
task.models.add(model)
|
|
||||||
|
// Remove duplicated imported model if existed.
|
||||||
|
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES)) {
|
||||||
|
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
||||||
|
if (modelIndex >= 0) {
|
||||||
|
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
||||||
|
task.models.removeAt(modelIndex)
|
||||||
|
}
|
||||||
|
task.models.add(model)
|
||||||
|
task.updateTrigger.value = System.currentTimeMillis()
|
||||||
|
}
|
||||||
|
|
||||||
// Add initial status and states.
|
// Add initial status and states.
|
||||||
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
||||||
|
@ -491,10 +464,6 @@ open class ModelManagerViewModel(
|
||||||
modelInitializationStatus = modelInstances
|
modelInitializationStatus = modelInstances
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
task.updateTrigger.value = System.currentTimeMillis()
|
|
||||||
// Also need to update single turn task.
|
|
||||||
TASK_LLM_SINGLE_TURN.updateTrigger.value = System.currentTimeMillis()
|
|
||||||
|
|
||||||
|
|
||||||
// Add to preference storage.
|
// Add to preference storage.
|
||||||
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
|
val importedModels = dataStoreRepository.readImportedModels().toMutableList()
|
||||||
|
@ -623,10 +592,110 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun processPendingDownloads() {
|
||||||
|
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
|
||||||
|
|
||||||
|
// Iterate through the inProgressWorkInfos and retrieve the corresponding modes.
|
||||||
|
// Those models are the ones that have not finished downloading.
|
||||||
|
val models: MutableList<Model> = mutableListOf()
|
||||||
|
for (info in inProgressWorkInfos) {
|
||||||
|
getModelByName(info.modelName)?.let { model ->
|
||||||
|
models.add(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel all pending downloads for the retrieved models.
|
||||||
|
downloadRepository.cancelAll(models) {
|
||||||
|
Log.d(TAG, "All pending work is cancelled")
|
||||||
|
|
||||||
|
viewModelScope.launch(Dispatchers.IO) {
|
||||||
|
// Kick off downloads for these models .
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
val tokenStatusAndData = getTokenStatusAndData()
|
||||||
|
for (info in inProgressWorkInfos) {
|
||||||
|
val model: Model? = getModelByName(info.modelName)
|
||||||
|
if (model != null) {
|
||||||
|
if (tokenStatusAndData.status == TokenStatus.NOT_EXPIRED && tokenStatusAndData.data != null) {
|
||||||
|
model.accessToken = tokenStatusAndData.data.accessToken
|
||||||
|
}
|
||||||
|
Log.d(TAG, "Sending a new download request for '${model.name}'")
|
||||||
|
downloadRepository.downloadModel(
|
||||||
|
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun loadModelAllowlist() {
|
||||||
|
_uiState.update {
|
||||||
|
uiState.value.copy(
|
||||||
|
loadingModelAllowlist = true,
|
||||||
|
loadingModelAllowlistError = ""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
viewModelScope.launch(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
// Load model allowlist json.
|
||||||
|
val modelAllowlist: ModelAllowlist? =
|
||||||
|
getJsonResponse<ModelAllowlist>(url = MODEL_ALLOWLIST_URL)
|
||||||
|
|
||||||
|
if (modelAllowlist == null) {
|
||||||
|
_uiState.update { uiState.value.copy(loadingModelAllowlistError = "Failed to load model list") }
|
||||||
|
return@launch
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.d(TAG, "Allowlist: $modelAllowlist")
|
||||||
|
|
||||||
|
// Convert models in the allowlist.
|
||||||
|
for (allowedModel in modelAllowlist.models) {
|
||||||
|
if (allowedModel.disabled == true) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
val model = allowedModel.toModel()
|
||||||
|
if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) {
|
||||||
|
TASK_LLM_CHAT.models.add(model)
|
||||||
|
}
|
||||||
|
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
|
||||||
|
TASK_LLM_USECASES.models.add(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-process all tasks.
|
||||||
|
processTasks()
|
||||||
|
|
||||||
|
// Update UI state.
|
||||||
|
val newUiState = createUiState()
|
||||||
|
_uiState.update {
|
||||||
|
newUiState.copy(
|
||||||
|
loadingModelAllowlist = false,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process pending downloads.
|
||||||
|
processPendingDownloads()
|
||||||
|
} catch (e: Exception) {
|
||||||
|
e.printStackTrace()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private fun isModelPartiallyDownloaded(model: Model): Boolean {
|
private fun isModelPartiallyDownloaded(model: Model): Boolean {
|
||||||
return inProgressWorkInfos.find { it.modelName == model.name } != null
|
return inProgressWorkInfos.find { it.modelName == model.name } != null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun createEmptyUiState(): ModelManagerUiState {
|
||||||
|
return ModelManagerUiState(
|
||||||
|
tasks = listOf(),
|
||||||
|
modelDownloadStatus = mapOf(),
|
||||||
|
modelInitializationStatus = mapOf(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
private fun createUiState(): ModelManagerUiState {
|
private fun createUiState(): ModelManagerUiState {
|
||||||
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
|
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
|
||||||
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
|
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
|
||||||
|
@ -643,11 +712,11 @@ open class ModelManagerViewModel(
|
||||||
Log.d(TAG, "stored imported model: $importedModel")
|
Log.d(TAG, "stored imported model: $importedModel")
|
||||||
|
|
||||||
// Create model.
|
// Create model.
|
||||||
val model = createModelFromImportedModelInfo(info = importedModel, task = TASK_LLM_CHAT)
|
val model = createModelFromImportedModelInfo(info = importedModel)
|
||||||
|
|
||||||
// Add to task.
|
// Add to task.
|
||||||
val task = TASK_LLM_CHAT
|
TASK_LLM_CHAT.models.add(model)
|
||||||
task.models.add(model)
|
TASK_LLM_USECASES.models.add(model)
|
||||||
|
|
||||||
// Update status.
|
// Update status.
|
||||||
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
modelDownloadStatus[model.name] = ModelDownloadStatus(
|
||||||
|
@ -660,6 +729,7 @@ open class ModelManagerViewModel(
|
||||||
val textInputHistory = dataStoreRepository.readTextInputHistory()
|
val textInputHistory = dataStoreRepository.readTextInputHistory()
|
||||||
Log.d(TAG, "text input history: $textInputHistory")
|
Log.d(TAG, "text input history: $textInputHistory")
|
||||||
|
|
||||||
|
Log.d(TAG, "model download status: $modelDownloadStatus")
|
||||||
return ModelManagerUiState(
|
return ModelManagerUiState(
|
||||||
tasks = TASKS,
|
tasks = TASKS,
|
||||||
modelDownloadStatus = modelDownloadStatus,
|
modelDownloadStatus = modelDownloadStatus,
|
||||||
|
@ -668,7 +738,7 @@ open class ModelManagerViewModel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun createModelFromImportedModelInfo(info: ImportedModelInfo, task: Task): Model {
|
private fun createModelFromImportedModelInfo(info: ImportedModelInfo): Model {
|
||||||
val accelerators: List<Accelerator> = (convertValueToTargetType(
|
val accelerators: List<Accelerator> = (convertValueToTargetType(
|
||||||
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING
|
info.defaultValues[ConfigKey.COMPATIBLE_ACCELERATORS.label]!!, ValueType.STRING
|
||||||
) as String).split(",").mapNotNull { acceleratorLabel ->
|
) as String).split(",").mapNotNull { acceleratorLabel ->
|
||||||
|
@ -733,74 +803,6 @@ open class ModelManagerViewModel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
suspend fun loadHfModels() {
|
|
||||||
// Update loading state shown in ui.
|
|
||||||
_uiState.update {
|
|
||||||
uiState.value.copy(
|
|
||||||
loadingHfModels = true,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
|
|
||||||
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
|
|
||||||
try {
|
|
||||||
// Load model summaries.
|
|
||||||
val modelSummaries =
|
|
||||||
getJsonResponse<List<HfModelSummary>>(url = "https://huggingface.co/api/models?search=$HG_COMMUNITY")
|
|
||||||
Log.d(TAG, "HF model summaries: $modelSummaries")
|
|
||||||
|
|
||||||
// Load individual models in parallel.
|
|
||||||
if (modelSummaries != null) {
|
|
||||||
coroutineScope {
|
|
||||||
val hfModels = modelSummaries.map { summary ->
|
|
||||||
async {
|
|
||||||
val details =
|
|
||||||
getJsonResponse<HfModelDetails>(url = "https://huggingface.co/api/models/${summary.modelId}")
|
|
||||||
if (details != null && details.siblings.find { it.rfilename == "app.json" } != null) {
|
|
||||||
val hfModel =
|
|
||||||
getJsonResponse<HfModel>(url = "https://huggingface.co/${summary.modelId}/resolve/main/app.json")
|
|
||||||
if (hfModel != null) {
|
|
||||||
hfModel.id = details.id
|
|
||||||
}
|
|
||||||
return@async hfModel
|
|
||||||
}
|
|
||||||
return@async null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process loaded app.json
|
|
||||||
for (hfModel in hfModels.awaitAll()) {
|
|
||||||
if (hfModel != null) {
|
|
||||||
Log.d(TAG, "HF model: $hfModel")
|
|
||||||
val task = TASKS.find { it.type.label == hfModel.task }
|
|
||||||
val model = hfModel.toModel()
|
|
||||||
if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) {
|
|
||||||
model.preProcess()
|
|
||||||
Log.d(TAG, "AG model: $model")
|
|
||||||
task.models.add(model)
|
|
||||||
|
|
||||||
// Add initial status and states.
|
|
||||||
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
|
|
||||||
modelInstances[model.name] =
|
|
||||||
ModelInitializationStatus(status = ModelInitializationStatusType.NOT_INITIALIZED)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_uiState.update {
|
|
||||||
uiState.value.copy(
|
|
||||||
loadingHfModels = false,
|
|
||||||
modelDownloadStatus = modelDownloadStatus,
|
|
||||||
modelInitializationStatus = modelInstances
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} catch (e: Exception) {
|
|
||||||
e.printStackTrace()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private inline fun <reified T> getJsonResponse(url: String): T? {
|
private inline fun <reified T> getJsonResponse(url: String): T? {
|
||||||
try {
|
try {
|
||||||
val connection = URL(url).openConnection() as HttpURLConnection
|
val connection = URL(url).openConnection() as HttpURLConnection
|
||||||
|
@ -817,9 +819,10 @@ open class ModelManagerViewModel(
|
||||||
val jsonObj = json.decodeFromString<T>(response)
|
val jsonObj = json.decodeFromString<T>(response)
|
||||||
return jsonObj
|
return jsonObj
|
||||||
} else {
|
} else {
|
||||||
println("HTTP error: $responseCode")
|
Log.e(TAG, "HTTP error: $responseCode")
|
||||||
}
|
}
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Error when getting json response: ${e.message}")
|
||||||
e.printStackTrace()
|
e.printStackTrace()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -859,11 +862,18 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun isModelDownloaded(model: Model): Boolean {
|
private fun isModelDownloaded(model: Model): Boolean {
|
||||||
val downloadedFileExists =
|
val downloadedFileExists = model.downloadFileName.isNotEmpty() && isFileInExternalFilesDir(
|
||||||
model.downloadFileName.isNotEmpty() && isFileInExternalFilesDir(model.downloadFileName)
|
listOf(
|
||||||
|
model.normalizedName, model.version, model.downloadFileName
|
||||||
|
).joinToString(File.separator)
|
||||||
|
)
|
||||||
|
|
||||||
val unzippedDirectoryExists =
|
val unzippedDirectoryExists =
|
||||||
model.isZip && model.unzipDir.isNotEmpty() && isFileInExternalFilesDir(model.unzipDir)
|
model.isZip && model.unzipDir.isNotEmpty() && isFileInExternalFilesDir(
|
||||||
|
listOf(
|
||||||
|
model.normalizedName, model.version, model.unzipDir
|
||||||
|
).joinToString(File.separator)
|
||||||
|
)
|
||||||
|
|
||||||
// Will also return true if model is partially downloaded.
|
// Will also return true if model is partially downloaded.
|
||||||
return downloadedFileExists || unzippedDirectoryExists
|
return downloadedFileExists || unzippedDirectoryExists
|
||||||
|
|
|
@ -46,7 +46,7 @@ import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
||||||
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
||||||
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
|
@ -231,7 +231,7 @@ fun GalleryNavHost(
|
||||||
enterTransition = { slideEnter() },
|
enterTransition = { slideEnter() },
|
||||||
exitTransition = { slideExit() },
|
exitTransition = { slideExit() },
|
||||||
) {
|
) {
|
||||||
getModelFromNavigationParam(it, TASK_LLM_SINGLE_TURN)?.let { defaultModel ->
|
getModelFromNavigationParam(it, TASK_LLM_USECASES)?.let { defaultModel ->
|
||||||
modelManagerViewModel.selectModel(defaultModel)
|
modelManagerViewModel.selectModel(defaultModel)
|
||||||
|
|
||||||
LlmSingleTurnScreen(
|
LlmSingleTurnScreen(
|
||||||
|
@ -271,7 +271,7 @@ fun navigateToTaskScreen(
|
||||||
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
||||||
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
||||||
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
||||||
TaskType.LLM_SINGLE_TURN -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||||
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import androidx.work.WorkerParameters
|
||||||
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ACCESS_TOKEN
|
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ACCESS_TOKEN
|
||||||
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ERROR_MESSAGE
|
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ERROR_MESSAGE
|
||||||
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_FILE_NAME
|
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_FILE_NAME
|
||||||
|
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_MODEL_DIR
|
||||||
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RATE
|
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RATE
|
||||||
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RECEIVED_BYTES
|
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RECEIVED_BYTES
|
||||||
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_REMAINING_MS
|
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_REMAINING_MS
|
||||||
|
@ -34,6 +35,7 @@ import com.google.aiedge.gallery.data.KEY_MODEL_START_UNZIPPING
|
||||||
import com.google.aiedge.gallery.data.KEY_MODEL_TOTAL_BYTES
|
import com.google.aiedge.gallery.data.KEY_MODEL_TOTAL_BYTES
|
||||||
import com.google.aiedge.gallery.data.KEY_MODEL_UNZIPPED_DIR
|
import com.google.aiedge.gallery.data.KEY_MODEL_UNZIPPED_DIR
|
||||||
import com.google.aiedge.gallery.data.KEY_MODEL_URL
|
import com.google.aiedge.gallery.data.KEY_MODEL_URL
|
||||||
|
import com.google.aiedge.gallery.data.KEY_MODEL_VERSION
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import java.io.BufferedInputStream
|
import java.io.BufferedInputStream
|
||||||
|
@ -48,7 +50,10 @@ import java.util.zip.ZipInputStream
|
||||||
|
|
||||||
private const val TAG = "AGDownloadWorker"
|
private const val TAG = "AGDownloadWorker"
|
||||||
|
|
||||||
data class UrlAndFileName(val url: String, val fileName: String)
|
data class UrlAndFileName(
|
||||||
|
val url: String,
|
||||||
|
val fileName: String,
|
||||||
|
)
|
||||||
|
|
||||||
class DownloadWorker(context: Context, params: WorkerParameters) :
|
class DownloadWorker(context: Context, params: WorkerParameters) :
|
||||||
CoroutineWorker(context, params) {
|
CoroutineWorker(context, params) {
|
||||||
|
@ -56,7 +61,9 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
|
||||||
|
|
||||||
override suspend fun doWork(): Result {
|
override suspend fun doWork(): Result {
|
||||||
val fileUrl = inputData.getString(KEY_MODEL_URL)
|
val fileUrl = inputData.getString(KEY_MODEL_URL)
|
||||||
|
val version = inputData.getString(KEY_MODEL_VERSION)!!
|
||||||
val fileName = inputData.getString(KEY_MODEL_DOWNLOAD_FILE_NAME)
|
val fileName = inputData.getString(KEY_MODEL_DOWNLOAD_FILE_NAME)
|
||||||
|
val modelDir = inputData.getString(KEY_MODEL_DOWNLOAD_MODEL_DIR)!!
|
||||||
val isZip = inputData.getBoolean(KEY_MODEL_IS_ZIP, false)
|
val isZip = inputData.getBoolean(KEY_MODEL_IS_ZIP, false)
|
||||||
val unzippedDir = inputData.getString(KEY_MODEL_UNZIPPED_DIR)
|
val unzippedDir = inputData.getString(KEY_MODEL_UNZIPPED_DIR)
|
||||||
val extraDataFileUrls = inputData.getString(KEY_MODEL_EXTRA_DATA_URLS)?.split(",") ?: listOf()
|
val extraDataFileUrls = inputData.getString(KEY_MODEL_EXTRA_DATA_URLS)?.split(",") ?: listOf()
|
||||||
|
@ -96,8 +103,20 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
|
||||||
connection.setRequestProperty("Authorization", "Bearer $accessToken")
|
connection.setRequestProperty("Authorization", "Bearer $accessToken")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prepare output file's dir.
|
||||||
|
val outputDir = File(
|
||||||
|
applicationContext.getExternalFilesDir(null),
|
||||||
|
listOf(modelDir, version).joinToString(separator = File.separator)
|
||||||
|
)
|
||||||
|
if (!outputDir.exists()) {
|
||||||
|
outputDir.mkdirs()
|
||||||
|
}
|
||||||
|
|
||||||
// Read the file and see if it is partially downloaded.
|
// Read the file and see if it is partially downloaded.
|
||||||
val outputFile = File(applicationContext.getExternalFilesDir(null), file.fileName)
|
val outputFile = File(
|
||||||
|
applicationContext.getExternalFilesDir(null),
|
||||||
|
listOf(modelDir, version, file.fileName).joinToString(separator = File.separator)
|
||||||
|
)
|
||||||
val outputFileBytes = outputFile.length()
|
val outputFileBytes = outputFile.length()
|
||||||
if (outputFileBytes > 0) {
|
if (outputFileBytes > 0) {
|
||||||
Log.d(
|
Log.d(
|
||||||
|
@ -192,14 +211,19 @@ class DownloadWorker(context: Context, params: WorkerParameters) :
|
||||||
setProgress(Data.Builder().putBoolean(KEY_MODEL_START_UNZIPPING, true).build())
|
setProgress(Data.Builder().putBoolean(KEY_MODEL_START_UNZIPPING, true).build())
|
||||||
|
|
||||||
// Prepare target dir.
|
// Prepare target dir.
|
||||||
val destDir = File("${externalFilesDir}${File.separator}${unzippedDir}")
|
val destDir =
|
||||||
|
File(
|
||||||
|
externalFilesDir,
|
||||||
|
listOf(modelDir, version, unzippedDir).joinToString(File.separator)
|
||||||
|
)
|
||||||
if (!destDir.exists()) {
|
if (!destDir.exists()) {
|
||||||
destDir.mkdirs()
|
destDir.mkdirs()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unzip.
|
// Unzip.
|
||||||
val unzipBuffer = ByteArray(4096)
|
val unzipBuffer = ByteArray(4096)
|
||||||
val zipFilePath = "${externalFilesDir}${File.separator}${fileName}"
|
val zipFilePath =
|
||||||
|
"${externalFilesDir}${File.separator}$modelDir${File.separator}$version${File.separator}${fileName}"
|
||||||
val zipIn = ZipInputStream(BufferedInputStream(FileInputStream(zipFilePath)))
|
val zipIn = ZipInputStream(BufferedInputStream(FileInputStream(zipFilePath)))
|
||||||
var zipEntry: ZipEntry? = zipIn.nextEntry
|
var zipEntry: ZipEntry? = zipIn.nextEntry
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ gson = "2.12.1"
|
||||||
lifecycleProcess = "2.8.7"
|
lifecycleProcess = "2.8.7"
|
||||||
#noinspection GradleDependency
|
#noinspection GradleDependency
|
||||||
mediapipeTasksText = "0.10.21"
|
mediapipeTasksText = "0.10.21"
|
||||||
mediapipeTasksGenai = "0.10.22"
|
mediapipeTasksGenai = "0.10.24"
|
||||||
mediapipeTasksImageGenerator = "0.10.21"
|
mediapipeTasksImageGenerator = "0.10.21"
|
||||||
commonmark = "1.0.0-alpha02"
|
commonmark = "1.0.0-alpha02"
|
||||||
richtext = "1.0.0-alpha02"
|
richtext = "1.0.0-alpha02"
|
||||||
|
|
|
@ -30,6 +30,7 @@ pluginManagement {
|
||||||
dependencyResolutionManagement {
|
dependencyResolutionManagement {
|
||||||
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
|
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
|
||||||
repositories {
|
repositories {
|
||||||
|
mavenLocal()
|
||||||
google()
|
google()
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue