mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-15 10:46:38 -04:00
Various bug fixes.
This commit is contained in:
parent
0c49efc054
commit
37a58d1a41
35 changed files with 1517 additions and 995 deletions
|
@ -27,10 +27,10 @@ android {
|
||||||
|
|
||||||
defaultConfig {
|
defaultConfig {
|
||||||
applicationId = "com.google.aiedge.gallery"
|
applicationId = "com.google.aiedge.gallery"
|
||||||
minSdk = 24
|
minSdk = 26
|
||||||
targetSdk = 35
|
targetSdk = 35
|
||||||
versionCode = 1
|
versionCode = 1
|
||||||
versionName = "20250428"
|
versionName = "0.9.0"
|
||||||
|
|
||||||
// Needed for HuggingFace auth workflows.
|
// Needed for HuggingFace auth workflows.
|
||||||
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
|
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
|
||||||
|
|
|
@ -41,7 +41,9 @@
|
||||||
android:name=".MainActivity"
|
android:name=".MainActivity"
|
||||||
android:exported="true"
|
android:exported="true"
|
||||||
android:theme="@style/Theme.Gallery.SplashScreen"
|
android:theme="@style/Theme.Gallery.SplashScreen"
|
||||||
android:windowSoftInputMode="adjustResize">
|
android:screenOrientation="portrait"
|
||||||
|
android:windowSoftInputMode="adjustResize"
|
||||||
|
tools:ignore="DiscouragedApi,LockedOrientationActivity">
|
||||||
<!-- This is for putting the app into launcher -->
|
<!-- This is for putting the app into launcher -->
|
||||||
<intent-filter>
|
<intent-filter>
|
||||||
<action android:name="android.intent.action.MAIN" />
|
<action android:name="android.intent.action.MAIN" />
|
||||||
|
|
|
@ -68,7 +68,6 @@ fun GalleryTopAppBar(
|
||||||
leftAction: AppBarAction? = null,
|
leftAction: AppBarAction? = null,
|
||||||
rightAction: AppBarAction? = null,
|
rightAction: AppBarAction? = null,
|
||||||
scrollBehavior: TopAppBarScrollBehavior? = null,
|
scrollBehavior: TopAppBarScrollBehavior? = null,
|
||||||
loadingHfModels: Boolean = false,
|
|
||||||
subtitle: String = "",
|
subtitle: String = "",
|
||||||
) {
|
) {
|
||||||
CenterAlignedTopAppBar(
|
CenterAlignedTopAppBar(
|
||||||
|
@ -151,28 +150,6 @@ fun GalleryTopAppBar(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Click an icon to open "download manager".
|
|
||||||
AppBarActionType.DOWNLOAD_MANAGER -> {
|
|
||||||
if (loadingHfModels) {
|
|
||||||
CircularProgressIndicator(
|
|
||||||
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
|
||||||
strokeWidth = 3.dp,
|
|
||||||
modifier = Modifier
|
|
||||||
.padding(end = 12.dp)
|
|
||||||
.size(20.dp)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
// else {
|
|
||||||
// IconButton(onClick = rightAction.actionFn) {
|
|
||||||
// Icon(
|
|
||||||
// imageVector = Deployed_code,
|
|
||||||
// contentDescription = "",
|
|
||||||
// tint = MaterialTheme.colorScheme.primary
|
|
||||||
// )
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
AppBarActionType.MODEL_SELECTOR -> {
|
AppBarActionType.MODEL_SELECTOR -> {
|
||||||
Text("ms")
|
Text("ms")
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ import javax.crypto.SecretKey
|
||||||
data class AccessTokenData(
|
data class AccessTokenData(
|
||||||
val accessToken: String,
|
val accessToken: String,
|
||||||
val refreshToken: String,
|
val refreshToken: String,
|
||||||
val expiresAtSeconds: Long
|
val expiresAtMs: Long
|
||||||
)
|
)
|
||||||
|
|
||||||
interface DataStoreRepository {
|
interface DataStoreRepository {
|
||||||
|
@ -46,6 +46,7 @@ interface DataStoreRepository {
|
||||||
fun saveThemeOverride(theme: String)
|
fun saveThemeOverride(theme: String)
|
||||||
fun readThemeOverride(): String
|
fun readThemeOverride(): String
|
||||||
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
|
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
|
||||||
|
fun clearAccessTokenData()
|
||||||
fun readAccessTokenData(): AccessTokenData?
|
fun readAccessTokenData(): AccessTokenData?
|
||||||
fun saveImportedModels(importedModels: List<ImportedModelInfo>)
|
fun saveImportedModels(importedModels: List<ImportedModelInfo>)
|
||||||
fun readImportedModels(): List<ImportedModelInfo>
|
fun readImportedModels(): List<ImportedModelInfo>
|
||||||
|
@ -135,6 +136,18 @@ class DefaultDataStoreRepository(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun clearAccessTokenData() {
|
||||||
|
return runBlocking {
|
||||||
|
dataStore.edit { preferences ->
|
||||||
|
preferences.remove(PreferencesKeys.ENCRYPTED_ACCESS_TOKEN)
|
||||||
|
preferences.remove(PreferencesKeys.ACCESS_TOKEN_IV)
|
||||||
|
preferences.remove(PreferencesKeys.ENCRYPTED_REFRESH_TOKEN)
|
||||||
|
preferences.remove(PreferencesKeys.REFRESH_TOKEN_IV)
|
||||||
|
preferences.remove(PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override fun readAccessTokenData(): AccessTokenData? {
|
override fun readAccessTokenData(): AccessTokenData? {
|
||||||
return runBlocking {
|
return runBlocking {
|
||||||
val preferences = dataStore.data.first()
|
val preferences = dataStore.data.first()
|
||||||
|
|
|
@ -43,7 +43,7 @@ data class AllowedModel(
|
||||||
|
|
||||||
// Config.
|
// Config.
|
||||||
val isLlmModel =
|
val isLlmModel =
|
||||||
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_USECASES.type.id)
|
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)
|
||||||
var configs: List<Config> = listOf()
|
var configs: List<Config> = listOf()
|
||||||
if (isLlmModel) {
|
if (isLlmModel) {
|
||||||
var defaultTopK: Int = DEFAULT_TOPK
|
var defaultTopK: Int = DEFAULT_TOPK
|
||||||
|
|
|
@ -32,9 +32,9 @@ enum class TaskType(val label: String, val id: String) {
|
||||||
TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
|
TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
|
||||||
IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
|
IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
|
||||||
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
|
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
|
||||||
LLM_CHAT(label = "LLM Chat", id = "llm_chat"),
|
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
|
||||||
LLM_USECASES(label = "LLM Use Cases", id = "llm_usecases"),
|
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
|
||||||
LLM_IMAGE_TO_TEXT(label = "LLM Image to Text", id = "llm_image_to_text"),
|
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
|
||||||
|
|
||||||
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
||||||
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
|
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
|
||||||
|
@ -93,7 +93,6 @@ val TASK_IMAGE_CLASSIFICATION = Task(
|
||||||
val TASK_LLM_CHAT = Task(
|
val TASK_LLM_CHAT = Task(
|
||||||
type = TaskType.LLM_CHAT,
|
type = TaskType.LLM_CHAT,
|
||||||
icon = Icons.Outlined.Forum,
|
icon = Icons.Outlined.Forum,
|
||||||
// models = MODELS_LLM,
|
|
||||||
models = mutableListOf(),
|
models = mutableListOf(),
|
||||||
description = "Chat with on-device large language models",
|
description = "Chat with on-device large language models",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
@ -101,10 +100,9 @@ val TASK_LLM_CHAT = Task(
|
||||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||||
)
|
)
|
||||||
|
|
||||||
val TASK_LLM_USECASES = Task(
|
val TASK_LLM_PROMPT_LAB = Task(
|
||||||
type = TaskType.LLM_USECASES,
|
type = TaskType.LLM_PROMPT_LAB,
|
||||||
icon = Icons.Outlined.Widgets,
|
icon = Icons.Outlined.Widgets,
|
||||||
// models = MODELS_LLM,
|
|
||||||
models = mutableListOf(),
|
models = mutableListOf(),
|
||||||
description = "Single turn use cases with on-device large language model",
|
description = "Single turn use cases with on-device large language model",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
@ -112,10 +110,9 @@ val TASK_LLM_USECASES = Task(
|
||||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
|
||||||
)
|
)
|
||||||
|
|
||||||
val TASK_LLM_IMAGE_TO_TEXT = Task(
|
val TASK_LLM_ASK_IMAGE = Task(
|
||||||
type = TaskType.LLM_IMAGE_TO_TEXT,
|
type = TaskType.LLM_ASK_IMAGE,
|
||||||
icon = Icons.Outlined.Mms,
|
icon = Icons.Outlined.Mms,
|
||||||
// models = MODELS_LLM,
|
|
||||||
models = mutableListOf(),
|
models = mutableListOf(),
|
||||||
description = "Ask questions about images with on-device large language models",
|
description = "Ask questions about images with on-device large language models",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
@ -135,12 +132,9 @@ val TASK_IMAGE_GENERATION = Task(
|
||||||
|
|
||||||
/** All tasks. */
|
/** All tasks. */
|
||||||
val TASKS: List<Task> = listOf(
|
val TASKS: List<Task> = listOf(
|
||||||
// TASK_TEXT_CLASSIFICATION,
|
TASK_LLM_ASK_IMAGE,
|
||||||
// TASK_IMAGE_CLASSIFICATION,
|
TASK_LLM_PROMPT_LAB,
|
||||||
// TASK_IMAGE_GENERATION,
|
|
||||||
TASK_LLM_USECASES,
|
|
||||||
TASK_LLM_CHAT,
|
TASK_LLM_CHAT,
|
||||||
TASK_LLM_IMAGE_TO_TEXT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
fun getModelByName(name: String): Model? {
|
fun getModelByName(name: String): Model? {
|
||||||
|
|
|
@ -25,7 +25,7 @@ import com.google.aiedge.gallery.GalleryApplication
|
||||||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
|
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextViewModel
|
import com.google.aiedge.gallery.ui.llmchat.LlmAskImageViewModel
|
||||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
|
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
|
||||||
|
@ -63,9 +63,9 @@ object ViewModelProvider {
|
||||||
LlmSingleTurnViewModel()
|
LlmSingleTurnViewModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initializer for LlmImageToTextViewModel.
|
// Initializer for LlmAskImageViewModel.
|
||||||
initializer {
|
initializer {
|
||||||
LlmImageToTextViewModel()
|
LlmAskImageViewModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initializer for ImageGenerationViewModel.
|
// Initializer for ImageGenerationViewModel.
|
||||||
|
|
|
@ -26,6 +26,8 @@ import androidx.browser.customtabs.CustomTabsIntent
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.wrapContentHeight
|
import androidx.compose.foundation.layout.wrapContentHeight
|
||||||
|
import androidx.compose.foundation.text.BasicText
|
||||||
|
import androidx.compose.foundation.text.TextAutoSize
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.automirrored.rounded.ArrowForward
|
import androidx.compose.material.icons.automirrored.rounded.ArrowForward
|
||||||
import androidx.compose.material.icons.rounded.Error
|
import androidx.compose.material.icons.rounded.Error
|
||||||
|
@ -48,6 +50,7 @@ import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.compose.ui.unit.sp
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
@ -291,11 +294,32 @@ fun DownloadAndTryButton(
|
||||||
modifier = Modifier.padding(end = 4.dp)
|
modifier = Modifier.padding(end = 4.dp)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val textColor = MaterialTheme.colorScheme.onPrimary
|
||||||
if (checkingToken) {
|
if (checkingToken) {
|
||||||
Text("Checking access...")
|
BasicText(
|
||||||
|
text = "Checking access...",
|
||||||
|
maxLines = 1,
|
||||||
|
color = { textColor },
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
autoSize = TextAutoSize.StepBased(
|
||||||
|
minFontSize = 8.sp,
|
||||||
|
maxFontSize = 14.sp,
|
||||||
|
stepSize = 1.sp
|
||||||
|
)
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
if (needToDownloadFirst) {
|
if (needToDownloadFirst) {
|
||||||
Text("Download & Try it", maxLines = 1)
|
BasicText(
|
||||||
|
text = "Download & Try",
|
||||||
|
maxLines = 1,
|
||||||
|
color = { textColor },
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
autoSize = TextAutoSize.StepBased(
|
||||||
|
minFontSize = 8.sp,
|
||||||
|
maxFontSize = 14.sp,
|
||||||
|
stepSize = 1.sp
|
||||||
|
)
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
Text("Try it", maxLines = 1)
|
Text("Try it", maxLines = 1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,6 +63,8 @@ fun ModelPageAppBar(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
onBackClicked: () -> Unit,
|
onBackClicked: () -> Unit,
|
||||||
onModelSelected: (Model) -> Unit,
|
onModelSelected: (Model) -> Unit,
|
||||||
|
inProgress: Boolean,
|
||||||
|
modelPreparing: Boolean,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
isResettingSession: Boolean = false,
|
isResettingSession: Boolean = false,
|
||||||
onResetSessionClicked: (Model) -> Unit = {},
|
onResetSessionClicked: (Model) -> Unit = {},
|
||||||
|
@ -129,15 +131,16 @@ fun ModelPageAppBar(
|
||||||
val isModelInitializing =
|
val isModelInitializing =
|
||||||
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
||||||
if (showConfigButton) {
|
if (showConfigButton) {
|
||||||
|
val enableConfigButton = !isModelInitializing && !inProgress
|
||||||
IconButton(
|
IconButton(
|
||||||
onClick = {
|
onClick = {
|
||||||
showConfigDialog = true
|
showConfigDialog = true
|
||||||
},
|
},
|
||||||
enabled = !isModelInitializing,
|
enabled = enableConfigButton,
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.scale(0.75f)
|
.scale(0.75f)
|
||||||
.offset(x = configButtonOffset)
|
.offset(x = configButtonOffset)
|
||||||
.alpha(if (isModelInitializing) 0.5f else 1f)
|
.alpha(if (!enableConfigButton) 0.5f else 1f)
|
||||||
) {
|
) {
|
||||||
Icon(
|
Icon(
|
||||||
imageVector = Icons.Rounded.Tune,
|
imageVector = Icons.Rounded.Tune,
|
||||||
|
@ -154,14 +157,15 @@ fun ModelPageAppBar(
|
||||||
modifier = Modifier.size(16.dp)
|
modifier = Modifier.size(16.dp)
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
val enableResetButton = !isModelInitializing && !modelPreparing
|
||||||
IconButton(
|
IconButton(
|
||||||
onClick = {
|
onClick = {
|
||||||
onResetSessionClicked(model)
|
onResetSessionClicked(model)
|
||||||
},
|
},
|
||||||
enabled = !isModelInitializing,
|
enabled = enableResetButton,
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.scale(0.75f)
|
.scale(0.75f)
|
||||||
.alpha(if (isModelInitializing) 0.5f else 1f)
|
.alpha(if (!enableResetButton) 0.5f else 1f)
|
||||||
) {
|
) {
|
||||||
Icon(
|
Icon(
|
||||||
imageVector = Icons.Rounded.MapsUgc,
|
imageVector = Icons.Rounded.MapsUgc,
|
||||||
|
|
|
@ -83,8 +83,11 @@ fun ModelPickerChipsPager(
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
val density = LocalDensity.current
|
val density = LocalDensity.current
|
||||||
val windowInfo = LocalWindowInfo.current
|
val windowInfo = LocalWindowInfo.current
|
||||||
val screenWidthDp =
|
val screenWidthDp = remember {
|
||||||
remember { with(density) { windowInfo.containerSize.width.toDp() } }
|
with(density) {
|
||||||
|
windowInfo.containerSize.width.toDp()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
|
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
|
||||||
pageCount = { task.models.size })
|
pageCount = { task.models.size })
|
||||||
|
@ -147,7 +150,7 @@ fun ModelPickerChipsPager(
|
||||||
}
|
}
|
||||||
Text(
|
Text(
|
||||||
model.name,
|
model.name,
|
||||||
style = MaterialTheme.typography.labelSmall,
|
style = MaterialTheme.typography.labelMedium,
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.padding(start = 4.dp)
|
.padding(start = 4.dp)
|
||||||
.widthIn(0.dp, screenWidthDp - 250.dp),
|
.widthIn(0.dp, screenWidthDp - 250.dp),
|
||||||
|
|
|
@ -21,6 +21,7 @@ import android.content.Context
|
||||||
import android.content.pm.PackageManager
|
import android.content.pm.PackageManager
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import android.os.Build
|
import android.os.Build
|
||||||
|
import android.util.Log
|
||||||
import androidx.activity.compose.ManagedActivityResultLauncher
|
import androidx.activity.compose.ManagedActivityResultLauncher
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
|
@ -39,7 +40,11 @@ import com.google.aiedge.gallery.ui.common.chat.Histogram
|
||||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.theme.customColors
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
|
import kotlinx.serialization.ExperimentalSerializationApi
|
||||||
|
import kotlinx.serialization.json.Json
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
import java.net.HttpURLConnection
|
||||||
|
import java.net.URL
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
|
@ -487,4 +492,39 @@ fun processLlmResponse(response: String): String {
|
||||||
newContent = newContent.replace("\\n", "\n")
|
newContent = newContent.replace("\\n", "\n")
|
||||||
|
|
||||||
return newContent
|
return newContent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(ExperimentalSerializationApi::class)
|
||||||
|
inline fun <reified T> getJsonResponse(url: String): T? {
|
||||||
|
try {
|
||||||
|
val connection = URL(url).openConnection() as HttpURLConnection
|
||||||
|
connection.requestMethod = "GET"
|
||||||
|
connection.connect()
|
||||||
|
|
||||||
|
val responseCode = connection.responseCode
|
||||||
|
if (responseCode == HttpURLConnection.HTTP_OK) {
|
||||||
|
val inputStream = connection.inputStream
|
||||||
|
val response = inputStream.bufferedReader().use { it.readText() }
|
||||||
|
|
||||||
|
// Parse JSON using kotlinx.serialization
|
||||||
|
val json = Json {
|
||||||
|
// Handle potential extra fields
|
||||||
|
ignoreUnknownKeys = true
|
||||||
|
allowComments = true
|
||||||
|
allowTrailingComma = true
|
||||||
|
}
|
||||||
|
val jsonObj = json.decodeFromString<T>(response)
|
||||||
|
return jsonObj
|
||||||
|
} else {
|
||||||
|
Log.e("AGUtils", "HTTP error: $responseCode")
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(
|
||||||
|
"AGUtils",
|
||||||
|
"Error when getting json response: ${e.message}"
|
||||||
|
)
|
||||||
|
e.printStackTrace()
|
||||||
|
}
|
||||||
|
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
|
@ -16,9 +16,14 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.common.chat
|
package com.google.aiedge.gallery.ui.common.chat
|
||||||
|
|
||||||
|
import androidx.compose.animation.AnimatedContent
|
||||||
|
import androidx.compose.animation.ExperimentalSharedTransitionApi
|
||||||
|
import androidx.compose.animation.SharedTransitionLayout
|
||||||
|
import androidx.compose.foundation.Image
|
||||||
import androidx.compose.foundation.background
|
import androidx.compose.foundation.background
|
||||||
import androidx.compose.foundation.clickable
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.gestures.detectTapGestures
|
import androidx.compose.foundation.gestures.detectTapGestures
|
||||||
|
import androidx.compose.foundation.gestures.detectTransformGestures
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
|
@ -28,6 +33,7 @@ import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.ime
|
import androidx.compose.foundation.layout.ime
|
||||||
import androidx.compose.foundation.layout.imePadding
|
import androidx.compose.foundation.layout.imePadding
|
||||||
|
import androidx.compose.foundation.layout.offset
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.foundation.layout.width
|
import androidx.compose.foundation.layout.width
|
||||||
|
@ -39,12 +45,15 @@ import androidx.compose.foundation.lazy.rememberLazyListState
|
||||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.outlined.Timer
|
import androidx.compose.material.icons.outlined.Timer
|
||||||
|
import androidx.compose.material.icons.rounded.Close
|
||||||
import androidx.compose.material.icons.rounded.ContentCopy
|
import androidx.compose.material.icons.rounded.ContentCopy
|
||||||
import androidx.compose.material.icons.rounded.Refresh
|
import androidx.compose.material.icons.rounded.Refresh
|
||||||
import androidx.compose.material3.Button
|
import androidx.compose.material3.Button
|
||||||
import androidx.compose.material3.Card
|
import androidx.compose.material3.Card
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.IconButton
|
||||||
|
import androidx.compose.material3.IconButtonDefaults
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.ModalBottomSheet
|
import androidx.compose.material3.ModalBottomSheet
|
||||||
import androidx.compose.material3.SnackbarHost
|
import androidx.compose.material3.SnackbarHost
|
||||||
|
@ -55,6 +64,7 @@ import androidx.compose.runtime.LaunchedEffect
|
||||||
import androidx.compose.runtime.MutableState
|
import androidx.compose.runtime.MutableState
|
||||||
import androidx.compose.runtime.collectAsState
|
import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableFloatStateOf
|
||||||
import androidx.compose.runtime.mutableIntStateOf
|
import androidx.compose.runtime.mutableIntStateOf
|
||||||
import androidx.compose.runtime.mutableStateOf
|
import androidx.compose.runtime.mutableStateOf
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
|
@ -65,12 +75,16 @@ import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.clip
|
import androidx.compose.ui.draw.clip
|
||||||
import androidx.compose.ui.geometry.Offset
|
import androidx.compose.ui.geometry.Offset
|
||||||
import androidx.compose.ui.graphics.Color
|
import androidx.compose.ui.graphics.Color
|
||||||
|
import androidx.compose.ui.graphics.RectangleShape
|
||||||
import androidx.compose.ui.graphics.asImageBitmap
|
import androidx.compose.ui.graphics.asImageBitmap
|
||||||
|
import androidx.compose.ui.graphics.graphicsLayer
|
||||||
import androidx.compose.ui.hapticfeedback.HapticFeedbackType
|
import androidx.compose.ui.hapticfeedback.HapticFeedbackType
|
||||||
import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
|
import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
|
||||||
import androidx.compose.ui.input.nestedscroll.NestedScrollSource
|
import androidx.compose.ui.input.nestedscroll.NestedScrollSource
|
||||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||||
import androidx.compose.ui.input.pointer.pointerInput
|
import androidx.compose.ui.input.pointer.pointerInput
|
||||||
|
import androidx.compose.ui.layout.ContentScale
|
||||||
|
import androidx.compose.ui.layout.onSizeChanged
|
||||||
import androidx.compose.ui.platform.LocalClipboardManager
|
import androidx.compose.ui.platform.LocalClipboardManager
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.platform.LocalDensity
|
import androidx.compose.ui.platform.LocalDensity
|
||||||
|
@ -80,6 +94,7 @@ import androidx.compose.ui.res.dimensionResource
|
||||||
import androidx.compose.ui.res.stringResource
|
import androidx.compose.ui.res.stringResource
|
||||||
import androidx.compose.ui.text.AnnotatedString
|
import androidx.compose.ui.text.AnnotatedString
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
import androidx.compose.ui.unit.IntSize
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import androidx.compose.ui.window.Dialog
|
import androidx.compose.ui.window.Dialog
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
|
@ -102,7 +117,7 @@ enum class ChatInputType {
|
||||||
/**
|
/**
|
||||||
* Composable function for the main chat panel, displaying messages and handling user input.
|
* Composable function for the main chat panel, displaying messages and handling user input.
|
||||||
*/
|
*/
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class, ExperimentalSharedTransitionApi::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun ChatPanel(
|
fun ChatPanel(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
@ -127,6 +142,7 @@ fun ChatPanel(
|
||||||
val snackbarHostState = remember { SnackbarHostState() }
|
val snackbarHostState = remember { SnackbarHostState() }
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
val haptic = LocalHapticFeedback.current
|
val haptic = LocalHapticFeedback.current
|
||||||
|
var selectedImageMessage by remember { mutableStateOf<ChatMessageImage?>(null) }
|
||||||
|
|
||||||
var curMessage by remember { mutableStateOf("") } // Correct state
|
var curMessage by remember { mutableStateOf("") } // Correct state
|
||||||
val focusManager = LocalFocusManager.current
|
val focusManager = LocalFocusManager.current
|
||||||
|
@ -200,291 +216,378 @@ fun ChatPanel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val modelInitializationStatus =
|
val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
||||||
modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
|
||||||
|
|
||||||
LaunchedEffect(modelInitializationStatus) {
|
LaunchedEffect(modelInitializationStatus) {
|
||||||
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
|
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
|
||||||
}
|
}
|
||||||
|
|
||||||
Column(
|
SharedTransitionLayout(modifier = Modifier.fillMaxSize()) {
|
||||||
modifier = modifier.imePadding()
|
AnimatedContent(targetState = selectedImageMessage) { targetSelectedImageMessage ->
|
||||||
) {
|
Column(
|
||||||
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
|
modifier = modifier.imePadding()
|
||||||
LazyColumn(
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxSize()
|
|
||||||
.nestedScroll(nestedScrollConnection),
|
|
||||||
state = listState, verticalArrangement = Arrangement.Top,
|
|
||||||
) {
|
) {
|
||||||
items(messages) { message ->
|
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
|
||||||
val imageHistoryCurIndex = remember { mutableIntStateOf(0) }
|
LazyColumn(
|
||||||
var hAlign: Alignment.Horizontal = Alignment.End
|
|
||||||
var backgroundColor: Color = MaterialTheme.customColors.userBubbleBgColor
|
|
||||||
var hardCornerAtLeftOrRight = false
|
|
||||||
var extraPaddingStart = 48.dp
|
|
||||||
var extraPaddingEnd = 0.dp
|
|
||||||
if (message.side == ChatSide.AGENT) {
|
|
||||||
hAlign = Alignment.Start
|
|
||||||
backgroundColor = MaterialTheme.customColors.agentBubbleBgColor
|
|
||||||
hardCornerAtLeftOrRight = true
|
|
||||||
extraPaddingStart = 0.dp
|
|
||||||
extraPaddingEnd = 48.dp
|
|
||||||
} else if (message.side == ChatSide.SYSTEM) {
|
|
||||||
extraPaddingStart = 24.dp
|
|
||||||
extraPaddingEnd = 24.dp
|
|
||||||
if (message.type == ChatMessageType.PROMPT_TEMPLATES) {
|
|
||||||
extraPaddingStart = 12.dp
|
|
||||||
extraPaddingEnd = 12.dp
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (message.type == ChatMessageType.IMAGE) {
|
|
||||||
backgroundColor = Color.Transparent
|
|
||||||
}
|
|
||||||
val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius)
|
|
||||||
|
|
||||||
Column(
|
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth()
|
.fillMaxSize()
|
||||||
.padding(
|
.nestedScroll(nestedScrollConnection),
|
||||||
start = 12.dp + extraPaddingStart,
|
state = listState, verticalArrangement = Arrangement.Top,
|
||||||
end = 12.dp + extraPaddingEnd,
|
|
||||||
top = 6.dp,
|
|
||||||
bottom = 6.dp,
|
|
||||||
),
|
|
||||||
horizontalAlignment = hAlign,
|
|
||||||
) {
|
) {
|
||||||
// Sender row.
|
items(messages) { message ->
|
||||||
MessageSender(
|
val imageHistoryCurIndex = remember { mutableIntStateOf(0) }
|
||||||
message = message,
|
var hAlign: Alignment.Horizontal = Alignment.End
|
||||||
agentNameRes = task.agentNameRes,
|
var backgroundColor: Color = MaterialTheme.customColors.userBubbleBgColor
|
||||||
imageHistoryCurIndex = imageHistoryCurIndex.intValue
|
var hardCornerAtLeftOrRight = false
|
||||||
)
|
var extraPaddingStart = 48.dp
|
||||||
|
var extraPaddingEnd = 0.dp
|
||||||
|
if (message.side == ChatSide.AGENT) {
|
||||||
|
hAlign = Alignment.Start
|
||||||
|
backgroundColor = MaterialTheme.customColors.agentBubbleBgColor
|
||||||
|
hardCornerAtLeftOrRight = true
|
||||||
|
extraPaddingStart = 0.dp
|
||||||
|
extraPaddingEnd = 48.dp
|
||||||
|
} else if (message.side == ChatSide.SYSTEM) {
|
||||||
|
extraPaddingStart = 24.dp
|
||||||
|
extraPaddingEnd = 24.dp
|
||||||
|
if (message.type == ChatMessageType.PROMPT_TEMPLATES) {
|
||||||
|
extraPaddingStart = 12.dp
|
||||||
|
extraPaddingEnd = 12.dp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (message.type == ChatMessageType.IMAGE) {
|
||||||
|
backgroundColor = Color.Transparent
|
||||||
|
}
|
||||||
|
val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius)
|
||||||
|
|
||||||
// Message body.
|
Column(
|
||||||
when (message) {
|
modifier = Modifier
|
||||||
// Loading.
|
.fillMaxWidth()
|
||||||
is ChatMessageLoading -> MessageBodyLoading()
|
.padding(
|
||||||
|
start = 12.dp + extraPaddingStart,
|
||||||
|
end = 12.dp + extraPaddingEnd,
|
||||||
|
top = 6.dp,
|
||||||
|
bottom = 6.dp,
|
||||||
|
),
|
||||||
|
horizontalAlignment = hAlign,
|
||||||
|
) {
|
||||||
|
// Sender row.
|
||||||
|
MessageSender(
|
||||||
|
message = message,
|
||||||
|
agentNameRes = task.agentNameRes,
|
||||||
|
imageHistoryCurIndex = imageHistoryCurIndex.intValue
|
||||||
|
)
|
||||||
|
|
||||||
// Info.
|
// Message body.
|
||||||
is ChatMessageInfo -> MessageBodyInfo(message = message)
|
when (message) {
|
||||||
|
// Loading.
|
||||||
|
is ChatMessageLoading -> MessageBodyLoading()
|
||||||
|
|
||||||
// Warning
|
// Info.
|
||||||
is ChatMessageWarning -> MessageBodyWarning(message = message)
|
is ChatMessageInfo -> MessageBodyInfo(message = message)
|
||||||
|
|
||||||
// Config values change.
|
// Warning
|
||||||
is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message)
|
is ChatMessageWarning -> MessageBodyWarning(message = message)
|
||||||
|
|
||||||
// Prompt templates.
|
// Config values change.
|
||||||
is ChatMessagePromptTemplates -> MessageBodyPromptTemplates(message = message,
|
is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message)
|
||||||
task = task,
|
|
||||||
onPromptClicked = { template ->
|
|
||||||
onSendMessage(
|
|
||||||
selectedModel,
|
|
||||||
listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER))
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Non-system messages.
|
// Prompt templates.
|
||||||
else -> {
|
is ChatMessagePromptTemplates -> MessageBodyPromptTemplates(message = message,
|
||||||
// The bubble shape around the message body.
|
task = task,
|
||||||
var messageBubbleModifier = Modifier
|
onPromptClicked = { template ->
|
||||||
.clip(
|
onSendMessage(
|
||||||
MessageBubbleShape(
|
selectedModel,
|
||||||
radius = bubbleBorderRadius,
|
listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER))
|
||||||
hardCornerAtLeftOrRight = hardCornerAtLeftOrRight
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.background(backgroundColor)
|
|
||||||
if (message is ChatMessageText) {
|
|
||||||
messageBubbleModifier = messageBubbleModifier
|
|
||||||
.pointerInput(Unit) {
|
|
||||||
detectTapGestures(
|
|
||||||
onLongPress = {
|
|
||||||
haptic.performHapticFeedback(HapticFeedbackType.LongPress)
|
|
||||||
longPressedMessage.value = message
|
|
||||||
showMessageLongPressedSheet = true
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Non-system messages.
|
||||||
|
else -> {
|
||||||
|
// The bubble shape around the message body.
|
||||||
|
var messageBubbleModifier = Modifier
|
||||||
|
.clip(
|
||||||
|
MessageBubbleShape(
|
||||||
|
radius = bubbleBorderRadius,
|
||||||
|
hardCornerAtLeftOrRight = hardCornerAtLeftOrRight
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.background(backgroundColor)
|
||||||
|
if (message is ChatMessageText) {
|
||||||
|
messageBubbleModifier = messageBubbleModifier.pointerInput(Unit) {
|
||||||
|
detectTapGestures(
|
||||||
|
onLongPress = {
|
||||||
|
haptic.performHapticFeedback(HapticFeedbackType.LongPress)
|
||||||
|
longPressedMessage.value = message
|
||||||
|
showMessageLongPressedSheet = true
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
Box(
|
||||||
Box(
|
modifier = messageBubbleModifier,
|
||||||
modifier = messageBubbleModifier,
|
|
||||||
) {
|
|
||||||
when (message) {
|
|
||||||
// Text
|
|
||||||
is ChatMessageText -> MessageBodyText(message = message)
|
|
||||||
|
|
||||||
// Image
|
|
||||||
is ChatMessageImage -> MessageBodyImage(message = message)
|
|
||||||
|
|
||||||
// Image with history (for image gen)
|
|
||||||
is ChatMessageImageWithHistory -> MessageBodyImageWithHistory(
|
|
||||||
message = message, imageHistoryCurIndex = imageHistoryCurIndex
|
|
||||||
)
|
|
||||||
|
|
||||||
// Classification result
|
|
||||||
is ChatMessageClassification -> MessageBodyClassification(
|
|
||||||
message = message,
|
|
||||||
modifier = Modifier.width(message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH)
|
|
||||||
)
|
|
||||||
|
|
||||||
// Benchmark result.
|
|
||||||
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
|
|
||||||
|
|
||||||
// Benchmark LLM result.
|
|
||||||
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
|
|
||||||
message = message,
|
|
||||||
modifier = Modifier.wrapContentWidth()
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (message.side == ChatSide.AGENT) {
|
|
||||||
Row(
|
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
|
||||||
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
|
||||||
) {
|
|
||||||
LatencyText(message = message)
|
|
||||||
// A button to show stats for the LLM message.
|
|
||||||
if (task.type == TaskType.LLM_CHAT && message is ChatMessageText
|
|
||||||
// This means we only want to show the action button when the message is done
|
|
||||||
// generating, at which point the latency will be set.
|
|
||||||
&& message.latencyMs >= 0
|
|
||||||
) {
|
) {
|
||||||
val showingStats =
|
when (message) {
|
||||||
viewModel.isShowingStats(model = selectedModel, message = message)
|
// Text
|
||||||
MessageActionButton(
|
is ChatMessageText -> MessageBodyText(message = message)
|
||||||
label = if (showingStats) "Hide stats" else "Show stats",
|
|
||||||
icon = Icons.Outlined.Timer,
|
|
||||||
onClick = {
|
|
||||||
// Toggle showing stats.
|
|
||||||
viewModel.toggleShowingStats(selectedModel, message)
|
|
||||||
|
|
||||||
// Add the stats message after the LLM message.
|
// Image
|
||||||
if (viewModel.isShowingStats(model = selectedModel, message = message)) {
|
is ChatMessageImage -> {
|
||||||
val llmBenchmarkResult = message.llmBenchmarkResult
|
if (targetSelectedImageMessage != message) {
|
||||||
if (llmBenchmarkResult != null) {
|
MessageBodyImage(
|
||||||
viewModel.insertMessageAfter(
|
message = message,
|
||||||
model = selectedModel,
|
modifier = Modifier
|
||||||
anchorMessage = message,
|
.clickable {
|
||||||
messageToAdd = llmBenchmarkResult,
|
selectedImageMessage = message
|
||||||
)
|
}
|
||||||
}
|
.sharedElement(
|
||||||
}
|
sharedContentState = rememberSharedContentState(key = "selected_image"),
|
||||||
// Remove the stats message.
|
animatedVisibilityScope = this@AnimatedContent
|
||||||
else {
|
),
|
||||||
val curMessageIndex =
|
|
||||||
viewModel.getMessageIndex(model = selectedModel, message = message)
|
|
||||||
viewModel.removeMessageAt(
|
|
||||||
model = selectedModel,
|
|
||||||
index = curMessageIndex + 1
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
enabled = !uiState.inProgress
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (message.side == ChatSide.USER) {
|
|
||||||
Row(
|
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
|
||||||
horizontalArrangement = Arrangement.spacedBy(4.dp)
|
|
||||||
) {
|
|
||||||
// Run again button.
|
|
||||||
if (selectedModel.showRunAgainButton) {
|
|
||||||
MessageActionButton(
|
|
||||||
label = stringResource(R.string.run_again),
|
|
||||||
icon = Icons.Rounded.Refresh,
|
|
||||||
onClick = {
|
|
||||||
onRunAgainClicked(selectedModel, message)
|
|
||||||
},
|
|
||||||
enabled = !uiState.inProgress
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Benchmark button
|
// Image with history (for image gen)
|
||||||
if (selectedModel.showBenchmarkButton) {
|
is ChatMessageImageWithHistory -> MessageBodyImageWithHistory(
|
||||||
MessageActionButton(
|
message = message, imageHistoryCurIndex = imageHistoryCurIndex
|
||||||
label = stringResource(R.string.benchmark),
|
)
|
||||||
icon = Icons.Outlined.Timer,
|
|
||||||
onClick = {
|
// Classification result
|
||||||
showBenchmarkConfigsDialog = true
|
is ChatMessageClassification -> MessageBodyClassification(
|
||||||
benchmarkMessage.value = message
|
message = message, modifier = Modifier.width(
|
||||||
},
|
message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH
|
||||||
enabled = !uiState.inProgress
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Benchmark result.
|
||||||
|
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
|
||||||
|
|
||||||
|
// Benchmark LLM result.
|
||||||
|
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
|
||||||
|
message = message, modifier = Modifier.wrapContentWidth()
|
||||||
|
)
|
||||||
|
|
||||||
|
else -> {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (message.side == ChatSide.AGENT) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
) {
|
||||||
|
LatencyText(message = message)
|
||||||
|
// A button to show stats for the LLM message.
|
||||||
|
if ((task.type == TaskType.LLM_CHAT || task.type == TaskType.LLM_ASK_IMAGE) && message is ChatMessageText
|
||||||
|
// This means we only want to show the action button when the message is done
|
||||||
|
// generating, at which point the latency will be set.
|
||||||
|
&& message.latencyMs >= 0
|
||||||
|
) {
|
||||||
|
val showingStats =
|
||||||
|
viewModel.isShowingStats(model = selectedModel, message = message)
|
||||||
|
MessageActionButton(
|
||||||
|
label = if (showingStats) "Hide stats" else "Show stats",
|
||||||
|
icon = Icons.Outlined.Timer,
|
||||||
|
onClick = {
|
||||||
|
// Toggle showing stats.
|
||||||
|
viewModel.toggleShowingStats(selectedModel, message)
|
||||||
|
|
||||||
|
// Add the stats message after the LLM message.
|
||||||
|
if (viewModel.isShowingStats(
|
||||||
|
model = selectedModel, message = message
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
val llmBenchmarkResult = message.llmBenchmarkResult
|
||||||
|
if (llmBenchmarkResult != null) {
|
||||||
|
viewModel.insertMessageAfter(
|
||||||
|
model = selectedModel,
|
||||||
|
anchorMessage = message,
|
||||||
|
messageToAdd = llmBenchmarkResult,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Remove the stats message.
|
||||||
|
else {
|
||||||
|
val curMessageIndex =
|
||||||
|
viewModel.getMessageIndex(
|
||||||
|
model = selectedModel,
|
||||||
|
message = message
|
||||||
|
)
|
||||||
|
viewModel.removeMessageAt(
|
||||||
|
model = selectedModel, index = curMessageIndex + 1
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
enabled = !uiState.inProgress
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (message.side == ChatSide.USER) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(4.dp)
|
||||||
|
) {
|
||||||
|
// Run again button.
|
||||||
|
if (selectedModel.showRunAgainButton) {
|
||||||
|
MessageActionButton(
|
||||||
|
label = stringResource(R.string.run_again),
|
||||||
|
icon = Icons.Rounded.Refresh,
|
||||||
|
onClick = {
|
||||||
|
onRunAgainClicked(selectedModel, message)
|
||||||
|
},
|
||||||
|
enabled = !uiState.inProgress
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark button
|
||||||
|
if (selectedModel.showBenchmarkButton) {
|
||||||
|
MessageActionButton(
|
||||||
|
label = stringResource(R.string.benchmark),
|
||||||
|
icon = Icons.Outlined.Timer,
|
||||||
|
onClick = {
|
||||||
|
showBenchmarkConfigsDialog = true
|
||||||
|
benchmarkMessage.value = message
|
||||||
|
},
|
||||||
|
enabled = !uiState.inProgress
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
|
||||||
|
|
||||||
|
// Show an info message for ask image task to get users started.
|
||||||
|
if (task.type == TaskType.LLM_ASK_IMAGE && messages.isEmpty()) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(horizontal = 16.dp)
|
||||||
|
.fillMaxSize(),
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.Center
|
||||||
|
) {
|
||||||
|
MessageBodyInfo(
|
||||||
|
ChatMessageInfo(content = "To get started, click + below to add an image and type a prompt to ask a question about it."),
|
||||||
|
smallFontSize = false
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat input
|
||||||
|
when (chatInputType) {
|
||||||
|
ChatInputType.TEXT -> {
|
||||||
|
// val isLlmTask = task.type == TaskType.LLM_CHAT
|
||||||
|
// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
|
||||||
|
val hasImageMessage = messages.any { it is ChatMessageImage }
|
||||||
|
MessageInputText(
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
curMessage = curMessage,
|
||||||
|
inProgress = uiState.inProgress,
|
||||||
|
isResettingSession = uiState.isResettingSession,
|
||||||
|
modelPreparing = uiState.preparing,
|
||||||
|
hasImageMessage = hasImageMessage,
|
||||||
|
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||||
|
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
||||||
|
onValueChanged = { curMessage = it },
|
||||||
|
onSendMessage = {
|
||||||
|
onSendMessage(selectedModel, it)
|
||||||
|
curMessage = ""
|
||||||
|
},
|
||||||
|
onOpenPromptTemplatesClicked = {
|
||||||
|
onSendMessage(
|
||||||
|
selectedModel, listOf(
|
||||||
|
ChatMessagePromptTemplates(
|
||||||
|
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
onStopButtonClicked = onStopButtonClicked,
|
||||||
|
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
||||||
|
showPromptTemplatesInMenu = false,
|
||||||
|
showImagePickerInMenu = selectedModel.llmSupportImage,
|
||||||
|
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatInputType.IMAGE -> MessageInputImage(
|
||||||
|
disableButtons = uiState.inProgress,
|
||||||
|
streamingMessage = streamingMessage,
|
||||||
|
onImageSelected = { bitmap ->
|
||||||
|
onSendMessage(
|
||||||
|
selectedModel, listOf(
|
||||||
|
ChatMessageImage(
|
||||||
|
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
onStreamImage = { bitmap ->
|
||||||
|
onStreamImageMessage(
|
||||||
|
selectedModel, ChatMessageImage(
|
||||||
|
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
|
||||||
|
)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
onStreamEnd = onStreamEnd,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
|
// A full-screen image viewer.
|
||||||
}
|
if (targetSelectedImageMessage != null) {
|
||||||
|
ZoomableBox(
|
||||||
|
modifier = Modifier
|
||||||
// Chat input
|
.fillMaxSize()
|
||||||
when (chatInputType) {
|
.background(Color.Black.copy(alpha = 0.9f))
|
||||||
ChatInputType.TEXT -> {
|
.sharedElement(
|
||||||
// val isLlmTask = task.type == TaskType.LLM_CHAT
|
rememberSharedContentState(key = "bounds"),
|
||||||
// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
|
animatedVisibilityScope = this,
|
||||||
val hasImageMessage = messages.any { it is ChatMessageImage }
|
|
||||||
MessageInputText(
|
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
|
||||||
curMessage = curMessage,
|
|
||||||
inProgress = uiState.inProgress,
|
|
||||||
isResettingSession = uiState.isResettingSession,
|
|
||||||
hasImageMessage = hasImageMessage,
|
|
||||||
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
|
||||||
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
|
||||||
onValueChanged = { curMessage = it },
|
|
||||||
onSendMessage = {
|
|
||||||
onSendMessage(selectedModel, it)
|
|
||||||
curMessage = ""
|
|
||||||
},
|
|
||||||
onOpenPromptTemplatesClicked = {
|
|
||||||
onSendMessage(
|
|
||||||
selectedModel, listOf(
|
|
||||||
ChatMessagePromptTemplates(
|
|
||||||
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
},
|
.skipToLookaheadSize(),
|
||||||
onStopButtonClicked = onStopButtonClicked,
|
) {
|
||||||
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
// Image.
|
||||||
showPromptTemplatesInMenu = false,
|
Image(
|
||||||
showImagePickerInMenu = selectedModel.llmSupportImage == true,
|
bitmap = targetSelectedImageMessage.imageBitMap,
|
||||||
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
|
contentDescription = "",
|
||||||
)
|
modifier = modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.graphicsLayer(
|
||||||
|
scaleX = scale,
|
||||||
|
scaleY = scale,
|
||||||
|
translationX = offsetX,
|
||||||
|
translationY = offsetY
|
||||||
|
)
|
||||||
|
.sharedElement(
|
||||||
|
sharedContentState = rememberSharedContentState(key = "selected_image"),
|
||||||
|
animatedVisibilityScope = this@AnimatedContent
|
||||||
|
),
|
||||||
|
contentScale = ContentScale.Fit,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Close button.
|
||||||
|
IconButton(
|
||||||
|
onClick = {
|
||||||
|
selectedImageMessage = null
|
||||||
|
},
|
||||||
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.surfaceVariant,
|
||||||
|
),
|
||||||
|
modifier = Modifier.offset(x = (-8).dp, y = 8.dp)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Rounded.Close,
|
||||||
|
contentDescription = "",
|
||||||
|
tint = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatInputType.IMAGE -> MessageInputImage(
|
|
||||||
disableButtons = uiState.inProgress,
|
|
||||||
streamingMessage = streamingMessage,
|
|
||||||
onImageSelected = { bitmap ->
|
|
||||||
onSendMessage(
|
|
||||||
selectedModel, listOf(
|
|
||||||
ChatMessageImage(
|
|
||||||
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
},
|
|
||||||
onStreamImage = { bitmap ->
|
|
||||||
onStreamImageMessage(
|
|
||||||
selectedModel, ChatMessageImage(
|
|
||||||
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
|
|
||||||
)
|
|
||||||
)
|
|
||||||
},
|
|
||||||
onStreamEnd = onStreamEnd,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -498,9 +601,7 @@ fun ChatPanel(
|
||||||
) {
|
) {
|
||||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||||
Column(
|
Column(
|
||||||
modifier = Modifier
|
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
.padding(20.dp),
|
|
||||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
|
||||||
) {
|
) {
|
||||||
// Title
|
// Title
|
||||||
Text(
|
Text(
|
||||||
|
@ -568,13 +669,10 @@ fun ChatPanel(
|
||||||
Row(
|
Row(
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||||
modifier = Modifier
|
modifier = Modifier.padding(vertical = 8.dp, horizontal = 16.dp)
|
||||||
.padding(vertical = 8.dp, horizontal = 16.dp)
|
|
||||||
) {
|
) {
|
||||||
Icon(
|
Icon(
|
||||||
Icons.Rounded.ContentCopy,
|
Icons.Rounded.ContentCopy, contentDescription = "", modifier = Modifier.size(18.dp)
|
||||||
contentDescription = "",
|
|
||||||
modifier = Modifier.size(18.dp)
|
|
||||||
)
|
)
|
||||||
Text("Copy text")
|
Text("Copy text")
|
||||||
}
|
}
|
||||||
|
@ -586,6 +684,51 @@ fun ChatPanel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ZoomableBox(
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
minScale: Float = 1f,
|
||||||
|
maxScale: Float = 5f,
|
||||||
|
content: @Composable ZoomableBoxScope.() -> Unit
|
||||||
|
) {
|
||||||
|
var scale by remember { mutableFloatStateOf(1f) }
|
||||||
|
var offsetX by remember { mutableFloatStateOf(0f) }
|
||||||
|
var offsetY by remember { mutableFloatStateOf(0f) }
|
||||||
|
var size by remember { mutableStateOf(IntSize.Zero) }
|
||||||
|
Box(
|
||||||
|
modifier = modifier
|
||||||
|
.clip(RectangleShape)
|
||||||
|
.onSizeChanged { size = it }
|
||||||
|
.pointerInput(Unit) {
|
||||||
|
detectTransformGestures { _, pan, zoom, _ ->
|
||||||
|
scale = maxOf(minScale, minOf(scale * zoom, maxScale))
|
||||||
|
val maxX = (size.width * (scale - 1)) / 2
|
||||||
|
val minX = -maxX
|
||||||
|
offsetX = maxOf(minX, minOf(maxX, offsetX + pan.x))
|
||||||
|
val maxY = (size.height * (scale - 1)) / 2
|
||||||
|
val minY = -maxY
|
||||||
|
offsetY = maxOf(minY, minOf(maxY, offsetY + pan.y))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
contentAlignment = Alignment.TopEnd
|
||||||
|
) {
|
||||||
|
val scope = ZoomableBoxScopeImpl(scale, offsetX, offsetY)
|
||||||
|
scope.content()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ZoomableBoxScope {
|
||||||
|
val scale: Float
|
||||||
|
val offsetX: Float
|
||||||
|
val offsetY: Float
|
||||||
|
}
|
||||||
|
|
||||||
|
private data class ZoomableBoxScopeImpl(
|
||||||
|
override val scale: Float,
|
||||||
|
override val offsetX: Float,
|
||||||
|
override val offsetY: Float
|
||||||
|
) : ZoomableBoxScope
|
||||||
|
|
||||||
@Preview(showBackground = true)
|
@Preview(showBackground = true)
|
||||||
@Composable
|
@Composable
|
||||||
fun ChatPanelPreview() {
|
fun ChatPanelPreview() {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package com.google.aiedge.gallery.ui.common.chat
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.activity.compose.BackHandler
|
import androidx.activity.compose.BackHandler
|
||||||
import androidx.compose.foundation.background
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.fillMaxSize
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
|
@ -27,6 +28,7 @@ import androidx.compose.foundation.pager.HorizontalPager
|
||||||
import androidx.compose.foundation.pager.rememberPagerState
|
import androidx.compose.foundation.pager.rememberPagerState
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.Scaffold
|
import androidx.compose.material3.Scaffold
|
||||||
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.LaunchedEffect
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
import androidx.compose.runtime.collectAsState
|
import androidx.compose.runtime.collectAsState
|
||||||
|
@ -36,10 +38,13 @@ import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.runtime.snapshotFlow
|
import androidx.compose.runtime.snapshotFlow
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.graphics.Color
|
||||||
import androidx.compose.ui.graphics.graphicsLayer
|
import androidx.compose.ui.graphics.graphicsLayer
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
@ -81,7 +86,7 @@ fun ChatView(
|
||||||
chatInputType: ChatInputType = ChatInputType.TEXT,
|
chatInputType: ChatInputType = ChatInputType.TEXT,
|
||||||
showStopButtonInInputWhenInProgress: Boolean = false,
|
showStopButtonInInputWhenInProgress: Boolean = false,
|
||||||
) {
|
) {
|
||||||
val uiStat by viewModel.uiState.collectAsState()
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val selectedModel = modelManagerUiState.selectedModel
|
val selectedModel = modelManagerUiState.selectedModel
|
||||||
|
|
||||||
|
@ -158,7 +163,9 @@ fun ChatView(
|
||||||
model = selectedModel,
|
model = selectedModel,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
showResetSessionButton = true,
|
showResetSessionButton = true,
|
||||||
isResettingSession = uiStat.isResettingSession,
|
isResettingSession = uiState.isResettingSession,
|
||||||
|
inProgress = uiState.inProgress,
|
||||||
|
modelPreparing = uiState.preparing,
|
||||||
onResetSessionClicked = onResetSessionClicked,
|
onResetSessionClicked = onResetSessionClicked,
|
||||||
onConfigChanged = { old, new ->
|
onConfigChanged = { old, new ->
|
||||||
viewModel.addConfigChangedMessage(
|
viewModel.addConfigChangedMessage(
|
||||||
|
|
|
@ -38,6 +38,11 @@ data class ChatUiState(
|
||||||
*/
|
*/
|
||||||
val isResettingSession: Boolean = false,
|
val isResettingSession: Boolean = false,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indicates whether the model is preparing (before outputting any result and after initializing).
|
||||||
|
*/
|
||||||
|
val preparing: Boolean = false,
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A map of model names to lists of chat messages.
|
* A map of model names to lists of chat messages.
|
||||||
*/
|
*/
|
||||||
|
@ -204,6 +209,10 @@ open class ChatViewModel(val task: Task) : ViewModel() {
|
||||||
_uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) }
|
_uiState.update { _uiState.value.copy(isResettingSession = isResettingSession) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun setPreparing(preparing: Boolean) {
|
||||||
|
_uiState.update { _uiState.value.copy(preparing = preparing) }
|
||||||
|
}
|
||||||
|
|
||||||
fun addConfigChangedMessage(
|
fun addConfigChangedMessage(
|
||||||
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
|
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
|
||||||
) {
|
) {
|
||||||
|
|
|
@ -18,6 +18,8 @@ package com.google.aiedge.gallery.ui.common.chat
|
||||||
|
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.compose.foundation.border
|
import androidx.compose.foundation.border
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
|
@ -28,9 +30,11 @@ import androidx.compose.foundation.layout.height
|
||||||
import androidx.compose.foundation.layout.offset
|
import androidx.compose.foundation.layout.offset
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.width
|
import androidx.compose.foundation.layout.width
|
||||||
|
import androidx.compose.foundation.rememberScrollState
|
||||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
import androidx.compose.foundation.text.BasicTextField
|
import androidx.compose.foundation.text.BasicTextField
|
||||||
import androidx.compose.foundation.text.KeyboardOptions
|
import androidx.compose.foundation.text.KeyboardOptions
|
||||||
|
import androidx.compose.foundation.verticalScroll
|
||||||
import androidx.compose.material3.Button
|
import androidx.compose.material3.Button
|
||||||
import androidx.compose.material3.Card
|
import androidx.compose.material3.Card
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
@ -53,6 +57,9 @@ import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.focus.FocusRequester
|
import androidx.compose.ui.focus.FocusRequester
|
||||||
import androidx.compose.ui.focus.focusRequester
|
import androidx.compose.ui.focus.focusRequester
|
||||||
import androidx.compose.ui.focus.onFocusChanged
|
import androidx.compose.ui.focus.onFocusChanged
|
||||||
|
import androidx.compose.ui.graphics.SolidColor
|
||||||
|
import androidx.compose.ui.platform.LocalFocusManager
|
||||||
|
import androidx.compose.ui.text.TextStyle
|
||||||
import androidx.compose.ui.text.input.KeyboardType
|
import androidx.compose.ui.text.input.KeyboardType
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
|
@ -89,9 +96,20 @@ fun ConfigDialog(
|
||||||
putAll(initialValues)
|
putAll(initialValues)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
val interactionSource = remember { MutableInteractionSource() }
|
||||||
|
|
||||||
Dialog(onDismissRequest = onDismissed) {
|
Dialog(onDismissRequest = onDismissed) {
|
||||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
val focusManager = LocalFocusManager.current
|
||||||
|
Card(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.clickable(
|
||||||
|
interactionSource = interactionSource, indication = null // Disable the ripple effect
|
||||||
|
) {
|
||||||
|
focusManager.clearFocus()
|
||||||
|
},
|
||||||
|
shape = RoundedCornerShape(16.dp)
|
||||||
|
) {
|
||||||
Column(
|
Column(
|
||||||
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
) {
|
) {
|
||||||
|
@ -114,7 +132,14 @@ fun ConfigDialog(
|
||||||
}
|
}
|
||||||
|
|
||||||
// List of config rows.
|
// List of config rows.
|
||||||
ConfigEditorsPanel(configs = configs, values = values)
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.verticalScroll(rememberScrollState())
|
||||||
|
.weight(1f, fill = false),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
ConfigEditorsPanel(configs = configs, values = values)
|
||||||
|
}
|
||||||
|
|
||||||
// Button row.
|
// Button row.
|
||||||
Row(
|
Row(
|
||||||
|
@ -264,6 +289,8 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
|
||||||
values[config.key.label] = NaN
|
values[config.key.label] = NaN
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
|
||||||
|
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
|
||||||
) { innerTextField ->
|
) { innerTextField ->
|
||||||
Box(
|
Box(
|
||||||
modifier = Modifier.border(
|
modifier = Modifier.border(
|
||||||
|
|
|
@ -25,17 +25,18 @@ import androidx.compose.ui.layout.ContentScale
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun MessageBodyImage(message: ChatMessageImage) {
|
fun MessageBodyImage(message: ChatMessageImage, modifier: Modifier = Modifier) {
|
||||||
val bitmapWidth = message.bitmap.width
|
val bitmapWidth = message.bitmap.width
|
||||||
val bitmapHeight = message.bitmap.height
|
val bitmapHeight = message.bitmap.height
|
||||||
val imageWidth =
|
val imageWidth =
|
||||||
if (bitmapWidth >= bitmapHeight) 200 else (200f / bitmapHeight * bitmapWidth).toInt()
|
if (bitmapWidth >= bitmapHeight) 200 else (200f / bitmapHeight * bitmapWidth).toInt()
|
||||||
val imageHeight =
|
val imageHeight =
|
||||||
if (bitmapHeight >= bitmapWidth) 200 else (200f / bitmapWidth * bitmapHeight).toInt()
|
if (bitmapHeight >= bitmapWidth) 200 else (200f / bitmapWidth * bitmapHeight).toInt()
|
||||||
|
|
||||||
Image(
|
Image(
|
||||||
bitmap = message.imageBitMap,
|
bitmap = message.imageBitMap,
|
||||||
contentDescription = "",
|
contentDescription = "",
|
||||||
modifier = Modifier
|
modifier = modifier
|
||||||
.height(imageHeight.dp)
|
.height(imageHeight.dp)
|
||||||
.width(imageWidth.dp),
|
.width(imageWidth.dp),
|
||||||
contentScale = ContentScale.Fit,
|
contentScale = ContentScale.Fit,
|
||||||
|
|
|
@ -38,7 +38,7 @@ import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
* Supports markdown.
|
* Supports markdown.
|
||||||
*/
|
*/
|
||||||
@Composable
|
@Composable
|
||||||
fun MessageBodyInfo(message: ChatMessageInfo) {
|
fun MessageBodyInfo(message: ChatMessageInfo, smallFontSize: Boolean = true) {
|
||||||
Row(
|
Row(
|
||||||
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
|
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
|
||||||
) {
|
) {
|
||||||
|
@ -47,7 +47,11 @@ fun MessageBodyInfo(message: ChatMessageInfo) {
|
||||||
.clip(RoundedCornerShape(16.dp))
|
.clip(RoundedCornerShape(16.dp))
|
||||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||||
) {
|
) {
|
||||||
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
|
MarkdownText(
|
||||||
|
text = message.content,
|
||||||
|
modifier = Modifier.padding(12.dp),
|
||||||
|
smallFontSize = smallFontSize
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,6 +103,7 @@ fun MessageInputText(
|
||||||
@StringRes textFieldPlaceHolderRes: Int,
|
@StringRes textFieldPlaceHolderRes: Int,
|
||||||
onValueChanged: (String) -> Unit,
|
onValueChanged: (String) -> Unit,
|
||||||
onSendMessage: (List<ChatMessage>) -> Unit,
|
onSendMessage: (List<ChatMessage>) -> Unit,
|
||||||
|
modelPreparing: Boolean = false,
|
||||||
onOpenPromptTemplatesClicked: () -> Unit = {},
|
onOpenPromptTemplatesClicked: () -> Unit = {},
|
||||||
onStopButtonClicked: () -> Unit = {},
|
onStopButtonClicked: () -> Unit = {},
|
||||||
showPromptTemplatesInMenu: Boolean = false,
|
showPromptTemplatesInMenu: Boolean = false,
|
||||||
|
@ -173,7 +174,7 @@ fun MessageInputText(
|
||||||
.height(80.dp)
|
.height(80.dp)
|
||||||
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
|
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
|
||||||
.clip(RoundedCornerShape(8.dp))
|
.clip(RoundedCornerShape(8.dp))
|
||||||
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(8.dp)),
|
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
|
||||||
)
|
)
|
||||||
Box(modifier = Modifier
|
Box(modifier = Modifier
|
||||||
.offset(x = 10.dp, y = (-10).dp)
|
.offset(x = 10.dp, y = (-10).dp)
|
||||||
|
@ -219,7 +220,7 @@ fun MessageInputText(
|
||||||
expanded = showAddContentMenu,
|
expanded = showAddContentMenu,
|
||||||
onDismissRequest = { showAddContentMenu = false }) {
|
onDismissRequest = { showAddContentMenu = false }) {
|
||||||
if (showImagePickerInMenu) {
|
if (showImagePickerInMenu) {
|
||||||
// Take a photo.
|
// Take a picture.
|
||||||
DropdownMenuItem(
|
DropdownMenuItem(
|
||||||
text = {
|
text = {
|
||||||
Row(
|
Row(
|
||||||
|
@ -227,7 +228,7 @@ fun MessageInputText(
|
||||||
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||||
) {
|
) {
|
||||||
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
|
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
|
||||||
Text("Take a photo")
|
Text("Take a picture")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
enabled = pickedImages.isEmpty() && !hasImageMessage,
|
enabled = pickedImages.isEmpty() && !hasImageMessage,
|
||||||
|
@ -321,7 +322,7 @@ fun MessageInputText(
|
||||||
Spacer(modifier = Modifier.width(8.dp))
|
Spacer(modifier = Modifier.width(8.dp))
|
||||||
|
|
||||||
if (inProgress && showStopButtonWhenInProgress) {
|
if (inProgress && showStopButtonWhenInProgress) {
|
||||||
if (!modelInitializing) {
|
if (!modelInitializing && !modelPreparing) {
|
||||||
IconButton(
|
IconButton(
|
||||||
onClick = onStopButtonClicked,
|
onClick = onStopButtonClicked,
|
||||||
colors = IconButtonDefaults.iconButtonColors(
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
|
|
@ -20,8 +20,9 @@ import android.content.Intent
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||||
import androidx.activity.result.contract.ActivityResultContracts
|
import androidx.activity.result.contract.ActivityResultContracts
|
||||||
import androidx.compose.animation.core.animateFloatAsState
|
import androidx.compose.animation.AnimatedContent
|
||||||
import androidx.compose.animation.core.tween
|
import androidx.compose.animation.ExperimentalSharedTransitionApi
|
||||||
|
import androidx.compose.animation.SharedTransitionLayout
|
||||||
import androidx.compose.foundation.background
|
import androidx.compose.foundation.background
|
||||||
import androidx.compose.foundation.clickable
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.interaction.MutableInteractionSource
|
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||||
|
@ -30,17 +31,13 @@ import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.heightIn
|
|
||||||
import androidx.compose.foundation.layout.offset
|
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.rounded.ChevronRight
|
import androidx.compose.material.icons.rounded.ChevronRight
|
||||||
import androidx.compose.material.icons.rounded.Settings
|
|
||||||
import androidx.compose.material.icons.rounded.UnfoldLess
|
import androidx.compose.material.icons.rounded.UnfoldLess
|
||||||
import androidx.compose.material.icons.rounded.UnfoldMore
|
import androidx.compose.material.icons.rounded.UnfoldMore
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.IconButton
|
|
||||||
import androidx.compose.material3.OutlinedButton
|
import androidx.compose.material3.OutlinedButton
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.material3.ripple
|
import androidx.compose.material3.ripple
|
||||||
|
@ -48,16 +45,12 @@ import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.collectAsState
|
import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.derivedStateOf
|
import androidx.compose.runtime.derivedStateOf
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.movableContentOf
|
|
||||||
import androidx.compose.runtime.movableContentWithReceiverOf
|
|
||||||
import androidx.compose.runtime.mutableStateOf
|
import androidx.compose.runtime.mutableStateOf
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.alpha
|
|
||||||
import androidx.compose.ui.draw.clip
|
import androidx.compose.ui.draw.clip
|
||||||
import androidx.compose.ui.layout.LookaheadScope
|
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import androidx.compose.ui.unit.Dp
|
import androidx.compose.ui.unit.Dp
|
||||||
|
@ -91,6 +84,7 @@ private val DEFAULT_VERTICAL_PADDING = 16.dp
|
||||||
* model description and buttons for learning more (opening a URL) and downloading/trying
|
* model description and buttons for learning more (opening a URL) and downloading/trying
|
||||||
* the model.
|
* the model.
|
||||||
*/
|
*/
|
||||||
|
@OptIn(ExperimentalSharedTransitionApi::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelItem(
|
fun ModelItem(
|
||||||
model: Model,
|
model: Model,
|
||||||
|
@ -117,181 +111,170 @@ fun ModelItem(
|
||||||
|
|
||||||
var isExpanded by remember { mutableStateOf(false) }
|
var isExpanded by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
// Animate alpha for model description and button rows when switching between layouts.
|
var boxModifier = modifier
|
||||||
val alphaAnimation by animateFloatAsState(
|
.fillMaxWidth()
|
||||||
targetValue = if (isExpanded) 1f else 0f,
|
.clip(RoundedCornerShape(size = 42.dp))
|
||||||
animationSpec = tween(durationMillis = LAYOUT_ANIMATION_DURATION - 50)
|
.background(
|
||||||
)
|
getTaskBgColor(task)
|
||||||
|
)
|
||||||
LookaheadScope {
|
boxModifier = if (canExpand) {
|
||||||
// Task icon.
|
boxModifier.clickable(onClick = {
|
||||||
val taskIcon = remember {
|
if (!model.imported) {
|
||||||
movableContentOf {
|
isExpanded = !isExpanded
|
||||||
TaskIcon(
|
} else {
|
||||||
task = task, modifier = Modifier.animateLayout()
|
onModelClicked(model)
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}, interactionSource = remember { MutableInteractionSource() }, indication = ripple(
|
||||||
|
bounded = true,
|
||||||
|
radius = 1000.dp,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
boxModifier
|
||||||
|
}
|
||||||
|
|
||||||
// Model name and status.
|
Box(
|
||||||
val modelNameAndStatus = remember {
|
modifier = boxModifier,
|
||||||
movableContentOf {
|
contentAlignment = Alignment.Center,
|
||||||
ModelNameAndStatus(
|
) {
|
||||||
model = model,
|
SharedTransitionLayout {
|
||||||
task = task,
|
AnimatedContent(
|
||||||
downloadStatus = downloadStatus,
|
isExpanded, label = "item_layout_transition",
|
||||||
isExpanded = isExpanded,
|
) { targetState ->
|
||||||
modifier = Modifier.animateLayout()
|
val taskIcon = @Composable {
|
||||||
)
|
TaskIcon(
|
||||||
}
|
task = task, modifier = Modifier.sharedElement(
|
||||||
}
|
sharedContentState = rememberSharedContentState(key = "task_icon"),
|
||||||
|
animatedVisibilityScope = this@AnimatedContent,
|
||||||
val actionButton = remember {
|
|
||||||
movableContentOf {
|
|
||||||
ModelItemActionButton(
|
|
||||||
context = context,
|
|
||||||
model = model,
|
|
||||||
task = task,
|
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
|
||||||
downloadStatus = downloadStatus,
|
|
||||||
onDownloadClicked = { model ->
|
|
||||||
checkNotificationPermissionAndStartDownload(
|
|
||||||
context = context,
|
|
||||||
launcher = launcher,
|
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
|
||||||
task = task,
|
|
||||||
model = model
|
|
||||||
)
|
)
|
||||||
},
|
)
|
||||||
showDeleteButton = showDeleteButton,
|
}
|
||||||
showDownloadButton = false,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expand/collapse icon, or the config icon.
|
val modelNameAndStatus = @Composable {
|
||||||
val expandButton = remember {
|
ModelNameAndStatus(
|
||||||
movableContentOf {
|
model = model,
|
||||||
if (showConfigButtonIfExisted) {
|
task = task,
|
||||||
if (downloadStatus?.status === ModelDownloadStatusType.SUCCEEDED) {
|
downloadStatus = downloadStatus,
|
||||||
if (model.configs.isNotEmpty()) {
|
isExpanded = isExpanded,
|
||||||
IconButton(onClick = onConfigClicked) {
|
animatedVisibilityScope = this@AnimatedContent,
|
||||||
Icon(
|
sharedTransitionScope = this@SharedTransitionLayout
|
||||||
Icons.Rounded.Settings,
|
)
|
||||||
contentDescription = "",
|
}
|
||||||
tint = getTaskIconColor(task)
|
|
||||||
)
|
val actionButton = @Composable {
|
||||||
}
|
ModelItemActionButton(
|
||||||
}
|
context = context,
|
||||||
}
|
model = model,
|
||||||
} else {
|
task = task,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
downloadStatus = downloadStatus,
|
||||||
|
onDownloadClicked = { model ->
|
||||||
|
checkNotificationPermissionAndStartDownload(
|
||||||
|
context = context,
|
||||||
|
launcher = launcher,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
task = task,
|
||||||
|
model = model
|
||||||
|
)
|
||||||
|
},
|
||||||
|
showDeleteButton = showDeleteButton,
|
||||||
|
showDownloadButton = false,
|
||||||
|
modifier = Modifier.sharedElement(
|
||||||
|
sharedContentState = rememberSharedContentState(key = "action_button"),
|
||||||
|
animatedVisibilityScope = this@AnimatedContent,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
val expandButton = @Composable {
|
||||||
Icon(
|
Icon(
|
||||||
// For imported model, show ">" directly indicating users can just tap the model item to
|
// For imported model, show ">" directly indicating users can just tap the model item to
|
||||||
// go into it without needing to expand it first.
|
// go into it without needing to expand it first.
|
||||||
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
|
||||||
contentDescription = "",
|
contentDescription = "",
|
||||||
tint = getTaskIconColor(task),
|
tint = getTaskIconColor(task),
|
||||||
|
modifier = Modifier.sharedElement(
|
||||||
|
sharedContentState = rememberSharedContentState(key = "expand_button"),
|
||||||
|
animatedVisibilityScope = this@AnimatedContent,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model description shown in expanded layout.
|
val description = @Composable {
|
||||||
val modelDescription = remember {
|
if (model.info.isNotEmpty()) {
|
||||||
movableContentOf { m: Modifier ->
|
MarkdownText(
|
||||||
if (model.info.isNotEmpty()) {
|
model.info, modifier = Modifier
|
||||||
MarkdownText(
|
.sharedElement(
|
||||||
model.info,
|
sharedContentState = rememberSharedContentState(key = "description"),
|
||||||
modifier = Modifier
|
animatedVisibilityScope = this@AnimatedContent,
|
||||||
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp)
|
)
|
||||||
.animateLayout()
|
.skipToLookaheadSize()
|
||||||
.then(m)
|
)
|
||||||
)
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Button rows shown in expanded layout.
|
val buttonsRow = @Composable {
|
||||||
val buttonRows = remember {
|
Row(
|
||||||
movableContentOf { m: Modifier ->
|
horizontalArrangement = Arrangement.spacedBy(12.dp), modifier = Modifier
|
||||||
Row(
|
.sharedElement(
|
||||||
modifier = Modifier
|
sharedContentState = rememberSharedContentState(key = "buttons_row"),
|
||||||
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp)
|
animatedVisibilityScope = this@AnimatedContent,
|
||||||
.animateLayout()
|
)
|
||||||
.then(m),
|
.skipToLookaheadSize()
|
||||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
) {
|
||||||
) {
|
// The "learn more" button. Click to show related urls in a bottom sheet.
|
||||||
// The "learn more" button. Click to show related urls in a bottom sheet.
|
if (model.learnMoreUrl.isNotEmpty()) {
|
||||||
if (model.learnMoreUrl.isNotEmpty()) {
|
OutlinedButton(
|
||||||
OutlinedButton(
|
onClick = {
|
||||||
onClick = {
|
if (isExpanded) {
|
||||||
if (isExpanded) {
|
val intent = Intent(Intent.ACTION_VIEW, Uri.parse(model.learnMoreUrl))
|
||||||
val intent = Intent(Intent.ACTION_VIEW, Uri.parse(model.learnMoreUrl))
|
context.startActivity(intent)
|
||||||
context.startActivity(intent)
|
}
|
||||||
}
|
},
|
||||||
},
|
) {
|
||||||
|
Text("Learn More", maxLines = 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Button to start the download and start the chat session with the model.
|
||||||
|
val needToDownloadFirst =
|
||||||
|
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
|
||||||
|
DownloadAndTryButton(task = task,
|
||||||
|
model = model,
|
||||||
|
enabled = isExpanded,
|
||||||
|
needToDownloadFirst = needToDownloadFirst,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
onClicked = { onModelClicked(model) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collapsed state.
|
||||||
|
if (!targetState) {
|
||||||
|
Column(
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(start = 18.dp, end = 18.dp)
|
||||||
|
.padding(vertical = verticalSpacing)
|
||||||
) {
|
) {
|
||||||
Text("Learn More", maxLines = 1)
|
// Icon at the left.
|
||||||
|
taskIcon()
|
||||||
|
// Model name and status at the center.
|
||||||
|
Row(modifier = Modifier.weight(1f)) {
|
||||||
|
modelNameAndStatus()
|
||||||
|
}
|
||||||
|
// Action button and expand/collapse button at the right.
|
||||||
|
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||||
|
actionButton()
|
||||||
|
expandButton()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
// Button to start the download and start the chat session with the model.
|
|
||||||
val needToDownloadFirst =
|
|
||||||
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
|
|
||||||
DownloadAndTryButton(
|
|
||||||
task = task,
|
|
||||||
model = model,
|
|
||||||
enabled = isExpanded,
|
|
||||||
needToDownloadFirst = needToDownloadFirst,
|
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
|
||||||
onClicked = { onModelClicked(model) }
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val container = remember {
|
|
||||||
movableContentWithReceiverOf<LookaheadScope, @Composable () -> Unit> { content ->
|
|
||||||
Box(
|
|
||||||
modifier = Modifier.animateLayout(),
|
|
||||||
contentAlignment = Alignment.TopEnd,
|
|
||||||
) {
|
|
||||||
content()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var boxModifier = modifier
|
|
||||||
.fillMaxWidth()
|
|
||||||
.clip(RoundedCornerShape(size = 42.dp))
|
|
||||||
.background(
|
|
||||||
getTaskBgColor(task)
|
|
||||||
)
|
|
||||||
boxModifier = if (canExpand) {
|
|
||||||
boxModifier.clickable(
|
|
||||||
onClick = {
|
|
||||||
if (!model.imported) {
|
|
||||||
isExpanded = !isExpanded
|
|
||||||
} else {
|
|
||||||
onModelClicked(model)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
interactionSource = remember { MutableInteractionSource() },
|
|
||||||
indication = ripple(
|
|
||||||
bounded = true,
|
|
||||||
radius = 500.dp,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
boxModifier
|
|
||||||
}
|
|
||||||
Box(
|
|
||||||
modifier = boxModifier,
|
|
||||||
contentAlignment = Alignment.Center
|
|
||||||
) {
|
|
||||||
if (isExpanded) {
|
|
||||||
container {
|
|
||||||
// The main part (icon, model name, status, etc)
|
|
||||||
Column(
|
Column(
|
||||||
verticalArrangement = Arrangement.spacedBy(14.dp),
|
verticalArrangement = Arrangement.spacedBy(14.dp),
|
||||||
horizontalAlignment = Alignment.CenterHorizontally,
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
@ -300,7 +283,9 @@ fun ModelItem(
|
||||||
.padding(vertical = verticalSpacing, horizontal = 18.dp)
|
.padding(vertical = verticalSpacing, horizontal = 18.dp)
|
||||||
) {
|
) {
|
||||||
Box(contentAlignment = Alignment.Center) {
|
Box(contentAlignment = Alignment.Center) {
|
||||||
|
// Icon at the top-center.
|
||||||
taskIcon()
|
taskIcon()
|
||||||
|
// Action button and expand/collapse button at the right.
|
||||||
Row(
|
Row(
|
||||||
modifier = Modifier.fillMaxWidth(),
|
modifier = Modifier.fillMaxWidth(),
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
@ -310,39 +295,12 @@ fun ModelItem(
|
||||||
expandButton()
|
expandButton()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Name and status below the icon.
|
||||||
modelNameAndStatus()
|
modelNameAndStatus()
|
||||||
modelDescription(Modifier.alpha(alphaAnimation))
|
// Description.
|
||||||
buttonRows(Modifier.alpha(alphaAnimation)) // Apply alpha here
|
description()
|
||||||
}
|
// Buttons
|
||||||
}
|
buttonsRow()
|
||||||
} else {
|
|
||||||
container {
|
|
||||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
|
||||||
// The main part (icon, model name, status, etc)
|
|
||||||
Row(
|
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
|
||||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxWidth()
|
|
||||||
.padding(start = 18.dp, end = 18.dp)
|
|
||||||
.padding(vertical = verticalSpacing)
|
|
||||||
) {
|
|
||||||
taskIcon()
|
|
||||||
Row(modifier = Modifier.weight(1f)) {
|
|
||||||
modelNameAndStatus()
|
|
||||||
}
|
|
||||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
|
||||||
actionButton()
|
|
||||||
expandButton()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Column(
|
|
||||||
modifier = Modifier.offset(y = 30.dp),
|
|
||||||
horizontalAlignment = Alignment.CenterHorizontally
|
|
||||||
) {
|
|
||||||
modelDescription(Modifier.alpha(alphaAnimation))
|
|
||||||
buttonRows(Modifier.alpha(alphaAnimation))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,8 @@ fun ModelItemActionButton(
|
||||||
// Button to cancel the download when it is in progress.
|
// Button to cancel the download when it is in progress.
|
||||||
ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = {
|
ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = {
|
||||||
modelManagerViewModel.cancelDownloadModel(
|
modelManagerViewModel.cancelDownloadModel(
|
||||||
model
|
task = task,
|
||||||
|
model = model
|
||||||
)
|
)
|
||||||
}) {
|
}) {
|
||||||
Icon(
|
Icon(
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.common.modelitem
|
package com.google.aiedge.gallery.ui.common.modelitem
|
||||||
|
|
||||||
|
import androidx.compose.animation.AnimatedVisibilityScope
|
||||||
|
import androidx.compose.animation.ExperimentalSharedTransitionApi
|
||||||
|
import androidx.compose.animation.SharedTransitionScope
|
||||||
import androidx.compose.animation.core.Animatable
|
import androidx.compose.animation.core.Animatable
|
||||||
import androidx.compose.animation.core.tween
|
import androidx.compose.animation.core.tween
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
|
@ -30,6 +33,7 @@ import androidx.compose.runtime.LaunchedEffect
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.focus.focusModifier
|
||||||
import androidx.compose.ui.text.style.TextAlign
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
import androidx.compose.ui.text.style.TextOverflow
|
import androidx.compose.ui.text.style.TextOverflow
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
|
@ -54,134 +58,164 @@ import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
|
||||||
* - "Unzipping..." status for unzipping processes.
|
* - "Unzipping..." status for unzipping processes.
|
||||||
* - Model size for successful downloads.
|
* - Model size for successful downloads.
|
||||||
*/
|
*/
|
||||||
|
@OptIn(ExperimentalSharedTransitionApi::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelNameAndStatus(
|
fun ModelNameAndStatus(
|
||||||
model: Model,
|
model: Model,
|
||||||
task: Task,
|
task: Task,
|
||||||
downloadStatus: ModelDownloadStatus?,
|
downloadStatus: ModelDownloadStatus?,
|
||||||
isExpanded: Boolean,
|
isExpanded: Boolean,
|
||||||
|
sharedTransitionScope: SharedTransitionScope,
|
||||||
|
animatedVisibilityScope: AnimatedVisibilityScope,
|
||||||
modifier: Modifier = Modifier
|
modifier: Modifier = Modifier
|
||||||
) {
|
) {
|
||||||
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
|
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
|
||||||
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
|
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
|
||||||
var curDownloadProgress = 0f
|
var curDownloadProgress = 0f
|
||||||
|
|
||||||
Column(
|
with(sharedTransitionScope) {
|
||||||
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
|
Column(
|
||||||
) {
|
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
|
||||||
// Model name.
|
|
||||||
Row(
|
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
|
||||||
) {
|
) {
|
||||||
Text(
|
// Model name.
|
||||||
model.name,
|
Row(
|
||||||
maxLines = 1,
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
overflow = TextOverflow.MiddleEllipsis,
|
) {
|
||||||
style = MaterialTheme.typography.titleMedium,
|
Text(
|
||||||
modifier = modifier,
|
model.name,
|
||||||
)
|
maxLines = 1,
|
||||||
}
|
overflow = TextOverflow.MiddleEllipsis,
|
||||||
|
style = MaterialTheme.typography.titleMedium,
|
||||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
modifier = Modifier.sharedElement(
|
||||||
// Status icon.
|
rememberSharedContentState(key = "model_name"),
|
||||||
if (!inProgress && !isPartiallyDownloaded) {
|
animatedVisibilityScope = animatedVisibilityScope
|
||||||
StatusIcon(
|
)
|
||||||
downloadStatus = downloadStatus,
|
|
||||||
modifier = modifier.padding(end = 4.dp)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Failure message.
|
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||||
if (downloadStatus != null && downloadStatus.status == ModelDownloadStatusType.FAILED) {
|
// Status icon.
|
||||||
Row(verticalAlignment = Alignment.CenterVertically) {
|
if (!inProgress && !isPartiallyDownloaded) {
|
||||||
Text(
|
StatusIcon(
|
||||||
downloadStatus.errorMessage,
|
downloadStatus = downloadStatus,
|
||||||
color = MaterialTheme.colorScheme.error,
|
modifier = modifier
|
||||||
style = labelSmallNarrow,
|
.padding(end = 4.dp)
|
||||||
overflow = TextOverflow.Ellipsis,
|
.sharedElement(
|
||||||
modifier = modifier,
|
rememberSharedContentState(key = "download_status_icon"),
|
||||||
|
animatedVisibilityScope = animatedVisibilityScope
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Status label
|
// Failure message.
|
||||||
else {
|
if (downloadStatus != null && downloadStatus.status == ModelDownloadStatusType.FAILED) {
|
||||||
var sizeLabel = model.totalBytes.humanReadableSize()
|
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||||
var fontSize = 11.sp
|
|
||||||
|
|
||||||
// Populate the status label.
|
|
||||||
if (downloadStatus != null) {
|
|
||||||
// For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime}
|
|
||||||
if (inProgress || isPartiallyDownloaded) {
|
|
||||||
var totalSize = downloadStatus.totalBytes
|
|
||||||
if (totalSize == 0L) {
|
|
||||||
totalSize = model.totalBytes
|
|
||||||
}
|
|
||||||
sizeLabel =
|
|
||||||
"${downloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
|
|
||||||
if (downloadStatus.bytesPerSecond > 0) {
|
|
||||||
sizeLabel =
|
|
||||||
"$sizeLabel · ${downloadStatus.bytesPerSecond.humanReadableSize()} / s"
|
|
||||||
if (downloadStatus.remainingMs >= 0) {
|
|
||||||
sizeLabel =
|
|
||||||
"$sizeLabel\n${downloadStatus.remainingMs.formatToHourMinSecond()} left"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (isPartiallyDownloaded) {
|
|
||||||
sizeLabel = "$sizeLabel (resuming...)"
|
|
||||||
}
|
|
||||||
curDownloadProgress =
|
|
||||||
downloadStatus.receivedBytes.toFloat() / downloadStatus.totalBytes.toFloat()
|
|
||||||
if (curDownloadProgress.isNaN()) {
|
|
||||||
curDownloadProgress = 0f
|
|
||||||
}
|
|
||||||
fontSize = 9.sp
|
|
||||||
}
|
|
||||||
// Status for unzipping.
|
|
||||||
else if (downloadStatus.status == ModelDownloadStatusType.UNZIPPING) {
|
|
||||||
sizeLabel = "Unzipping..."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Column(
|
|
||||||
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start,
|
|
||||||
) {
|
|
||||||
for ((index, line) in sizeLabel.split("\n").withIndex()) {
|
|
||||||
Text(
|
Text(
|
||||||
line,
|
downloadStatus.errorMessage,
|
||||||
color = MaterialTheme.colorScheme.secondary,
|
color = MaterialTheme.colorScheme.error,
|
||||||
maxLines = 1,
|
style = labelSmallNarrow,
|
||||||
style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp),
|
overflow = TextOverflow.Ellipsis,
|
||||||
textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start,
|
modifier = Modifier.sharedElement(
|
||||||
overflow = TextOverflow.Visible,
|
rememberSharedContentState(key = "failure_messsage"),
|
||||||
modifier = modifier.offset(y = if (index == 0) 0.dp else (-1).dp)
|
animatedVisibilityScope = animatedVisibilityScope
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Download progress bar.
|
// Status label
|
||||||
if (inProgress || isPartiallyDownloaded) {
|
else {
|
||||||
val animatedProgress = remember { Animatable(0f) }
|
var sizeLabel = model.totalBytes.humanReadableSize()
|
||||||
LinearProgressIndicator(
|
var fontSize = 11.sp
|
||||||
progress = { animatedProgress.value },
|
|
||||||
color = getTaskIconColor(task = task),
|
// Populate the status label.
|
||||||
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
if (downloadStatus != null) {
|
||||||
modifier = modifier.padding(top = 2.dp)
|
// For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime}
|
||||||
)
|
if (inProgress || isPartiallyDownloaded) {
|
||||||
LaunchedEffect(curDownloadProgress) {
|
var totalSize = downloadStatus.totalBytes
|
||||||
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
|
if (totalSize == 0L) {
|
||||||
|
totalSize = model.totalBytes
|
||||||
|
}
|
||||||
|
sizeLabel =
|
||||||
|
"${downloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
|
||||||
|
if (downloadStatus.bytesPerSecond > 0) {
|
||||||
|
sizeLabel =
|
||||||
|
"$sizeLabel · ${downloadStatus.bytesPerSecond.humanReadableSize()} / s"
|
||||||
|
if (downloadStatus.remainingMs >= 0) {
|
||||||
|
sizeLabel =
|
||||||
|
"$sizeLabel\n${downloadStatus.remainingMs.formatToHourMinSecond()} left"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (isPartiallyDownloaded) {
|
||||||
|
sizeLabel = "$sizeLabel (resuming...)"
|
||||||
|
}
|
||||||
|
curDownloadProgress =
|
||||||
|
downloadStatus.receivedBytes.toFloat() / downloadStatus.totalBytes.toFloat()
|
||||||
|
if (curDownloadProgress.isNaN()) {
|
||||||
|
curDownloadProgress = 0f
|
||||||
|
}
|
||||||
|
fontSize = 9.sp
|
||||||
|
}
|
||||||
|
// Status for unzipping.
|
||||||
|
else if (downloadStatus.status == ModelDownloadStatusType.UNZIPPING) {
|
||||||
|
sizeLabel = "Unzipping..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Column(
|
||||||
|
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start,
|
||||||
|
) {
|
||||||
|
for ((index, line) in sizeLabel.split("\n").withIndex()) {
|
||||||
|
Text(
|
||||||
|
line,
|
||||||
|
color = MaterialTheme.colorScheme.secondary,
|
||||||
|
maxLines = 1,
|
||||||
|
style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp),
|
||||||
|
textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start,
|
||||||
|
overflow = TextOverflow.Visible,
|
||||||
|
modifier = Modifier
|
||||||
|
.offset(y = if (index == 0) 0.dp else (-1).dp)
|
||||||
|
.sharedElement(
|
||||||
|
rememberSharedContentState(key = "status_label_${index}"),
|
||||||
|
animatedVisibilityScope = animatedVisibilityScope
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download progress bar.
|
||||||
|
if (inProgress || isPartiallyDownloaded) {
|
||||||
|
val animatedProgress = remember { Animatable(0f) }
|
||||||
|
LinearProgressIndicator(
|
||||||
|
progress = { animatedProgress.value },
|
||||||
|
color = getTaskIconColor(task = task),
|
||||||
|
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(top = 2.dp)
|
||||||
|
.sharedElement(
|
||||||
|
rememberSharedContentState(key = "download_progress_bar"),
|
||||||
|
animatedVisibilityScope = animatedVisibilityScope
|
||||||
|
)
|
||||||
|
)
|
||||||
|
LaunchedEffect(curDownloadProgress) {
|
||||||
|
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Unzipping progress.
|
||||||
|
else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) {
|
||||||
|
LinearProgressIndicator(
|
||||||
|
color = getTaskIconColor(task = task),
|
||||||
|
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(top = 2.dp)
|
||||||
|
.sharedElement(
|
||||||
|
rememberSharedContentState(key = "unzip_progress_bar"),
|
||||||
|
animatedVisibilityScope = animatedVisibilityScope
|
||||||
|
)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// Unzipping progress.
|
|
||||||
else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) {
|
|
||||||
LinearProgressIndicator(
|
|
||||||
color = getTaskIconColor(task = task),
|
|
||||||
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
|
||||||
modifier = Modifier
|
|
||||||
.padding(top = 2.dp),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,8 +16,11 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.home
|
package com.google.aiedge.gallery.ui.home
|
||||||
|
|
||||||
|
import android.app.Activity
|
||||||
|
import android.content.Context
|
||||||
import android.content.Intent
|
import android.content.Intent
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
|
import android.provider.OpenableColumns
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||||
import androidx.activity.result.ActivityResultLauncher
|
import androidx.activity.result.ActivityResultLauncher
|
||||||
|
@ -49,6 +52,7 @@ import androidx.compose.material.icons.automirrored.outlined.NoteAdd
|
||||||
import androidx.compose.material.icons.filled.Add
|
import androidx.compose.material.icons.filled.Add
|
||||||
import androidx.compose.material.icons.rounded.Error
|
import androidx.compose.material.icons.rounded.Error
|
||||||
import androidx.compose.material3.AlertDialog
|
import androidx.compose.material3.AlertDialog
|
||||||
|
import androidx.compose.material3.Button
|
||||||
import androidx.compose.material3.Card
|
import androidx.compose.material3.Card
|
||||||
import androidx.compose.material3.CardDefaults
|
import androidx.compose.material3.CardDefaults
|
||||||
import androidx.compose.material3.CircularProgressIndicator
|
import androidx.compose.material3.CircularProgressIndicator
|
||||||
|
@ -80,12 +84,18 @@ import androidx.compose.ui.draw.clip
|
||||||
import androidx.compose.ui.draw.scale
|
import androidx.compose.ui.draw.scale
|
||||||
import androidx.compose.ui.graphics.Brush
|
import androidx.compose.ui.graphics.Brush
|
||||||
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
import androidx.compose.ui.input.nestedscroll.nestedScroll
|
||||||
import androidx.compose.ui.layout.layout
|
|
||||||
import androidx.compose.ui.platform.LocalConfiguration
|
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
import androidx.compose.ui.platform.LocalDensity
|
||||||
|
import androidx.compose.ui.platform.LocalWindowInfo
|
||||||
import androidx.compose.ui.res.stringResource
|
import androidx.compose.ui.res.stringResource
|
||||||
|
import androidx.compose.ui.text.LinkAnnotation
|
||||||
|
import androidx.compose.ui.text.SpanStyle
|
||||||
|
import androidx.compose.ui.text.TextLinkStyles
|
||||||
|
import androidx.compose.ui.text.buildAnnotatedString
|
||||||
import androidx.compose.ui.text.font.FontWeight
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
import androidx.compose.ui.text.style.TextAlign
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
|
import androidx.compose.ui.text.style.TextDecoration
|
||||||
|
import androidx.compose.ui.text.withLink
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import androidx.compose.ui.unit.sp
|
import androidx.compose.ui.unit.sp
|
||||||
|
@ -93,7 +103,6 @@ import com.google.aiedge.gallery.GalleryTopAppBar
|
||||||
import com.google.aiedge.gallery.R
|
import com.google.aiedge.gallery.R
|
||||||
import com.google.aiedge.gallery.data.AppBarAction
|
import com.google.aiedge.gallery.data.AppBarAction
|
||||||
import com.google.aiedge.gallery.data.AppBarActionType
|
import com.google.aiedge.gallery.data.AppBarActionType
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
|
||||||
import com.google.aiedge.gallery.data.ImportedModelInfo
|
import com.google.aiedge.gallery.data.ImportedModelInfo
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.TaskIcon
|
import com.google.aiedge.gallery.ui.common.TaskIcon
|
||||||
|
@ -101,7 +110,6 @@ import com.google.aiedge.gallery.ui.common.getTaskBgColor
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||||
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
|
||||||
import com.google.aiedge.gallery.ui.theme.customColors
|
import com.google.aiedge.gallery.ui.theme.customColors
|
||||||
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
|
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
|
@ -109,6 +117,12 @@ import kotlinx.coroutines.launch
|
||||||
|
|
||||||
private const val TAG = "AGHomeScreen"
|
private const val TAG = "AGHomeScreen"
|
||||||
private const val TASK_COUNT_ANIMATION_DURATION = 250
|
private const val TASK_COUNT_ANIMATION_DURATION = 250
|
||||||
|
private const val MAX_TASK_CARD_PADDING = 24
|
||||||
|
private const val MIN_TASK_CARD_PADDING = 18
|
||||||
|
private const val MAX_TASK_CARD_RADIUS = 43.5
|
||||||
|
private const val MIN_TASK_CARD_RADIUS = 30
|
||||||
|
private const val MAX_TASK_CARD_ICON_SIZE = 56
|
||||||
|
private const val MIN_TASK_CARD_ICON_SIZE = 50
|
||||||
|
|
||||||
/** Navigation destination data */
|
/** Navigation destination data */
|
||||||
object HomeScreenDestination {
|
object HomeScreenDestination {
|
||||||
|
@ -127,6 +141,7 @@ fun HomeScreen(
|
||||||
val uiState by modelManagerViewModel.uiState.collectAsState()
|
val uiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
var showSettingsDialog by remember { mutableStateOf(false) }
|
var showSettingsDialog by remember { mutableStateOf(false) }
|
||||||
var showImportModelSheet by remember { mutableStateOf(false) }
|
var showImportModelSheet by remember { mutableStateOf(false) }
|
||||||
|
var showUnsupportedFileTypeDialog by remember { mutableStateOf(false) }
|
||||||
val sheetState = rememberModalBottomSheetState()
|
val sheetState = rememberModalBottomSheetState()
|
||||||
var showImportDialog by remember { mutableStateOf(false) }
|
var showImportDialog by remember { mutableStateOf(false) }
|
||||||
var showImportingDialog by remember { mutableStateOf(false) }
|
var showImportingDialog by remember { mutableStateOf(false) }
|
||||||
|
@ -135,17 +150,21 @@ fun HomeScreen(
|
||||||
val coroutineScope = rememberCoroutineScope()
|
val coroutineScope = rememberCoroutineScope()
|
||||||
val snackbarHostState = remember { SnackbarHostState() }
|
val snackbarHostState = remember { SnackbarHostState() }
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
|
val context = LocalContext.current
|
||||||
val nonEmptyTasks = uiState.tasks.filter { it.models.size > 0 }
|
|
||||||
val loadingHfModels = uiState.loadingHfModels
|
|
||||||
|
|
||||||
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
|
||||||
contract = ActivityResultContracts.StartActivityForResult()
|
contract = ActivityResultContracts.StartActivityForResult()
|
||||||
) { result ->
|
) { result ->
|
||||||
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
||||||
result.data?.data?.let { uri ->
|
result.data?.data?.let { uri ->
|
||||||
selectedLocalModelFileUri.value = uri
|
val fileName = getFileName(context = context, uri = uri)
|
||||||
showImportDialog = true
|
Log.d(TAG, "Selected file: $fileName")
|
||||||
|
if (fileName != null && !fileName.endsWith(".task")) {
|
||||||
|
showUnsupportedFileTypeDialog = true
|
||||||
|
} else {
|
||||||
|
selectedLocalModelFileUri.value = uri
|
||||||
|
showImportDialog = true
|
||||||
|
}
|
||||||
} ?: run {
|
} ?: run {
|
||||||
Log.d(TAG, "No file selected or URI is null.")
|
Log.d(TAG, "No file selected or URI is null.")
|
||||||
}
|
}
|
||||||
|
@ -154,36 +173,29 @@ fun HomeScreen(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Scaffold(
|
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = {
|
||||||
modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
|
GalleryTopAppBar(
|
||||||
topBar = {
|
title = stringResource(HomeScreenDestination.titleRes),
|
||||||
GalleryTopAppBar(
|
rightAction = AppBarAction(actionType = AppBarActionType.APP_SETTING, actionFn = {
|
||||||
title = stringResource(HomeScreenDestination.titleRes),
|
showSettingsDialog = true
|
||||||
rightAction = AppBarAction(
|
}),
|
||||||
actionType = AppBarActionType.APP_SETTING, actionFn = {
|
scrollBehavior = scrollBehavior,
|
||||||
showSettingsDialog = true
|
)
|
||||||
}
|
}, floatingActionButton = {
|
||||||
),
|
// A floating action button to show "import model" bottom sheet.
|
||||||
loadingHfModels = loadingHfModels,
|
SmallFloatingActionButton(
|
||||||
scrollBehavior = scrollBehavior,
|
onClick = {
|
||||||
)
|
showImportModelSheet = true
|
||||||
},
|
},
|
||||||
floatingActionButton = {
|
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
||||||
// A floating action button to show "import model" bottom sheet.
|
contentColor = MaterialTheme.colorScheme.secondary,
|
||||||
SmallFloatingActionButton(
|
) {
|
||||||
onClick = {
|
Icon(Icons.Filled.Add, "")
|
||||||
showImportModelSheet = true
|
|
||||||
},
|
|
||||||
containerColor = MaterialTheme.colorScheme.secondaryContainer,
|
|
||||||
contentColor = MaterialTheme.colorScheme.secondary,
|
|
||||||
) {
|
|
||||||
Icon(Icons.Filled.Add, "")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
) { innerPadding ->
|
}) { innerPadding ->
|
||||||
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
|
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
|
||||||
TaskList(
|
TaskList(
|
||||||
tasks = nonEmptyTasks,
|
tasks = uiState.tasks,
|
||||||
navigateToTaskScreen = navigateToTaskScreen,
|
navigateToTaskScreen = navigateToTaskScreen,
|
||||||
loadingModelAllowlist = uiState.loadingModelAllowlist,
|
loadingModelAllowlist = uiState.loadingModelAllowlist,
|
||||||
modifier = Modifier.fillMaxSize(),
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
@ -198,16 +210,8 @@ fun HomeScreen(
|
||||||
if (showSettingsDialog) {
|
if (showSettingsDialog) {
|
||||||
SettingsDialog(
|
SettingsDialog(
|
||||||
curThemeOverride = modelManagerViewModel.readThemeOverride(),
|
curThemeOverride = modelManagerViewModel.readThemeOverride(),
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
onDismissed = { showSettingsDialog = false },
|
onDismissed = { showSettingsDialog = false },
|
||||||
onOk = { curConfigValues ->
|
|
||||||
// Update theme settings.
|
|
||||||
// This will update app's theme.
|
|
||||||
val themeOverride = curConfigValues[ConfigKey.THEME.label] as String
|
|
||||||
ThemeSettings.themeOverride.value = themeOverride
|
|
||||||
|
|
||||||
// Save to data store.
|
|
||||||
modelManagerViewModel.saveThemeOverride(themeOverride)
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -232,10 +236,6 @@ fun HomeScreen(
|
||||||
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||||
addCategory(Intent.CATEGORY_OPENABLE)
|
addCategory(Intent.CATEGORY_OPENABLE)
|
||||||
type = "*/*"
|
type = "*/*"
|
||||||
putExtra(
|
|
||||||
Intent.EXTRA_MIME_TYPES,
|
|
||||||
arrayOf("application/x-binary", "application/octet-stream")
|
|
||||||
)
|
|
||||||
// Single select.
|
// Single select.
|
||||||
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
||||||
}
|
}
|
||||||
|
@ -259,13 +259,11 @@ fun HomeScreen(
|
||||||
// Import dialog
|
// Import dialog
|
||||||
if (showImportDialog) {
|
if (showImportDialog) {
|
||||||
selectedLocalModelFileUri.value?.let { uri ->
|
selectedLocalModelFileUri.value?.let { uri ->
|
||||||
ModelImportDialog(uri = uri,
|
ModelImportDialog(uri = uri, onDismiss = { showImportDialog = false }, onDone = { info ->
|
||||||
onDismiss = { showImportDialog = false },
|
selectedImportedModelInfo.value = info
|
||||||
onDone = { info ->
|
showImportDialog = false
|
||||||
selectedImportedModelInfo.value = info
|
showImportingDialog = true
|
||||||
showImportDialog = false
|
})
|
||||||
showImportingDialog = true
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,8 +271,7 @@ fun HomeScreen(
|
||||||
if (showImportingDialog) {
|
if (showImportingDialog) {
|
||||||
selectedLocalModelFileUri.value?.let { uri ->
|
selectedLocalModelFileUri.value?.let { uri ->
|
||||||
selectedImportedModelInfo.value?.let { info ->
|
selectedImportedModelInfo.value?.let { info ->
|
||||||
ModelImportingDialog(
|
ModelImportingDialog(uri = uri,
|
||||||
uri = uri,
|
|
||||||
info = info,
|
info = info,
|
||||||
onDismiss = { showImportingDialog = false },
|
onDismiss = { showImportingDialog = false },
|
||||||
onDone = {
|
onDone = {
|
||||||
|
@ -292,6 +289,22 @@ fun HomeScreen(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Alert dialog for unsupported file type.
|
||||||
|
if (showUnsupportedFileTypeDialog) {
|
||||||
|
AlertDialog(
|
||||||
|
onDismissRequest = { showUnsupportedFileTypeDialog = false },
|
||||||
|
title = { Text("Unsupported file type") },
|
||||||
|
text = {
|
||||||
|
Text("Only \".task\" file type is supported.")
|
||||||
|
},
|
||||||
|
confirmButton = {
|
||||||
|
Button(onClick = { showUnsupportedFileTypeDialog = false }) {
|
||||||
|
Text(stringResource(R.string.ok))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if (uiState.loadingModelAllowlistError.isNotEmpty()) {
|
if (uiState.loadingModelAllowlistError.isNotEmpty()) {
|
||||||
AlertDialog(
|
AlertDialog(
|
||||||
icon = {
|
icon = {
|
||||||
|
@ -307,11 +320,9 @@ fun HomeScreen(
|
||||||
modelManagerViewModel.loadModelAllowlist()
|
modelManagerViewModel.loadModelAllowlist()
|
||||||
},
|
},
|
||||||
confirmButton = {
|
confirmButton = {
|
||||||
TextButton(
|
TextButton(onClick = {
|
||||||
onClick = {
|
modelManagerViewModel.loadModelAllowlist()
|
||||||
modelManagerViewModel.loadModelAllowlist()
|
}) {
|
||||||
}
|
|
||||||
) {
|
|
||||||
Text("Retry")
|
Text("Retry")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -327,6 +338,38 @@ private fun TaskList(
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
contentPadding: PaddingValues = PaddingValues(0.dp),
|
contentPadding: PaddingValues = PaddingValues(0.dp),
|
||||||
) {
|
) {
|
||||||
|
val density = LocalDensity.current
|
||||||
|
val windowInfo = LocalWindowInfo.current
|
||||||
|
val screenWidthDp = remember {
|
||||||
|
with(density) {
|
||||||
|
windowInfo.containerSize.width.toDp()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val screenHeightDp = remember {
|
||||||
|
with(density) {
|
||||||
|
windowInfo.containerSize.height.toDp()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val sizeFraction = remember { ((screenWidthDp - 360.dp) / (410.dp - 360.dp)).coerceIn(0f, 1f) }
|
||||||
|
val linkColor = MaterialTheme.customColors.linkColor
|
||||||
|
|
||||||
|
val introText = buildAnnotatedString {
|
||||||
|
append("Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from ")
|
||||||
|
withLink(
|
||||||
|
link = LinkAnnotation.Url(
|
||||||
|
url = "https://huggingface.co/litert-community", // Replace with the actual URL
|
||||||
|
styles = TextLinkStyles(
|
||||||
|
style = SpanStyle(
|
||||||
|
color = linkColor,
|
||||||
|
textDecoration = TextDecoration.Underline,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
append("LiteRT community")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Box(modifier = modifier.fillMaxSize()) {
|
Box(modifier = modifier.fillMaxSize()) {
|
||||||
LazyVerticalGrid(
|
LazyVerticalGrid(
|
||||||
columns = GridCells.Fixed(count = 2),
|
columns = GridCells.Fixed(count = 2),
|
||||||
|
@ -335,10 +378,15 @@ private fun TaskList(
|
||||||
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
verticalArrangement = Arrangement.spacedBy(8.dp),
|
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
) {
|
) {
|
||||||
|
// New rel
|
||||||
|
item(key = "newReleaseNotification", span = { GridItemSpan(2) }) {
|
||||||
|
NewReleaseNotification()
|
||||||
|
}
|
||||||
|
|
||||||
// Headline.
|
// Headline.
|
||||||
item(key = "headline", span = { GridItemSpan(2) }) {
|
item(key = "headline", span = { GridItemSpan(2) }) {
|
||||||
Text(
|
Text(
|
||||||
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community",
|
introText,
|
||||||
textAlign = TextAlign.Center,
|
textAlign = TextAlign.Center,
|
||||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||||
modifier = Modifier.padding(bottom = 20.dp)
|
modifier = Modifier.padding(bottom = 20.dp)
|
||||||
|
@ -364,14 +412,21 @@ private fun TaskList(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Cards.
|
// LLM Cards.
|
||||||
|
item(key = "llmCardsHeader", span = { GridItemSpan(2) }) {
|
||||||
|
Text(
|
||||||
|
"Example LLM Use Cases",
|
||||||
|
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||||
|
modifier = Modifier.padding(bottom = 4.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
items(tasks) { task ->
|
items(tasks) { task ->
|
||||||
TaskCard(
|
TaskCard(
|
||||||
task = task,
|
sizeFraction = sizeFraction, task = task, onClick = {
|
||||||
onClick = {
|
|
||||||
navigateToTaskScreen(task)
|
navigateToTaskScreen(task)
|
||||||
},
|
}, modifier = Modifier
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
.aspectRatio(1f)
|
.aspectRatio(1f)
|
||||||
)
|
)
|
||||||
|
@ -388,7 +443,7 @@ private fun TaskList(
|
||||||
Box(
|
Box(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
.height(LocalConfiguration.current.screenHeightDp.dp * 0.25f)
|
.height(screenHeightDp * 0.25f)
|
||||||
.background(
|
.background(
|
||||||
Brush.verticalGradient(
|
Brush.verticalGradient(
|
||||||
colors = MaterialTheme.customColors.homeBottomGradient,
|
colors = MaterialTheme.customColors.homeBottomGradient,
|
||||||
|
@ -400,7 +455,15 @@ private fun TaskList(
|
||||||
}
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) {
|
private fun TaskCard(
|
||||||
|
task: Task, onClick: () -> Unit, sizeFraction: Float, modifier: Modifier = Modifier
|
||||||
|
) {
|
||||||
|
val padding =
|
||||||
|
(MAX_TASK_CARD_PADDING - MIN_TASK_CARD_PADDING) * sizeFraction + MIN_TASK_CARD_PADDING
|
||||||
|
val radius = (MAX_TASK_CARD_RADIUS - MIN_TASK_CARD_RADIUS) * sizeFraction + MIN_TASK_CARD_RADIUS
|
||||||
|
val iconSize =
|
||||||
|
(MAX_TASK_CARD_ICON_SIZE - MIN_TASK_CARD_ICON_SIZE) * sizeFraction + MIN_TASK_CARD_ICON_SIZE
|
||||||
|
|
||||||
// Observes the model count and updates the model count label with a fade-in/fade-out animation
|
// Observes the model count and updates the model count label with a fade-in/fade-out animation
|
||||||
// whenever the count changes.
|
// whenever the count changes.
|
||||||
val modelCount by remember {
|
val modelCount by remember {
|
||||||
|
@ -445,7 +508,7 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
|
||||||
|
|
||||||
Card(
|
Card(
|
||||||
modifier = modifier
|
modifier = modifier
|
||||||
.clip(RoundedCornerShape(43.5.dp))
|
.clip(RoundedCornerShape(radius.dp))
|
||||||
.clickable(
|
.clickable(
|
||||||
onClick = onClick,
|
onClick = onClick,
|
||||||
),
|
),
|
||||||
|
@ -456,39 +519,24 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
|
||||||
Column(
|
Column(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxSize()
|
.fillMaxSize()
|
||||||
.padding(24.dp),
|
.padding(padding.dp),
|
||||||
) {
|
) {
|
||||||
// Icon.
|
// Icon.
|
||||||
TaskIcon(task = task)
|
TaskIcon(task = task, width = iconSize.dp)
|
||||||
|
|
||||||
Spacer(modifier = Modifier.weight(1f))
|
Spacer(modifier = Modifier.weight(2f))
|
||||||
|
|
||||||
// Title.
|
// Title.
|
||||||
val pair = task.type.label.splitByFirstSpace()
|
|
||||||
Text(
|
Text(
|
||||||
pair.first,
|
task.type.label,
|
||||||
color = MaterialTheme.colorScheme.primary,
|
color = MaterialTheme.colorScheme.primary,
|
||||||
style = titleMediumNarrow.copy(
|
style = titleMediumNarrow.copy(
|
||||||
fontSize = 20.sp,
|
fontSize = 20.sp,
|
||||||
fontWeight = FontWeight.Bold,
|
fontWeight = FontWeight.Bold,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if (pair.second.isNotEmpty()) {
|
|
||||||
Text(
|
Spacer(modifier = Modifier.weight(1f))
|
||||||
pair.second,
|
|
||||||
color = MaterialTheme.colorScheme.primary,
|
|
||||||
style = titleMediumNarrow.copy(
|
|
||||||
fontSize = 18.sp,
|
|
||||||
fontWeight = FontWeight.Bold,
|
|
||||||
),
|
|
||||||
modifier = Modifier.layout { measurable, constraints ->
|
|
||||||
val placeable = measurable.measure(constraints)
|
|
||||||
layout(placeable.width, placeable.height) {
|
|
||||||
placeable.placeRelative(0, -4.dp.roundToPx())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model count.
|
// Model count.
|
||||||
Text(
|
Text(
|
||||||
|
@ -503,12 +551,21 @@ private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun String.splitByFirstSpace(): Pair<String, String> {
|
// Helper function to get the file name from a URI
|
||||||
val spaceIndex = this.indexOf(' ')
|
fun getFileName(context: Context, uri: Uri): String? {
|
||||||
if (spaceIndex == -1) {
|
if (uri.scheme == "content") {
|
||||||
return Pair(this, "")
|
context.contentResolver.query(uri, null, null, null, null)?.use { cursor ->
|
||||||
|
if (cursor.moveToFirst()) {
|
||||||
|
val nameIndex = cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME)
|
||||||
|
if (nameIndex != -1) {
|
||||||
|
return cursor.getString(nameIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (uri.scheme == "file") {
|
||||||
|
return uri.lastPathSegment
|
||||||
}
|
}
|
||||||
return Pair(this.substring(0, spaceIndex), this.substring(spaceIndex + 1))
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
@Preview
|
@Preview
|
||||||
|
|
|
@ -22,12 +22,16 @@ import android.provider.OpenableColumns
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.compose.animation.core.Animatable
|
import androidx.compose.animation.core.Animatable
|
||||||
import androidx.compose.animation.core.tween
|
import androidx.compose.animation.core.tween
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.rememberScrollState
|
||||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
|
import androidx.compose.foundation.verticalScroll
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.rounded.Error
|
import androidx.compose.material.icons.rounded.Error
|
||||||
import androidx.compose.material3.Button
|
import androidx.compose.material3.Button
|
||||||
|
@ -51,6 +55,7 @@ import androidx.compose.runtime.snapshots.SnapshotStateMap
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
import androidx.compose.ui.platform.LocalFocusManager
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import androidx.compose.ui.window.Dialog
|
import androidx.compose.ui.window.Dialog
|
||||||
import androidx.compose.ui.window.DialogProperties
|
import androidx.compose.ui.window.DialogProperties
|
||||||
|
@ -82,41 +87,34 @@ import java.nio.charset.StandardCharsets
|
||||||
private const val TAG = "AGModelImportDialog"
|
private const val TAG = "AGModelImportDialog"
|
||||||
|
|
||||||
private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
|
private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
|
||||||
LabelConfig(key = ConfigKey.NAME),
|
LabelConfig(key = ConfigKey.NAME), LabelConfig(key = ConfigKey.MODEL_TYPE), NumberSliderConfig(
|
||||||
LabelConfig(key = ConfigKey.MODEL_TYPE),
|
|
||||||
NumberSliderConfig(
|
|
||||||
key = ConfigKey.DEFAULT_MAX_TOKENS,
|
key = ConfigKey.DEFAULT_MAX_TOKENS,
|
||||||
sliderMin = 100f,
|
sliderMin = 100f,
|
||||||
sliderMax = 1024f,
|
sliderMax = 1024f,
|
||||||
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
|
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
|
||||||
valueType = ValueType.INT
|
valueType = ValueType.INT
|
||||||
),
|
), NumberSliderConfig(
|
||||||
NumberSliderConfig(
|
|
||||||
key = ConfigKey.DEFAULT_TOPK,
|
key = ConfigKey.DEFAULT_TOPK,
|
||||||
sliderMin = 5f,
|
sliderMin = 5f,
|
||||||
sliderMax = 40f,
|
sliderMax = 40f,
|
||||||
defaultValue = DEFAULT_TOPK.toFloat(),
|
defaultValue = DEFAULT_TOPK.toFloat(),
|
||||||
valueType = ValueType.INT
|
valueType = ValueType.INT
|
||||||
),
|
), NumberSliderConfig(
|
||||||
NumberSliderConfig(
|
|
||||||
key = ConfigKey.DEFAULT_TOPP,
|
key = ConfigKey.DEFAULT_TOPP,
|
||||||
sliderMin = 0.0f,
|
sliderMin = 0.0f,
|
||||||
sliderMax = 1.0f,
|
sliderMax = 1.0f,
|
||||||
defaultValue = DEFAULT_TOPP,
|
defaultValue = DEFAULT_TOPP,
|
||||||
valueType = ValueType.FLOAT
|
valueType = ValueType.FLOAT
|
||||||
),
|
), NumberSliderConfig(
|
||||||
NumberSliderConfig(
|
|
||||||
key = ConfigKey.DEFAULT_TEMPERATURE,
|
key = ConfigKey.DEFAULT_TEMPERATURE,
|
||||||
sliderMin = 0.0f,
|
sliderMin = 0.0f,
|
||||||
sliderMax = 2.0f,
|
sliderMax = 2.0f,
|
||||||
defaultValue = DEFAULT_TEMPERATURE,
|
defaultValue = DEFAULT_TEMPERATURE,
|
||||||
valueType = ValueType.FLOAT
|
valueType = ValueType.FLOAT
|
||||||
),
|
), BooleanSwitchConfig(
|
||||||
BooleanSwitchConfig(
|
|
||||||
key = ConfigKey.SUPPORT_IMAGE,
|
key = ConfigKey.SUPPORT_IMAGE,
|
||||||
defaultValue = false,
|
defaultValue = false,
|
||||||
),
|
), SegmentedButtonConfig(
|
||||||
SegmentedButtonConfig(
|
|
||||||
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
||||||
defaultValue = Accelerator.CPU.label,
|
defaultValue = Accelerator.CPU.label,
|
||||||
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
|
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
|
||||||
|
@ -126,9 +124,7 @@ private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelImportDialog(
|
fun ModelImportDialog(
|
||||||
uri: Uri,
|
uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
|
||||||
onDismiss: () -> Unit,
|
|
||||||
onDone: (ImportedModelInfo) -> Unit
|
|
||||||
) {
|
) {
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
|
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
|
||||||
|
@ -150,15 +146,23 @@ fun ModelImportDialog(
|
||||||
putAll(initialValues)
|
putAll(initialValues)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
val interactionSource = remember { MutableInteractionSource() }
|
||||||
|
|
||||||
Dialog(
|
Dialog(
|
||||||
onDismissRequest = onDismiss,
|
onDismissRequest = onDismiss,
|
||||||
) {
|
) {
|
||||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
val focusManager = LocalFocusManager.current
|
||||||
|
Card(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.clickable(
|
||||||
|
interactionSource = interactionSource, indication = null // Disable the ripple effect
|
||||||
|
) {
|
||||||
|
focusManager.clearFocus()
|
||||||
|
}, shape = RoundedCornerShape(16.dp)
|
||||||
|
) {
|
||||||
Column(
|
Column(
|
||||||
modifier = Modifier
|
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
.padding(20.dp),
|
|
||||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
|
||||||
) {
|
) {
|
||||||
// Title.
|
// Title.
|
||||||
Text(
|
Text(
|
||||||
|
@ -167,11 +171,18 @@ fun ModelImportDialog(
|
||||||
modifier = Modifier.padding(bottom = 8.dp)
|
modifier = Modifier.padding(bottom = 8.dp)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Default configs for users to set.
|
Column(
|
||||||
ConfigEditorsPanel(
|
modifier = Modifier
|
||||||
configs = IMPORT_CONFIGS_LLM,
|
.verticalScroll(rememberScrollState())
|
||||||
values = values,
|
.weight(1f, fill = false),
|
||||||
)
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
// Default configs for users to set.
|
||||||
|
ConfigEditorsPanel(
|
||||||
|
configs = IMPORT_CONFIGS_LLM,
|
||||||
|
values = values,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// Button row.
|
// Button row.
|
||||||
Row(
|
Row(
|
||||||
|
@ -210,10 +221,7 @@ fun ModelImportDialog(
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelImportingDialog(
|
fun ModelImportingDialog(
|
||||||
uri: Uri,
|
uri: Uri, info: ImportedModelInfo, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
|
||||||
info: ImportedModelInfo,
|
|
||||||
onDismiss: () -> Unit,
|
|
||||||
onDone: (ImportedModelInfo) -> Unit
|
|
||||||
) {
|
) {
|
||||||
var error by remember { mutableStateOf("") }
|
var error by remember { mutableStateOf("") }
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
|
@ -222,8 +230,7 @@ fun ModelImportingDialog(
|
||||||
|
|
||||||
LaunchedEffect(Unit) {
|
LaunchedEffect(Unit) {
|
||||||
// Import.
|
// Import.
|
||||||
importModel(
|
importModel(context = context,
|
||||||
context = context,
|
|
||||||
coroutineScope = coroutineScope,
|
coroutineScope = coroutineScope,
|
||||||
fileName = info.fileName,
|
fileName = info.fileName,
|
||||||
fileSize = info.fileSize,
|
fileSize = info.fileSize,
|
||||||
|
@ -236,8 +243,7 @@ fun ModelImportingDialog(
|
||||||
},
|
},
|
||||||
onError = {
|
onError = {
|
||||||
error = it
|
error = it
|
||||||
}
|
})
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Dialog(
|
Dialog(
|
||||||
|
@ -246,9 +252,7 @@ fun ModelImportingDialog(
|
||||||
) {
|
) {
|
||||||
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
|
||||||
Column(
|
Column(
|
||||||
modifier = Modifier
|
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
.padding(20.dp),
|
|
||||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
|
||||||
) {
|
) {
|
||||||
// Title.
|
// Title.
|
||||||
Text(
|
Text(
|
||||||
|
@ -280,13 +284,10 @@ fun ModelImportingDialog(
|
||||||
// Has error.
|
// Has error.
|
||||||
else {
|
else {
|
||||||
Row(
|
Row(
|
||||||
verticalAlignment = Alignment.Top,
|
verticalAlignment = Alignment.Top, horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||||
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
|
||||||
) {
|
) {
|
||||||
Icon(
|
Icon(
|
||||||
Icons.Rounded.Error,
|
Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error
|
||||||
contentDescription = "",
|
|
||||||
tint = MaterialTheme.colorScheme.error
|
|
||||||
)
|
)
|
||||||
Text(
|
Text(
|
||||||
error,
|
error,
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
package com.google.aiedge.gallery.ui.home
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.compose.animation.AnimatedVisibility
|
||||||
|
import androidx.compose.animation.expandVertically
|
||||||
|
import androidx.compose.animation.fadeIn
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.automirrored.rounded.OpenInNew
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.google.aiedge.gallery.BuildConfig
|
||||||
|
import com.google.aiedge.gallery.ui.common.getJsonResponse
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ClickableLink
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import kotlinx.serialization.Serializable
|
||||||
|
import kotlin.math.max
|
||||||
|
|
||||||
|
private const val TAG = "AGNewReleaseNotification"
|
||||||
|
private const val REPO = "google-ai-edge/gallery"
|
||||||
|
|
||||||
|
@Serializable
|
||||||
|
data class ReleaseInfo(
|
||||||
|
val html_url: String,
|
||||||
|
val tag_name: String,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun NewReleaseNotification() {
|
||||||
|
var newReleaseVersion by remember { mutableStateOf("") }
|
||||||
|
var newReleaseUrl by remember { mutableStateOf("") }
|
||||||
|
|
||||||
|
LaunchedEffect(Unit) {
|
||||||
|
withContext(Dispatchers.IO) {
|
||||||
|
Log.d("AGNewReleaseNotification", "Checking for new release...")
|
||||||
|
val info = getJsonResponse<ReleaseInfo>("https://api.github.com/repos/$REPO/releases/latest")
|
||||||
|
if (info != null) {
|
||||||
|
val curRelease = BuildConfig.VERSION_NAME
|
||||||
|
val newRelease = info.tag_name
|
||||||
|
val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease)
|
||||||
|
Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer")
|
||||||
|
if (isNewer) {
|
||||||
|
newReleaseVersion = newRelease
|
||||||
|
newReleaseUrl = info.html_url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AnimatedVisibility(
|
||||||
|
visible = newReleaseVersion.isNotEmpty(),
|
||||||
|
enter = fadeIn() + expandVertically()
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.SpaceBetween,
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(horizontal = 16.dp)
|
||||||
|
.padding(bottom = 12.dp)
|
||||||
|
.clip(
|
||||||
|
CircleShape
|
||||||
|
)
|
||||||
|
.background(MaterialTheme.colorScheme.tertiaryContainer)
|
||||||
|
.padding(4.dp)
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"New release $newReleaseVersion available",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
modifier = Modifier.padding(start = 12.dp)
|
||||||
|
)
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.padding(end = 12.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
ClickableLink(
|
||||||
|
url = newReleaseUrl,
|
||||||
|
linkText = "View",
|
||||||
|
icon = Icons.AutoMirrored.Rounded.OpenInNew,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun isNewerRelease(currentRelease: String, newRelease: String): Boolean {
|
||||||
|
// Split the version strings into their individual components (e.g., "0.9.0" -> ["0", "9", "0"])
|
||||||
|
val currentComponents = currentRelease.split('.').map { it.toIntOrNull() ?: 0 }
|
||||||
|
val newComponents = newRelease.split('.').map { it.toIntOrNull() ?: 0 }
|
||||||
|
|
||||||
|
// Determine the maximum number of components to iterate through
|
||||||
|
val maxComponents = max(currentComponents.size, newComponents.size)
|
||||||
|
|
||||||
|
// Iterate through the components from left to right (major, minor, patch, etc.)
|
||||||
|
for (i in 0 until maxComponents) {
|
||||||
|
val currentComponent = currentComponents.getOrElse(i) { 0 }
|
||||||
|
val newComponent = newComponents.getOrElse(i) { 0 }
|
||||||
|
|
||||||
|
if (newComponent > currentComponent) {
|
||||||
|
return true
|
||||||
|
} else if (newComponent < currentComponent) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
|
@ -16,45 +16,266 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.home
|
package com.google.aiedge.gallery.ui.home
|
||||||
|
|
||||||
|
import androidx.compose.foundation.border
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.interaction.MutableInteractionSource
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.height
|
||||||
|
import androidx.compose.foundation.layout.offset
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.wrapContentHeight
|
||||||
|
import androidx.compose.foundation.rememberScrollState
|
||||||
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
|
import androidx.compose.foundation.text.BasicTextField
|
||||||
|
import androidx.compose.foundation.verticalScroll
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.rounded.CheckCircle
|
||||||
|
import androidx.compose.material3.Button
|
||||||
|
import androidx.compose.material3.Card
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.IconButton
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.MultiChoiceSegmentedButtonRow
|
||||||
|
import androidx.compose.material3.OutlinedButton
|
||||||
|
import androidx.compose.material3.SegmentedButton
|
||||||
|
import androidx.compose.material3.SegmentedButtonDefaults
|
||||||
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.focus.FocusRequester
|
||||||
|
import androidx.compose.ui.focus.focusRequester
|
||||||
|
import androidx.compose.ui.focus.onFocusChanged
|
||||||
|
import androidx.compose.ui.graphics.SolidColor
|
||||||
|
import androidx.compose.ui.platform.LocalFocusManager
|
||||||
|
import androidx.compose.ui.text.TextStyle
|
||||||
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.compose.ui.window.Dialog
|
||||||
import com.google.aiedge.gallery.BuildConfig
|
import com.google.aiedge.gallery.BuildConfig
|
||||||
import com.google.aiedge.gallery.data.Config
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.data.ConfigKey
|
|
||||||
import com.google.aiedge.gallery.data.SegmentedButtonConfig
|
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
|
||||||
import com.google.aiedge.gallery.ui.theme.THEME_AUTO
|
import com.google.aiedge.gallery.ui.theme.THEME_AUTO
|
||||||
import com.google.aiedge.gallery.ui.theme.THEME_DARK
|
import com.google.aiedge.gallery.ui.theme.THEME_DARK
|
||||||
import com.google.aiedge.gallery.ui.theme.THEME_LIGHT
|
import com.google.aiedge.gallery.ui.theme.THEME_LIGHT
|
||||||
|
import com.google.aiedge.gallery.ui.theme.ThemeSettings
|
||||||
|
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
|
||||||
|
import java.time.Instant
|
||||||
|
import java.time.ZoneId
|
||||||
|
import java.time.format.DateTimeFormatter
|
||||||
|
import java.util.Locale
|
||||||
|
import kotlin.math.min
|
||||||
|
|
||||||
private val CONFIGS: List<Config> = listOf(
|
private val THEME_OPTIONS = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK)
|
||||||
SegmentedButtonConfig(
|
|
||||||
key = ConfigKey.THEME,
|
|
||||||
defaultValue = THEME_AUTO,
|
|
||||||
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun SettingsDialog(
|
fun SettingsDialog(
|
||||||
curThemeOverride: String,
|
curThemeOverride: String,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
onDismissed: () -> Unit,
|
onDismissed: () -> Unit,
|
||||||
onOk: (Map<String, Any>) -> Unit,
|
|
||||||
) {
|
) {
|
||||||
val initialValues = mapOf(
|
var selectedTheme by remember { mutableStateOf(curThemeOverride) }
|
||||||
ConfigKey.THEME.label to curThemeOverride
|
var hfToken by remember { mutableStateOf(modelManagerViewModel.getTokenStatusAndData().data) }
|
||||||
)
|
val dateFormatter = remember {
|
||||||
ConfigDialog(
|
DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.systemDefault())
|
||||||
title = "Settings",
|
.withLocale(Locale.getDefault())
|
||||||
subtitle = "App version: ${BuildConfig.VERSION_NAME}",
|
}
|
||||||
okBtnLabel = "OK",
|
var customHfToken by remember { mutableStateOf("") }
|
||||||
configs = CONFIGS,
|
var isFocused by remember { mutableStateOf(false) }
|
||||||
initialValues = initialValues,
|
val focusRequester = remember { FocusRequester() }
|
||||||
onDismissed = onDismissed,
|
val interactionSource = remember { MutableInteractionSource() }
|
||||||
onOk = { curConfigValues ->
|
|
||||||
onOk(curConfigValues)
|
|
||||||
|
|
||||||
// Hide config dialog.
|
Dialog(onDismissRequest = onDismissed) {
|
||||||
onDismissed()
|
val focusManager = LocalFocusManager.current
|
||||||
},
|
Card(
|
||||||
)
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.clickable(
|
||||||
|
interactionSource = interactionSource, indication = null // Disable the ripple effect
|
||||||
|
) {
|
||||||
|
focusManager.clearFocus()
|
||||||
|
}, shape = RoundedCornerShape(16.dp)
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
// Dialog title and subtitle.
|
||||||
|
Column {
|
||||||
|
Text(
|
||||||
|
"Settings",
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
modifier = Modifier.padding(bottom = 8.dp)
|
||||||
|
)
|
||||||
|
// Subtitle.
|
||||||
|
Text(
|
||||||
|
"App version: ${BuildConfig.VERSION_NAME}",
|
||||||
|
style = labelSmallNarrow,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||||
|
modifier = Modifier.offset(y = (-6).dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.verticalScroll(rememberScrollState())
|
||||||
|
.weight(1f, fill = false),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
// Theme switcher.
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.fillMaxWidth()
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"Theme",
|
||||||
|
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
|
||||||
|
)
|
||||||
|
MultiChoiceSegmentedButtonRow {
|
||||||
|
THEME_OPTIONS.forEachIndexed { index, label ->
|
||||||
|
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
|
||||||
|
index = index, count = THEME_OPTIONS.size
|
||||||
|
), onCheckedChange = {
|
||||||
|
selectedTheme = label
|
||||||
|
|
||||||
|
// Update theme settings.
|
||||||
|
// This will update app's theme.
|
||||||
|
ThemeSettings.themeOverride.value = label
|
||||||
|
|
||||||
|
// Save to data store.
|
||||||
|
modelManagerViewModel.saveThemeOverride(label)
|
||||||
|
}, checked = label == selectedTheme, label = { Text(label) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HF Token management.
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.fillMaxWidth(), verticalArrangement = Arrangement.spacedBy(4.dp)
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"HuggingFace access token",
|
||||||
|
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
|
||||||
|
)
|
||||||
|
// Show the start of the token.
|
||||||
|
val curHfToken = hfToken
|
||||||
|
if (curHfToken != null) {
|
||||||
|
Text(
|
||||||
|
curHfToken.accessToken.substring(0, min(16, curHfToken.accessToken.length)) + "...",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
Text(
|
||||||
|
"Not available",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
"The token will be automatically retrieved when a gated model is downloaded",
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
|
||||||
|
OutlinedButton(
|
||||||
|
onClick = {
|
||||||
|
modelManagerViewModel.clearAccessToken()
|
||||||
|
hfToken = null
|
||||||
|
}, enabled = curHfToken != null
|
||||||
|
) {
|
||||||
|
Text("Clear")
|
||||||
|
}
|
||||||
|
BasicTextField(
|
||||||
|
value = customHfToken,
|
||||||
|
singleLine = true,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(top = 4.dp)
|
||||||
|
.focusRequester(focusRequester)
|
||||||
|
.onFocusChanged {
|
||||||
|
isFocused = it.isFocused
|
||||||
|
},
|
||||||
|
onValueChange = {
|
||||||
|
customHfToken = it
|
||||||
|
},
|
||||||
|
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
|
||||||
|
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
|
||||||
|
) { innerTextField ->
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.border(
|
||||||
|
width = if (isFocused) 2.dp else 1.dp,
|
||||||
|
color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline,
|
||||||
|
shape = CircleShape,
|
||||||
|
)
|
||||||
|
.height(40.dp), contentAlignment = Alignment.CenterStart
|
||||||
|
) {
|
||||||
|
Row(verticalAlignment = Alignment.CenterVertically) {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(start = 16.dp)
|
||||||
|
.weight(1f)
|
||||||
|
) {
|
||||||
|
if (customHfToken.isEmpty()) {
|
||||||
|
Text(
|
||||||
|
"Enter token manually",
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||||
|
style = MaterialTheme.typography.bodySmall
|
||||||
|
)
|
||||||
|
}
|
||||||
|
innerTextField()
|
||||||
|
}
|
||||||
|
if (customHfToken.isNotEmpty()) {
|
||||||
|
IconButton(
|
||||||
|
modifier = Modifier.offset(x = 1.dp),
|
||||||
|
onClick = {
|
||||||
|
modelManagerViewModel.saveAccessToken(
|
||||||
|
accessToken = customHfToken,
|
||||||
|
refreshToken = "",
|
||||||
|
expiresAt = System.currentTimeMillis() + 1000L * 60 * 60 * 24 * 365 * 10,
|
||||||
|
)
|
||||||
|
hfToken = modelManagerViewModel.getTokenStatusAndData().data
|
||||||
|
}) {
|
||||||
|
Icon(Icons.Rounded.CheckCircle, contentDescription = "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Button row.
|
||||||
|
Row(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(top = 8.dp),
|
||||||
|
horizontalArrangement = Arrangement.End,
|
||||||
|
) {
|
||||||
|
// Close button
|
||||||
|
Button(
|
||||||
|
onClick = {
|
||||||
|
onDismissed()
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
Text("Close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,9 +34,9 @@ object LlmChatDestination {
|
||||||
val route = "LlmChatRoute"
|
val route = "LlmChatRoute"
|
||||||
}
|
}
|
||||||
|
|
||||||
object LlmImageToTextDestination {
|
object LlmAskImageDestination {
|
||||||
@Serializable
|
@Serializable
|
||||||
val route = "LlmImageToTextRoute"
|
val route = "LlmAskImageRoute"
|
||||||
}
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
|
@ -57,11 +57,11 @@ fun LlmChatScreen(
|
||||||
}
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun LlmImageToTextScreen(
|
fun LlmAskImageScreen(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
navigateUp: () -> Unit,
|
navigateUp: () -> Unit,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
viewModel: LlmImageToTextViewModel = viewModel(
|
viewModel: LlmAskImageViewModel = viewModel(
|
||||||
factory = ViewModelProvider.Factory
|
factory = ViewModelProvider.Factory
|
||||||
),
|
),
|
||||||
) {
|
) {
|
||||||
|
@ -129,12 +129,7 @@ fun ChatViewWrapper(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
|
onBenchmarkClicked = { _, _, _, _ ->
|
||||||
if (message is ChatMessageText) {
|
|
||||||
viewModel.benchmark(
|
|
||||||
model = model, message = message
|
|
||||||
)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
onResetSessionClicked = { model ->
|
onResetSessionClicked = { model ->
|
||||||
viewModel.resetSession(model = model)
|
viewModel.resetSession(model = model)
|
||||||
|
|
|
@ -22,7 +22,7 @@ import android.util.Log
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
||||||
|
@ -49,6 +49,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
|
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
setInProgress(true)
|
setInProgress(true)
|
||||||
|
setPreparing(true)
|
||||||
|
|
||||||
// Loading.
|
// Loading.
|
||||||
addMessage(
|
addMessage(
|
||||||
|
@ -88,6 +89,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
timeToFirstToken = (firstTokenTs - start) / 1000f
|
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||||
prefillSpeed = prefillTokens / timeToFirstToken
|
prefillSpeed = prefillTokens / timeToFirstToken
|
||||||
firstRun = false
|
firstRun = false
|
||||||
|
setPreparing(false)
|
||||||
} else {
|
} else {
|
||||||
decodeTokens++
|
decodeTokens++
|
||||||
}
|
}
|
||||||
|
@ -137,10 +139,12 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
},
|
},
|
||||||
cleanUpListener = {
|
cleanUpListener = {
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
|
setPreparing(false)
|
||||||
})
|
})
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.e(TAG, "Error occurred while running inference", e)
|
Log.e(TAG, "Error occurred while running inference", e)
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
|
setPreparing(false)
|
||||||
onError()
|
onError()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -194,98 +198,6 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun benchmark(model: Model, message: ChatMessageText) {
|
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
|
||||||
setInProgress(true)
|
|
||||||
|
|
||||||
// Wait for model to be initialized.
|
|
||||||
while (model.instance == null) {
|
|
||||||
delay(100)
|
|
||||||
}
|
|
||||||
val instance = model.instance as LlmModelInstance
|
|
||||||
val prefillTokens = instance.session.sizeInTokens(message.content)
|
|
||||||
|
|
||||||
// Add the message to show benchmark results.
|
|
||||||
val benchmarkLlmResult = ChatMessageBenchmarkLlmResult(
|
|
||||||
orderedStats = STATS,
|
|
||||||
statValues = mutableMapOf(),
|
|
||||||
running = true,
|
|
||||||
latencyMs = -1f,
|
|
||||||
)
|
|
||||||
addMessage(model = model, message = benchmarkLlmResult)
|
|
||||||
|
|
||||||
// Run inference.
|
|
||||||
val result = StringBuilder()
|
|
||||||
var firstRun = true
|
|
||||||
var timeToFirstToken = 0f
|
|
||||||
var firstTokenTs = 0L
|
|
||||||
var decodeTokens = 0
|
|
||||||
var prefillSpeed = 0f
|
|
||||||
var decodeSpeed: Float
|
|
||||||
val start = System.currentTimeMillis()
|
|
||||||
var lastUpdateTime = 0L
|
|
||||||
LlmChatModelHelper.runInference(model = model,
|
|
||||||
input = message.content,
|
|
||||||
resultListener = { partialResult, done ->
|
|
||||||
val curTs = System.currentTimeMillis()
|
|
||||||
|
|
||||||
if (firstRun) {
|
|
||||||
firstTokenTs = System.currentTimeMillis()
|
|
||||||
timeToFirstToken = (firstTokenTs - start) / 1000f
|
|
||||||
prefillSpeed = prefillTokens / timeToFirstToken
|
|
||||||
firstRun = false
|
|
||||||
|
|
||||||
// Update message to show prefill speed.
|
|
||||||
replaceLastMessage(
|
|
||||||
model = model,
|
|
||||||
message = ChatMessageBenchmarkLlmResult(
|
|
||||||
orderedStats = STATS,
|
|
||||||
statValues = mutableMapOf(
|
|
||||||
"prefill_speed" to prefillSpeed,
|
|
||||||
"time_to_first_token" to timeToFirstToken,
|
|
||||||
"latency" to (curTs - start).toFloat() / 1000f,
|
|
||||||
),
|
|
||||||
running = false,
|
|
||||||
latencyMs = -1f,
|
|
||||||
),
|
|
||||||
type = ChatMessageType.BENCHMARK_LLM_RESULT,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
decodeTokens++
|
|
||||||
}
|
|
||||||
result.append(partialResult)
|
|
||||||
|
|
||||||
if (curTs - lastUpdateTime > 500 || done) {
|
|
||||||
decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f)
|
|
||||||
if (decodeSpeed.isNaN()) {
|
|
||||||
decodeSpeed = 0f
|
|
||||||
}
|
|
||||||
replaceLastMessage(
|
|
||||||
model = model, message = ChatMessageBenchmarkLlmResult(
|
|
||||||
orderedStats = STATS,
|
|
||||||
statValues = mutableMapOf(
|
|
||||||
"prefill_speed" to prefillSpeed,
|
|
||||||
"decode_speed" to decodeSpeed,
|
|
||||||
"time_to_first_token" to timeToFirstToken,
|
|
||||||
"latency" to (curTs - start).toFloat() / 1000f,
|
|
||||||
),
|
|
||||||
running = !done,
|
|
||||||
latencyMs = -1f,
|
|
||||||
), type = ChatMessageType.BENCHMARK_LLM_RESULT
|
|
||||||
)
|
|
||||||
lastUpdateTime = curTs
|
|
||||||
|
|
||||||
if (done) {
|
|
||||||
setInProgress(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
cleanUpListener = {
|
|
||||||
setInProgress(false)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun handleError(
|
fun handleError(
|
||||||
context: Context,
|
context: Context,
|
||||||
model: Model,
|
model: Model,
|
||||||
|
@ -320,15 +232,8 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
)
|
)
|
||||||
|
|
||||||
// Re-generate the response automatically.
|
// Re-generate the response automatically.
|
||||||
generateResponse(model = model, input = triggeredMessage.content, onError = {
|
generateResponse(model = model, input = triggeredMessage.content, onError = {})
|
||||||
handleError(
|
|
||||||
context = context,
|
|
||||||
model = model,
|
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
|
||||||
triggeredMessage = triggeredMessage
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class LlmImageToTextViewModel : LlmChatViewModel(curTask = TASK_LLM_IMAGE_TO_TEXT)
|
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
|
|
@ -73,6 +73,7 @@ fun LlmSingleTurnScreen(
|
||||||
) {
|
) {
|
||||||
val task = viewModel.task
|
val task = viewModel.task
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
val selectedModel = modelManagerUiState.selectedModel
|
val selectedModel = modelManagerUiState.selectedModel
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
|
@ -114,6 +115,8 @@ fun LlmSingleTurnScreen(
|
||||||
task = task,
|
task = task,
|
||||||
model = selectedModel,
|
model = selectedModel,
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
inProgress = uiState.inProgress,
|
||||||
|
modelPreparing = uiState.preparing,
|
||||||
onConfigChanged = { _, _ -> },
|
onConfigChanged = { _, _ -> },
|
||||||
onBackClicked = { handleNavigateUp() },
|
onBackClicked = { handleNavigateUp() },
|
||||||
onModelSelected = { newSelectedModel ->
|
onModelSelected = { newSelectedModel ->
|
||||||
|
|
|
@ -20,10 +20,9 @@ import android.util.Log
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
|
|
||||||
import com.google.aiedge.gallery.ui.common.chat.Stat
|
import com.google.aiedge.gallery.ui.common.chat.Stat
|
||||||
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
import com.google.aiedge.gallery.ui.common.processLlmResponse
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
|
||||||
|
@ -44,9 +43,9 @@ data class LlmSingleTurnUiState(
|
||||||
val inProgress: Boolean = false,
|
val inProgress: Boolean = false,
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates whether the model is currently being initialized.
|
* Indicates whether the model is preparing (before outputting any result and after initializing).
|
||||||
*/
|
*/
|
||||||
val initializing: Boolean = false,
|
val preparing: Boolean = false,
|
||||||
|
|
||||||
// model -> <template label -> response>
|
// model -> <template label -> response>
|
||||||
val responsesByModel: Map<String, Map<String, String>>,
|
val responsesByModel: Map<String, Map<String, String>>,
|
||||||
|
@ -65,14 +64,14 @@ private val STATS = listOf(
|
||||||
Stat(id = "latency", label = "Latency", unit = "sec")
|
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||||
)
|
)
|
||||||
|
|
||||||
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewModel() {
|
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() {
|
||||||
private val _uiState = MutableStateFlow(createUiState(task = task))
|
private val _uiState = MutableStateFlow(createUiState(task = task))
|
||||||
val uiState = _uiState.asStateFlow()
|
val uiState = _uiState.asStateFlow()
|
||||||
|
|
||||||
fun generateResponse(model: Model, input: String) {
|
fun generateResponse(model: Model, input: String) {
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
setInProgress(true)
|
setInProgress(true)
|
||||||
setInitializing(true)
|
setPreparing(true)
|
||||||
|
|
||||||
// Wait for instance to be initialized.
|
// Wait for instance to be initialized.
|
||||||
while (model.instance == null) {
|
while (model.instance == null) {
|
||||||
|
@ -98,7 +97,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
|
||||||
val curTs = System.currentTimeMillis()
|
val curTs = System.currentTimeMillis()
|
||||||
|
|
||||||
if (firstRun) {
|
if (firstRun) {
|
||||||
setInitializing(false)
|
setPreparing(false)
|
||||||
firstTokenTs = System.currentTimeMillis()
|
firstTokenTs = System.currentTimeMillis()
|
||||||
timeToFirstToken = (firstTokenTs - start) / 1000f
|
timeToFirstToken = (firstTokenTs - start) / 1000f
|
||||||
prefillSpeed = prefillTokens / timeToFirstToken
|
prefillSpeed = prefillTokens / timeToFirstToken
|
||||||
|
@ -148,7 +147,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
|
||||||
},
|
},
|
||||||
singleTurn = true,
|
singleTurn = true,
|
||||||
cleanUpListener = {
|
cleanUpListener = {
|
||||||
setInitializing(false)
|
setPreparing(false)
|
||||||
setInProgress(false)
|
setInProgress(false)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -167,8 +166,8 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_USECASES) : ViewMode
|
||||||
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
|
||||||
}
|
}
|
||||||
|
|
||||||
fun setInitializing(initializing: Boolean) {
|
fun setPreparing(preparing: Boolean) {
|
||||||
_uiState.update { _uiState.value.copy(initializing = initializing) }
|
_uiState.update { _uiState.value.copy(preparing = preparing) }
|
||||||
}
|
}
|
||||||
|
|
||||||
fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) {
|
fun updateResponse(model: Model, promptTemplateType: PromptTemplateType, response: String) {
|
||||||
|
|
|
@ -339,7 +339,7 @@ fun PromptTemplatesPanel(
|
||||||
|
|
||||||
val modelInitializing =
|
val modelInitializing =
|
||||||
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
|
||||||
if (inProgress && !modelInitializing) {
|
if (inProgress && !modelInitializing && !uiState.preparing) {
|
||||||
IconButton(
|
IconButton(
|
||||||
onClick = {
|
onClick = {
|
||||||
onStopButtonClicked(model)
|
onStopButtonClicked(model)
|
||||||
|
|
|
@ -57,7 +57,7 @@ import androidx.compose.ui.platform.LocalClipboardManager
|
||||||
import androidx.compose.ui.text.AnnotatedString
|
import androidx.compose.ui.text.AnnotatedString
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
||||||
|
@ -76,11 +76,11 @@ fun ResponsePanel(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
) {
|
) {
|
||||||
val task = TASK_LLM_USECASES
|
val task = TASK_LLM_PROMPT_LAB
|
||||||
val uiState by viewModel.uiState.collectAsState()
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val inProgress = uiState.inProgress
|
val inProgress = uiState.inProgress
|
||||||
val initializing = uiState.initializing
|
val initializing = uiState.preparing
|
||||||
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||||
val responseScrollState = rememberScrollState()
|
val responseScrollState = rememberScrollState()
|
||||||
var selectedOptionIndex by remember { mutableIntStateOf(0) }
|
var selectedOptionIndex by remember { mutableIntStateOf(0) }
|
||||||
|
|
|
@ -70,7 +70,7 @@ fun ModelList(
|
||||||
) {
|
) {
|
||||||
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
|
// This is just to update "models" list when task.updateTrigger is updated so that the UI can
|
||||||
// be properly updated.
|
// be properly updated.
|
||||||
val models by remember {
|
val models by remember(task) {
|
||||||
derivedStateOf {
|
derivedStateOf {
|
||||||
val trigger = task.updateTrigger.value
|
val trigger = task.updateTrigger.value
|
||||||
if (trigger >= 0) {
|
if (trigger >= 0) {
|
||||||
|
@ -80,7 +80,7 @@ fun ModelList(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val importedModels by remember {
|
val importedModels by remember(task) {
|
||||||
derivedStateOf {
|
derivedStateOf {
|
||||||
val trigger = task.updateTrigger.value
|
val trigger = task.updateTrigger.value
|
||||||
if (trigger >= 0) {
|
if (trigger >= 0) {
|
||||||
|
|
|
@ -37,15 +37,16 @@ import com.google.aiedge.gallery.data.ModelAllowlist
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
import com.google.aiedge.gallery.data.ModelDownloadStatus
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.TASKS
|
import com.google.aiedge.gallery.data.TASKS
|
||||||
|
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
import com.google.aiedge.gallery.data.ValueType
|
import com.google.aiedge.gallery.data.ValueType
|
||||||
import com.google.aiedge.gallery.data.getModelByName
|
import com.google.aiedge.gallery.data.getModelByName
|
||||||
import com.google.aiedge.gallery.ui.common.AuthConfig
|
import com.google.aiedge.gallery.ui.common.AuthConfig
|
||||||
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
|
||||||
|
import com.google.aiedge.gallery.ui.common.getJsonResponse
|
||||||
import com.google.aiedge.gallery.ui.common.processTasks
|
import com.google.aiedge.gallery.ui.common.processTasks
|
||||||
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
|
||||||
|
@ -59,8 +60,6 @@ import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.flow.update
|
import kotlinx.coroutines.flow.update
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import kotlinx.serialization.ExperimentalSerializationApi
|
|
||||||
import kotlinx.serialization.json.Json
|
|
||||||
import net.openid.appauth.AuthorizationException
|
import net.openid.appauth.AuthorizationException
|
||||||
import net.openid.appauth.AuthorizationRequest
|
import net.openid.appauth.AuthorizationRequest
|
||||||
import net.openid.appauth.AuthorizationResponse
|
import net.openid.appauth.AuthorizationResponse
|
||||||
|
@ -73,7 +72,7 @@ import java.net.URL
|
||||||
private const val TAG = "AGModelManagerViewModel"
|
private const val TAG = "AGModelManagerViewModel"
|
||||||
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
|
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
|
||||||
private const val MODEL_ALLOWLIST_URL =
|
private const val MODEL_ALLOWLIST_URL =
|
||||||
"https://raw.githubusercontent.com/jinjingforever/kokoro-codelab-jingjin/refs/heads/main/model_allowlist.json"
|
"https://raw.githubusercontent.com/google-ai-edge/gallery/refs/heads/main/model_allowlist.json"
|
||||||
|
|
||||||
data class ModelInitializationStatus(
|
data class ModelInitializationStatus(
|
||||||
val status: ModelInitializationStatusType, var error: String = ""
|
val status: ModelInitializationStatusType, var error: String = ""
|
||||||
|
@ -116,11 +115,6 @@ data class ModelManagerUiState(
|
||||||
*/
|
*/
|
||||||
val modelInitializationStatus: Map<String, ModelInitializationStatus>,
|
val modelInitializationStatus: Map<String, ModelInitializationStatus>,
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether Hugging Face models from the given community are currently being loaded.
|
|
||||||
*/
|
|
||||||
val loadingHfModels: Boolean = false,
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Whether the app is loading and processing the model allowlist.
|
* Whether the app is loading and processing the model allowlist.
|
||||||
*/
|
*/
|
||||||
|
@ -196,8 +190,9 @@ open class ModelManagerViewModel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun cancelDownloadModel(model: Model) {
|
fun cancelDownloadModel(task: Task, model: Model) {
|
||||||
downloadRepository.cancelDownloadModel(model)
|
downloadRepository.cancelDownloadModel(model)
|
||||||
|
deleteModel(task = task, model = model)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun deleteModel(task: Task, model: Model) {
|
fun deleteModel(task: Task, model: Model) {
|
||||||
|
@ -311,13 +306,13 @@ open class ModelManagerViewModel(
|
||||||
onDone = onDone,
|
onDone = onDone,
|
||||||
)
|
)
|
||||||
|
|
||||||
TaskType.LLM_USECASES -> LlmChatModelHelper.initialize(
|
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.initialize(
|
||||||
context = context,
|
context = context,
|
||||||
model = model,
|
model = model,
|
||||||
onDone = onDone,
|
onDone = onDone,
|
||||||
)
|
)
|
||||||
|
|
||||||
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.initialize(
|
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.initialize(
|
||||||
context = context,
|
context = context,
|
||||||
model = model,
|
model = model,
|
||||||
onDone = onDone,
|
onDone = onDone,
|
||||||
|
@ -341,8 +336,8 @@ open class ModelManagerViewModel(
|
||||||
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
||||||
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
|
||||||
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
TaskType.LLM_USECASES -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
TaskType.LLM_IMAGE_TO_TEXT -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
|
@ -446,14 +441,14 @@ open class ModelManagerViewModel(
|
||||||
// Create model.
|
// Create model.
|
||||||
val model = createModelFromImportedModelInfo(info = info)
|
val model = createModelFromImportedModelInfo(info = info)
|
||||||
|
|
||||||
for (task in listOf(TASK_LLM_CHAT, TASK_LLM_USECASES, TASK_LLM_IMAGE_TO_TEXT)) {
|
for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
|
||||||
// Remove duplicated imported model if existed.
|
// Remove duplicated imported model if existed.
|
||||||
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
||||||
if (modelIndex >= 0) {
|
if (modelIndex >= 0) {
|
||||||
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
||||||
task.models.removeAt(modelIndex)
|
task.models.removeAt(modelIndex)
|
||||||
}
|
}
|
||||||
if (task == TASK_LLM_IMAGE_TO_TEXT && model.llmSupportImage || task != TASK_LLM_IMAGE_TO_TEXT) {
|
if (task == TASK_LLM_ASK_IMAGE && model.llmSupportImage || task != TASK_LLM_ASK_IMAGE) {
|
||||||
task.models.add(model)
|
task.models.add(model)
|
||||||
}
|
}
|
||||||
task.updateTrigger.value = System.currentTimeMillis()
|
task.updateTrigger.value = System.currentTimeMillis()
|
||||||
|
@ -502,7 +497,7 @@ open class ModelManagerViewModel(
|
||||||
|
|
||||||
// Check expiration (with 5-minute buffer).
|
// Check expiration (with 5-minute buffer).
|
||||||
val curTs = System.currentTimeMillis()
|
val curTs = System.currentTimeMillis()
|
||||||
val expirationTs = tokenData.expiresAtSeconds - 5 * 60
|
val expirationTs = tokenData.expiresAtMs - 5 * 60
|
||||||
Log.d(
|
Log.d(
|
||||||
TAG,
|
TAG,
|
||||||
"Checking whether token has expired or not. Current ts: $curTs, expires at: $expirationTs"
|
"Checking whether token has expired or not. Current ts: $curTs, expires at: $expirationTs"
|
||||||
|
@ -562,7 +557,7 @@ open class ModelManagerViewModel(
|
||||||
} else {
|
} else {
|
||||||
// Token exchange successful. Store the tokens securely
|
// Token exchange successful. Store the tokens securely
|
||||||
Log.d(TAG, "Token exchange successful. Storing tokens...")
|
Log.d(TAG, "Token exchange successful. Storing tokens...")
|
||||||
dataStoreRepository.saveAccessTokenData(
|
saveAccessToken(
|
||||||
accessToken = tokenResponse.accessToken!!,
|
accessToken = tokenResponse.accessToken!!,
|
||||||
refreshToken = tokenResponse.refreshToken!!,
|
refreshToken = tokenResponse.refreshToken!!,
|
||||||
expiresAt = tokenResponse.accessTokenExpirationTime!!
|
expiresAt = tokenResponse.accessTokenExpirationTime!!
|
||||||
|
@ -606,6 +601,18 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun saveAccessToken(accessToken: String, refreshToken: String, expiresAt: Long) {
|
||||||
|
dataStoreRepository.saveAccessTokenData(
|
||||||
|
accessToken = accessToken,
|
||||||
|
refreshToken = refreshToken,
|
||||||
|
expiresAt = expiresAt,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun clearAccessToken() {
|
||||||
|
dataStoreRepository.clearAccessTokenData()
|
||||||
|
}
|
||||||
|
|
||||||
private fun processPendingDownloads() {
|
private fun processPendingDownloads() {
|
||||||
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
|
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
|
||||||
|
|
||||||
|
@ -673,11 +680,11 @@ open class ModelManagerViewModel(
|
||||||
if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) {
|
if (allowedModel.taskTypes.contains(TASK_LLM_CHAT.type.id)) {
|
||||||
TASK_LLM_CHAT.models.add(model)
|
TASK_LLM_CHAT.models.add(model)
|
||||||
}
|
}
|
||||||
if (allowedModel.taskTypes.contains(TASK_LLM_USECASES.type.id)) {
|
if (allowedModel.taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)) {
|
||||||
TASK_LLM_USECASES.models.add(model)
|
TASK_LLM_PROMPT_LAB.models.add(model)
|
||||||
}
|
}
|
||||||
if (allowedModel.taskTypes.contains(TASK_LLM_IMAGE_TO_TEXT.type.id)) {
|
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
|
||||||
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
|
TASK_LLM_ASK_IMAGE.models.add(model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -732,9 +739,9 @@ open class ModelManagerViewModel(
|
||||||
|
|
||||||
// Add to task.
|
// Add to task.
|
||||||
TASK_LLM_CHAT.models.add(model)
|
TASK_LLM_CHAT.models.add(model)
|
||||||
TASK_LLM_USECASES.models.add(model)
|
TASK_LLM_PROMPT_LAB.models.add(model)
|
||||||
if (model.llmSupportImage) {
|
if (model.llmSupportImage) {
|
||||||
TASK_LLM_IMAGE_TO_TEXT.models.add(model)
|
TASK_LLM_ASK_IMAGE.models.add(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update status.
|
// Update status.
|
||||||
|
@ -827,38 +834,6 @@ open class ModelManagerViewModel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@OptIn(ExperimentalSerializationApi::class)
|
|
||||||
private inline fun <reified T> getJsonResponse(url: String): T? {
|
|
||||||
try {
|
|
||||||
val connection = URL(url).openConnection() as HttpURLConnection
|
|
||||||
connection.requestMethod = "GET"
|
|
||||||
connection.connect()
|
|
||||||
|
|
||||||
val responseCode = connection.responseCode
|
|
||||||
if (responseCode == HttpURLConnection.HTTP_OK) {
|
|
||||||
val inputStream = connection.inputStream
|
|
||||||
val response = inputStream.bufferedReader().use { it.readText() }
|
|
||||||
|
|
||||||
// Parse JSON using kotlinx.serialization
|
|
||||||
val json = Json {
|
|
||||||
// Handle potential extra fields
|
|
||||||
ignoreUnknownKeys = true
|
|
||||||
allowComments = true
|
|
||||||
allowTrailingComma = true
|
|
||||||
}
|
|
||||||
val jsonObj = json.decodeFromString<T>(response)
|
|
||||||
return jsonObj
|
|
||||||
} else {
|
|
||||||
Log.e(TAG, "HTTP error: $responseCode")
|
|
||||||
}
|
|
||||||
} catch (e: Exception) {
|
|
||||||
Log.e(TAG, "Error when getting json response: ${e.message}")
|
|
||||||
e.printStackTrace()
|
|
||||||
}
|
|
||||||
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun isFileInExternalFilesDir(fileName: String): Boolean {
|
private fun isFileInExternalFilesDir(fileName: String): Boolean {
|
||||||
if (externalFilesDir != null) {
|
if (externalFilesDir != null) {
|
||||||
val file = File(externalFilesDir, fileName)
|
val file = File(externalFilesDir, fileName)
|
||||||
|
|
|
@ -46,8 +46,8 @@ import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
|
||||||
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_IMAGE_TO_TEXT
|
import com.google.aiedge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_USECASES
|
import com.google.aiedge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||||
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.data.TaskType
|
import com.google.aiedge.gallery.data.TaskType
|
||||||
|
@ -60,8 +60,8 @@ import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
|
||||||
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
|
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
|
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextDestination
|
import com.google.aiedge.gallery.ui.llmchat.LlmAskImageDestination
|
||||||
import com.google.aiedge.gallery.ui.llmchat.LlmImageToTextScreen
|
import com.google.aiedge.gallery.ui.llmchat.LlmAskImageScreen
|
||||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
|
||||||
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
|
import com.google.aiedge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
|
||||||
|
@ -233,7 +233,7 @@ fun GalleryNavHost(
|
||||||
enterTransition = { slideEnter() },
|
enterTransition = { slideEnter() },
|
||||||
exitTransition = { slideExit() },
|
exitTransition = { slideExit() },
|
||||||
) {
|
) {
|
||||||
getModelFromNavigationParam(it, TASK_LLM_USECASES)?.let { defaultModel ->
|
getModelFromNavigationParam(it, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
|
||||||
modelManagerViewModel.selectModel(defaultModel)
|
modelManagerViewModel.selectModel(defaultModel)
|
||||||
|
|
||||||
LlmSingleTurnScreen(
|
LlmSingleTurnScreen(
|
||||||
|
@ -245,15 +245,15 @@ fun GalleryNavHost(
|
||||||
|
|
||||||
// LLM image to text.
|
// LLM image to text.
|
||||||
composable(
|
composable(
|
||||||
route = "${LlmImageToTextDestination.route}/{modelName}",
|
route = "${LlmAskImageDestination.route}/{modelName}",
|
||||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||||
enterTransition = { slideEnter() },
|
enterTransition = { slideEnter() },
|
||||||
exitTransition = { slideExit() },
|
exitTransition = { slideExit() },
|
||||||
) {
|
) {
|
||||||
getModelFromNavigationParam(it, TASK_LLM_IMAGE_TO_TEXT)?.let { defaultModel ->
|
getModelFromNavigationParam(it, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
|
||||||
modelManagerViewModel.selectModel(defaultModel)
|
modelManagerViewModel.selectModel(defaultModel)
|
||||||
|
|
||||||
LlmImageToTextScreen(
|
LlmAskImageScreen(
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
navigateUp = { navController.navigateUp() },
|
navigateUp = { navController.navigateUp() },
|
||||||
)
|
)
|
||||||
|
@ -287,8 +287,8 @@ fun navigateToTaskScreen(
|
||||||
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
|
||||||
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
|
||||||
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
||||||
TaskType.LLM_IMAGE_TO_TEXT -> navController.navigate("${LlmImageToTextDestination.route}/${modelName}")
|
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
|
||||||
TaskType.LLM_USECASES -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
TaskType.LLM_PROMPT_LAB -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||||
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
|
|
|
@ -42,6 +42,9 @@ class PreviewDataStoreRepository : DataStoreRepository {
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun clearAccessTokenData() {
|
||||||
|
}
|
||||||
|
|
||||||
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
|
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue