No public description

PiperOrigin-RevId: 770859221
This commit is contained in:
Google AI Edge Gallery 2025-06-12 17:32:43 -07:00 committed by Jing Jin
parent 9b45d4904c
commit f4006b35b0
126 changed files with 5557 additions and 6977 deletions

16
Android/.gitignore vendored
View file

@ -1,3 +1,19 @@
# @license
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Gradle files # Gradle files
.gradle/ .gradle/
build/ build/

View file

@ -1 +1 @@
# AI Edge Gallery (Android) # Google AI Edge Gallery (Android)

View file

@ -1,4 +1,20 @@
# @license
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
*.iml *.iml
.gradle .gradle
/local.properties /local.properties
/.idea/caches /.idea/caches

View file

@ -1,2 +1,18 @@
# @license
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
/build /build
/release /release

View file

@ -19,6 +19,7 @@ plugins {
alias(libs.plugins.kotlin.android) alias(libs.plugins.kotlin.android)
alias(libs.plugins.kotlin.compose) alias(libs.plugins.kotlin.compose)
alias(libs.plugins.kotlin.serialization) alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.protobuf)
} }
android { android {
@ -26,7 +27,6 @@ android {
compileSdk = 35 compileSdk = 35
defaultConfig { defaultConfig {
// Don't change to com.google.ai.edge.gallery yet.
applicationId = "com.google.aiedge.gallery" applicationId = "com.google.aiedge.gallery"
minSdk = 26 minSdk = 26
targetSdk = 35 targetSdk = 35
@ -42,10 +42,7 @@ android {
buildTypes { buildTypes {
release { release {
isMinifyEnabled = false isMinifyEnabled = false
proguardFiles( proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro")
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
signingConfig = signingConfigs.getByName("debug") signingConfig = signingConfigs.getByName("debug")
} }
} }
@ -76,7 +73,7 @@ dependencies {
implementation(libs.kotlinx.serialization.json) implementation(libs.kotlinx.serialization.json)
implementation(libs.material.icon.extended) implementation(libs.material.icon.extended)
implementation(libs.androidx.work.runtime) implementation(libs.androidx.work.runtime)
implementation(libs.androidx.datastore.preferences) implementation(libs.androidx.datastore)
implementation(libs.com.google.code.gson) implementation(libs.com.google.code.gson)
implementation(libs.androidx.lifecycle.process) implementation(libs.androidx.lifecycle.process)
implementation(libs.mediapipe.tasks.text) implementation(libs.mediapipe.tasks.text)
@ -93,6 +90,7 @@ dependencies {
implementation(libs.camerax.view) implementation(libs.camerax.view)
implementation(libs.openid.appauth) implementation(libs.openid.appauth)
implementation(libs.androidx.splashscreen) implementation(libs.androidx.splashscreen)
implementation(libs.protobuf.javalite)
testImplementation(libs.junit) testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit) androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core) androidTestImplementation(libs.androidx.espresso.core)
@ -100,4 +98,9 @@ dependencies {
androidTestImplementation(libs.androidx.ui.test.junit4) androidTestImplementation(libs.androidx.ui.test.junit4)
debugImplementation(libs.androidx.ui.tooling) debugImplementation(libs.androidx.ui.tooling)
debugImplementation(libs.androidx.ui.test.manifest) debugImplementation(libs.androidx.ui.test.manifest)
} }
protobuf {
protoc { artifact = "com.google.protobuf:protoc:4.26.1" }
generateProtoTasks { all().forEach { it.plugins { create("java") { option("lite") } } } }
}

View file

@ -1,21 +0,0 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View file

@ -16,13 +16,20 @@
--> -->
<manifest xmlns:android="http://schemas.android.com/apk/res/android" <manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.ai.edge.gallery"
xmlns:tools="http://schemas.android.com/tools"> xmlns:tools="http://schemas.android.com/tools">
<uses-sdk
android:minSdkVersion="26"
android:compileSdkVersion ="35"
android:targetSdkVersion="35" />
<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.FOREGROUND_SERVICE"/> <uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/> <uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
<uses-permission android:name="android.permission.INTERNET" /> <uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" /> <uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
<uses-permission android:name="android.permission.CAMERA" /> <uses-permission android:name="android.permission.WAKE_LOCK"/>
<uses-feature <uses-feature
android:name="android.hardware.camera" android:name="android.hardware.camera"
@ -63,6 +70,7 @@
</intent-filter> </intent-filter>
</activity> </activity>
<!-- For LLM inference engine -->
<uses-native-library <uses-native-library
android:name="libOpenCL.so" android:name="libOpenCL.so"
android:required="false" /> android:required="false" />

View file

@ -14,167 +14,15 @@
* limitations under the License. * limitations under the License.
*/ */
@file:OptIn(ExperimentalMaterial3Api::class)
package com.google.ai.edge.gallery package com.google.ai.edge.gallery
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.text.BasicText
import androidx.compose.foundation.text.TextAutoSize
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBarScrollBehavior
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import androidx.navigation.NavHostController import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController import androidx.navigation.compose.rememberNavController
import com.google.ai.edge.gallery.data.AppBarAction
import com.google.ai.edge.gallery.data.AppBarActionType
import com.google.ai.edge.gallery.ui.navigation.GalleryNavHost import com.google.ai.edge.gallery.ui.navigation.GalleryNavHost
/** /** Top level composable representing the main screen of the application. */
* Top level composable representing the main screen of the application.
*/
@Composable @Composable
fun GalleryApp(navController: NavHostController = rememberNavController()) { fun GalleryApp(navController: NavHostController = rememberNavController()) {
GalleryNavHost(navController = navController) GalleryNavHost(navController = navController)
} }
/**
* The top app bar.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun GalleryTopAppBar(
title: String,
modifier: Modifier = Modifier,
leftAction: AppBarAction? = null,
rightAction: AppBarAction? = null,
scrollBehavior: TopAppBarScrollBehavior? = null,
subtitle: String = "",
) {
val titleColor = MaterialTheme.colorScheme.primary
CenterAlignedTopAppBar(
title = {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
if (title == stringResource(R.string.app_name)) {
Icon(
painterResource(R.drawable.logo),
modifier = Modifier.size(20.dp),
contentDescription = "",
tint = Color.Unspecified,
)
}
BasicText(
text = title,
maxLines = 1,
color = { titleColor },
style = MaterialTheme.typography.titleLarge.copy(fontWeight = FontWeight.SemiBold),
autoSize = TextAutoSize.StepBased(
minFontSize = 14.sp,
maxFontSize = 22.sp,
stepSize = 1.sp
)
)
}
if (subtitle.isNotEmpty()) {
Text(
subtitle,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.secondary
)
}
}
},
modifier = modifier,
scrollBehavior = scrollBehavior,
// The button at the left.
navigationIcon = {
when (leftAction?.actionType) {
AppBarActionType.NAVIGATE_UP -> {
IconButton(onClick = leftAction.actionFn) {
Icon(
imageVector = Icons.AutoMirrored.Rounded.ArrowBack,
contentDescription = "",
)
}
}
AppBarActionType.REFRESH_MODELS -> {
IconButton(onClick = leftAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Refresh,
contentDescription = "",
tint = MaterialTheme.colorScheme.secondary
)
}
}
AppBarActionType.REFRESHING_MODELS -> {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
strokeWidth = 3.dp,
modifier = Modifier
.padding(start = 16.dp)
.size(20.dp)
)
}
else -> {}
}
},
// The "action" component at the right.
actions = {
when (rightAction?.actionType) {
// Click an icon to open "app setting".
AppBarActionType.APP_SETTING -> {
IconButton(onClick = rightAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Settings,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
AppBarActionType.MODEL_SELECTOR -> {
Text("ms")
}
// Click a button to navigate up.
AppBarActionType.NAVIGATE_UP -> {
TextButton(onClick = rightAction.actionFn) {
Text("Done")
}
}
else -> {}
}
}
)
}

View file

@ -0,0 +1,157 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@file:OptIn(ExperimentalMaterial3Api::class)
package com.google.ai.edge.gallery
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.text.BasicText
import androidx.compose.foundation.text.TextAutoSize
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBarScrollBehavior
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.data.AppBarAction
import com.google.ai.edge.gallery.data.AppBarActionType
/** The top app bar. */
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun GalleryTopAppBar(
title: String,
modifier: Modifier = Modifier,
leftAction: AppBarAction? = null,
rightAction: AppBarAction? = null,
scrollBehavior: TopAppBarScrollBehavior? = null,
subtitle: String = "",
) {
val titleColor = MaterialTheme.colorScheme.primary
CenterAlignedTopAppBar(
title = {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp),
) {
if (title == stringResource(R.string.app_name)) {
Icon(
painterResource(R.drawable.logo),
modifier = Modifier.size(20.dp),
contentDescription = "",
tint = Color.Unspecified,
)
}
BasicText(
text = title,
maxLines = 1,
color = { titleColor },
style = MaterialTheme.typography.titleLarge.copy(fontWeight = FontWeight.SemiBold),
autoSize =
TextAutoSize.StepBased(minFontSize = 14.sp, maxFontSize = 22.sp, stepSize = 1.sp),
)
}
if (subtitle.isNotEmpty()) {
Text(
subtitle,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.secondary,
)
}
}
},
modifier = modifier,
scrollBehavior = scrollBehavior,
// The button at the left.
navigationIcon = {
when (leftAction?.actionType) {
AppBarActionType.NAVIGATE_UP -> {
IconButton(onClick = leftAction.actionFn) {
Icon(imageVector = Icons.AutoMirrored.Rounded.ArrowBack, contentDescription = "")
}
}
AppBarActionType.REFRESH_MODELS -> {
IconButton(onClick = leftAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Refresh,
contentDescription = "",
tint = MaterialTheme.colorScheme.secondary,
)
}
}
AppBarActionType.REFRESHING_MODELS -> {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
strokeWidth = 3.dp,
modifier = Modifier.padding(start = 16.dp).size(20.dp),
)
}
else -> {}
}
},
// The "action" component at the right.
actions = {
when (rightAction?.actionType) {
// Click an icon to open "app setting".
AppBarActionType.APP_SETTING -> {
IconButton(onClick = rightAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Settings,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary,
)
}
}
AppBarActionType.MODEL_SELECTOR -> {
Text("ms")
}
// Click a button to navigate up.
AppBarActionType.NAVIGATE_UP -> {
TextButton(onClick = rightAction.actionFn) { Text("Done") }
}
else -> {}
}
},
)
}

View file

@ -18,15 +18,35 @@ package com.google.ai.edge.gallery
import android.app.Application import android.app.Application
import android.content.Context import android.content.Context
import androidx.datastore.core.CorruptionException
import androidx.datastore.core.DataStore import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences import androidx.datastore.core.Serializer
import androidx.datastore.preferences.preferencesDataStore import androidx.datastore.dataStore
import com.google.ai.edge.gallery.common.writeLaunchInfo
import com.google.ai.edge.gallery.data.AppContainer import com.google.ai.edge.gallery.data.AppContainer
import com.google.ai.edge.gallery.data.DefaultAppContainer import com.google.ai.edge.gallery.data.DefaultAppContainer
import com.google.ai.edge.gallery.ui.common.writeLaunchInfo import com.google.ai.edge.gallery.proto.Settings
import com.google.ai.edge.gallery.ui.theme.ThemeSettings import com.google.ai.edge.gallery.ui.theme.ThemeSettings
import com.google.protobuf.InvalidProtocolBufferException
import java.io.InputStream
import java.io.OutputStream
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences") object SettingsSerializer : Serializer<Settings> {
override val defaultValue: Settings = Settings.getDefaultInstance()
override suspend fun readFrom(input: InputStream): Settings {
try {
return Settings.parseFrom(input)
} catch (exception: InvalidProtocolBufferException) {
throw CorruptionException("Cannot read proto.", exception)
}
}
override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
}
private val Context.dataStore: DataStore<Settings> by
dataStore(fileName = "settings.pb", serializer = SettingsSerializer)
class GalleryApplication : Application() { class GalleryApplication : Application() {
/** AppContainer instance used by the rest of classes to obtain dependencies */ /** AppContainer instance used by the rest of classes to obtain dependencies */
@ -35,11 +55,10 @@ class GalleryApplication : Application() {
override fun onCreate() { override fun onCreate() {
super.onCreate() super.onCreate()
writeLaunchInfo(context = this) writeLaunchInfo(context = this)
container = DefaultAppContainer(this, dataStore) container = DefaultAppContainer(this, dataStore)
// Load theme. // Load saved theme.
ThemeSettings.themeOverride.value = container.dataStoreRepository.readThemeOverride() ThemeSettings.themeOverride.value = container.dataStoreRepository.readTheme()
} }
} }

View file

@ -16,29 +16,16 @@
package com.google.ai.edge.gallery package com.google.ai.edge.gallery
import androidx.lifecycle.DefaultLifecycleObserver
import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.ProcessLifecycleOwner
interface AppLifecycleProvider { interface AppLifecycleProvider {
val isAppInForeground: Boolean var isAppInForeground: Boolean
} }
class GalleryLifecycleProvider : AppLifecycleProvider, DefaultLifecycleObserver { class GalleryLifecycleProvider : AppLifecycleProvider {
private var _isAppInForeground = false private var _isAppInForeground = false
init { override var isAppInForeground: Boolean
ProcessLifecycleOwner.get().lifecycle.addObserver(this)
}
override val isAppInForeground: Boolean
get() = _isAppInForeground get() = _isAppInForeground
set(value) {
override fun onResume(owner: LifecycleOwner) { _isAppInForeground = value
_isAppInForeground = true }
}
override fun onPause(owner: LifecycleOwner) {
_isAppInForeground = false
}
} }

View file

@ -1,12 +0,0 @@
package com.google.ai.edge.gallery
import android.app.Service
import android.content.Intent
import android.os.IBinder
// TODO(jingjin): implement foreground service.
class GalleryService : Service() {
override fun onBind(p0: Intent?): IBinder? {
return null
}
}

View file

@ -32,14 +32,6 @@ class MainActivity : ComponentActivity() {
super.onCreate(savedInstanceState) super.onCreate(savedInstanceState)
enableEdgeToEdge() enableEdgeToEdge()
setContent { setContent { GalleryTheme { Surface(modifier = Modifier.fillMaxSize()) { GalleryApp() } } }
GalleryTheme {
Surface(
modifier = Modifier.fillMaxSize()
) {
GalleryApp()
}
}
}
} }
} }

View file

@ -0,0 +1,27 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.common
import androidx.compose.ui.graphics.Color
interface LatencyProvider {
val latencyMs: Float
}
data class Classification(val label: String, val score: Float, val color: Color)
data class JsonObjAndTextContent<T>(val jsonObj: T, val textContent: String)

View file

@ -0,0 +1,114 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.common
import android.content.Context
import android.util.Log
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
data class LaunchInfo(val ts: Long)
private const val TAG = "AGUtils"
private const val LAUNCH_INFO_FILE_NAME = "launch_info"
private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***"
fun readLaunchInfo(context: Context): LaunchInfo? {
try {
val gson = Gson()
val type = object : TypeToken<LaunchInfo>() {}.type
val file = File(context.getExternalFilesDir(null), LAUNCH_INFO_FILE_NAME)
val content = file.readText()
return gson.fromJson(content, type)
} catch (e: Exception) {
Log.e(TAG, "Failed to read launch info", e)
return null
}
}
fun cleanUpMediapipeTaskErrorMessage(message: String): String {
val index = message.indexOf("=== Source Location Trace")
if (index >= 0) {
return message.substring(0, index)
}
return message
}
fun processLlmResponse(response: String): String {
// Add "thinking" and "done thinking" around the thinking content.
var newContent =
response.replace("<think>", "$START_THINKING\n").replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) {
val thinkingContent =
newContent
.substring(0, endThinkingIndex + DONE_THINKING.length)
.replace(START_THINKING, "")
.replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
}
}
newContent = newContent.replace("\\n", "\n")
return newContent
}
fun writeLaunchInfo(context: Context) {
try {
val gson = Gson()
val launchInfo = LaunchInfo(ts = System.currentTimeMillis())
val jsonString = gson.toJson(launchInfo)
val file = File(context.getExternalFilesDir(null), LAUNCH_INFO_FILE_NAME)
file.writeText(jsonString)
} catch (e: Exception) {
Log.e(TAG, "Failed to write launch info", e)
}
}
inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET"
connection.connect()
val responseCode = connection.responseCode
if (responseCode == HttpURLConnection.HTTP_OK) {
val inputStream = connection.inputStream
val response = inputStream.bufferedReader().use { it.readText() }
val gson = Gson()
val type = object : TypeToken<T>() {}.type
val jsonObj = gson.fromJson<T>(response, type)
return JsonObjAndTextContent(jsonObj = jsonObj, textContent = response)
} else {
Log.e("AGUtils", "HTTP error: $responseCode")
}
} catch (e: Exception) {
Log.e("AGUtils", "Error when getting json response: ${e.message}")
e.printStackTrace()
}
return null
}

View file

@ -27,4 +27,4 @@ enum class AppBarActionType {
REFRESHING_MODELS, REFRESHING_MODELS,
} }
class AppBarAction(val actionType: AppBarActionType, val actionFn: () -> Unit) class AppBarAction(val actionType: AppBarActionType, val actionFn: () -> Unit)

View file

@ -18,9 +18,9 @@ package com.google.ai.edge.gallery.data
import android.content.Context import android.content.Context
import androidx.datastore.core.DataStore import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import com.google.ai.edge.gallery.GalleryLifecycleProvider
import com.google.ai.edge.gallery.AppLifecycleProvider import com.google.ai.edge.gallery.AppLifecycleProvider
import com.google.ai.edge.gallery.GalleryLifecycleProvider
import com.google.ai.edge.gallery.proto.Settings
/** /**
* App container for Dependency injection. * App container for Dependency injection.
@ -39,9 +39,9 @@ interface AppContainer {
* *
* This class provides concrete implementations for the application's dependencies, * This class provides concrete implementations for the application's dependencies,
*/ */
class DefaultAppContainer(ctx: Context, dataStore: DataStore<Preferences>) : AppContainer { class DefaultAppContainer(ctx: Context, dataStore: DataStore<Settings>) : AppContainer {
override val context = ctx override val context = ctx
override val lifecycleProvider = GalleryLifecycleProvider() override val lifecycleProvider = GalleryLifecycleProvider()
override val dataStoreRepository = DefaultDataStoreRepository(dataStore) override val dataStoreRepository = DefaultDataStoreRepository(dataStore)
override val downloadRepository = DefaultDownloadRepository(ctx, lifecycleProvider) override val downloadRepository = DefaultDownloadRepository(ctx, lifecycleProvider)
} }

View file

@ -16,11 +16,13 @@
package com.google.ai.edge.gallery.data package com.google.ai.edge.gallery.data
import kotlin.math.abs
/** /**
* The types of configuration editors available. * The types of configuration editors available.
* *
* This enum defines the different UI components used to edit configuration values. * This enum defines the different UI components used to edit configuration values. Each type
* Each type corresponds to a specific editor widget, such as a slider or a switch. * corresponds to a specific editor widget, such as a slider or a switch.
*/ */
enum class ConfigEditorType { enum class ConfigEditorType {
LABEL, LABEL,
@ -29,9 +31,7 @@ enum class ConfigEditorType {
DROPDOWN, DROPDOWN,
} }
/** /** The data types of configuration values. */
* The data types of configuration values.
*/
enum class ValueType { enum class ValueType {
INT, INT,
FLOAT, FLOAT,
@ -40,6 +40,28 @@ enum class ValueType {
BOOLEAN, BOOLEAN,
} }
enum class ConfigKey(val label: String) {
MAX_TOKENS("Max tokens"),
TOPK("TopK"),
TOPP("TopP"),
TEMPERATURE("Temperature"),
DEFAULT_MAX_TOKENS("Default max tokens"),
DEFAULT_TOPK("Default TopK"),
DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"),
SUPPORT_IMAGE("Support image"),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Choose accelerator"),
COMPATIBLE_ACCELERATORS("Compatible accelerators"),
WARM_UP_ITERATIONS("Warm up iterations"),
BENCHMARK_ITERATIONS("Benchmark iterations"),
ITERATIONS("Iterations"),
THEME("Theme"),
NAME("Name"),
MODEL_TYPE("Model type"),
}
/** /**
* Base class for configuration settings. * Base class for configuration settings.
* *
@ -58,18 +80,14 @@ open class Config(
open val needReinitialization: Boolean = true, open val needReinitialization: Boolean = true,
) )
/** /** Configuration setting for a label. */
* Configuration setting for a label. class LabelConfig(override val key: ConfigKey, override val defaultValue: String = "") :
*/ Config(
class LabelConfig( type = ConfigEditorType.LABEL,
override val key: ConfigKey, key = key,
override val defaultValue: String = "", defaultValue = defaultValue,
) : Config( valueType = ValueType.STRING,
type = ConfigEditorType.LABEL, )
key = key,
defaultValue = defaultValue,
valueType = ValueType.STRING
)
/** /**
* Configuration setting for a number slider. * Configuration setting for a number slider.
@ -92,32 +110,122 @@ class NumberSliderConfig(
valueType = valueType, valueType = valueType,
) )
/** /** Configuration setting for a boolean switch. */
* Configuration setting for a boolean switch.
*/
class BooleanSwitchConfig( class BooleanSwitchConfig(
override val key: ConfigKey, override val key: ConfigKey,
override val defaultValue: Boolean, override val defaultValue: Boolean,
override val needReinitialization: Boolean = true, override val needReinitialization: Boolean = true,
) : Config( ) :
type = ConfigEditorType.BOOLEAN_SWITCH, Config(
key = key, type = ConfigEditorType.BOOLEAN_SWITCH,
defaultValue = defaultValue, key = key,
valueType = ValueType.BOOLEAN, defaultValue = defaultValue,
) valueType = ValueType.BOOLEAN,
)
/** /** Configuration setting for a dropdown. */
* Configuration setting for a dropdown.
*/
class SegmentedButtonConfig( class SegmentedButtonConfig(
override val key: ConfigKey, override val key: ConfigKey,
override val defaultValue: String, override val defaultValue: String,
val options: List<String>, val options: List<String>,
val allowMultiple: Boolean = false, val allowMultiple: Boolean = false,
) : Config( ) :
type = ConfigEditorType.DROPDOWN, Config(
key = key, type = ConfigEditorType.DROPDOWN,
defaultValue = defaultValue, key = key,
// The emitted value will be comma-separated labels when allowMultiple=true. defaultValue = defaultValue,
valueType = ValueType.STRING, // The emitted value will be comma-separated labels when allowMultiple=true.
) valueType = ValueType.STRING,
)
fun convertValueToTargetType(value: Any, valueType: ValueType): Any {
return when (valueType) {
ValueType.INT ->
when (value) {
is Int -> value
is Float -> value.toInt()
is Double -> value.toInt()
is String -> value.toIntOrNull() ?: ""
is Boolean -> if (value) 1 else 0
else -> ""
}
ValueType.FLOAT ->
when (value) {
is Int -> value.toFloat()
is Float -> value
is Double -> value.toFloat()
is String -> value.toFloatOrNull() ?: ""
is Boolean -> if (value) 1f else 0f
else -> ""
}
ValueType.DOUBLE ->
when (value) {
is Int -> value.toDouble()
is Float -> value.toDouble()
is Double -> value
is String -> value.toDoubleOrNull() ?: ""
is Boolean -> if (value) 1.0 else 0.0
else -> ""
}
ValueType.BOOLEAN ->
when (value) {
is Int -> value == 0
is Boolean -> value
is Float -> abs(value) > 1e-6
is Double -> abs(value) > 1e-6
is String -> value.isNotEmpty()
else -> false
}
ValueType.STRING -> value.toString()
}
}
fun createLlmChatConfigs(
defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
defaultTopK: Int = DEFAULT_TOPK,
defaultTopP: Float = DEFAULT_TOPP,
defaultTemperature: Float = DEFAULT_TEMPERATURE,
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
): List<Config> {
return listOf(
LabelConfig(key = ConfigKey.MAX_TOKENS, defaultValue = "$defaultMaxToken"),
NumberSliderConfig(
key = ConfigKey.TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = defaultTopK.toFloat(),
valueType = ValueType.INT,
),
NumberSliderConfig(
key = ConfigKey.TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = defaultTopP,
valueType = ValueType.FLOAT,
),
NumberSliderConfig(
key = ConfigKey.TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = defaultTemperature,
valueType = ValueType.FLOAT,
),
SegmentedButtonConfig(
key = ConfigKey.ACCELERATOR,
defaultValue = accelerators[0].label,
options = accelerators.map { it.label },
),
)
}
fun getConfigValueString(value: Any, config: Config): String {
var strNewValue = "$value"
if (config.valueType == ValueType.FLOAT) {
strNewValue = "%.2f".format(value)
}
return strNewValue
}

View file

@ -16,64 +16,55 @@
package com.google.ai.edge.gallery.data package com.google.ai.edge.gallery.data
import kotlinx.serialization.KSerializer // @Serializable(with = ConfigValueSerializer::class)
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonPrimitive
@Serializable(with = ConfigValueSerializer::class)
sealed class ConfigValue { sealed class ConfigValue {
@Serializable // @Serializable
data class IntValue(val value: Int) : ConfigValue() data class IntValue(val value: Int) : ConfigValue()
@Serializable // @Serializable
data class FloatValue(val value: Float) : ConfigValue() data class FloatValue(val value: Float) : ConfigValue()
@Serializable // @Serializable
data class StringValue(val value: String) : ConfigValue() data class StringValue(val value: String) : ConfigValue()
} }
/** // /**
* Custom serializer for the ConfigValue class. // * Custom serializer for the ConfigValue class.
* // *
* This object implements the KSerializer interface to provide custom serialization and // * This object implements the KSerializer interface to provide custom serialization and
* deserialization logic for the ConfigValue class. It handles different types of ConfigValue // * deserialization logic for the ConfigValue class. It handles different types of ConfigValue
* (IntValue, FloatValue, StringValue) and supports JSON format. // * (IntValue, FloatValue, StringValue) and supports JSON format.
*/ // */
object ConfigValueSerializer : KSerializer<ConfigValue> { // object ConfigValueSerializer : KSerializer<ConfigValue> {
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("ConfigValue") // override val descriptor: SerialDescriptor = buildClassSerialDescriptor("ConfigValue")
override fun serialize(encoder: Encoder, value: ConfigValue) { // override fun serialize(encoder: Encoder, value: ConfigValue) {
when (value) { // when (value) {
is ConfigValue.IntValue -> encoder.encodeInt(value.value) // is ConfigValue.IntValue -> encoder.encodeInt(value.value)
is ConfigValue.FloatValue -> encoder.encodeFloat(value.value) // is ConfigValue.FloatValue -> encoder.encodeFloat(value.value)
is ConfigValue.StringValue -> encoder.encodeString(value.value) // is ConfigValue.StringValue -> encoder.encodeString(value.value)
} // }
} // }
override fun deserialize(decoder: Decoder): ConfigValue { // override fun deserialize(decoder: Decoder): ConfigValue {
val input = decoder as? JsonDecoder // val input =
?: throw SerializationException("This serializer only works with Json") // decoder as? JsonDecoder
return when (val element = input.decodeJsonElement()) { // ?: throw SerializationException("This serializer only works with Json")
is JsonPrimitive -> { // return when (val element = input.decodeJsonElement()) {
if (element.isString) { // is JsonPrimitive -> {
ConfigValue.StringValue(element.content) // if (element.isString) {
} else if (element.content.contains('.')) { // ConfigValue.StringValue(element.content)
ConfigValue.FloatValue(element.content.toFloat()) // } else if (element.content.contains('.')) {
} else { // ConfigValue.FloatValue(element.content.toFloat())
ConfigValue.IntValue(element.content.toInt()) // } else {
} // ConfigValue.IntValue(element.content.toInt())
} // }
// }
else -> throw SerializationException("Expected JsonPrimitive") // else -> throw SerializationException("Expected JsonPrimitive")
} // }
} // }
} // }
fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int { fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int {
if (configValue == null) { if (configValue == null) {

View file

@ -34,3 +34,13 @@ const val KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES = "KEY_MODEL_EXTRA_DATA_DOWNL
const val KEY_MODEL_IS_ZIP = "KEY_MODEL_IS_ZIP" const val KEY_MODEL_IS_ZIP = "KEY_MODEL_IS_ZIP"
const val KEY_MODEL_UNZIPPED_DIR = "KEY_MODEL_UNZIPPED_DIR" const val KEY_MODEL_UNZIPPED_DIR = "KEY_MODEL_UNZIPPED_DIR"
const val KEY_MODEL_START_UNZIPPING = "KEY_MODEL_START_UNZIPPING" const val KEY_MODEL_START_UNZIPPING = "KEY_MODEL_START_UNZIPPING"
// Default values for LLM models.
const val DEFAULT_MAX_TOKEN = 1024
const val DEFAULT_TOPK = 40
const val DEFAULT_TOPP = 0.9f
const val DEFAULT_TEMPERATURE = 1.0f
val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
// Max number of images allowed in a "ask image" session.
const val MAX_IMAGE_COUNT = 10

View file

@ -16,231 +16,109 @@
package com.google.ai.edge.gallery.data package com.google.ai.edge.gallery.data
import android.security.keystore.KeyGenParameterSpec
import android.security.keystore.KeyProperties
import android.util.Base64
import androidx.datastore.core.DataStore import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences import com.google.ai.edge.gallery.proto.AccessTokenData
import androidx.datastore.preferences.core.edit import com.google.ai.edge.gallery.proto.ImportedModel
import androidx.datastore.preferences.core.longPreferencesKey import com.google.ai.edge.gallery.proto.Settings
import androidx.datastore.preferences.core.stringPreferencesKey import com.google.ai.edge.gallery.proto.Theme
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import com.google.ai.edge.gallery.ui.theme.THEME_AUTO
import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import java.security.KeyStore
import javax.crypto.Cipher
import javax.crypto.KeyGenerator
import javax.crypto.SecretKey
data class AccessTokenData(
val accessToken: String,
val refreshToken: String,
val expiresAtMs: Long
)
// TODO(b/423700720): Change to async (suspend) functions
interface DataStoreRepository { interface DataStoreRepository {
fun saveTextInputHistory(history: List<String>) fun saveTextInputHistory(history: List<String>)
fun readTextInputHistory(): List<String> fun readTextInputHistory(): List<String>
fun saveThemeOverride(theme: String)
fun readThemeOverride(): String fun saveTheme(theme: Theme)
fun readTheme(): Theme
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
fun clearAccessTokenData() fun clearAccessTokenData()
fun readAccessTokenData(): AccessTokenData? fun readAccessTokenData(): AccessTokenData?
fun saveImportedModels(importedModels: List<ImportedModelInfo>)
fun readImportedModels(): List<ImportedModelInfo> fun saveImportedModels(importedModels: List<ImportedModel>)
fun readImportedModels(): List<ImportedModel>
} }
/** /** Repository for managing data using Proto DataStore. */
* Repository for managing data using DataStore, with JSON serialization. class DefaultDataStoreRepository(private val dataStore: DataStore<Settings>) : DataStoreRepository {
*
* This class provides methods to read, add, remove, and clear data stored in DataStore,
* using JSON serialization for complex objects. It uses Gson for serializing and deserializing
* lists of objects to/from JSON strings.
*
* DataStore is used to persist data as JSON strings under specified keys.
*/
class DefaultDataStoreRepository(
private val dataStore: DataStore<Preferences>
) :
DataStoreRepository {
private object PreferencesKeys {
val TEXT_INPUT_HISTORY = stringPreferencesKey("text_input_history")
val THEME_OVERRIDE = stringPreferencesKey("theme_override")
val ENCRYPTED_ACCESS_TOKEN = stringPreferencesKey("encrypted_access_token")
// Store Initialization Vector
val ACCESS_TOKEN_IV = stringPreferencesKey("access_token_iv")
val ENCRYPTED_REFRESH_TOKEN = stringPreferencesKey("encrypted_refresh_token")
// Store Initialization Vector
val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv")
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
// Data for all imported models.
val IMPORTED_MODELS = stringPreferencesKey("imported_models")
}
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
private val keyStore: KeyStore = KeyStore.getInstance("AndroidKeyStore").apply { load(null) }
override fun saveTextInputHistory(history: List<String>) { override fun saveTextInputHistory(history: List<String>) {
runBlocking { runBlocking {
dataStore.edit { preferences -> dataStore.updateData { settings ->
val gson = Gson() settings.toBuilder().clearTextInputHistory().addAllTextInputHistory(history).build()
val jsonString = gson.toJson(history)
preferences[PreferencesKeys.TEXT_INPUT_HISTORY] = jsonString
} }
} }
} }
override fun readTextInputHistory(): List<String> { override fun readTextInputHistory(): List<String> {
return runBlocking { return runBlocking {
val preferences = dataStore.data.first() val settings = dataStore.data.first()
getTextInputHistory(preferences) settings.textInputHistoryList
} }
} }
override fun saveThemeOverride(theme: String) { override fun saveTheme(theme: Theme) {
runBlocking { runBlocking {
dataStore.edit { preferences -> dataStore.updateData { settings -> settings.toBuilder().setTheme(theme).build() }
preferences[PreferencesKeys.THEME_OVERRIDE] = theme
}
} }
} }
override fun readThemeOverride(): String { override fun readTheme(): Theme {
return runBlocking { return runBlocking {
val preferences = dataStore.data.first() val settings = dataStore.data.first()
preferences[PreferencesKeys.THEME_OVERRIDE] ?: THEME_AUTO val curTheme = settings.theme
// Use "auto" as the default theme.
if (curTheme == Theme.THEME_UNSPECIFIED) Theme.THEME_AUTO else curTheme
} }
} }
override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) { override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) {
runBlocking { runBlocking {
val (encryptedAccessToken, accessTokenIv) = encrypt(accessToken) dataStore.updateData { settings ->
val (encryptedRefreshToken, refreshTokenIv) = encrypt(refreshToken) settings
dataStore.edit { preferences -> .toBuilder()
preferences[PreferencesKeys.ENCRYPTED_ACCESS_TOKEN] = encryptedAccessToken .setAccessTokenData(
preferences[PreferencesKeys.ACCESS_TOKEN_IV] = accessTokenIv AccessTokenData.newBuilder()
preferences[PreferencesKeys.ENCRYPTED_REFRESH_TOKEN] = encryptedRefreshToken .setAccessToken(accessToken)
preferences[PreferencesKeys.REFRESH_TOKEN_IV] = refreshTokenIv .setRefreshToken(refreshToken)
preferences[PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT] = expiresAt .setExpiresAtMs(expiresAt)
.build()
)
.build()
} }
} }
} }
override fun clearAccessTokenData() { override fun clearAccessTokenData() {
return runBlocking { runBlocking {
dataStore.edit { preferences -> dataStore.updateData { settings -> settings.toBuilder().clearAccessTokenData().build() }
preferences.remove(PreferencesKeys.ENCRYPTED_ACCESS_TOKEN)
preferences.remove(PreferencesKeys.ACCESS_TOKEN_IV)
preferences.remove(PreferencesKeys.ENCRYPTED_REFRESH_TOKEN)
preferences.remove(PreferencesKeys.REFRESH_TOKEN_IV)
preferences.remove(PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT)
}
} }
} }
override fun readAccessTokenData(): AccessTokenData? { override fun readAccessTokenData(): AccessTokenData? {
return runBlocking { return runBlocking {
val preferences = dataStore.data.first() val settings = dataStore.data.first()
val encryptedAccessToken = preferences[PreferencesKeys.ENCRYPTED_ACCESS_TOKEN] settings.accessTokenData
val encryptedRefreshToken = preferences[PreferencesKeys.ENCRYPTED_REFRESH_TOKEN]
val accessTokenIv = preferences[PreferencesKeys.ACCESS_TOKEN_IV]
val refreshTokenIv = preferences[PreferencesKeys.REFRESH_TOKEN_IV]
val expiresAt = preferences[PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT]
var decryptedAccessToken: String? = null
var decryptedRefreshToken: String? = null
if (encryptedAccessToken != null && accessTokenIv != null) {
decryptedAccessToken = decrypt(encryptedAccessToken, accessTokenIv)
}
if (encryptedRefreshToken != null && refreshTokenIv != null) {
decryptedRefreshToken = decrypt(encryptedRefreshToken, refreshTokenIv)
}
if (decryptedAccessToken != null && decryptedRefreshToken != null && expiresAt != null) {
AccessTokenData(decryptedAccessToken, decryptedRefreshToken, expiresAt)
} else {
null
}
} }
} }
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) { override fun saveImportedModels(importedModels: List<ImportedModel>) {
runBlocking { runBlocking {
dataStore.edit { preferences -> dataStore.updateData { settings ->
val gson = Gson() settings.toBuilder().clearImportedModel().addAllImportedModel(importedModels).build()
val jsonString = gson.toJson(importedModels)
preferences[PreferencesKeys.IMPORTED_MODELS] = jsonString
} }
} }
} }
override fun readImportedModels(): List<ImportedModelInfo> { override fun readImportedModels(): List<ImportedModel> {
return runBlocking { return runBlocking {
val preferences = dataStore.data.first() val settings = dataStore.data.first()
val infosStr = preferences[PreferencesKeys.IMPORTED_MODELS] ?: "[]" settings.importedModelList
val gson = Gson()
val listType = object : TypeToken<List<ImportedModelInfo>>() {}.type
gson.fromJson(infosStr, listType)
}
}
private fun getTextInputHistory(preferences: Preferences): List<String> {
val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]"
val gson = Gson()
val listType = object : TypeToken<List<String>>() {}.type
return gson.fromJson(infosStr, listType)
}
private fun getOrCreateSecretKey(): SecretKey {
return (keyStore.getKey(keystoreAlias, null) as? SecretKey) ?: run {
val keyGenerator =
KeyGenerator.getInstance(KeyProperties.KEY_ALGORITHM_AES, "AndroidKeyStore")
val keyGenParameterSpec = KeyGenParameterSpec.Builder(
keystoreAlias,
KeyProperties.PURPOSE_ENCRYPT or KeyProperties.PURPOSE_DECRYPT
)
.setBlockModes(KeyProperties.BLOCK_MODE_GCM)
.setEncryptionPaddings(KeyProperties.ENCRYPTION_PADDING_NONE)
.setUserAuthenticationRequired(false) // Consider setting to true for added security
.build()
keyGenerator.init(keyGenParameterSpec)
keyGenerator.generateKey()
}
}
private fun encrypt(plainText: String): Pair<String, String> {
val secretKey = getOrCreateSecretKey()
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
cipher.init(Cipher.ENCRYPT_MODE, secretKey)
val iv = cipher.iv
val encryptedBytes = cipher.doFinal(plainText.toByteArray(Charsets.UTF_8))
return Base64.encodeToString(encryptedBytes, Base64.DEFAULT) to Base64.encodeToString(
iv,
Base64.DEFAULT
)
}
private fun decrypt(encryptedText: String, ivText: String): String? {
val secretKey = getOrCreateSecretKey()
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
val ivBytes = Base64.decode(ivText, Base64.DEFAULT)
val spec = javax.crypto.spec.GCMParameterSpec(128, ivBytes) // 128 bit tag length
cipher.init(Cipher.DECRYPT_MODE, secretKey, spec)
val encryptedBytes = Base64.decode(encryptedText, Base64.DEFAULT)
return try {
String(cipher.doFinal(encryptedBytes), Charsets.UTF_8)
} catch (e: Exception) {
// Handle decryption errors (e.g., key not found)
null
} }
} }
} }

View file

@ -23,11 +23,11 @@ import android.app.PendingIntent
import android.content.Context import android.content.Context
import android.content.Intent import android.content.Intent
import android.content.pm.PackageManager import android.content.pm.PackageManager
import android.net.Uri
import android.util.Log import android.util.Log
import androidx.core.app.ActivityCompat import androidx.core.app.ActivityCompat
import androidx.core.app.NotificationCompat import androidx.core.app.NotificationCompat
import androidx.core.app.NotificationManagerCompat import androidx.core.app.NotificationManagerCompat
import androidx.core.net.toUri
import androidx.work.Data import androidx.work.Data
import androidx.work.ExistingWorkPolicy import androidx.work.ExistingWorkPolicy
import androidx.work.OneTimeWorkRequestBuilder import androidx.work.OneTimeWorkRequestBuilder
@ -38,7 +38,7 @@ import androidx.work.WorkManager
import androidx.work.WorkQuery import androidx.work.WorkQuery
import com.google.ai.edge.gallery.AppLifecycleProvider import com.google.ai.edge.gallery.AppLifecycleProvider
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.common.readLaunchInfo import com.google.ai.edge.gallery.common.readLaunchInfo
import com.google.ai.edge.gallery.worker.DownloadWorker import com.google.ai.edge.gallery.worker.DownloadWorker
import com.google.common.util.concurrent.FutureCallback import com.google.common.util.concurrent.FutureCallback
import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.Futures
@ -53,7 +53,8 @@ data class AGWorkInfo(val modelName: String, val workId: String)
interface DownloadRepository { interface DownloadRepository {
fun downloadModel( fun downloadModel(
model: Model, onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) )
fun cancelDownloadModel(model: Model) fun cancelDownloadModel(model: Model)
@ -83,7 +84,8 @@ class DefaultDownloadRepository(
private val workManager = WorkManager.getInstance(context) private val workManager = WorkManager.getInstance(context)
override fun downloadModel( override fun downloadModel(
model: Model, onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) { ) {
val appTs = readLaunchInfo(context = context)?.ts ?: 0 val appTs = readLaunchInfo(context = context)?.ts ?: 0
@ -91,18 +93,24 @@ class DefaultDownloadRepository(
val builder = Data.Builder() val builder = Data.Builder()
val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes } val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes }
val inputDataBuilder = val inputDataBuilder =
builder.putString(KEY_MODEL_NAME, model.name).putString(KEY_MODEL_URL, model.url) builder
.putString(KEY_MODEL_NAME, model.name)
.putString(KEY_MODEL_URL, model.url)
.putString(KEY_MODEL_VERSION, model.version) .putString(KEY_MODEL_VERSION, model.version)
.putString(KEY_MODEL_DOWNLOAD_MODEL_DIR, model.normalizedName) .putString(KEY_MODEL_DOWNLOAD_MODEL_DIR, model.normalizedName)
.putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName) .putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName)
.putBoolean(KEY_MODEL_IS_ZIP, model.isZip).putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir) .putBoolean(KEY_MODEL_IS_ZIP, model.isZip)
.putLong(KEY_MODEL_TOTAL_BYTES, totalBytes).putLong(KEY_MODEL_DOWNLOAD_APP_TS, appTs) .putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir)
.putLong(KEY_MODEL_TOTAL_BYTES, totalBytes)
.putLong(KEY_MODEL_DOWNLOAD_APP_TS, appTs)
if (model.extraDataFiles.isNotEmpty()) { if (model.extraDataFiles.isNotEmpty()) {
inputDataBuilder.putString(KEY_MODEL_EXTRA_DATA_URLS, inputDataBuilder
model.extraDataFiles.joinToString(",") { it.url }).putString( .putString(KEY_MODEL_EXTRA_DATA_URLS, model.extraDataFiles.joinToString(",") { it.url })
KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES, .putString(
model.extraDataFiles.joinToString(",") { it.downloadFileName }) KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES,
model.extraDataFiles.joinToString(",") { it.downloadFileName },
)
} }
if (model.accessToken != null) { if (model.accessToken != null) {
inputDataBuilder.putString(KEY_MODEL_DOWNLOAD_ACCESS_TOKEN, model.accessToken) inputDataBuilder.putString(KEY_MODEL_DOWNLOAD_ACCESS_TOKEN, model.accessToken)
@ -111,20 +119,19 @@ class DefaultDownloadRepository(
// Create worker request. // Create worker request.
val downloadWorkRequest = val downloadWorkRequest =
OneTimeWorkRequestBuilder<DownloadWorker>().setExpedited(OutOfQuotaPolicy.RUN_AS_NON_EXPEDITED_WORK_REQUEST) OneTimeWorkRequestBuilder<DownloadWorker>()
.setInputData(inputData).addTag("$MODEL_NAME_TAG:${model.name}").build() .setExpedited(OutOfQuotaPolicy.RUN_AS_NON_EXPEDITED_WORK_REQUEST)
.setInputData(inputData)
.addTag("$MODEL_NAME_TAG:${model.name}")
.build()
val workerId = downloadWorkRequest.id val workerId = downloadWorkRequest.id
// Start! // Start!
workManager.enqueueUniqueWork( workManager.enqueueUniqueWork(model.name, ExistingWorkPolicy.REPLACE, downloadWorkRequest)
model.name, ExistingWorkPolicy.REPLACE, downloadWorkRequest
)
// Observe progress. // Observe progress.
observerWorkerProgress( observerWorkerProgress(workerId = workerId, model = model, onStatusUpdated = onStatusUpdated)
workerId = workerId, model = model, onStatusUpdated = onStatusUpdated
)
} }
override fun cancelDownloadModel(model: Model) { override fun cancelDownloadModel(model: Model) {
@ -143,7 +150,8 @@ class DefaultDownloadRepository(
} }
val combinedFuture: ListenableFuture<List<Operation.State.SUCCESS>> = Futures.allAsList(futures) val combinedFuture: ListenableFuture<List<Operation.State.SUCCESS>> = Futures.allAsList(futures)
Futures.addCallback( Futures.addCallback(
combinedFuture, object : FutureCallback<List<Operation.State.SUCCESS>> { combinedFuture,
object : FutureCallback<List<Operation.State.SUCCESS>> {
override fun onSuccess(result: List<Operation.State.SUCCESS>?) { override fun onSuccess(result: List<Operation.State.SUCCESS>?) {
// All cancellations are complete // All cancellations are complete
onComplete() onComplete()
@ -154,7 +162,8 @@ class DefaultDownloadRepository(
t.printStackTrace() t.printStackTrace()
onComplete() onComplete()
} }
}, MoreExecutors.directExecutor() },
MoreExecutors.directExecutor(),
) )
} }
@ -175,45 +184,41 @@ class DefaultDownloadRepository(
if (!startUnzipping) { if (!startUnzipping) {
if (receivedBytes != 0L) { if (receivedBytes != 0L) {
onStatusUpdated( onStatusUpdated(
model, ModelDownloadStatus( model,
ModelDownloadStatus(
status = ModelDownloadStatusType.IN_PROGRESS, status = ModelDownloadStatusType.IN_PROGRESS,
totalBytes = model.totalBytes, totalBytes = model.totalBytes,
receivedBytes = receivedBytes, receivedBytes = receivedBytes,
bytesPerSecond = downloadRate, bytesPerSecond = downloadRate,
remainingMs = remainingSeconds, remainingMs = remainingSeconds,
) ),
) )
} }
} else { } else {
onStatusUpdated( onStatusUpdated(
model, ModelDownloadStatus( model,
status = ModelDownloadStatusType.UNZIPPING, ModelDownloadStatus(status = ModelDownloadStatusType.UNZIPPING),
)
) )
} }
} }
WorkInfo.State.SUCCEEDED -> { WorkInfo.State.SUCCEEDED -> {
Log.d("repo", "worker %s success".format(workerId.toString())) Log.d("repo", "worker %s success".format(workerId.toString()))
onStatusUpdated( onStatusUpdated(model, ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED))
model, ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED,
)
)
sendNotification( sendNotification(
title = context.getString( title = context.getString(R.string.notification_title_success),
R.string.notification_title_success
),
text = context.getString(R.string.notification_content_success).format(model.name), text = context.getString(R.string.notification_content_success).format(model.name),
modelName = model.name, modelName = model.name,
) )
} }
WorkInfo.State.FAILED, WorkInfo.State.CANCELLED -> { WorkInfo.State.FAILED,
WorkInfo.State.CANCELLED -> {
var status = ModelDownloadStatusType.FAILED var status = ModelDownloadStatusType.FAILED
val errorMessage = workInfo.outputData.getString(KEY_MODEL_DOWNLOAD_ERROR_MESSAGE) ?: "" val errorMessage = workInfo.outputData.getString(KEY_MODEL_DOWNLOAD_ERROR_MESSAGE) ?: ""
Log.d( Log.d(
"repo", "worker %s FAILED or CANCELLED: %s".format(workerId.toString(), errorMessage) "repo",
"worker %s FAILED or CANCELLED: %s".format(workerId.toString(), errorMessage),
) )
if (workInfo.state == WorkInfo.State.CANCELLED) { if (workInfo.state == WorkInfo.State.CANCELLED) {
status = ModelDownloadStatusType.NOT_DOWNLOADED status = ModelDownloadStatusType.NOT_DOWNLOADED
@ -225,7 +230,8 @@ class DefaultDownloadRepository(
) )
} }
onStatusUpdated( onStatusUpdated(
model, ModelDownloadStatus(status = status, errorMessage = errorMessage) model,
ModelDownloadStatus(status = status, errorMessage = errorMessage),
) )
} }
@ -278,29 +284,35 @@ class DefaultDownloadRepository(
notificationManager.createNotificationChannel(channel) notificationManager.createNotificationChannel(channel)
// Create an Intent to open your app with a deep link. // Create an Intent to open your app with a deep link.
val intent = Intent( val intent =
Intent.ACTION_VIEW, Uri.parse("com.google.ai.edge.gallery://model/${modelName}") Intent(Intent.ACTION_VIEW, "com.google.ai.edge.gallery://model/${modelName}".toUri()).apply {
).apply { flags = Intent.FLAG_ACTIVITY_NEW_TASK
flags = Intent.FLAG_ACTIVITY_NEW_TASK }
}
// Create a PendingIntent // Create a PendingIntent
val pendingIntent: PendingIntent = PendingIntent.getActivity( val pendingIntent: PendingIntent =
context, 0, intent, PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE PendingIntent.getActivity(
) context,
0,
intent,
PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE,
)
val builder =
val builder = NotificationCompat.Builder(context, channelId) NotificationCompat.Builder(context, channelId)
// TODO: replace icon. // TODO: replace icon.
.setSmallIcon(android.R.drawable.ic_dialog_info).setContentTitle(title).setContentText(text) .setSmallIcon(android.R.drawable.ic_dialog_info)
.setPriority(NotificationCompat.PRIORITY_HIGH).setContentIntent(pendingIntent) .setContentTitle(title)
.setAutoCancel(true) .setContentText(text)
.setPriority(NotificationCompat.PRIORITY_HIGH)
.setContentIntent(pendingIntent)
.setAutoCancel(true)
with(NotificationManagerCompat.from(context)) { with(NotificationManagerCompat.from(context)) {
// notificationId is a unique int for each notification that you must define // notificationId is a unique int for each notification that you must define
if (ActivityCompat.checkSelfPermission( if (
context, Manifest.permission.POST_NOTIFICATIONS ActivityCompat.checkSelfPermission(context, Manifest.permission.POST_NOTIFICATIONS) !=
) != PackageManager.PERMISSION_GRANTED PackageManager.PERMISSION_GRANTED
) { ) {
// Permission not granted, return or handle accordingly. In real app, request permission. // Permission not granted, return or handle accordingly. In real app, request permission.
return return

View file

@ -17,8 +17,6 @@
package com.google.ai.edge.gallery.data package com.google.ai.edge.gallery.data
import android.content.Context import android.content.Context
import com.google.ai.edge.gallery.ui.common.chat.PromptTemplate
import com.google.ai.edge.gallery.ui.common.convertValueToTargetType
import java.io.File import java.io.File
data class ModelDataFile( data class ModelDataFile(
@ -28,13 +26,11 @@ data class ModelDataFile(
val sizeInBytes: Long, val sizeInBytes: Long,
) )
enum class Accelerator(val label: String) {
CPU(label = "CPU"), GPU(label = "GPU")
}
const val IMPORTS_DIR = "__imports" const val IMPORTS_DIR = "__imports"
private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]") private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]")
data class PromptTemplate(val title: String, val description: String, val prompt: String)
/** A model for a task */ /** A model for a task */
data class Model( data class Model(
/** The name (for display purpose) of the model. */ /** The name (for display purpose) of the model. */
@ -67,9 +63,7 @@ data class Model(
*/ */
val info: String = "", val info: String = "",
/** /** The url to jump to when clicking "learn more" in expanded model item. */
* The url to jump to when clicking "learn more" in expanded model item.
*/
val learnMoreUrl: String = "", val learnMoreUrl: String = "",
/** A list of configurable parameters for the model. */ /** A list of configurable parameters for the model. */
@ -105,6 +99,9 @@ data class Model(
var configValues: Map<String, Any> = mapOf(), var configValues: Map<String, Any> = mapOf(),
var totalBytes: Long = 0L, var totalBytes: Long = 0L,
var accessToken: String? = null, var accessToken: String? = null,
/** The estimated peak memory in byte to run the model. */
val estimatedPeakMemoryInBytes: Long? = null,
) { ) {
init { init {
normalizedName = NORMALIZE_NAME_REGEX.replace(name, "_") normalizedName = NORMALIZE_NAME_REGEX.replace(name, "_")
@ -121,17 +118,13 @@ data class Model(
fun getPath(context: Context, fileName: String = downloadFileName): String { fun getPath(context: Context, fileName: String = downloadFileName): String {
if (imported) { if (imported) {
return listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", fileName).joinToString( return listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", fileName)
File.separator .joinToString(File.separator)
)
} }
val baseDir = val baseDir =
listOf( listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", normalizedName, version)
context.getExternalFilesDir(null)?.absolutePath ?: "", .joinToString(File.separator)
normalizedName,
version
).joinToString(File.separator)
return if (this.isZip && this.unzipDir.isNotEmpty()) { return if (this.isZip && this.unzipDir.isNotEmpty()) {
"$baseDir/${this.unzipDir}" "$baseDir/${this.unzipDir}"
} else { } else {
@ -140,27 +133,27 @@ data class Model(
} }
fun getIntConfigValue(key: ConfigKey, defaultValue: Int = 0): Int { fun getIntConfigValue(key: ConfigKey, defaultValue: Int = 0): Int {
return getTypedConfigValue( return getTypedConfigValue(key = key, valueType = ValueType.INT, defaultValue = defaultValue)
key = key, valueType = ValueType.INT, defaultValue = defaultValue as Int
) as Int
} }
fun getFloatConfigValue(key: ConfigKey, defaultValue: Float = 0.0f): Float { fun getFloatConfigValue(key: ConfigKey, defaultValue: Float = 0.0f): Float {
return getTypedConfigValue( return getTypedConfigValue(key = key, valueType = ValueType.FLOAT, defaultValue = defaultValue)
key = key, valueType = ValueType.FLOAT, defaultValue = defaultValue as Float
) as Float
} }
fun getBooleanConfigValue(key: ConfigKey, defaultValue: Boolean = false): Boolean { fun getBooleanConfigValue(key: ConfigKey, defaultValue: Boolean = false): Boolean {
return getTypedConfigValue( return getTypedConfigValue(
key = key, valueType = ValueType.BOOLEAN, defaultValue = defaultValue key = key,
) as Boolean valueType = ValueType.BOOLEAN,
defaultValue = defaultValue,
)
as Boolean
} }
fun getStringConfigValue(key: ConfigKey, defaultValue: String = ""): String { fun getStringConfigValue(key: ConfigKey, defaultValue: String = ""): String {
return getTypedConfigValue( return getTypedConfigValue(key = key, valueType = ValueType.STRING, defaultValue = defaultValue)
key = key, valueType = ValueType.STRING, defaultValue = defaultValue as String
) as String
} }
fun getExtraDataFile(name: String): ModelDataFile? { fun getExtraDataFile(name: String): ModelDataFile? {
@ -169,20 +162,19 @@ data class Model(
private fun getTypedConfigValue(key: ConfigKey, valueType: ValueType, defaultValue: Any): Any { private fun getTypedConfigValue(key: ConfigKey, valueType: ValueType, defaultValue: Any): Any {
return convertValueToTargetType( return convertValueToTargetType(
value = configValues.getOrDefault(key.label, defaultValue), valueType = valueType value = configValues.getOrDefault(key.label, defaultValue),
valueType = valueType,
) )
} }
} }
/** Data for a imported local model. */
data class ImportedModelInfo(
val fileName: String,
val fileSize: Long,
val defaultValues: Map<String, Any>
)
enum class ModelDownloadStatusType { enum class ModelDownloadStatusType {
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED, NOT_DOWNLOADED,
PARTIALLY_DOWNLOADED,
IN_PROGRESS,
UNZIPPING,
SUCCEEDED,
FAILED,
} }
data class ModelDownloadStatus( data class ModelDownloadStatus(
@ -197,51 +189,29 @@ data class ModelDownloadStatus(
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// Configs. // Configs.
enum class ConfigKey(val label: String) { val MOBILENET_CONFIGS: List<Config> =
MAX_TOKENS("Max tokens"), listOf(
TOPK("TopK"), NumberSliderConfig(
TOPP("TopP"), key = ConfigKey.MAX_RESULT_COUNT,
TEMPERATURE("Temperature"), sliderMin = 1f,
DEFAULT_MAX_TOKENS("Default max tokens"), sliderMax = 5f,
DEFAULT_TOPK("Default TopK"), defaultValue = 3f,
DEFAULT_TOPP("Default TopP"), valueType = ValueType.INT,
DEFAULT_TEMPERATURE("Default temperature"), ),
SUPPORT_IMAGE("Support image"), BooleanSwitchConfig(key = ConfigKey.USE_GPU, defaultValue = false),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Choose accelerator"),
COMPATIBLE_ACCELERATORS("Compatible accelerators"),
WARM_UP_ITERATIONS("Warm up iterations"),
BENCHMARK_ITERATIONS("Benchmark iterations"),
ITERATIONS("Iterations"),
THEME("Theme"),
NAME("Name"),
MODEL_TYPE("Model type")
}
val MOBILENET_CONFIGS: List<Config> = listOf(
NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f,
sliderMax = 5f,
defaultValue = 3f,
valueType = ValueType.INT
), BooleanSwitchConfig(
key = ConfigKey.USE_GPU,
defaultValue = false,
) )
)
val IMAGE_GENERATION_CONFIGS: List<Config> = listOf( val IMAGE_GENERATION_CONFIGS: List<Config> =
NumberSliderConfig( listOf(
key = ConfigKey.ITERATIONS, NumberSliderConfig(
sliderMin = 5f, key = ConfigKey.ITERATIONS,
sliderMax = 50f, sliderMin = 5f,
defaultValue = 10f, sliderMax = 50f,
valueType = ValueType.INT, defaultValue = 10f,
needReinitialization = false, valueType = ValueType.INT,
needReinitialization = false,
)
) )
)
const val TEXT_CLASSIFICATION_INFO = const val TEXT_CLASSIFICATION_INFO =
"Model is trained on movie reviews dataset. Type a movie review below and see the scores of positive or negative sentiment." "Model is trained on movie reviews dataset. Type a movie review below and see the scores of positive or negative sentiment."
@ -256,92 +226,97 @@ const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/lite
const val IMAGE_GENERATION_INFO = const val IMAGE_GENERATION_INFO =
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)" "Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model =
Model(
name = "MobileBert",
downloadFileName = "bert_classifier.tflite",
url =
"https://storage.googleapis.com/mediapipe-models/text_classifier/bert_classifier/float32/latest/bert_classifier.tflite",
sizeInBytes = 25707538L,
info = TEXT_CLASSIFICATION_INFO,
learnMoreUrl = TEXT_CLASSIFICATION_LEARN_MORE_URL,
)
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model( val MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING: Model =
name = "MobileBert", Model(
downloadFileName = "bert_classifier.tflite", name = "Average word embedding",
url = "https://storage.googleapis.com/mediapipe-models/text_classifier/bert_classifier/float32/latest/bert_classifier.tflite", downloadFileName = "average_word_classifier.tflite",
sizeInBytes = 25707538L, url =
info = TEXT_CLASSIFICATION_INFO, "https://storage.googleapis.com/mediapipe-models/text_classifier/average_word_classifier/float32/latest/average_word_classifier.tflite",
learnMoreUrl = TEXT_CLASSIFICATION_LEARN_MORE_URL, sizeInBytes = 775708L,
) info = TEXT_CLASSIFICATION_INFO,
)
val MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING: Model = Model( val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1: Model =
name = "Average word embedding", Model(
downloadFileName = "average_word_classifier.tflite", name = "Mobilenet V1",
url = "https://storage.googleapis.com/mediapipe-models/text_classifier/average_word_classifier/float32/latest/average_word_classifier.tflite", downloadFileName = "mobilenet_v1.tflite",
sizeInBytes = 775708L, url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v1.tflite",
info = TEXT_CLASSIFICATION_INFO, sizeInBytes = 16900760L,
) extraDataFiles =
listOf(
ModelDataFile(
name = "labels",
url =
"https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt",
downloadFileName = "mobilenet_labels_v1.txt",
sizeInBytes = 21685L,
)
),
configs = MOBILENET_CONFIGS,
info = IMAGE_CLASSIFICATION_INFO,
learnMoreUrl = IMAGE_CLASSIFICATION_LEARN_MORE_URL,
)
val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1: Model = Model( val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2: Model =
name = "Mobilenet V1", Model(
downloadFileName = "mobilenet_v1.tflite", name = "Mobilenet V2",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v1.tflite", downloadFileName = "mobilenet_v2.tflite",
sizeInBytes = 16900760L, url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v2.tflite",
extraDataFiles = listOf( sizeInBytes = 13978596L,
ModelDataFile( extraDataFiles =
name = "labels", listOf(
url = "https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt", ModelDataFile(
downloadFileName = "mobilenet_labels_v1.txt", name = "labels",
sizeInBytes = 21685L url =
), "https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt",
), downloadFileName = "mobilenet_labels_v2.txt",
configs = MOBILENET_CONFIGS, sizeInBytes = 21685L,
info = IMAGE_CLASSIFICATION_INFO, )
learnMoreUrl = IMAGE_CLASSIFICATION_LEARN_MORE_URL, ),
) configs = MOBILENET_CONFIGS,
info = IMAGE_CLASSIFICATION_INFO,
)
val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2: Model = Model( val MODEL_IMAGE_GENERATION_STABLE_DIFFUSION: Model =
name = "Mobilenet V2", Model(
downloadFileName = "mobilenet_v2.tflite", name = "Stable diffusion",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v2.tflite", downloadFileName = "sd15.zip",
sizeInBytes = 13978596L, isZip = true,
extraDataFiles = listOf( unzipDir = "sd15",
ModelDataFile( url = "https://storage.googleapis.com/tfweb/app_gallery_models/sd15.zip",
name = "labels", sizeInBytes = 1906219565L,
url = "https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt", showRunAgainButton = false,
downloadFileName = "mobilenet_labels_v2.txt", showBenchmarkButton = false,
sizeInBytes = 21685L info = IMAGE_GENERATION_INFO,
), configs = IMAGE_GENERATION_CONFIGS,
), learnMoreUrl = "https://huggingface.co/litert-community",
configs = MOBILENET_CONFIGS, )
info = IMAGE_CLASSIFICATION_INFO,
)
val MODEL_IMAGE_GENERATION_STABLE_DIFFUSION: Model = Model( val EMPTY_MODEL: Model =
name = "Stable diffusion", Model(name = "empty", downloadFileName = "empty.tflite", url = "", sizeInBytes = 0L)
downloadFileName = "sd15.zip",
isZip = true,
unzipDir = "sd15",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/sd15.zip",
sizeInBytes = 1906219565L,
showRunAgainButton = false,
showBenchmarkButton = false,
info = IMAGE_GENERATION_INFO,
configs = IMAGE_GENERATION_CONFIGS,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val EMPTY_MODEL: Model = Model(
name = "empty",
downloadFileName = "empty.tflite",
url = "",
sizeInBytes = 0L,
)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// Model collections for different tasks. // Model collections for different tasks.
val MODELS_TEXT_CLASSIFICATION: MutableList<Model> = mutableListOf( val MODELS_TEXT_CLASSIFICATION: MutableList<Model> =
MODEL_TEXT_CLASSIFICATION_MOBILEBERT, mutableListOf(
MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING, MODEL_TEXT_CLASSIFICATION_MOBILEBERT,
) MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING,
)
val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf( val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> =
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1, mutableListOf(MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1, MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2)
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
)
val MODELS_IMAGE_GENERATION: MutableList<Model> = val MODELS_IMAGE_GENERATION: MutableList<Model> =
mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION) mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION)

View file

@ -16,15 +16,17 @@
package com.google.ai.edge.gallery.data package com.google.ai.edge.gallery.data
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_ACCELERATORS import com.google.gson.annotations.SerializedName
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TOPK data class DefaultConfig(
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TOPP @SerializedName("topK") val topK: Int?,
import com.google.ai.edge.gallery.ui.llmchat.createLlmChatConfigs @SerializedName("topP") val topP: Float?,
import kotlinx.serialization.Serializable @SerializedName("temperature") val temperature: Float?,
@SerializedName("accelerators") val accelerators: String?,
@SerializedName("maxTokens") val maxTokens: Int?,
)
/** A model in the model allowlist. */ /** A model in the model allowlist. */
@Serializable
data class AllowedModel( data class AllowedModel(
val name: String, val name: String,
val modelId: String, val modelId: String,
@ -32,10 +34,11 @@ data class AllowedModel(
val description: String, val description: String,
val sizeInBytes: Long, val sizeInBytes: Long,
val version: String, val version: String,
val defaultConfig: Map<String, ConfigValue>, val defaultConfig: DefaultConfig,
val taskTypes: List<String>, val taskTypes: List<String>,
val disabled: Boolean? = null, val disabled: Boolean? = null,
val llmSupportImage: Boolean? = null, val llmSupportImage: Boolean? = null,
val estimatedPeakMemoryInBytes: Long? = null,
) { ) {
fun toModel(): Model { fun toModel(): Model {
// Construct HF download url. // Construct HF download url.
@ -46,25 +49,13 @@ data class AllowedModel(
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id) taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)
var configs: List<Config> = listOf() var configs: List<Config> = listOf()
if (isLlmModel) { if (isLlmModel) {
var defaultTopK: Int = DEFAULT_TOPK var defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK
var defaultTopP: Float = DEFAULT_TOPP var defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP
var defaultTemperature: Float = DEFAULT_TEMPERATURE var defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE
var defaultMaxToken = 1024 var defaultMaxToken = defaultConfig.maxTokens ?: 1024
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
if (defaultConfig.containsKey("topK")) { if (defaultConfig.accelerators != null) {
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK) val items = defaultConfig.accelerators.split(",")
}
if (defaultConfig.containsKey("topP")) {
defaultTopP = getFloatConfigValue(defaultConfig["topP"], defaultTopP)
}
if (defaultConfig.containsKey("temperature")) {
defaultTemperature = getFloatConfigValue(defaultConfig["temperature"], defaultTemperature)
}
if (defaultConfig.containsKey("maxTokens")) {
defaultMaxToken = getIntConfigValue(defaultConfig["maxTokens"], defaultMaxToken)
}
if (defaultConfig.containsKey("accelerators")) {
val items = getStringConfigValue(defaultConfig["accelerators"], "gpu").split(",")
accelerators = mutableListOf() accelerators = mutableListOf()
for (item in items) { for (item in items) {
if (item == "cpu") { if (item == "cpu") {
@ -74,13 +65,14 @@ data class AllowedModel(
} }
} }
} }
configs = createLlmChatConfigs( configs =
defaultTopK = defaultTopK, createLlmChatConfigs(
defaultTopP = defaultTopP, defaultTopK = defaultTopK,
defaultTemperature = defaultTemperature, defaultTopP = defaultTopP,
defaultMaxToken = defaultMaxToken, defaultTemperature = defaultTemperature,
accelerators = accelerators, defaultMaxToken = defaultMaxToken,
) accelerators = accelerators,
)
} }
// Misc. // Misc.
@ -97,6 +89,7 @@ data class AllowedModel(
info = description, info = description,
url = downloadUrl, url = downloadUrl,
sizeInBytes = sizeInBytes, sizeInBytes = sizeInBytes,
estimatedPeakMemoryInBytes = estimatedPeakMemoryInBytes,
configs = configs, configs = configs,
downloadFileName = modelFile, downloadFileName = modelFile,
showBenchmarkButton = showBenchmarkButton, showBenchmarkButton = showBenchmarkButton,
@ -112,8 +105,4 @@ data class AllowedModel(
} }
/** The model allowlist. */ /** The model allowlist. */
@Serializable data class ModelAllowlist(val models: List<AllowedModel>)
data class ModelAllowlist(
val models: List<AllowedModel>,
)

View file

@ -17,27 +17,21 @@
package com.google.ai.edge.gallery.data package com.google.ai.edge.gallery.data
import androidx.annotation.StringRes import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Forum
import androidx.compose.material.icons.outlined.Mms
import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.material.icons.rounded.ImageSearch
import androidx.compose.runtime.MutableState import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableLongStateOf import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.icon.Forum
import com.google.ai.edge.gallery.ui.icon.Mms
import com.google.ai.edge.gallery.ui.icon.Widgets
/** Type of task. */ /** Type of task. */
enum class TaskType(val label: String, val id: String) { enum class TaskType(val label: String, val id: String) {
TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
LLM_CHAT(label = "AI Chat", id = "llm_chat"), LLM_CHAT(label = "AI Chat", id = "llm_chat"),
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"), LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"), LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2") TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
} }
/** Data class for a task listed in home screen. */ /** Data class for a task listed in home screen. */
@ -71,71 +65,47 @@ data class Task(
// The following fields are managed by the app. Don't need to set manually. // The following fields are managed by the app. Don't need to set manually.
var index: Int = -1, var index: Int = -1,
val updateTrigger: MutableState<Long> = mutableLongStateOf(0),
val updateTrigger: MutableState<Long> = mutableLongStateOf(0)
) )
val TASK_TEXT_CLASSIFICATION = Task( val TASK_LLM_CHAT =
type = TaskType.TEXT_CLASSIFICATION, Task(
iconVectorResourceId = R.drawable.text_spark, type = TaskType.LLM_CHAT,
models = MODELS_TEXT_CLASSIFICATION, icon = Forum,
description = "Classify text into different categories", models = mutableListOf(),
textInputPlaceHolderRes = R.string.text_input_placeholder_text_classification description = "Chat with on-device large language models",
) docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl =
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
val TASK_IMAGE_CLASSIFICATION = Task( val TASK_LLM_PROMPT_LAB =
type = TaskType.IMAGE_CLASSIFICATION, Task(
icon = Icons.Rounded.ImageSearch, type = TaskType.LLM_PROMPT_LAB,
description = "Classify images into different categories", icon = Widgets,
models = MODELS_IMAGE_CLASSIFICATION models = mutableListOf(),
) description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl =
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
val TASK_LLM_CHAT = Task( val TASK_LLM_ASK_IMAGE =
type = TaskType.LLM_CHAT, Task(
icon = Icons.Outlined.Forum, type = TaskType.LLM_ASK_IMAGE,
models = mutableListOf(), icon = Mms,
description = "Chat with on-device large language models", models = mutableListOf(),
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", description = "Ask questions about images with on-device large language models",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat sourceCodeUrl =
) "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
val TASK_LLM_PROMPT_LAB = Task( )
type = TaskType.LLM_PROMPT_LAB,
icon = Icons.Outlined.Widgets,
models = mutableListOf(),
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_LLM_ASK_IMAGE = Task(
type = TaskType.LLM_ASK_IMAGE,
icon = Icons.Outlined.Mms,
models = mutableListOf(),
description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_IMAGE_GENERATION = Task(
type = TaskType.IMAGE_GENERATION,
iconVectorResourceId = R.drawable.image_spark,
models = MODELS_IMAGE_GENERATION,
description = "Generate images from text",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt",
textInputPlaceHolderRes = R.string.text_image_generation_text_field_placeholder
)
/** All tasks. */ /** All tasks. */
val TASKS: List<Task> = listOf( val TASKS: List<Task> = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
TASK_LLM_ASK_IMAGE,
TASK_LLM_PROMPT_LAB,
TASK_LLM_CHAT,
)
fun getModelByName(name: String): Model? { fun getModelByName(name: String): Model? {
for (task in TASKS) { for (task in TASKS) {
@ -147,3 +117,12 @@ fun getModelByName(name: String): Model? {
} }
return null return null
} }
fun processTasks() {
for ((index, task) in TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess()
}
}
}

View file

@ -0,0 +1,22 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.data
enum class Accelerator(val label: String) {
CPU(label = "CPU"),
GPU(label = "GPU"),
}

View file

@ -16,19 +16,15 @@
package com.google.ai.edge.gallery.ui package com.google.ai.edge.gallery.ui
import android.app.Application
import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory
import androidx.lifecycle.viewmodel.CreationExtras import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.initializer import androidx.lifecycle.viewmodel.initializer
import androidx.lifecycle.viewmodel.viewModelFactory import androidx.lifecycle.viewmodel.viewModelFactory
import com.google.ai.edge.gallery.GalleryApplication import com.google.ai.edge.gallery.GalleryApplication
import com.google.ai.edge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.ai.edge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.textclassification.TextClassificationViewModel
object ViewModelProvider { object ViewModelProvider {
val Factory = viewModelFactory { val Factory = viewModelFactory {
@ -36,42 +32,23 @@ object ViewModelProvider {
initializer { initializer {
val downloadRepository = galleryApplication().container.downloadRepository val downloadRepository = galleryApplication().container.downloadRepository
val dataStoreRepository = galleryApplication().container.dataStoreRepository val dataStoreRepository = galleryApplication().container.dataStoreRepository
val lifecycleProvider = galleryApplication().container.lifecycleProvider
ModelManagerViewModel( ModelManagerViewModel(
downloadRepository = downloadRepository, downloadRepository = downloadRepository,
dataStoreRepository = dataStoreRepository, dataStoreRepository = dataStoreRepository,
lifecycleProvider = lifecycleProvider,
context = galleryApplication().container.context, context = galleryApplication().container.context,
) )
} }
// Initializer for TextClassificationViewModel
initializer {
TextClassificationViewModel()
}
// Initializer for ImageClassificationViewModel
initializer {
ImageClassificationViewModel()
}
// Initializer for LlmChatViewModel. // Initializer for LlmChatViewModel.
initializer { initializer { LlmChatViewModel() }
LlmChatViewModel()
}
// Initializer for LlmSingleTurnViewModel.. // Initializer for LlmSingleTurnViewModel..
initializer { initializer { LlmSingleTurnViewModel() }
LlmSingleTurnViewModel()
}
// Initializer for LlmAskImageViewModel. // Initializer for LlmAskImageViewModel.
initializer { initializer { LlmAskImageViewModel() }
LlmAskImageViewModel()
}
// Initializer for ImageGenerationViewModel.
initializer {
ImageGenerationViewModel()
}
} }
} }

View file

@ -16,7 +16,7 @@
package com.google.ai.edge.gallery.ui.common package com.google.ai.edge.gallery.ui.common
import android.net.Uri import androidx.core.net.toUri
import net.openid.appauth.AuthorizationServiceConfiguration import net.openid.appauth.AuthorizationServiceConfiguration
object AuthConfig { object AuthConfig {
@ -34,8 +34,9 @@ object AuthConfig {
private const val tokenEndpoint = "https://huggingface.co/oauth/token" private const val tokenEndpoint = "https://huggingface.co/oauth/token"
// OAuth service configuration (AppAuth library requires this) // OAuth service configuration (AppAuth library requires this)
val authServiceConfig = AuthorizationServiceConfiguration( val authServiceConfig =
Uri.parse(authEndpoint), // Authorization endpoint AuthorizationServiceConfiguration(
Uri.parse(tokenEndpoint) // Token exchange endpoint authEndpoint.toUri(), // Authorization endpoint
) tokenEndpoint.toUri(), // Token exchange endpoint
} )
}

View file

@ -0,0 +1,68 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalUriHandler
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextDecoration
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.customColors
@Composable
fun ClickableLink(url: String, linkText: String, icon: ImageVector) {
val uriHandler = LocalUriHandler.current
val annotatedText =
AnnotatedString(
text = linkText,
spanStyles =
listOf(
AnnotatedString.Range(
item =
SpanStyle(
color = MaterialTheme.customColors.linkColor,
textDecoration = TextDecoration.Underline,
),
start = 0,
end = linkText.length,
)
),
)
Row(verticalAlignment = Alignment.CenterVertically, horizontalArrangement = Arrangement.Center) {
Icon(icon, contentDescription = "", modifier = Modifier.size(16.dp))
Text(
text = annotatedText,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyLarge,
modifier = Modifier.padding(start = 6.dp).clickable { uriHandler.openUri(url) },
)
}
}

View file

@ -0,0 +1,41 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.graphics.Color
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.theme.customColors
@Composable
fun getTaskBgColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskBgColors.size
return MaterialTheme.customColors.taskBgColors[colorIndex]
}
@Composable
fun getTaskIconColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}
@Composable
fun getTaskIconColor(index: Int): Color {
val colorIndex: Int = index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}

View file

@ -14,8 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.util.Log import android.util.Log
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
@ -61,7 +64,6 @@ import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.text.TextStyle import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.input.KeyboardType import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.Dialog
import com.google.ai.edge.gallery.data.BooleanSwitchConfig import com.google.ai.edge.gallery.data.BooleanSwitchConfig
@ -70,8 +72,6 @@ import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.NumberSliderConfig import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.ValueType import com.google.ai.edge.gallery.data.ValueType
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
import kotlin.Double.Companion.NaN import kotlin.Double.Companion.NaN
@ -92,33 +92,32 @@ fun ConfigDialog(
showCancel: Boolean = true, showCancel: Boolean = true,
) { ) {
val values: SnapshotStateMap<String, Any> = remember { val values: SnapshotStateMap<String, Any> = remember {
mutableStateMapOf<String, Any>().apply { mutableStateMapOf<String, Any>().apply { putAll(initialValues) }
putAll(initialValues)
}
} }
val interactionSource = remember { MutableInteractionSource() } val interactionSource = remember { MutableInteractionSource() }
Dialog(onDismissRequest = onDismissed) { Dialog(onDismissRequest = onDismissed) {
val focusManager = LocalFocusManager.current val focusManager = LocalFocusManager.current
Card( Card(
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth().clickable(
.clickable( interactionSource = interactionSource,
interactionSource = interactionSource, indication = null // Disable the ripple effect indication = null, // Disable the ripple effect
) { ) {
focusManager.clearFocus() focusManager.clearFocus()
}, },
shape = RoundedCornerShape(16.dp) shape = RoundedCornerShape(16.dp),
) { ) {
Column( Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) { ) {
// Dialog title and subtitle. // Dialog title and subtitle.
Column { Column {
Text( Text(
title, title,
style = MaterialTheme.typography.titleLarge, style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp) modifier = Modifier.padding(bottom = 8.dp),
) )
// Subtitle. // Subtitle.
if (subtitle.isNotEmpty()) { if (subtitle.isNotEmpty()) {
@ -126,35 +125,27 @@ fun ConfigDialog(
subtitle, subtitle,
style = labelSmallNarrow, style = labelSmallNarrow,
color = MaterialTheme.colorScheme.onSurfaceVariant, color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.offset(y = (-6).dp) modifier = Modifier.offset(y = (-6).dp),
) )
} }
} }
// List of config rows. // List of config rows.
Column( Column(
modifier = Modifier modifier = Modifier.verticalScroll(rememberScrollState()).weight(1f, fill = false),
.verticalScroll(rememberScrollState()) verticalArrangement = Arrangement.spacedBy(16.dp),
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
) { ) {
ConfigEditorsPanel(configs = configs, values = values) ConfigEditorsPanel(configs = configs, values = values)
} }
// Button row. // Button row.
Row( Row(
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(top = 8.dp),
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End, horizontalArrangement = Arrangement.End,
) { ) {
// Cancel button. // Cancel button.
if (showCancel) { if (showCancel) {
TextButton( TextButton(onClick = { onDismissed() }) { Text("Cancel") }
onClick = { onDismissed() },
) {
Text("Cancel")
}
} }
// Ok button // Ok button
@ -162,7 +153,7 @@ fun ConfigDialog(
onClick = { onClick = {
Log.d(TAG, "Values from dialog: $values") Log.d(TAG, "Values from dialog: $values")
onOk(values.toMap()) onOk(values.toMap())
}, }
) { ) {
Text(okBtnLabel) Text(okBtnLabel)
} }
@ -172,9 +163,7 @@ fun ConfigDialog(
} }
} }
/** /** Composable function to display a list of config editor rows. */
* Composable function to display a list of config editor rows.
*/
@Composable @Composable
fun ConfigEditorsPanel(configs: List<Config>, values: SnapshotStateMap<String, Any>) { fun ConfigEditorsPanel(configs: List<Config>, values: SnapshotStateMap<String, Any>) {
for (config in configs) { for (config in configs) {
@ -210,11 +199,12 @@ fun LabelRow(config: LabelConfig, values: SnapshotStateMap<String, Any>) {
// Field label. // Field label.
Text(config.key.label, style = MaterialTheme.typography.titleSmall) Text(config.key.label, style = MaterialTheme.typography.titleSmall)
// Content label. // Content label.
val label = try { val label =
values[config.key.label] as String try {
} catch (e: Exception) { values[config.key.label] as String
"" } catch (e: Exception) {
} ""
}
Text(label, style = MaterialTheme.typography.bodyMedium) Text(label, style = MaterialTheme.typography.bodyMedium)
} }
} }
@ -222,9 +212,9 @@ fun LabelRow(config: LabelConfig, values: SnapshotStateMap<String, Any>) {
/** /**
* Composable function to display a number slider with an associated text input field. * Composable function to display a number slider with an associated text input field.
* *
* This function renders a row containing a slider and a text field, both used to modify * This function renders a row containing a slider and a text field, both used to modify a numeric
* a numeric value. The slider allows users to visually adjust the value within a specified range, * value. The slider allows users to visually adjust the value within a specified range, while the
* while the text field provides precise numeric input. * text field provides precise numeric input.
*/ */
@Composable @Composable
fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String, Any>) { fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String, Any>) {
@ -233,52 +223,50 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
Text(config.key.label, style = MaterialTheme.typography.titleSmall) Text(config.key.label, style = MaterialTheme.typography.titleSmall)
// Controls row. // Controls row.
Row( Row(modifier = Modifier.fillMaxWidth(), verticalAlignment = Alignment.CenterVertically) {
modifier = Modifier.fillMaxWidth(), verticalAlignment = Alignment.CenterVertically
) {
var isFocused by remember { mutableStateOf(false) } var isFocused by remember { mutableStateOf(false) }
val focusRequester = remember { FocusRequester() } val focusRequester = remember { FocusRequester() }
// Number slider. // Number slider.
val sliderValue = try { val sliderValue =
values[config.key.label] as Float try {
} catch (e: Exception) { values[config.key.label] as Float
0f } catch (e: Exception) {
} 0f
Slider(modifier = Modifier }
.height(24.dp) Slider(
.weight(1f), modifier = Modifier.height(24.dp).weight(1f),
value = sliderValue, value = sliderValue,
valueRange = config.sliderMin..config.sliderMax, valueRange = config.sliderMin..config.sliderMax,
onValueChange = { values[config.key.label] = it }) onValueChange = { values[config.key.label] = it },
)
Spacer(modifier = Modifier.width(8.dp)) Spacer(modifier = Modifier.width(8.dp))
// Text field. // Text field.
val textFieldValue = try { val textFieldValue =
when (config.valueType) { try {
ValueType.FLOAT -> { when (config.valueType) {
"%.2f".format(values[config.key.label] as Float) ValueType.FLOAT -> {
} "%.2f".format(values[config.key.label] as Float)
}
ValueType.INT -> { ValueType.INT -> {
"${(values[config.key.label] as Float).toInt()}" "${(values[config.key.label] as Float).toInt()}"
} }
else -> { else -> {
"" ""
}
} }
} catch (e: Exception) {
""
} }
} catch (e: Exception) {
""
}
// A smaller text field. // A smaller text field.
BasicTextField( BasicTextField(
value = textFieldValue, value = textFieldValue,
modifier = Modifier modifier =
.width(80.dp) Modifier.width(80.dp).focusRequester(focusRequester).onFocusChanged {
.focusRequester(focusRequester)
.onFocusChanged {
isFocused = it.isFocused isFocused = it.isFocused
}, },
keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number), keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number),
@ -293,15 +281,16 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface), cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
) { innerTextField -> ) { innerTextField ->
Box( Box(
modifier = Modifier.border( modifier =
width = if (isFocused) 2.dp else 1.dp, Modifier.border(
color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline, width = if (isFocused) 2.dp else 1.dp,
shape = RoundedCornerShape(4.dp) color =
) if (isFocused) MaterialTheme.colorScheme.primary
else MaterialTheme.colorScheme.outline,
shape = RoundedCornerShape(4.dp),
)
) { ) {
Box(modifier = Modifier.padding(8.dp)) { Box(modifier = Modifier.padding(8.dp)) { innerTextField() }
innerTextField()
}
} }
} }
} }
@ -311,16 +300,17 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
/** /**
* Composable function to display a row with a boolean switch. * Composable function to display a row with a boolean switch.
* *
* This function renders a row containing a label and a switch, allowing users to toggle * This function renders a row containing a label and a switch, allowing users to toggle a boolean
* a boolean value. * value.
*/ */
@Composable @Composable
fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<String, Any>) { fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<String, Any>) {
val switchValue = try { val switchValue =
values[config.key.label] as Boolean try {
} catch (e: Exception) { values[config.key.label] as Boolean
false } catch (e: Exception) {
} false
}
Column(modifier = Modifier.fillMaxWidth()) { Column(modifier = Modifier.fillMaxWidth()) {
Text(config.key.label, style = MaterialTheme.typography.titleSmall) Text(config.key.label, style = MaterialTheme.typography.titleSmall)
Switch(checked = switchValue, onCheckedChange = { values[config.key.label] = it }) Switch(checked = switchValue, onCheckedChange = { values[config.key.label] = it })
@ -331,63 +321,66 @@ fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<Strin
fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) { fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) {
val selectedOptions: List<String> = remember { (values[config.key.label] as String).split(",") } val selectedOptions: List<String> = remember { (values[config.key.label] as String).split(",") }
var selectionStates: List<Boolean> by remember { var selectionStates: List<Boolean> by remember {
mutableStateOf(List(config.options.size) { index -> mutableStateOf(
selectedOptions.contains(config.options[index]) List(config.options.size) { index -> selectedOptions.contains(config.options[index]) }
}) )
} }
Column(modifier = Modifier.fillMaxWidth()) { Column(modifier = Modifier.fillMaxWidth()) {
Text(config.key.label, style = MaterialTheme.typography.titleSmall) Text(config.key.label, style = MaterialTheme.typography.titleSmall)
MultiChoiceSegmentedButtonRow { MultiChoiceSegmentedButtonRow {
config.options.forEachIndexed { index, label -> config.options.forEachIndexed { index, label ->
SegmentedButton(shape = SegmentedButtonDefaults.itemShape( SegmentedButton(
index = index, count = config.options.size shape = SegmentedButtonDefaults.itemShape(index = index, count = config.options.size),
), onCheckedChange = { onCheckedChange = {
var newSelectionStates = selectionStates.toMutableList() var newSelectionStates = selectionStates.toMutableList()
val selectedCount = newSelectionStates.count { it } val selectedCount = newSelectionStates.count { it }
// Single select. // Single select.
if (!config.allowMultiple) { if (!config.allowMultiple) {
if (!newSelectionStates[index]) { if (!newSelectionStates[index]) {
newSelectionStates = MutableList(config.options.size) { it == index } newSelectionStates = MutableList(config.options.size) { it == index }
}
} }
} // Multiple select.
// Multiple select. else {
else { if (!(selectedCount == 1 && newSelectionStates[index])) {
if (!(selectedCount == 1 && newSelectionStates[index])) { newSelectionStates[index] = !newSelectionStates[index]
newSelectionStates[index] = !newSelectionStates[index] }
} }
} selectionStates = newSelectionStates
selectionStates = newSelectionStates
values[config.key.label] = values[config.key.label] =
config.options.filterIndexed { index, option -> selectionStates[index] } config.options
.joinToString(",") .filterIndexed { index, option -> selectionStates[index] }
}, checked = selectionStates[index], label = { Text(label) }) .joinToString(",")
},
checked = selectionStates[index],
label = { Text(label) },
)
} }
} }
} }
} }
@Composable // @Composable
@Preview(showBackground = true) // @Preview(showBackground = true)
fun ConfigDialogPreview() { // fun ConfigDialogPreview() {
GalleryTheme { // GalleryTheme {
val defaultValues: MutableMap<String, Any> = mutableMapOf() // val defaultValues: MutableMap<String, Any> = mutableMapOf()
for (config in MODEL_TEST1.configs) { // for (config in MODEL_TEST1.configs) {
defaultValues[config.key.label] = config.defaultValue // defaultValues[config.key.label] = config.defaultValue
} // }
Column { // Column {
ConfigDialog( // ConfigDialog(
title = "Dialog title", // title = "Dialog title",
subtitle = "20250413", // subtitle = "20250413",
configs = MODEL_TEST1.configs, // configs = MODEL_TEST1.configs,
initialValues = defaultValues, // initialValues = defaultValues,
onDismissed = {}, // onDismissed = {},
onOk = {}, // onOk = {},
) // )
} // }
} // }
} // }

View file

@ -16,8 +16,8 @@
package com.google.ai.edge.gallery.ui.common package com.google.ai.edge.gallery.ui.common
import android.app.ActivityManager
import android.content.Intent import android.content.Intent
import android.net.Uri
import android.util.Log import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher import androidx.activity.result.ActivityResultLauncher
@ -51,16 +51,17 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import androidx.core.net.toUri
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.modelmanager.TokenRequestResultType import com.google.ai.edge.gallery.ui.modelmanager.TokenRequestResultType
import com.google.ai.edge.gallery.ui.modelmanager.TokenStatus import com.google.ai.edge.gallery.ui.modelmanager.TokenStatus
import java.net.HttpURLConnection
import kotlin.math.max
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import java.net.HttpURLConnection
private const val TAG = "AGDownloadAndTryButton" private const val TAG = "AGDownloadAndTryButton"
@ -72,14 +73,14 @@ private const val TAG = "AGDownloadAndTryButton"
* Handles the "Download & Try it" button click, managing the model download process based on * Handles the "Download & Try it" button click, managing the model download process based on
* various conditions. * various conditions.
* *
* If the button is enabled and not currently checking the token, it initiates a coroutine to * If the button is enabled and not currently checking the token, it initiates a coroutine to handle
* handle the download logic. * the download logic.
* *
* For models requiring download first, it specifically addresses HuggingFace URLs by first * For models requiring download first, it specifically addresses HuggingFace URLs by first checking
* checking if authentication is necessary. If no authentication is needed, the download starts * if authentication is necessary. If no authentication is needed, the download starts directly.
* directly. Otherwise, it checks the current token status; if the token is invalid or expired, * Otherwise, it checks the current token status; if the token is invalid or expired, a token
* a token exchange flow is initiated. If a valid token exists, it attempts to access the * exchange flow is initiated. If a valid token exists, it attempts to access the download URL. If
* download URL. If access is granted, the download begins; if not, a new token is requested. * access is granted, the download begins; if not, a new token is requested.
* *
* For non-HuggingFace URLs that need downloading, the download starts directly. * For non-HuggingFace URLs that need downloading, the download starts directly.
* *
@ -102,21 +103,21 @@ fun DownloadAndTryButton(
enabled: Boolean, enabled: Boolean,
needToDownloadFirst: Boolean, needToDownloadFirst: Boolean,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
onClicked: () -> Unit onClicked: () -> Unit,
) { ) {
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val context = LocalContext.current val context = LocalContext.current
var checkingToken by remember { mutableStateOf(false) } var checkingToken by remember { mutableStateOf(false) }
var showAgreementAckSheet by remember { mutableStateOf(false) } var showAgreementAckSheet by remember { mutableStateOf(false) }
var showErrorDialog by remember { mutableStateOf(false) } var showErrorDialog by remember { mutableStateOf(false) }
var showMemoryWarning by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState() val sheetState = rememberModalBottomSheetState()
// A launcher for requesting notification permission. // A launcher for requesting notification permission.
val permissionLauncher = rememberLauncherForActivityResult( val permissionLauncher =
ActivityResultContracts.RequestPermission() rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
) { modelManagerViewModel.downloadModel(task = task, model = model)
modelManagerViewModel.downloadModel(task = task, model = model) }
}
// Function to kick off download. // Function to kick off download.
val startDownload: (accessToken: String?) -> Unit = { accessToken -> val startDownload: (accessToken: String?) -> Unit = { accessToken ->
@ -127,64 +128,73 @@ fun DownloadAndTryButton(
launcher = permissionLauncher, launcher = permissionLauncher,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
task = task, task = task,
model = model model = model,
) )
checkingToken = false checkingToken = false
} }
// A launcher for opening the custom tabs intent for requesting user agreement ack. // A launcher for opening the custom tabs intent for requesting user agreement ack.
// Once the tab is closed, try starting the download process. // Once the tab is closed, try starting the download process.
val agreementAckLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult( val agreementAckLauncher: ActivityResultLauncher<Intent> =
contract = ActivityResultContracts.StartActivityForResult() rememberLauncherForActivityResult(
) { result -> contract = ActivityResultContracts.StartActivityForResult()
Log.d(TAG, "User closes the browser tab. Try to start downloading.") ) { result ->
startDownload(modelManagerViewModel.curAccessToken) Log.d(TAG, "User closes the browser tab. Try to start downloading.")
} startDownload(modelManagerViewModel.curAccessToken)
}
// A launcher for handling the authentication flow. // A launcher for handling the authentication flow.
// It processes the result of the authentication activity and then checks if a user agreement // It processes the result of the authentication activity and then checks if a user agreement
// acknowledgement is needed before proceeding with the model download. // acknowledgement is needed before proceeding with the model download.
val authResultLauncher = rememberLauncherForActivityResult( val authResultLauncher =
contract = ActivityResultContracts.StartActivityForResult() rememberLauncherForActivityResult(
) { result -> contract = ActivityResultContracts.StartActivityForResult()
modelManagerViewModel.handleAuthResult(result, onTokenRequested = { tokenRequestResult -> ) { result ->
when (tokenRequestResult.status) { modelManagerViewModel.handleAuthResult(
TokenRequestResultType.SUCCEEDED -> { result,
Log.d(TAG, "Token request succeeded. Checking if we need user to ack user agreement") onTokenRequested = { tokenRequestResult ->
scope.launch(Dispatchers.IO) { when (tokenRequestResult.status) {
// Check if we can use the current token to access model. If not, we might need to TokenRequestResultType.SUCCEEDED -> {
// acknowledge the user agreement. Log.d(TAG, "Token request succeeded. Checking if we need user to ack user agreement")
if (modelManagerViewModel.getModelUrlResponse( scope.launch(Dispatchers.IO) {
model = model, // Check if we can use the current token to access model. If not, we might need to
accessToken = modelManagerViewModel.curAccessToken // acknowledge the user agreement.
) == HttpURLConnection.HTTP_FORBIDDEN if (
) { modelManagerViewModel.getModelUrlResponse(
Log.d(TAG, "Model '${model.name}' needs user agreement ack.") model = model,
showAgreementAckSheet = true accessToken = modelManagerViewModel.curAccessToken,
} else { ) == HttpURLConnection.HTTP_FORBIDDEN
Log.d( ) {
TAG, Log.d(TAG, "Model '${model.name}' needs user agreement ack.")
"Model '${model.name}' does NOT need user agreement ack. Start downloading..." showAgreementAckSheet = true
) } else {
withContext(Dispatchers.Main) { Log.d(
startDownload(modelManagerViewModel.curAccessToken) TAG,
"Model '${model.name}' does NOT need user agreement ack. Start downloading...",
)
withContext(Dispatchers.Main) {
startDownload(modelManagerViewModel.curAccessToken)
}
}
} }
} }
TokenRequestResultType.FAILED -> {
Log.d(
TAG,
"Token request done. Error message: ${tokenRequestResult.errorMessage ?: ""}",
)
checkingToken = false
}
TokenRequestResultType.USER_CANCELLED -> {
Log.d(TAG, "User cancelled. Do nothing")
checkingToken = false
}
} }
} },
)
TokenRequestResultType.FAILED -> { }
Log.d(TAG, "Token request done. Error message: ${tokenRequestResult.errorMessage ?: ""}")
checkingToken = false
}
TokenRequestResultType.USER_CANCELLED -> {
Log.d(TAG, "User cancelled. Do nothing")
checkingToken = false
}
}
})
}
// Function to kick off the authentication and token exchange flow. // Function to kick off the authentication and token exchange flow.
val startTokenExchange = { val startTokenExchange = {
@ -213,14 +223,12 @@ fun DownloadAndTryButton(
// Check if the url needs auth. // Check if the url needs auth.
Log.d( Log.d(
TAG, TAG,
"Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download" "Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download",
) )
val firstResponseCode = modelManagerViewModel.getModelUrlResponse(model = model) val firstResponseCode = modelManagerViewModel.getModelUrlResponse(model = model)
if (firstResponseCode == HttpURLConnection.HTTP_OK) { if (firstResponseCode == HttpURLConnection.HTTP_OK) {
Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...") Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...")
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) { startDownload(null) }
startDownload(null)
}
return@launch return@launch
} else if (firstResponseCode < 0) { } else if (firstResponseCode < 0) {
checkingToken = false checkingToken = false
@ -235,37 +243,36 @@ fun DownloadAndTryButton(
when (tokenStatusAndData.status) { when (tokenStatusAndData.status) {
// If token is not stored or expired, log in and request a new token. // If token is not stored or expired, log in and request a new token.
TokenStatus.NOT_STORED, TokenStatus.EXPIRED -> { TokenStatus.NOT_STORED,
withContext(Dispatchers.Main) { TokenStatus.EXPIRED -> {
startTokenExchange() withContext(Dispatchers.Main) { startTokenExchange() }
}
} }
// If token is still valid... // If token is still valid...
TokenStatus.NOT_EXPIRED -> { TokenStatus.NOT_EXPIRED -> {
// Use the current token to check the download url. // Use the current token to check the download url.
Log.d(TAG, "Checking the download url '${model.url}' with the current token...") Log.d(TAG, "Checking the download url '${model.url}' with the current token...")
val responseCode = modelManagerViewModel.getModelUrlResponse( val responseCode =
model = model, accessToken = tokenStatusAndData.data!!.accessToken modelManagerViewModel.getModelUrlResponse(
) model = model,
accessToken = tokenStatusAndData.data!!.accessToken,
)
if (responseCode == HttpURLConnection.HTTP_OK) { if (responseCode == HttpURLConnection.HTTP_OK) {
// Download url is accessible. Download the model. // Download url is accessible. Download the model.
Log.d(TAG, "Download url is accessible with the current token.") Log.d(TAG, "Download url is accessible with the current token.")
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) {
startDownload(tokenStatusAndData.data.accessToken) startDownload(tokenStatusAndData.data!!.accessToken)
} }
} }
// Download url is NOT accessible. Request a new token. // Download url is NOT accessible. Request a new token.
else { else {
Log.d( Log.d(
TAG, TAG,
"Download url is NOT accessible. Response code: ${responseCode}. Trying to request a new token." "Download url is NOT accessible. Response code: ${responseCode}. Trying to request a new token.",
) )
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) { startTokenExchange() }
startTokenExchange()
}
} }
} }
} }
@ -274,24 +281,50 @@ fun DownloadAndTryButton(
else { else {
Log.d( Log.d(
TAG, TAG,
"Model '${model.name}' is not from huggingface. Start downloading the model..." "Model '${model.name}' is not from huggingface. Start downloading the model...",
) )
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) { startDownload(null) }
startDownload(null)
}
} }
} else { } else {
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) {
onClicked() val activityManager =
context.getSystemService(android.app.Activity.ACTIVITY_SERVICE) as? ActivityManager
val estimatedPeakMemoryInBytes = model.estimatedPeakMemoryInBytes
val isMemoryLow =
if (activityManager != null && estimatedPeakMemoryInBytes != null) {
val memoryInfo = ActivityManager.MemoryInfo()
activityManager.getMemoryInfo(memoryInfo)
Log.d(
TAG,
"AvailMem: ${memoryInfo.availMem}. TotalMem: ${memoryInfo.totalMem}. Estimated peak memory: ${estimatedPeakMemoryInBytes}.",
)
// The device should be able to run the model if `availMem` is larger than the
// estimated peak memory. Android also has a mechanism to kill background apps to
// free up memory for the foreground app. We believe that if half of the total
// memory on the device is larger than the estimated peak memory, it can run the
// model fine with this mechanism. For example, a phone with 12GB memory can have
// very few `availMem` but will have no problem running most models.
max(memoryInfo.availMem, memoryInfo.totalMem / 2) < estimatedPeakMemoryInBytes
} else {
false
}
if (isMemoryLow) {
showMemoryWarning = true
} else {
onClicked()
}
} }
} }
} }
}, }
) { ) {
Icon( Icon(
Icons.AutoMirrored.Rounded.ArrowForward, Icons.AutoMirrored.Rounded.ArrowForward,
contentDescription = "", contentDescription = "",
modifier = Modifier.padding(end = 4.dp) modifier = Modifier.padding(end = 4.dp),
) )
val textColor = MaterialTheme.colorScheme.onPrimary val textColor = MaterialTheme.colorScheme.onPrimary
@ -301,11 +334,7 @@ fun DownloadAndTryButton(
maxLines = 1, maxLines = 1,
color = { textColor }, color = { textColor },
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased( autoSize = TextAutoSize.StepBased(minFontSize = 8.sp, maxFontSize = 14.sp, stepSize = 1.sp),
minFontSize = 8.sp,
maxFontSize = 14.sp,
stepSize = 1.sp
)
) )
} else { } else {
if (needToDownloadFirst) { if (needToDownloadFirst) {
@ -314,11 +343,8 @@ fun DownloadAndTryButton(
maxLines = 1, maxLines = 1,
color = { textColor }, color = { textColor },
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased( autoSize =
minFontSize = 8.sp, TextAutoSize.StepBased(minFontSize = 8.sp, maxFontSize = 14.sp, stepSize = 1.sp),
maxFontSize = 14.sp,
stepSize = 1.sp
)
) )
} else { } else {
Text("Try it", maxLines = 1) Text("Try it", maxLines = 1)
@ -341,28 +367,30 @@ fun DownloadAndTryButton(
) { ) {
Column( Column(
horizontalAlignment = Alignment.CenterHorizontally, horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier.padding(horizontal = 16.dp) modifier = Modifier.padding(horizontal = 16.dp),
) { ) {
Text("Acknowledge user agreement", style = MaterialTheme.typography.titleLarge) Text("Acknowledge user agreement", style = MaterialTheme.typography.titleLarge)
Text( Text(
"This is a gated model. Please click the button below to view and agree to the user agreement. After accepting, simply close that tab to proceed with the model download.", "This is a gated model. Please click the button below to view and agree to the user agreement. After accepting, simply close that tab to proceed with the model download.",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
modifier = Modifier.padding(vertical = 16.dp) modifier = Modifier.padding(vertical = 16.dp),
) )
Button(onClick = { Button(
// Get agreement url from model url. onClick = {
val index = model.url.indexOf("/resolve/") // Get agreement url from model url.
// Show it in a tab. val index = model.url.indexOf("/resolve/")
if (index >= 0) { // Show it in a tab.
val agreementUrl = model.url.substring(0, index) if (index >= 0) {
val agreementUrl = model.url.substring(0, index)
val customTabsIntent = CustomTabsIntent.Builder().build() val customTabsIntent = CustomTabsIntent.Builder().build()
customTabsIntent.intent.setData(Uri.parse(agreementUrl)) customTabsIntent.intent.setData(agreementUrl.toUri())
agreementAckLauncher.launch(customTabsIntent.intent) agreementAckLauncher.launch(customTabsIntent.intent)
}
// Dismiss the sheet.
showAgreementAckSheet = false
} }
// Dismiss the sheet. ) {
showAgreementAckSheet = false
}) {
Text("Open user agreement") Text("Open user agreement")
} }
} }
@ -374,24 +402,34 @@ fun DownloadAndTryButton(
icon = { icon = {
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error) Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
}, },
title = { title = { Text("Unknown network error") },
Text("Unknown network error") text = { Text("Please check your internet connection.") },
}, onDismissRequest = { showErrorDialog = false },
confirmButton = { TextButton(onClick = { showErrorDialog = false }) { Text("Close") } },
)
}
if (showMemoryWarning) {
AlertDialog(
title = { Text("Memory Warning") },
text = { text = {
Text("Please check your internet connection.") Text(
}, "This model might need more memory than your device has available. " +
onDismissRequest = { "Running it could cause the app to crash."
showErrorDialog = false )
}, },
onDismissRequest = { showMemoryWarning = false },
confirmButton = { confirmButton = {
TextButton( TextButton(
onClick = { onClick = {
showErrorDialog = false onClicked()
showMemoryWarning = false
} }
) { ) {
Text("Close") Text("Continue")
} }
}, },
dismissButton = { TextButton(onClick = { showMemoryWarning = false }) { Text("Cancel") } },
) )
} }
} }

View file

@ -33,18 +33,17 @@ import androidx.compose.ui.window.Dialog
@Composable @Composable
fun ErrorDialog(error: String, onDismiss: () -> Unit) { fun ErrorDialog(error: String, onDismiss: () -> Unit) {
Dialog( Dialog(onDismissRequest = onDismiss) {
onDismissRequest = onDismiss
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column( Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) { ) {
// Title // Title
Text( Text(
"Error", "Error",
style = MaterialTheme.typography.titleLarge, style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp) modifier = Modifier.padding(bottom = 8.dp),
) )
// Error // Error
@ -55,11 +54,7 @@ fun ErrorDialog(error: String, onDismiss: () -> Unit) {
) )
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) { Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button( Button(onClick = onDismiss) { Text("Close") }
onClick = onDismiss
) {
Text("Close")
}
} }
} }
} }

View file

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ProvideTextStyle import androidx.compose.material3.ProvideTextStyle
@ -25,8 +25,6 @@ import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.TextLinkStyles import androidx.compose.ui.text.TextLinkStyles
import androidx.compose.ui.text.TextStyle import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontFamily import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.tooling.preview.Preview
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors import com.google.ai.edge.gallery.ui.theme.customColors
import com.halilibo.richtext.commonmark.Markdown import com.halilibo.richtext.commonmark.Markdown
import com.halilibo.richtext.ui.CodeBlockStyle import com.halilibo.richtext.ui.CodeBlockStyle
@ -34,52 +32,43 @@ import com.halilibo.richtext.ui.RichTextStyle
import com.halilibo.richtext.ui.material3.RichText import com.halilibo.richtext.ui.material3.RichText
import com.halilibo.richtext.ui.string.RichTextStringStyle import com.halilibo.richtext.ui.string.RichTextStringStyle
/** /** Composable function to display Markdown-formatted text. */
* Composable function to display Markdown-formatted text.
*/
@Composable @Composable
fun MarkdownText( fun MarkdownText(text: String, modifier: Modifier = Modifier, smallFontSize: Boolean = false) {
text: String,
modifier: Modifier = Modifier,
smallFontSize: Boolean = false
) {
val fontSize = val fontSize =
if (smallFontSize) MaterialTheme.typography.bodyMedium.fontSize else MaterialTheme.typography.bodyLarge.fontSize if (smallFontSize) MaterialTheme.typography.bodyMedium.fontSize
else MaterialTheme.typography.bodyLarge.fontSize
CompositionLocalProvider { CompositionLocalProvider {
ProvideTextStyle( ProvideTextStyle(value = TextStyle(fontSize = fontSize, lineHeight = fontSize * 1.3)) {
value = TextStyle(
fontSize = fontSize,
lineHeight = fontSize * 1.3,
)
) {
RichText( RichText(
modifier = modifier, modifier = modifier,
style = RichTextStyle( style =
codeBlockStyle = CodeBlockStyle( RichTextStyle(
textStyle = TextStyle( codeBlockStyle =
fontSize = MaterialTheme.typography.bodySmall.fontSize, CodeBlockStyle(
fontFamily = FontFamily.Monospace, textStyle =
) TextStyle(
fontSize = MaterialTheme.typography.bodySmall.fontSize,
fontFamily = FontFamily.Monospace,
)
),
stringStyle =
RichTextStringStyle(
linkStyle =
TextLinkStyles(style = SpanStyle(color = MaterialTheme.customColors.linkColor))
),
), ),
stringStyle = RichTextStringStyle(
linkStyle = TextLinkStyles(
style = SpanStyle(color = MaterialTheme.customColors.linkColor)
)
)
),
) { ) {
Markdown( Markdown(content = text)
content = text
)
} }
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MarkdownTextPreview() { // fun MarkdownTextPreview() {
GalleryTheme { // GalleryTheme {
MarkdownText(text = "*Hello World*\n**Good morning!!**") // MarkdownText(text = "*Hello World*\n**Good morning!!**")
} // }
} // }

View file

@ -53,7 +53,7 @@ import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.chat.ConfigDialog import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
@ -71,54 +71,54 @@ fun ModelPageAppBar(
isResettingSession: Boolean = false, isResettingSession: Boolean = false,
onResetSessionClicked: (Model) -> Unit = {}, onResetSessionClicked: (Model) -> Unit = {},
canShowResetSessionButton: Boolean = false, canShowResetSessionButton: Boolean = false,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> }, onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit =
{ _, _ ->
},
) { ) {
var showConfigDialog by remember { mutableStateOf(false) } var showConfigDialog by remember { mutableStateOf(false) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val context = LocalContext.current val context = LocalContext.current
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name] val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
val modelInitializationStatus = val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[model.name]
modelManagerUiState.modelInitializationStatus[model.name]
CenterAlignedTopAppBar(title = { CenterAlignedTopAppBar(
Column( title = {
horizontalAlignment = Alignment.CenterHorizontally, Column(
verticalArrangement = Arrangement.spacedBy(4.dp) horizontalAlignment = Alignment.CenterHorizontally,
) { verticalArrangement = Arrangement.spacedBy(4.dp),
// Task type.
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) { ) {
Icon( // Task type.
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!), Row(
tint = getTaskIconColor(task = task), verticalAlignment = Alignment.CenterVertically,
modifier = Modifier.size(16.dp), horizontalArrangement = Arrangement.spacedBy(6.dp),
contentDescription = "", ) {
) Icon(
Text( task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
task.type.label, tint = getTaskIconColor(task = task),
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.SemiBold), modifier = Modifier.size(16.dp),
color = getTaskIconColor(task = task) contentDescription = "",
)
Text(
task.type.label,
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.SemiBold),
color = getTaskIconColor(task = task),
)
}
// Model chips pager.
ModelPickerChipsPager(
task = task,
initialModel = model,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = onModelSelected,
) )
} }
},
// Model chips pager. modifier = modifier,
ModelPickerChipsPager(
task = task,
initialModel = model,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = onModelSelected,
)
}
}, modifier = modifier,
// The back button. // The back button.
navigationIcon = { navigationIcon = {
IconButton(onClick = onBackClicked) { IconButton(onClick = onBackClicked) {
Icon( Icon(imageVector = Icons.AutoMirrored.Rounded.ArrowBack, contentDescription = "")
imageVector = Icons.AutoMirrored.Rounded.ArrowBack,
contentDescription = "",
)
} }
}, },
// The config button for the model (if existed). // The config button for the model (if existed).
@ -136,19 +136,16 @@ fun ModelPageAppBar(
if (showConfigButton) { if (showConfigButton) {
val enableConfigButton = !isModelInitializing && !inProgress val enableConfigButton = !isModelInitializing && !inProgress
IconButton( IconButton(
onClick = { onClick = { showConfigDialog = true },
showConfigDialog = true
},
enabled = enableConfigButton, enabled = enableConfigButton,
modifier = Modifier modifier =
.offset(x = configButtonOffset) Modifier.offset(x = configButtonOffset).alpha(if (!enableConfigButton) 0.5f else 1f),
.alpha(if (!enableConfigButton) 0.5f else 1f)
) { ) {
Icon( Icon(
imageVector = Icons.Rounded.Tune, imageVector = Icons.Rounded.Tune,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.primary, tint = MaterialTheme.colorScheme.primary,
modifier = Modifier.size(20.dp) modifier = Modifier.size(20.dp),
) )
} }
} }
@ -157,39 +154,35 @@ fun ModelPageAppBar(
CircularProgressIndicator( CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceVariant, trackColor = MaterialTheme.colorScheme.surfaceVariant,
strokeWidth = 2.dp, strokeWidth = 2.dp,
modifier = Modifier.size(16.dp) modifier = Modifier.size(16.dp),
) )
} else { } else {
val enableResetButton = !isModelInitializing && !modelPreparing val enableResetButton = !isModelInitializing && !modelPreparing
IconButton( IconButton(
onClick = { onClick = { onResetSessionClicked(model) },
onResetSessionClicked(model)
},
enabled = enableResetButton, enabled = enableResetButton,
modifier = Modifier modifier = Modifier.alpha(if (!enableResetButton) 0.5f else 1f),
.alpha(if (!enableResetButton) 0.5f else 1f)
) { ) {
Box( Box(
modifier = Modifier modifier =
.size(32.dp) Modifier.size(32.dp)
.clip(CircleShape) .clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainer), .background(MaterialTheme.colorScheme.surfaceContainer),
contentAlignment = Alignment.Center contentAlignment = Alignment.Center,
) { ) {
Icon( Icon(
imageVector = Icons.Rounded.MapsUgc, imageVector = Icons.Rounded.MapsUgc,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.primary, tint = MaterialTheme.colorScheme.primary,
modifier = Modifier modifier = Modifier.size(20.dp),
.size(20.dp)
) )
} }
} }
} }
} }
} }
},
}) )
// Config dialog. // Config dialog.
if (showConfigDialog) { if (showConfigDialog) {
@ -208,12 +201,16 @@ fun ModelPageAppBar(
var needReinitialization = false var needReinitialization = false
for (config in model.configs) { for (config in model.configs) {
val key = config.key.label val key = config.key.label
val oldValue = convertValueToTargetType( val oldValue =
value = model.configValues.getValue(key), valueType = config.valueType convertValueToTargetType(
) value = model.configValues.getValue(key),
val newValue = convertValueToTargetType( valueType = config.valueType,
value = curConfigValues.getValue(key), valueType = config.valueType )
) val newValue =
convertValueToTargetType(
value = curConfigValues.getValue(key),
valueType = config.valueType,
)
if (oldValue != newValue) { if (oldValue != newValue) {
same = false same = false
if (config.needReinitialization) { if (config.needReinitialization) {
@ -233,7 +230,10 @@ fun ModelPageAppBar(
// Force to re-initialize the model with the new configs. // Force to re-initialize the model with the new configs.
if (needReinitialization) { if (needReinitialization) {
modelManagerViewModel.initializeModel( modelManagerViewModel.initializeModel(
context = context, task = task, model = model, force = true context = context,
task = task,
model = model,
force = true,
) )
} }
@ -242,4 +242,4 @@ fun ModelPageAppBar(
}, },
) )
} }
} }

View file

@ -16,6 +16,11 @@
package com.google.ai.edge.gallery.ui.common package com.google.ai.edge.gallery.ui.common
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
@ -28,7 +33,6 @@ import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.width
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.CheckCircle import androidx.compose.material.icons.filled.CheckCircle
import androidx.compose.material.icons.outlined.CheckCircle
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text import androidx.compose.material3.Text
@ -39,34 +43,27 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource import androidx.compose.ui.res.vectorResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.modelitem.StatusIcon import com.google.ai.edge.gallery.ui.common.modelitem.StatusIcon
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
@Composable @Composable
fun ModelPicker( fun ModelPicker(
task: Task, task: Task,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
onModelSelected: (Model) -> Unit onModelSelected: (Model) -> Unit,
) { ) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
Column(modifier = Modifier.padding(bottom = 8.dp)) { Column(modifier = Modifier.padding(bottom = 8.dp)) {
// Title // Title
Row( Row(
modifier = Modifier modifier = Modifier.padding(horizontal = 16.dp).padding(top = 4.dp, bottom = 4.dp),
.padding(horizontal = 16.dp)
.padding(top = 4.dp, bottom = 4.dp),
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp), horizontalArrangement = Arrangement.spacedBy(8.dp),
) { ) {
@ -90,51 +87,47 @@ fun ModelPicker(
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween, horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.clickable { .clickable { onModelSelected(model) }
onModelSelected(model) .background(
} if (selected) MaterialTheme.colorScheme.surfaceContainer else Color.Transparent
.background(if (selected) MaterialTheme.colorScheme.surfaceContainer else Color.Transparent) )
.padding(horizontal = 16.dp, vertical = 8.dp), .padding(horizontal = 16.dp, vertical = 8.dp),
) { ) {
Spacer(modifier = Modifier.width(24.dp)) Spacer(modifier = Modifier.width(24.dp))
Column(modifier = Modifier.weight(1f)) { Column(modifier = Modifier.weight(1f)) {
Text(model.name, style = MaterialTheme.typography.bodyMedium) Text(model.name, style = MaterialTheme.typography.bodyMedium)
Row( Row(
horizontalArrangement = Arrangement.spacedBy(4.dp), horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically,
) { ) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name]) StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text( Text(
model.sizeInBytes.humanReadableSize(), model.sizeInBytes.humanReadableSize(),
color = MaterialTheme.colorScheme.secondary, color = MaterialTheme.colorScheme.secondary,
style = labelSmallNarrow.copy(lineHeight = 10.sp) style = labelSmallNarrow.copy(lineHeight = 10.sp),
) )
} }
} }
if (selected) { if (selected) {
Icon( Icon(Icons.Filled.CheckCircle, modifier = Modifier.size(16.dp), contentDescription = "")
Icons.Filled.CheckCircle,
modifier = Modifier.size(16.dp),
contentDescription = ""
)
} }
} }
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun ModelPickerPreview() { // fun ModelPickerPreview() {
val context = LocalContext.current // val context = LocalContext.current
GalleryTheme { // GalleryTheme {
ModelPicker( // ModelPicker(
task = TASK_TEST1, // task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = context), // modelManagerViewModel = PreviewModelManagerViewModel(context = context),
onModelSelected = {}, // onModelSelected = {},
) // )
} // }
} // }

View file

@ -64,9 +64,9 @@ import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.modelitem.StatusIcon import com.google.ai.edge.gallery.ui.common.modelitem.StatusIcon
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlin.math.absoluteValue
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlin.math.absoluteValue
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
@ -83,14 +83,13 @@ fun ModelPickerChipsPager(
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val density = LocalDensity.current val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current val windowInfo = LocalWindowInfo.current
val screenWidthDp = remember { val screenWidthDp = remember { with(density) { windowInfo.containerSize.width.toDp() } }
with(density) {
windowInfo.containerSize.width.toDp()
}
}
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel), val pagerState =
pageCount = { task.models.size }) rememberPagerState(
initialPage = task.models.indexOf(initialModel),
pageCount = { task.models.size },
)
// Sync scrolling. // Sync scrolling.
LaunchedEffect(modelManagerViewModel.pagerScrollState) { LaunchedEffect(modelManagerViewModel.pagerScrollState) {
@ -107,56 +106,51 @@ fun ModelPickerChipsPager(
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue ((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
val curAlpha = 1f - (pageOffset * 1.5f).coerceIn(0f, 1f) val curAlpha = 1f - (pageOffset * 1.5f).coerceIn(0f, 1f)
val modelInitializationStatus = val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[model.name]
modelManagerUiState.modelInitializationStatus[model.name]
Box( Box(
modifier = Modifier modifier = Modifier.fillMaxWidth().graphicsLayer { alpha = curAlpha },
.fillMaxWidth() contentAlignment = Alignment.Center,
.graphicsLayer { alpha = curAlpha },
contentAlignment = Alignment.Center
) { ) {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp) horizontalArrangement = Arrangement.spacedBy(2.dp),
) { ) {
Row(verticalAlignment = Alignment.CenterVertically, Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp), horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier modifier =
.clip(CircleShape) Modifier.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHigh) .background(MaterialTheme.colorScheme.surfaceContainerHigh)
.clickable { .clickable {
modelPickerModel = model modelPickerModel = model
showModelPicker = true showModelPicker = true
} }
.padding(start = 8.dp, end = 2.dp) .padding(start = 8.dp, end = 2.dp)
.padding(vertical = 4.dp)) Inner@{ .padding(vertical = 4.dp),
) Inner@{
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(21.dp)) { Box(contentAlignment = Alignment.Center, modifier = Modifier.size(21.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name]) StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
this@Inner.AnimatedVisibility( this@Inner.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, visible =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(), enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(), exit = scaleOut() + fadeOut(),
) { ) {
// Circular progress indicator. // Circular progress indicator.
CircularProgressIndicator( CircularProgressIndicator(
modifier = Modifier modifier = Modifier.size(24.dp).alpha(0.5f),
.size(24.dp)
.alpha(0.5f),
strokeWidth = 2.dp, strokeWidth = 2.dp,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant,
) )
} }
} }
Text( Text(
model.name, model.name,
style = MaterialTheme.typography.labelLarge, style = MaterialTheme.typography.labelLarge,
modifier = Modifier modifier = Modifier.padding(start = 4.dp).widthIn(0.dp, screenWidthDp - 250.dp),
.padding(start = 4.dp)
.widthIn(0.dp, screenWidthDp - 250.dp),
maxLines = 1, maxLines = 1,
overflow = TextOverflow.MiddleEllipsis overflow = TextOverflow.MiddleEllipsis,
) )
Icon( Icon(
Icons.Rounded.ArrowDropDown, Icons.Rounded.ArrowDropDown,
@ -171,10 +165,7 @@ fun ModelPickerChipsPager(
// Model picker. // Model picker.
val curModelPickerModel = modelPickerModel val curModelPickerModel = modelPickerModel
if (showModelPicker && curModelPickerModel != null) { if (showModelPicker && curModelPickerModel != null) {
ModalBottomSheet( ModalBottomSheet(onDismissRequest = { showModelPicker = false }, sheetState = sheetState) {
onDismissRequest = { showModelPicker = false },
sheetState = sheetState,
) {
ModelPicker( ModelPicker(
task = task, task = task,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
@ -187,8 +178,8 @@ fun ModelPickerChipsPager(
} }
onModelSelected(selectedModel) onModelSelected(selectedModel)
} },
) )
} }
} }
} }

View file

@ -52,28 +52,22 @@ private val SHAPES: List<Int> =
listOf(R.drawable.pantegon, R.drawable.double_circle, R.drawable.circle, R.drawable.four_circle) listOf(R.drawable.pantegon, R.drawable.double_circle, R.drawable.circle, R.drawable.four_circle)
/** /**
* Composable that displays an icon representing a task. It consists of a background * Composable that displays an icon representing a task. It consists of a background image and a
* image and a foreground icon, both centered within a square box. * foreground icon, both centered within a square box.
*/ */
@Composable @Composable
fun TaskIcon(task: Task, modifier: Modifier = Modifier, width: Dp = 56.dp) { fun TaskIcon(task: Task, modifier: Modifier = Modifier, width: Dp = 56.dp) {
Box( Box(modifier = modifier.width(width).aspectRatio(1f), contentAlignment = Alignment.Center) {
modifier = modifier
.width(width)
.aspectRatio(1f),
contentAlignment = Alignment.Center,
) {
Image( Image(
painter = getTaskIconBgShape(task = task), painter = getTaskIconBgShape(task = task),
contentDescription = "", contentDescription = "",
modifier = Modifier modifier = Modifier.fillMaxSize().alpha(0.6f),
.fillMaxSize()
.alpha(0.6f),
contentScale = ContentScale.Fit, contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint( colorFilter =
MaterialTheme.customColors.taskIconShapeBgColor, ColorFilter.tint(
blendMode = BlendMode.SrcIn MaterialTheme.customColors.taskIconShapeBgColor,
) blendMode = BlendMode.SrcIn,
),
) )
Icon( Icon(
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!), task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
@ -102,4 +96,4 @@ fun TaskIconPreview() {
TaskIcon(task = TASK_LLM_CHAT, width = 80.dp) TaskIcon(task = TASK_LLM_CHAT, width = 80.dp)
} }
} }
} }

View file

@ -21,63 +21,18 @@ import android.content.Context
import android.content.pm.PackageManager import android.content.pm.PackageManager
import android.net.Uri import android.net.Uri
import android.os.Build import android.os.Build
import android.util.Log
import androidx.activity.compose.ManagedActivityResultLauncher import androidx.activity.compose.ManagedActivityResultLauncher
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.core.content.ContextCompat import androidx.core.content.ContextCompat
import androidx.core.content.FileProvider import androidx.core.content.FileProvider
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASKS
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.ValueType
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkResult
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageType
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import com.google.ai.edge.gallery.ui.common.chat.Histogram
import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.customColors
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import java.io.File import java.io.File
import java.net.HttpURLConnection
import java.net.URL
import kotlin.math.abs
import kotlin.math.ln import kotlin.math.ln
import kotlin.math.max
import kotlin.math.min
import kotlin.math.pow import kotlin.math.pow
import kotlin.math.sqrt
private const val TAG = "AGUtils" private const val TAG = "AGUtils"
private const val LAUNCH_INFO_FILE_NAME = "launch_info"
private val STATS = listOf(
Stat(id = "min", label = "Min", unit = "ms"),
Stat(id = "max", label = "Max", unit = "ms"),
Stat(id = "avg", label = "Avg", unit = "ms"),
Stat(id = "stddev", label = "Stddev", unit = "ms")
)
interface LatencyProvider {
val latencyMs: Float
}
private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***"
data class JsonObjAndTextContent<T>(
val jsonObj: T, val textContent: String,
)
data class LaunchInfo(
val ts: Long
)
/** Format the bytes into a human-readable format. */ /** Format the bytes into a human-readable format. */
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String { fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
@ -139,320 +94,56 @@ fun Long.formatToHourMinSecond(): String {
return parts.joinToString(" ") return parts.joinToString(" ")
} }
fun convertValueToTargetType(value: Any, valueType: ValueType): Any {
return when (valueType) {
ValueType.INT -> when (value) {
is Int -> value
is Float -> value.toInt()
is Double -> value.toInt()
is String -> value.toIntOrNull() ?: ""
is Boolean -> if (value) 1 else 0
else -> ""
}
ValueType.FLOAT -> when (value) {
is Int -> value.toFloat()
is Float -> value
is Double -> value.toFloat()
is String -> value.toFloatOrNull() ?: ""
is Boolean -> if (value) 1f else 0f
else -> ""
}
ValueType.DOUBLE -> when (value) {
is Int -> value.toDouble()
is Float -> value.toDouble()
is Double -> value
is String -> value.toDoubleOrNull() ?: ""
is Boolean -> if (value) 1.0 else 0.0
else -> ""
}
ValueType.BOOLEAN -> when (value) {
is Int -> value == 0
is Boolean -> value
is Float -> abs(value) > 1e-6
is Double -> abs(value) > 1e-6
is String -> value.isNotEmpty()
else -> false
}
ValueType.STRING -> value.toString()
}
}
fun getDistinctiveColor(index: Int): Color { fun getDistinctiveColor(index: Int): Color {
val colors = listOf( val colors =
// Color(0xffe6194b), listOf(
Color(0xff3cb44b), // Color(0xffe6194b),
Color(0xffffe119), Color(0xff3cb44b),
Color(0xff4363d8), Color(0xffffe119),
Color(0xfff58231), Color(0xff4363d8),
Color(0xff911eb4), Color(0xfff58231),
Color(0xff46f0f0), Color(0xff911eb4),
Color(0xfff032e6), Color(0xff46f0f0),
Color(0xffbcf60c), Color(0xfff032e6),
Color(0xfffabebe), Color(0xffbcf60c),
Color(0xff008080), Color(0xfffabebe),
Color(0xffe6beff), Color(0xff008080),
Color(0xff9a6324), Color(0xffe6beff),
Color(0xfffffac8), Color(0xff9a6324),
Color(0xff800000), Color(0xfffffac8),
Color(0xffaaffc3), Color(0xff800000),
Color(0xff808000), Color(0xffaaffc3),
Color(0xffffd8b1), Color(0xff808000),
Color(0xff000075) Color(0xffffd8b1),
) Color(0xff000075),
)
return colors[index % colors.size] return colors[index % colors.size]
} }
fun Context.createTempPictureUri( fun Context.createTempPictureUri(
fileName: String = "picture_${System.currentTimeMillis()}", fileExtension: String = ".png" fileName: String = "picture_${System.currentTimeMillis()}",
fileExtension: String = ".png",
): Uri { ): Uri {
val tempFile = File.createTempFile( val tempFile = File.createTempFile(fileName, fileExtension, cacheDir).apply { createNewFile() }
fileName, fileExtension, cacheDir
).apply {
createNewFile()
}
return FileProvider.getUriForFile( return FileProvider.getUriForFile(
applicationContext, applicationContext,
"com.google.aiedge.gallery.provider" /* {applicationId}.provider */, "com.google.aiedge.gallery.provider" /* {applicationId}.provider */,
tempFile tempFile,
) )
} }
fun runBasicBenchmark(
model: Model,
warmupCount: Int,
iterations: Int,
chatViewModel: ChatViewModel,
inferenceFn: () -> LatencyProvider,
chatMessageType: ChatMessageType,
) {
val start = System.currentTimeMillis()
var lastUpdateTs = 0L
val update: (ChatMessageBenchmarkResult) -> Unit = { message ->
if (lastUpdateTs == 0L) {
chatViewModel.addMessage(
model = model,
message = message,
)
lastUpdateTs = System.currentTimeMillis()
} else {
val curTs = System.currentTimeMillis()
if (curTs - lastUpdateTs > 500) {
chatViewModel.replaceLastMessage(model = model, message = message, type = chatMessageType)
lastUpdateTs = curTs
}
}
}
// Warmup.
val latencies: MutableList<Float> = mutableListOf()
for (count in 1..warmupCount) {
inferenceFn()
update(
ChatMessageBenchmarkResult(
orderedStats = STATS,
statValues = calculateStats(min = 0f, max = 0f, sum = 0f, latencies = latencies),
histogram = calculateLatencyHistogram(
latencies = latencies, min = 0f, max = 0f, avg = 0f
),
values = latencies,
warmupCurrent = count,
warmupTotal = warmupCount,
iterationCurrent = 0,
iterationTotal = iterations,
latencyMs = (System.currentTimeMillis() - start).toFloat(),
highlightStat = "avg"
)
)
}
// Benchmark iterations.
var min = Float.MAX_VALUE
var max = 0f
var sum = 0f
for (count in 1..iterations) {
val result = inferenceFn()
val latency = result.latencyMs
min = min(min, latency)
max = max(max, latency)
sum += latency
latencies.add(latency)
val curTs = System.currentTimeMillis()
if (curTs - lastUpdateTs > 500 || count == iterations) {
lastUpdateTs = curTs
val stats = calculateStats(min = min, max = max, sum = sum, latencies = latencies)
chatViewModel.replaceLastMessage(
model = model,
message = ChatMessageBenchmarkResult(
orderedStats = STATS,
statValues = stats,
histogram = calculateLatencyHistogram(
latencies = latencies,
min = min,
max = max,
avg = stats["avg"] ?: 0f,
),
values = latencies,
warmupCurrent = warmupCount,
warmupTotal = warmupCount,
iterationCurrent = count,
iterationTotal = iterations,
latencyMs = (System.currentTimeMillis() - start).toFloat(),
highlightStat = "avg"
),
type = chatMessageType,
)
}
// Go through other benchmark messages and update their buckets for the common min/max values.
var allMin = Float.MAX_VALUE
var allMax = 0f
val allMessages = chatViewModel.uiState.value.messagesByModel[model.name] ?: listOf()
for (message in allMessages) {
if (message is ChatMessageBenchmarkResult) {
val curMin = message.statValues["min"] ?: 0f
val curMax = message.statValues["max"] ?: 0f
allMin = min(allMin, curMin)
allMax = max(allMax, curMax)
}
}
for ((index, message) in allMessages.withIndex()) {
if (message === allMessages.last()) {
break
}
if (message is ChatMessageBenchmarkResult) {
val updatedMessage = ChatMessageBenchmarkResult(
orderedStats = STATS,
statValues = message.statValues,
histogram = calculateLatencyHistogram(
latencies = message.values,
min = allMin,
max = allMax,
avg = message.statValues["avg"] ?: 0f,
),
values = message.values,
warmupCurrent = message.warmupCurrent,
warmupTotal = message.warmupTotal,
iterationCurrent = message.iterationCurrent,
iterationTotal = message.iterationTotal,
latencyMs = message.latencyMs,
highlightStat = "avg"
)
chatViewModel.replaceMessage(model = model, index = index, message = updatedMessage)
}
}
}
}
private fun calculateStats(
min: Float, max: Float, sum: Float, latencies: MutableList<Float>
): MutableMap<String, Float> {
val avg = if (latencies.size == 0) 0f else sum / latencies.size
val squaredDifferences = latencies.map { (it - avg).pow(2) }
val variance = squaredDifferences.average()
val stddev = if (latencies.size == 0) 0f else sqrt(variance).toFloat()
var medium = 0f
if (latencies.size == 1) {
medium = latencies[0]
} else if (latencies.size > 1) {
latencies.sort()
val middle = latencies.size / 2
medium =
if (latencies.size % 2 == 0) (latencies[middle - 1] + latencies[middle]) / 2.0f else latencies[middle]
}
return mutableMapOf(
"min" to min, "max" to max, "avg" to avg, "stddev" to stddev, "medium" to medium
)
}
fun calculateLatencyHistogram(
latencies: List<Float>, min: Float, max: Float, avg: Float, numBuckets: Int = 20
): Histogram {
if (latencies.isEmpty() || numBuckets <= 0) {
return Histogram(
buckets = List(numBuckets) { 0 }, maxCount = 0
)
}
if (min == max) {
// All latencies are the same.
val result = MutableList(numBuckets) { 0 }
result[0] = latencies.size
return Histogram(buckets = result, maxCount = result[0], highlightBucketIndex = 0)
}
val bucketSize = (max - min) / numBuckets
val histogram = MutableList(numBuckets) { 0 }
val getBucketIndex: (value: Float) -> Int = {
var bucketIndex = ((it - min) / bucketSize).toInt()
// Handle the case where latency equals maxLatency
if (bucketIndex == numBuckets) {
bucketIndex = numBuckets - 1
}
bucketIndex
}
for (latency in latencies) {
val bucketIndex = getBucketIndex(latency)
histogram[bucketIndex]++
}
val avgBucketIndex = getBucketIndex(avg)
return Histogram(
buckets = histogram,
maxCount = histogram.maxOrNull() ?: 0,
highlightBucketIndex = avgBucketIndex
)
}
fun getConfigValueString(value: Any, config: Config): String {
var strNewValue = "$value"
if (config.valueType == ValueType.FLOAT) {
strNewValue = "%.2f".format(value)
}
return strNewValue
}
@Composable
fun getTaskBgColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskBgColors.size
return MaterialTheme.customColors.taskBgColors[colorIndex]
}
@Composable
fun getTaskIconColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}
@Composable
fun getTaskIconColor(index: Int): Color {
val colorIndex: Int = index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}
fun checkNotificationPermissionAndStartDownload( fun checkNotificationPermissionAndStartDownload(
context: Context, context: Context,
launcher: ManagedActivityResultLauncher<String, Boolean>, launcher: ManagedActivityResultLauncher<String, Boolean>,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
task: Task, task: Task,
model: Model model: Model,
) { ) {
// Check permission // Check permission
when (PackageManager.PERMISSION_GRANTED) { when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda. // Already got permission. Call the lambda.
ContextCompat.checkSelfPermission( ContextCompat.checkSelfPermission(context, Manifest.permission.POST_NOTIFICATIONS) -> {
context, Manifest.permission.POST_NOTIFICATIONS
) -> {
modelManagerViewModel.downloadModel(task = task, model = model) modelManagerViewModel.downloadModel(task = task, model = model)
} }
@ -468,100 +159,3 @@ fun checkNotificationPermissionAndStartDownload(
fun ensureValidFileName(fileName: String): String { fun ensureValidFileName(fileName: String): String {
return fileName.replace(Regex("[^a-zA-Z0-9._-]"), "_") return fileName.replace(Regex("[^a-zA-Z0-9._-]"), "_")
} }
fun cleanUpMediapipeTaskErrorMessage(message: String): String {
val index = message.indexOf("=== Source Location Trace")
if (index >= 0) {
return message.substring(0, index)
}
return message
}
fun processTasks() {
for ((index, task) in TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess()
}
}
}
fun processLlmResponse(response: String): String {
// Add "thinking" and "done thinking" around the thinking content.
var newContent =
response.replace("<think>", "$START_THINKING\n").replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) {
val thinkingContent =
newContent.substring(0, endThinkingIndex + DONE_THINKING.length).replace(START_THINKING, "")
.replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
}
}
newContent = newContent.replace("\\n", "\n")
return newContent
}
@OptIn(ExperimentalSerializationApi::class)
inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET"
connection.connect()
val responseCode = connection.responseCode
if (responseCode == HttpURLConnection.HTTP_OK) {
val inputStream = connection.inputStream
val response = inputStream.bufferedReader().use { it.readText() }
// Parse JSON using kotlinx.serialization
val json = Json {
// Handle potential extra fields
ignoreUnknownKeys = true
allowComments = true
allowTrailingComma = true
}
val jsonObj = json.decodeFromString<T>(response)
return JsonObjAndTextContent(jsonObj = jsonObj, textContent = response)
} else {
Log.e("AGUtils", "HTTP error: $responseCode")
}
} catch (e: Exception) {
Log.e(
"AGUtils", "Error when getting json response: ${e.message}"
)
e.printStackTrace()
}
return null
}
fun writeLaunchInfo(context: Context) {
try {
val gson = Gson()
val launchInfo = LaunchInfo(ts = System.currentTimeMillis())
val jsonString = gson.toJson(launchInfo)
val file = File(context.getExternalFilesDir(null), LAUNCH_INFO_FILE_NAME)
file.writeText(jsonString)
} catch (e: Exception) {
Log.e(TAG, "Failed to write launch info", e)
}
}
fun readLaunchInfo(context: Context): LaunchInfo? {
try {
val gson = Gson()
val type = object : TypeToken<LaunchInfo>() {}.type
val file = File(context.getExternalFilesDir(null), LAUNCH_INFO_FILE_NAME)
val content = file.readText()
return gson.fromJson(content, type)
} catch (e: Exception) {
Log.e(TAG, "Failed to read launch info", e)
return null
}
}

View file

@ -16,52 +16,55 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.tooling.preview.Preview
import com.google.ai.edge.gallery.data.Config import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.NumberSliderConfig import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.ValueType import com.google.ai.edge.gallery.data.ValueType
import com.google.ai.edge.gallery.ui.common.convertValueToTargetType import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.ui.theme.GalleryTheme import com.google.ai.edge.gallery.ui.common.ConfigDialog
private const val DEFAULT_BENCHMARK_WARM_UP_ITERATIONS = 50f private const val DEFAULT_BENCHMARK_WARM_UP_ITERATIONS = 50f
private const val DEFAULT_BENCHMARK_ITERATIONS = 200f private const val DEFAULT_BENCHMARK_ITERATIONS = 200f
private val BENCHMARK_CONFIGS: List<Config> = listOf( private val BENCHMARK_CONFIGS: List<Config> =
NumberSliderConfig( listOf(
key = ConfigKey.WARM_UP_ITERATIONS, NumberSliderConfig(
sliderMin = 10f, key = ConfigKey.WARM_UP_ITERATIONS,
sliderMax = 200f, sliderMin = 10f,
defaultValue = DEFAULT_BENCHMARK_WARM_UP_ITERATIONS, sliderMax = 200f,
valueType = ValueType.INT defaultValue = DEFAULT_BENCHMARK_WARM_UP_ITERATIONS,
), valueType = ValueType.INT,
NumberSliderConfig( ),
key = ConfigKey.BENCHMARK_ITERATIONS, NumberSliderConfig(
sliderMin = 50f, key = ConfigKey.BENCHMARK_ITERATIONS,
sliderMax = 500f, sliderMin = 50f,
defaultValue = DEFAULT_BENCHMARK_ITERATIONS, sliderMax = 500f,
valueType = ValueType.INT defaultValue = DEFAULT_BENCHMARK_ITERATIONS,
), valueType = ValueType.INT,
) ),
)
private val BENCHMARK_CONFIGS_INITIAL_VALUES = mapOf( private val BENCHMARK_CONFIGS_INITIAL_VALUES =
ConfigKey.WARM_UP_ITERATIONS.label to DEFAULT_BENCHMARK_WARM_UP_ITERATIONS, mapOf(
ConfigKey.BENCHMARK_ITERATIONS.label to DEFAULT_BENCHMARK_ITERATIONS ConfigKey.WARM_UP_ITERATIONS.label to DEFAULT_BENCHMARK_WARM_UP_ITERATIONS,
) ConfigKey.BENCHMARK_ITERATIONS.label to DEFAULT_BENCHMARK_ITERATIONS,
)
/** /**
* Composable function to display a configuration dialog for benchmarking a chat message. * Composable function to display a configuration dialog for benchmarking a chat message.
* *
* This function renders a configuration dialog specifically tailored for setting up * This function renders a configuration dialog specifically tailored for setting up benchmark
* benchmark parameters. It allows users to specify warm-up and benchmark iterations * parameters. It allows users to specify warm-up and benchmark iterations before running a
* before running a benchmark test on a given chat message. * benchmark test on a given chat message.
*/ */
@Composable @Composable
fun BenchmarkConfigDialog( fun BenchmarkConfigDialog(
onDismissed: () -> Unit, onDismissed: () -> Unit,
messageToBenchmark: ChatMessage?, messageToBenchmark: ChatMessage?,
onBenchmarkClicked: (ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit onBenchmarkClicked: (ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
) { ) {
ConfigDialog( ConfigDialog(
title = "Benchmark configs", title = "Benchmark configs",
@ -75,28 +78,32 @@ fun BenchmarkConfigDialog(
// Start benchmark. // Start benchmark.
messageToBenchmark?.let { message -> messageToBenchmark?.let { message ->
val warmUpIterations = convertValueToTargetType( val warmUpIterations =
value = curConfigValues.getValue(ConfigKey.WARM_UP_ITERATIONS.label), convertValueToTargetType(
valueType = ValueType.INT value = curConfigValues.getValue(ConfigKey.WARM_UP_ITERATIONS.label),
) as Int valueType = ValueType.INT,
val benchmarkIterations = convertValueToTargetType( )
value = curConfigValues.getValue(ConfigKey.BENCHMARK_ITERATIONS.label), as Int
valueType = ValueType.INT val benchmarkIterations =
) as Int convertValueToTargetType(
value = curConfigValues.getValue(ConfigKey.BENCHMARK_ITERATIONS.label),
valueType = ValueType.INT,
)
as Int
onBenchmarkClicked(message, warmUpIterations, benchmarkIterations) onBenchmarkClicked(message, warmUpIterations, benchmarkIterations)
} }
}, },
) )
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun BenchmarkConfigDialogPreview() { // fun BenchmarkConfigDialogPreview() {
GalleryTheme { // GalleryTheme {
BenchmarkConfigDialog( // BenchmarkConfigDialog(
onDismissed = {}, // onDismissed = {},
messageToBenchmark = null, // messageToBenchmark = null,
onBenchmarkClicked = { _, _, _ -> } // onBenchmarkClicked = { _, _, _ -> },
) // )
} // }
} // }

View file

@ -17,10 +17,11 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap import android.graphics.Bitmap
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.ImageBitmap import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.unit.Dp import androidx.compose.ui.unit.Dp
import com.google.ai.edge.gallery.common.Classification
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.PromptTemplate
enum class ChatMessageType { enum class ChatMessageType {
INFO, INFO,
@ -33,15 +34,15 @@ enum class ChatMessageType {
CONFIG_VALUES_CHANGE, CONFIG_VALUES_CHANGE,
BENCHMARK_RESULT, BENCHMARK_RESULT,
BENCHMARK_LLM_RESULT, BENCHMARK_LLM_RESULT,
PROMPT_TEMPLATES PROMPT_TEMPLATES,
} }
enum class ChatSide { enum class ChatSide {
USER, AGENT, SYSTEM USER,
AGENT,
SYSTEM,
} }
data class Classification(val label: String, val score: Float, val color: Color)
/** Base class for a chat message. */ /** Base class for a chat message. */
open class ChatMessage( open class ChatMessage(
open val type: ChatMessageType, open val type: ChatMessageType,
@ -70,7 +71,7 @@ class ChatMessageWarning(val content: String) :
class ChatMessageConfigValuesChange( class ChatMessageConfigValuesChange(
val model: Model, val model: Model,
val oldValues: Map<String, Any>, val oldValues: Map<String, Any>,
val newValues: Map<String, Any> val newValues: Map<String, Any>,
) : ChatMessage(type = ChatMessageType.CONFIG_VALUES_CHANGE, side = ChatSide.SYSTEM) ) : ChatMessage(type = ChatMessageType.CONFIG_VALUES_CHANGE, side = ChatSide.SYSTEM)
/** Chat message for plain text. */ /** Chat message for plain text. */
@ -84,12 +85,13 @@ open class ChatMessageText(
// Benchmark result for LLM response. // Benchmark result for LLM response.
var llmBenchmarkResult: ChatMessageBenchmarkLlmResult? = null, var llmBenchmarkResult: ChatMessageBenchmarkLlmResult? = null,
override val accelerator: String = "", override val accelerator: String = "",
) : ChatMessage( ) :
type = ChatMessageType.TEXT, ChatMessage(
side = side, type = ChatMessageType.TEXT,
latencyMs = latencyMs, side = side,
accelerator = accelerator latencyMs = latencyMs,
) { accelerator = accelerator,
) {
override fun clone(): ChatMessageText { override fun clone(): ChatMessageText {
return ChatMessageText( return ChatMessageText(
content = content, content = content,
@ -107,15 +109,14 @@ class ChatMessageImage(
val bitmap: Bitmap, val bitmap: Bitmap,
val imageBitMap: ImageBitmap, val imageBitMap: ImageBitmap,
override val side: ChatSide, override val side: ChatSide,
override val latencyMs: Float = 0f override val latencyMs: Float = 0f,
) : ) : ChatMessage(type = ChatMessageType.IMAGE, side = side, latencyMs = latencyMs) {
ChatMessage(type = ChatMessageType.IMAGE, side = side, latencyMs = latencyMs) {
override fun clone(): ChatMessageImage { override fun clone(): ChatMessageImage {
return ChatMessageImage( return ChatMessageImage(
bitmap = bitmap, bitmap = bitmap,
imageBitMap = imageBitMap, imageBitMap = imageBitMap,
side = side, side = side,
latencyMs = latencyMs latencyMs = latencyMs,
) )
} }
} }
@ -128,8 +129,7 @@ class ChatMessageImageWithHistory(
override val side: ChatSide, override val side: ChatSide,
override val latencyMs: Float = 0f, override val latencyMs: Float = 0f,
var curIteration: Int = 0, // 0-based var curIteration: Int = 0, // 0-based
) : ) : ChatMessage(type = ChatMessageType.IMAGE_WITH_HISTORY, side = side, latencyMs = latencyMs) {
ChatMessage(type = ChatMessageType.IMAGE_WITH_HISTORY, side = side, latencyMs = latencyMs) {
fun isRunning(): Boolean { fun isRunning(): Boolean {
return curIteration < totalIterations - 1 return curIteration < totalIterations - 1
} }
@ -141,7 +141,8 @@ class ChatMessageClassification(
override val latencyMs: Float = 0f, override val latencyMs: Float = 0f,
// Typical android phone width is > 320dp // Typical android phone width is > 320dp
val maxBarWidth: Dp? = null, val maxBarWidth: Dp? = null,
) : ChatMessage(type = ChatMessageType.CLASSIFICATION, side = ChatSide.AGENT, latencyMs = latencyMs) ) :
ChatMessage(type = ChatMessageType.CLASSIFICATION, side = ChatSide.AGENT, latencyMs = latencyMs)
/** A stat used in benchmark result. */ /** A stat used in benchmark result. */
data class Stat(val id: String, val label: String, val unit: String) data class Stat(val id: String, val label: String, val unit: String)
@ -162,7 +163,7 @@ class ChatMessageBenchmarkResult(
ChatMessage( ChatMessage(
type = ChatMessageType.BENCHMARK_RESULT, type = ChatMessageType.BENCHMARK_RESULT,
side = ChatSide.AGENT, side = ChatSide.AGENT,
latencyMs = latencyMs latencyMs = latencyMs,
) { ) {
fun isWarmingUp(): Boolean { fun isWarmingUp(): Boolean {
return warmupCurrent < warmupTotal return warmupCurrent < warmupTotal
@ -180,23 +181,18 @@ class ChatMessageBenchmarkLlmResult(
val running: Boolean, val running: Boolean,
override val latencyMs: Float = 0f, override val latencyMs: Float = 0f,
override val accelerator: String = "", override val accelerator: String = "",
) : ChatMessage( ) :
type = ChatMessageType.BENCHMARK_LLM_RESULT, ChatMessage(
side = ChatSide.AGENT, type = ChatMessageType.BENCHMARK_LLM_RESULT,
latencyMs = latencyMs, side = ChatSide.AGENT,
accelerator = accelerator, latencyMs = latencyMs,
) accelerator = accelerator,
)
data class Histogram( data class Histogram(val buckets: List<Int>, val maxCount: Int, val highlightBucketIndex: Int = -1)
val buckets: List<Int>,
val maxCount: Int,
val highlightBucketIndex: Int = -1
)
/** Chat message for showing prompt templates. */ /** Chat message for showing prompt templates. */
class ChatMessagePromptTemplates( class ChatMessagePromptTemplates(
val templates: List<PromptTemplate>, val templates: List<PromptTemplate>,
val showMakeYourOwn: Boolean = true, val showMakeYourOwn: Boolean = true,
) : ChatMessage(type = ChatMessageType.PROMPT_TEMPLATES, side = ChatSide.SYSTEM) ) : ChatMessage(type = ChatMessageType.PROMPT_TEMPLATES, side = ChatSide.SYSTEM)
data class PromptTemplate(val title: String, val description: String, val prompt: String)

View file

@ -16,6 +16,11 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.preview.PreviewChatModel
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.content.ClipData
import android.graphics.Bitmap import android.graphics.Bitmap
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
@ -69,15 +74,13 @@ import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
import androidx.compose.ui.input.nestedscroll.NestedScrollSource import androidx.compose.ui.input.nestedscroll.NestedScrollSource
import androidx.compose.ui.input.nestedscroll.nestedScroll import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.input.pointer.pointerInput import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.platform.ClipEntry
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalClipboard
import androidx.compose.ui.platform.LocalDensity import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.platform.LocalHapticFeedback import androidx.compose.ui.platform.LocalHapticFeedback
import androidx.compose.ui.res.dimensionResource import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
@ -86,20 +89,15 @@ import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.ui.common.ErrorDialog import com.google.ai.edge.gallery.ui.common.ErrorDialog
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewChatModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors import com.google.ai.edge.gallery.ui.theme.customColors
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
enum class ChatInputType { enum class ChatInputType {
TEXT, IMAGE, TEXT,
IMAGE,
} }
/** /** Composable function for the main chat panel, displaying messages and handling user input. */
* Composable function for the main chat panel, displaying messages and handling user input.
*/
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun ChatPanel( fun ChatPanel(
@ -126,18 +124,19 @@ fun ChatPanel(
val snackbarHostState = remember { SnackbarHostState() } val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val haptic = LocalHapticFeedback.current val haptic = LocalHapticFeedback.current
val hasImageMessageToLastConfigChange = remember(messages) { val imageMessageCountToLastConfigChange =
var foundImageMessage = false remember(messages) {
for (message in messages.reversed()) { var imageMessageCount = 0
if (message is ChatMessageConfigValuesChange) { for (message in messages.reversed()) {
break if (message is ChatMessageConfigValuesChange) {
} break
if (message is ChatMessageImage) { }
foundImageMessage = true if (message is ChatMessageImage) {
imageMessageCount++
}
} }
imageMessageCount
} }
foundImageMessage
}
var curMessage by remember { mutableStateOf("") } // Correct state var curMessage by remember { mutableStateOf("") } // Correct state
val focusManager = LocalFocusManager.current val focusManager = LocalFocusManager.current
@ -163,8 +162,9 @@ fun ChatPanel(
lastMessageContent.value = tmpLastMessage.content lastMessageContent.value = tmpLastMessage.content
} }
} }
val lastShowingStatsByModel: MutableState<Map<String, MutableSet<ChatMessage>>> = val lastShowingStatsByModel: MutableState<Map<String, MutableSet<ChatMessage>>> = remember {
remember { mutableStateOf(mapOf()) } mutableStateOf(mapOf())
}
// Scroll the content to the bottom when any of these changes. // Scroll the content to the bottom when any of these changes.
LaunchedEffect( LaunchedEffect(
@ -217,15 +217,12 @@ fun ChatPanel(
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
} }
Column( Column(modifier = modifier.imePadding()) {
modifier = modifier.imePadding()
) {
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) { Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
LazyColumn( LazyColumn(
modifier = Modifier modifier = Modifier.fillMaxSize().nestedScroll(nestedScrollConnection),
.fillMaxSize() state = listState,
.nestedScroll(nestedScrollConnection), verticalArrangement = Arrangement.Top,
state = listState, verticalArrangement = Arrangement.Top,
) { ) {
items(messages) { message -> items(messages) { message ->
val imageHistoryCurIndex = remember { mutableIntStateOf(0) } val imageHistoryCurIndex = remember { mutableIntStateOf(0) }
@ -254,14 +251,14 @@ fun ChatPanel(
val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius) val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius)
Column( Column(
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.padding( .padding(
start = 12.dp + extraPaddingStart, start = 12.dp + extraPaddingStart,
end = 12.dp + extraPaddingEnd, end = 12.dp + extraPaddingEnd,
top = 6.dp, top = 6.dp,
bottom = 6.dp, bottom = 6.dp,
), ),
horizontalAlignment = hAlign, horizontalAlignment = hAlign,
) messageColumn@{ ) messageColumn@{
// Sender row. // Sender row.
@ -272,7 +269,7 @@ fun ChatPanel(
MessageSender( MessageSender(
message = message, message = message,
agentName = agentName, agentName = agentName,
imageHistoryCurIndex = imageHistoryCurIndex.intValue imageHistoryCurIndex = imageHistoryCurIndex.intValue,
) )
// Message body. // Message body.
@ -290,40 +287,42 @@ fun ChatPanel(
is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message) is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message)
// Prompt templates. // Prompt templates.
is ChatMessagePromptTemplates -> MessageBodyPromptTemplates(message = message, is ChatMessagePromptTemplates ->
task = task, MessageBodyPromptTemplates(
onPromptClicked = { template -> message = message,
onSendMessage( task = task,
selectedModel, onPromptClicked = { template ->
listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER)) onSendMessage(
) selectedModel,
}) listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER)),
)
},
)
// Non-system messages. // Non-system messages.
else -> { else -> {
// The bubble shape around the message body. // The bubble shape around the message body.
var messageBubbleModifier = Modifier var messageBubbleModifier =
.clip( Modifier.clip(
MessageBubbleShape( MessageBubbleShape(
radius = bubbleBorderRadius, radius = bubbleBorderRadius,
hardCornerAtLeftOrRight = hardCornerAtLeftOrRight hardCornerAtLeftOrRight = hardCornerAtLeftOrRight,
)
) )
) .background(backgroundColor)
.background(backgroundColor)
if (message is ChatMessageText) { if (message is ChatMessageText) {
messageBubbleModifier = messageBubbleModifier.pointerInput(Unit) { messageBubbleModifier =
detectTapGestures( messageBubbleModifier.pointerInput(Unit) {
onLongPress = { detectTapGestures(
haptic.performHapticFeedback(HapticFeedbackType.LongPress) onLongPress = {
longPressedMessage.value = message haptic.performHapticFeedback(HapticFeedbackType.LongPress)
showMessageLongPressedSheet = true longPressedMessage.value = message
}, showMessageLongPressedSheet = true
) }
} )
}
} }
Box( Box(modifier = messageBubbleModifier) {
modifier = messageBubbleModifier,
) {
when (message) { when (message) {
// Text // Text
is ChatMessageText -> MessageBodyText(message = message) is ChatMessageText -> MessageBodyText(message = message)
@ -331,32 +330,35 @@ fun ChatPanel(
// Image // Image
is ChatMessageImage -> { is ChatMessageImage -> {
MessageBodyImage( MessageBodyImage(
message = message, modifier = Modifier message = message,
.clickable { modifier = Modifier.clickable { onImageSelected(message.bitmap) },
onImageSelected(message.bitmap)
}
) )
} }
// Image with history (for image gen) // Image with history (for image gen)
is ChatMessageImageWithHistory -> MessageBodyImageWithHistory( is ChatMessageImageWithHistory ->
message = message, imageHistoryCurIndex = imageHistoryCurIndex MessageBodyImageWithHistory(
) message = message,
imageHistoryCurIndex = imageHistoryCurIndex,
)
// Classification result // Classification result
is ChatMessageClassification -> MessageBodyClassification( is ChatMessageClassification ->
message = message, modifier = Modifier.width( MessageBodyClassification(
message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH message = message,
modifier =
Modifier.width(message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH),
) )
)
// Benchmark result. // Benchmark result.
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message) is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
// Benchmark LLM result. // Benchmark LLM result.
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm( is ChatMessageBenchmarkLlmResult ->
message = message, modifier = Modifier.wrapContentWidth() MessageBodyBenchmarkLlm(
) message = message,
modifier = Modifier.wrapContentWidth(),
)
else -> {} else -> {}
} }
@ -369,10 +371,14 @@ fun ChatPanel(
) { ) {
LatencyText(message = message) LatencyText(message = message)
// A button to show stats for the LLM message. // A button to show stats for the LLM message.
if (task.type.id.startsWith("llm_") && message is ChatMessageText if (
// This means we only want to show the action button when the message is done task.type.id.startsWith("llm_") &&
// generating, at which point the latency will be set. message is ChatMessageText
&& message.latencyMs >= 0 // This means we only want to show the action button when the message is
// done
// generating, at which point the latency will be set.
&&
message.latencyMs >= 0
) { ) {
val showingStats = val showingStats =
viewModel.isShowingStats(model = selectedModel, message = message) viewModel.isShowingStats(model = selectedModel, message = message)
@ -384,10 +390,7 @@ fun ChatPanel(
viewModel.toggleShowingStats(selectedModel, message) viewModel.toggleShowingStats(selectedModel, message)
// Add the stats message after the LLM message. // Add the stats message after the LLM message.
if (viewModel.isShowingStats( if (viewModel.isShowingStats(model = selectedModel, message = message)) {
model = selectedModel, message = message
)
) {
val llmBenchmarkResult = message.llmBenchmarkResult val llmBenchmarkResult = message.llmBenchmarkResult
if (llmBenchmarkResult != null) { if (llmBenchmarkResult != null) {
viewModel.insertMessageAfter( viewModel.insertMessageAfter(
@ -399,32 +402,30 @@ fun ChatPanel(
} }
// Remove the stats message. // Remove the stats message.
else { else {
val curMessageIndex = viewModel.getMessageIndex( val curMessageIndex =
model = selectedModel, message = message viewModel.getMessageIndex(model = selectedModel, message = message)
)
viewModel.removeMessageAt( viewModel.removeMessageAt(
model = selectedModel, index = curMessageIndex + 1 model = selectedModel,
index = curMessageIndex + 1,
) )
} }
}, },
enabled = !uiState.inProgress enabled = !uiState.inProgress,
) )
} }
} }
} else if (message.side == ChatSide.USER) { } else if (message.side == ChatSide.USER) {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp) horizontalArrangement = Arrangement.spacedBy(4.dp),
) { ) {
// Run again button. // Run again button.
if (selectedModel.showRunAgainButton) { if (selectedModel.showRunAgainButton) {
MessageActionButton( MessageActionButton(
label = stringResource(R.string.run_again), label = stringResource(R.string.run_again),
icon = Icons.Rounded.Refresh, icon = Icons.Rounded.Refresh,
onClick = { onClick = { onRunAgainClicked(selectedModel, message) },
onRunAgainClicked(selectedModel, message) enabled = !uiState.inProgress,
},
enabled = !uiState.inProgress
) )
} }
@ -437,7 +438,7 @@ fun ChatPanel(
showBenchmarkConfigsDialog = true showBenchmarkConfigsDialog = true
benchmarkMessage.value = message benchmarkMessage.value = message
}, },
enabled = !uiState.inProgress enabled = !uiState.inProgress,
) )
} }
} }
@ -453,15 +454,16 @@ fun ChatPanel(
// Show an info message for ask image task to get users started. // Show an info message for ask image task to get users started.
if (task.type == TaskType.LLM_ASK_IMAGE && messages.isEmpty()) { if (task.type == TaskType.LLM_ASK_IMAGE && messages.isEmpty()) {
Column( Column(
modifier = Modifier modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(),
.padding(horizontal = 16.dp)
.fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally, horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center verticalArrangement = Arrangement.Center,
) { ) {
MessageBodyInfo( MessageBodyInfo(
ChatMessageInfo(content = "To get started, click + below to add an image and type a prompt to ask a question about it."), ChatMessageInfo(
smallFontSize = false content =
"To get started, click + below to add images (up to 10 in a single session) and type a prompt to ask a question about it."
),
smallFontSize = false,
) )
} }
} }
@ -470,16 +472,18 @@ fun ChatPanel(
// Chat input // Chat input
when (chatInputType) { when (chatInputType) {
ChatInputType.TEXT -> { ChatInputType.TEXT -> {
// val isLlmTask = task.type == TaskType.LLM_CHAT // val isLlmTask = task.type == TaskType.LLM_CHAT
// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates) // val notLlmStartScreen = !(messages.size == 1 && messages[0] is
// ChatMessagePromptTemplates)
MessageInputText( MessageInputText(
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
curMessage = curMessage, curMessage = curMessage,
inProgress = uiState.inProgress, inProgress = uiState.inProgress,
isResettingSession = uiState.isResettingSession, isResettingSession = uiState.isResettingSession,
modelPreparing = uiState.preparing, modelPreparing = uiState.preparing,
hasImageMessage = hasImageMessageToLastConfigChange, imageMessageCount = imageMessageCountToLastConfigChange,
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes, textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
onValueChanged = { curMessage = it }, onValueChanged = { curMessage = it },
onSendMessage = { onSendMessage = {
@ -488,66 +492,78 @@ fun ChatPanel(
}, },
onOpenPromptTemplatesClicked = { onOpenPromptTemplatesClicked = {
onSendMessage( onSendMessage(
selectedModel, listOf( selectedModel,
listOf(
ChatMessagePromptTemplates( ChatMessagePromptTemplates(
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false templates = selectedModel.llmPromptTemplates,
showMakeYourOwn = false,
) )
) ),
) )
}, },
onStopButtonClicked = onStopButtonClicked, onStopButtonClicked = onStopButtonClicked,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, // showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false, showPromptTemplatesInMenu = false,
showImagePickerInMenu = selectedModel.llmSupportImage, showImagePickerInMenu = selectedModel.llmSupportImage,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress, showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
) )
} }
ChatInputType.IMAGE -> MessageInputImage( ChatInputType.IMAGE ->
disableButtons = uiState.inProgress, MessageInputImage(
streamingMessage = streamingMessage, disableButtons = uiState.inProgress,
onImageSelected = { bitmap -> streamingMessage = streamingMessage,
onSendMessage( onImageSelected = { bitmap ->
selectedModel, listOf( onSendMessage(
selectedModel,
listOf(
ChatMessageImage(
bitmap = bitmap,
imageBitMap = bitmap.asImageBitmap(),
side = ChatSide.USER,
)
),
)
},
onStreamImage = { bitmap ->
onStreamImageMessage(
selectedModel,
ChatMessageImage( ChatMessageImage(
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER bitmap = bitmap,
) imageBitMap = bitmap.asImageBitmap(),
side = ChatSide.USER,
),
) )
) },
}, onStreamEnd = onStreamEnd,
onStreamImage = { bitmap -> )
onStreamImageMessage(
selectedModel, ChatMessageImage(
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
)
)
},
onStreamEnd = onStreamEnd,
)
} }
} }
// Error dialog. // Error dialog.
if (showErrorDialog) { if (showErrorDialog) {
ErrorDialog(error = modelInitializationStatus?.error ?: "", onDismiss = { ErrorDialog(
showErrorDialog = false error = modelInitializationStatus?.error ?: "",
}) onDismiss = { showErrorDialog = false },
)
} }
// Benchmark config dialog. // Benchmark config dialog.
if (showBenchmarkConfigsDialog) { if (showBenchmarkConfigsDialog) {
BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false }, BenchmarkConfigDialog(
onDismissed = { showBenchmarkConfigsDialog = false },
messageToBenchmark = benchmarkMessage.value, messageToBenchmark = benchmarkMessage.value,
onBenchmarkClicked = { message, warmUpIterations, benchmarkIterations -> onBenchmarkClicked = { message, warmUpIterations, benchmarkIterations ->
onBenchmarkClicked(selectedModel, message, warmUpIterations, benchmarkIterations) onBenchmarkClicked(selectedModel, message, warmUpIterations, benchmarkIterations)
}) },
)
} }
// Sheet to show when a message is long-pressed. // Sheet to show when a message is long-pressed.
if (showMessageLongPressedSheet) { if (showMessageLongPressedSheet) {
val message = longPressedMessage.value val message = longPressedMessage.value
if (message != null && message is ChatMessageText) { if (message != null && message is ChatMessageText) {
val clipboardManager = LocalClipboardManager.current val clipboard = LocalClipboard.current
ModalBottomSheet( ModalBottomSheet(
onDismissRequest = { showMessageLongPressedSheet = false }, onDismissRequest = { showMessageLongPressedSheet = false },
@ -555,28 +571,32 @@ fun ChatPanel(
) { ) {
Column { Column {
// Copy text. // Copy text.
Box(modifier = Modifier Box(
.fillMaxWidth() modifier =
.clickable { Modifier.fillMaxWidth().clickable {
// Copy text. // Copy text.
val clipData = AnnotatedString(message.content) scope.launch {
clipboardManager.setText(clipData) val clipData = ClipData.newPlainText("message content", message.content)
val clipEntry = ClipEntry(clipData = clipData)
clipboard.setClipEntry(clipEntry = clipEntry)
}
// Hide sheet. // Hide sheet.
showMessageLongPressedSheet = false showMessageLongPressedSheet = false
// Show a snack bar. // Show a snack bar.
scope.launch { scope.launch { snackbarHostState.showSnackbar("Text copied to clipboard") }
snackbarHostState.showSnackbar("Text copied to clipboard")
} }
}) { ) {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp), horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier.padding(vertical = 8.dp, horizontal = 16.dp) modifier = Modifier.padding(vertical = 8.dp, horizontal = 16.dp),
) { ) {
Icon( Icon(
Icons.Rounded.ContentCopy, contentDescription = "", modifier = Modifier.size(18.dp) Icons.Rounded.ContentCopy,
contentDescription = "",
modifier = Modifier.size(18.dp),
) )
Text("Copy text") Text("Copy text")
} }
@ -584,25 +604,24 @@ fun ChatPanel(
} }
} }
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun ChatPanelPreview() { // fun ChatPanelPreview() {
GalleryTheme { // GalleryTheme {
val context = LocalContext.current // val context = LocalContext.current
val task = TASK_TEST1 // val task = TASK_TEST1
ChatPanel( // ChatPanel(
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), // modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
task = task, // task = task,
selectedModel = TASK_TEST1.models[1], // selectedModel = TASK_TEST1.models[1],
viewModel = PreviewChatModel(context = context), // viewModel = PreviewChatModel(context = context),
navigateUp = {}, // navigateUp = {},
onSendMessage = { _, _ -> }, // onSendMessage = { _, _ -> },
onRunAgainClicked = { _, _ -> }, // onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> }, // onBenchmarkClicked = { _, _, _, _ -> },
) // )
} // }
} // }

View file

@ -16,6 +16,10 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.preview.PreviewChatModel
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.graphics.Bitmap import android.graphics.Bitmap
import android.util.Log import android.util.Log
import androidx.activity.compose.BackHandler import androidx.activity.compose.BackHandler
@ -55,7 +59,6 @@ import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.ModelDownloadStatusType
@ -63,13 +66,9 @@ import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.ModelPageAppBar import com.google.ai.edge.gallery.ui.common.ModelPageAppBar
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.modelmanager.PagerScrollState import com.google.ai.edge.gallery.ui.modelmanager.PagerScrollState
import com.google.ai.edge.gallery.ui.preview.PreviewChatModel import kotlin.math.absoluteValue
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlin.math.absoluteValue
private const val TAG = "AGChatView" private const val TAG = "AGChatView"
@ -77,8 +76,8 @@ private const val TAG = "AGChatView"
* A composable that displays a chat interface, allowing users to interact with different models * A composable that displays a chat interface, allowing users to interact with different models
* associated with a given task. * associated with a given task.
* *
* This composable provides a horizontal pager for switching between models, a model selector * This composable provides a horizontal pager for switching between models, a model selector for
* for configuring the selected model, and a chat panel for sending and receiving messages. It also * configuring the selected model, and a chat panel for sending and receiving messages. It also
* manages model initialization, cleanup, and download status, and handles navigation and system * manages model initialization, cleanup, and download status, and handles navigation and system
* back gestures. * back gestures.
*/ */
@ -104,8 +103,11 @@ fun ChatView(
var selectedImage by remember { mutableStateOf<Bitmap?>(null) } var selectedImage by remember { mutableStateOf<Bitmap?>(null) }
var showImageViewer by remember { mutableStateOf(false) } var showImageViewer by remember { mutableStateOf(false) }
val pagerState = rememberPagerState(initialPage = task.models.indexOf(selectedModel), val pagerState =
pageCount = { task.models.size }) rememberPagerState(
initialPage = task.models.indexOf(selectedModel),
pageCount = { task.models.size },
)
val context = LocalContext.current val context = LocalContext.current
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
var navigatingUp by remember { mutableStateOf(false) } var navigatingUp by remember { mutableStateOf(false) }
@ -138,7 +140,7 @@ fun ChatView(
val curSelectedModel = task.models[pagerState.settledPage] val curSelectedModel = task.models[pagerState.settledPage]
Log.d( Log.d(
TAG, TAG,
"Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model." "Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model.",
) )
if (curSelectedModel.name != selectedModel.name) { if (curSelectedModel.name != selectedModel.name) {
modelManagerViewModel.cleanupModel(task = task, model = selectedModel) modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
@ -148,52 +150,49 @@ fun ChatView(
LaunchedEffect(pagerState) { LaunchedEffect(pagerState) {
// Collect from the a snapshotFlow reading the currentPage // Collect from the a snapshotFlow reading the currentPage
snapshotFlow { pagerState.currentPage }.collect { page -> snapshotFlow { pagerState.currentPage }.collect { page -> Log.d(TAG, "Page changed to $page") }
Log.d(TAG, "Page changed to $page")
}
} }
// Trigger scroll sync. // Trigger scroll sync.
LaunchedEffect(pagerState) { LaunchedEffect(pagerState) {
snapshotFlow { snapshotFlow {
PagerScrollState( PagerScrollState(
page = pagerState.currentPage, offset = pagerState.currentPageOffsetFraction page = pagerState.currentPage,
) offset = pagerState.currentPageOffsetFraction,
}.collect { scrollState -> )
modelManagerViewModel.pagerScrollState.value = scrollState }
} .collect { scrollState -> modelManagerViewModel.pagerScrollState.value = scrollState }
} }
// Handle system's edge swipe. // Handle system's edge swipe.
BackHandler { BackHandler { handleNavigateUp() }
handleNavigateUp()
}
Scaffold(modifier = modifier, topBar = { Scaffold(
ModelPageAppBar( modifier = modifier,
task = task, topBar = {
model = selectedModel, ModelPageAppBar(
modelManagerViewModel = modelManagerViewModel, task = task,
canShowResetSessionButton = true, model = selectedModel,
isResettingSession = uiState.isResettingSession, modelManagerViewModel = modelManagerViewModel,
inProgress = uiState.inProgress, canShowResetSessionButton = true,
modelPreparing = uiState.preparing, isResettingSession = uiState.isResettingSession,
onResetSessionClicked = onResetSessionClicked, inProgress = uiState.inProgress,
onConfigChanged = { old, new -> modelPreparing = uiState.preparing,
viewModel.addConfigChangedMessage( onResetSessionClicked = onResetSessionClicked,
oldConfigValues = old, newConfigValues = new, model = selectedModel onConfigChanged = { old, new ->
) viewModel.addConfigChangedMessage(
}, oldConfigValues = old,
onBackClicked = { newConfigValues = new,
handleNavigateUp() model = selectedModel,
}, )
onModelSelected = { model -> },
scope.launch { onBackClicked = { handleNavigateUp() },
pagerState.animateScrollToPage(task.models.indexOf(model)) onModelSelected = { model ->
} scope.launch { pagerState.animateScrollToPage(task.models.indexOf(model)) }
}, },
) )
}) { innerPadding -> },
) { innerPadding ->
Box { Box {
// A horizontal scrollable pager to switch between models. // A horizontal scrollable pager to switch between models.
HorizontalPager(state = pagerState) { pageIndex -> HorizontalPager(state = pagerState) { pageIndex ->
@ -202,17 +201,20 @@ fun ChatView(
// Calculate the alpha of the current page based on how far they are from the center. // Calculate the alpha of the current page based on how far they are from the center.
val pageOffset = val pageOffset =
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue ((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction)
.absoluteValue
val curAlpha = 1f - pageOffset.coerceIn(0f, 1f) val curAlpha = 1f - pageOffset.coerceIn(0f, 1f)
Column( Column(
modifier = Modifier modifier =
.padding(innerPadding) Modifier.padding(innerPadding)
.fillMaxSize() .fillMaxSize()
.background(MaterialTheme.colorScheme.surface) .background(MaterialTheme.colorScheme.surface)
) { ) {
ModelDownloadStatusInfoPanel( ModelDownloadStatusInfoPanel(
model = curSelectedModel, task = task, modelManagerViewModel = modelManagerViewModel model = curSelectedModel,
task = task,
modelManagerViewModel = modelManagerViewModel,
) )
// The main messages panel. // The main messages panel.
@ -230,19 +232,16 @@ fun ChatView(
onStreamEnd = { averageFps -> onStreamEnd = { averageFps ->
viewModel.addMessage( viewModel.addMessage(
model = curSelectedModel, model = curSelectedModel,
message = ChatMessageInfo(content = "Live camera session ended. Average FPS: $averageFps") message =
ChatMessageInfo(content = "Live camera session ended. Average FPS: $averageFps"),
) )
}, },
onStopButtonClicked = { onStopButtonClicked = { onStopButtonClicked(curSelectedModel) },
onStopButtonClicked(curSelectedModel)
},
onImageSelected = { bitmap -> onImageSelected = { bitmap ->
selectedImage = bitmap selectedImage = bitmap
showImageViewer = true showImageViewer = true
}, },
modifier = Modifier modifier = Modifier.weight(1f).graphicsLayer { alpha = curAlpha },
.weight(1f)
.graphicsLayer { alpha = curAlpha },
chatInputType = chatInputType, chatInputType = chatInputType,
showStopButtonInInputWhenInProgress = showStopButtonInInputWhenInProgress, showStopButtonInInputWhenInProgress = showStopButtonInInputWhenInProgress,
) )
@ -254,39 +253,43 @@ fun ChatView(
AnimatedVisibility( AnimatedVisibility(
visible = showImageViewer, visible = showImageViewer,
enter = slideInVertically(initialOffsetY = { fullHeight -> fullHeight }) + fadeIn(), enter = slideInVertically(initialOffsetY = { fullHeight -> fullHeight }) + fadeIn(),
exit = slideOutVertically( exit = slideOutVertically(targetOffsetY = { fullHeight -> fullHeight }) + fadeOut(),
targetOffsetY = { fullHeight -> fullHeight },
) + fadeOut()
) { ) {
selectedImage?.let { image -> selectedImage?.let { image ->
ZoomableBox( ZoomableBox(
modifier = Modifier modifier =
.fillMaxSize() Modifier.fillMaxSize()
.padding(top = innerPadding.calculateTopPadding()) .padding(top = innerPadding.calculateTopPadding())
.background(Color.Black.copy(alpha = 0.95f)), .background(Color.Black.copy(alpha = 0.95f))
) { ) {
Image( Image(
bitmap = image.asImageBitmap(), contentDescription = "", bitmap = image.asImageBitmap(),
modifier = modifier contentDescription = "",
.fillMaxSize() modifier =
.graphicsLayer( modifier
scaleX = scale, scaleY = scale, translationX = offsetX, translationY = offsetY .fillMaxSize()
), .graphicsLayer(
scaleX = scale,
scaleY = scale,
translationX = offsetX,
translationY = offsetY,
),
contentScale = ContentScale.Fit, contentScale = ContentScale.Fit,
) )
// Close button. // Close button.
IconButton( IconButton(
onClick = { onClick = { showImageViewer = false },
showImageViewer = false colors =
}, colors = IconButtonDefaults.iconButtonColors( IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceVariant, containerColor = MaterialTheme.colorScheme.surfaceVariant
), modifier = Modifier.offset(x = (-8).dp, y = 8.dp) ),
modifier = Modifier.offset(x = (-8).dp, y = 8.dp),
) { ) {
Icon( Icon(
Icons.Rounded.Close, Icons.Rounded.Close,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.primary tint = MaterialTheme.colorScheme.primary,
) )
} }
} }
@ -296,20 +299,20 @@ fun ChatView(
} }
} }
@Preview // @Preview
@Composable // @Composable
fun ChatScreenPreview() { // fun ChatScreenPreview() {
GalleryTheme { // GalleryTheme {
val context = LocalContext.current // val context = LocalContext.current
val task = TASK_TEST1 // val task = TASK_TEST1
ChatView( // ChatView(
task = task, // task = task,
viewModel = PreviewChatModel(context = context), // viewModel = PreviewChatModel(context = context),
modelManagerViewModel = PreviewModelManagerViewModel(context = context), // modelManagerViewModel = PreviewModelManagerViewModel(context = context),
onSendMessage = { _, _ -> }, // onSendMessage = { _, _ -> },
onRunAgainClicked = { _, _ -> }, // onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> }, // onBenchmarkClicked = { _, _, _, _ -> },
navigateUp = {}, // navigateUp = {},
) // )
} // }
} // }

View file

@ -18,9 +18,9 @@ package com.google.ai.edge.gallery.ui.common.chat
import android.util.Log import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import com.google.ai.edge.gallery.common.processLlmResponse
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.processLlmResponse
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
@ -28,14 +28,10 @@ import kotlinx.coroutines.flow.update
private const val TAG = "AGChatViewModel" private const val TAG = "AGChatViewModel"
data class ChatUiState( data class ChatUiState(
/** /** Indicates whether the runtime is currently processing a message. */
* Indicates whether the runtime is currently processing a message.
*/
val inProgress: Boolean = false, val inProgress: Boolean = false,
/** /** Indicates whether the session is being reset. */
* Indicates whether the session is being reset.
*/
val isResettingSession: Boolean = false, val isResettingSession: Boolean = false,
/** /**
@ -43,14 +39,10 @@ data class ChatUiState(
*/ */
val preparing: Boolean = false, val preparing: Boolean = false,
/** /** A map of model names to lists of chat messages. */
* A map of model names to lists of chat messages.
*/
val messagesByModel: Map<String, MutableList<ChatMessage>> = mapOf(), val messagesByModel: Map<String, MutableList<ChatMessage>> = mapOf(),
/** /** A map of model names to the currently streaming chat message. */
* A map of model names to the currently streaming chat message.
*/
val streamingMessagesByModel: Map<String, ChatMessage> = mapOf(), val streamingMessagesByModel: Map<String, ChatMessage> = mapOf(),
/* /*
@ -60,9 +52,7 @@ data class ChatUiState(
val showingStatsByModel: Map<String, MutableSet<ChatMessage>> = mapOf(), val showingStatsByModel: Map<String, MutableSet<ChatMessage>> = mapOf(),
) )
/** /** ViewModel responsible for managing the chat UI state and handling chat-related operations. */
* ViewModel responsible for managing the chat UI state and handling chat-related operations.
*/
open class ChatViewModel(val task: Task) : ViewModel() { open class ChatViewModel(val task: Task) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task)) private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow() val uiState = _uiState.asStateFlow()
@ -137,12 +127,13 @@ open class ChatViewModel(val task: Task) : ViewModel() {
val lastMessage = newMessages.last() val lastMessage = newMessages.last()
if (lastMessage is ChatMessageText) { if (lastMessage is ChatMessageText) {
val newContent = processLlmResponse(response = "${lastMessage.content}${partialContent}") val newContent = processLlmResponse(response = "${lastMessage.content}${partialContent}")
val newLastMessage = ChatMessageText( val newLastMessage =
content = newContent, ChatMessageText(
side = lastMessage.side, content = newContent,
latencyMs = latencyMs, side = lastMessage.side,
accelerator = lastMessage.accelerator, latencyMs = latencyMs,
) accelerator = lastMessage.accelerator,
)
newMessages.removeAt(newMessages.size - 1) newMessages.removeAt(newMessages.size - 1)
newMessages.add(newLastMessage) newMessages.add(newLastMessage)
} }
@ -154,7 +145,7 @@ open class ChatViewModel(val task: Task) : ViewModel() {
fun updateLastTextMessageLlmBenchmarkResult( fun updateLastTextMessageLlmBenchmarkResult(
model: Model, model: Model,
llmBenchmarkResult: ChatMessageBenchmarkLlmResult llmBenchmarkResult: ChatMessageBenchmarkLlmResult,
) { ) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap() val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf() val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
@ -215,12 +206,17 @@ open class ChatViewModel(val task: Task) : ViewModel() {
} }
fun addConfigChangedMessage( fun addConfigChangedMessage(
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model oldConfigValues: Map<String, Any>,
newConfigValues: Map<String, Any>,
model: Model,
) { ) {
Log.d(TAG, "Adding config changed message. Old: ${oldConfigValues}, new: $newConfigValues") Log.d(TAG, "Adding config changed message. Old: ${oldConfigValues}, new: $newConfigValues")
val message = ChatMessageConfigValuesChange( val message =
model = model, oldValues = oldConfigValues, newValues = newConfigValues ChatMessageConfigValuesChange(
) model = model,
oldValues = oldConfigValues,
newValues = newConfigValues,
)
addMessage(message = message, model = model) addMessage(message = message, model = model)
} }
@ -253,8 +249,6 @@ open class ChatViewModel(val task: Task) : ViewModel() {
} }
messagesByModel[model.name] = messages messagesByModel[model.name] = messages
} }
return ChatUiState( return ChatUiState(messagesByModel = messagesByModel)
messagesByModel = messagesByModel
)
} }
} }

View file

@ -37,9 +37,8 @@ import com.google.ai.edge.gallery.ui.theme.labelSmallNarrowMedium
/** /**
* Composable function to display a data card with a label and a numeric value. * Composable function to display a data card with a label and a numeric value.
* *
* This function renders a column containing a label and a formatted numeric value. * This function renders a column containing a label and a formatted numeric value. It provides
* It provides options for highlighting the value and displaying a placeholder when the value is not * options for highlighting the value and displaying a placeholder when the value is not available.
* available.
*/ */
@Composable @Composable
fun DataCard( fun DataCard(
@ -47,7 +46,7 @@ fun DataCard(
value: Float?, value: Float?,
unit: String, unit: String,
highlight: Boolean = false, highlight: Boolean = false,
showPlaceholder: Boolean = false showPlaceholder: Boolean = false,
) { ) {
var strValue = "-" var strValue = "-"
Column { Column {
@ -57,19 +56,13 @@ fun DataCard(
} else { } else {
strValue = if (value == null) "-" else "%.2f".format(value) strValue = if (value == null) "-" else "%.2f".format(value)
if (highlight) { if (highlight) {
Text( Text(strValue, style = bodySmallMediumNarrowBold, color = MaterialTheme.colorScheme.primary)
strValue, style = bodySmallMediumNarrowBold, color = MaterialTheme.colorScheme.primary
)
} else { } else {
Text(strValue, style = bodySmallMediumNarrow) Text(strValue, style = bodySmallMediumNarrow)
} }
} }
if (strValue != "-") { if (strValue != "-") {
Text( Text(unit, style = labelSmallNarrow, modifier = Modifier.alpha(0.5f).offset(y = (-1).dp))
unit, style = labelSmallNarrow, modifier = Modifier
.alpha(0.5f)
.offset(y = (-1).dp)
)
} }
} }
} }
@ -80,14 +73,26 @@ fun DataCardPreview() {
GalleryTheme { GalleryTheme {
Row(modifier = Modifier.padding(16.dp), horizontalArrangement = Arrangement.spacedBy(16.dp)) { Row(modifier = Modifier.padding(16.dp), horizontalArrangement = Arrangement.spacedBy(16.dp)) {
DataCard( DataCard(
label = "sum", value = 123.45f, unit = "ms", highlight = true, showPlaceholder = false label = "sum",
value = 123.45f,
unit = "ms",
highlight = true,
showPlaceholder = false,
) )
DataCard( DataCard(
label = "average", value = 12.3f, unit = "ms", highlight = false, showPlaceholder = false label = "average",
value = 12.3f,
unit = "ms",
highlight = false,
showPlaceholder = false,
) )
DataCard( DataCard(
label = "test", value = null, unit = "ms", highlight = false, showPlaceholder = false label = "test",
value = null,
unit = "ms",
highlight = false,
showPlaceholder = false,
) )
} }
} }
} }

View file

@ -1,226 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap
import android.graphics.Matrix
import android.util.Size
import androidx.camera.core.CameraSelector
import androidx.camera.core.ImageAnalysis
import androidx.camera.core.resolutionselector.ResolutionSelector
import androidx.camera.core.resolutionselector.ResolutionStrategy
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.aspectRatio
import androidx.compose.foundation.layout.fillMaxHeight
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.Card
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.core.content.ContextCompat
import androidx.lifecycle.compose.LocalLifecycleOwner
import java.util.concurrent.Executors
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine
/**
* Composable function to display a live camera feed in a dialog.
*
* This function renders a dialog that displays a live camera preview, along with optional
* classification results and FPS information. It manages camera initialization, frame capture,
* and dialog dismissal.
*/
@Composable
fun LiveCameraDialog(
onDismissed: (averageFps: Int) -> Unit,
onBitmap: (Bitmap) -> Unit,
streamingMessage: ChatMessage? = null,
) {
val context = LocalContext.current
val lifecycleOwner = LocalLifecycleOwner.current
var imageBitmap by remember { mutableStateOf<ImageBitmap?>(null) }
var cameraProvider: ProcessCameraProvider? by remember { mutableStateOf(null) }
var sumFps by remember { mutableLongStateOf(0L) }
var fpsCount by remember { mutableLongStateOf(0L) }
LaunchedEffect(key1 = true) {
cameraProvider = startCamera(
context,
lifecycleOwner,
onBitmap = onBitmap,
onImageBitmap = { b -> imageBitmap = b })
}
Dialog(onDismissRequest = {
cameraProvider?.unbindAll()
onDismissed((sumFps / fpsCount).toInt())
}) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 8.dp)
) {
Text(
"Live camera",
style = MaterialTheme.typography.titleLarge,
)
if (streamingMessage != null) {
val fps = (1000f / streamingMessage.latencyMs).toInt()
sumFps += fps.toLong()
fpsCount += 1
Text(
"%d FPS".format(fps),
style = MaterialTheme.typography.titleLarge,
)
}
}
// Camera live view.
Row(
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f),
horizontalArrangement = Arrangement.Center
) {
val ib = imageBitmap
if (ib != null) {
Image(
bitmap = ib,
contentDescription = "",
modifier = Modifier
.fillMaxHeight()
.clip(RoundedCornerShape(8.dp)),
contentScale = ContentScale.Inside
)
}
}
// Result.
if (streamingMessage != null && streamingMessage is ChatMessageClassification) {
MessageBodyClassification(
message = streamingMessage,
modifier = Modifier.fillMaxWidth(),
oneLineLabel = true
)
}
// Button.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
TextButton(
onClick = {
cameraProvider?.unbindAll()
onDismissed((sumFps / fpsCount).toInt())
},
) {
Text("OK")
}
}
}
}
}
}
/**
* Asynchronously initializes and starts the camera for image capture and analysis.
*
* This function sets up the camera using CameraX, configures image analysis, and binds
* the camera lifecycle to the provided LifecycleOwner. It captures frames from the camera,
* converts them to Bitmaps and ImageBitmaps, and invokes the provided callbacks.
*/
private suspend fun startCamera(
context: android.content.Context,
lifecycleOwner: androidx.lifecycle.LifecycleOwner,
onBitmap: (Bitmap) -> Unit,
onImageBitmap: (ImageBitmap) -> Unit
): ProcessCameraProvider? = suspendCoroutine { continuation ->
val cameraProviderFuture = ProcessCameraProvider.getInstance(context)
cameraProviderFuture.addListener({
val cameraProvider = cameraProviderFuture.get()
val resolutionSelector = ResolutionSelector.Builder().setResolutionStrategy(
ResolutionStrategy(
Size(1080, 1080),
ResolutionStrategy.FALLBACK_RULE_CLOSEST_LOWER_THEN_HIGHER
)
).build()
val imageAnalysis =
ImageAnalysis.Builder().setResolutionSelector(resolutionSelector)
.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST).build().also {
it.setAnalyzer(Executors.newSingleThreadExecutor()) { imageProxy ->
var bitmap = imageProxy.toBitmap()
val rotation = imageProxy.imageInfo.rotationDegrees
bitmap = if (rotation != 0) {
val matrix = Matrix().apply {
postRotate(rotation.toFloat())
}
Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else bitmap
onBitmap(bitmap)
onImageBitmap(bitmap.asImageBitmap())
imageProxy.close()
}
}
val cameraSelector = CameraSelector.DEFAULT_BACK_CAMERA
try {
cameraProvider?.unbindAll()
cameraProvider?.bindToLifecycle(
lifecycleOwner, cameraSelector, imageAnalysis
)
// Resume with the provider
continuation.resume(cameraProvider)
} catch (exc: Exception) {
// todo: Handle exceptions (e.g., camera initialization failure)
}
}, ContextCompat.getMainExecutor(context))
}

View file

@ -16,16 +16,15 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.PlayArrow
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text import androidx.compose.material3.Text
@ -35,62 +34,56 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow
/** /** Composable function to display an action button below a chat message. */
* Composable function to display an action button below a chat message.
*/
@Composable @Composable
fun MessageActionButton( fun MessageActionButton(
label: String, label: String,
icon: ImageVector, icon: ImageVector,
onClick: () -> Unit, onClick: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
enabled: Boolean = true enabled: Boolean = true,
) { ) {
val curModifier = modifier val curModifier =
.padding(top = 4.dp) modifier
.clip(CircleShape) .padding(top = 4.dp)
.background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh) .clip(CircleShape)
.background(
if (enabled) MaterialTheme.colorScheme.secondaryContainer
else MaterialTheme.colorScheme.surfaceContainerHigh
)
val alpha: Float = if (enabled) 1.0f else 0.3f val alpha: Float = if (enabled) 1.0f else 0.3f
Row( Row(
modifier = if (enabled) curModifier.clickable { onClick() } else modifier, modifier = if (enabled) curModifier.clickable { onClick() } else modifier,
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
) { ) {
Icon( Icon(
icon, contentDescription = "", modifier = Modifier icon,
.size(16.dp) contentDescription = "",
.offset(x = 6.dp) modifier = Modifier.size(16.dp).offset(x = 6.dp).alpha(alpha),
.alpha(alpha)
) )
Text( Text(
label, label,
color = MaterialTheme.colorScheme.onSecondaryContainer, color = MaterialTheme.colorScheme.onSecondaryContainer,
style = bodySmallNarrow, style = bodySmallNarrow,
modifier = Modifier modifier = Modifier.padding(start = 10.dp, end = 8.dp, top = 4.dp, bottom = 4.dp).alpha(alpha),
.padding(
start = 10.dp, end = 8.dp, top = 4.dp, bottom = 4.dp
)
.alpha(alpha)
) )
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageActionButtonPreview() { // fun MessageActionButtonPreview() {
GalleryTheme { // GalleryTheme {
Column { // Column {
MessageActionButton(label = "run", icon = Icons.Default.PlayArrow, onClick = {}) // MessageActionButton(label = "run", icon = Icons.Default.PlayArrow, onClick = {})
MessageActionButton( // MessageActionButton(
label = "run", // label = "run",
icon = Icons.Default.PlayArrow, // icon = Icons.Default.PlayArrow,
enabled = false, // enabled = false,
onClick = {}) // onClick = {})
} // }
} // }
} // }

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
@ -31,9 +33,7 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import kotlin.math.max import kotlin.math.max
private const val DEFAULT_HISTOGRAM_BAR_HEIGHT = 50f private const val DEFAULT_HISTOGRAM_BAR_HEIGHT = 50f
@ -41,37 +41,31 @@ private const val DEFAULT_HISTOGRAM_BAR_HEIGHT = 50f
/** /**
* Composable function to display benchmark results within a chat message. * Composable function to display benchmark results within a chat message.
* *
* This function renders benchmark statistics (e.g., average latency) in data cards and * This function renders benchmark statistics (e.g., average latency) in data cards and visualizes
* visualizes the latency distribution using a histogram. * the latency distribution using a histogram.
*/ */
@Composable @Composable
fun MessageBodyBenchmark(message: ChatMessageBenchmarkResult) { fun MessageBodyBenchmark(message: ChatMessageBenchmarkResult) {
Column( Column(
modifier = Modifier modifier = Modifier.padding(12.dp).fillMaxWidth(),
.padding(12.dp) verticalArrangement = Arrangement.spacedBy(8.dp),
.fillMaxWidth(),
verticalArrangement = Arrangement.spacedBy(8.dp)
) { ) {
// Data cards. // Data cards.
Row( Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween) {
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) {
for (stat in message.orderedStats) { for (stat in message.orderedStats) {
DataCard( DataCard(
label = stat.label, label = stat.label,
unit = stat.unit, unit = stat.unit,
value = message.statValues[stat.id], value = message.statValues[stat.id],
highlight = stat.id == message.highlightStat, highlight = stat.id == message.highlightStat,
showPlaceholder = message.isWarmingUp() showPlaceholder = message.isWarmingUp(),
) )
} }
} }
// Histogram // Histogram
if (message.histogram.buckets.isNotEmpty()) { if (message.histogram.buckets.isNotEmpty()) {
Row( Row(horizontalArrangement = Arrangement.spacedBy(2.dp)) {
horizontalArrangement = Arrangement.spacedBy(2.dp)
) {
for ((index, count) in message.histogram.buckets.withIndex()) { for ((index, count) in message.histogram.buckets.withIndex()) {
var barBgColor = MaterialTheme.colorScheme.onSurfaceVariant var barBgColor = MaterialTheme.colorScheme.onSurfaceVariant
var alpha = 0.3f var alpha = 0.3f
@ -84,24 +78,24 @@ fun MessageBodyBenchmark(message: ChatMessageBenchmarkResult) {
} }
// Bar container. // Bar container.
Column( Column(
modifier = Modifier modifier = Modifier.height(DEFAULT_HISTOGRAM_BAR_HEIGHT.dp).width(4.dp),
.height(DEFAULT_HISTOGRAM_BAR_HEIGHT.dp)
.width(4.dp),
verticalArrangement = Arrangement.Bottom, verticalArrangement = Arrangement.Bottom,
) { ) {
// Bar content. // Bar content.
Box( Box(
modifier = Modifier modifier =
.height( Modifier.height(
max( max(
1f, 1f,
count.toFloat() / message.histogram.maxCount.toFloat() * DEFAULT_HISTOGRAM_BAR_HEIGHT count.toFloat() / message.histogram.maxCount.toFloat() *
).dp DEFAULT_HISTOGRAM_BAR_HEIGHT,
) )
.fillMaxWidth() .dp
.clip(RoundedCornerShape(20.dp, 20.dp, 0.dp, 0.dp)) )
.alpha(alpha) .fillMaxWidth()
.background(barBgColor) .clip(RoundedCornerShape(20.dp, 20.dp, 0.dp, 0.dp))
.alpha(alpha)
.background(barBgColor)
) )
} }
} }
@ -110,31 +104,31 @@ fun MessageBodyBenchmark(message: ChatMessageBenchmarkResult) {
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageBodyBenchmarkPreview() { // fun MessageBodyBenchmarkPreview() {
GalleryTheme { // GalleryTheme {
MessageBodyBenchmark( // MessageBodyBenchmark(
message = ChatMessageBenchmarkResult( // message = ChatMessageBenchmarkResult(
orderedStats = listOf( // orderedStats = listOf(
Stat(id = "stat1", label = "Stat1", unit = "ms"), // Stat(id = "stat1", label = "Stat1", unit = "ms"),
Stat(id = "stat2", label = "Stat2", unit = "ms"), // Stat(id = "stat2", label = "Stat2", unit = "ms"),
Stat(id = "stat3", label = "Stat3", unit = "ms"), // Stat(id = "stat3", label = "Stat3", unit = "ms"),
Stat(id = "stat4", label = "Stat4", unit = "ms") // Stat(id = "stat4", label = "Stat4", unit = "ms")
), // ),
statValues = mutableMapOf( // statValues = mutableMapOf(
"stat1" to 0.3f, // "stat1" to 0.3f,
"stat2" to 0.4f, // "stat2" to 0.4f,
"stat3" to 0.5f, // "stat3" to 0.5f,
), // ),
values = listOf(), // values = listOf(),
histogram = Histogram(listOf(), 0), // histogram = Histogram(listOf(), 0),
warmupCurrent = 0, // warmupCurrent = 0,
warmupTotal = 0, // warmupTotal = 0,
iterationCurrent = 0, // iterationCurrent = 0,
iterationTotal = 0, // iterationTotal = 0,
highlightStat = "stat2" // highlightStat = "stat2"
) // )
) // )
} // }
} // }

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
@ -23,9 +25,7 @@ import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/** /**
* Composable function to display benchmark LLM results within a chat message. * Composable function to display benchmark LLM results within a chat message.
@ -34,41 +34,32 @@ import com.google.ai.edge.gallery.ui.theme.GalleryTheme
*/ */
@Composable @Composable
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult, modifier: Modifier = Modifier) { fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult, modifier: Modifier = Modifier) {
Column( Column(modifier = modifier.padding(12.dp), verticalArrangement = Arrangement.spacedBy(8.dp)) {
modifier = modifier.padding(12.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
// Data cards. // Data cards.
Row( Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween) {
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) {
for (stat in message.orderedStats) { for (stat in message.orderedStats) {
DataCard( DataCard(label = stat.label, unit = stat.unit, value = message.statValues[stat.id])
label = stat.label,
unit = stat.unit,
value = message.statValues[stat.id],
)
} }
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageBodyBenchmarkLlmPreview() { // fun MessageBodyBenchmarkLlmPreview() {
GalleryTheme { // GalleryTheme {
MessageBodyBenchmarkLlm( // MessageBodyBenchmarkLlm(
message = ChatMessageBenchmarkLlmResult( // message = ChatMessageBenchmarkLlmResult(
orderedStats = listOf( // orderedStats = listOf(
Stat(id = "stat1", label = "Stat1", unit = "tokens/s"), // Stat(id = "stat1", label = "Stat1", unit = "tokens/s"),
Stat(id = "stat2", label = "Stat2", unit = "tokens/s") // Stat(id = "stat2", label = "Stat2", unit = "tokens/s")
), // ),
statValues = mutableMapOf( // statValues = mutableMapOf(
"stat1" to 0.3f, // "stat1" to 0.3f,
"stat2" to 0.4f, // "stat2" to 0.4f,
), // ),
running = false, // running = false,
) // )
) // )
} // }
} // }

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
@ -32,11 +34,8 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
val CLASSIFICATION_BAR_HEIGHT = 8.dp val CLASSIFICATION_BAR_HEIGHT = 8.dp
val CLASSIFICATION_BAR_MAX_WIDTH = 200.dp val CLASSIFICATION_BAR_MAX_WIDTH = 200.dp
@ -44,7 +43,8 @@ val CLASSIFICATION_BAR_MAX_WIDTH = 200.dp
/** /**
* Composable function to display classification results. * Composable function to display classification results.
* *
* This function renders a list of classifications, each with its label, score, and a visual score bar. * This function renders a list of classifications, each with its label, score, and a visual score
* bar.
*/ */
@Composable @Composable
fun MessageBodyClassification( fun MessageBodyClassification(
@ -52,45 +52,40 @@ fun MessageBodyClassification(
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
oneLineLabel: Boolean = false, oneLineLabel: Boolean = false,
) { ) {
Column( Column(modifier = modifier.padding(12.dp)) {
modifier = modifier.padding(12.dp)
) {
for (classification in message.classifications) { for (classification in message.classifications) {
Row( Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween) {
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) {
// Classification label. // Classification label.
Text( Text(
classification.label, classification.label,
maxLines = if (oneLineLabel) 1 else Int.MAX_VALUE, maxLines = if (oneLineLabel) 1 else Int.MAX_VALUE,
overflow = TextOverflow.Ellipsis, overflow = TextOverflow.Ellipsis,
style = MaterialTheme.typography.bodySmall, style = MaterialTheme.typography.bodySmall,
modifier = Modifier.weight(1f) modifier = Modifier.weight(1f),
) )
// Classification score. // Classification score.
Text( Text(
"%.2f".format(classification.score), "%.2f".format(classification.score),
style = MaterialTheme.typography.bodySmall, style = MaterialTheme.typography.bodySmall,
modifier = Modifier modifier = Modifier.align(Alignment.Bottom),
.align(Alignment.Bottom),
) )
} }
Spacer(modifier = Modifier.height(2.dp)) Spacer(modifier = Modifier.height(2.dp))
// Score bar. // Score bar.
Box { Box {
Box( Box(
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.height(CLASSIFICATION_BAR_HEIGHT) .height(CLASSIFICATION_BAR_HEIGHT)
.clip(CircleShape) .clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceDim) .background(MaterialTheme.colorScheme.surfaceDim)
) )
Box( Box(
modifier = Modifier modifier =
.fillMaxWidth(classification.score) Modifier.fillMaxWidth(classification.score)
.height(CLASSIFICATION_BAR_HEIGHT) .height(CLASSIFICATION_BAR_HEIGHT)
.clip(CircleShape) .clip(CircleShape)
.background(classification.color) .background(classification.color)
) )
} }
Spacer(modifier = Modifier.height(6.dp)) Spacer(modifier = Modifier.height(6.dp))
@ -98,18 +93,20 @@ fun MessageBodyClassification(
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageBodyClassificationPreview() { // fun MessageBodyClassificationPreview() {
GalleryTheme { // GalleryTheme {
MessageBodyClassification( // MessageBodyClassification(
message = ChatMessageClassification( // message =
classifications = listOf( // ChatMessageClassification(
Classification(label = "label1", score = 0.3f, color = Color.Red), // classifications =
Classification(label = "label2", score = 0.7f, color = Color.Blue) // listOf(
), // Classification(label = "label1", score = 0.3f, color = Color.Red),
latencyMs = 12345f, // Classification(label = "label2", score = 0.7f, color = Color.Blue),
), // ),
) // latencyMs = 12345f,
} // )
} // )
// }
// }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
@ -34,33 +37,26 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.ui.common.convertValueToTargetType import com.google.ai.edge.gallery.data.getConfigValueString
import com.google.ai.edge.gallery.ui.common.getConfigValueString
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow
import com.google.ai.edge.gallery.ui.theme.titleSmaller import com.google.ai.edge.gallery.ui.theme.titleSmaller
/** /**
* Composable function to display a message indicating configuration value changes. * Composable function to display a message indicating configuration value changes.
* *
* This function renders a centered row containing a box that displays the old and new * This function renders a centered row containing a box that displays the old and new values of
* values of configuration settings that have been updated. * configuration settings that have been updated.
*/ */
@Composable @Composable
fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) { fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) {
Row( Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) {
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.Center,
) {
Box( Box(
modifier = Modifier modifier =
.clip(RoundedCornerShape(4.dp)) Modifier.clip(RoundedCornerShape(4.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer) .background(MaterialTheme.colorScheme.tertiaryContainer)
) { ) {
Column(modifier = Modifier.padding(8.dp)) { Column(modifier = Modifier.padding(8.dp)) {
// Title. // Title.
@ -74,11 +70,7 @@ fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) {
// Keys // Keys
Column { Column {
for (config in message.model.configs) { for (config in message.model.configs) {
Text( Text("${config.key.label}:", style = bodySmallNarrow, modifier = Modifier.alpha(0.6f))
"${config.key.label}:",
style = bodySmallNarrow,
modifier = Modifier.alpha(0.6f),
)
} }
} }
@ -88,23 +80,25 @@ fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) {
Column { Column {
for (config in message.model.configs) { for (config in message.model.configs) {
val key = config.key.label val key = config.key.label
val oldValue: Any = convertValueToTargetType( val oldValue: Any =
value = message.oldValues.getValue(key), valueType = config.valueType convertValueToTargetType(
) value = message.oldValues.getValue(key),
val newValue: Any = convertValueToTargetType( valueType = config.valueType,
value = message.newValues.getValue(key), valueType = config.valueType )
) val newValue: Any =
convertValueToTargetType(
value = message.newValues.getValue(key),
valueType = config.valueType,
)
if (oldValue == newValue) { if (oldValue == newValue) {
Text("$newValue", style = bodySmallNarrow) Text("$newValue", style = bodySmallNarrow)
} else { } else {
Row(verticalAlignment = Alignment.CenterVertically) { Row(verticalAlignment = Alignment.CenterVertically) {
Text( Text(getConfigValueString(oldValue, config), style = bodySmallNarrow)
getConfigValueString(oldValue, config), style = bodySmallNarrow
)
Text( Text(
"", "",
style = bodySmallNarrow.copy(fontSize = 12.sp), style = bodySmallNarrow.copy(fontSize = 12.sp),
modifier = Modifier.padding(start = 4.dp, end = 4.dp) modifier = Modifier.padding(start = 4.dp, end = 4.dp),
) )
Text( Text(
getConfigValueString(newValue, config), getConfigValueString(newValue, config),
@ -121,24 +115,24 @@ fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) {
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageBodyConfigUpdatePreview() { // fun MessageBodyConfigUpdatePreview() {
GalleryTheme { // GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) { // Row(modifier = Modifier.padding(16.dp)) {
MessageBodyConfigUpdate( // MessageBodyConfigUpdate(
message = ChatMessageConfigValuesChange( // message = ChatMessageConfigValuesChange(
model = MODEL_TEST1, // model = MODEL_TEST1,
oldValues = mapOf( // oldValues = mapOf(
ConfigKey.MAX_RESULT_COUNT.label to 100, // ConfigKey.MAX_RESULT_COUNT.label to 100,
ConfigKey.USE_GPU.label to false // ConfigKey.USE_GPU.label to false
), // ),
newValues = mapOf( // newValues = mapOf(
ConfigKey.MAX_RESULT_COUNT.label to 200, // ConfigKey.MAX_RESULT_COUNT.label to 200,
ConfigKey.USE_GPU.label to true // ConfigKey.USE_GPU.label to true
) // )
) // )
) // )
} // }
} // }
} // }

View file

@ -36,9 +36,7 @@ fun MessageBodyImage(message: ChatMessageImage, modifier: Modifier = Modifier) {
Image( Image(
bitmap = message.imageBitMap, bitmap = message.imageBitMap,
contentDescription = "", contentDescription = "",
modifier = modifier modifier = modifier.height(imageHeight.dp).width(imageWidth.dp),
.height(imageHeight.dp)
.width(imageWidth.dp),
contentScale = ContentScale.Fit, contentScale = ContentScale.Fit,
) )
} }

View file

@ -43,7 +43,7 @@ import androidx.compose.ui.unit.dp
@Composable @Composable
fun MessageBodyImageWithHistory( fun MessageBodyImageWithHistory(
message: ChatMessageImageWithHistory, message: ChatMessageImageWithHistory,
imageHistoryCurIndex: MutableIntState imageHistoryCurIndex: MutableIntState,
) { ) {
val prevMessage: MutableState<ChatMessageImageWithHistory?> = remember { mutableStateOf(null) } val prevMessage: MutableState<ChatMessageImageWithHistory?> = remember { mutableStateOf(null) }
@ -68,15 +68,15 @@ fun MessageBodyImageWithHistory(
Image( Image(
bitmap = curImageBitmap, bitmap = curImageBitmap,
contentDescription = "", contentDescription = "",
modifier = Modifier modifier =
.height(imageHeight.dp) Modifier.height(imageHeight.dp).width(imageWidth.dp).pointerInput(Unit) {
.width(imageWidth.dp) detectHorizontalDragGestures(
.pointerInput(Unit) { onDragStart = {
detectHorizontalDragGestures(onDragStart = { value = 0f
value = 0f savedIndex = imageHistoryCurIndex.intValue
savedIndex = imageHistoryCurIndex.intValue }
}) { _, dragAmount -> ) { _, dragAmount ->
value += (dragAmount / 20f)// Adjust sensitivity here value += (dragAmount / 20f) // Adjust sensitivity here
imageHistoryCurIndex.intValue = (savedIndex + value).toInt() imageHistoryCurIndex.intValue = (savedIndex + value).toInt()
if (imageHistoryCurIndex.intValue < 0) { if (imageHistoryCurIndex.intValue < 0) {
imageHistoryCurIndex.intValue = 0 imageHistoryCurIndex.intValue = 0
@ -88,4 +88,4 @@ fun MessageBodyImageWithHistory(
contentScale = ContentScale.Fit, contentScale = ContentScale.Fit,
) )
} }
} }

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
@ -27,9 +29,8 @@ import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme import com.google.ai.edge.gallery.ui.common.MarkdownText
import com.google.ai.edge.gallery.ui.theme.customColors import com.google.ai.edge.gallery.ui.theme.customColors
/** /**
@ -39,29 +40,27 @@ import com.google.ai.edge.gallery.ui.theme.customColors
*/ */
@Composable @Composable
fun MessageBodyInfo(message: ChatMessageInfo, smallFontSize: Boolean = true) { fun MessageBodyInfo(message: ChatMessageInfo, smallFontSize: Boolean = true) {
Row( Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) {
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
) {
Box( Box(
modifier = Modifier modifier =
.clip(RoundedCornerShape(16.dp)) Modifier.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor) .background(MaterialTheme.customColors.agentBubbleBgColor)
) { ) {
MarkdownText( MarkdownText(
text = message.content, text = message.content,
modifier = Modifier.padding(12.dp), modifier = Modifier.padding(12.dp),
smallFontSize = smallFontSize smallFontSize = smallFontSize,
) )
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageBodyInfoPreview() { // fun MessageBodyInfoPreview() {
GalleryTheme { // GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) { // Row(modifier = Modifier.padding(16.dp)) {
MessageBodyInfo(message = ChatMessageInfo(content = "This is a model")) // MessageBodyInfo(message = ChatMessageInfo(content = "This is a model"))
} // }
} // }
} // }

View file

@ -16,12 +16,12 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.animation.core.Animatable import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween import androidx.compose.animation.core.tween
import androidx.compose.foundation.Image import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
@ -33,29 +33,21 @@ import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.res.painterResource import androidx.compose.ui.res.painterResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.common.getTaskIconColor import com.google.ai.edge.gallery.ui.common.getTaskIconColor
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
private val IMAGE_RESOURCES = listOf( private val IMAGE_RESOURCES =
R.drawable.pantegon, listOf(R.drawable.pantegon, R.drawable.double_circle, R.drawable.circle, R.drawable.four_circle)
R.drawable.double_circle,
R.drawable.circle,
R.drawable.four_circle
)
private const val ANIMATION_DURATION = 300 private const val ANIMATION_DURATION = 300
private const val ANIMATION_DURATION2 = 300 private const val ANIMATION_DURATION2 = 300
private const val PAUSE_DURATION = 200 private const val PAUSE_DURATION = 200
private const val PAUSE_DURATION2 = 0 private const val PAUSE_DURATION2 = 0
/** /** Composable function to display a loading indicator. */
* Composable function to display a loading indicator.
*/
@Composable @Composable
fun MessageBodyLoading() { fun MessageBodyLoading() {
val progress = remember { Animatable(0f) } val progress = remember { Animatable(0f) }
@ -67,18 +59,17 @@ fun MessageBodyLoading() {
var progressJob = launch { var progressJob = launch {
progress.animateTo( progress.animateTo(
targetValue = 1f, targetValue = 1f,
animationSpec = tween( animationSpec =
durationMillis = ANIMATION_DURATION, tween(
easing = multiBounceEasing(bounces = 3, decay = 0.02f) durationMillis = ANIMATION_DURATION,
) easing = multiBounceEasing(bounces = 3, decay = 0.02f),
),
) )
} }
var alphaJob = launch { var alphaJob = launch {
alphaAnim.animateTo( alphaAnim.animateTo(
targetValue = 1f, targetValue = 1f,
animationSpec = tween( animationSpec = tween(durationMillis = ANIMATION_DURATION / 2),
durationMillis = ANIMATION_DURATION / 2,
)
) )
} }
progressJob.join() progressJob.join()
@ -88,18 +79,17 @@ fun MessageBodyLoading() {
progressJob = launch { progressJob = launch {
progress.animateTo( progress.animateTo(
targetValue = 0f, targetValue = 0f,
animationSpec = tween( animationSpec =
durationMillis = ANIMATION_DURATION2, tween(
easing = multiBounceEasing(bounces = 3, decay = 0.02f) durationMillis = ANIMATION_DURATION2,
) easing = multiBounceEasing(bounces = 3, decay = 0.02f),
),
) )
} }
alphaJob = launch { alphaJob = launch {
alphaAnim.animateTo( alphaAnim.animateTo(
targetValue = 0f, targetValue = 0f,
animationSpec = tween( animationSpec = tween(durationMillis = ANIMATION_DURATION2 / 2),
durationMillis = ANIMATION_DURATION2 / 2,
)
) )
} }
@ -118,25 +108,21 @@ fun MessageBodyLoading() {
contentDescription = "", contentDescription = "",
contentScale = ContentScale.Fit, contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint(getTaskIconColor(index = index)), colorFilter = ColorFilter.tint(getTaskIconColor(index = index)),
modifier = Modifier modifier =
.graphicsLayer { Modifier.graphicsLayer {
scaleX = progress.value * 0.2f + 0.8f scaleX = progress.value * 0.2f + 0.8f
scaleY = progress.value * 0.2f + 0.8f scaleY = progress.value * 0.2f + 0.8f
rotationZ = progress.value * 100 rotationZ = progress.value * 100
alpha = if (index != activeImageIndex.intValue) 0f else alphaAnim.value alpha = if (index != activeImageIndex.intValue) 0f else alphaAnim.value
} }
.size(24.dp) .size(24.dp),
) )
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageBodyLoadingPreview() { // fun MessageBodyLoadingPreview() {
GalleryTheme { // GalleryTheme { Row(modifier = Modifier.padding(16.dp)) { MessageBodyLoading() } }
Row(modifier = Modifier.padding(16.dp)) { // }
MessageBodyLoading()
}
}
}

View file

@ -16,13 +16,17 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.ALL_PREVIEW_TASKS
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
@ -41,13 +45,10 @@ import androidx.compose.ui.draw.shadow
import androidx.compose.ui.graphics.Brush import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.PromptTemplate
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.getTaskIconColor import com.google.ai.edge.gallery.ui.common.getTaskIconColor
import com.google.ai.edge.gallery.ui.preview.ALL_PREVIEW_TASKS
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
private const val CARD_HEIGHT = 100 private const val CARD_HEIGHT = 100
@ -63,16 +64,15 @@ fun MessageBodyPromptTemplates(
Column( Column(
modifier = Modifier.padding(top = 12.dp), modifier = Modifier.padding(top = 12.dp),
verticalArrangement = Arrangement.spacedBy(8.dp) verticalArrangement = Arrangement.spacedBy(8.dp),
) { ) {
Text( Text(
"Try an example prompt", "Try an example prompt",
style = MaterialTheme.typography.titleLarge.copy( style =
fontWeight = FontWeight.Bold, MaterialTheme.typography.titleLarge.copy(
brush = Brush.linearGradient( fontWeight = FontWeight.Bold,
colors = gradientColors, brush = Brush.linearGradient(colors = gradientColors),
) ),
),
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth(),
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
) )
@ -80,41 +80,30 @@ fun MessageBodyPromptTemplates(
Text( Text(
"Or make your own", "Or make your own",
style = MaterialTheme.typography.titleSmall, style = MaterialTheme.typography.titleSmall,
modifier = Modifier modifier = Modifier.fillMaxWidth().offset(y = (-4).dp),
.fillMaxWidth()
.offset(y = (-4).dp),
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
) )
} }
LazyColumn( LazyColumn(
modifier = Modifier modifier = Modifier.height((rowCount * (CARD_HEIGHT + 8)).dp),
.height((rowCount * (CARD_HEIGHT + 8)).dp),
verticalArrangement = Arrangement.spacedBy(8.dp), verticalArrangement = Arrangement.spacedBy(8.dp),
) { ) {
// Cards. // Cards.
items(message.templates) { template -> items(message.templates) { template ->
Box( Box(
modifier = Modifier modifier =
.border( Modifier.border(
width = 1.dp, width = 1.dp,
color = color.copy(alpha = 0.3f), color = color.copy(alpha = 0.3f),
shape = RoundedCornerShape(24.dp) shape = RoundedCornerShape(24.dp),
) )
.height(CARD_HEIGHT.dp) .height(CARD_HEIGHT.dp)
.shadow( .shadow(elevation = 2.dp, shape = RoundedCornerShape(24.dp), spotColor = color)
elevation = 2.dp, .background(MaterialTheme.colorScheme.surface)
shape = RoundedCornerShape(24.dp), .clickable { onPromptClicked(template) }
spotColor = color
)
.background(MaterialTheme.colorScheme.surface)
.clickable {
onPromptClicked(template)
}
) { ) {
Column( Column(
modifier = Modifier modifier = Modifier.padding(horizontal = 12.dp, vertical = 20.dp).fillMaxSize(),
.padding(horizontal = 12.dp, vertical = 20.dp)
.fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally, horizontalAlignment = Alignment.CenterHorizontally,
) { ) {
Text( Text(
@ -134,35 +123,37 @@ fun MessageBodyPromptTemplates(
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageBodyPromptTemplatesPreview() { // fun MessageBodyPromptTemplatesPreview() {
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) { // for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
task.index = index // task.index = index
for (model in task.models) { // for (model in task.models) {
model.preProcess() // model.preProcess()
} // }
} // }
GalleryTheme { // GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) { // Row(modifier = Modifier.padding(16.dp)) {
MessageBodyPromptTemplates( // MessageBodyPromptTemplates(
message = ChatMessagePromptTemplates( // message =
templates = listOf( // ChatMessagePromptTemplates(
PromptTemplate( // templates =
title = "Math Worksheets", // listOf(
description = "Create a set of math worksheets for parents", // PromptTemplate(
prompt = "" // title = "Math Worksheets",
), // description = "Create a set of math worksheets for parents",
PromptTemplate( // prompt = "",
title = "Shape Sequencer", // ),
description = "Find the next shape in a sequence", // PromptTemplate(
prompt = "" // title = "Shape Sequencer",
) // description = "Find the next shape in a sequence",
) // prompt = "",
), // ),
task = TASK_TEST1, // )
) // ),
} // task = TASK_TEST1,
} // )
} // }
// }
// }

View file

@ -16,9 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
import androidx.compose.foundation.background // import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Column // import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text import androidx.compose.material3.Text
@ -26,13 +26,10 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme import com.google.ai.edge.gallery.ui.common.MarkdownText
/** /** Composable function to display the text content of a ChatMessageText. */
* Composable function to display the text content of a ChatMessageText.
*/
@Composable @Composable
fun MessageBodyText(message: ChatMessageText) { fun MessageBodyText(message: ChatMessageText) {
if (message.side == ChatSide.USER) { if (message.side == ChatSide.USER) {
@ -40,7 +37,7 @@ fun MessageBodyText(message: ChatMessageText) {
message.content, message.content,
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.Medium), style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.Medium),
color = Color.White, color = Color.White,
modifier = Modifier.padding(12.dp) modifier = Modifier.padding(12.dp),
) )
} else if (message.side == ChatSide.AGENT) { } else if (message.side == ChatSide.AGENT) {
if (message.isMarkdown) { if (message.isMarkdown) {
@ -50,31 +47,25 @@ fun MessageBodyText(message: ChatMessageText) {
message.content, message.content,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.Medium), style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.Medium),
color = MaterialTheme.colorScheme.onSurface, color = MaterialTheme.colorScheme.onSurface,
modifier = Modifier.padding(12.dp) modifier = Modifier.padding(12.dp),
) )
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageBodyTextPreview() { // fun MessageBodyTextPreview() {
GalleryTheme { // GalleryTheme {
Column { // Column {
Row( // Row(modifier = Modifier.padding(16.dp).background(MaterialTheme.colorScheme.primary)) {
modifier = Modifier // MessageBodyText(ChatMessageText(content = "Hello world", side = ChatSide.USER))
.padding(16.dp) // }
.background(MaterialTheme.colorScheme.primary), // Row(
) { // modifier = Modifier.padding(16.dp).background(MaterialTheme.colorScheme.surfaceContainer)
MessageBodyText(ChatMessageText(content = "Hello world", side = ChatSide.USER)) // ) {
} // MessageBodyText(ChatMessageText(content = "yes hello world", side = ChatSide.AGENT))
Row( // }
modifier = Modifier // }
.padding(16.dp) // }
.background(MaterialTheme.colorScheme.surfaceContainer), // }
) {
MessageBodyText(ChatMessageText(content = "yes hello world", side = ChatSide.AGENT))
}
}
}
}

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
@ -27,9 +29,8 @@ import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme import com.google.ai.edge.gallery.ui.common.MarkdownText
/** /**
* Composable function to display warning message content within a chat. * Composable function to display warning message content within a chat.
@ -38,29 +39,27 @@ import com.google.ai.edge.gallery.ui.theme.GalleryTheme
*/ */
@Composable @Composable
fun MessageBodyWarning(message: ChatMessageWarning) { fun MessageBodyWarning(message: ChatMessageWarning) {
Row( Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) {
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
) {
Box( Box(
modifier = Modifier modifier =
.clip(RoundedCornerShape(16.dp)) Modifier.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer) .background(MaterialTheme.colorScheme.tertiaryContainer)
) { ) {
MarkdownText( MarkdownText(
text = message.content, text = message.content,
modifier = Modifier.padding(horizontal = 16.dp, vertical = 6.dp), modifier = Modifier.padding(horizontal = 16.dp, vertical = 6.dp),
smallFontSize = true smallFontSize = true,
) )
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageBodyWarningPreview() { // fun MessageBodyWarningPreview() {
GalleryTheme { // GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) { // Row(modifier = Modifier.padding(16.dp)) {
MessageBodyWarning(message = ChatMessageWarning(content = "This is a warning")) // MessageBodyWarning(message = ChatMessageWarning(content = "This is a warning"))
} // }
} // }
} // }

View file

@ -29,41 +29,39 @@ import androidx.compose.ui.unit.LayoutDirection
/** /**
* Custom Shape for creating message bubble outlines with configurable corner radii. * Custom Shape for creating message bubble outlines with configurable corner radii.
* *
* This class defines a custom Shape that generates a rounded rectangle outline, * This class defines a custom Shape that generates a rounded rectangle outline, suitable for
* suitable for message bubbles. It allows specifying a uniform corner radius for * message bubbles. It allows specifying a uniform corner radius for most corners, but also provides
* most corners, but also provides the option to have a hard (non-rounded) corner * the option to have a hard (non-rounded) corner on either the left or right side.
* on either the left or right side.
*/ */
class MessageBubbleShape( class MessageBubbleShape(
private val radius: Dp, private val radius: Dp,
private val hardCornerAtLeftOrRight: Boolean = false private val hardCornerAtLeftOrRight: Boolean = false,
) : Shape { ) : Shape {
override fun createOutline( override fun createOutline(
size: Size, size: Size,
layoutDirection: LayoutDirection, layoutDirection: LayoutDirection,
density: Density density: Density,
): Outline { ): Outline {
val radiusPx = with(density) { radius.toPx() } val radiusPx = with(density) { radius.toPx() }
val path = Path().apply { val path =
addRoundRect( Path().apply {
RoundRect( addRoundRect(
left = 0f, RoundRect(
top = 0f, left = 0f,
right = size.width, top = 0f,
bottom = size.height, right = size.width,
topLeftCornerRadius = if (hardCornerAtLeftOrRight) CornerRadius(0f, 0f) else CornerRadius( bottom = size.height,
radiusPx, topLeftCornerRadius =
radiusPx if (hardCornerAtLeftOrRight) CornerRadius(0f, 0f)
), else CornerRadius(radiusPx, radiusPx),
topRightCornerRadius = if (hardCornerAtLeftOrRight) CornerRadius( topRightCornerRadius =
radiusPx, if (hardCornerAtLeftOrRight) CornerRadius(radiusPx, radiusPx)
radiusPx else CornerRadius(0f, 0f), // No rounding here
) else CornerRadius(0f, 0f), // No rounding here bottomLeftCornerRadius = CornerRadius(radiusPx, radiusPx),
bottomLeftCornerRadius = CornerRadius(radiusPx, radiusPx), bottomRightCornerRadius = CornerRadius(radiusPx, radiusPx),
bottomRightCornerRadius = CornerRadius(radiusPx, radiusPx) )
) )
) }
}
return Outline.Generic(path) return Outline.Generic(path)
} }
} }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.Manifest import android.Manifest
import android.content.Context import android.content.Context
import android.content.pm.PackageManager import android.content.pm.PackageManager
@ -28,7 +31,6 @@ import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.PickVisualMediaRequest import androidx.activity.result.PickVisualMediaRequest
import androidx.activity.result.contract.ActivityResultContracts import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
@ -49,11 +51,9 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.core.content.ContextCompat import androidx.core.content.ContextCompat
import com.google.ai.edge.gallery.ui.common.createTempPictureUri import com.google.ai.edge.gallery.ui.common.createTempPictureUri
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
private const val TAG = "AGMessageInputImage" private const val TAG = "AGMessageInputImage"
@ -102,30 +102,28 @@ fun MessageInputImage(
} }
// Permission request when taking picture. // Permission request when taking picture.
val takePicturePermissionLauncher = rememberLauncherForActivityResult( val takePicturePermissionLauncher =
ActivityResultContracts.RequestPermission() rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
) { permissionGranted -> permissionGranted ->
if (permissionGranted) { if (permissionGranted) {
tempPhotoUri = context.createTempPictureUri() tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri) cameraLauncher.launch(tempPhotoUri)
}
} }
}
// Permission request when using live camera. // Permission request when using live camera.
val liveCameraPermissionLauncher = rememberLauncherForActivityResult( val liveCameraPermissionLauncher =
ActivityResultContracts.RequestPermission() rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
) { permissionGranted -> permissionGranted ->
if (permissionGranted) { if (permissionGranted) {
showLiveCameraDialog = true showLiveCameraDialog = true
}
} }
}
val buttonAlpha = if (disableButtons) 0.3f else 1f val buttonAlpha = if (disableButtons) 0.3f else 1f
Row( Row(
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(12.dp),
.fillMaxWidth()
.padding(12.dp),
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.End, horizontalArrangement = Arrangement.End,
) { ) {
@ -139,9 +137,8 @@ fun MessageInputImage(
// Launch the photo picker and let the user choose only images. // Launch the photo picker and let the user choose only images.
pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly)) pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly))
}, },
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.colorScheme.primary, IconButtonDefaults.iconButtonColors(containerColor = MaterialTheme.colorScheme.primary),
),
modifier = Modifier.alpha(buttonAlpha), modifier = Modifier.alpha(buttonAlpha),
) { ) {
Icon(Icons.Rounded.Photo, contentDescription = "", tint = MaterialTheme.colorScheme.onPrimary) Icon(Icons.Rounded.Photo, contentDescription = "", tint = MaterialTheme.colorScheme.onPrimary)
@ -157,9 +154,7 @@ fun MessageInputImage(
// Check permission // Check permission
when (PackageManager.PERMISSION_GRANTED) { when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda. // Already got permission. Call the lambda.
ContextCompat.checkSelfPermission( ContextCompat.checkSelfPermission(context, Manifest.permission.CAMERA) -> {
context, Manifest.permission.CAMERA
) -> {
tempPhotoUri = context.createTempPictureUri() tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri) cameraLauncher.launch(tempPhotoUri)
} }
@ -170,15 +165,14 @@ fun MessageInputImage(
} }
} }
}, },
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.colorScheme.primary, IconButtonDefaults.iconButtonColors(containerColor = MaterialTheme.colorScheme.primary),
),
modifier = Modifier.alpha(buttonAlpha), modifier = Modifier.alpha(buttonAlpha),
) { ) {
Icon( Icon(
Icons.Rounded.PhotoCamera, Icons.Rounded.PhotoCamera,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.onPrimary tint = MaterialTheme.colorScheme.onPrimary,
) )
} }
@ -192,9 +186,7 @@ fun MessageInputImage(
// Check permission // Check permission
when (PackageManager.PERMISSION_GRANTED) { when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda. // Already got permission. Call the lambda.
ContextCompat.checkSelfPermission( ContextCompat.checkSelfPermission(context, Manifest.permission.CAMERA) -> {
context, Manifest.permission.CAMERA
) -> {
showLiveCameraDialog = true showLiveCameraDialog = true
} }
@ -204,25 +196,30 @@ fun MessageInputImage(
} }
} }
}, },
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.colorScheme.primary, IconButtonDefaults.iconButtonColors(containerColor = MaterialTheme.colorScheme.primary),
),
modifier = Modifier.alpha(buttonAlpha), modifier = Modifier.alpha(buttonAlpha),
) { ) {
Icon( Icon(
Icons.Rounded.Videocam, contentDescription = "", tint = MaterialTheme.colorScheme.onPrimary Icons.Rounded.Videocam,
contentDescription = "",
tint = MaterialTheme.colorScheme.onPrimary,
) )
} }
} }
// Live camera stream dialog. // Live camera stream dialog.
if (showLiveCameraDialog) { if (showLiveCameraDialog) {
LiveCameraDialog( // TODO(migration)
streamingMessage = streamingMessage, onDismissed = { averageFps -> //
onStreamEnd(averageFps) // LiveCameraDialog(
showLiveCameraDialog = false // streamingMessage = streamingMessage,
}, onBitmap = onStreamImage // onDismissed = { averageFps ->
) // onStreamEnd(averageFps)
// showLiveCameraDialog = false
// },
// onBitmap = onStreamImage,
// )
} }
} }
@ -237,33 +234,33 @@ private fun handleImageSelected(
) { ) {
Log.d(TAG, "Selected URI: $uri") Log.d(TAG, "Selected URI: $uri")
val bitmap: Bitmap? = try { val bitmap: Bitmap? =
val inputStream = context.contentResolver.openInputStream(uri) try {
val tmpBitmap = BitmapFactory.decodeStream(inputStream) val inputStream = context.contentResolver.openInputStream(uri)
if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) { val tmpBitmap = BitmapFactory.decodeStream(inputStream)
val matrix = Matrix() if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) {
matrix.postRotate(90f) val matrix = Matrix()
Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true) matrix.postRotate(90f)
} else { Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true)
tmpBitmap } else {
tmpBitmap
}
} catch (e: Exception) {
e.printStackTrace()
null
} }
} catch (e: Exception) {
e.printStackTrace()
null
}
if (bitmap != null) { if (bitmap != null) {
onImageSelected(bitmap) onImageSelected(bitmap)
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageInputImagePreview() { // fun MessageInputImagePreview() {
GalleryTheme { // GalleryTheme {
Column { // Column {
MessageInputImage(onImageSelected = {}) // MessageInputImage(onImageSelected = {})
MessageInputImage(disableButtons = true, onImageSelected = {}) // MessageInputImage(disableButtons = true, onImageSelected = {})
} // }
} // }
} // }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.Manifest import android.Manifest
import android.content.Context import android.content.Context
import android.content.pm.PackageManager import android.content.pm.PackageManager
@ -43,9 +46,9 @@ import androidx.compose.foundation.Image
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.horizontalScroll
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
@ -55,6 +58,7 @@ import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.width
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
@ -98,26 +102,22 @@ import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.viewinterop.AndroidView import androidx.compose.ui.viewinterop.AndroidView
import androidx.core.content.ContextCompat import androidx.core.content.ContextCompat
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
import com.google.ai.edge.gallery.ui.common.createTempPictureUri
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.launch
import java.util.concurrent.Executors import java.util.concurrent.Executors
import kotlinx.coroutines.launch
private const val TAG = "AGMessageInputText" private const val TAG = "AGMessageInputText"
/** /**
* Composable function to display a text input field for composing chat messages. * Composable function to display a text input field for composing chat messages.
* *
* This function renders a row containing a text field for message input and a send button. * This function renders a row containing a text field for message input and a send button. It
* It handles message composition, input validation, and sending messages. * handles message composition, input validation, and sending messages.
*/ */
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
@ -126,7 +126,7 @@ fun MessageInputText(
curMessage: String, curMessage: String,
isResettingSession: Boolean, isResettingSession: Boolean,
inProgress: Boolean, inProgress: Boolean,
hasImageMessage: Boolean, imageMessageCount: Int,
modelInitializing: Boolean, modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int, @StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit, onValueChanged: (String) -> Unit,
@ -145,73 +145,80 @@ fun MessageInputText(
var showTextInputHistorySheet by remember { mutableStateOf(false) } var showTextInputHistorySheet by remember { mutableStateOf(false) }
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) } var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) } var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
val updatePickedImages: (Bitmap) -> Unit = { bitmap -> val updatePickedImages: (List<Bitmap>) -> Unit = { bitmaps ->
val newPickedImages: MutableList<Bitmap> = mutableListOf() var newPickedImages: MutableList<Bitmap> = mutableListOf()
newPickedImages.addAll(pickedImages) newPickedImages.addAll(pickedImages)
newPickedImages.add(bitmap) newPickedImages.addAll(bitmaps)
if (newPickedImages.size > MAX_IMAGE_COUNT) {
newPickedImages = newPickedImages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
}
pickedImages = newPickedImages.toList() pickedImages = newPickedImages.toList()
} }
var hasFrontCamera by remember { mutableStateOf(false) } var hasFrontCamera by remember { mutableStateOf(false) }
LaunchedEffect(Unit) { LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) }
checkFrontCamera(context = context, callback = { hasFrontCamera = it })
}
// Permission request when taking picture. // Permission request when taking picture.
val takePicturePermissionLauncher = rememberLauncherForActivityResult( val takePicturePermissionLauncher =
ActivityResultContracts.RequestPermission() rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
) { permissionGranted -> permissionGranted ->
if (permissionGranted) { if (permissionGranted) {
showAddContentMenu = false showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri() showCameraCaptureBottomSheet = true
showCameraCaptureBottomSheet = true }
} }
}
// Registers a photo picker activity launcher in single-select mode. // Registers a photo picker activity launcher in single-select mode.
val pickMedia = val pickMedia =
rememberLauncherForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri -> rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris ->
// Callback is invoked after the user selects a media item or closes the // Callback is invoked after the user selects media items or closes the
// photo picker. // photo picker.
if (uri != null) { if (uris.isNotEmpty()) {
handleImageSelected(context = context, uri = uri, onImageSelected = { bitmap -> handleImagesSelected(
updatePickedImages(bitmap) context = context,
}) uris = uris,
onImagesSelected = { bitmaps -> updatePickedImages(bitmaps) },
)
} }
} }
Box(contentAlignment = Alignment.CenterStart) { Box(contentAlignment = Alignment.CenterStart) {
// A preview panel for the selected image. // A preview panel for the selected image.
if (pickedImages.isNotEmpty()) { if (pickedImages.isNotEmpty()) {
Box( Row(
contentAlignment = Alignment.TopEnd, modifier = Modifier.offset(x = 16.dp, y = (-80).dp) modifier =
Modifier.offset(x = 16.dp, y = (-80).dp)
.fillMaxWidth()
.horizontalScroll(rememberScrollState()),
horizontalArrangement = Arrangement.spacedBy(16.dp),
) { ) {
Image( for (image in pickedImages) {
bitmap = pickedImages.last().asImageBitmap(), Box(contentAlignment = Alignment.TopEnd) {
contentDescription = "", Image(
modifier = Modifier bitmap = image.asImageBitmap(),
.height(80.dp) contentDescription = "",
.shadow(2.dp, shape = RoundedCornerShape(8.dp)) modifier =
.clip(RoundedCornerShape(8.dp)) Modifier.height(80.dp)
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)), .shadow(2.dp, shape = RoundedCornerShape(8.dp))
) .clip(RoundedCornerShape(8.dp))
Box(modifier = Modifier .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
.offset(x = 10.dp, y = (-10).dp) )
.clip(CircleShape) Box(
.background(MaterialTheme.colorScheme.surface) modifier =
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape) Modifier.offset(x = 10.dp, y = (-10).dp)
.clickable { .clip(CircleShape)
pickedImages = listOf() .background(MaterialTheme.colorScheme.surface)
}) { .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
Icon( .clickable { pickedImages = pickedImages.filter { image != it } }
Icons.Rounded.Close, ) {
contentDescription = "", Icon(
modifier = Modifier Icons.Rounded.Close,
.padding(3.dp) contentDescription = "",
.size(16.dp) modifier = Modifier.padding(3.dp).size(16.dp),
) )
}
}
} }
} }
} }
@ -220,48 +227,41 @@ fun MessageInputText(
IconButton( IconButton(
enabled = !inProgress && !isResettingSession, enabled = !inProgress && !isResettingSession,
onClick = { showAddContentMenu = true }, onClick = { showAddContentMenu = true },
modifier = Modifier modifier = Modifier.offset(x = 16.dp).alpha(0.8f),
.offset(x = 16.dp)
.alpha(0.8f)
) { ) {
Icon( Icon(Icons.Rounded.Add, contentDescription = "", modifier = Modifier.size(28.dp))
Icons.Rounded.Add,
contentDescription = "",
modifier = Modifier.size(28.dp),
)
} }
Row( Row(
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.padding(12.dp) .padding(12.dp)
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(28.dp)), .border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(28.dp)),
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
) { ) {
val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT
DropdownMenu( DropdownMenu(
expanded = showAddContentMenu, expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false }) { onDismissRequest = { showAddContentMenu = false },
) {
if (showImagePickerInMenu) { if (showImagePickerInMenu) {
// Take a picture. // Take a picture.
DropdownMenuItem( DropdownMenuItem(
text = { text = {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp) horizontalArrangement = Arrangement.spacedBy(6.dp),
) { ) {
Icon(Icons.Rounded.PhotoCamera, contentDescription = "") Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
Text("Take a picture") Text("Take a picture")
} }
}, },
enabled = pickedImages.isEmpty() && !hasImageMessage, enabled = enableAddImageMenuItems,
onClick = { onClick = {
// Check permission // Check permission
when (PackageManager.PERMISSION_GRANTED) { when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda. // Already got permission. Call the lambda.
ContextCompat.checkSelfPermission( ContextCompat.checkSelfPermission(context, Manifest.permission.CAMERA) -> {
context, Manifest.permission.CAMERA
) -> {
showAddContentMenu = false showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri()
showCameraCaptureBottomSheet = true showCameraCaptureBottomSheet = true
} }
@ -270,75 +270,86 @@ fun MessageInputText(
takePicturePermissionLauncher.launch(Manifest.permission.CAMERA) takePicturePermissionLauncher.launch(Manifest.permission.CAMERA)
} }
} }
}) },
)
// Pick an image from album. // Pick an image from album.
DropdownMenuItem( DropdownMenuItem(
text = { text = {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp) horizontalArrangement = Arrangement.spacedBy(6.dp),
) { ) {
Icon(Icons.Rounded.Photo, contentDescription = "") Icon(Icons.Rounded.Photo, contentDescription = "")
Text("Pick from album") Text("Pick from album")
} }
}, },
enabled = pickedImages.isEmpty() && !hasImageMessage, enabled = enableAddImageMenuItems,
onClick = { onClick = {
// Launch the photo picker and let the user choose only images. // Launch the photo picker and let the user choose only images.
pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly)) pickMedia.launch(
PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly)
)
showAddContentMenu = false showAddContentMenu = false
}) },
)
} }
// Prompt templates. // Prompt templates.
if (showPromptTemplatesInMenu) { if (showPromptTemplatesInMenu) {
DropdownMenuItem(text = { DropdownMenuItem(
Row( text = {
verticalAlignment = Alignment.CenterVertically, Row(
horizontalArrangement = Arrangement.spacedBy(6.dp) verticalAlignment = Alignment.CenterVertically,
) { horizontalArrangement = Arrangement.spacedBy(6.dp),
Icon(Icons.Rounded.PostAdd, contentDescription = "") ) {
Text("Prompt templates") Icon(Icons.Rounded.PostAdd, contentDescription = "")
} Text("Prompt templates")
}, onClick = { }
onOpenPromptTemplatesClicked() },
showAddContentMenu = false onClick = {
}) onOpenPromptTemplatesClicked()
showAddContentMenu = false
},
)
} }
// Prompt history. // Prompt history.
DropdownMenuItem(text = { DropdownMenuItem(
Row( text = {
verticalAlignment = Alignment.CenterVertically, Row(
horizontalArrangement = Arrangement.spacedBy(6.dp) verticalAlignment = Alignment.CenterVertically,
) { horizontalArrangement = Arrangement.spacedBy(6.dp),
Icon(Icons.Rounded.History, contentDescription = "") ) {
Text("Input history") Icon(Icons.Rounded.History, contentDescription = "")
} Text("Input history")
}, onClick = { }
showAddContentMenu = false },
showTextInputHistorySheet = true onClick = {
}) showAddContentMenu = false
showTextInputHistorySheet = true
},
)
} }
// Text field. // Text field.
TextField(value = curMessage, TextField(
value = curMessage,
minLines = 1, minLines = 1,
maxLines = 3, maxLines = 3,
onValueChange = onValueChanged, onValueChange = onValueChanged,
colors = TextFieldDefaults.colors( colors =
unfocusedContainerColor = Color.Transparent, TextFieldDefaults.colors(
focusedContainerColor = Color.Transparent, unfocusedContainerColor = Color.Transparent,
focusedIndicatorColor = Color.Transparent, focusedContainerColor = Color.Transparent,
unfocusedIndicatorColor = Color.Transparent, focusedIndicatorColor = Color.Transparent,
disabledIndicatorColor = Color.Transparent, unfocusedIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent, disabledIndicatorColor = Color.Transparent,
), disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyLarge, textStyle = MaterialTheme.typography.bodyLarge,
modifier = Modifier modifier = Modifier.weight(1f).padding(start = 36.dp),
.weight(1f) placeholder = { Text(stringResource(textFieldPlaceHolderRes)) },
.padding(start = 36.dp), )
placeholder = { Text(stringResource(textFieldPlaceHolderRes)) })
Spacer(modifier = Modifier.width(8.dp)) Spacer(modifier = Modifier.width(8.dp))
@ -346,12 +357,15 @@ fun MessageInputText(
if (!modelInitializing && !modelPreparing) { if (!modelInitializing && !modelPreparing) {
IconButton( IconButton(
onClick = onStopButtonClicked, onClick = onStopButtonClicked,
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.colorScheme.secondaryContainer, IconButtonDefaults.iconButtonColors(
), containerColor = MaterialTheme.colorScheme.secondaryContainer
),
) { ) {
Icon( Icon(
Icons.Rounded.Stop, contentDescription = "", tint = MaterialTheme.colorScheme.primary Icons.Rounded.Stop,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary,
) )
} }
} }
@ -365,15 +379,16 @@ fun MessageInputText(
) )
pickedImages = listOf() pickedImages = listOf()
}, },
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.colorScheme.secondaryContainer, IconButtonDefaults.iconButtonColors(
), containerColor = MaterialTheme.colorScheme.secondaryContainer
),
) { ) {
Icon( Icon(
Icons.AutoMirrored.Rounded.Send, Icons.AutoMirrored.Rounded.Send,
contentDescription = "", contentDescription = "",
modifier = Modifier.offset(x = 2.dp), modifier = Modifier.offset(x = 2.dp),
tint = MaterialTheme.colorScheme.primary tint = MaterialTheme.colorScheme.primary,
) )
} }
} }
@ -385,40 +400,37 @@ fun MessageInputText(
if (showTextInputHistorySheet) { if (showTextInputHistorySheet) {
TextInputHistorySheet( TextInputHistorySheet(
history = modelManagerUiState.textInputHistory, history = modelManagerUiState.textInputHistory,
onDismissed = { onDismissed = { showTextInputHistorySheet = false },
showTextInputHistorySheet = false
},
onHistoryItemClicked = { item -> onHistoryItemClicked = { item ->
onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item)) onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
pickedImages = listOf() pickedImages = listOf()
modelManagerViewModel.promoteTextInputHistoryItem(item) modelManagerViewModel.promoteTextInputHistoryItem(item)
}, },
onHistoryItemDeleted = { item -> onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) },
modelManagerViewModel.deleteTextInputHistory(item) onHistoryItemsDeleteAll = { modelManagerViewModel.clearTextInputHistory() },
}, )
onHistoryItemsDeleteAll = {
modelManagerViewModel.clearTextInputHistory()
})
} }
if (showCameraCaptureBottomSheet) { if (showCameraCaptureBottomSheet) {
ModalBottomSheet( ModalBottomSheet(
sheetState = cameraCaptureSheetState, sheetState = cameraCaptureSheetState,
onDismissRequest = { showCameraCaptureBottomSheet = false }) { onDismissRequest = { showCameraCaptureBottomSheet = false },
) {
val lifecycleOwner = LocalLifecycleOwner.current val lifecycleOwner = LocalLifecycleOwner.current
val previewUseCase = remember { androidx.camera.core.Preview.Builder().build() } val previewUseCase = remember { androidx.camera.core.Preview.Builder().build() }
val imageCaptureUseCase = remember { val imageCaptureUseCase = remember {
// Try to limit the image size. // Try to limit the image size.
val preferredSize = Size(512, 512) val preferredSize = Size(512, 512)
val resolutionStrategy = ResolutionStrategy( val resolutionStrategy =
preferredSize, ResolutionStrategy(
ResolutionStrategy.FALLBACK_RULE_CLOSEST_HIGHER_THEN_LOWER preferredSize,
) ResolutionStrategy.FALLBACK_RULE_CLOSEST_HIGHER_THEN_LOWER,
val resolutionSelector = ResolutionSelector.Builder() )
.setResolutionStrategy(resolutionStrategy) val resolutionSelector =
.setAspectRatioStrategy(AspectRatioStrategy.RATIO_4_3_FALLBACK_AUTO_STRATEGY) ResolutionSelector.Builder()
.build() .setResolutionStrategy(resolutionStrategy)
.setAspectRatioStrategy(AspectRatioStrategy.RATIO_4_3_FALLBACK_AUTO_STRATEGY)
.build()
ImageCapture.Builder().setResolutionSelector(resolutionSelector).build() ImageCapture.Builder().setResolutionSelector(resolutionSelector).build()
} }
@ -430,17 +442,16 @@ fun MessageInputText(
fun rebindCameraProvider() { fun rebindCameraProvider() {
cameraProvider?.let { cameraProvider -> cameraProvider?.let { cameraProvider ->
val cameraSelector = CameraSelector.Builder() val cameraSelector = CameraSelector.Builder().requireLensFacing(cameraSide).build()
.requireLensFacing(cameraSide)
.build()
try { try {
cameraProvider.unbindAll() cameraProvider.unbindAll()
val camera = cameraProvider.bindToLifecycle( val camera =
lifecycleOwner = lifecycleOwner, cameraProvider.bindToLifecycle(
cameraSelector = cameraSelector, lifecycleOwner = lifecycleOwner,
previewUseCase, cameraSelector = cameraSelector,
imageCaptureUseCase previewUseCase,
) imageCaptureUseCase,
)
cameraControl = camera.cameraControl cameraControl = camera.cameraControl
} catch (e: Exception) { } catch (e: Exception) {
Log.d(TAG, "Failed to bind camera", e) Log.d(TAG, "Failed to bind camera", e)
@ -453,15 +464,13 @@ fun MessageInputText(
rebindCameraProvider() rebindCameraProvider()
} }
LaunchedEffect(cameraSide) { LaunchedEffect(cameraSide) { rebindCameraProvider() }
rebindCameraProvider()
}
DisposableEffect(Unit) { // Or key on lifecycleOwner if it makes more sense DisposableEffect(Unit) { // Or key on lifecycleOwner if it makes more sense
onDispose { onDispose {
cameraProvider?.unbindAll() // Unbind all use cases from the camera provider cameraProvider?.unbindAll() // Unbind all use cases from the camera provider
if (!executor.isShutdown) { if (!executor.isShutdown) {
executor.shutdown() // Shut down the executor service executor.shutdown() // Shut down the executor service
} }
} }
} }
@ -485,54 +494,54 @@ fun MessageInputText(
cameraCaptureSheetState.hide() cameraCaptureSheetState.hide()
showCameraCaptureBottomSheet = false showCameraCaptureBottomSheet = false
} }
}, colors = IconButtonDefaults.iconButtonColors( },
containerColor = MaterialTheme.colorScheme.surfaceVariant, colors =
), modifier = Modifier IconButtonDefaults.iconButtonColors(
.offset(x = (-8).dp, y = 8.dp) containerColor = MaterialTheme.colorScheme.surfaceVariant
.align(Alignment.TopEnd) ),
modifier = Modifier.offset(x = (-8).dp, y = 8.dp).align(Alignment.TopEnd),
) { ) {
Icon( Icon(
Icons.Rounded.Close, Icons.Rounded.Close,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.primary tint = MaterialTheme.colorScheme.primary,
) )
} }
// Button that triggers the image capture process // Button that triggers the image capture process
IconButton( IconButton(
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.colorScheme.primary, IconButtonDefaults.iconButtonColors(containerColor = MaterialTheme.colorScheme.primary),
), modifier =
modifier = Modifier Modifier.align(Alignment.BottomCenter)
.align(Alignment.BottomCenter) .padding(bottom = 32.dp)
.padding(bottom = 32.dp) .size(64.dp)
.size(64.dp) .border(2.dp, MaterialTheme.colorScheme.onPrimary, CircleShape),
.border(2.dp, MaterialTheme.colorScheme.onPrimary, CircleShape),
onClick = { onClick = {
val callback = object : ImageCapture.OnImageCapturedCallback() { val callback =
override fun onCaptureSuccess(image: ImageProxy) { object : ImageCapture.OnImageCapturedCallback() {
try { override fun onCaptureSuccess(image: ImageProxy) {
var bitmap = image.toBitmap() try {
val rotation = image.imageInfo.rotationDegrees var bitmap = image.toBitmap()
bitmap = if (rotation != 0) { val rotation = image.imageInfo.rotationDegrees
val matrix = Matrix().apply { bitmap =
postRotate(rotation.toFloat()) if (rotation != 0) {
val matrix = Matrix().apply { postRotate(rotation.toFloat()) }
Log.d(TAG, "image size: ${bitmap.width}, ${bitmap.height}")
Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else bitmap
updatePickedImages(listOf(bitmap))
} catch (e: Exception) {
Log.e(TAG, "Failed to process image", e)
} finally {
image.close()
scope.launch {
cameraCaptureSheetState.hide()
showCameraCaptureBottomSheet = false
} }
Log.d(TAG, "image size: ${bitmap.width}, ${bitmap.height}")
Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else bitmap
updatePickedImages(bitmap)
} catch (e: Exception) {
Log.e(TAG, "Failed to process image", e)
} finally {
image.close()
scope.launch {
cameraCaptureSheetState.hide()
showCameraCaptureBottomSheet = false
} }
} }
} }
}
imageCaptureUseCase.takePicture(executor, callback) imageCaptureUseCase.takePicture(executor, callback)
}, },
) { ) {
@ -540,32 +549,32 @@ fun MessageInputText(
Icons.Rounded.PhotoCamera, Icons.Rounded.PhotoCamera,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.onPrimary, tint = MaterialTheme.colorScheme.onPrimary,
modifier = Modifier.size(36.dp) modifier = Modifier.size(36.dp),
) )
} }
// Button that toggles the front and back camera. // Button that toggles the front and back camera.
if (hasFrontCamera) { if (hasFrontCamera) {
IconButton( IconButton(
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.colorScheme.secondaryContainer, IconButtonDefaults.iconButtonColors(
), containerColor = MaterialTheme.colorScheme.secondaryContainer
modifier = Modifier ),
.align(Alignment.BottomEnd) modifier =
.padding(bottom = 40.dp, end = 32.dp) Modifier.align(Alignment.BottomEnd).padding(bottom = 40.dp, end = 32.dp).size(48.dp),
.size(48.dp),
onClick = { onClick = {
cameraSide = when (cameraSide) { cameraSide =
CameraSelector.LENS_FACING_BACK -> CameraSelector.LENS_FACING_FRONT when (cameraSide) {
else -> CameraSelector.LENS_FACING_BACK CameraSelector.LENS_FACING_BACK -> CameraSelector.LENS_FACING_FRONT
} else -> CameraSelector.LENS_FACING_BACK
}
}, },
) { ) {
Icon( Icon(
Icons.Rounded.FlipCameraAndroid, Icons.Rounded.FlipCameraAndroid,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.onSecondaryContainer, tint = MaterialTheme.colorScheme.onSecondaryContainer,
modifier = Modifier.size(24.dp) modifier = Modifier.size(24.dp),
) )
} }
} }
@ -574,25 +583,32 @@ fun MessageInputText(
} }
} }
private fun handleImageSelected( private fun handleImagesSelected(
context: Context, context: Context,
uri: Uri, uris: List<Uri>,
onImageSelected: (Bitmap) -> Unit, onImagesSelected: (List<Bitmap>) -> Unit,
// For some reason, some Android phone would store the picture taken by the camera rotated // For some reason, some Android phone would store the picture taken by the camera rotated
// horizontally. Use this flag to rotate the image back to portrait if the picture's width // horizontally. Use this flag to rotate the image back to portrait if the picture's width
// is bigger than height. // is bigger than height.
rotateForPortrait: Boolean = false, rotateForPortrait: Boolean = false,
) { ) {
val bitmap: Bitmap? = try { val images: MutableList<Bitmap> = mutableListOf()
val inputStream = context.contentResolver.openInputStream(uri) for (uri in uris) {
val tmpBitmap = BitmapFactory.decodeStream(inputStream) val bitmap: Bitmap? =
rotateImageIfNecessary(bitmap = tmpBitmap, rotateForPortrait = rotateForPortrait) try {
} catch (e: Exception) { val inputStream = context.contentResolver.openInputStream(uri)
e.printStackTrace() val tmpBitmap = BitmapFactory.decodeStream(inputStream)
null rotateImageIfNecessary(bitmap = tmpBitmap, rotateForPortrait = rotateForPortrait)
} catch (e: Exception) {
e.printStackTrace()
null
}
if (bitmap != null) {
images.add(bitmap)
}
} }
if (bitmap != null) { if (images.isNotEmpty()) {
onImageSelected(bitmap) onImagesSelected(images)
} }
} }
@ -608,106 +624,106 @@ private fun rotateImageIfNecessary(bitmap: Bitmap, rotateForPortrait: Boolean =
private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) { private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) {
val cameraProviderFuture = ProcessCameraProvider.getInstance(context) val cameraProviderFuture = ProcessCameraProvider.getInstance(context)
cameraProviderFuture.addListener({ cameraProviderFuture.addListener(
val cameraProvider = cameraProviderFuture.get() {
try { val cameraProvider = cameraProviderFuture.get()
// Attempt to select the default front camera try {
val hasFront = cameraProvider.hasCamera(CameraSelector.DEFAULT_FRONT_CAMERA) // Attempt to select the default front camera
callback(hasFront) val hasFront = cameraProvider.hasCamera(CameraSelector.DEFAULT_FRONT_CAMERA)
} catch (e: Exception) { callback(hasFront)
e.printStackTrace() } catch (e: Exception) {
callback(false) e.printStackTrace()
} callback(false)
}, ContextCompat.getMainExecutor(context)) }
},
ContextCompat.getMainExecutor(context),
)
} }
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> { private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> {
val messages: MutableList<ChatMessage> = mutableListOf() var messages: MutableList<ChatMessage> = mutableListOf()
if (pickedImages.isNotEmpty()) { if (pickedImages.isNotEmpty()) {
val lastImage = pickedImages.last() for (image in pickedImages) {
messages.add( messages.add(
ChatMessageImage( ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER)
bitmap = lastImage, imageBitMap = lastImage.asImageBitmap(), side = ChatSide.USER
) )
) }
} }
messages.add( // Cap the number of image messages.
ChatMessageText( if (messages.size > MAX_IMAGE_COUNT) {
content = text, side = ChatSide.USER messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
) }
) messages.add(ChatMessageText(content = text, side = ChatSide.USER))
return messages return messages
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageInputTextPreview() { // fun MessageInputTextPreview() {
val context = LocalContext.current // val context = LocalContext.current
GalleryTheme {
Column {
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = false,
isResettingSession = false,
modelInitializing = false,
hasImageMessage = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
showImagePickerInMenu = true,
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = false,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = true,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = false,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = true,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
)
}
}
}
// GalleryTheme {
// Column {
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "hello",
// inProgress = false,
// isResettingSession = false,
// modelInitializing = false,
// hasImageMessage = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// showStopButtonWhenInProgress = true,
// showImagePickerInMenu = true,
// )
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "hello",
// inProgress = false,
// isResettingSession = false,
// hasImageMessage = false,
// modelInitializing = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// showStopButtonWhenInProgress = true,
// )
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "hello",
// inProgress = true,
// isResettingSession = false,
// hasImageMessage = false,
// modelInitializing = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// )
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "",
// inProgress = false,
// isResettingSession = false,
// hasImageMessage = false,
// modelInitializing = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// )
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "",
// inProgress = true,
// isResettingSession = false,
// hasImageMessage = false,
// modelInitializing = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// showStopButtonWhenInProgress = true,
// )
// }
// }
// }

View file

@ -16,22 +16,17 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement // import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.foundation.layout.Column // import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.common.humanReadableDuration import com.google.ai.edge.gallery.ui.common.humanReadableDuration
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/** /** Composable function to display the latency of a chat message, if available. */
* Composable function to display the latency of a chat message, if available.
*/
@Composable @Composable
fun LatencyText(message: ChatMessage) { fun LatencyText(message: ChatMessage) {
if (message.latencyMs >= 0) { if (message.latencyMs >= 0) {
@ -43,21 +38,19 @@ fun LatencyText(message: ChatMessage) {
} }
} }
// @Preview(showBackground = true)
@Preview(showBackground = true) // @Composable
@Composable // fun LatencyTextPreview() {
fun LatencyTextPreview() { // GalleryTheme {
GalleryTheme { // Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp))
Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp)) { // {
for (latencyMs in listOf(123f, 1234f, 123456f, 7234567f)) { // for (latencyMs in listOf(123f, 1234f, 123456f, 7234567f)) {
LatencyText( // LatencyText(
message = ChatMessage( // message =
latencyMs = latencyMs, // ChatMessage(latencyMs = latencyMs, type = ChatMessageType.TEXT, side =
type = ChatMessageType.TEXT, // ChatSide.AGENT)
side = ChatSide.AGENT // )
) // }
) // }
} // }
} // }
}
}

View file

@ -16,9 +16,10 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap // import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
@ -32,40 +33,37 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow
data class MessageLayoutConfig( data class MessageLayoutConfig(
val horizontalArrangement: Arrangement.Horizontal, val horizontalArrangement: Arrangement.Horizontal,
val modifier: Modifier, val modifier: Modifier,
val userLabel: String, val userLabel: String,
val rightSideLabel: String val rightSideLabel: String,
) )
/** /**
* Composable function to display the sender information for a chat message. * Composable function to display the sender information for a chat message.
* *
* This function handles different types of chat messages, including system messages, * This function handles different types of chat messages, including system messages, benchmark
* benchmark results, and image generation results, and displays the appropriate sender label * results, and image generation results, and displays the appropriate sender label and status
* and status information. * information.
*/ */
@Composable @Composable
fun MessageSender( fun MessageSender(message: ChatMessage, agentName: String = "", imageHistoryCurIndex: Int = 0) {
message: ChatMessage,
agentName: String = "",
imageHistoryCurIndex: Int = 0
) {
// No user label for system messages. // No user label for system messages.
if (message.side == ChatSide.SYSTEM) { if (message.side == ChatSide.SYSTEM) {
return return
} }
val (horizontalArrangement, modifier, userLabel, rightSideLabel) = getMessageLayoutConfig( val (horizontalArrangement, modifier, userLabel, rightSideLabel) =
message = message, agentName = agentName, imageHistoryCurIndex = imageHistoryCurIndex getMessageLayoutConfig(
) message = message,
agentName = agentName,
imageHistoryCurIndex = imageHistoryCurIndex,
)
Row( Row(
modifier = modifier, modifier = modifier,
@ -74,10 +72,7 @@ fun MessageSender(
) { ) {
Row(verticalAlignment = Alignment.CenterVertically) { Row(verticalAlignment = Alignment.CenterVertically) {
// Sender label. // Sender label.
Text( Text(userLabel, style = MaterialTheme.typography.titleSmall)
userLabel,
style = MaterialTheme.typography.titleSmall,
)
when (message) { when (message) {
// Benchmark running status. // Benchmark running status.
@ -87,21 +82,18 @@ fun MessageSender(
CircularProgressIndicator( CircularProgressIndicator(
modifier = Modifier.size(10.dp), modifier = Modifier.size(10.dp),
strokeWidth = 1.5.dp, strokeWidth = 1.5.dp,
color = MaterialTheme.colorScheme.secondary color = MaterialTheme.colorScheme.secondary,
) )
Spacer(modifier = Modifier.width(4.dp)) Spacer(modifier = Modifier.width(4.dp))
} }
val statusLabel = if (message.isWarmingUp()) { val statusLabel =
stringResource(R.string.warming_up) if (message.isWarmingUp()) {
} else if (message.isRunning()) { stringResource(R.string.warming_up)
stringResource(R.string.running) } else if (message.isRunning()) {
} else "" stringResource(R.string.running)
} else ""
if (statusLabel.isNotEmpty()) { if (statusLabel.isNotEmpty()) {
Text( Text(statusLabel, color = MaterialTheme.colorScheme.secondary, style = bodySmallNarrow)
statusLabel,
color = MaterialTheme.colorScheme.secondary,
style = bodySmallNarrow,
)
} }
} }
@ -112,7 +104,7 @@ fun MessageSender(
CircularProgressIndicator( CircularProgressIndicator(
modifier = Modifier.size(10.dp), modifier = Modifier.size(10.dp),
strokeWidth = 1.5.dp, strokeWidth = 1.5.dp,
color = MaterialTheme.colorScheme.secondary color = MaterialTheme.colorScheme.secondary,
) )
} }
} }
@ -124,7 +116,7 @@ fun MessageSender(
CircularProgressIndicator( CircularProgressIndicator(
modifier = Modifier.size(10.dp), modifier = Modifier.size(10.dp),
strokeWidth = 1.5.dp, strokeWidth = 1.5.dp,
color = MaterialTheme.colorScheme.secondary color = MaterialTheme.colorScheme.secondary,
) )
Spacer(modifier = Modifier.width(4.dp)) Spacer(modifier = Modifier.width(4.dp))
Text( Text(
@ -141,8 +133,7 @@ fun MessageSender(
when (message) { when (message) {
is ChatMessageBenchmarkResult, is ChatMessageBenchmarkResult,
is ChatMessageImageWithHistory, is ChatMessageImageWithHistory,
is ChatMessageBenchmarkLlmResult, is ChatMessageBenchmarkLlmResult -> {
-> {
Text(rightSideLabel, style = MaterialTheme.typography.bodySmall) Text(rightSideLabel, style = MaterialTheme.typography.bodySmall)
} }
} }
@ -169,11 +160,12 @@ private fun getMessageLayoutConfig(
horizontalArrangement = Arrangement.SpaceBetween horizontalArrangement = Arrangement.SpaceBetween
modifier = modifier.fillMaxWidth() modifier = modifier.fillMaxWidth()
userLabel = "Benchmark" userLabel = "Benchmark"
rightSideLabel = if (message.isWarmingUp()) { rightSideLabel =
"${message.warmupCurrent}/${message.warmupTotal}" if (message.isWarmingUp()) {
} else { "${message.warmupCurrent}/${message.warmupTotal}"
"${message.iterationCurrent}/${message.iterationTotal}" } else {
} "${message.iterationCurrent}/${message.iterationTotal}"
}
} }
is ChatMessageBenchmarkLlmResult -> { is ChatMessageBenchmarkLlmResult -> {
@ -198,64 +190,68 @@ private fun getMessageLayoutConfig(
horizontalArrangement = horizontalArrangement, horizontalArrangement = horizontalArrangement,
modifier = modifier, modifier = modifier,
userLabel = userLabel, userLabel = userLabel,
rightSideLabel = rightSideLabel rightSideLabel = rightSideLabel,
) )
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun MessageSenderPreview() { // fun MessageSenderPreview() {
GalleryTheme { // GalleryTheme {
Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp)) { // Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp))
// Agent message. // {
MessageSender( // // Agent message.
message = ChatMessageText(content = "hello world", side = ChatSide.AGENT), // MessageSender(
agentName = stringResource(R.string.chat_generic_agent_name) // message = ChatMessageText(content = "hello world", side = ChatSide.AGENT),
) // agentName = stringResource(R.string.chat_generic_agent_name),
// User message. // )
MessageSender( // // User message.
message = ChatMessageText(content = "hello world", side = ChatSide.USER), // MessageSender(
agentName = stringResource(R.string.chat_generic_agent_name) // message = ChatMessageText(content = "hello world", side = ChatSide.USER),
) // agentName = stringResource(R.string.chat_generic_agent_name),
// Benchmark during warmup. // )
MessageSender( // // Benchmark during warmup.
message = ChatMessageBenchmarkResult( // MessageSender(
orderedStats = listOf(), // message =
statValues = mutableMapOf(), // ChatMessageBenchmarkResult(
values = listOf(), // orderedStats = listOf(),
histogram = Histogram(listOf(), 0), // statValues = mutableMapOf(),
warmupCurrent = 10, // values = listOf(),
warmupTotal = 50, // histogram = Histogram(listOf(), 0),
iterationCurrent = 0, // warmupCurrent = 10,
iterationTotal = 200 // warmupTotal = 50,
), // iterationCurrent = 0,
agentName = stringResource(R.string.chat_generic_agent_name) // iterationTotal = 200,
) // ),
// Benchmark during running. // agentName = stringResource(R.string.chat_generic_agent_name),
MessageSender( // )
message = ChatMessageBenchmarkResult( // // Benchmark during running.
orderedStats = listOf(), // MessageSender(
statValues = mutableMapOf(), // message =
values = listOf(), // ChatMessageBenchmarkResult(
histogram = Histogram(listOf(), 0), // orderedStats = listOf(),
warmupCurrent = 50, // statValues = mutableMapOf(),
warmupTotal = 50, // values = listOf(),
iterationCurrent = 123, // histogram = Histogram(listOf(), 0),
iterationTotal = 200 // warmupCurrent = 50,
), // warmupTotal = 50,
agentName = stringResource(R.string.chat_generic_agent_name) // iterationCurrent = 123,
) // iterationTotal = 200,
// Image generation during running. // ),
MessageSender( // agentName = stringResource(R.string.chat_generic_agent_name),
message = ChatMessageImageWithHistory( // )
bitmaps = listOf(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888)), // // Image generation during running.
imageBitMaps = listOf(), // MessageSender(
totalIterations = 10, // message =
ChatSide.AGENT // ChatMessageImageWithHistory(
), // bitmaps = listOf(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888)),
agentName = stringResource(R.string.chat_generic_agent_name), // imageBitMaps = listOf(),
imageHistoryCurIndex = 4, // totalIterations = 10,
) // ChatSide.AGENT,
} // ),
} // agentName = stringResource(R.string.chat_generic_agent_name),
} // imageHistoryCurIndex = 4,
// )
// }
// }
// }

View file

@ -45,7 +45,7 @@ import kotlinx.coroutines.delay
fun ModelDownloadStatusInfoPanel( fun ModelDownloadStatusInfoPanel(
model: Model, model: Model,
task: Task, task: Task,
modelManagerViewModel: ModelManagerViewModel modelManagerViewModel: ModelManagerViewModel,
) { ) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
@ -62,9 +62,12 @@ fun ModelDownloadStatusInfoPanel(
var downloadModelButtonConditionMet by remember { mutableStateOf(false) } var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
downloadingAnimationConditionMet = downloadingAnimationConditionMet =
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS || curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED || curStatus?.status == ModelDownloadStatusType.UNZIPPING curStatus?.status == ModelDownloadStatusType.IN_PROGRESS ||
curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED ||
curStatus?.status == ModelDownloadStatusType.UNZIPPING
downloadModelButtonConditionMet = downloadModelButtonConditionMet =
curStatus?.status == ModelDownloadStatusType.FAILED || curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED curStatus?.status == ModelDownloadStatusType.FAILED ||
curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
LaunchedEffect(downloadingAnimationConditionMet) { LaunchedEffect(downloadingAnimationConditionMet) {
if (downloadingAnimationConditionMet) { if (downloadingAnimationConditionMet) {
@ -87,24 +90,22 @@ fun ModelDownloadStatusInfoPanel(
AnimatedVisibility( AnimatedVisibility(
visible = shouldShowDownloadingAnimation, visible = shouldShowDownloadingAnimation,
enter = scaleIn(initialScale = 0.9f) + fadeIn(), enter = scaleIn(initialScale = 0.9f) + fadeIn(),
exit = scaleOut(targetScale = 0.9f) + fadeOut() exit = scaleOut(targetScale = 0.9f) + fadeOut(),
) { ) {
Box( Box(modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center
) {
ModelDownloadingAnimation( ModelDownloadingAnimation(
model = model, task = task, modelManagerViewModel = modelManagerViewModel model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
) )
} }
} }
AnimatedVisibility( AnimatedVisibility(visible = shouldShowDownloadModelButton, enter = fadeIn(), exit = fadeOut()) {
visible = shouldShowDownloadModelButton, enter = fadeIn(), exit = fadeOut()
) {
Column( Column(
modifier = Modifier.fillMaxSize(), modifier = Modifier.fillMaxSize(),
verticalArrangement = Arrangement.Center, verticalArrangement = Arrangement.Center,
horizontalAlignment = Alignment.CenterHorizontally horizontalAlignment = Alignment.CenterHorizontally,
) { ) {
DownloadAndTryButton( DownloadAndTryButton(
task = task, task = task,
@ -112,8 +113,8 @@ fun ModelDownloadStatusInfoPanel(
enabled = true, enabled = true,
needToDownloadFirst = true, needToDownloadFirst = true,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
onClicked = {} onClicked = {},
) )
} }
} }
} }

View file

@ -16,6 +16,11 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.animation.core.Animatable import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.Easing import androidx.compose.animation.core.Easing
import androidx.compose.animation.core.tween import androidx.compose.animation.core.tween
@ -47,13 +52,10 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.ColorFilter import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.ModelDownloadStatusType
@ -62,14 +64,10 @@ import com.google.ai.edge.gallery.ui.common.formatToHourMinSecond
import com.google.ai.edge.gallery.ui.common.getTaskIconColor import com.google.ai.edge.gallery.ui.common.getTaskIconColor
import com.google.ai.edge.gallery.ui.common.humanReadableSize import com.google.ai.edge.gallery.ui.common.humanReadableSize
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
import kotlinx.coroutines.delay
import kotlin.math.cos import kotlin.math.cos
import kotlin.math.pow import kotlin.math.pow
import kotlinx.coroutines.delay
private val GRID_SIZE = 240.dp private val GRID_SIZE = 240.dp
private val GRID_SPACING = 0.dp private val GRID_SPACING = 0.dp
@ -78,7 +76,6 @@ private const val ANIMATION_DURATION = 500
private const val START_SCALE = 0.9f private const val START_SCALE = 0.9f
private const val END_SCALE = 0.6f private const val END_SCALE = 0.6f
/** /**
* Composable function to display a loading animation using a 2x2 grid of images with a synchronized * Composable function to display a loading animation using a 2x2 grid of images with a synchronized
* scaling and rotation effect. * scaling and rotation effect.
@ -87,7 +84,7 @@ private const val END_SCALE = 0.6f
fun ModelDownloadingAnimation( fun ModelDownloadingAnimation(
model: Model, model: Model,
task: Task, task: Task,
modelManagerViewModel: ModelManagerViewModel modelManagerViewModel: ModelManagerViewModel,
) { ) {
val scale = remember { Animatable(END_SCALE) } val scale = remember { Animatable(END_SCALE) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
@ -103,26 +100,27 @@ fun ModelDownloadingAnimation(
// Phase 1: Scale up // Phase 1: Scale up
scale.animateTo( scale.animateTo(
targetValue = START_SCALE, targetValue = START_SCALE,
animationSpec = tween( animationSpec =
durationMillis = ANIMATION_DURATION, tween(
easing = multiBounceEasing(bounces = 3, decay = 0.02f) durationMillis = ANIMATION_DURATION,
) easing = multiBounceEasing(bounces = 3, decay = 0.02f),
),
) )
delay(PAUSE_DURATION.toLong()) delay(PAUSE_DURATION.toLong())
// Phase 2: Scale down // Phase 2: Scale down
scale.animateTo( scale.animateTo(
targetValue = END_SCALE, targetValue = END_SCALE,
animationSpec = tween( animationSpec =
durationMillis = ANIMATION_DURATION, tween(
easing = multiBounceEasing(bounces = 3, decay = 0.02f) durationMillis = ANIMATION_DURATION,
) easing = multiBounceEasing(bounces = 3, decay = 0.02f),
),
) )
delay(PAUSE_DURATION.toLong()) delay(PAUSE_DURATION.toLong())
} }
} }
// Failure message. // Failure message.
val curDownloadStatus = downloadStatus val curDownloadStatus = downloadStatus
if (curDownloadStatus != null && curDownloadStatus.status == ModelDownloadStatusType.FAILED) { if (curDownloadStatus != null && curDownloadStatus.status == ModelDownloadStatusType.FAILED) {
@ -139,58 +137,55 @@ fun ModelDownloadingAnimation(
else { else {
Column( Column(
horizontalAlignment = Alignment.CenterHorizontally, horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier.offset(y = -GRID_SIZE / 8) modifier = Modifier.offset(y = -GRID_SIZE / 8),
) { ) {
LazyVerticalGrid( LazyVerticalGrid(
columns = GridCells.Fixed(2), columns = GridCells.Fixed(2),
horizontalArrangement = Arrangement.spacedBy(GRID_SPACING), horizontalArrangement = Arrangement.spacedBy(GRID_SPACING),
verticalArrangement = Arrangement.spacedBy(GRID_SPACING), verticalArrangement = Arrangement.spacedBy(GRID_SPACING),
modifier = Modifier modifier = Modifier.width(GRID_SIZE).height(GRID_SIZE),
.width(GRID_SIZE)
.height(GRID_SIZE)
) { ) {
itemsIndexed( itemsIndexed(
listOf( listOf(
R.drawable.pantegon, R.drawable.pantegon,
R.drawable.double_circle, R.drawable.double_circle,
R.drawable.circle, R.drawable.circle,
R.drawable.four_circle R.drawable.four_circle,
) )
) { index, imageResource -> ) { index, imageResource ->
val currentScale = val currentScale =
if (index == 0 || index == 3) scale.value else START_SCALE + END_SCALE - scale.value if (index == 0 || index == 3) scale.value else START_SCALE + END_SCALE - scale.value
Box( Box(
modifier = Modifier modifier =
.width((GRID_SIZE - GRID_SPACING) / 2) Modifier.width((GRID_SIZE - GRID_SPACING) / 2).height((GRID_SIZE - GRID_SPACING) / 2),
.height((GRID_SIZE - GRID_SPACING) / 2), contentAlignment =
contentAlignment = when (index) { when (index) {
0 -> Alignment.BottomEnd 0 -> Alignment.BottomEnd
1 -> Alignment.BottomStart 1 -> Alignment.BottomStart
2 -> Alignment.TopEnd 2 -> Alignment.TopEnd
3 -> Alignment.TopStart 3 -> Alignment.TopStart
else -> Alignment.Center else -> Alignment.Center
} },
) { ) {
Image( Image(
painter = painterResource(id = imageResource), painter = painterResource(id = imageResource),
contentDescription = "", contentDescription = "",
contentScale = ContentScale.Fit, contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint(getTaskIconColor(index = index)), colorFilter = ColorFilter.tint(getTaskIconColor(index = index)),
modifier = Modifier modifier =
.graphicsLayer { Modifier.graphicsLayer {
scaleX = currentScale scaleX = currentScale
scaleY = currentScale scaleY = currentScale
rotationZ = currentScale * 120 rotationZ = currentScale * 120
alpha = 0.8f alpha = 0.8f
} }
.size(70.dp) .size(70.dp),
) )
} }
} }
} }
// Download stats // Download stats
var sizeLabel = model.totalBytes.humanReadableSize() var sizeLabel = model.totalBytes.humanReadableSize()
if (curDownloadStatus != null) { if (curDownloadStatus != null) {
@ -203,8 +198,7 @@ fun ModelDownloadingAnimation(
sizeLabel = sizeLabel =
"${curDownloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}" "${curDownloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
if (curDownloadStatus.bytesPerSecond > 0) { if (curDownloadStatus.bytesPerSecond > 0) {
sizeLabel = sizeLabel = "$sizeLabel · ${curDownloadStatus.bytesPerSecond.humanReadableSize()} / s"
"$sizeLabel · ${curDownloadStatus.bytesPerSecond.humanReadableSize()} / s"
if (curDownloadStatus.remainingMs >= 0) { if (curDownloadStatus.remainingMs >= 0) {
sizeLabel = sizeLabel =
"$sizeLabel · ${curDownloadStatus.remainingMs.formatToHourMinSecond()} left" "$sizeLabel · ${curDownloadStatus.remainingMs.formatToHourMinSecond()} left"
@ -229,8 +223,7 @@ fun ModelDownloadingAnimation(
style = MaterialTheme.typography.labelMedium, style = MaterialTheme.typography.labelMedium,
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
overflow = TextOverflow.Visible, overflow = TextOverflow.Visible,
modifier = Modifier modifier = Modifier.padding(bottom = 4.dp),
.padding(bottom = 4.dp)
) )
} }
@ -241,10 +234,7 @@ fun ModelDownloadingAnimation(
progress = { animatedProgress.value }, progress = { animatedProgress.value },
color = getTaskIconColor(task = task), color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(bottom = 36.dp).padding(horizontal = 36.dp),
.fillMaxWidth()
.padding(bottom = 36.dp)
.padding(horizontal = 36.dp)
) )
LaunchedEffect(curDownloadProgress) { LaunchedEffect(curDownloadProgress) {
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150)) animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
@ -255,23 +245,19 @@ fun ModelDownloadingAnimation(
LinearProgressIndicator( LinearProgressIndicator(
color = getTaskIconColor(task = task), color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(bottom = 36.dp).padding(horizontal = 36.dp),
.fillMaxWidth()
.padding(bottom = 36.dp)
.padding(horizontal = 36.dp)
) )
} }
Text( Text(
"Feel free to switch apps or lock your device.\n" "Feel free to switch apps or lock your device.\n" +
+ "The download will continue in the background.\n" "The download will continue in the background.\n" +
+ "We'll send a notification when it's done.", "We'll send a notification when it's done.",
style = MaterialTheme.typography.bodyLarge, style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Center textAlign = TextAlign.Center,
) )
} }
} }
} }
// Custom Easing function for a multi-bounce effect // Custom Easing function for a multi-bounce effect
@ -283,18 +269,18 @@ fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun ModelDownloadingAnimationPreview() { // fun ModelDownloadingAnimationPreview() {
val context = LocalContext.current // val context = LocalContext.current
GalleryTheme { // GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) { // Row(modifier = Modifier.padding(16.dp)) {
ModelDownloadingAnimation( // ModelDownloadingAnimation(
model = MODEL_TEST1, // model = MODEL_TEST1,
task = TASK_TEST1, // task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = context) // modelManagerViewModel = PreviewModelManagerViewModel(context = context),
) // )
} // }
} // }
} // }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
@ -34,37 +37,35 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/** /**
* Composable function to display a visual indicator for model initialization status. * Composable function to display a visual indicator for model initialization status.
* *
* This function renders a row containing a circular progress indicator and a message * This function renders a row containing a circular progress indicator and a message indicating
* indicating that the model is currently initializing. It provides a visual cue to the * that the model is currently initializing. It provides a visual cue to the user that the model is
* user that the model is in a loading state. * in a loading state.
*/ */
@Composable @Composable
fun ModelInitializationStatusChip() { fun ModelInitializationStatusChip() {
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) { Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) {
Box( Box(
modifier = Modifier modifier =
.padding(8.dp) Modifier.padding(8.dp)
.clip(CircleShape) .clip(CircleShape)
.background(MaterialTheme.colorScheme.secondaryContainer) .background(MaterialTheme.colorScheme.secondaryContainer)
) { ) {
Row( Row(
modifier = Modifier.padding(top = 4.dp, bottom = 4.dp, start = 8.dp, end = 8.dp), modifier = Modifier.padding(top = 4.dp, bottom = 4.dp, start = 8.dp, end = 8.dp),
horizontalArrangement = Arrangement.Center, horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically,
) { ) {
// Circular progress indicator. // Circular progress indicator.
CircularProgressIndicator( CircularProgressIndicator(
modifier = Modifier.size(14.dp), modifier = Modifier.size(14.dp),
strokeWidth = 2.dp, strokeWidth = 2.dp,
color = MaterialTheme.colorScheme.onSecondaryContainer color = MaterialTheme.colorScheme.onSecondaryContainer,
) )
Spacer(modifier = Modifier.width(8.dp)) Spacer(modifier = Modifier.width(8.dp))
@ -80,10 +81,8 @@ fun ModelInitializationStatusChip() {
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun ModelInitializationStatusPreview() { // fun ModelInitializationStatusPreview() {
GalleryTheme { // GalleryTheme { ModelInitializationStatusChip() }
ModelInitializationStatusChip() // }
}
}

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
@ -24,8 +27,6 @@ import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/** /**
* Composable function to display a button to download model if the model has not been downloaded. * Composable function to display a button to download model if the model has not been downloaded.
@ -35,20 +36,14 @@ fun ModelNotDownloaded(modifier: Modifier = Modifier, onClicked: () -> Unit) {
Column( Column(
modifier = modifier.fillMaxSize(), modifier = modifier.fillMaxSize(),
verticalArrangement = Arrangement.Center, verticalArrangement = Arrangement.Center,
horizontalAlignment = Alignment.CenterHorizontally horizontalAlignment = Alignment.CenterHorizontally,
) { ) {
Button( Button(onClick = onClicked) { Text("Download & Try it", maxLines = 1) }
onClick = onClicked,
) {
Text("Download & Try it", maxLines = 1)
}
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun Preview() { // fun Preview() {
GalleryTheme { // GalleryTheme { ModelNotDownloaded(onClicked = {}) }
ModelNotDownloaded(onClicked = {}) // }
}
}

View file

@ -16,7 +16,12 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement // import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST2
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
@ -31,17 +36,13 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.convertValueToTargetType import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.ui.common.ConfigDialog
import com.google.ai.edge.gallery.ui.common.modelitem.ModelItem import com.google.ai.edge.gallery.ui.common.modelitem.ModelItem
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.preview.TASK_TEST2
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/** /**
* Composable function to display a selectable model item with an option to configure its settings. * Composable function to display a selectable model item with an option to configure its settings.
@ -53,39 +54,31 @@ fun ModelSelector(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
contentAlpha: Float = 1f, contentAlpha: Float = 1f,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> }, onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit =
{ _, _ ->
},
) { ) {
var showConfigDialog by remember { mutableStateOf(false) } var showConfigDialog by remember { mutableStateOf(false) }
val context = LocalContext.current val context = LocalContext.current
Column( Column(modifier = modifier) {
modifier = modifier
) {
Box( Box(
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(bottom = 8.dp),
.fillMaxWidth() contentAlignment = Alignment.Center,
.padding(bottom = 8.dp),
contentAlignment = Alignment.Center
) { ) {
// Model row. // Model row.
Row( Row(
modifier = Modifier modifier = Modifier.fillMaxWidth().graphicsLayer { alpha = contentAlpha },
.fillMaxWidth() verticalAlignment = Alignment.CenterVertically,
.graphicsLayer { alpha = contentAlpha },
verticalAlignment = Alignment.CenterVertically
) { ) {
ModelItem( ModelItem(
model = model, model = model,
task = task, task = task,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
onModelClicked = {}, onModelClicked = {},
onConfigClicked = { onConfigClicked = { showConfigDialog = true },
showConfigDialog = true
},
verticalSpacing = 10.dp, verticalSpacing = 10.dp,
modifier = Modifier modifier = Modifier.weight(1f).padding(horizontal = 16.dp),
.weight(1f)
.padding(horizontal = 16.dp),
showDeleteButton = false, showDeleteButton = false,
showConfigButtonIfExisted = true, showConfigButtonIfExisted = true,
canExpand = false, canExpand = false,
@ -111,12 +104,16 @@ fun ModelSelector(
var needReinitialization = false var needReinitialization = false
for (config in model.configs) { for (config in model.configs) {
val key = config.key.label val key = config.key.label
val oldValue = convertValueToTargetType( val oldValue =
value = model.configValues.getValue(key), valueType = config.valueType convertValueToTargetType(
) value = model.configValues.getValue(key),
val newValue = convertValueToTargetType( valueType = config.valueType,
value = curConfigValues.getValue(key), valueType = config.valueType )
) val newValue =
convertValueToTargetType(
value = curConfigValues.getValue(key),
valueType = config.valueType,
)
if (oldValue != newValue) { if (oldValue != newValue) {
same = false same = false
if (config.needReinitialization) { if (config.needReinitialization) {
@ -139,7 +136,7 @@ fun ModelSelector(
context = context, context = context,
task = task, task = task,
model = model, model = model,
force = true force = true,
) )
} }
@ -150,27 +147,26 @@ fun ModelSelector(
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun ModelSelectorPreview( // fun ModelSelectorPreview() {
) { // GalleryTheme {
GalleryTheme { // Column(verticalArrangement = Arrangement.spacedBy(16.dp)) {
Column(verticalArrangement = Arrangement.spacedBy(16.dp)) { // ModelSelector(
ModelSelector( // model = TASK_TEST1.models[0],
model = TASK_TEST1.models[0], // task = TASK_TEST1,
task = TASK_TEST1, // modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), // )
) // ModelSelector(
ModelSelector( // model = TASK_TEST1.models[1],
model = TASK_TEST1.models[1], // task = TASK_TEST1,
task = TASK_TEST1, // modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), // )
) // ModelSelector(
ModelSelector( // model = TASK_TEST2.models[1],
model = TASK_TEST2.models[1], // task = TASK_TEST2,
task = TASK_TEST2, // modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), // )
) // }
} // }
} // }
}

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
@ -53,10 +55,8 @@ import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors import com.google.ai.edge.gallery.ui.theme.customColors
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -68,7 +68,7 @@ fun TextInputHistorySheet(
onHistoryItemClicked: (String) -> Unit, onHistoryItemClicked: (String) -> Unit,
onHistoryItemDeleted: (String) -> Unit, onHistoryItemDeleted: (String) -> Unit,
onHistoryItemsDeleteAll: () -> Unit, onHistoryItemsDeleteAll: () -> Unit,
onDismissed: () -> Unit onDismissed: () -> Unit,
) { ) {
val sheetState = rememberModalBottomSheetState() val sheetState = rememberModalBottomSheetState()
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
@ -101,7 +101,7 @@ fun TextInputHistorySheet(
sheetState.hide() sheetState.hide()
onDismissed() onDismissed()
} }
} },
) )
} }
} }
@ -112,7 +112,7 @@ private fun SheetContent(
onHistoryItemClicked: (String) -> Unit, onHistoryItemClicked: (String) -> Unit,
onHistoryItemDeleted: (String) -> Unit, onHistoryItemDeleted: (String) -> Unit,
onHistoryItemsDeleteAll: () -> Unit, onHistoryItemsDeleteAll: () -> Unit,
onDismissed: () -> Unit onDismissed: () -> Unit,
) { ) {
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
var showConfirmDeleteDialog by remember { mutableStateOf(false) } var showConfirmDeleteDialog by remember { mutableStateOf(false) }
@ -122,47 +122,44 @@ private fun SheetContent(
Text( Text(
"Text input history", "Text input history",
style = MaterialTheme.typography.titleLarge, style = MaterialTheme.typography.titleLarge,
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(8.dp),
.fillMaxWidth() textAlign = TextAlign.Center,
.padding(8.dp),
textAlign = TextAlign.Center
) )
IconButton(modifier = Modifier.padding(end = 12.dp), onClick = { IconButton(
showConfirmDeleteDialog = true modifier = Modifier.padding(end = 12.dp),
}) { onClick = { showConfirmDeleteDialog = true },
) {
Icon(Icons.Rounded.DeleteSweep, contentDescription = "") Icon(Icons.Rounded.DeleteSweep, contentDescription = "")
} }
} }
LazyColumn(modifier = Modifier.weight(1f)) { LazyColumn(modifier = Modifier.weight(1f)) {
items(history, key = { it }) { item -> items(history, key = { it }) { item ->
Row( Row(
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.padding(horizontal = 8.dp, vertical = 2.dp) .padding(horizontal = 8.dp, vertical = 2.dp)
.clip(RoundedCornerShape(24.dp)) .clip(RoundedCornerShape(24.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor) .background(MaterialTheme.customColors.agentBubbleBgColor)
.clickable { .clickable { onHistoryItemClicked(item) },
onHistoryItemClicked(item)
},
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp) horizontalArrangement = Arrangement.spacedBy(8.dp),
) { ) {
Text( Text(
item, item,
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
maxLines = 3, maxLines = 3,
overflow = TextOverflow.Ellipsis, overflow = TextOverflow.Ellipsis,
modifier = Modifier modifier = Modifier.padding(vertical = 16.dp).padding(start = 16.dp).weight(1f),
.padding(vertical = 16.dp)
.padding(start = 16.dp)
.weight(1f)
) )
IconButton(modifier = Modifier.padding(end = 8.dp), onClick = { IconButton(
scope.launch { modifier = Modifier.padding(end = 8.dp),
delay(400) onClick = {
onHistoryItemDeleted(item) scope.launch {
} delay(400)
}) { onHistoryItemDeleted(item)
}
},
) {
Icon(Icons.Rounded.Delete, contentDescription = "") Icon(Icons.Rounded.Delete, contentDescription = "")
} }
} }
@ -171,18 +168,17 @@ private fun SheetContent(
} }
if (showConfirmDeleteDialog) { if (showConfirmDeleteDialog) {
AlertDialog(onDismissRequest = { showConfirmDeleteDialog = false }, AlertDialog(
onDismissRequest = { showConfirmDeleteDialog = false },
title = { Text("Clear history?") }, title = { Text("Clear history?") },
text = { text = { Text("Are you sure you want to clear the history? This action cannot be undone.") },
Text(
"Are you sure you want to clear the history? This action cannot be undone."
)
},
confirmButton = { confirmButton = {
Button(onClick = { Button(
showConfirmDeleteDialog = false onClick = {
onHistoryItemsDeleteAll() showConfirmDeleteDialog = false
}) { onHistoryItemsDeleteAll()
}
) {
Text(stringResource(R.string.ok)) Text(stringResource(R.string.ok))
} }
}, },
@ -190,25 +186,29 @@ private fun SheetContent(
TextButton(onClick = { showConfirmDeleteDialog = false }) { TextButton(onClick = { showConfirmDeleteDialog = false }) {
Text(stringResource(R.string.cancel)) Text(stringResource(R.string.cancel))
} }
}) },
)
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun TextInputHistorySheetContentPreview() { // fun TextInputHistorySheetContentPreview() {
GalleryTheme { // GalleryTheme {
SheetContent( // SheetContent(
history = listOf( // history =
"Analyze the sentiment of the following Tweets and classify them as POSITIVE, NEGATIVE, or NEUTRAL. \"It's so beautiful today!\"", // listOf(
"I have the ingredients above. Not sure what to cook for lunch. Show me a list of foods with the recipes.", // "Analyze the sentiment of the following Tweets and classify them as POSITIVE, NEGATIVE,
"You are Santa Claus, write a letter back for this kid.", // or NEUTRAL. \"It's so beautiful today!\"",
"Generate a list of cookie recipes. Make the outputs in JSON format." // "I have the ingredients above. Not sure what to cook for lunch. Show me a list of foods
), // with the recipes.",
onHistoryItemClicked = {}, // "You are Santa Claus, write a letter back for this kid.",
onHistoryItemDeleted = {}, // "Generate a list of cookie recipes. Make the outputs in JSON format.",
onHistoryItemsDeleteAll = {}, // ),
onDismissed = {}, // onHistoryItemClicked = {},
) // onHistoryItemDeleted = {},
} // onHistoryItemsDeleteAll = {},
} // onDismissed = {},
// )
// }
// }

View file

@ -35,25 +35,28 @@ fun ZoomableBox(
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
minScale: Float = 1f, minScale: Float = 1f,
maxScale: Float = 5f, maxScale: Float = 5f,
content: @Composable ZoomableBoxScope.() -> Unit content: @Composable ZoomableBoxScope.() -> Unit,
) { ) {
var scale by remember { mutableFloatStateOf(1f) } var scale by remember { mutableFloatStateOf(1f) }
var offsetX by remember { mutableFloatStateOf(0f) } var offsetX by remember { mutableFloatStateOf(0f) }
var offsetY by remember { mutableFloatStateOf(0f) } var offsetY by remember { mutableFloatStateOf(0f) }
var size by remember { mutableStateOf(IntSize.Zero) } var size by remember { mutableStateOf(IntSize.Zero) }
Box(modifier = modifier Box(
.onSizeChanged { size = it } modifier =
.pointerInput(Unit) { modifier
detectTransformGestures { _, pan, zoom, _ -> .onSizeChanged { size = it }
scale = maxOf(minScale, minOf(scale * zoom, maxScale)) .pointerInput(Unit) {
val maxX = (size.width * (scale - 1)) / 2 detectTransformGestures { _, pan, zoom, _ ->
val minX = -maxX scale = maxOf(minScale, minOf(scale * zoom, maxScale))
offsetX = maxOf(minX, minOf(maxX, offsetX + pan.x)) val maxX = (size.width * (scale - 1)) / 2
val maxY = (size.height * (scale - 1)) / 2 val minX = -maxX
val minY = -maxY offsetX = maxOf(minX, minOf(maxX, offsetX + pan.x))
offsetY = maxOf(minY, minOf(maxY, offsetY + pan.y)) val maxY = (size.height * (scale - 1)) / 2
} val minY = -maxY
}, contentAlignment = Alignment.TopEnd offsetY = maxOf(minY, minOf(maxY, offsetY + pan.y))
}
},
contentAlignment = Alignment.TopEnd,
) { ) {
val scope = ZoomableBoxScopeImpl(scale, offsetX, offsetY) val scope = ZoomableBoxScopeImpl(scale, offsetX, offsetY)
scope.content() scope.content()
@ -67,5 +70,7 @@ interface ZoomableBoxScope {
} }
private data class ZoomableBoxScopeImpl( private data class ZoomableBoxScopeImpl(
override val scale: Float, override val offsetX: Float, override val offsetY: Float override val scale: Float,
override val offsetX: Float,
override val offsetY: Float,
) : ZoomableBoxScope ) : ZoomableBoxScope

View file

@ -1,73 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.modelitem
import androidx.compose.animation.core.DeferredTargetAnimation
import androidx.compose.animation.core.ExperimentalAnimatableApi
import androidx.compose.animation.core.VectorConverter
import androidx.compose.animation.core.tween
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.ui.Modifier
import androidx.compose.ui.composed
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.layout.LookaheadScope
import androidx.compose.ui.layout.approachLayout
import androidx.compose.ui.unit.Constraints
import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.unit.IntSize
import androidx.compose.ui.unit.round
const val LAYOUT_ANIMATION_DURATION = 250
context(LookaheadScope)
@OptIn(ExperimentalAnimatableApi::class)
fun Modifier.animateLayout(): Modifier = composed {
val sizeAnim = remember { DeferredTargetAnimation(IntSize.VectorConverter) }
val offsetAnim = remember { DeferredTargetAnimation(IntOffset.VectorConverter) }
val scope = rememberCoroutineScope()
this.approachLayout(
isMeasurementApproachInProgress = { lookaheadSize ->
sizeAnim.updateTarget(lookaheadSize, scope, tween(LAYOUT_ANIMATION_DURATION))
!sizeAnim.isIdle
},
isPlacementApproachInProgress = { lookaheadCoordinates ->
val target = lookaheadScopeCoordinates.localLookaheadPositionOf(lookaheadCoordinates)
offsetAnim.updateTarget(target.round(), scope, tween(LAYOUT_ANIMATION_DURATION))
!offsetAnim.isIdle
}
) { measurable, _ ->
val (animWidth, animHeight) = sizeAnim.updateTarget(
lookaheadSize,
scope,
tween(LAYOUT_ANIMATION_DURATION)
)
measurable.measure(Constraints.fixed(animWidth, animHeight))
.run {
layout(width, height) {
coordinates?.let {
val target = lookaheadScopeCoordinates.localLookaheadPositionOf(it).round()
val animOffset = offsetAnim.updateTarget(target, scope, tween(LAYOUT_ANIMATION_DURATION))
val current = lookaheadScopeCoordinates.localPositionOf(it, Offset.Zero).round()
val (x, y) = animOffset - current
place(x, y)
} ?: place(0, 0)
}
}
}
}

View file

@ -25,28 +25,16 @@ import androidx.compose.ui.res.stringResource
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
/** /** Composable function to display a confirmation dialog for deleting a model. */
* Composable function to display a confirmation dialog for deleting a model.
*/
@Composable @Composable
fun ConfirmDeleteModelDialog(model: Model, onConfirm: () -> Unit, onDismiss: () -> Unit) { fun ConfirmDeleteModelDialog(model: Model, onConfirm: () -> Unit, onDismiss: () -> Unit) {
AlertDialog(onDismissRequest = onDismiss, AlertDialog(
onDismissRequest = onDismiss,
title = { Text(stringResource(R.string.confirm_delete_model_dialog_title)) }, title = { Text(stringResource(R.string.confirm_delete_model_dialog_title)) },
text = { text = {
Text( Text(stringResource(R.string.confirm_delete_model_dialog_content).format(model.name))
stringResource(R.string.confirm_delete_model_dialog_content).format(
model.name
)
)
}, },
confirmButton = { confirmButton = { Button(onClick = onConfirm) { Text(stringResource(R.string.ok)) } },
Button(onClick = onConfirm) { dismissButton = { TextButton(onClick = onDismiss) { Text(stringResource(R.string.cancel)) } },
Text(stringResource(R.string.ok)) )
}
},
dismissButton = {
TextButton(onClick = onDismiss) {
Text(stringResource(R.string.cancel))
}
})
} }

View file

@ -16,8 +16,17 @@
package com.google.ai.edge.gallery.ui.common.modelitem package com.google.ai.edge.gallery.ui.common.modelitem
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST2
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST3
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST4
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST2
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.content.Intent import android.content.Intent
import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.animation.AnimatedContent import androidx.compose.animation.AnimatedContent
@ -52,37 +61,28 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.core.net.toUri
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.DownloadAndTryButton import com.google.ai.edge.gallery.ui.common.DownloadAndTryButton
import com.google.ai.edge.gallery.ui.common.MarkdownText
import com.google.ai.edge.gallery.ui.common.TaskIcon import com.google.ai.edge.gallery.ui.common.TaskIcon
import com.google.ai.edge.gallery.ui.common.chat.MarkdownText
import com.google.ai.edge.gallery.ui.common.checkNotificationPermissionAndStartDownload import com.google.ai.edge.gallery.ui.common.checkNotificationPermissionAndStartDownload
import com.google.ai.edge.gallery.ui.common.getTaskBgColor import com.google.ai.edge.gallery.ui.common.getTaskBgColor
import com.google.ai.edge.gallery.ui.common.getTaskIconColor import com.google.ai.edge.gallery.ui.common.getTaskIconColor
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST2
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST3
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST4
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.preview.TASK_TEST2
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
private val DEFAULT_VERTICAL_PADDING = 16.dp private val DEFAULT_VERTICAL_PADDING = 16.dp
/** /**
* Composable function to display a model item in the model manager list. * Composable function to display a model item in the model manager list.
* *
* This function renders a card representing a model, displaying its task icon, name, * This function renders a card representing a model, displaying its task icon, name, download
* download status, and providing action buttons. It supports expanding to show a * status, and providing action buttons. It supports expanding to show a model description and
* model description and buttons for learning more (opening a URL) and downloading/trying * buttons for learning more (opening a URL) and downloading/trying the model.
* the model.
*/ */
@OptIn(ExperimentalSharedTransitionApi::class) @OptIn(ExperimentalSharedTransitionApi::class)
@Composable @Composable
@ -103,170 +103,174 @@ fun ModelItem(
val downloadStatus by remember { val downloadStatus by remember {
derivedStateOf { modelManagerUiState.modelDownloadStatus[model.name] } derivedStateOf { modelManagerUiState.modelDownloadStatus[model.name] }
} }
val launcher = rememberLauncherForActivityResult( val launcher =
ActivityResultContracts.RequestPermission() rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
) { modelManagerViewModel.downloadModel(task = task, model = model)
modelManagerViewModel.downloadModel(task = task, model = model) }
}
var isExpanded by remember { mutableStateOf(false) } var isExpanded by remember { mutableStateOf(false) }
var boxModifier = modifier var boxModifier =
.fillMaxWidth() modifier.fillMaxWidth().clip(RoundedCornerShape(size = 42.dp)).background(getTaskBgColor(task))
.clip(RoundedCornerShape(size = 42.dp)) boxModifier =
.background( if (canExpand) {
getTaskBgColor(task) boxModifier.clickable(
) onClick = {
boxModifier = if (canExpand) { if (!model.imported) {
boxModifier.clickable(onClick = { isExpanded = !isExpanded
if (!model.imported) { } else {
isExpanded = !isExpanded onModelClicked(model)
} else { }
onModelClicked(model) },
} interactionSource = remember { MutableInteractionSource() },
}, interactionSource = remember { MutableInteractionSource() }, indication = ripple( indication = ripple(bounded = true, radius = 1000.dp),
bounded = true, )
radius = 1000.dp, } else {
) boxModifier
) }
} else {
boxModifier
}
Box( Box(modifier = boxModifier, contentAlignment = Alignment.Center) {
modifier = boxModifier,
contentAlignment = Alignment.Center,
) {
SharedTransitionLayout { SharedTransitionLayout {
AnimatedContent( AnimatedContent(isExpanded, label = "item_layout_transition") { targetState ->
isExpanded, label = "item_layout_transition", val taskIcon =
) { targetState -> @Composable {
val taskIcon = @Composable { TaskIcon(
TaskIcon( task = task,
task = task, modifier = Modifier.sharedElement( modifier =
sharedContentState = rememberSharedContentState(key = "task_icon"), Modifier.sharedElement(
animatedVisibilityScope = this@AnimatedContent, sharedContentState = rememberSharedContentState(key = "task_icon"),
)
)
}
val modelNameAndStatus = @Composable {
ModelNameAndStatus(
model = model,
task = task,
downloadStatus = downloadStatus,
isExpanded = isExpanded,
animatedVisibilityScope = this@AnimatedContent,
sharedTransitionScope = this@SharedTransitionLayout
)
}
val actionButton = @Composable {
ModelItemActionButton(
context = context,
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
downloadStatus = downloadStatus,
onDownloadClicked = { model ->
checkNotificationPermissionAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = model
)
},
showDeleteButton = showDeleteButton,
showDownloadButton = false,
modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "action_button"),
animatedVisibilityScope = this@AnimatedContent,
)
)
}
val expandButton = @Composable {
Icon(
// For imported model, show ">" directly indicating users can just tap the model item to
// go into it without needing to expand it first.
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "",
tint = getTaskIconColor(task),
modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "expand_button"),
animatedVisibilityScope = this@AnimatedContent,
)
)
}
val description = @Composable {
if (model.info.isNotEmpty()) {
MarkdownText(
model.info, modifier = Modifier
.sharedElement(
sharedContentState = rememberSharedContentState(key = "description"),
animatedVisibilityScope = this@AnimatedContent, animatedVisibilityScope = this@AnimatedContent,
) ),
.skipToLookaheadSize()
) )
} }
}
val buttonsRow = @Composable { val modelNameAndStatus =
Row( @Composable {
horizontalArrangement = Arrangement.spacedBy(12.dp), modifier = Modifier ModelNameAndStatus(
.sharedElement(
sharedContentState = rememberSharedContentState(key = "buttons_row"),
animatedVisibilityScope = this@AnimatedContent,
)
.skipToLookaheadSize()
) {
// The "learn more" button. Click to show related urls in a bottom sheet.
if (model.learnMoreUrl.isNotEmpty()) {
OutlinedButton(
onClick = {
if (isExpanded) {
val intent = Intent(Intent.ACTION_VIEW, Uri.parse(model.learnMoreUrl))
context.startActivity(intent)
}
},
) {
Text("Learn More", maxLines = 1)
}
}
// Button to start the download and start the chat session with the model.
val needToDownloadFirst =
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
DownloadAndTryButton(task = task,
model = model, model = model,
enabled = isExpanded, task = task,
needToDownloadFirst = needToDownloadFirst, downloadStatus = downloadStatus,
modelManagerViewModel = modelManagerViewModel, isExpanded = isExpanded,
onClicked = { onModelClicked(model) }) animatedVisibilityScope = this@AnimatedContent,
sharedTransitionScope = this@SharedTransitionLayout,
)
}
val actionButton =
@Composable {
ModelItemActionButton(
context = context,
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
downloadStatus = downloadStatus,
onDownloadClicked = { model ->
checkNotificationPermissionAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = model,
)
},
showDeleteButton = showDeleteButton,
showDownloadButton = false,
modifier =
Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "action_button"),
animatedVisibilityScope = this@AnimatedContent,
),
)
}
val expandButton =
@Composable {
Icon(
// For imported model, show ">" directly indicating users can just tap the model item
// to
// go into it without needing to expand it first.
if (model.imported) Icons.Rounded.ChevronRight
else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "",
tint = getTaskIconColor(task),
modifier =
Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "expand_button"),
animatedVisibilityScope = this@AnimatedContent,
),
)
}
val description =
@Composable {
if (model.info.isNotEmpty()) {
MarkdownText(
model.info,
modifier =
Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "description"),
animatedVisibilityScope = this@AnimatedContent,
)
.skipToLookaheadSize(),
)
}
}
val buttonsRow =
@Composable {
Row(
horizontalArrangement = Arrangement.spacedBy(12.dp),
modifier =
Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "buttons_row"),
animatedVisibilityScope = this@AnimatedContent,
)
.skipToLookaheadSize(),
) {
// The "learn more" button. Click to show related urls in a bottom sheet.
if (model.learnMoreUrl.isNotEmpty()) {
OutlinedButton(
onClick = {
if (isExpanded) {
val intent = Intent(Intent.ACTION_VIEW, model.learnMoreUrl.toUri())
context.startActivity(intent)
}
}
) {
Text("Learn More", maxLines = 1)
}
}
// Button to start the download and start the chat session with the model.
val needToDownloadFirst =
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED ||
downloadStatus?.status == ModelDownloadStatusType.FAILED
DownloadAndTryButton(
task = task,
model = model,
enabled = isExpanded,
needToDownloadFirst = needToDownloadFirst,
modelManagerViewModel = modelManagerViewModel,
onClicked = { onModelClicked(model) },
)
}
} }
}
// Collapsed state. // Collapsed state.
if (!targetState) { if (!targetState) {
Column( Column(horizontalAlignment = Alignment.CenterHorizontally) {
horizontalAlignment = Alignment.CenterHorizontally,
) {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp), horizontalArrangement = Arrangement.spacedBy(12.dp),
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.padding(start = 18.dp, end = 18.dp) .padding(start = 18.dp, end = 18.dp)
.padding(vertical = verticalSpacing) .padding(vertical = verticalSpacing),
) { ) {
// Icon at the left. // Icon at the left.
taskIcon() taskIcon()
// Model name and status at the center. // Model name and status at the center.
Row(modifier = Modifier.weight(1f)) { Row(modifier = Modifier.weight(1f)) { modelNameAndStatus() }
modelNameAndStatus()
}
// Action button and expand/collapse button at the right. // Action button and expand/collapse button at the right.
Row(verticalAlignment = Alignment.CenterVertically) { Row(verticalAlignment = Alignment.CenterVertically) {
actionButton() actionButton()
@ -278,9 +282,8 @@ fun ModelItem(
Column( Column(
verticalArrangement = Arrangement.spacedBy(14.dp), verticalArrangement = Arrangement.spacedBy(14.dp),
horizontalAlignment = Alignment.CenterHorizontally, horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth().padding(vertical = verticalSpacing, horizontal = 18.dp),
.padding(vertical = verticalSpacing, horizontal = 18.dp)
) { ) {
Box(contentAlignment = Alignment.Center) { Box(contentAlignment = Alignment.Center) {
// Icon at the top-center. // Icon at the top-center.
@ -289,7 +292,7 @@ fun ModelItem(
Row( Row(
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth(),
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.End horizontalArrangement = Arrangement.End,
) { ) {
actionButton() actionButton()
expandButton() expandButton()
@ -308,37 +311,37 @@ fun ModelItem(
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun PreviewModelItem() { // fun PreviewModelItem() {
GalleryTheme { // GalleryTheme {
Column( // Column(
verticalArrangement = Arrangement.spacedBy(16.dp), modifier = Modifier.padding(16.dp) // verticalArrangement = Arrangement.spacedBy(16.dp), modifier = Modifier.padding(16.dp)
) { // ) {
ModelItem( // ModelItem(
model = MODEL_TEST1, // model = MODEL_TEST1,
task = TASK_TEST1, // task = TASK_TEST1,
onModelClicked = { }, // onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), // modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
) // )
ModelItem( // ModelItem(
model = MODEL_TEST2, // model = MODEL_TEST2,
task = TASK_TEST1, // task = TASK_TEST1,
onModelClicked = { }, // onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), // modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
) // )
ModelItem( // ModelItem(
model = MODEL_TEST3, // model = MODEL_TEST3,
task = TASK_TEST2, // task = TASK_TEST2,
onModelClicked = { }, // onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), // modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
) // )
ModelItem( // ModelItem(
model = MODEL_TEST4, // model = MODEL_TEST4,
task = TASK_TEST2, // task = TASK_TEST2,
onModelClicked = { }, // onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), // modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
) // )
} // }
} // }
} // }

View file

@ -66,69 +66,50 @@ fun ModelItemActionButton(
Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) { Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) {
when (downloadStatus?.status) { when (downloadStatus?.status) {
// Button to start the download. // Button to start the download.
ModelDownloadStatusType.NOT_DOWNLOADED, ModelDownloadStatusType.FAILED -> ModelDownloadStatusType.NOT_DOWNLOADED,
ModelDownloadStatusType.FAILED ->
if (showDownloadButton) { if (showDownloadButton) {
IconButton(onClick = { IconButton(onClick = { onDownloadClicked(model) }) {
onDownloadClicked(model) Icon(Icons.Rounded.FileDownload, contentDescription = "", tint = getTaskIconColor(task))
}) {
Icon(
Icons.Rounded.FileDownload,
contentDescription = "",
tint = getTaskIconColor(task),
)
} }
} }
// Button to delete the download. // Button to delete the download.
ModelDownloadStatusType.SUCCEEDED -> { ModelDownloadStatusType.SUCCEEDED -> {
if (showDeleteButton) { if (showDeleteButton) {
IconButton(onClick = { IconButton(onClick = { showConfirmDeleteDialog = true }) {
showConfirmDeleteDialog = true Icon(Icons.Rounded.Delete, contentDescription = "", tint = getTaskIconColor(task))
}) {
Icon(
Icons.Rounded.Delete,
contentDescription = "",
tint = getTaskIconColor(task),
)
} }
} }
} }
// Show spinner when the model is partially downloaded because it might some time for // Show spinner when the model is partially downloaded because it might some time for
// background task to be started by Android. // background task to be started by Android.
ModelDownloadStatusType.PARTIALLY_DOWNLOADED -> { ModelDownloadStatusType.PARTIALLY_DOWNLOADED -> {
CircularProgressIndicator( CircularProgressIndicator(modifier = Modifier.padding(end = 12.dp).size(24.dp))
modifier = Modifier
.padding(end = 12.dp)
.size(24.dp)
)
} }
// Button to cancel the download when it is in progress. // Button to cancel the download when it is in progress.
ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = { ModelDownloadStatusType.IN_PROGRESS,
modelManagerViewModel.cancelDownloadModel( ModelDownloadStatusType.UNZIPPING ->
task = task, IconButton(
model = model onClick = { modelManagerViewModel.cancelDownloadModel(task = task, model = model) }
) ) {
}) { Icon(Icons.Rounded.Cancel, contentDescription = "", tint = getTaskIconColor(task))
Icon( }
Icons.Rounded.Cancel,
contentDescription = "",
tint = getTaskIconColor(task),
)
}
else -> {} else -> {}
} }
} }
if (showConfirmDeleteDialog) { if (showConfirmDeleteDialog) {
ConfirmDeleteModelDialog(model = model, onConfirm = { ConfirmDeleteModelDialog(
modelManagerViewModel.deleteModel(task = task, model = model) model = model,
showConfirmDeleteDialog = false onConfirm = {
}, onDismiss = { modelManagerViewModel.deleteModel(task = task, model = model)
showConfirmDeleteDialog = false showConfirmDeleteDialog = false
}) },
onDismiss = { showConfirmDeleteDialog = false },
)
} }
} }

View file

@ -52,8 +52,8 @@ import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
* This function renders the model's name and its current download status, including: * This function renders the model's name and its current download status, including:
* - Model name. * - Model name.
* - Failure message (if download failed). * - Failure message (if download failed).
* - Download progress (received size, total size, download rate, remaining time) for * - Download progress (received size, total size, download rate, remaining time) for in-progress
* in-progress downloads. * downloads.
* - "Unzipping..." status for unzipping processes. * - "Unzipping..." status for unzipping processes.
* - Model size for successful downloads. * - Model size for successful downloads.
*/ */
@ -66,7 +66,7 @@ fun ModelNameAndStatus(
isExpanded: Boolean, isExpanded: Boolean,
sharedTransitionScope: SharedTransitionScope, sharedTransitionScope: SharedTransitionScope,
animatedVisibilityScope: AnimatedVisibilityScope, animatedVisibilityScope: AnimatedVisibilityScope,
modifier: Modifier = Modifier modifier: Modifier = Modifier,
) { ) {
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
@ -77,18 +77,17 @@ fun ModelNameAndStatus(
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
) { ) {
// Model name. // Model name.
Row( Row(verticalAlignment = Alignment.CenterVertically) {
verticalAlignment = Alignment.CenterVertically,
) {
Text( Text(
model.name, model.name,
maxLines = 1, maxLines = 1,
overflow = TextOverflow.MiddleEllipsis, overflow = TextOverflow.MiddleEllipsis,
style = MaterialTheme.typography.titleMedium, style = MaterialTheme.typography.titleMedium,
modifier = Modifier.sharedElement( modifier =
rememberSharedContentState(key = "model_name"), Modifier.sharedElement(
animatedVisibilityScope = animatedVisibilityScope rememberSharedContentState(key = "model_name"),
) animatedVisibilityScope = animatedVisibilityScope,
),
) )
} }
@ -97,12 +96,13 @@ fun ModelNameAndStatus(
if (!inProgress && !isPartiallyDownloaded) { if (!inProgress && !isPartiallyDownloaded) {
StatusIcon( StatusIcon(
downloadStatus = downloadStatus, downloadStatus = downloadStatus,
modifier = modifier modifier =
.padding(end = 4.dp) modifier
.sharedElement( .padding(end = 4.dp)
rememberSharedContentState(key = "download_status_icon"), .sharedElement(
animatedVisibilityScope = animatedVisibilityScope rememberSharedContentState(key = "download_status_icon"),
) animatedVisibilityScope = animatedVisibilityScope,
),
) )
} }
@ -114,10 +114,11 @@ fun ModelNameAndStatus(
color = MaterialTheme.colorScheme.error, color = MaterialTheme.colorScheme.error,
style = labelSmallNarrow, style = labelSmallNarrow,
overflow = TextOverflow.Ellipsis, overflow = TextOverflow.Ellipsis,
modifier = Modifier.sharedElement( modifier =
rememberSharedContentState(key = "failure_messsage"), Modifier.sharedElement(
animatedVisibilityScope = animatedVisibilityScope rememberSharedContentState(key = "failure_messsage"),
) animatedVisibilityScope = animatedVisibilityScope,
),
) )
} }
} }
@ -138,8 +139,7 @@ fun ModelNameAndStatus(
sizeLabel = sizeLabel =
"${downloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}" "${downloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
if (downloadStatus.bytesPerSecond > 0) { if (downloadStatus.bytesPerSecond > 0) {
sizeLabel = sizeLabel = "$sizeLabel · ${downloadStatus.bytesPerSecond.humanReadableSize()} / s"
"$sizeLabel · ${downloadStatus.bytesPerSecond.humanReadableSize()} / s"
if (downloadStatus.remainingMs >= 0) { if (downloadStatus.remainingMs >= 0) {
sizeLabel = sizeLabel =
"$sizeLabel\n${downloadStatus.remainingMs.formatToHourMinSecond()} left" "$sizeLabel\n${downloadStatus.remainingMs.formatToHourMinSecond()} left"
@ -162,7 +162,7 @@ fun ModelNameAndStatus(
} }
Column( Column(
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start, horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
) { ) {
for ((index, line) in sizeLabel.split("\n").withIndex()) { for ((index, line) in sizeLabel.split("\n").withIndex()) {
Text( Text(
@ -172,12 +172,12 @@ fun ModelNameAndStatus(
style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp), style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp),
textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start, textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start,
overflow = TextOverflow.Visible, overflow = TextOverflow.Visible,
modifier = Modifier modifier =
.offset(y = if (index == 0) 0.dp else (-1).dp) Modifier.offset(y = if (index == 0) 0.dp else (-1).dp)
.sharedElement( .sharedElement(
rememberSharedContentState(key = "status_label_${index}"), rememberSharedContentState(key = "status_label_${index}"),
animatedVisibilityScope = animatedVisibilityScope animatedVisibilityScope = animatedVisibilityScope,
) ),
) )
} }
} }
@ -191,12 +191,12 @@ fun ModelNameAndStatus(
progress = { animatedProgress.value }, progress = { animatedProgress.value },
color = getTaskIconColor(task = task), color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier modifier =
.padding(top = 2.dp) Modifier.padding(top = 2.dp)
.sharedElement( .sharedElement(
rememberSharedContentState(key = "download_progress_bar"), rememberSharedContentState(key = "download_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope animatedVisibilityScope = animatedVisibilityScope,
) ),
) )
LaunchedEffect(curDownloadProgress) { LaunchedEffect(curDownloadProgress) {
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150)) animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
@ -207,12 +207,12 @@ fun ModelNameAndStatus(
LinearProgressIndicator( LinearProgressIndicator(
color = getTaskIconColor(task = task), color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest, trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier modifier =
.padding(top = 2.dp) Modifier.padding(top = 2.dp)
.sharedElement( .sharedElement(
rememberSharedContentState(key = "unzip_progress_bar"), rememberSharedContentState(key = "unzip_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope animatedVisibilityScope = animatedVisibilityScope,
) ),
) )
} }
} }

View file

@ -16,8 +16,9 @@
package com.google.ai.edge.gallery.ui.common.modelitem package com.google.ai.edge.gallery.ui.common.modelitem
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
@ -31,75 +32,71 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.ModelDownloadStatus import com.google.ai.edge.gallery.data.ModelDownloadStatus
import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors import com.google.ai.edge.gallery.ui.theme.customColors
private val SIZE = 18.dp private val SIZE = 18.dp
/** /** Composable function to display an icon representing the download status of a model. */
* Composable function to display an icon representing the download status of a model.
*/
@Composable @Composable
fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifier) { fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifier) {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.Center, horizontalArrangement = Arrangement.Center,
modifier = modifier modifier = modifier,
) { ) {
when (downloadStatus?.status) { when (downloadStatus?.status) {
ModelDownloadStatusType.NOT_DOWNLOADED -> Icon( ModelDownloadStatusType.NOT_DOWNLOADED ->
Icons.AutoMirrored.Outlined.HelpOutline, Icon(
tint = Color(0xFFCCCCCC), Icons.AutoMirrored.Outlined.HelpOutline,
contentDescription = "", tint = Color(0xFFCCCCCC),
modifier = Modifier.size(SIZE) contentDescription = "",
) modifier = Modifier.size(SIZE),
)
ModelDownloadStatusType.SUCCEEDED -> { ModelDownloadStatusType.SUCCEEDED -> {
Icon( Icon(
Icons.Filled.DownloadForOffline, Icons.Filled.DownloadForOffline,
tint = MaterialTheme.customColors.successColor, tint = MaterialTheme.customColors.successColor,
contentDescription = "", contentDescription = "",
modifier = Modifier.size(SIZE) modifier = Modifier.size(SIZE),
) )
} }
ModelDownloadStatusType.FAILED -> Icon( ModelDownloadStatusType.FAILED ->
Icons.Rounded.Error, Icon(
tint = Color(0xFFAA0000), Icons.Rounded.Error,
contentDescription = "", tint = Color(0xFFAA0000),
modifier = Modifier.size(SIZE) contentDescription = "",
) modifier = Modifier.size(SIZE),
)
ModelDownloadStatusType.IN_PROGRESS -> Icon( ModelDownloadStatusType.IN_PROGRESS ->
Icons.Rounded.Downloading, Icon(Icons.Rounded.Downloading, contentDescription = "", modifier = Modifier.size(SIZE))
contentDescription = "",
modifier = Modifier.size(SIZE)
)
else -> {} else -> {}
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun StatusIconPreview() { // fun StatusIconPreview() {
GalleryTheme { // GalleryTheme {
Column { // Column {
for (downloadStatus in listOf( // for (downloadStatus in
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED), // listOf(
ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS), // ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED),
ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED), // ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS),
ModelDownloadStatus(status = ModelDownloadStatusType.FAILED), // ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED),
ModelDownloadStatus(status = ModelDownloadStatusType.UNZIPPING), // ModelDownloadStatus(status = ModelDownloadStatusType.FAILED),
ModelDownloadStatus(status = ModelDownloadStatusType.PARTIALLY_DOWNLOADED), // ModelDownloadStatus(status = ModelDownloadStatusType.UNZIPPING),
)) { // ModelDownloadStatus(status = ModelDownloadStatusType.PARTIALLY_DOWNLOADED),
StatusIcon(downloadStatus = downloadStatus) // )) {
} // StatusIcon(downloadStatus = downloadStatus)
} // }
} // }
} // }
// }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.home package com.google.ai.edge.gallery.ui.home
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import android.content.Context import android.content.Context
import android.content.Intent import android.content.Intent
import android.net.Uri import android.net.Uri
@ -95,20 +98,17 @@ import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextDecoration import androidx.compose.ui.text.style.TextDecoration
import androidx.compose.ui.text.withLink import androidx.compose.ui.text.withLink
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.GalleryTopAppBar import com.google.ai.edge.gallery.GalleryTopAppBar
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.AppBarAction import com.google.ai.edge.gallery.data.AppBarAction
import com.google.ai.edge.gallery.data.AppBarActionType import com.google.ai.edge.gallery.data.AppBarActionType
import com.google.ai.edge.gallery.data.ImportedModelInfo
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.proto.ImportedModel
import com.google.ai.edge.gallery.ui.common.TaskIcon import com.google.ai.edge.gallery.ui.common.TaskIcon
import com.google.ai.edge.gallery.ui.common.getTaskBgColor import com.google.ai.edge.gallery.ui.common.getTaskBgColor
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors import com.google.ai.edge.gallery.ui.theme.customColors
import com.google.ai.edge.gallery.ui.theme.titleMediumNarrow import com.google.ai.edge.gallery.ui.theme.titleMediumNarrow
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
@ -125,8 +125,7 @@ private const val MIN_TASK_CARD_ICON_SIZE = 50
/** Navigation destination data */ /** Navigation destination data */
object HomeScreenDestination { object HomeScreenDestination {
@StringRes @StringRes val titleRes = R.string.app_name
val titleRes = R.string.app_name
} }
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@ -134,7 +133,7 @@ object HomeScreenDestination {
fun HomeScreen( fun HomeScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateToTaskScreen: (Task) -> Unit, navigateToTaskScreen: (Task) -> Unit,
modifier: Modifier = Modifier modifier: Modifier = Modifier,
) { ) {
val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior() val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior()
val uiState by modelManagerViewModel.uiState.collectAsState() val uiState by modelManagerViewModel.uiState.collectAsState()
@ -145,53 +144,56 @@ fun HomeScreen(
var showImportDialog by remember { mutableStateOf(false) } var showImportDialog by remember { mutableStateOf(false) }
var showImportingDialog by remember { mutableStateOf(false) } var showImportingDialog by remember { mutableStateOf(false) }
val selectedLocalModelFileUri = remember { mutableStateOf<Uri?>(null) } val selectedLocalModelFileUri = remember { mutableStateOf<Uri?>(null) }
val selectedImportedModelInfo = remember { mutableStateOf<ImportedModelInfo?>(null) } val selectedImportedModelInfo = remember { mutableStateOf<ImportedModel?>(null) }
val coroutineScope = rememberCoroutineScope() val coroutineScope = rememberCoroutineScope()
val snackbarHostState = remember { SnackbarHostState() } val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val context = LocalContext.current val context = LocalContext.current
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult( val filePickerLauncher: ActivityResultLauncher<Intent> =
contract = ActivityResultContracts.StartActivityForResult() rememberLauncherForActivityResult(
) { result -> contract = ActivityResultContracts.StartActivityForResult()
if (result.resultCode == android.app.Activity.RESULT_OK) { ) { result ->
result.data?.data?.let { uri -> if (result.resultCode == android.app.Activity.RESULT_OK) {
val fileName = getFileName(context = context, uri = uri) result.data?.data?.let { uri ->
Log.d(TAG, "Selected file: $fileName") val fileName = getFileName(context = context, uri = uri)
if (fileName != null && !fileName.endsWith(".task")) { Log.d(TAG, "Selected file: $fileName")
showUnsupportedFileTypeDialog = true if (fileName != null && !fileName.endsWith(".task")) {
} else { showUnsupportedFileTypeDialog = true
selectedLocalModelFileUri.value = uri } else {
showImportDialog = true selectedLocalModelFileUri.value = uri
} showImportDialog = true
} ?: run { }
Log.d(TAG, "No file selected or URI is null.") } ?: run { Log.d(TAG, "No file selected or URI is null.") }
} else {
Log.d(TAG, "File picking cancelled.")
} }
} else {
Log.d(TAG, "File picking cancelled.")
} }
}
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = { Scaffold(
GalleryTopAppBar( modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
title = stringResource(HomeScreenDestination.titleRes), topBar = {
rightAction = AppBarAction(actionType = AppBarActionType.APP_SETTING, actionFn = { GalleryTopAppBar(
showSettingsDialog = true title = stringResource(HomeScreenDestination.titleRes),
}), rightAction =
scrollBehavior = scrollBehavior, AppBarAction(
) actionType = AppBarActionType.APP_SETTING,
}, floatingActionButton = { actionFn = { showSettingsDialog = true },
// A floating action button to show "import model" bottom sheet. ),
SmallFloatingActionButton( scrollBehavior = scrollBehavior,
onClick = { )
showImportModelSheet = true },
}, floatingActionButton = {
containerColor = MaterialTheme.colorScheme.secondaryContainer, // A floating action button to show "import model" bottom sheet.
contentColor = MaterialTheme.colorScheme.secondary, SmallFloatingActionButton(
) { onClick = { showImportModelSheet = true },
Icon(Icons.Filled.Add, "") containerColor = MaterialTheme.colorScheme.secondaryContainer,
} contentColor = MaterialTheme.colorScheme.secondary,
}) { innerPadding -> ) {
Icon(Icons.Filled.Add, "")
}
},
) { innerPadding ->
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) { Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
TaskList( TaskList(
tasks = uiState.tasks, tasks = uiState.tasks,
@ -216,37 +218,36 @@ fun HomeScreen(
// Import model bottom sheet. // Import model bottom sheet.
if (showImportModelSheet) { if (showImportModelSheet) {
ModalBottomSheet( ModalBottomSheet(onDismissRequest = { showImportModelSheet = false }, sheetState = sheetState) {
onDismissRequest = { showImportModelSheet = false },
sheetState = sheetState,
) {
Text( Text(
"Import model", "Import model",
style = MaterialTheme.typography.titleLarge, style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp) modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp),
) )
Box(modifier = Modifier.clickable { Box(
coroutineScope.launch { modifier =
// Give it sometime to show the click effect. Modifier.clickable {
delay(200) coroutineScope.launch {
showImportModelSheet = false // Give it sometime to show the click effect.
delay(200)
showImportModelSheet = false
// Show file picker. // Show file picker.
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply { val intent =
addCategory(Intent.CATEGORY_OPENABLE) Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
type = "*/*" addCategory(Intent.CATEGORY_OPENABLE)
// Single select. type = "*/*"
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false) // Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
}
filePickerLauncher.launch(intent)
}
} }
filePickerLauncher.launch(intent) ) {
}
}) {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp), horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(16.dp),
.fillMaxWidth()
.padding(16.dp)
) { ) {
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "") Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
Text("From local model file") Text("From local model file")
@ -258,11 +259,15 @@ fun HomeScreen(
// Import dialog // Import dialog
if (showImportDialog) { if (showImportDialog) {
selectedLocalModelFileUri.value?.let { uri -> selectedLocalModelFileUri.value?.let { uri ->
ModelImportDialog(uri = uri, onDismiss = { showImportDialog = false }, onDone = { info -> ModelImportDialog(
selectedImportedModelInfo.value = info uri = uri,
showImportDialog = false onDismiss = { showImportDialog = false },
showImportingDialog = true onDone = { info ->
}) selectedImportedModelInfo.value = info
showImportDialog = false
showImportingDialog = true
},
)
} }
} }
@ -270,20 +275,18 @@ fun HomeScreen(
if (showImportingDialog) { if (showImportingDialog) {
selectedLocalModelFileUri.value?.let { uri -> selectedLocalModelFileUri.value?.let { uri ->
selectedImportedModelInfo.value?.let { info -> selectedImportedModelInfo.value?.let { info ->
ModelImportingDialog(uri = uri, ModelImportingDialog(
uri = uri,
info = info, info = info,
onDismiss = { showImportingDialog = false }, onDismiss = { showImportingDialog = false },
onDone = { onDone = {
modelManagerViewModel.addImportedLlmModel( modelManagerViewModel.addImportedLlmModel(info = it)
info = it,
)
showImportingDialog = false showImportingDialog = false
// Show a snack bar for successful import. // Show a snack bar for successful import.
scope.launch { scope.launch { snackbarHostState.showSnackbar("Model imported successfully") }
snackbarHostState.showSnackbar("Model imported successfully") },
} )
})
} }
} }
} }
@ -293,9 +296,7 @@ fun HomeScreen(
AlertDialog( AlertDialog(
onDismissRequest = { showUnsupportedFileTypeDialog = false }, onDismissRequest = { showUnsupportedFileTypeDialog = false },
title = { Text("Unsupported file type") }, title = { Text("Unsupported file type") },
text = { text = { Text("Only \".task\" file type is supported.") },
Text("Only \".task\" file type is supported.")
},
confirmButton = { confirmButton = {
Button(onClick = { showUnsupportedFileTypeDialog = false }) { Button(onClick = { showUnsupportedFileTypeDialog = false }) {
Text(stringResource(R.string.ok)) Text(stringResource(R.string.ok))
@ -309,21 +310,11 @@ fun HomeScreen(
icon = { icon = {
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error) Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
}, },
title = { title = { Text(uiState.loadingModelAllowlistError) },
Text(uiState.loadingModelAllowlistError) text = { Text("Please check your internet connection and try again later.") },
}, onDismissRequest = { modelManagerViewModel.loadModelAllowlist() },
text = {
Text("Please check your internet connection and try again later.")
},
onDismissRequest = {
modelManagerViewModel.loadModelAllowlist()
},
confirmButton = { confirmButton = {
TextButton(onClick = { TextButton(onClick = { modelManagerViewModel.loadModelAllowlist() }) { Text("Retry") }
modelManagerViewModel.loadModelAllowlist()
}) {
Text("Retry")
}
}, },
) )
} }
@ -339,31 +330,22 @@ private fun TaskList(
) { ) {
val density = LocalDensity.current val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current val windowInfo = LocalWindowInfo.current
val screenWidthDp = remember { val screenWidthDp = remember { with(density) { windowInfo.containerSize.width.toDp() } }
with(density) { val screenHeightDp = remember { with(density) { windowInfo.containerSize.height.toDp() } }
windowInfo.containerSize.width.toDp()
}
}
val screenHeightDp = remember {
with(density) {
windowInfo.containerSize.height.toDp()
}
}
val sizeFraction = remember { ((screenWidthDp - 360.dp) / (410.dp - 360.dp)).coerceIn(0f, 1f) } val sizeFraction = remember { ((screenWidthDp - 360.dp) / (410.dp - 360.dp)).coerceIn(0f, 1f) }
val linkColor = MaterialTheme.customColors.linkColor val linkColor = MaterialTheme.customColors.linkColor
val introText = buildAnnotatedString { val introText = buildAnnotatedString {
append("Welcome to Google AI Edge Gallery! Explore a world of amazing on-device models from ") append("Welcome to Google AI Edge Gallery! Explore a world of amazing on-device models from ")
withLink( withLink(
link = LinkAnnotation.Url( link =
url = "https://huggingface.co/litert-community", // Replace with the actual URL LinkAnnotation.Url(
styles = TextLinkStyles( url = "https://huggingface.co/litert-community", // Replace with the actual URL
style = SpanStyle( styles =
color = linkColor, TextLinkStyles(
textDecoration = TextDecoration.Underline, style = SpanStyle(color = linkColor, textDecoration = TextDecoration.Underline)
) ),
) )
)
) { ) {
append("LiteRT community") append("LiteRT community")
} }
@ -378,9 +360,7 @@ private fun TaskList(
verticalArrangement = Arrangement.spacedBy(8.dp), verticalArrangement = Arrangement.spacedBy(8.dp),
) { ) {
// New rel // New rel
item(key = "newReleaseNotification", span = { GridItemSpan(2) }) { item(key = "newReleaseNotification", span = { GridItemSpan(2) }) { NewReleaseNotification() }
NewReleaseNotification()
}
// Headline. // Headline.
item(key = "headline", span = { GridItemSpan(2) }) { item(key = "headline", span = { GridItemSpan(2) }) {
@ -388,7 +368,7 @@ private fun TaskList(
introText, introText,
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.SemiBold), style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.padding(bottom = 20.dp).padding(horizontal = 16.dp) modifier = Modifier.padding(bottom = 20.dp).padding(horizontal = 16.dp),
) )
} }
@ -396,16 +376,12 @@ private fun TaskList(
item(key = "loading", span = { GridItemSpan(2) }) { item(key = "loading", span = { GridItemSpan(2) }) {
Row( Row(
horizontalArrangement = Arrangement.Center, horizontalArrangement = Arrangement.Center,
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(top = 32.dp),
.fillMaxWidth()
.padding(top = 32.dp)
) { ) {
CircularProgressIndicator( CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceVariant, trackColor = MaterialTheme.colorScheme.surfaceVariant,
strokeWidth = 3.dp, strokeWidth = 3.dp,
modifier = Modifier modifier = Modifier.padding(end = 8.dp).size(20.dp),
.padding(end = 8.dp)
.size(20.dp)
) )
Text("Loading model list...", style = MaterialTheme.typography.bodyMedium) Text("Loading model list...", style = MaterialTheme.typography.bodyMedium)
} }
@ -417,17 +393,16 @@ private fun TaskList(
"Example LLM Use Cases", "Example LLM Use Cases",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold), style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
color = MaterialTheme.colorScheme.onSurfaceVariant, color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.padding(bottom = 4.dp) modifier = Modifier.padding(bottom = 4.dp),
) )
} }
items(tasks) { task -> items(tasks) { task ->
TaskCard( TaskCard(
sizeFraction = sizeFraction, task = task, onClick = { sizeFraction = sizeFraction,
navigateToTaskScreen(task) task = task,
}, modifier = Modifier onClick = { navigateToTaskScreen(task) },
.fillMaxWidth() modifier = Modifier.fillMaxWidth().aspectRatio(1f),
.aspectRatio(1f)
) )
} }
} }
@ -440,22 +415,23 @@ private fun TaskList(
// Gradient overlay at the bottom. // Gradient overlay at the bottom.
Box( Box(
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.height(screenHeightDp * 0.25f) .height(screenHeightDp * 0.25f)
.background( .background(
Brush.verticalGradient( Brush.verticalGradient(colors = MaterialTheme.customColors.homeBottomGradient)
colors = MaterialTheme.customColors.homeBottomGradient,
) )
) .align(Alignment.BottomCenter)
.align(Alignment.BottomCenter)
) )
} }
} }
@Composable @Composable
private fun TaskCard( private fun TaskCard(
task: Task, onClick: () -> Unit, sizeFraction: Float, modifier: Modifier = Modifier task: Task,
onClick: () -> Unit,
sizeFraction: Float,
modifier: Modifier = Modifier,
) { ) {
val padding = val padding =
(MAX_TASK_CARD_PADDING - MIN_TASK_CARD_PADDING) * sizeFraction + MIN_TASK_CARD_PADDING (MAX_TASK_CARD_PADDING - MIN_TASK_CARD_PADDING) * sizeFraction + MIN_TASK_CARD_PADDING
@ -485,14 +461,16 @@ private fun TaskCard(
} }
var curModelCountLabel by remember { mutableStateOf("") } var curModelCountLabel by remember { mutableStateOf("") }
var modelCountLabelVisible by remember { mutableStateOf(true) } var modelCountLabelVisible by remember { mutableStateOf(true) }
val modelCountAlpha: Float by animateFloatAsState( val modelCountAlpha: Float by
targetValue = if (modelCountLabelVisible) 1f else 0f, animateFloatAsState(
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION) targetValue = if (modelCountLabelVisible) 1f else 0f,
) animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION),
val modelCountScale: Float by animateFloatAsState( )
targetValue = if (modelCountLabelVisible) 1f else 0.7f, val modelCountScale: Float by
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION) animateFloatAsState(
) targetValue = if (modelCountLabelVisible) 1f else 0.7f,
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION),
)
LaunchedEffect(modelCountLabel) { LaunchedEffect(modelCountLabel) {
if (curModelCountLabel.isEmpty()) { if (curModelCountLabel.isEmpty()) {
@ -506,20 +484,10 @@ private fun TaskCard(
} }
Card( Card(
modifier = modifier modifier = modifier.clip(RoundedCornerShape(radius.dp)).clickable(onClick = onClick),
.clip(RoundedCornerShape(radius.dp)) colors = CardDefaults.cardColors(containerColor = getTaskBgColor(task = task)),
.clickable(
onClick = onClick,
),
colors = CardDefaults.cardColors(
containerColor = getTaskBgColor(task = task)
),
) { ) {
Column( Column(modifier = Modifier.fillMaxSize().padding(padding.dp)) {
modifier = Modifier
.fillMaxSize()
.padding(padding.dp),
) {
// Icon. // Icon.
TaskIcon(task = task, width = iconSize.dp) TaskIcon(task = task, width = iconSize.dp)
@ -529,10 +497,7 @@ private fun TaskCard(
Text( Text(
task.type.label, task.type.label,
color = MaterialTheme.colorScheme.primary, color = MaterialTheme.colorScheme.primary,
style = titleMediumNarrow.copy( style = titleMediumNarrow.copy(fontSize = 20.sp, fontWeight = FontWeight.Bold),
fontSize = 20.sp,
fontWeight = FontWeight.Bold,
),
) )
Spacer(modifier = Modifier.weight(1f)) Spacer(modifier = Modifier.weight(1f))
@ -542,9 +507,7 @@ private fun TaskCard(
curModelCountLabel, curModelCountLabel,
color = MaterialTheme.colorScheme.secondary, color = MaterialTheme.colorScheme.secondary,
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
modifier = Modifier modifier = Modifier.alpha(modelCountAlpha).scale(modelCountScale),
.alpha(modelCountAlpha)
.scale(modelCountScale),
) )
} }
} }
@ -567,15 +530,13 @@ fun getFileName(context: Context, uri: Uri): String? {
return null return null
} }
@Preview // @Preview
@Composable // @Composable
fun HomeScreenPreview( // fun HomeScreenPreview() {
) { // GalleryTheme {
GalleryTheme { // HomeScreen(
HomeScreen( // modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), // navigateToTaskScreen = {},
navigateToTaskScreen = {}, // )
) // }
} // }
}

View file

@ -63,69 +63,74 @@ import com.google.ai.edge.gallery.data.Accelerator
import com.google.ai.edge.gallery.data.BooleanSwitchConfig import com.google.ai.edge.gallery.data.BooleanSwitchConfig
import com.google.ai.edge.gallery.data.Config import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.DEFAULT_MAX_TOKEN
import com.google.ai.edge.gallery.data.DEFAULT_TEMPERATURE
import com.google.ai.edge.gallery.data.DEFAULT_TOPK
import com.google.ai.edge.gallery.data.DEFAULT_TOPP
import com.google.ai.edge.gallery.data.IMPORTS_DIR import com.google.ai.edge.gallery.data.IMPORTS_DIR
import com.google.ai.edge.gallery.data.LabelConfig import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.ImportedModelInfo
import com.google.ai.edge.gallery.data.NumberSliderConfig import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.ValueType import com.google.ai.edge.gallery.data.ValueType
import com.google.ai.edge.gallery.ui.common.chat.ConfigEditorsPanel import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.proto.ImportedModel
import com.google.ai.edge.gallery.proto.LlmConfig
import com.google.ai.edge.gallery.ui.common.ConfigEditorsPanel
import com.google.ai.edge.gallery.ui.common.ensureValidFileName import com.google.ai.edge.gallery.ui.common.ensureValidFileName
import com.google.ai.edge.gallery.ui.common.humanReadableSize import com.google.ai.edge.gallery.ui.common.humanReadableSize
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_MAX_TOKEN
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TOPK
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TOPP
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import java.io.File import java.io.File
import java.io.FileOutputStream import java.io.FileOutputStream
import java.net.URLDecoder import java.net.URLDecoder
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
private const val TAG = "AGModelImportDialog" private const val TAG = "AGModelImportDialog"
private val IMPORT_CONFIGS_LLM: List<Config> = listOf( private val IMPORT_CONFIGS_LLM: List<Config> =
LabelConfig(key = ConfigKey.NAME), LabelConfig(key = ConfigKey.MODEL_TYPE), NumberSliderConfig( listOf(
key = ConfigKey.DEFAULT_MAX_TOKENS, LabelConfig(key = ConfigKey.NAME),
sliderMin = 100f, LabelConfig(key = ConfigKey.MODEL_TYPE),
sliderMax = 1024f, NumberSliderConfig(
defaultValue = DEFAULT_MAX_TOKEN.toFloat(), key = ConfigKey.DEFAULT_MAX_TOKENS,
valueType = ValueType.INT sliderMin = 100f,
), NumberSliderConfig( sliderMax = 1024f,
key = ConfigKey.DEFAULT_TOPK, defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
sliderMin = 5f, valueType = ValueType.INT,
sliderMax = 40f, ),
defaultValue = DEFAULT_TOPK.toFloat(), NumberSliderConfig(
valueType = ValueType.INT key = ConfigKey.DEFAULT_TOPK,
), NumberSliderConfig( sliderMin = 5f,
key = ConfigKey.DEFAULT_TOPP, sliderMax = 40f,
sliderMin = 0.0f, defaultValue = DEFAULT_TOPK.toFloat(),
sliderMax = 1.0f, valueType = ValueType.INT,
defaultValue = DEFAULT_TOPP, ),
valueType = ValueType.FLOAT NumberSliderConfig(
), NumberSliderConfig( key = ConfigKey.DEFAULT_TOPP,
key = ConfigKey.DEFAULT_TEMPERATURE, sliderMin = 0.0f,
sliderMin = 0.0f, sliderMax = 1.0f,
sliderMax = 2.0f, defaultValue = DEFAULT_TOPP,
defaultValue = DEFAULT_TEMPERATURE, valueType = ValueType.FLOAT,
valueType = ValueType.FLOAT ),
), BooleanSwitchConfig( NumberSliderConfig(
key = ConfigKey.SUPPORT_IMAGE, key = ConfigKey.DEFAULT_TEMPERATURE,
defaultValue = false, sliderMin = 0.0f,
), SegmentedButtonConfig( sliderMax = 2.0f,
key = ConfigKey.COMPATIBLE_ACCELERATORS, defaultValue = DEFAULT_TEMPERATURE,
defaultValue = Accelerator.CPU.label, valueType = ValueType.FLOAT,
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label), ),
allowMultiple = true, BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false),
SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
allowMultiple = true,
),
) )
)
@Composable @Composable
fun ModelImportDialog( fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -> Unit) {
uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
) {
val context = LocalContext.current val context = LocalContext.current
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) } val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
val fileSize by remember { mutableLongStateOf(info.first) } val fileSize by remember { mutableLongStateOf(info.first) }
@ -142,78 +147,110 @@ fun ModelImportDialog(
} }
} }
val values: SnapshotStateMap<String, Any> = remember { val values: SnapshotStateMap<String, Any> = remember {
mutableStateMapOf<String, Any>().apply { mutableStateMapOf<String, Any>().apply { putAll(initialValues) }
putAll(initialValues)
}
} }
val interactionSource = remember { MutableInteractionSource() } val interactionSource = remember { MutableInteractionSource() }
Dialog( Dialog(onDismissRequest = onDismiss) {
onDismissRequest = onDismiss,
) {
val focusManager = LocalFocusManager.current val focusManager = LocalFocusManager.current
Card( Card(
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth().clickable(
.clickable( interactionSource = interactionSource,
interactionSource = interactionSource, indication = null // Disable the ripple effect indication = null, // Disable the ripple effect
) { ) {
focusManager.clearFocus() focusManager.clearFocus()
}, shape = RoundedCornerShape(16.dp) },
shape = RoundedCornerShape(16.dp),
) { ) {
Column( Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) { ) {
// Title. // Title.
Text( Text(
"Import Model", "Import Model",
style = MaterialTheme.typography.titleLarge, style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp) modifier = Modifier.padding(bottom = 8.dp),
) )
Column( Column(
modifier = Modifier modifier = Modifier.verticalScroll(rememberScrollState()).weight(1f, fill = false),
.verticalScroll(rememberScrollState()) verticalArrangement = Arrangement.spacedBy(16.dp),
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
) { ) {
// Default configs for users to set. // Default configs for users to set.
ConfigEditorsPanel( ConfigEditorsPanel(configs = IMPORT_CONFIGS_LLM, values = values)
configs = IMPORT_CONFIGS_LLM,
values = values,
)
} }
// Button row. // Button row.
Row( Row(
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(top = 8.dp),
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End, horizontalArrangement = Arrangement.End,
) { ) {
// Cancel button. // Cancel button.
TextButton( TextButton(onClick = { onDismiss() }) { Text("Cancel") }
onClick = { onDismiss() },
) {
Text("Cancel")
}
// Import button // Import button
Button( Button(
onClick = { onClick = {
onDone( val supportedAccelerators =
ImportedModelInfo( (convertValueToTargetType(
fileName = fileName, value = values.get(ConfigKey.COMPATIBLE_ACCELERATORS.label)!!,
fileSize = fileSize, valueType = ValueType.STRING,
defaultValues = values, )
as String)
.split(",")
val defaultMaxTokens =
convertValueToTargetType(
value = values.get(ConfigKey.DEFAULT_MAX_TOKENS.label)!!,
valueType = ValueType.INT,
) )
) as Int
}, val defaultTopk =
convertValueToTargetType(
value = values.get(ConfigKey.DEFAULT_TOPK.label)!!,
valueType = ValueType.INT,
)
as Int
val defaultTopp =
convertValueToTargetType(
value = values.get(ConfigKey.DEFAULT_TOPP.label)!!,
valueType = ValueType.FLOAT,
)
as Float
val defaultTemperature =
convertValueToTargetType(
value = values.get(ConfigKey.DEFAULT_TEMPERATURE.label)!!,
valueType = ValueType.FLOAT,
)
as Float
val supportImage =
convertValueToTargetType(
value = values.get(ConfigKey.SUPPORT_IMAGE.label)!!,
valueType = ValueType.BOOLEAN,
)
as Boolean
val importedModel: ImportedModel =
ImportedModel.newBuilder()
.setFileName(fileName)
.setFileSize(fileSize)
.setLlmConfig(
LlmConfig.newBuilder()
.addAllCompatibleAccelerators(supportedAccelerators)
.setDefaultMaxTokens(defaultMaxTokens)
.setDefaultTopk(defaultTopk)
.setDefaultTopp(defaultTopp)
.setDefaultTemperature(defaultTemperature)
.setSupportImage(supportImage)
.build()
)
.build()
onDone(importedModel)
}
) { ) {
Text("Import") Text("Import")
} }
} }
} }
} }
} }
@ -221,7 +258,10 @@ fun ModelImportDialog(
@Composable @Composable
fun ModelImportingDialog( fun ModelImportingDialog(
uri: Uri, info: ImportedModelInfo, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit uri: Uri,
info: ImportedModel,
onDismiss: () -> Unit,
onDone: (ImportedModel) -> Unit,
) { ) {
var error by remember { mutableStateOf("") } var error by remember { mutableStateOf("") }
val context = LocalContext.current val context = LocalContext.current
@ -230,20 +270,16 @@ fun ModelImportingDialog(
LaunchedEffect(Unit) { LaunchedEffect(Unit) {
// Import. // Import.
importModel(context = context, importModel(
context = context,
coroutineScope = coroutineScope, coroutineScope = coroutineScope,
fileName = info.fileName, fileName = info.fileName,
fileSize = info.fileSize, fileSize = info.fileSize,
uri = uri, uri = uri,
onDone = { onDone = { onDone(info) },
onDone(info) onProgress = { progress = it },
}, onError = { error = it },
onProgress = { )
progress = it
},
onError = {
error = it
})
} }
Dialog( Dialog(
@ -252,13 +288,14 @@ fun ModelImportingDialog(
) { ) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) { Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column( Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) { ) {
// Title. // Title.
Text( Text(
"Import Model", "Import Model",
style = MaterialTheme.typography.titleLarge, style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp) modifier = Modifier.padding(bottom = 8.dp),
) )
// No error. // No error.
@ -272,9 +309,7 @@ fun ModelImportingDialog(
val animatedProgress = remember { Animatable(0f) } val animatedProgress = remember { Animatable(0f) }
LinearProgressIndicator( LinearProgressIndicator(
progress = { animatedProgress.value }, progress = { animatedProgress.value },
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(bottom = 8.dp),
.fillMaxWidth()
.padding(bottom = 8.dp),
) )
LaunchedEffect(progress) { LaunchedEffect(progress) {
animatedProgress.animateTo(progress, animationSpec = tween(150)) animatedProgress.animateTo(progress, animationSpec = tween(150))
@ -284,24 +319,23 @@ fun ModelImportingDialog(
// Has error. // Has error.
else { else {
Row( Row(
verticalAlignment = Alignment.Top, horizontalArrangement = Arrangement.spacedBy(6.dp) verticalAlignment = Alignment.Top,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) { ) {
Icon( Icon(
Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error Icons.Rounded.Error,
contentDescription = "",
tint = MaterialTheme.colorScheme.error,
) )
Text( Text(
error, error,
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.error, color = MaterialTheme.colorScheme.error,
modifier = Modifier.padding(top = 4.dp) modifier = Modifier.padding(top = 4.dp),
) )
} }
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) { Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button(onClick = { Button(onClick = { onDismiss() }) { Text("Close") }
onDismiss()
}) {
Text("Close")
}
} }
} }
} }
@ -376,17 +410,17 @@ private fun getFileSizeAndDisplayNameFromUri(context: Context, uri: Uri): Pair<L
var displayName = "" var displayName = ""
try { try {
contentResolver.query( contentResolver
uri, arrayOf(OpenableColumns.SIZE, OpenableColumns.DISPLAY_NAME), null, null, null .query(uri, arrayOf(OpenableColumns.SIZE, OpenableColumns.DISPLAY_NAME), null, null, null)
)?.use { cursor -> ?.use { cursor ->
if (cursor.moveToFirst()) { if (cursor.moveToFirst()) {
val sizeIndex = cursor.getColumnIndexOrThrow(OpenableColumns.SIZE) val sizeIndex = cursor.getColumnIndexOrThrow(OpenableColumns.SIZE)
fileSize = cursor.getLong(sizeIndex) fileSize = cursor.getLong(sizeIndex)
val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME) val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME)
displayName = cursor.getString(nameIndex) displayName = cursor.getString(nameIndex)
}
} }
}
} catch (e: Exception) { } catch (e: Exception) {
e.printStackTrace() e.printStackTrace()
return Pair(0L, "") return Pair(0L, "")

View file

@ -1,3 +1,19 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.home package com.google.ai.edge.gallery.ui.home
import android.util.Log import android.util.Log
@ -29,22 +45,17 @@ import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.LifecycleOwner import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import com.google.ai.edge.gallery.BuildConfig import com.google.ai.edge.gallery.BuildConfig
import com.google.ai.edge.gallery.ui.common.getJsonResponse import com.google.ai.edge.gallery.common.getJsonResponse
import com.google.ai.edge.gallery.ui.modelmanager.ClickableLink import com.google.ai.edge.gallery.ui.common.ClickableLink
import kotlin.math.max
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import kotlin.math.max
private const val TAG = "AGNewReleaseNotification" private const val TAG = "AGNewReleaseNotifi"
private const val REPO = "google-ai-edge/gallery" private const val REPO = "google-ai-edge/gallery"
@Serializable data class ReleaseInfo(val html_url: String, val tag_name: String)
data class ReleaseInfo(
val html_url: String,
val tag_name: String,
)
@Composable @Composable
fun NewReleaseNotification() { fun NewReleaseNotification() {
@ -84,35 +95,31 @@ fun NewReleaseNotification() {
lifecycleOwner.lifecycle.addObserver(observer) lifecycleOwner.lifecycle.addObserver(observer)
onDispose { onDispose { lifecycleOwner.lifecycle.removeObserver(observer) }
lifecycleOwner.lifecycle.removeObserver(observer)
}
} }
AnimatedVisibility( AnimatedVisibility(
visible = newReleaseVersion.isNotEmpty(), visible = newReleaseVersion.isNotEmpty(),
enter = fadeIn() + expandVertically() enter = fadeIn() + expandVertically(),
) { ) {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween, horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier modifier =
.padding(horizontal = 16.dp) Modifier.padding(horizontal = 16.dp)
.padding(bottom = 12.dp) .padding(bottom = 12.dp)
.clip( .clip(CircleShape)
CircleShape .background(MaterialTheme.colorScheme.tertiaryContainer)
) .padding(4.dp),
.background(MaterialTheme.colorScheme.tertiaryContainer)
.padding(4.dp)
) { ) {
Text( Text(
"New release $newReleaseVersion available", "New release $newReleaseVersion available",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
modifier = Modifier.padding(start = 12.dp) modifier = Modifier.padding(start = 12.dp),
) )
Row( Row(
modifier = Modifier.padding(end = 12.dp), modifier = Modifier.padding(end = 12.dp),
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically,
) { ) {
ClickableLink( ClickableLink(
url = newReleaseUrl, url = newReleaseUrl,

View file

@ -61,10 +61,8 @@ import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog import androidx.compose.ui.window.Dialog
import com.google.ai.edge.gallery.BuildConfig import com.google.ai.edge.gallery.BuildConfig
import com.google.ai.edge.gallery.proto.Theme
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.THEME_AUTO
import com.google.ai.edge.gallery.ui.theme.THEME_DARK
import com.google.ai.edge.gallery.ui.theme.THEME_LIGHT
import com.google.ai.edge.gallery.ui.theme.ThemeSettings import com.google.ai.edge.gallery.ui.theme.ThemeSettings
import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
import java.time.Instant import java.time.Instant
@ -73,18 +71,19 @@ import java.time.format.DateTimeFormatter
import java.util.Locale import java.util.Locale
import kotlin.math.min import kotlin.math.min
private val THEME_OPTIONS = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK) private val THEME_OPTIONS = listOf(Theme.THEME_AUTO, Theme.THEME_LIGHT, Theme.THEME_DARK)
@Composable @Composable
fun SettingsDialog( fun SettingsDialog(
curThemeOverride: String, curThemeOverride: Theme,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
onDismissed: () -> Unit, onDismissed: () -> Unit,
) { ) {
var selectedTheme by remember { mutableStateOf(curThemeOverride) } var selectedTheme by remember { mutableStateOf(curThemeOverride) }
var hfToken by remember { mutableStateOf(modelManagerViewModel.getTokenStatusAndData().data) } var hfToken by remember { mutableStateOf(modelManagerViewModel.getTokenStatusAndData().data) }
val dateFormatter = remember { val dateFormatter = remember {
DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.systemDefault()) DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
.withZone(ZoneId.systemDefault())
.withLocale(Locale.getDefault()) .withLocale(Locale.getDefault())
} }
var customHfToken by remember { mutableStateOf("") } var customHfToken by remember { mutableStateOf("") }
@ -95,72 +94,75 @@ fun SettingsDialog(
Dialog(onDismissRequest = onDismissed) { Dialog(onDismissRequest = onDismissed) {
val focusManager = LocalFocusManager.current val focusManager = LocalFocusManager.current
Card( Card(
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth().clickable(
.clickable( interactionSource = interactionSource,
interactionSource = interactionSource, indication = null // Disable the ripple effect indication = null, // Disable the ripple effect
) { ) {
focusManager.clearFocus() focusManager.clearFocus()
}, shape = RoundedCornerShape(16.dp) },
shape = RoundedCornerShape(16.dp),
) { ) {
Column( Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp) modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) { ) {
// Dialog title and subtitle. // Dialog title and subtitle.
Column { Column {
Text( Text(
"Settings", "Settings",
style = MaterialTheme.typography.titleLarge, style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp) modifier = Modifier.padding(bottom = 8.dp),
) )
// Subtitle. // Subtitle.
Text( Text(
"App version: ${BuildConfig.VERSION_NAME}", "App version: ${BuildConfig.VERSION_NAME}",
style = labelSmallNarrow, style = labelSmallNarrow,
color = MaterialTheme.colorScheme.onSurfaceVariant, color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.offset(y = (-6).dp) modifier = Modifier.offset(y = (-6).dp),
) )
} }
Column( Column(
modifier = Modifier modifier = Modifier.verticalScroll(rememberScrollState()).weight(1f, fill = false),
.verticalScroll(rememberScrollState()) verticalArrangement = Arrangement.spacedBy(16.dp),
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
) { ) {
// Theme switcher. // Theme switcher.
Column( Column(modifier = Modifier.fillMaxWidth()) {
modifier = Modifier.fillMaxWidth()
) {
Text( Text(
"Theme", "Theme",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold) style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold),
) )
MultiChoiceSegmentedButtonRow { MultiChoiceSegmentedButtonRow {
THEME_OPTIONS.forEachIndexed { index, label -> THEME_OPTIONS.forEachIndexed { index, theme ->
SegmentedButton(shape = SegmentedButtonDefaults.itemShape( SegmentedButton(
index = index, count = THEME_OPTIONS.size shape =
), onCheckedChange = { SegmentedButtonDefaults.itemShape(index = index, count = THEME_OPTIONS.size),
selectedTheme = label onCheckedChange = {
selectedTheme = theme
// Update theme settings. // Update theme settings.
// This will update app's theme. // This will update app's theme.
ThemeSettings.themeOverride.value = label ThemeSettings.themeOverride.value = theme
// Save to data store. // Save to data store.
modelManagerViewModel.saveThemeOverride(label) modelManagerViewModel.saveThemeOverride(theme)
}, checked = label == selectedTheme, label = { Text(label) }) },
checked = theme == selectedTheme,
label = { Text(themeLabel(theme)) },
)
} }
} }
} }
// HF Token management. // HF Token management.
Column( Column(
modifier = Modifier.fillMaxWidth(), verticalArrangement = Arrangement.spacedBy(4.dp) modifier = Modifier.fillMaxWidth(),
verticalArrangement = Arrangement.spacedBy(4.dp),
) { ) {
Text( Text(
"HuggingFace access token", "HuggingFace access token",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold) style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold),
) )
// Show the start of the token. // Show the start of the token.
val curHfToken = hfToken val curHfToken = hfToken
@ -168,23 +170,23 @@ fun SettingsDialog(
Text( Text(
curHfToken.accessToken.substring(0, min(16, curHfToken.accessToken.length)) + "...", curHfToken.accessToken.substring(0, min(16, curHfToken.accessToken.length)) + "...",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant,
) )
Text( Text(
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}", "Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant,
) )
} else { } else {
Text( Text(
"Not available", "Not available",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant,
) )
Text( Text(
"The token will be automatically retrieved when a gated model is downloaded", "The token will be automatically retrieved when a gated model is downloaded",
style = MaterialTheme.typography.bodySmall, style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant,
) )
} }
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) { Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
@ -192,46 +194,42 @@ fun SettingsDialog(
onClick = { onClick = {
modelManagerViewModel.clearAccessToken() modelManagerViewModel.clearAccessToken()
hfToken = null hfToken = null
}, enabled = curHfToken != null },
enabled = curHfToken != null,
) { ) {
Text("Clear") Text("Clear")
} }
BasicTextField( BasicTextField(
value = customHfToken, value = customHfToken,
singleLine = true, singleLine = true,
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.padding(top = 4.dp) .padding(top = 4.dp)
.focusRequester(focusRequester) .focusRequester(focusRequester)
.onFocusChanged { .onFocusChanged { isFocused = it.isFocused },
isFocused = it.isFocused onValueChange = { customHfToken = it },
},
onValueChange = {
customHfToken = it
},
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface), textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface), cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
) { innerTextField -> ) { innerTextField ->
Box( Box(
modifier = Modifier modifier =
.border( Modifier.border(
width = if (isFocused) 2.dp else 1.dp, width = if (isFocused) 2.dp else 1.dp,
color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline, color =
shape = CircleShape, if (isFocused) MaterialTheme.colorScheme.primary
) else MaterialTheme.colorScheme.outline,
.height(40.dp), contentAlignment = Alignment.CenterStart shape = CircleShape,
)
.height(40.dp),
contentAlignment = Alignment.CenterStart,
) { ) {
Row(verticalAlignment = Alignment.CenterVertically) { Row(verticalAlignment = Alignment.CenterVertically) {
Box( Box(modifier = Modifier.padding(start = 16.dp).weight(1f)) {
modifier = Modifier
.padding(start = 16.dp)
.weight(1f)
) {
if (customHfToken.isEmpty()) { if (customHfToken.isEmpty()) {
Text( Text(
"Enter token manually", "Enter token manually",
color = MaterialTheme.colorScheme.onSurfaceVariant, color = MaterialTheme.colorScheme.onSurfaceVariant,
style = MaterialTheme.typography.bodySmall style = MaterialTheme.typography.bodySmall,
) )
} }
innerTextField() innerTextField()
@ -246,7 +244,8 @@ fun SettingsDialog(
expiresAt = System.currentTimeMillis() + 1000L * 60 * 60 * 24 * 365 * 10, expiresAt = System.currentTimeMillis() + 1000L * 60 * 60 * 24 * 365 * 10,
) )
hfToken = modelManagerViewModel.getTokenStatusAndData().data hfToken = modelManagerViewModel.getTokenStatusAndData().data
}) { },
) {
Icon(Icons.Rounded.CheckCircle, contentDescription = "") Icon(Icons.Rounded.CheckCircle, contentDescription = "")
} }
} }
@ -257,24 +256,24 @@ fun SettingsDialog(
} }
} }
// Button row. // Button row.
Row( Row(
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(top = 8.dp),
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End, horizontalArrangement = Arrangement.End,
) { ) {
// Close button // Close button
Button( Button(onClick = { onDismissed() }) { Text("Close") }
onClick = {
onDismissed()
},
) {
Text("Close")
}
} }
} }
} }
} }
} }
private fun themeLabel(theme: Theme): String {
return when (theme) {
Theme.THEME_AUTO -> "Auto"
Theme.THEME_LIGHT -> "Light"
Theme.THEME_DARK -> "Dark"
else -> "Unknown"
}
}

View file

@ -30,61 +30,64 @@ val Deployed_code: ImageVector
if (internal_Deployed_code != null) { if (internal_Deployed_code != null) {
return internal_Deployed_code!! return internal_Deployed_code!!
} }
internal_Deployed_code = ImageVector.Builder( internal_Deployed_code =
name = "Deployed_code", ImageVector.Builder(
defaultWidth = 24.dp, name = "Deployed_code",
defaultHeight = 24.dp, defaultWidth = 24.dp,
viewportWidth = 960f, defaultHeight = 24.dp,
viewportHeight = 960f viewportWidth = 960f,
).apply { viewportHeight = 960f,
path( )
fill = SolidColor(Color.Black), .apply {
fillAlpha = 1.0f, path(
stroke = null, fill = SolidColor(Color.Black),
strokeAlpha = 1.0f, fillAlpha = 1.0f,
strokeLineWidth = 1.0f, stroke = null,
strokeLineCap = StrokeCap.Butt, strokeAlpha = 1.0f,
strokeLineJoin = StrokeJoin.Miter, strokeLineWidth = 1.0f,
strokeLineMiter = 1.0f, strokeLineCap = StrokeCap.Butt,
pathFillType = PathFillType.NonZero strokeLineJoin = StrokeJoin.Miter,
) { strokeLineMiter = 1.0f,
moveTo(440f, 777f) pathFillType = PathFillType.NonZero,
verticalLineToRelative(-274f) ) {
lineTo(200f, 364f) moveTo(440f, 777f)
verticalLineToRelative(274f) verticalLineToRelative(-274f)
close() lineTo(200f, 364f)
moveToRelative(80f, 0f) verticalLineToRelative(274f)
lineToRelative(240f, -139f) close()
verticalLineToRelative(-274f) moveToRelative(80f, 0f)
lineTo(520f, 503f) lineToRelative(240f, -139f)
close() verticalLineToRelative(-274f)
moveToRelative(-40f, -343f) lineTo(520f, 503f)
lineToRelative(237f, -137f) close()
lineToRelative(-237f, -137f) moveToRelative(-40f, -343f)
lineToRelative(-237f, 137f) lineToRelative(237f, -137f)
close() lineToRelative(-237f, -137f)
moveTo(160f, 708f) lineToRelative(-237f, 137f)
quadToRelative(-19f, -11f, -29.5f, -29f) close()
reflectiveQuadTo(120f, 639f) moveTo(160f, 708f)
verticalLineToRelative(-318f) quadToRelative(-19f, -11f, -29.5f, -29f)
quadToRelative(0f, -22f, 10.5f, -40f) reflectiveQuadTo(120f, 639f)
reflectiveQuadToRelative(29.5f, -29f) verticalLineToRelative(-318f)
lineToRelative(280f, -161f) quadToRelative(0f, -22f, 10.5f, -40f)
quadToRelative(19f, -11f, 40f, -11f) reflectiveQuadToRelative(29.5f, -29f)
reflectiveQuadToRelative(40f, 11f) lineToRelative(280f, -161f)
lineToRelative(280f, 161f) quadToRelative(19f, -11f, 40f, -11f)
quadToRelative(19f, 11f, 29.5f, 29f) reflectiveQuadToRelative(40f, 11f)
reflectiveQuadToRelative(10.5f, 40f) lineToRelative(280f, 161f)
verticalLineToRelative(318f) quadToRelative(19f, 11f, 29.5f, 29f)
quadToRelative(0f, 22f, -10.5f, 40f) reflectiveQuadToRelative(10.5f, 40f)
reflectiveQuadTo(800f, 708f) verticalLineToRelative(318f)
lineTo(520f, 869f) quadToRelative(0f, 22f, -10.5f, 40f)
quadToRelative(-19f, 11f, -40f, 11f) reflectiveQuadTo(800f, 708f)
reflectiveQuadToRelative(-40f, -11f) lineTo(520f, 869f)
close() quadToRelative(-19f, 11f, -40f, 11f)
moveToRelative(320f, -228f) reflectiveQuadToRelative(-40f, -11f)
} close()
}.build() moveToRelative(320f, -228f)
}
}
.build()
return internal_Deployed_code!! return internal_Deployed_code!!
} }

View file

@ -0,0 +1,78 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Forum: ImageVector
get() {
if (_Forum != null) return _Forum!!
_Forum =
ImageVector.Builder(
name = "Forum",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(280f, 720f)
quadToRelative(-17f, 0f, -28.5f, -11.5f)
reflectiveQuadTo(240f, 680f)
verticalLineToRelative(-80f)
horizontalLineToRelative(520f)
verticalLineToRelative(-360f)
horizontalLineToRelative(80f)
quadToRelative(17f, 0f, 28.5f, 11.5f)
reflectiveQuadTo(880f, 280f)
verticalLineToRelative(600f)
lineTo(720f, 720f)
close()
moveTo(80f, 680f)
verticalLineToRelative(-560f)
quadToRelative(0f, -17f, 11.5f, -28.5f)
reflectiveQuadTo(120f, 80f)
horizontalLineToRelative(520f)
quadToRelative(17f, 0f, 28.5f, 11.5f)
reflectiveQuadTo(680f, 120f)
verticalLineToRelative(360f)
quadToRelative(0f, 17f, -11.5f, 28.5f)
reflectiveQuadTo(640f, 520f)
horizontalLineTo(240f)
close()
moveToRelative(520f, -240f)
verticalLineToRelative(-280f)
horizontalLineTo(160f)
verticalLineToRelative(280f)
close()
moveToRelative(-440f, 0f)
verticalLineToRelative(-280f)
close()
}
}
.build()
return _Forum!!
}
private var _Forum: ImageVector? = null

View file

@ -0,0 +1,73 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Mms: ImageVector
get() {
if (_Mms != null) return _Mms!!
_Mms =
ImageVector.Builder(
name = "Mms",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(240f, 560f)
horizontalLineToRelative(480f)
lineTo(570f, 360f)
lineTo(450f, 520f)
lineToRelative(-90f, -120f)
close()
moveTo(80f, 880f)
verticalLineToRelative(-720f)
quadToRelative(0f, -33f, 23.5f, -56.5f)
reflectiveQuadTo(160f, 80f)
horizontalLineToRelative(640f)
quadToRelative(33f, 0f, 56.5f, 23.5f)
reflectiveQuadTo(880f, 160f)
verticalLineToRelative(480f)
quadToRelative(0f, 33f, -23.5f, 56.5f)
reflectiveQuadTo(800f, 720f)
horizontalLineTo(240f)
close()
moveToRelative(126f, -240f)
horizontalLineToRelative(594f)
verticalLineToRelative(-480f)
horizontalLineTo(160f)
verticalLineToRelative(525f)
close()
moveToRelative(-46f, 0f)
verticalLineToRelative(-480f)
close()
}
}
.build()
return _Mms!!
}
private var _Mms: ImageVector? = null

View file

@ -0,0 +1,87 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Widgets: ImageVector
get() {
if (_Widgets != null) return _Widgets!!
_Widgets =
ImageVector.Builder(
name = "Widgets",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(666f, 520f)
lineTo(440f, 294f)
lineToRelative(226f, -226f)
lineToRelative(226f, 226f)
close()
moveToRelative(-546f, -80f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(400f, 400f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(-400f, 0f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(80f, -480f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(200f)
close()
moveToRelative(467f, 48f)
lineToRelative(113f, -113f)
lineToRelative(-113f, -113f)
lineToRelative(-113f, 113f)
close()
moveToRelative(-67f, 352f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(600f)
close()
moveToRelative(-400f, 0f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(200f)
close()
moveToRelative(400f, -160f)
}
}
.build()
return _Widgets!!
}
private var _Widgets: ImageVector? = null

View file

@ -1,154 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imageclassification
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import androidx.compose.ui.graphics.Color
import com.google.android.gms.tflite.client.TfLiteInitializationOptions
import com.google.android.gms.tflite.gpu.support.TfLiteGpu
import com.google.android.gms.tflite.java.TfLite
import com.google.ai.edge.gallery.ui.common.chat.Classification
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.common.LatencyProvider
import org.tensorflow.lite.DataType
import org.tensorflow.lite.InterpreterApi
import org.tensorflow.lite.gpu.GpuDelegateFactory
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.File
import java.io.FileInputStream
private const val TAG = "AGImageClassificationModelHelper"
class ImageClassificationInferenceResult(
val categories: List<Classification>, override val latencyMs: Float
) : LatencyProvider
//TODO: handle error.
object ImageClassificationModelHelper {
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU)
TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask ->
val optionsBuilder = TfLiteInitializationOptions.builder()
if (gpuTask.result) {
optionsBuilder.setEnableGpuDelegateSupport(true)
}
val task = TfLite.initialize(
context,
optionsBuilder.build()
)
task.addOnSuccessListener {
val interpreterOption =
InterpreterApi.Options().setRuntime(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)
if (useGpu) {
interpreterOption.addDelegateFactory(GpuDelegateFactory())
}
val interpreter = InterpreterApi.create(
File(model.getPath(context = context)), interpreterOption
)
model.instance = interpreter
onDone("")
}
}
}
fun cleanUp(model: Model) {
if (model.instance == null) {
return
}
val instance = model.instance as InterpreterApi
instance.close()
}
fun runInference(
context: Context,
model: Model,
input: Bitmap,
primaryColor: Color,
): ImageClassificationInferenceResult {
val instance = model.instance
if (instance == null) {
Log.d(
TAG, "Model '${model.name}' has not been initialized"
)
return ImageClassificationInferenceResult(categories = listOf(), latencyMs = 0f)
}
// Pre-process image.
val start = System.currentTimeMillis()
val interpreter = model.instance as InterpreterApi
val inputShape = interpreter.getInputTensor(0).shape()
val imageProcessor = ImageProcessor.Builder()
.add(ResizeOp(inputShape[1], inputShape[2], ResizeOp.ResizeMethod.BILINEAR))
.add(NormalizeOp(127.5f, 127.5f)) // Normalize pixel values
.build()
val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(input)
val inputTensorBuffer = imageProcessor.process(tensorImage).tensorBuffer
// Output buffer
val outputBuffer =
TensorBuffer.createFixedSize(interpreter.getOutputTensor(0).shape(), DataType.FLOAT32)
// Run inference
interpreter.run(inputTensorBuffer.buffer, outputBuffer.buffer)
// Post-process result.
val output = outputBuffer.floatArray
val labelsFilePath = model.getPath(
context = context,
fileName = model.getExtraDataFile(name = "labels")?.downloadFileName ?: "_"
)
val labelsFileInputStream = FileInputStream(File(labelsFilePath))
val labels = FileUtil.loadLabels(labelsFileInputStream)
labelsFileInputStream.close()
val topIndices =
getTopKMaxIndices(output = output, k = model.getIntConfigValue(ConfigKey.MAX_RESULT_COUNT))
val categories: List<Classification> =
topIndices.map { i ->
Classification(
label = labels[i],
score = output[i],
color = primaryColor
)
}
return ImageClassificationInferenceResult(
categories = categories,
latencyMs = (System.currentTimeMillis() - start).toFloat()
)
}
private fun getTopKMaxIndices(output: FloatArray, k: Int): List<Int> {
if (k <= 0 || output.isEmpty()) {
return emptyList()
}
val indexedValues = output.withIndex().toList()
val sortedIndexedValues =
indexedValues.sortedByDescending { it.value }
return sortedIndexedValues.take(k).map { it.index } // Take the top k and extract indices
}
}

View file

@ -1,98 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imageclassification
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.common.chat.ChatInputType
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatView
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */
object ImageClassificationDestination {
@Serializable
val route = "ImageClassificationRoute"
}
@Composable
fun ImageClassificationScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: ImageClassificationViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
val context = LocalContext.current
val primaryColor = MaterialTheme.colorScheme.primary
ChatView(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage(
model = model,
message = message,
)
if (message is ChatMessageImage) {
viewModel.generateResponse(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor,
)
}
},
onStreamImageMessage = { model, message ->
viewModel.generateStreamingResponse(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor,
)
},
onRunAgainClicked = { model, message ->
viewModel.runAgain(
context = context,
model = model,
message = message,
primaryColor = primaryColor,
)
},
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
viewModel.benchmark(
context = context,
model = model,
message = message,
warmupCount = warmUpIterations,
iterations = benchmarkIterations,
primaryColor = primaryColor,
)
},
navigateUp = navigateUp,
modifier = modifier,
chatInputType = ChatInputType.IMAGE,
)
}

View file

@ -1,165 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imageclassification
import android.content.Context
import android.graphics.Bitmap
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.ui.common.chat.ChatMessage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageClassification
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageType
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import com.google.ai.edge.gallery.ui.common.runBasicBenchmark
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
class ImageClassificationViewModel : ChatViewModel(task = TASK_IMAGE_CLASSIFICATION) {
private val mutex = Mutex()
fun generateResponse(context: Context, model: Model, input: Bitmap, primaryColor: Color) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
val result = ImageClassificationModelHelper.runInference(
context = context, model = model, input = input, primaryColor = primaryColor
)
super.addMessage(
model = model,
message = ChatMessageClassification(
classifications = result.categories,
latencyMs = result.latencyMs,
maxBarWidth = 300.dp,
),
)
}
}
fun generateStreamingResponse(
context: Context,
model: Model,
input: Bitmap,
primaryColor: Color
) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (mutex.tryLock()) {
try {
val result = ImageClassificationModelHelper.runInference(
context = context, model = model, input = input, primaryColor = primaryColor
)
updateStreamingMessage(
model = model,
message = ChatMessageClassification(
classifications = result.categories,
latencyMs = result.latencyMs
)
)
} finally {
mutex.unlock()
}
} else {
// skip call if the previous call has not been finished (mutex is still locked).
}
}
}
fun benchmark(
context: Context,
model: Model,
message: ChatMessage,
warmupCount: Int,
iterations: Int,
primaryColor: Color
) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (message is ChatMessageImage) {
setInProgress(true)
runBasicBenchmark(
model = model,
warmupCount = warmupCount,
iterations = iterations,
chatViewModel = this@ImageClassificationViewModel,
inferenceFn = {
ImageClassificationModelHelper.runInference(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor
)
},
chatMessageType = ChatMessageType.BENCHMARK_RESULT,
)
setInProgress(false)
}
}
}
fun runAgain(context: Context, model: Model, message: ChatMessage, primaryColor: Color) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (message is ChatMessageImage) {
// Clone the clicked message and add it.
addMessage(model = model, message = message.clone())
// Run inference.
val result =
ImageClassificationModelHelper.runInference(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor
)
// Add response message.
val newMessage = generateClassificationMessage(result = result)
addMessage(model = model, message = newMessage)
}
}
}
private fun generateClassificationMessage(result: ImageClassificationInferenceResult): ChatMessageClassification {
return ChatMessageClassification(
classifications = result.categories,
latencyMs = result.latencyMs,
maxBarWidth = 300.dp,
)
}
}

View file

@ -1,83 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imagegeneration
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import com.google.mediapipe.framework.image.BitmapExtractor
import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.common.LatencyProvider
import com.google.ai.edge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import kotlin.random.Random
private const val TAG = "AGImageGenerationModelHelper"
class ImageGenerationInferenceResult(
val bitmap: Bitmap, override val latencyMs: Float
) : LatencyProvider
object ImageGenerationModelHelper {
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
try {
val options = ImageGenerator.ImageGeneratorOptions.builder()
.setImageGeneratorModelDirectory(model.getPath(context = context))
.build()
model.instance = ImageGenerator.createFromOptions(context, options)
} catch (e: Exception) {
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
return
}
onDone("")
}
fun cleanUp(model: Model) {
if (model.instance == null) {
return
}
val instance = model.instance as ImageGenerator
try {
instance.close()
} catch (e: Exception) {
// ignore
}
model.instance = null
Log.d(TAG, "Clean up done.")
}
fun runInference(
model: Model,
input: String,
onStep: (curIteration: Int, totalIterations: Int, ImageGenerationInferenceResult, isLast: Boolean) -> Unit
) {
val start = System.currentTimeMillis()
val instance = model.instance as ImageGenerator
val iterations = model.getIntConfigValue(ConfigKey.ITERATIONS)
instance.setInputs(input, iterations, Random.nextInt())
for (i in 0..<iterations) {
val result = ImageGenerationInferenceResult(
bitmap = BitmapExtractor.extract(
instance.execute(true)?.generatedImage(),
),
latencyMs = (System.currentTimeMillis() - start).toFloat(),
)
onStep(i, iterations, result, i == iterations - 1)
}
}
}

View file

@ -1,65 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imagegeneration
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatView
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */
object ImageGenerationDestination {
@Serializable
val route = "ImageGenerationRoute"
}
@Composable
fun ImageGenerationScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: ImageGenerationViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
ChatView(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage(
model = model,
message = message,
)
if (message is ChatMessageText) {
viewModel.generateResponse(
model = model,
input = message.content,
)
}
},
onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> },
navigateUp = navigateUp,
modifier = modifier,
)
}

View file

@ -1,87 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imagegeneration
import android.graphics.Bitmap
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.graphics.asImageBitmap
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_IMAGE_GENERATION
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImageWithHistory
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageType
import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
class ImageGenerationViewModel : ChatViewModel(task = TASK_IMAGE_GENERATION) {
fun generateResponse(model: Model, input: String) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
// Loading.
addMessage(
model = model,
message = ChatMessageLoading(),
)
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
// Run inference.
val bitmaps: MutableList<Bitmap> = mutableListOf()
val imageBitmaps: MutableList<ImageBitmap> = mutableListOf()
ImageGenerationModelHelper.runInference(
model = model, input = input
) { step, totalIterations, result, isLast ->
bitmaps.add(result.bitmap)
imageBitmaps.add(result.bitmap.asImageBitmap())
val message = ChatMessageImageWithHistory(
bitmaps = bitmaps,
imageBitMaps = imageBitmaps,
totalIterations = totalIterations,
side = ChatSide.AGENT,
latencyMs = result.latencyMs,
curIteration = step,
)
if (step == 0) {
removeLastMessage(model = model)
super.addMessage(
model = model,
message = message,
)
} else {
super.replaceLastMessage(
model = model,
message = message,
type = ChatMessageType.IMAGE_WITH_HISTORY
)
}
if (isLast) {
setInProgress(false)
}
}
}
}
}

View file

@ -1,72 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.llmchat
import com.google.ai.edge.gallery.data.Accelerator
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.ValueType
const val DEFAULT_MAX_TOKEN = 1024
const val DEFAULT_TOPK = 40
const val DEFAULT_TOPP = 0.9f
const val DEFAULT_TEMPERATURE = 1.0f
val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
fun createLlmChatConfigs(
defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
defaultTopK: Int = DEFAULT_TOPK,
defaultTopP: Float = DEFAULT_TOPP,
defaultTemperature: Float = DEFAULT_TEMPERATURE,
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
): List<Config> {
return listOf(
LabelConfig(
key = ConfigKey.MAX_TOKENS,
defaultValue = "$defaultMaxToken",
),
NumberSliderConfig(
key = ConfigKey.TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = defaultTopK.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = defaultTopP,
valueType = ValueType.FLOAT
),
NumberSliderConfig(
key = ConfigKey.TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = defaultTemperature,
valueType = ValueType.FLOAT
),
SegmentedButtonConfig(
key = ConfigKey.ACCELERATOR,
defaultValue = accelerators[0].label,
options = accelerators.map { it.label }
)
)
}

View file

@ -19,10 +19,15 @@ package com.google.ai.edge.gallery.ui.llmchat
import android.content.Context import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.util.Log import android.util.Log
import com.google.ai.edge.gallery.common.cleanUpMediapipeTaskErrorMessage
import com.google.ai.edge.gallery.data.Accelerator import com.google.ai.edge.gallery.data.Accelerator
import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.DEFAULT_MAX_TOKEN
import com.google.ai.edge.gallery.data.DEFAULT_TEMPERATURE
import com.google.ai.edge.gallery.data.DEFAULT_TOPK
import com.google.ai.edge.gallery.data.DEFAULT_TOPP
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import com.google.mediapipe.framework.image.BitmapImageBuilder import com.google.mediapipe.framework.image.BitmapImageBuilder
import com.google.mediapipe.tasks.genai.llminference.GraphOptions import com.google.mediapipe.tasks.genai.llminference.GraphOptions
import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.mediapipe.tasks.genai.llminference.LlmInference
@ -31,6 +36,7 @@ import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
private const val TAG = "AGLlmChatModelHelper" private const val TAG = "AGLlmChatModelHelper"
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
typealias CleanUpListener = () -> Unit typealias CleanUpListener = () -> Unit
data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession) data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession)
@ -39,9 +45,7 @@ object LlmChatModelHelper {
// Indexed by model name. // Indexed by model name.
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf() private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
fun initialize( fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
context: Context, model: Model, onDone: (String) -> Unit
) {
// Prepare options. // Prepare options.
val maxTokens = val maxTokens =
model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN) model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN)
@ -52,29 +56,36 @@ object LlmChatModelHelper {
val accelerator = val accelerator =
model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = Accelerator.GPU.label) model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = Accelerator.GPU.label)
Log.d(TAG, "Initializing...") Log.d(TAG, "Initializing...")
val preferredBackend = when (accelerator) { val preferredBackend =
Accelerator.CPU.label -> LlmInference.Backend.CPU when (accelerator) {
Accelerator.GPU.label -> LlmInference.Backend.GPU Accelerator.CPU.label -> LlmInference.Backend.CPU
else -> LlmInference.Backend.GPU Accelerator.GPU.label -> LlmInference.Backend.GPU
} else -> LlmInference.Backend.GPU
}
val options = val options =
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context)) LlmInference.LlmInferenceOptions.builder()
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend) .setModelPath(model.getPath(context = context))
.setMaxNumImages(if (model.llmSupportImage) 1 else 0) .setMaxTokens(maxTokens)
.setPreferredBackend(preferredBackend)
.setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0)
.build() .build()
// Create an instance of the LLM Inference task and session. // Create an instance of the LLM Inference task and session.
try { try {
val llmInference = LlmInference.createFromOptions(context, options) val llmInference = LlmInference.createFromOptions(context, options)
val session = LlmInferenceSession.createFromOptions( val session =
llmInference, LlmInferenceSession.createFromOptions(
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) llmInference,
.setTemperature(temperature) LlmInferenceSession.LlmInferenceSessionOptions.builder()
.setGraphOptions( .setTopK(topK)
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() .setTopP(topP)
).build() .setTemperature(temperature)
) .setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
)
.build(),
)
model.instance = LlmModelInstance(engine = llmInference, session = session) model.instance = LlmModelInstance(engine = llmInference, session = session)
} catch (e: Exception) { } catch (e: Exception) {
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error")) onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
@ -96,14 +107,18 @@ object LlmChatModelHelper {
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP) val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature = val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE) model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
val newSession = LlmInferenceSession.createFromOptions( val newSession =
inference, LlmInferenceSession.createFromOptions(
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) inference,
.setTemperature(temperature) LlmInferenceSession.LlmInferenceSessionOptions.builder()
.setGraphOptions( .setTopK(topK)
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() .setTopP(topP)
).build() .setTemperature(temperature)
) .setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
)
.build(),
)
instance.session = newSession instance.session = newSession
Log.d(TAG, "Resetting done") Log.d(TAG, "Resetting done")
} catch (e: Exception) { } catch (e: Exception) {
@ -117,12 +132,19 @@ object LlmChatModelHelper {
} }
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
try {
instance.session.close()
} catch (e: Exception) {
Log.e(TAG, "Failed to close the LLM Inference session: ${e.message}")
}
try { try {
// This will also close the session. Do not call session.close manually.
instance.engine.close() instance.engine.close()
} catch (e: Exception) { } catch (e: Exception) {
// ignore Log.e(TAG, "Failed to close the LLM Inference engine: ${e.message}")
} }
val onCleanUp = cleanUpListeners.remove(model.name) val onCleanUp = cleanUpListeners.remove(model.name)
if (onCleanUp != null) { if (onCleanUp != null) {
onCleanUp() onCleanUp()
@ -136,7 +158,7 @@ object LlmChatModelHelper {
input: String, input: String,
resultListener: ResultListener, resultListener: ResultListener,
cleanUpListener: CleanUpListener, cleanUpListener: CleanUpListener,
image: Bitmap? = null, images: List<Bitmap> = listOf(),
) { ) {
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
@ -151,9 +173,9 @@ object LlmChatModelHelper {
// image. // image.
val session = instance.session val session = instance.session
session.addQueryChunk(input) session.addQueryChunk(input)
if (image != null) { for (image in images) {
session.addImage(BitmapImageBuilder(image).build()) session.addImage(BitmapImageBuilder(image).build())
} }
session.generateResponseAsync(resultListener) val unused = session.generateResponseAsync(resultListener)
} }
} }

View file

@ -26,16 +26,13 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatView import com.google.ai.edge.gallery.ui.common.chat.ChatView
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */ /** Navigation destination data */
object LlmChatDestination { object LlmChatDestination {
@Serializable
val route = "LlmChatRoute" val route = "LlmChatRoute"
} }
object LlmAskImageDestination { object LlmAskImageDestination {
@Serializable
val route = "LlmAskImageRoute" val route = "LlmAskImageRoute"
} }
@ -44,9 +41,7 @@ fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmChatViewModel = viewModel( viewModel: LlmChatViewModel = viewModel(factory = ViewModelProvider.Factory),
factory = ViewModelProvider.Factory
),
) { ) {
ChatViewWrapper( ChatViewWrapper(
viewModel = viewModel, viewModel = viewModel,
@ -61,9 +56,7 @@ fun LlmAskImageScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmAskImageViewModel = viewModel( viewModel: LlmAskImageViewModel = viewModel(factory = ViewModelProvider.Factory),
factory = ViewModelProvider.Factory
),
) { ) {
ChatViewWrapper( ChatViewWrapper(
viewModel = viewModel, viewModel = viewModel,
@ -78,7 +71,7 @@ fun ChatViewWrapper(
viewModel: LlmChatViewModel, viewModel: LlmChatViewModel,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier modifier: Modifier = Modifier,
) { ) {
val context = LocalContext.current val context = LocalContext.current
@ -88,57 +81,58 @@ fun ChatViewWrapper(
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, messages -> onSendMessage = { model, messages ->
for (message in messages) { for (message in messages) {
viewModel.addMessage( viewModel.addMessage(model = model, message = message)
model = model,
message = message,
)
} }
var text = "" var text = ""
var image: Bitmap? = null val images: MutableList<Bitmap> = mutableListOf()
var chatMessageText: ChatMessageText? = null var chatMessageText: ChatMessageText? = null
for (message in messages) { for (message in messages) {
if (message is ChatMessageText) { if (message is ChatMessageText) {
chatMessageText = message chatMessageText = message
text = message.content text = message.content
} else if (message is ChatMessageImage) { } else if (message is ChatMessageImage) {
image = message.bitmap images.add(message.bitmap)
} }
} }
if (text.isNotEmpty() && chatMessageText != null) { if (text.isNotEmpty() && chatMessageText != null) {
modelManagerViewModel.addTextInputHistory(text) modelManagerViewModel.addTextInputHistory(text)
viewModel.generateResponse(model = model, input = text, image = image, onError = { viewModel.generateResponse(
viewModel.handleError( model = model,
context = context, input = text,
model = model, images = images,
modelManagerViewModel = modelManagerViewModel, onError = {
triggeredMessage = chatMessageText, viewModel.handleError(
) context = context,
}) model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = chatMessageText,
)
},
)
} }
}, },
onRunAgainClicked = { model, message -> onRunAgainClicked = { model, message ->
if (message is ChatMessageText) { if (message is ChatMessageText) {
viewModel.runAgain(model = model, message = message, onError = { viewModel.runAgain(
viewModel.handleError( model = model,
context = context, message = message,
model = model, onError = {
modelManagerViewModel = modelManagerViewModel, viewModel.handleError(
triggeredMessage = message, context = context,
) model = model,
}) modelManagerViewModel = modelManagerViewModel,
triggeredMessage = message,
)
},
)
} }
}, },
onBenchmarkClicked = { _, _, _, _ -> onBenchmarkClicked = { _, _, _, _ -> },
}, onResetSessionClicked = { model -> viewModel.resetSession(model = model) },
onResetSessionClicked = { model ->
viewModel.resetSession(model = model)
},
showStopButtonInInputWhenInProgress = true, showStopButtonInInputWhenInProgress = true,
onStopButtonClicked = { model -> onStopButtonClicked = { model -> viewModel.stopResponse(model = model) },
viewModel.stopResponse(model = model)
},
navigateUp = navigateUp, navigateUp = navigateUp,
modifier = modifier, modifier = modifier,
) )
} }

View file

@ -22,8 +22,8 @@ import android.util.Log
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
@ -39,25 +39,28 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
private const val TAG = "AGLlmChatViewModel" private const val TAG = "AGLlmChatViewModel"
private val STATS = listOf( private val STATS =
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"), listOf(
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"), Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"), Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec") Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
) Stat(id = "latency", label = "Latency", unit = "sec"),
)
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) { open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) { fun generateResponse(
model: Model,
input: String,
images: List<Bitmap> = listOf(),
onError: () -> Unit,
) {
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
viewModelScope.launch(Dispatchers.Default) { viewModelScope.launch(Dispatchers.Default) {
setInProgress(true) setInProgress(true)
setPreparing(true) setPreparing(true)
// Loading. // Loading.
addMessage( addMessage(model = model, message = ChatMessageLoading(accelerator = accelerator))
model = model,
message = ChatMessageLoading(accelerator = accelerator),
)
// Wait for instance to be initialized. // Wait for instance to be initialized.
while (model.instance == null) { while (model.instance == null) {
@ -68,9 +71,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
// Run inference. // Run inference.
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
var prefillTokens = instance.session.sizeInTokens(input) var prefillTokens = instance.session.sizeInTokens(input)
if (image != null) { prefillTokens += images.size * 257
prefillTokens += 257
}
var firstRun = true var firstRun = true
var timeToFirstToken = 0f var timeToFirstToken = 0f
@ -81,9 +82,10 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
try { try {
LlmChatModelHelper.runInference(model = model, LlmChatModelHelper.runInference(
model = model,
input = input, input = input,
image = image, images = images,
resultListener = { partialResult, done -> resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis() val curTs = System.currentTimeMillis()
@ -106,18 +108,17 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
// Add an empty message that will receive streaming results. // Add an empty message that will receive streaming results.
addMessage( addMessage(
model = model, model = model,
message = ChatMessageText( message =
content = "", ChatMessageText(content = "", side = ChatSide.AGENT, accelerator = accelerator),
side = ChatSide.AGENT,
accelerator = accelerator
)
) )
} }
// Incrementally update the streamed partial results. // Incrementally update the streamed partial results.
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1 val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
updateLastTextMessageContentIncrementally( updateLastTextMessageContentIncrementally(
model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat() model = model,
partialContent = partialResult,
latencyMs = latencyMs.toFloat(),
) )
if (done) { if (done) {
@ -130,18 +131,21 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
if (lastMessage is ChatMessageText) { if (lastMessage is ChatMessageText) {
updateLastTextMessageLlmBenchmarkResult( updateLastTextMessageLlmBenchmarkResult(
model = model, llmBenchmarkResult = ChatMessageBenchmarkLlmResult( model = model,
orderedStats = STATS, llmBenchmarkResult =
statValues = mutableMapOf( ChatMessageBenchmarkLlmResult(
"prefill_speed" to prefillSpeed, orderedStats = STATS,
"decode_speed" to decodeSpeed, statValues =
"time_to_first_token" to timeToFirstToken, mutableMapOf(
"latency" to (curTs - start).toFloat() / 1000f, "prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = false,
latencyMs = -1f,
accelerator = accelerator,
), ),
running = false,
latencyMs = -1f,
accelerator = accelerator,
)
) )
} }
} }
@ -149,7 +153,8 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
cleanUpListener = { cleanUpListener = {
setInProgress(false) setInProgress(false)
setPreparing(false) setPreparing(false)
}) },
)
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error occurred while running inference", e) Log.e(TAG, "Error occurred while running inference", e)
setInProgress(false) setInProgress(false)
@ -201,9 +206,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
addMessage(model = model, message = message.clone()) addMessage(model = model, message = message.clone())
// Run inference. // Run inference.
generateResponse( generateResponse(model = model, input = message.content, onError = onError)
model = model, input = message.content, onError = onError
)
} }
} }
@ -229,20 +232,18 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
// Add a warning message for re-initializing the session. // Add a warning message for re-initializing the session.
addMessage( addMessage(
model = model, model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the session.") message = ChatMessageWarning(content = "Error occurred. Re-initializing the session."),
) )
// Add the triggered message back. // Add the triggered message back.
addMessage(model = model, message = triggeredMessage) addMessage(model = model, message = triggeredMessage)
// Re-initialize the session/engine. // Re-initialize the session/engine.
modelManagerViewModel.initializeModel( modelManagerViewModel.initializeModel(context = context, task = task, model = model)
context = context, task = task, model = model
)
// Re-generate the response automatically. // Re-generate the response automatically.
generateResponse(model = model, input = triggeredMessage.content, onError = {}) generateResponse(model = model, input = triggeredMessage.content, onError = {})
} }
} }
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)

View file

@ -16,6 +16,10 @@
package com.google.ai.edge.gallery.ui.llmsingleturn package com.google.ai.edge.gallery.ui.llmsingleturn
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.util.Log import android.util.Log
import androidx.activity.compose.BackHandler import androidx.activity.compose.BackHandler
import androidx.compose.foundation.background import androidx.compose.foundation.background
@ -39,7 +43,6 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection import androidx.compose.ui.platform.LocalLayoutDirection
import androidx.compose.ui.tooling.preview.Preview
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.ui.ViewModelProvider import com.google.ai.edge.gallery.ui.ViewModelProvider
@ -48,17 +51,12 @@ import com.google.ai.edge.gallery.ui.common.ModelPageAppBar
import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors import com.google.ai.edge.gallery.ui.theme.customColors
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable
/** Navigation destination data */ /** Navigation destination data */
object LlmSingleTurnDestination { object LlmSingleTurnDestination {
@Serializable
val route = "LlmSingleTurnRoute" val route = "LlmSingleTurnRoute"
} }
@ -69,9 +67,7 @@ fun LlmSingleTurnScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmSingleTurnViewModel = viewModel( viewModel: LlmSingleTurnViewModel = viewModel(factory = ViewModelProvider.Factory),
factory = ViewModelProvider.Factory
),
) { ) {
val task = viewModel.task val task = viewModel.task
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
@ -95,9 +91,7 @@ fun LlmSingleTurnScreen(
} }
// Handle system's edge swipe. // Handle system's edge swipe.
BackHandler { BackHandler { handleNavigateUp() }
handleNavigateUp()
}
// Initialize model when model/download state changes. // Initialize model when model/download state changes.
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name] val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
@ -106,7 +100,7 @@ fun LlmSingleTurnScreen(
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
Log.d( Log.d(
TAG, TAG,
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect" "Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect",
) )
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
} }
@ -118,50 +112,55 @@ fun LlmSingleTurnScreen(
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
} }
Scaffold(modifier = modifier, topBar = { Scaffold(
ModelPageAppBar( modifier = modifier,
task = task, topBar = {
model = selectedModel, ModelPageAppBar(
modelManagerViewModel = modelManagerViewModel, task = task,
inProgress = uiState.inProgress, model = selectedModel,
modelPreparing = uiState.preparing, modelManagerViewModel = modelManagerViewModel,
onConfigChanged = { _, _ -> }, inProgress = uiState.inProgress,
onBackClicked = { handleNavigateUp() }, modelPreparing = uiState.preparing,
onModelSelected = { newSelectedModel -> onConfigChanged = { _, _ -> },
scope.launch(Dispatchers.Default) { onBackClicked = { handleNavigateUp() },
// Clean up current model. onModelSelected = { newSelectedModel ->
modelManagerViewModel.cleanupModel(task = task, model = selectedModel) scope.launch(Dispatchers.Default) {
// Clean up current model.
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
// Update selected model. // Update selected model.
modelManagerViewModel.selectModel(model = newSelectedModel) modelManagerViewModel.selectModel(model = newSelectedModel)
} }
} },
)
}) { innerPadding ->
Column(
modifier = Modifier.padding(
top = innerPadding.calculateTopPadding(),
start = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
end = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
) )
},
) { innerPadding ->
Column(
modifier =
Modifier.padding(
top = innerPadding.calculateTopPadding(),
start = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
end = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
)
) { ) {
ModelDownloadStatusInfoPanel( ModelDownloadStatusInfoPanel(
model = selectedModel, model = selectedModel,
task = task, task = task,
modelManagerViewModel = modelManagerViewModel modelManagerViewModel = modelManagerViewModel,
) )
// Main UI after model is downloaded. // Main UI after model is downloaded.
val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
Box( Box(
contentAlignment = Alignment.BottomCenter, contentAlignment = Alignment.BottomCenter,
modifier = Modifier modifier =
.weight(1f) Modifier.weight(1f)
// Just hide the UI without removing it from the screen so that the scroll syncing // Just hide the UI without removing it from the screen so that the scroll syncing
// from ResponsePanel still works. // from ResponsePanel still works.
.alpha(if (modelDownloaded) 1.0f else 0.0f) .alpha(if (modelDownloaded) 1.0f else 0.0f),
) { ) {
VerticalSplitView(modifier = Modifier.fillMaxSize(), VerticalSplitView(
modifier = Modifier.fillMaxSize(),
topView = { topView = {
PromptTemplatesPanel( PromptTemplatesPanel(
model = selectedModel, model = selectedModel,
@ -170,49 +169,47 @@ fun LlmSingleTurnScreen(
onSend = { fullPrompt -> onSend = { fullPrompt ->
viewModel.generateResponse(model = selectedModel, input = fullPrompt) viewModel.generateResponse(model = selectedModel, input = fullPrompt)
}, },
onStopButtonClicked = { model -> onStopButtonClicked = { model -> viewModel.stopResponse(model = model) },
viewModel.stopResponse(model = model) modifier = Modifier.fillMaxSize(),
},
modifier = Modifier.fillMaxSize()
) )
}, },
bottomView = { bottomView = {
Box( Box(
contentAlignment = Alignment.BottomCenter, contentAlignment = Alignment.BottomCenter,
modifier = Modifier modifier =
.fillMaxSize() Modifier.fillMaxSize().background(MaterialTheme.customColors.agentBubbleBgColor),
.background(MaterialTheme.customColors.agentBubbleBgColor)
) { ) {
ResponsePanel( ResponsePanel(
model = selectedModel, model = selectedModel,
viewModel = viewModel, viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
modifier = Modifier modifier =
.fillMaxSize() Modifier.fillMaxSize().padding(bottom = innerPadding.calculateBottomPadding()),
.padding(bottom = innerPadding.calculateBottomPadding())
) )
} }
}) },
)
} }
if (showErrorDialog) { if (showErrorDialog) {
ErrorDialog(error = modelInitializationStatus?.error ?: "", onDismiss = { ErrorDialog(
showErrorDialog = false error = modelInitializationStatus?.error ?: "",
}) onDismiss = { showErrorDialog = false },
)
} }
} }
} }
} }
@Preview(showBackground = true) // @Preview(showBackground = true)
@Composable // @Composable
fun LlmSingleTurnScreenPreview() { // fun LlmSingleTurnScreenPreview() {
val context = LocalContext.current // val context = LocalContext.current
GalleryTheme { // GalleryTheme {
LlmSingleTurnScreen( // LlmSingleTurnScreen(
modelManagerViewModel = PreviewModelManagerViewModel(context = context), // modelManagerViewModel = PreviewModelManagerViewModel(context = context),
viewModel = PreviewLlmSingleTurnViewModel(), // viewModel = PreviewLlmSingleTurnViewModel(),
navigateUp = {}, // navigateUp = {},
) // )
} // }
} // }

View file

@ -19,12 +19,12 @@ package com.google.ai.edge.gallery.ui.llmsingleturn
import android.util.Log import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.common.processLlmResponse
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.Stat import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.common.processLlmResponse
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@ -34,12 +34,10 @@ import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
private const val TAG = "AGLlmSingleTurnViewModel" private const val TAG = "AGLlmSingleTurnVM"
data class LlmSingleTurnUiState( data class LlmSingleTurnUiState(
/** /** Indicates whether the runtime is currently processing a message. */
* Indicates whether the runtime is currently processing a message.
*/
val inProgress: Boolean = false, val inProgress: Boolean = false,
/** /**
@ -57,12 +55,13 @@ data class LlmSingleTurnUiState(
val selectedPromptTemplateType: PromptTemplateType = PromptTemplateType.entries[0], val selectedPromptTemplateType: PromptTemplateType = PromptTemplateType.entries[0],
) )
private val STATS = listOf( private val STATS =
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"), listOf(
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"), Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"), Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec") Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
) Stat(id = "latency", label = "Latency", unit = "sec"),
)
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() { open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task)) private val _uiState = MutableStateFlow(createUiState(task = task))
@ -94,7 +93,8 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
var response = "" var response = ""
var lastBenchmarkUpdateTs = 0L var lastBenchmarkUpdateTs = 0L
LlmChatModelHelper.runInference(model = model, LlmChatModelHelper.runInference(
model = model,
input = input, input = input,
resultListener = { partialResult, done -> resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis() val curTs = System.currentTimeMillis()
@ -116,7 +116,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
updateResponse( updateResponse(
model = model, model = model,
promptTemplateType = uiState.value.selectedPromptTemplateType, promptTemplateType = uiState.value.selectedPromptTemplateType,
response = response response = response,
) )
// Update benchmark (with throttling). // Update benchmark (with throttling).
@ -125,21 +125,23 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
if (decodeSpeed.isNaN()) { if (decodeSpeed.isNaN()) {
decodeSpeed = 0f decodeSpeed = 0f
} }
val benchmark = ChatMessageBenchmarkLlmResult( val benchmark =
orderedStats = STATS, ChatMessageBenchmarkLlmResult(
statValues = mutableMapOf( orderedStats = STATS,
"prefill_speed" to prefillSpeed, statValues =
"decode_speed" to decodeSpeed, mutableMapOf(
"time_to_first_token" to timeToFirstToken, "prefill_speed" to prefillSpeed,
"latency" to (curTs - start).toFloat() / 1000f, "decode_speed" to decodeSpeed,
), "time_to_first_token" to timeToFirstToken,
running = !done, "latency" to (curTs - start).toFloat() / 1000f,
latencyMs = -1f, ),
) running = !done,
latencyMs = -1f,
)
updateBenchmark( updateBenchmark(
model = model, model = model,
promptTemplateType = uiState.value.selectedPromptTemplateType, promptTemplateType = uiState.value.selectedPromptTemplateType,
benchmark = benchmark benchmark = benchmark,
) )
lastBenchmarkUpdateTs = curTs lastBenchmarkUpdateTs = curTs
} }
@ -151,7 +153,8 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
cleanUpListener = { cleanUpListener = {
setPreparing(false) setPreparing(false)
setInProgress(false) setInProgress(false)
}) },
)
} }
} }
@ -161,7 +164,9 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
// Clear response. // Clear response.
updateResponse(model = model, promptTemplateType = promptTemplateType, response = "") updateResponse(model = model, promptTemplateType = promptTemplateType, response = "")
this._uiState.update { this.uiState.value.copy(selectedPromptTemplateType = promptTemplateType) } this._uiState.update {
this.uiState.value.copy(selectedPromptTemplateType = promptTemplateType)
}
} }
fun setInProgress(inProgress: Boolean) { fun setInProgress(inProgress: Boolean) {
@ -184,7 +189,9 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
} }
fun updateBenchmark( fun updateBenchmark(
model: Model, promptTemplateType: PromptTemplateType, benchmark: ChatMessageBenchmarkLlmResult model: Model,
promptTemplateType: PromptTemplateType,
benchmark: ChatMessageBenchmarkLlmResult,
) { ) {
_uiState.update { currentState -> _uiState.update { currentState ->
val currentBenchmark = currentState.benchmarkByModel val currentBenchmark = currentState.benchmarkByModel
@ -218,4 +225,4 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
benchmarkByModel = benchmarkByModel, benchmarkByModel = benchmarkByModel,
) )
} }
} }

View file

@ -16,21 +16,23 @@
package com.google.ai.edge.gallery.ui.llmsingleturn package com.google.ai.edge.gallery.ui.llmsingleturn
import androidx.compose.ui.graphics.Brush.Companion.linearGradient
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.buildAnnotatedString import androidx.compose.ui.text.buildAnnotatedString
import androidx.compose.ui.text.withStyle import androidx.compose.ui.text.withStyle
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.graphics.Brush.Companion.linearGradient
enum class PromptTemplateInputEditorType { enum class PromptTemplateInputEditorType {
SINGLE_SELECT SINGLE_SELECT
} }
enum class RewriteToneType(val label: String) { enum class RewriteToneType(val label: String) {
FORMAL(label = "Formal"), CASUAL(label = "Casual"), FRIENDLY(label = "Friendly"), POLITE(label = "Polite"), ENTHUSIASTIC( FORMAL(label = "Formal"),
label = "Enthusiastic" CASUAL(label = "Casual"),
), FRIENDLY(label = "Friendly"),
POLITE(label = "Polite"),
ENTHUSIASTIC(label = "Enthusiastic"),
CONCISE(label = "Concise"), CONCISE(label = "Concise"),
} }
@ -69,51 +71,60 @@ class PromptTemplateSingleSelectInputEditor(
override val label: String, override val label: String,
val options: List<String> = listOf(), val options: List<String> = listOf(),
override val defaultOption: String = "", override val defaultOption: String = "",
) : PromptTemplateInputEditor( ) :
label = label, type = PromptTemplateInputEditorType.SINGLE_SELECT, defaultOption = defaultOption PromptTemplateInputEditor(
) label = label,
type = PromptTemplateInputEditorType.SINGLE_SELECT,
defaultOption = defaultOption,
)
data class PromptTemplateConfig(val inputEditors: List<PromptTemplateInputEditor> = listOf()) data class PromptTemplateConfig(val inputEditors: List<PromptTemplateInputEditor> = listOf())
private val GEMINI_GRADIENT_STYLE = SpanStyle( private val GEMINI_GRADIENT_STYLE =
brush = linearGradient( SpanStyle(
colors = listOf(Color(0xFF4285f4), Color(0xFF9b72cb), Color(0xFFd96570)) brush = linearGradient(colors = listOf(Color(0xFF4285f4), Color(0xFF9b72cb), Color(0xFFd96570)))
) )
)
@Suppress("ImmutableEnum")
enum class PromptTemplateType( enum class PromptTemplateType(
val label: String, val label: String,
val config: PromptTemplateConfig, val config: PromptTemplateConfig,
val genFullPrompt: (userInput: String, inputEditorValues: Map<String, Any>) -> AnnotatedString = { _, _ -> val genFullPrompt: (userInput: String, inputEditorValues: Map<String, Any>) -> AnnotatedString =
AnnotatedString("") { _, _ ->
}, AnnotatedString("")
},
val examplePrompts: List<String> = listOf(), val examplePrompts: List<String> = listOf(),
) { ) {
FREE_FORM( FREE_FORM(
label = "Free form", label = "Free form",
config = PromptTemplateConfig(), config = PromptTemplateConfig(),
genFullPrompt = { userInput, _ -> AnnotatedString(userInput) }, genFullPrompt = { userInput, _ -> AnnotatedString(userInput) },
examplePrompts = listOf( examplePrompts =
"Suggest 3 topics for a podcast about \"Friendships in your 20s\".", listOf(
"Outline the key sections needed in a basic logo design brief.", "Suggest 3 topics for a podcast about \"Friendships in your 20s\".",
"List 3 pros and 3 cons to consider before buying a smart watch.", "Outline the key sections needed in a basic logo design brief.",
"Write a short, optimistic quote about the future of technology.", "List 3 pros and 3 cons to consider before buying a smart watch.",
"Generate 3 potential names for a mobile app that helps users identify plants.", "Write a short, optimistic quote about the future of technology.",
"Explain the difference between AI and machine learning in 2 sentences.", "Generate 3 potential names for a mobile app that helps users identify plants.",
"Create a simple haiku about a cat sleeping in the sun.", "Explain the difference between AI and machine learning in 2 sentences.",
"List 3 ways to make instant noodles taste better using common kitchen ingredients." "Create a simple haiku about a cat sleeping in the sun.",
) "List 3 ways to make instant noodles taste better using common kitchen ingredients.",
),
), ),
REWRITE_TONE( REWRITE_TONE(
label = "Rewrite tone", config = PromptTemplateConfig( label = "Rewrite tone",
inputEditors = listOf( config =
PromptTemplateSingleSelectInputEditor( PromptTemplateConfig(
label = InputEditorLabel.TONE.label, inputEditors =
options = RewriteToneType.entries.map { it.label }, listOf(
defaultOption = RewriteToneType.FORMAL.label PromptTemplateSingleSelectInputEditor(
) label = InputEditorLabel.TONE.label,
) options = RewriteToneType.entries.map { it.label },
), genFullPrompt = { userInput, inputEditorValues -> defaultOption = RewriteToneType.FORMAL.label,
)
)
),
genFullPrompt = { userInput, inputEditorValues ->
val tone = inputEditorValues[InputEditorLabel.TONE.label] as String val tone = inputEditorValues[InputEditorLabel.TONE.label] as String
buildAnnotatedString { buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) { withStyle(GEMINI_GRADIENT_STYLE) {
@ -121,25 +132,29 @@ enum class PromptTemplateType(
} }
append(userInput) append(userInput)
} }
}, examplePrompts = listOf( },
"Hey team, just wanted to remind everyone about the meeting tomorrow @ 10. Be there!", examplePrompts =
"Our new software update includes several bug fixes and performance improvements.", listOf(
"Due to the fact that the weather was bad, we decided to postpone the event.", "Hey team, just wanted to remind everyone about the meeting tomorrow @ 10. Be there!",
"Please find attached the requested documentation for your perusal.", "Our new software update includes several bug fixes and performance improvements.",
"Welcome to the team. Review the onboarding materials.", "Due to the fact that the weather was bad, we decided to postpone the event.",
) "Please find attached the requested documentation for your perusal.",
"Welcome to the team. Review the onboarding materials.",
),
), ),
SUMMARIZE_TEXT( SUMMARIZE_TEXT(
label = "Summarize text", label = "Summarize text",
config = PromptTemplateConfig( config =
inputEditors = listOf( PromptTemplateConfig(
PromptTemplateSingleSelectInputEditor( inputEditors =
label = InputEditorLabel.STYLE.label, listOf(
options = SummarizationType.entries.map { it.label }, PromptTemplateSingleSelectInputEditor(
defaultOption = SummarizationType.KEY_BULLET_POINT.label label = InputEditorLabel.STYLE.label,
) options = SummarizationType.entries.map { it.label },
) defaultOption = SummarizationType.KEY_BULLET_POINT.label,
), )
)
),
genFullPrompt = { userInput, inputEditorValues -> genFullPrompt = { userInput, inputEditorValues ->
val style = inputEditorValues[InputEditorLabel.STYLE.label] as String val style = inputEditorValues[InputEditorLabel.STYLE.label] as String
buildAnnotatedString { buildAnnotatedString {
@ -149,37 +164,38 @@ enum class PromptTemplateType(
append(userInput) append(userInput)
} }
}, },
examplePrompts = listOf( examplePrompts =
"The new Pixel phone features an advanced camera system with improved low-light performance and AI-powered editing tools. The display is brighter and more energy-efficient. It runs on the latest Tensor chip, offering faster processing and enhanced security features. Battery life has also been extended, providing all-day power for most users.", listOf(
"Beginning this Friday, January 24, giant pandas Bao Li and Qing Bao are officially on view to the public at the Smithsonians National Zoo and Conservation Biology Institute (NZCBI). The 3-year-old bears arrived in Washington this past October, undergoing a quarantine period before making their debut. Under NZCBIs new agreement with the CWCA, Qing Bao and Bao Li will remain in the United States for ten years, until April 2034, in exchange for an annual fee of \$1 million. The pair are still too young to breed, as pandas only reach sexual maturity between ages 4 and 7. “Kind of picture them as like awkward teenagers right now,” Lally told WUSA9. “We still have about two years before we would probably even see signs that theyre ready to start mating.”", "The new Pixel phone features an advanced camera system with improved low-light performance and AI-powered editing tools. The display is brighter and more energy-efficient. It runs on the latest Tensor chip, offering faster processing and enhanced security features. Battery life has also been extended, providing all-day power for most users.",
), "Beginning this Friday, January 24, giant pandas Bao Li and Qing Bao are officially on view to the public at the Smithsonians National Zoo and Conservation Biology Institute (NZCBI). The 3-year-old bears arrived in Washington this past October, undergoing a quarantine period before making their debut. Under NZCBIs new agreement with the CWCA, Qing Bao and Bao Li will remain in the United States for ten years, until April 2034, in exchange for an annual fee of \$1 million. The pair are still too young to breed, as pandas only reach sexual maturity between ages 4 and 7. “Kind of picture them as like awkward teenagers right now,” Lally told WUSA9. “We still have about two years before we would probably even see signs that theyre ready to start mating.”",
),
), ),
CODE_SNIPPET( CODE_SNIPPET(
label = "Code snippet", label = "Code snippet",
config = PromptTemplateConfig( config =
inputEditors = listOf( PromptTemplateConfig(
PromptTemplateSingleSelectInputEditor( inputEditors =
label = InputEditorLabel.LANGUAGE.label, listOf(
options = LanguageType.entries.map { it.label }, PromptTemplateSingleSelectInputEditor(
defaultOption = LanguageType.JAVASCRIPT.label label = InputEditorLabel.LANGUAGE.label,
) options = LanguageType.entries.map { it.label },
) defaultOption = LanguageType.JAVASCRIPT.label,
), )
)
),
genFullPrompt = { userInput, inputEditorValues -> genFullPrompt = { userInput, inputEditorValues ->
val language = inputEditorValues[InputEditorLabel.LANGUAGE.label] as String val language = inputEditorValues[InputEditorLabel.LANGUAGE.label] as String
buildAnnotatedString { buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) { withStyle(GEMINI_GRADIENT_STYLE) { append("Write a $language code snippet to ") }
append("Write a $language code snippet to ")
}
append(userInput) append(userInput)
} }
}, },
examplePrompts = listOf( examplePrompts =
"Create an alert box that says \"Hello, World!\"", listOf(
"Declare an immutable variable named 'appName' with the value \"AI Gallery\"", "Create an alert box that says \"Hello, World!\"",
"Print the numbers from 1 to 5 using a for loop.", "Declare an immutable variable named 'appName' with the value \"AI Gallery\"",
"Write a function that returns the square of an integer input.", "Print the numbers from 1 to 5 using a for loop.",
), "Write a function that returns the square of an integer input.",
),
), ),
} }

View file

@ -16,6 +16,7 @@
package com.google.ai.edge.gallery.ui.llmsingleturn package com.google.ai.edge.gallery.ui.llmsingleturn
import android.content.ClipData
import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.border import androidx.compose.foundation.border
@ -79,7 +80,8 @@ import androidx.compose.ui.draw.clip
import androidx.compose.ui.focus.FocusRequester import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusRequester import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.platform.ClipEntry
import androidx.compose.ui.platform.LocalClipboard
import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.res.dimensionResource import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.text.TextLayoutResult import androidx.compose.ui.text.TextLayoutResult
@ -108,7 +110,7 @@ fun PromptTemplatesPanel(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
onSend: (fullPrompt: String) -> Unit, onSend: (fullPrompt: String) -> Unit,
onStopButtonClicked: (Model) -> Unit, onStopButtonClicked: (Model) -> Unit,
modifier: Modifier = Modifier modifier: Modifier = Modifier,
) { ) {
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
@ -125,13 +127,12 @@ fun PromptTemplatesPanel(
uiState.selectedPromptTemplateType.genFullPrompt(curTextInputContent, inputEditorValues) uiState.selectedPromptTemplateType.genFullPrompt(curTextInputContent, inputEditorValues)
} }
} }
val clipboardManager = LocalClipboardManager.current val clipboard = LocalClipboard.current
val focusRequester = remember { FocusRequester() } val focusRequester = remember { FocusRequester() }
val focusManager = LocalFocusManager.current val focusManager = LocalFocusManager.current
val interactionSource = remember { MutableInteractionSource() } val interactionSource = remember { MutableInteractionSource() }
val expandedStates = remember { mutableStateMapOf<String, Boolean>() } val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
val modelInitializationStatus = val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[model.name]
modelManagerUiState.modelInitializationStatus[model.name]
// Update input editor values when prompt template changes. // Update input editor values when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) { LaunchedEffect(selectedPromptTemplateType) {
@ -147,11 +148,10 @@ fun PromptTemplatesPanel(
Column(modifier = modifier) { Column(modifier = modifier) {
// Scrollable tab row for all prompt templates. // Scrollable tab row for all prompt templates.
PrimaryScrollableTabRow( PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) {
selectedTabIndex = selectedTabIndex
) {
TAB_TITLES.forEachIndexed { index, title -> TAB_TITLES.forEachIndexed { index, title ->
Tab(selected = selectedTabIndex == index, Tab(
selected = selectedTabIndex == index,
enabled = !inProgress, enabled = !inProgress,
onClick = { onClick = {
// Clear input when tab changes. // Clear input when tab changes.
@ -162,41 +162,34 @@ fun PromptTemplatesPanel(
selectedTabIndex = index selectedTabIndex = index
viewModel.selectPromptTemplate( viewModel.selectPromptTemplate(
model = model, model = model,
promptTemplateType = promptTemplateTypes[index] promptTemplateType = promptTemplateTypes[index],
) )
}, },
text = { text = { Text(text = title, modifier = Modifier.alpha(if (inProgress) 0.5f else 1f)) },
Text( )
text = title,
modifier = Modifier.alpha(if (inProgress) 0.5f else 1f)
)
})
} }
} }
// Content. // Content.
Column( Column(modifier = Modifier.weight(1f).fillMaxWidth()) {
modifier = Modifier
.weight(1f)
.fillMaxWidth()
) {
// Input editor row. // Input editor row.
if (selectedPromptTemplateType.config.inputEditors.isNotEmpty()) { if (selectedPromptTemplateType.config.inputEditors.isNotEmpty()) {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp), horizontalArrangement = Arrangement.spacedBy(8.dp),
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.background(MaterialTheme.colorScheme.surfaceContainerLow) .background(MaterialTheme.colorScheme.surfaceContainerLow)
.padding(horizontal = 16.dp, vertical = 10.dp) .padding(horizontal = 16.dp, vertical = 10.dp),
) { ) {
// Input editors. // Input editors.
for (inputEditor in selectedPromptTemplateType.config.inputEditors) { for (inputEditor in selectedPromptTemplateType.config.inputEditors) {
when (inputEditor.type) { when (inputEditor.type) {
PromptTemplateInputEditorType.SINGLE_SELECT -> SingleSelectButton(config = inputEditor as PromptTemplateSingleSelectInputEditor, PromptTemplateInputEditorType.SINGLE_SELECT ->
onSelected = { option -> SingleSelectButton(
inputEditorValues[inputEditor.label] = option config = inputEditor as PromptTemplateSingleSelectInputEditor,
}) onSelected = { option -> inputEditorValues[inputEditor.label] = option },
)
} }
} }
} }
@ -205,12 +198,10 @@ fun PromptTemplatesPanel(
// Text input box. // Text input box.
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) { Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
Column( Column(
modifier = Modifier modifier =
.fillMaxSize() Modifier.fillMaxSize().verticalScroll(rememberScrollState()).clickable(
.verticalScroll(rememberScrollState())
.clickable(
interactionSource = interactionSource, interactionSource = interactionSource,
indication = null // Disable the ripple effect indication = null, // Disable the ripple effect
) { ) {
// Request focus on the TextField when the Column is clicked // Request focus on the TextField when the Column is clicked
focusRequester.requestFocus() focusRequester.requestFocus()
@ -220,32 +211,31 @@ fun PromptTemplatesPanel(
Text( Text(
fullPrompt, fullPrompt,
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.padding(16.dp) .padding(16.dp)
.padding(bottom = 40.dp) .padding(bottom = 40.dp)
.clip(MessageBubbleShape(radius = bubbleBorderRadius)) .clip(MessageBubbleShape(radius = bubbleBorderRadius))
.background(MaterialTheme.customColors.agentBubbleBgColor) .background(MaterialTheme.customColors.agentBubbleBgColor)
.padding(16.dp) .padding(16.dp)
.focusRequester(focusRequester) .focusRequester(focusRequester),
) )
} else { } else {
TextField( TextField(
value = curTextInputContent, value = curTextInputContent,
onValueChange = { curTextInputContent = it }, onValueChange = { curTextInputContent = it },
colors = TextFieldDefaults.colors( colors =
unfocusedContainerColor = Color.Transparent, TextFieldDefaults.colors(
focusedContainerColor = Color.Transparent, unfocusedContainerColor = Color.Transparent,
focusedIndicatorColor = Color.Transparent, focusedContainerColor = Color.Transparent,
unfocusedIndicatorColor = Color.Transparent, focusedIndicatorColor = Color.Transparent,
disabledIndicatorColor = Color.Transparent, unfocusedIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent, disabledIndicatorColor = Color.Transparent,
), disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyLarge, textStyle = MaterialTheme.typography.bodyLarge,
placeholder = { Text("Enter content") }, placeholder = { Text("Enter content") },
modifier = Modifier modifier = Modifier.padding(bottom = 40.dp).focusRequester(focusRequester),
.padding(bottom = 40.dp)
.focusRequester(focusRequester)
) )
} }
} }
@ -254,26 +244,35 @@ fun PromptTemplatesPanel(
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp), horizontalArrangement = Arrangement.spacedBy(4.dp),
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(vertical = 4.dp, horizontal = 16.dp),
.fillMaxWidth()
.padding(vertical = 4.dp, horizontal = 16.dp)
) { ) {
// Full prompt switch. // Full prompt switch.
if (selectedPromptTemplateType != PromptTemplateType.FREE_FORM && curTextInputContent.isNotEmpty()) { if (
Row(verticalAlignment = Alignment.CenterVertically, selectedPromptTemplateType != PromptTemplateType.FREE_FORM &&
curTextInputContent.isNotEmpty()
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp), horizontalArrangement = Arrangement.spacedBy(4.dp),
modifier = Modifier modifier =
.clip(CircleShape) Modifier.clip(CircleShape)
.background(if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.customColors.agentBubbleBgColor) .background(
.clickable { if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean)
inputEditorValues[FULL_PROMPT_SWITCH_KEY] = MaterialTheme.colorScheme.secondaryContainer
!(inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) else MaterialTheme.customColors.agentBubbleBgColor
} )
.height(40.dp) .clickable {
.border( inputEditorValues[FULL_PROMPT_SWITCH_KEY] =
width = 1.dp, color = MaterialTheme.colorScheme.surface, shape = CircleShape !(inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean)
) }
.padding(horizontal = 12.dp)) { .height(40.dp)
.border(
width = 1.dp,
color = MaterialTheme.colorScheme.surface,
shape = CircleShape,
)
.padding(horizontal = 12.dp),
) {
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) { if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
Icon( Icon(
imageVector = Icons.Rounded.Visibility, imageVector = Icons.Rounded.Visibility,
@ -284,9 +283,7 @@ fun PromptTemplatesPanel(
Icon( Icon(
imageVector = Icons.Rounded.VisibilityOff, imageVector = Icons.Rounded.VisibilityOff,
contentDescription = "", contentDescription = "",
modifier = Modifier modifier = Modifier.size(FilterChipDefaults.IconSize).alpha(0.3f),
.size(FilterChipDefaults.IconSize)
.alpha(0.3f),
) )
} }
Text("Preview prompt", style = MaterialTheme.typography.labelMedium) Text("Preview prompt", style = MaterialTheme.typography.labelMedium)
@ -299,20 +296,27 @@ fun PromptTemplatesPanel(
if (curTextInputContent.isNotEmpty()) { if (curTextInputContent.isNotEmpty()) {
OutlinedIconButton( OutlinedIconButton(
onClick = { onClick = {
val clipData = fullPrompt scope.launch {
clipboardManager.setText(clipData) val clipData = ClipData.newPlainText("prompt", fullPrompt)
val clipEntry = ClipEntry(clipData = clipData)
clipboard.setClipEntry(clipEntry = clipEntry)
}
}, },
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.customColors.agentBubbleBgColor, IconButtonDefaults.iconButtonColors(
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f), containerColor = MaterialTheme.customColors.agentBubbleBgColor,
contentColor = MaterialTheme.colorScheme.primary, disabledContainerColor =
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f), MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
), contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface), border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE) modifier = Modifier.size(ICON_BUTTON_SIZE),
) { ) {
Icon( Icon(
Icons.Outlined.ContentCopy, contentDescription = "", modifier = Modifier.size(20.dp) Icons.Outlined.ContentCopy,
contentDescription = "",
modifier = Modifier.size(20.dp),
) )
} }
} }
@ -321,38 +325,35 @@ fun PromptTemplatesPanel(
OutlinedIconButton( OutlinedIconButton(
enabled = !inProgress, enabled = !inProgress,
onClick = { showExamplePromptBottomSheet = true }, onClick = { showExamplePromptBottomSheet = true },
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.customColors.agentBubbleBgColor, IconButtonDefaults.iconButtonColors(
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f), containerColor = MaterialTheme.customColors.agentBubbleBgColor,
contentColor = MaterialTheme.colorScheme.primary, disabledContainerColor =
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f), MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
), contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface), border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE) modifier = Modifier.size(ICON_BUTTON_SIZE),
) { ) {
Icon( Icon(Icons.Rounded.Add, contentDescription = "", modifier = Modifier.size(20.dp))
Icons.Rounded.Add,
contentDescription = "",
modifier = Modifier.size(20.dp),
)
} }
val modelInitializing = val modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (inProgress && !modelInitializing && !uiState.preparing) { if (inProgress && !modelInitializing && !uiState.preparing) {
IconButton( IconButton(
onClick = { onClick = { onStopButtonClicked(model) },
onStopButtonClicked(model) colors =
}, IconButtonDefaults.iconButtonColors(
colors = IconButtonDefaults.iconButtonColors( containerColor = MaterialTheme.colorScheme.secondaryContainer
containerColor = MaterialTheme.colorScheme.secondaryContainer, ),
), modifier = Modifier.size(ICON_BUTTON_SIZE),
modifier = Modifier.size(ICON_BUTTON_SIZE)
) { ) {
Icon( Icon(
Icons.Rounded.Stop, Icons.Rounded.Stop,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.primary tint = MaterialTheme.colorScheme.primary,
) )
} }
} else { } else {
@ -363,21 +364,21 @@ fun PromptTemplatesPanel(
focusManager.clearFocus() focusManager.clearFocus()
onSend(fullPrompt.text) onSend(fullPrompt.text)
}, },
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.colorScheme.secondaryContainer, IconButtonDefaults.iconButtonColors(
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f), containerColor = MaterialTheme.colorScheme.secondaryContainer,
contentColor = MaterialTheme.colorScheme.primary, disabledContainerColor =
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f), MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
), contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface), border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE) modifier = Modifier.size(ICON_BUTTON_SIZE),
) { ) {
Icon( Icon(
Icons.AutoMirrored.Rounded.Send, Icons.AutoMirrored.Rounded.Send,
contentDescription = "", contentDescription = "",
modifier = Modifier modifier = Modifier.size(20.dp).offset(x = 2.dp),
.size(20.dp)
.offset(x = 2.dp),
) )
} }
} }
@ -396,89 +397,82 @@ fun PromptTemplatesPanel(
// Title // Title
Text( Text(
"Select an example", "Select an example",
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(16.dp),
.fillMaxWidth() style = MaterialTheme.typography.titleLarge,
.padding(16.dp),
style = MaterialTheme.typography.titleLarge
) )
// Examples // Examples
for (prompt in selectedPromptTemplateType.examplePrompts) { for (prompt in selectedPromptTemplateType.examplePrompts) {
var textLayoutResultState by remember { mutableStateOf<TextLayoutResult?>(null) } var textLayoutResultState by remember { mutableStateOf<TextLayoutResult?>(null) }
val hasOverflow = remember(textLayoutResultState) { val hasOverflow =
textLayoutResultState?.hasVisualOverflow ?: false remember(textLayoutResultState) { textLayoutResultState?.hasVisualOverflow ?: false }
}
val isExpanded = expandedStates[prompt] ?: false val isExpanded = expandedStates[prompt] ?: false
Column( Column(
modifier = Modifier modifier =
.fillMaxWidth() Modifier.fillMaxWidth()
.clickable { .clickable {
curTextInputContent = prompt curTextInputContent = prompt
scope.launch { scope.launch {
// Give it sometime to show the click effect. // Give it sometime to show the click effect.
delay(200) delay(200)
showExamplePromptBottomSheet = false showExamplePromptBottomSheet = false
}
} }
} .padding(horizontal = 16.dp, vertical = 8.dp)
.padding(horizontal = 16.dp, vertical = 8.dp),
) { ) {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp), horizontalArrangement = Arrangement.spacedBy(8.dp),
) { ) {
Icon(Icons.Outlined.Description, contentDescription = "") Icon(Icons.Outlined.Description, contentDescription = "")
Text(prompt, Text(
prompt,
maxLines = if (isExpanded) Int.MAX_VALUE else 3, maxLines = if (isExpanded) Int.MAX_VALUE else 3,
overflow = TextOverflow.Ellipsis, overflow = TextOverflow.Ellipsis,
style = MaterialTheme.typography.bodySmall, style = MaterialTheme.typography.bodySmall,
modifier = Modifier.weight(1f), modifier = Modifier.weight(1f),
onTextLayout = { textLayoutResultState = it } onTextLayout = { textLayoutResultState = it },
) )
} }
if (hasOverflow && !isExpanded) { if (hasOverflow && !isExpanded) {
Row( Row(
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(top = 2.dp),
.fillMaxWidth() horizontalArrangement = Arrangement.End,
.padding(top = 2.dp),
horizontalArrangement = Arrangement.End
) { ) {
Box(modifier = Modifier Box(
.padding(end = 16.dp) modifier =
.clip(CircleShape) Modifier.padding(end = 16.dp)
.background(MaterialTheme.colorScheme.surfaceContainerHighest) .clip(CircleShape)
.clickable { .background(MaterialTheme.colorScheme.surfaceContainerHighest)
expandedStates[prompt] = true .clickable { expandedStates[prompt] = true }
} .padding(vertical = 1.dp, horizontal = 6.dp)
.padding(vertical = 1.dp, horizontal = 6.dp)) { ) {
Icon( Icon(
Icons.Outlined.ExpandMore, Icons.Outlined.ExpandMore,
contentDescription = "", contentDescription = "",
modifier = Modifier.size(12.dp) modifier = Modifier.size(12.dp),
) )
} }
} }
} else if (isExpanded) { } else if (isExpanded) {
Row( Row(
modifier = Modifier modifier = Modifier.fillMaxWidth().padding(top = 2.dp),
.fillMaxWidth() horizontalArrangement = Arrangement.End,
.padding(top = 2.dp),
horizontalArrangement = Arrangement.End
) { ) {
Box(modifier = Modifier Box(
.padding(end = 16.dp) modifier =
.clip(CircleShape) Modifier.padding(end = 16.dp)
.background(MaterialTheme.colorScheme.surfaceContainerHighest) .clip(CircleShape)
.clickable { .background(MaterialTheme.colorScheme.surfaceContainerHighest)
expandedStates[prompt] = false .clickable { expandedStates[prompt] = false }
} .padding(vertical = 1.dp, horizontal = 6.dp)
.padding(vertical = 1.dp, horizontal = 6.dp)) { ) {
Icon( Icon(
Icons.Outlined.ExpandLess, Icons.Outlined.ExpandLess,
contentDescription = "", contentDescription = "",
modifier = Modifier.size(12.dp) modifier = Modifier.size(12.dp),
) )
} }
} }

View file

@ -16,6 +16,7 @@
package com.google.ai.edge.gallery.ui.llmsingleturn package com.google.ai.edge.gallery.ui.llmsingleturn
import android.content.ClipData
import android.util.Log import android.util.Log
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
@ -49,24 +50,26 @@ import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshotFlow import androidx.compose.runtime.snapshotFlow
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.platform.ClipEntry
import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.platform.LocalClipboard
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.ui.common.chat.MarkdownText import com.google.ai.edge.gallery.ui.common.MarkdownText
import com.google.ai.edge.gallery.ui.common.chat.MessageBodyBenchmarkLlm import com.google.ai.edge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.ai.edge.gallery.ui.common.chat.MessageBodyLoading import com.google.ai.edge.gallery.ui.common.chat.MessageBodyLoading
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.modelmanager.PagerScrollState import com.google.ai.edge.gallery.ui.modelmanager.PagerScrollState
import kotlinx.coroutines.launch
private val OPTIONS = listOf("Response", "Benchmark") private val OPTIONS = listOf("Response", "Benchmark")
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer) private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
@ -88,23 +91,21 @@ fun ResponsePanel(
val selectedPromptTemplateType = uiState.selectedPromptTemplateType val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val responseScrollState = rememberScrollState() val responseScrollState = rememberScrollState()
var selectedOptionIndex by remember { mutableIntStateOf(0) } var selectedOptionIndex by remember { mutableIntStateOf(0) }
val clipboardManager = LocalClipboardManager.current val clipboard = LocalClipboard.current
val pagerState = rememberPagerState( val scope = rememberCoroutineScope()
initialPage = task.models.indexOf(model), val pagerState =
pageCount = { task.models.size }) rememberPagerState(initialPage = task.models.indexOf(model), pageCount = { task.models.size })
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
// Select the "response" tab when prompt template changes. // Select the "response" tab when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) { LaunchedEffect(selectedPromptTemplateType) { selectedOptionIndex = 0 }
selectedOptionIndex = 0
}
// Update selected model and clean up previous model when page is settled on a model page. // Update selected model and clean up previous model when page is settled on a model page.
LaunchedEffect(pagerState.settledPage) { LaunchedEffect(pagerState.settledPage) {
val curSelectedModel = task.models[pagerState.settledPage] val curSelectedModel = task.models[pagerState.settledPage]
Log.d( Log.d(
TAG, TAG,
"Pager settled on model '${curSelectedModel.name}' from '${model.name}'. Updating selected model." "Pager settled on model '${curSelectedModel.name}' from '${model.name}'. Updating selected model.",
) )
if (curSelectedModel.name != model.name) { if (curSelectedModel.name != model.name) {
modelManagerViewModel.cleanupModel(task = task, model = model) modelManagerViewModel.cleanupModel(task = task, model = model)
@ -115,13 +116,12 @@ fun ResponsePanel(
// Trigger scroll sync. // Trigger scroll sync.
LaunchedEffect(pagerState) { LaunchedEffect(pagerState) {
snapshotFlow { snapshotFlow {
PagerScrollState( PagerScrollState(
page = pagerState.currentPage, page = pagerState.currentPage,
offset = pagerState.currentPageOffsetFraction offset = pagerState.currentPageOffsetFraction,
) )
}.collect { scrollState -> }
modelManagerViewModel.pagerScrollState.value = scrollState .collect { scrollState -> modelManagerViewModel.pagerScrollState.value = scrollState }
}
} }
// Scroll pager when selected model changes. // Scroll pager when selected model changes.
@ -147,9 +147,7 @@ fun ResponsePanel(
if (initializing) { if (initializing) {
Box( Box(
contentAlignment = Alignment.TopStart, contentAlignment = Alignment.TopStart,
modifier = modifier modifier = modifier.fillMaxSize().padding(horizontal = 16.dp),
.fillMaxSize()
.padding(horizontal = 16.dp)
) { ) {
MessageBodyLoading() MessageBodyLoading()
} }
@ -159,7 +157,7 @@ fun ResponsePanel(
Row( Row(
modifier = Modifier.fillMaxSize(), modifier = Modifier.fillMaxSize(),
horizontalArrangement = Arrangement.Center, horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically,
) { ) {
Text( Text(
"Response will appear here", "Response will appear here",
@ -170,11 +168,7 @@ fun ResponsePanel(
} }
// Response markdown. // Response markdown.
else { else {
Column( Column(modifier = modifier.padding(horizontal = 16.dp).padding(bottom = 4.dp)) {
modifier = modifier
.padding(horizontal = 16.dp)
.padding(bottom = 4.dp)
) {
// Response/benchmark switch. // Response/benchmark switch.
Row(modifier = Modifier.fillMaxWidth()) { Row(modifier = Modifier.fillMaxWidth()) {
PrimaryTabRow( PrimaryTabRow(
@ -182,66 +176,64 @@ fun ResponsePanel(
containerColor = Color.Transparent, containerColor = Color.Transparent,
) { ) {
OPTIONS.forEachIndexed { index, title -> OPTIONS.forEachIndexed { index, title ->
Tab(selected = selectedOptionIndex == index, onClick = { Tab(
selectedOptionIndex = index selected = selectedOptionIndex == index,
}, text = { onClick = { selectedOptionIndex = index },
Row( text = {
verticalAlignment = Alignment.CenterVertically, Row(
horizontalArrangement = Arrangement.spacedBy(4.dp) verticalAlignment = Alignment.CenterVertically,
) { horizontalArrangement = Arrangement.spacedBy(4.dp),
Icon( ) {
ICONS[index], Icon(
contentDescription = "", ICONS[index],
modifier = Modifier contentDescription = "",
.size(16.dp) modifier = Modifier.size(16.dp).alpha(0.7f),
.alpha(0.7f)
)
var curTitle = title
if (accelerator.isNotEmpty()) {
curTitle = "$curTitle on $accelerator"
}
val titleColor = MaterialTheme.colorScheme.primary
BasicText(
text = curTitle,
maxLines = 1,
color = { titleColor },
style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased(
minFontSize = 9.sp,
maxFontSize = 14.sp,
stepSize = 1.sp
) )
) var curTitle = title
} if (accelerator.isNotEmpty()) {
}) curTitle = "$curTitle on $accelerator"
}
val titleColor = MaterialTheme.colorScheme.primary
BasicText(
text = curTitle,
maxLines = 1,
color = { titleColor },
style = MaterialTheme.typography.bodyMedium,
autoSize =
TextAutoSize.StepBased(
minFontSize = 9.sp,
maxFontSize = 14.sp,
stepSize = 1.sp,
),
)
}
},
)
} }
} }
} }
if (selectedOptionIndex == 0) { if (selectedOptionIndex == 0) {
Box( Box(contentAlignment = Alignment.BottomEnd, modifier = Modifier.weight(1f)) {
contentAlignment = Alignment.BottomEnd, Column(modifier = Modifier.fillMaxSize().verticalScroll(responseScrollState)) {
modifier = Modifier.weight(1f)
) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(responseScrollState)
) {
MarkdownText( MarkdownText(
text = response, text = response,
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp) modifier = Modifier.padding(top = 8.dp, bottom = 40.dp),
) )
} }
// Copy button. // Copy button.
IconButton( IconButton(
onClick = { onClick = {
val clipData = AnnotatedString(response) scope.launch {
clipboardManager.setText(clipData) val clipData = ClipData.newPlainText("response", response)
val clipEntry = ClipEntry(clipData = clipData)
clipboard.setClipEntry(clipEntry = clipEntry)
}
}, },
colors = IconButtonDefaults.iconButtonColors( colors =
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest, IconButtonDefaults.iconButtonColors(
contentColor = MaterialTheme.colorScheme.primary, containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
), contentColor = MaterialTheme.colorScheme.primary,
),
) { ) {
Icon( Icon(
Icons.Outlined.ContentCopy, Icons.Outlined.ContentCopy,

View file

@ -44,36 +44,29 @@ import androidx.compose.ui.unit.dp
@Composable @Composable
fun SingleSelectButton( fun SingleSelectButton(
config: PromptTemplateSingleSelectInputEditor, config: PromptTemplateSingleSelectInputEditor,
onSelected: (String) -> Unit onSelected: (String) -> Unit,
) { ) {
var showMenu by remember { mutableStateOf(false) } var showMenu by remember { mutableStateOf(false) }
var selectedOption by remember { mutableStateOf(config.defaultOption) } var selectedOption by remember { mutableStateOf(config.defaultOption) }
LaunchedEffect(config) { LaunchedEffect(config) { selectedOption = config.defaultOption }
selectedOption = config.defaultOption
}
Box { Box {
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp), horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier modifier =
.clip(RoundedCornerShape(8.dp)) Modifier.clip(RoundedCornerShape(8.dp))
.background(MaterialTheme.colorScheme.secondaryContainer) .background(MaterialTheme.colorScheme.secondaryContainer)
.clickable { .clickable { showMenu = true }
showMenu = true .padding(vertical = 4.dp, horizontal = 6.dp)
} .padding(start = 8.dp),
.padding(vertical = 4.dp, horizontal = 6.dp)
.padding(start = 8.dp)
) { ) {
Text("${config.label}: $selectedOption", style = MaterialTheme.typography.labelLarge) Text("${config.label}: $selectedOption", style = MaterialTheme.typography.labelLarge)
Icon(Icons.Rounded.ArrowDropDown, contentDescription = "") Icon(Icons.Rounded.ArrowDropDown, contentDescription = "")
} }
DropdownMenu( DropdownMenu(expanded = showMenu, onDismissRequest = { showMenu = false }) {
expanded = showMenu,
onDismissRequest = { showMenu = false }
) {
// Options // Options
for (option in config.options) { for (option in config.options) {
DropdownMenuItem( DropdownMenuItem(
@ -82,9 +75,9 @@ fun SingleSelectButton(
selectedOption = option selectedOption = option
showMenu = false showMenu = false
onSelected(option) onSelected(option)
} },
) )
} }
} }
} }
} }

Some files were not shown because too many files have changed in this diff Show more