Merge branch 'main' into main

This commit is contained in:
Fattire 2025-06-16 12:28:36 -07:00 committed by GitHub
commit 6cad1f60ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
128 changed files with 5679 additions and 7063 deletions

25
.github/workflows/build_android.yaml vendored Normal file
View file

@ -0,0 +1,25 @@
name: Build Android APK
on:
workflow_dispatch:
push:
branches: [ "main" ]
paths:
- 'Android/**'
pull_request:
branches: [ "main" ]
paths:
- 'Android/**'
jobs:
build_apk:
name: Build Android APK
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./Android/src
steps:
- name: Checkout the source code
uses: actions/checkout@v3
- name: Build
run: ./gradlew assembleRelease

16
Android/.gitignore vendored
View file

@ -1,3 +1,19 @@
# @license
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Gradle files
.gradle/
build/

View file

@ -1 +1 @@
# AI Edge Gallery (Android)
# Google AI Edge Gallery (Android)

View file

@ -1,4 +1,20 @@
# @license
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
*.iml
.gradle
/local.properties
/.idea/caches

View file

@ -1,2 +1,18 @@
# @license
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
/build
/release

View file

@ -19,6 +19,7 @@ plugins {
alias(libs.plugins.kotlin.android)
alias(libs.plugins.kotlin.compose)
alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.protobuf)
}
android {
@ -26,12 +27,11 @@ android {
compileSdk = 36
defaultConfig {
// Don't change to com.google.ai.edge.gallery yet.
applicationId = "com.google.aiedge.gallery"
minSdk = 26
targetSdk = 36
versionCode = 1
versionName = "1.0.3"
versionName = "1.0.4"
// Needed for HuggingFace auth workflows.
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.ai.edge.gallery.oauth"
@ -42,10 +42,7 @@ android {
buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro")
signingConfig = signingConfigs.getByName("debug")
}
}
@ -76,7 +73,7 @@ dependencies {
implementation(libs.kotlinx.serialization.json)
implementation(libs.material.icon.extended)
implementation(libs.androidx.work.runtime)
implementation(libs.androidx.datastore.preferences)
implementation(libs.androidx.datastore)
implementation(libs.com.google.code.gson)
implementation(libs.androidx.lifecycle.process)
implementation(libs.mediapipe.tasks.text)
@ -93,6 +90,7 @@ dependencies {
implementation(libs.camerax.view)
implementation(libs.openid.appauth)
implementation(libs.androidx.splashscreen)
implementation(libs.protobuf.javalite)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
@ -101,4 +99,9 @@ dependencies {
debugImplementation(libs.androidx.ui.tooling)
debugImplementation(libs.androidx.ui.test.manifest)
implementation(libs.material)
}
protobuf {
protoc { artifact = "com.google.protobuf:protoc:4.26.1" }
generateProtoTasks { all().forEach { it.plugins { create("java") { option("lite") } } } }
}

View file

@ -1,21 +0,0 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View file

@ -16,13 +16,20 @@
-->
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.ai.edge.gallery"
xmlns:tools="http://schemas.android.com/tools">
<uses-sdk
android:minSdkVersion="26"
android:compileSdkVersion ="35"
android:targetSdkVersion="35" />
<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
<uses-permission android:name="android.permission.CAMERA" />
<uses-permission android:name="android.permission.WAKE_LOCK"/>
<uses-feature
android:name="android.hardware.camera"
@ -63,6 +70,7 @@
</intent-filter>
</activity>
<!-- For LLM inference engine -->
<uses-native-library
android:name="libOpenCL.so"
android:required="false" />

View file

@ -14,167 +14,15 @@
* limitations under the License.
*/
@file:OptIn(ExperimentalMaterial3Api::class)
package com.google.ai.edge.gallery
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.text.BasicText
import androidx.compose.foundation.text.TextAutoSize
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBarScrollBehavior
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import com.google.ai.edge.gallery.data.AppBarAction
import com.google.ai.edge.gallery.data.AppBarActionType
import com.google.ai.edge.gallery.ui.navigation.GalleryNavHost
/**
* Top level composable representing the main screen of the application.
*/
/** Top level composable representing the main screen of the application. */
@Composable
fun GalleryApp(navController: NavHostController = rememberNavController()) {
GalleryNavHost(navController = navController)
}
/**
* The top app bar.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun GalleryTopAppBar(
title: String,
modifier: Modifier = Modifier,
leftAction: AppBarAction? = null,
rightAction: AppBarAction? = null,
scrollBehavior: TopAppBarScrollBehavior? = null,
subtitle: String = "",
) {
val titleColor = MaterialTheme.colorScheme.primary
CenterAlignedTopAppBar(
title = {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
if (title == stringResource(R.string.app_name)) {
Icon(
painterResource(R.drawable.logo),
modifier = Modifier.size(20.dp),
contentDescription = "",
tint = Color.Unspecified,
)
}
BasicText(
text = title,
maxLines = 1,
color = { titleColor },
style = MaterialTheme.typography.titleLarge.copy(fontWeight = FontWeight.SemiBold),
autoSize = TextAutoSize.StepBased(
minFontSize = 14.sp,
maxFontSize = 22.sp,
stepSize = 1.sp
)
)
}
if (subtitle.isNotEmpty()) {
Text(
subtitle,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.secondary
)
}
}
},
modifier = modifier,
scrollBehavior = scrollBehavior,
// The button at the left.
navigationIcon = {
when (leftAction?.actionType) {
AppBarActionType.NAVIGATE_UP -> {
IconButton(onClick = leftAction.actionFn) {
Icon(
imageVector = Icons.AutoMirrored.Rounded.ArrowBack,
contentDescription = "",
)
}
}
AppBarActionType.REFRESH_MODELS -> {
IconButton(onClick = leftAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Refresh,
contentDescription = "",
tint = MaterialTheme.colorScheme.secondary
)
}
}
AppBarActionType.REFRESHING_MODELS -> {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
strokeWidth = 3.dp,
modifier = Modifier
.padding(start = 16.dp)
.size(20.dp)
)
}
else -> {}
}
},
// The "action" component at the right.
actions = {
when (rightAction?.actionType) {
// Click an icon to open "app setting".
AppBarActionType.APP_SETTING -> {
IconButton(onClick = rightAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Settings,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
AppBarActionType.MODEL_SELECTOR -> {
Text("ms")
}
// Click a button to navigate up.
AppBarActionType.NAVIGATE_UP -> {
TextButton(onClick = rightAction.actionFn) {
Text("Done")
}
}
else -> {}
}
}
)
}

View file

@ -0,0 +1,157 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@file:OptIn(ExperimentalMaterial3Api::class)
package com.google.ai.edge.gallery
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.text.BasicText
import androidx.compose.foundation.text.TextAutoSize
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBarScrollBehavior
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.data.AppBarAction
import com.google.ai.edge.gallery.data.AppBarActionType
/** The top app bar. */
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun GalleryTopAppBar(
title: String,
modifier: Modifier = Modifier,
leftAction: AppBarAction? = null,
rightAction: AppBarAction? = null,
scrollBehavior: TopAppBarScrollBehavior? = null,
subtitle: String = "",
) {
val titleColor = MaterialTheme.colorScheme.primary
CenterAlignedTopAppBar(
title = {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp),
) {
if (title == stringResource(R.string.app_name)) {
Icon(
painterResource(R.drawable.logo),
modifier = Modifier.size(20.dp),
contentDescription = "",
tint = Color.Unspecified,
)
}
BasicText(
text = title,
maxLines = 1,
color = { titleColor },
style = MaterialTheme.typography.titleLarge.copy(fontWeight = FontWeight.SemiBold),
autoSize =
TextAutoSize.StepBased(minFontSize = 14.sp, maxFontSize = 22.sp, stepSize = 1.sp),
)
}
if (subtitle.isNotEmpty()) {
Text(
subtitle,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.secondary,
)
}
}
},
modifier = modifier,
scrollBehavior = scrollBehavior,
// The button at the left.
navigationIcon = {
when (leftAction?.actionType) {
AppBarActionType.NAVIGATE_UP -> {
IconButton(onClick = leftAction.actionFn) {
Icon(imageVector = Icons.AutoMirrored.Rounded.ArrowBack, contentDescription = "")
}
}
AppBarActionType.REFRESH_MODELS -> {
IconButton(onClick = leftAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Refresh,
contentDescription = "",
tint = MaterialTheme.colorScheme.secondary,
)
}
}
AppBarActionType.REFRESHING_MODELS -> {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
strokeWidth = 3.dp,
modifier = Modifier.padding(start = 16.dp).size(20.dp),
)
}
else -> {}
}
},
// The "action" component at the right.
actions = {
when (rightAction?.actionType) {
// Click an icon to open "app setting".
AppBarActionType.APP_SETTING -> {
IconButton(onClick = rightAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Settings,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary,
)
}
}
AppBarActionType.MODEL_SELECTOR -> {
Text("ms")
}
// Click a button to navigate up.
AppBarActionType.NAVIGATE_UP -> {
TextButton(onClick = rightAction.actionFn) { Text("Done") }
}
else -> {}
}
},
)
}

View file

@ -18,15 +18,35 @@ package com.google.ai.edge.gallery
import android.app.Application
import android.content.Context
import androidx.datastore.core.CorruptionException
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.preferencesDataStore
import androidx.datastore.core.Serializer
import androidx.datastore.dataStore
import com.google.ai.edge.gallery.common.writeLaunchInfo
import com.google.ai.edge.gallery.data.AppContainer
import com.google.ai.edge.gallery.data.DefaultAppContainer
import com.google.ai.edge.gallery.ui.common.writeLaunchInfo
import com.google.ai.edge.gallery.proto.Settings
import com.google.ai.edge.gallery.ui.theme.ThemeSettings
import com.google.protobuf.InvalidProtocolBufferException
import java.io.InputStream
import java.io.OutputStream
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
object SettingsSerializer : Serializer<Settings> {
override val defaultValue: Settings = Settings.getDefaultInstance()
override suspend fun readFrom(input: InputStream): Settings {
try {
return Settings.parseFrom(input)
} catch (exception: InvalidProtocolBufferException) {
throw CorruptionException("Cannot read proto.", exception)
}
}
override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
}
private val Context.dataStore: DataStore<Settings> by
dataStore(fileName = "settings.pb", serializer = SettingsSerializer)
class GalleryApplication : Application() {
/** AppContainer instance used by the rest of classes to obtain dependencies */
@ -35,11 +55,10 @@ class GalleryApplication : Application() {
override fun onCreate() {
super.onCreate()
writeLaunchInfo(context = this)
container = DefaultAppContainer(this, dataStore)
// Load theme.
ThemeSettings.themeOverride.value = container.dataStoreRepository.readThemeOverride()
// Load saved theme.
ThemeSettings.themeOverride.value = container.dataStoreRepository.readTheme()
}
}
}

View file

@ -16,29 +16,16 @@
package com.google.ai.edge.gallery
import androidx.lifecycle.DefaultLifecycleObserver
import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.ProcessLifecycleOwner
interface AppLifecycleProvider {
val isAppInForeground: Boolean
var isAppInForeground: Boolean
}
class GalleryLifecycleProvider : AppLifecycleProvider, DefaultLifecycleObserver {
class GalleryLifecycleProvider : AppLifecycleProvider {
private var _isAppInForeground = false
init {
ProcessLifecycleOwner.get().lifecycle.addObserver(this)
}
override val isAppInForeground: Boolean
override var isAppInForeground: Boolean
get() = _isAppInForeground
override fun onResume(owner: LifecycleOwner) {
_isAppInForeground = true
}
override fun onPause(owner: LifecycleOwner) {
_isAppInForeground = false
}
set(value) {
_isAppInForeground = value
}
}

View file

@ -1,12 +0,0 @@
package com.google.ai.edge.gallery
import android.app.Service
import android.content.Intent
import android.os.IBinder
// TODO(jingjin): implement foreground service.
class GalleryService : Service() {
override fun onBind(p0: Intent?): IBinder? {
return null
}
}

View file

@ -16,6 +16,7 @@
package com.google.ai.edge.gallery
import android.os.Build
import android.os.Bundle
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
@ -32,14 +33,11 @@ class MainActivity : ComponentActivity() {
super.onCreate(savedInstanceState)
enableEdgeToEdge()
setContent {
GalleryTheme {
Surface(
modifier = Modifier.fillMaxSize()
) {
GalleryApp()
}
}
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
// Fix for three-button nav not properly going edge-to-edge.
// See: https://issuetracker.google.com/issues/298296168
window.isNavigationBarContrastEnforced = false
}
setContent { GalleryTheme { Surface(modifier = Modifier.fillMaxSize()) { GalleryApp() } } }
}
}

View file

@ -0,0 +1,27 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.common
import androidx.compose.ui.graphics.Color
interface LatencyProvider {
val latencyMs: Float
}
data class Classification(val label: String, val score: Float, val color: Color)
data class JsonObjAndTextContent<T>(val jsonObj: T, val textContent: String)

View file

@ -0,0 +1,114 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.common
import android.content.Context
import android.util.Log
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
data class LaunchInfo(val ts: Long)
private const val TAG = "AGUtils"
private const val LAUNCH_INFO_FILE_NAME = "launch_info"
private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***"
fun readLaunchInfo(context: Context): LaunchInfo? {
try {
val gson = Gson()
val type = object : TypeToken<LaunchInfo>() {}.type
val file = File(context.getExternalFilesDir(null), LAUNCH_INFO_FILE_NAME)
val content = file.readText()
return gson.fromJson(content, type)
} catch (e: Exception) {
Log.e(TAG, "Failed to read launch info", e)
return null
}
}
fun cleanUpMediapipeTaskErrorMessage(message: String): String {
val index = message.indexOf("=== Source Location Trace")
if (index >= 0) {
return message.substring(0, index)
}
return message
}
fun processLlmResponse(response: String): String {
// Add "thinking" and "done thinking" around the thinking content.
var newContent =
response.replace("<think>", "$START_THINKING\n").replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) {
val thinkingContent =
newContent
.substring(0, endThinkingIndex + DONE_THINKING.length)
.replace(START_THINKING, "")
.replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
}
}
newContent = newContent.replace("\\n", "\n")
return newContent
}
fun writeLaunchInfo(context: Context) {
try {
val gson = Gson()
val launchInfo = LaunchInfo(ts = System.currentTimeMillis())
val jsonString = gson.toJson(launchInfo)
val file = File(context.getExternalFilesDir(null), LAUNCH_INFO_FILE_NAME)
file.writeText(jsonString)
} catch (e: Exception) {
Log.e(TAG, "Failed to write launch info", e)
}
}
inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET"
connection.connect()
val responseCode = connection.responseCode
if (responseCode == HttpURLConnection.HTTP_OK) {
val inputStream = connection.inputStream
val response = inputStream.bufferedReader().use { it.readText() }
val gson = Gson()
val type = object : TypeToken<T>() {}.type
val jsonObj = gson.fromJson<T>(response, type)
return JsonObjAndTextContent(jsonObj = jsonObj, textContent = response)
} else {
Log.e("AGUtils", "HTTP error: $responseCode")
}
} catch (e: Exception) {
Log.e("AGUtils", "Error when getting json response: ${e.message}")
e.printStackTrace()
}
return null
}

View file

@ -27,4 +27,4 @@ enum class AppBarActionType {
REFRESHING_MODELS,
}
class AppBarAction(val actionType: AppBarActionType, val actionFn: () -> Unit)
class AppBarAction(val actionType: AppBarActionType, val actionFn: () -> Unit)

View file

@ -18,9 +18,9 @@ package com.google.ai.edge.gallery.data
import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import com.google.ai.edge.gallery.GalleryLifecycleProvider
import com.google.ai.edge.gallery.AppLifecycleProvider
import com.google.ai.edge.gallery.GalleryLifecycleProvider
import com.google.ai.edge.gallery.proto.Settings
/**
* App container for Dependency injection.
@ -39,9 +39,9 @@ interface AppContainer {
*
* This class provides concrete implementations for the application's dependencies,
*/
class DefaultAppContainer(ctx: Context, dataStore: DataStore<Preferences>) : AppContainer {
class DefaultAppContainer(ctx: Context, dataStore: DataStore<Settings>) : AppContainer {
override val context = ctx
override val lifecycleProvider = GalleryLifecycleProvider()
override val dataStoreRepository = DefaultDataStoreRepository(dataStore)
override val downloadRepository = DefaultDownloadRepository(ctx, lifecycleProvider)
}
}

View file

@ -16,11 +16,13 @@
package com.google.ai.edge.gallery.data
import kotlin.math.abs
/**
* The types of configuration editors available.
*
* This enum defines the different UI components used to edit configuration values.
* Each type corresponds to a specific editor widget, such as a slider or a switch.
* This enum defines the different UI components used to edit configuration values. Each type
* corresponds to a specific editor widget, such as a slider or a switch.
*/
enum class ConfigEditorType {
LABEL,
@ -29,9 +31,7 @@ enum class ConfigEditorType {
DROPDOWN,
}
/**
* The data types of configuration values.
*/
/** The data types of configuration values. */
enum class ValueType {
INT,
FLOAT,
@ -40,6 +40,28 @@ enum class ValueType {
BOOLEAN,
}
enum class ConfigKey(val label: String) {
MAX_TOKENS("Max tokens"),
TOPK("TopK"),
TOPP("TopP"),
TEMPERATURE("Temperature"),
DEFAULT_MAX_TOKENS("Default max tokens"),
DEFAULT_TOPK("Default TopK"),
DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"),
SUPPORT_IMAGE("Support image"),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Choose accelerator"),
COMPATIBLE_ACCELERATORS("Compatible accelerators"),
WARM_UP_ITERATIONS("Warm up iterations"),
BENCHMARK_ITERATIONS("Benchmark iterations"),
ITERATIONS("Iterations"),
THEME("Theme"),
NAME("Name"),
MODEL_TYPE("Model type"),
}
/**
* Base class for configuration settings.
*
@ -58,18 +80,14 @@ open class Config(
open val needReinitialization: Boolean = true,
)
/**
* Configuration setting for a label.
*/
class LabelConfig(
override val key: ConfigKey,
override val defaultValue: String = "",
) : Config(
type = ConfigEditorType.LABEL,
key = key,
defaultValue = defaultValue,
valueType = ValueType.STRING
)
/** Configuration setting for a label. */
class LabelConfig(override val key: ConfigKey, override val defaultValue: String = "") :
Config(
type = ConfigEditorType.LABEL,
key = key,
defaultValue = defaultValue,
valueType = ValueType.STRING,
)
/**
* Configuration setting for a number slider.
@ -92,32 +110,122 @@ class NumberSliderConfig(
valueType = valueType,
)
/**
* Configuration setting for a boolean switch.
*/
/** Configuration setting for a boolean switch. */
class BooleanSwitchConfig(
override val key: ConfigKey,
override val defaultValue: Boolean,
override val needReinitialization: Boolean = true,
) : Config(
type = ConfigEditorType.BOOLEAN_SWITCH,
key = key,
defaultValue = defaultValue,
valueType = ValueType.BOOLEAN,
)
) :
Config(
type = ConfigEditorType.BOOLEAN_SWITCH,
key = key,
defaultValue = defaultValue,
valueType = ValueType.BOOLEAN,
)
/**
* Configuration setting for a dropdown.
*/
/** Configuration setting for a dropdown. */
class SegmentedButtonConfig(
override val key: ConfigKey,
override val defaultValue: String,
val options: List<String>,
val allowMultiple: Boolean = false,
) : Config(
type = ConfigEditorType.DROPDOWN,
key = key,
defaultValue = defaultValue,
// The emitted value will be comma-separated labels when allowMultiple=true.
valueType = ValueType.STRING,
)
) :
Config(
type = ConfigEditorType.DROPDOWN,
key = key,
defaultValue = defaultValue,
// The emitted value will be comma-separated labels when allowMultiple=true.
valueType = ValueType.STRING,
)
fun convertValueToTargetType(value: Any, valueType: ValueType): Any {
return when (valueType) {
ValueType.INT ->
when (value) {
is Int -> value
is Float -> value.toInt()
is Double -> value.toInt()
is String -> value.toIntOrNull() ?: ""
is Boolean -> if (value) 1 else 0
else -> ""
}
ValueType.FLOAT ->
when (value) {
is Int -> value.toFloat()
is Float -> value
is Double -> value.toFloat()
is String -> value.toFloatOrNull() ?: ""
is Boolean -> if (value) 1f else 0f
else -> ""
}
ValueType.DOUBLE ->
when (value) {
is Int -> value.toDouble()
is Float -> value.toDouble()
is Double -> value
is String -> value.toDoubleOrNull() ?: ""
is Boolean -> if (value) 1.0 else 0.0
else -> ""
}
ValueType.BOOLEAN ->
when (value) {
is Int -> value == 0
is Boolean -> value
is Float -> abs(value) > 1e-6
is Double -> abs(value) > 1e-6
is String -> value.isNotEmpty()
else -> false
}
ValueType.STRING -> value.toString()
}
}
fun createLlmChatConfigs(
defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
defaultTopK: Int = DEFAULT_TOPK,
defaultTopP: Float = DEFAULT_TOPP,
defaultTemperature: Float = DEFAULT_TEMPERATURE,
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
): List<Config> {
return listOf(
LabelConfig(key = ConfigKey.MAX_TOKENS, defaultValue = "$defaultMaxToken"),
NumberSliderConfig(
key = ConfigKey.TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = defaultTopK.toFloat(),
valueType = ValueType.INT,
),
NumberSliderConfig(
key = ConfigKey.TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = defaultTopP,
valueType = ValueType.FLOAT,
),
NumberSliderConfig(
key = ConfigKey.TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = defaultTemperature,
valueType = ValueType.FLOAT,
),
SegmentedButtonConfig(
key = ConfigKey.ACCELERATOR,
defaultValue = accelerators[0].label,
options = accelerators.map { it.label },
),
)
}
fun getConfigValueString(value: Any, config: Config): String {
var strNewValue = "$value"
if (config.valueType == ValueType.FLOAT) {
strNewValue = "%.2f".format(value)
}
return strNewValue
}

View file

@ -16,64 +16,55 @@
package com.google.ai.edge.gallery.data
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonPrimitive
@Serializable(with = ConfigValueSerializer::class)
// @Serializable(with = ConfigValueSerializer::class)
sealed class ConfigValue {
@Serializable
// @Serializable
data class IntValue(val value: Int) : ConfigValue()
@Serializable
// @Serializable
data class FloatValue(val value: Float) : ConfigValue()
@Serializable
// @Serializable
data class StringValue(val value: String) : ConfigValue()
}
/**
* Custom serializer for the ConfigValue class.
*
* This object implements the KSerializer interface to provide custom serialization and
* deserialization logic for the ConfigValue class. It handles different types of ConfigValue
* (IntValue, FloatValue, StringValue) and supports JSON format.
*/
object ConfigValueSerializer : KSerializer<ConfigValue> {
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("ConfigValue")
// /**
// * Custom serializer for the ConfigValue class.
// *
// * This object implements the KSerializer interface to provide custom serialization and
// * deserialization logic for the ConfigValue class. It handles different types of ConfigValue
// * (IntValue, FloatValue, StringValue) and supports JSON format.
// */
// object ConfigValueSerializer : KSerializer<ConfigValue> {
// override val descriptor: SerialDescriptor = buildClassSerialDescriptor("ConfigValue")
override fun serialize(encoder: Encoder, value: ConfigValue) {
when (value) {
is ConfigValue.IntValue -> encoder.encodeInt(value.value)
is ConfigValue.FloatValue -> encoder.encodeFloat(value.value)
is ConfigValue.StringValue -> encoder.encodeString(value.value)
}
}
// override fun serialize(encoder: Encoder, value: ConfigValue) {
// when (value) {
// is ConfigValue.IntValue -> encoder.encodeInt(value.value)
// is ConfigValue.FloatValue -> encoder.encodeFloat(value.value)
// is ConfigValue.StringValue -> encoder.encodeString(value.value)
// }
// }
override fun deserialize(decoder: Decoder): ConfigValue {
val input = decoder as? JsonDecoder
?: throw SerializationException("This serializer only works with Json")
return when (val element = input.decodeJsonElement()) {
is JsonPrimitive -> {
if (element.isString) {
ConfigValue.StringValue(element.content)
} else if (element.content.contains('.')) {
ConfigValue.FloatValue(element.content.toFloat())
} else {
ConfigValue.IntValue(element.content.toInt())
}
}
// override fun deserialize(decoder: Decoder): ConfigValue {
// val input =
// decoder as? JsonDecoder
// ?: throw SerializationException("This serializer only works with Json")
// return when (val element = input.decodeJsonElement()) {
// is JsonPrimitive -> {
// if (element.isString) {
// ConfigValue.StringValue(element.content)
// } else if (element.content.contains('.')) {
// ConfigValue.FloatValue(element.content.toFloat())
// } else {
// ConfigValue.IntValue(element.content.toInt())
// }
// }
else -> throw SerializationException("Expected JsonPrimitive")
}
}
}
// else -> throw SerializationException("Expected JsonPrimitive")
// }
// }
// }
fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int {
if (configValue == null) {

View file

@ -34,3 +34,13 @@ const val KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES = "KEY_MODEL_EXTRA_DATA_DOWNL
const val KEY_MODEL_IS_ZIP = "KEY_MODEL_IS_ZIP"
const val KEY_MODEL_UNZIPPED_DIR = "KEY_MODEL_UNZIPPED_DIR"
const val KEY_MODEL_START_UNZIPPING = "KEY_MODEL_START_UNZIPPING"
// Default values for LLM models.
const val DEFAULT_MAX_TOKEN = 1024
const val DEFAULT_TOPK = 40
const val DEFAULT_TOPP = 0.9f
const val DEFAULT_TEMPERATURE = 1.0f
val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
// Max number of images allowed in a "ask image" session.
const val MAX_IMAGE_COUNT = 10

View file

@ -16,231 +16,109 @@
package com.google.ai.edge.gallery.data
import android.security.keystore.KeyGenParameterSpec
import android.security.keystore.KeyProperties
import android.util.Base64
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.longPreferencesKey
import androidx.datastore.preferences.core.stringPreferencesKey
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import com.google.ai.edge.gallery.ui.theme.THEME_AUTO
import com.google.ai.edge.gallery.proto.AccessTokenData
import com.google.ai.edge.gallery.proto.ImportedModel
import com.google.ai.edge.gallery.proto.Settings
import com.google.ai.edge.gallery.proto.Theme
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking
import java.security.KeyStore
import javax.crypto.Cipher
import javax.crypto.KeyGenerator
import javax.crypto.SecretKey
data class AccessTokenData(
val accessToken: String,
val refreshToken: String,
val expiresAtMs: Long
)
// TODO(b/423700720): Change to async (suspend) functions
interface DataStoreRepository {
fun saveTextInputHistory(history: List<String>)
fun readTextInputHistory(): List<String>
fun saveThemeOverride(theme: String)
fun readThemeOverride(): String
fun saveTheme(theme: Theme)
fun readTheme(): Theme
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
fun clearAccessTokenData()
fun readAccessTokenData(): AccessTokenData?
fun saveImportedModels(importedModels: List<ImportedModelInfo>)
fun readImportedModels(): List<ImportedModelInfo>
fun saveImportedModels(importedModels: List<ImportedModel>)
fun readImportedModels(): List<ImportedModel>
}
/**
* Repository for managing data using DataStore, with JSON serialization.
*
* This class provides methods to read, add, remove, and clear data stored in DataStore,
* using JSON serialization for complex objects. It uses Gson for serializing and deserializing
* lists of objects to/from JSON strings.
*
* DataStore is used to persist data as JSON strings under specified keys.
*/
class DefaultDataStoreRepository(
private val dataStore: DataStore<Preferences>
) :
DataStoreRepository {
private object PreferencesKeys {
val TEXT_INPUT_HISTORY = stringPreferencesKey("text_input_history")
val THEME_OVERRIDE = stringPreferencesKey("theme_override")
val ENCRYPTED_ACCESS_TOKEN = stringPreferencesKey("encrypted_access_token")
// Store Initialization Vector
val ACCESS_TOKEN_IV = stringPreferencesKey("access_token_iv")
val ENCRYPTED_REFRESH_TOKEN = stringPreferencesKey("encrypted_refresh_token")
// Store Initialization Vector
val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv")
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
// Data for all imported models.
val IMPORTED_MODELS = stringPreferencesKey("imported_models")
}
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
private val keyStore: KeyStore = KeyStore.getInstance("AndroidKeyStore").apply { load(null) }
/** Repository for managing data using Proto DataStore. */
class DefaultDataStoreRepository(private val dataStore: DataStore<Settings>) : DataStoreRepository {
override fun saveTextInputHistory(history: List<String>) {
runBlocking {
dataStore.edit { preferences ->
val gson = Gson()
val jsonString = gson.toJson(history)
preferences[PreferencesKeys.TEXT_INPUT_HISTORY] = jsonString
dataStore.updateData { settings ->
settings.toBuilder().clearTextInputHistory().addAllTextInputHistory(history).build()
}
}
}
override fun readTextInputHistory(): List<String> {
return runBlocking {
val preferences = dataStore.data.first()
getTextInputHistory(preferences)
val settings = dataStore.data.first()
settings.textInputHistoryList
}
}
override fun saveThemeOverride(theme: String) {
override fun saveTheme(theme: Theme) {
runBlocking {
dataStore.edit { preferences ->
preferences[PreferencesKeys.THEME_OVERRIDE] = theme
}
dataStore.updateData { settings -> settings.toBuilder().setTheme(theme).build() }
}
}
override fun readThemeOverride(): String {
override fun readTheme(): Theme {
return runBlocking {
val preferences = dataStore.data.first()
preferences[PreferencesKeys.THEME_OVERRIDE] ?: THEME_AUTO
val settings = dataStore.data.first()
val curTheme = settings.theme
// Use "auto" as the default theme.
if (curTheme == Theme.THEME_UNSPECIFIED) Theme.THEME_AUTO else curTheme
}
}
override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) {
runBlocking {
val (encryptedAccessToken, accessTokenIv) = encrypt(accessToken)
val (encryptedRefreshToken, refreshTokenIv) = encrypt(refreshToken)
dataStore.edit { preferences ->
preferences[PreferencesKeys.ENCRYPTED_ACCESS_TOKEN] = encryptedAccessToken
preferences[PreferencesKeys.ACCESS_TOKEN_IV] = accessTokenIv
preferences[PreferencesKeys.ENCRYPTED_REFRESH_TOKEN] = encryptedRefreshToken
preferences[PreferencesKeys.REFRESH_TOKEN_IV] = refreshTokenIv
preferences[PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT] = expiresAt
dataStore.updateData { settings ->
settings
.toBuilder()
.setAccessTokenData(
AccessTokenData.newBuilder()
.setAccessToken(accessToken)
.setRefreshToken(refreshToken)
.setExpiresAtMs(expiresAt)
.build()
)
.build()
}
}
}
override fun clearAccessTokenData() {
return runBlocking {
dataStore.edit { preferences ->
preferences.remove(PreferencesKeys.ENCRYPTED_ACCESS_TOKEN)
preferences.remove(PreferencesKeys.ACCESS_TOKEN_IV)
preferences.remove(PreferencesKeys.ENCRYPTED_REFRESH_TOKEN)
preferences.remove(PreferencesKeys.REFRESH_TOKEN_IV)
preferences.remove(PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT)
}
runBlocking {
dataStore.updateData { settings -> settings.toBuilder().clearAccessTokenData().build() }
}
}
override fun readAccessTokenData(): AccessTokenData? {
return runBlocking {
val preferences = dataStore.data.first()
val encryptedAccessToken = preferences[PreferencesKeys.ENCRYPTED_ACCESS_TOKEN]
val encryptedRefreshToken = preferences[PreferencesKeys.ENCRYPTED_REFRESH_TOKEN]
val accessTokenIv = preferences[PreferencesKeys.ACCESS_TOKEN_IV]
val refreshTokenIv = preferences[PreferencesKeys.REFRESH_TOKEN_IV]
val expiresAt = preferences[PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT]
var decryptedAccessToken: String? = null
var decryptedRefreshToken: String? = null
if (encryptedAccessToken != null && accessTokenIv != null) {
decryptedAccessToken = decrypt(encryptedAccessToken, accessTokenIv)
}
if (encryptedRefreshToken != null && refreshTokenIv != null) {
decryptedRefreshToken = decrypt(encryptedRefreshToken, refreshTokenIv)
}
if (decryptedAccessToken != null && decryptedRefreshToken != null && expiresAt != null) {
AccessTokenData(decryptedAccessToken, decryptedRefreshToken, expiresAt)
} else {
null
}
val settings = dataStore.data.first()
settings.accessTokenData
}
}
override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
override fun saveImportedModels(importedModels: List<ImportedModel>) {
runBlocking {
dataStore.edit { preferences ->
val gson = Gson()
val jsonString = gson.toJson(importedModels)
preferences[PreferencesKeys.IMPORTED_MODELS] = jsonString
dataStore.updateData { settings ->
settings.toBuilder().clearImportedModel().addAllImportedModel(importedModels).build()
}
}
}
override fun readImportedModels(): List<ImportedModelInfo> {
override fun readImportedModels(): List<ImportedModel> {
return runBlocking {
val preferences = dataStore.data.first()
val infosStr = preferences[PreferencesKeys.IMPORTED_MODELS] ?: "[]"
val gson = Gson()
val listType = object : TypeToken<List<ImportedModelInfo>>() {}.type
gson.fromJson(infosStr, listType)
}
}
private fun getTextInputHistory(preferences: Preferences): List<String> {
val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]"
val gson = Gson()
val listType = object : TypeToken<List<String>>() {}.type
return gson.fromJson(infosStr, listType)
}
private fun getOrCreateSecretKey(): SecretKey {
return (keyStore.getKey(keystoreAlias, null) as? SecretKey) ?: run {
val keyGenerator =
KeyGenerator.getInstance(KeyProperties.KEY_ALGORITHM_AES, "AndroidKeyStore")
val keyGenParameterSpec = KeyGenParameterSpec.Builder(
keystoreAlias,
KeyProperties.PURPOSE_ENCRYPT or KeyProperties.PURPOSE_DECRYPT
)
.setBlockModes(KeyProperties.BLOCK_MODE_GCM)
.setEncryptionPaddings(KeyProperties.ENCRYPTION_PADDING_NONE)
.setUserAuthenticationRequired(false) // Consider setting to true for added security
.build()
keyGenerator.init(keyGenParameterSpec)
keyGenerator.generateKey()
}
}
private fun encrypt(plainText: String): Pair<String, String> {
val secretKey = getOrCreateSecretKey()
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
cipher.init(Cipher.ENCRYPT_MODE, secretKey)
val iv = cipher.iv
val encryptedBytes = cipher.doFinal(plainText.toByteArray(Charsets.UTF_8))
return Base64.encodeToString(encryptedBytes, Base64.DEFAULT) to Base64.encodeToString(
iv,
Base64.DEFAULT
)
}
private fun decrypt(encryptedText: String, ivText: String): String? {
val secretKey = getOrCreateSecretKey()
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
val ivBytes = Base64.decode(ivText, Base64.DEFAULT)
val spec = javax.crypto.spec.GCMParameterSpec(128, ivBytes) // 128 bit tag length
cipher.init(Cipher.DECRYPT_MODE, secretKey, spec)
val encryptedBytes = Base64.decode(encryptedText, Base64.DEFAULT)
return try {
String(cipher.doFinal(encryptedBytes), Charsets.UTF_8)
} catch (e: Exception) {
// Handle decryption errors (e.g., key not found)
null
val settings = dataStore.data.first()
settings.importedModelList
}
}
}

View file

@ -23,11 +23,11 @@ import android.app.PendingIntent
import android.content.Context
import android.content.Intent
import android.content.pm.PackageManager
import android.net.Uri
import android.util.Log
import androidx.core.app.ActivityCompat
import androidx.core.app.NotificationCompat
import androidx.core.app.NotificationManagerCompat
import androidx.core.net.toUri
import androidx.work.Data
import androidx.work.ExistingWorkPolicy
import androidx.work.OneTimeWorkRequestBuilder
@ -38,7 +38,7 @@ import androidx.work.WorkManager
import androidx.work.WorkQuery
import com.google.ai.edge.gallery.AppLifecycleProvider
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.common.readLaunchInfo
import com.google.ai.edge.gallery.common.readLaunchInfo
import com.google.ai.edge.gallery.worker.DownloadWorker
import com.google.common.util.concurrent.FutureCallback
import com.google.common.util.concurrent.Futures
@ -53,7 +53,8 @@ data class AGWorkInfo(val modelName: String, val workId: String)
interface DownloadRepository {
fun downloadModel(
model: Model, onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
)
fun cancelDownloadModel(model: Model)
@ -83,7 +84,8 @@ class DefaultDownloadRepository(
private val workManager = WorkManager.getInstance(context)
override fun downloadModel(
model: Model, onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) {
val appTs = readLaunchInfo(context = context)?.ts ?: 0
@ -91,18 +93,24 @@ class DefaultDownloadRepository(
val builder = Data.Builder()
val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes }
val inputDataBuilder =
builder.putString(KEY_MODEL_NAME, model.name).putString(KEY_MODEL_URL, model.url)
builder
.putString(KEY_MODEL_NAME, model.name)
.putString(KEY_MODEL_URL, model.url)
.putString(KEY_MODEL_VERSION, model.version)
.putString(KEY_MODEL_DOWNLOAD_MODEL_DIR, model.normalizedName)
.putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName)
.putBoolean(KEY_MODEL_IS_ZIP, model.isZip).putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir)
.putLong(KEY_MODEL_TOTAL_BYTES, totalBytes).putLong(KEY_MODEL_DOWNLOAD_APP_TS, appTs)
.putBoolean(KEY_MODEL_IS_ZIP, model.isZip)
.putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir)
.putLong(KEY_MODEL_TOTAL_BYTES, totalBytes)
.putLong(KEY_MODEL_DOWNLOAD_APP_TS, appTs)
if (model.extraDataFiles.isNotEmpty()) {
inputDataBuilder.putString(KEY_MODEL_EXTRA_DATA_URLS,
model.extraDataFiles.joinToString(",") { it.url }).putString(
KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES,
model.extraDataFiles.joinToString(",") { it.downloadFileName })
inputDataBuilder
.putString(KEY_MODEL_EXTRA_DATA_URLS, model.extraDataFiles.joinToString(",") { it.url })
.putString(
KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES,
model.extraDataFiles.joinToString(",") { it.downloadFileName },
)
}
if (model.accessToken != null) {
inputDataBuilder.putString(KEY_MODEL_DOWNLOAD_ACCESS_TOKEN, model.accessToken)
@ -111,20 +119,19 @@ class DefaultDownloadRepository(
// Create worker request.
val downloadWorkRequest =
OneTimeWorkRequestBuilder<DownloadWorker>().setExpedited(OutOfQuotaPolicy.RUN_AS_NON_EXPEDITED_WORK_REQUEST)
.setInputData(inputData).addTag("$MODEL_NAME_TAG:${model.name}").build()
OneTimeWorkRequestBuilder<DownloadWorker>()
.setExpedited(OutOfQuotaPolicy.RUN_AS_NON_EXPEDITED_WORK_REQUEST)
.setInputData(inputData)
.addTag("$MODEL_NAME_TAG:${model.name}")
.build()
val workerId = downloadWorkRequest.id
// Start!
workManager.enqueueUniqueWork(
model.name, ExistingWorkPolicy.REPLACE, downloadWorkRequest
)
workManager.enqueueUniqueWork(model.name, ExistingWorkPolicy.REPLACE, downloadWorkRequest)
// Observe progress.
observerWorkerProgress(
workerId = workerId, model = model, onStatusUpdated = onStatusUpdated
)
observerWorkerProgress(workerId = workerId, model = model, onStatusUpdated = onStatusUpdated)
}
override fun cancelDownloadModel(model: Model) {
@ -143,7 +150,8 @@ class DefaultDownloadRepository(
}
val combinedFuture: ListenableFuture<List<Operation.State.SUCCESS>> = Futures.allAsList(futures)
Futures.addCallback(
combinedFuture, object : FutureCallback<List<Operation.State.SUCCESS>> {
combinedFuture,
object : FutureCallback<List<Operation.State.SUCCESS>> {
override fun onSuccess(result: List<Operation.State.SUCCESS>?) {
// All cancellations are complete
onComplete()
@ -154,7 +162,8 @@ class DefaultDownloadRepository(
t.printStackTrace()
onComplete()
}
}, MoreExecutors.directExecutor()
},
MoreExecutors.directExecutor(),
)
}
@ -175,45 +184,41 @@ class DefaultDownloadRepository(
if (!startUnzipping) {
if (receivedBytes != 0L) {
onStatusUpdated(
model, ModelDownloadStatus(
model,
ModelDownloadStatus(
status = ModelDownloadStatusType.IN_PROGRESS,
totalBytes = model.totalBytes,
receivedBytes = receivedBytes,
bytesPerSecond = downloadRate,
remainingMs = remainingSeconds,
)
),
)
}
} else {
onStatusUpdated(
model, ModelDownloadStatus(
status = ModelDownloadStatusType.UNZIPPING,
)
model,
ModelDownloadStatus(status = ModelDownloadStatusType.UNZIPPING),
)
}
}
WorkInfo.State.SUCCEEDED -> {
Log.d("repo", "worker %s success".format(workerId.toString()))
onStatusUpdated(
model, ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED,
)
)
onStatusUpdated(model, ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED))
sendNotification(
title = context.getString(
R.string.notification_title_success
),
title = context.getString(R.string.notification_title_success),
text = context.getString(R.string.notification_content_success).format(model.name),
modelName = model.name,
)
}
WorkInfo.State.FAILED, WorkInfo.State.CANCELLED -> {
WorkInfo.State.FAILED,
WorkInfo.State.CANCELLED -> {
var status = ModelDownloadStatusType.FAILED
val errorMessage = workInfo.outputData.getString(KEY_MODEL_DOWNLOAD_ERROR_MESSAGE) ?: ""
Log.d(
"repo", "worker %s FAILED or CANCELLED: %s".format(workerId.toString(), errorMessage)
"repo",
"worker %s FAILED or CANCELLED: %s".format(workerId.toString(), errorMessage),
)
if (workInfo.state == WorkInfo.State.CANCELLED) {
status = ModelDownloadStatusType.NOT_DOWNLOADED
@ -225,7 +230,8 @@ class DefaultDownloadRepository(
)
}
onStatusUpdated(
model, ModelDownloadStatus(status = status, errorMessage = errorMessage)
model,
ModelDownloadStatus(status = status, errorMessage = errorMessage),
)
}
@ -278,29 +284,35 @@ class DefaultDownloadRepository(
notificationManager.createNotificationChannel(channel)
// Create an Intent to open your app with a deep link.
val intent = Intent(
Intent.ACTION_VIEW, Uri.parse("com.google.ai.edge.gallery://model/${modelName}")
).apply {
flags = Intent.FLAG_ACTIVITY_NEW_TASK
}
val intent =
Intent(Intent.ACTION_VIEW, "com.google.ai.edge.gallery://model/${modelName}".toUri()).apply {
flags = Intent.FLAG_ACTIVITY_NEW_TASK
}
// Create a PendingIntent
val pendingIntent: PendingIntent = PendingIntent.getActivity(
context, 0, intent, PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE
)
val pendingIntent: PendingIntent =
PendingIntent.getActivity(
context,
0,
intent,
PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE,
)
val builder = NotificationCompat.Builder(context, channelId)
// TODO: replace icon.
.setSmallIcon(android.R.drawable.ic_dialog_info).setContentTitle(title).setContentText(text)
.setPriority(NotificationCompat.PRIORITY_HIGH).setContentIntent(pendingIntent)
.setAutoCancel(true)
val builder =
NotificationCompat.Builder(context, channelId)
// TODO: replace icon.
.setSmallIcon(android.R.drawable.ic_dialog_info)
.setContentTitle(title)
.setContentText(text)
.setPriority(NotificationCompat.PRIORITY_HIGH)
.setContentIntent(pendingIntent)
.setAutoCancel(true)
with(NotificationManagerCompat.from(context)) {
// notificationId is a unique int for each notification that you must define
if (ActivityCompat.checkSelfPermission(
context, Manifest.permission.POST_NOTIFICATIONS
) != PackageManager.PERMISSION_GRANTED
if (
ActivityCompat.checkSelfPermission(context, Manifest.permission.POST_NOTIFICATIONS) !=
PackageManager.PERMISSION_GRANTED
) {
// Permission not granted, return or handle accordingly. In real app, request permission.
return

View file

@ -17,8 +17,6 @@
package com.google.ai.edge.gallery.data
import android.content.Context
import com.google.ai.edge.gallery.ui.common.chat.PromptTemplate
import com.google.ai.edge.gallery.ui.common.convertValueToTargetType
import java.io.File
data class ModelDataFile(
@ -28,13 +26,11 @@ data class ModelDataFile(
val sizeInBytes: Long,
)
enum class Accelerator(val label: String) {
CPU(label = "CPU"), GPU(label = "GPU")
}
const val IMPORTS_DIR = "__imports"
private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]")
data class PromptTemplate(val title: String, val description: String, val prompt: String)
/** A model for a task */
data class Model(
/** The name (for display purpose) of the model. */
@ -67,9 +63,7 @@ data class Model(
*/
val info: String = "",
/**
* The url to jump to when clicking "learn more" in expanded model item.
*/
/** The url to jump to when clicking "learn more" in expanded model item. */
val learnMoreUrl: String = "",
/** A list of configurable parameters for the model. */
@ -105,6 +99,9 @@ data class Model(
var configValues: Map<String, Any> = mapOf(),
var totalBytes: Long = 0L,
var accessToken: String? = null,
/** The estimated peak memory in byte to run the model. */
val estimatedPeakMemoryInBytes: Long? = null,
) {
init {
normalizedName = NORMALIZE_NAME_REGEX.replace(name, "_")
@ -121,17 +118,13 @@ data class Model(
fun getPath(context: Context, fileName: String = downloadFileName): String {
if (imported) {
return listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", fileName).joinToString(
File.separator
)
return listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", fileName)
.joinToString(File.separator)
}
val baseDir =
listOf(
context.getExternalFilesDir(null)?.absolutePath ?: "",
normalizedName,
version
).joinToString(File.separator)
listOf(context.getExternalFilesDir(null)?.absolutePath ?: "", normalizedName, version)
.joinToString(File.separator)
return if (this.isZip && this.unzipDir.isNotEmpty()) {
"$baseDir/${this.unzipDir}"
} else {
@ -140,27 +133,27 @@ data class Model(
}
fun getIntConfigValue(key: ConfigKey, defaultValue: Int = 0): Int {
return getTypedConfigValue(
key = key, valueType = ValueType.INT, defaultValue = defaultValue
) as Int
return getTypedConfigValue(key = key, valueType = ValueType.INT, defaultValue = defaultValue)
as Int
}
fun getFloatConfigValue(key: ConfigKey, defaultValue: Float = 0.0f): Float {
return getTypedConfigValue(
key = key, valueType = ValueType.FLOAT, defaultValue = defaultValue
) as Float
return getTypedConfigValue(key = key, valueType = ValueType.FLOAT, defaultValue = defaultValue)
as Float
}
fun getBooleanConfigValue(key: ConfigKey, defaultValue: Boolean = false): Boolean {
return getTypedConfigValue(
key = key, valueType = ValueType.BOOLEAN, defaultValue = defaultValue
) as Boolean
key = key,
valueType = ValueType.BOOLEAN,
defaultValue = defaultValue,
)
as Boolean
}
fun getStringConfigValue(key: ConfigKey, defaultValue: String = ""): String {
return getTypedConfigValue(
key = key, valueType = ValueType.STRING, defaultValue = defaultValue
) as String
return getTypedConfigValue(key = key, valueType = ValueType.STRING, defaultValue = defaultValue)
as String
}
fun getExtraDataFile(name: String): ModelDataFile? {
@ -169,20 +162,19 @@ data class Model(
private fun getTypedConfigValue(key: ConfigKey, valueType: ValueType, defaultValue: Any): Any {
return convertValueToTargetType(
value = configValues.getOrDefault(key.label, defaultValue), valueType = valueType
value = configValues.getOrDefault(key.label, defaultValue),
valueType = valueType,
)
}
}
/** Data for a imported local model. */
data class ImportedModelInfo(
val fileName: String,
val fileSize: Long,
val defaultValues: Map<String, Any>
)
enum class ModelDownloadStatusType {
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
NOT_DOWNLOADED,
PARTIALLY_DOWNLOADED,
IN_PROGRESS,
UNZIPPING,
SUCCEEDED,
FAILED,
}
data class ModelDownloadStatus(
@ -197,51 +189,29 @@ data class ModelDownloadStatus(
////////////////////////////////////////////////////////////////////////////////////////////////////
// Configs.
enum class ConfigKey(val label: String) {
MAX_TOKENS("Max tokens"),
TOPK("TopK"),
TOPP("TopP"),
TEMPERATURE("Temperature"),
DEFAULT_MAX_TOKENS("Default max tokens"),
DEFAULT_TOPK("Default TopK"),
DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"),
SUPPORT_IMAGE("Support image"),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Choose accelerator"),
COMPATIBLE_ACCELERATORS("Compatible accelerators"),
WARM_UP_ITERATIONS("Warm up iterations"),
BENCHMARK_ITERATIONS("Benchmark iterations"),
ITERATIONS("Iterations"),
THEME("Theme"),
NAME("Name"),
MODEL_TYPE("Model type")
}
val MOBILENET_CONFIGS: List<Config> = listOf(
NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f,
sliderMax = 5f,
defaultValue = 3f,
valueType = ValueType.INT
), BooleanSwitchConfig(
key = ConfigKey.USE_GPU,
defaultValue = false,
val MOBILENET_CONFIGS: List<Config> =
listOf(
NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f,
sliderMax = 5f,
defaultValue = 3f,
valueType = ValueType.INT,
),
BooleanSwitchConfig(key = ConfigKey.USE_GPU, defaultValue = false),
)
)
val IMAGE_GENERATION_CONFIGS: List<Config> = listOf(
NumberSliderConfig(
key = ConfigKey.ITERATIONS,
sliderMin = 5f,
sliderMax = 50f,
defaultValue = 10f,
valueType = ValueType.INT,
needReinitialization = false,
val IMAGE_GENERATION_CONFIGS: List<Config> =
listOf(
NumberSliderConfig(
key = ConfigKey.ITERATIONS,
sliderMin = 5f,
sliderMax = 50f,
defaultValue = 10f,
valueType = ValueType.INT,
needReinitialization = false,
)
)
)
const val TEXT_CLASSIFICATION_INFO =
"Model is trained on movie reviews dataset. Type a movie review below and see the scores of positive or negative sentiment."
@ -256,92 +226,97 @@ const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/lite
const val IMAGE_GENERATION_INFO =
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model =
Model(
name = "MobileBert",
downloadFileName = "bert_classifier.tflite",
url =
"https://storage.googleapis.com/mediapipe-models/text_classifier/bert_classifier/float32/latest/bert_classifier.tflite",
sizeInBytes = 25707538L,
info = TEXT_CLASSIFICATION_INFO,
learnMoreUrl = TEXT_CLASSIFICATION_LEARN_MORE_URL,
)
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
name = "MobileBert",
downloadFileName = "bert_classifier.tflite",
url = "https://storage.googleapis.com/mediapipe-models/text_classifier/bert_classifier/float32/latest/bert_classifier.tflite",
sizeInBytes = 25707538L,
info = TEXT_CLASSIFICATION_INFO,
learnMoreUrl = TEXT_CLASSIFICATION_LEARN_MORE_URL,
)
val MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING: Model =
Model(
name = "Average word embedding",
downloadFileName = "average_word_classifier.tflite",
url =
"https://storage.googleapis.com/mediapipe-models/text_classifier/average_word_classifier/float32/latest/average_word_classifier.tflite",
sizeInBytes = 775708L,
info = TEXT_CLASSIFICATION_INFO,
)
val MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING: Model = Model(
name = "Average word embedding",
downloadFileName = "average_word_classifier.tflite",
url = "https://storage.googleapis.com/mediapipe-models/text_classifier/average_word_classifier/float32/latest/average_word_classifier.tflite",
sizeInBytes = 775708L,
info = TEXT_CLASSIFICATION_INFO,
)
val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1: Model =
Model(
name = "Mobilenet V1",
downloadFileName = "mobilenet_v1.tflite",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v1.tflite",
sizeInBytes = 16900760L,
extraDataFiles =
listOf(
ModelDataFile(
name = "labels",
url =
"https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt",
downloadFileName = "mobilenet_labels_v1.txt",
sizeInBytes = 21685L,
)
),
configs = MOBILENET_CONFIGS,
info = IMAGE_CLASSIFICATION_INFO,
learnMoreUrl = IMAGE_CLASSIFICATION_LEARN_MORE_URL,
)
val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1: Model = Model(
name = "Mobilenet V1",
downloadFileName = "mobilenet_v1.tflite",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v1.tflite",
sizeInBytes = 16900760L,
extraDataFiles = listOf(
ModelDataFile(
name = "labels",
url = "https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt",
downloadFileName = "mobilenet_labels_v1.txt",
sizeInBytes = 21685L
),
),
configs = MOBILENET_CONFIGS,
info = IMAGE_CLASSIFICATION_INFO,
learnMoreUrl = IMAGE_CLASSIFICATION_LEARN_MORE_URL,
)
val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2: Model =
Model(
name = "Mobilenet V2",
downloadFileName = "mobilenet_v2.tflite",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v2.tflite",
sizeInBytes = 13978596L,
extraDataFiles =
listOf(
ModelDataFile(
name = "labels",
url =
"https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt",
downloadFileName = "mobilenet_labels_v2.txt",
sizeInBytes = 21685L,
)
),
configs = MOBILENET_CONFIGS,
info = IMAGE_CLASSIFICATION_INFO,
)
val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2: Model = Model(
name = "Mobilenet V2",
downloadFileName = "mobilenet_v2.tflite",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v2.tflite",
sizeInBytes = 13978596L,
extraDataFiles = listOf(
ModelDataFile(
name = "labels",
url = "https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt",
downloadFileName = "mobilenet_labels_v2.txt",
sizeInBytes = 21685L
),
),
configs = MOBILENET_CONFIGS,
info = IMAGE_CLASSIFICATION_INFO,
)
val MODEL_IMAGE_GENERATION_STABLE_DIFFUSION: Model =
Model(
name = "Stable diffusion",
downloadFileName = "sd15.zip",
isZip = true,
unzipDir = "sd15",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/sd15.zip",
sizeInBytes = 1906219565L,
showRunAgainButton = false,
showBenchmarkButton = false,
info = IMAGE_GENERATION_INFO,
configs = IMAGE_GENERATION_CONFIGS,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val MODEL_IMAGE_GENERATION_STABLE_DIFFUSION: Model = Model(
name = "Stable diffusion",
downloadFileName = "sd15.zip",
isZip = true,
unzipDir = "sd15",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/sd15.zip",
sizeInBytes = 1906219565L,
showRunAgainButton = false,
showBenchmarkButton = false,
info = IMAGE_GENERATION_INFO,
configs = IMAGE_GENERATION_CONFIGS,
learnMoreUrl = "https://huggingface.co/litert-community",
)
val EMPTY_MODEL: Model = Model(
name = "empty",
downloadFileName = "empty.tflite",
url = "",
sizeInBytes = 0L,
)
val EMPTY_MODEL: Model =
Model(name = "empty", downloadFileName = "empty.tflite", url = "", sizeInBytes = 0L)
////////////////////////////////////////////////////////////////////////////////////////////////////
// Model collections for different tasks.
val MODELS_TEXT_CLASSIFICATION: MutableList<Model> = mutableListOf(
MODEL_TEXT_CLASSIFICATION_MOBILEBERT,
MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING,
)
val MODELS_TEXT_CLASSIFICATION: MutableList<Model> =
mutableListOf(
MODEL_TEXT_CLASSIFICATION_MOBILEBERT,
MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING,
)
val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1,
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
)
val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> =
mutableListOf(MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1, MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2)
val MODELS_IMAGE_GENERATION: MutableList<Model> =
mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION)

View file

@ -16,15 +16,17 @@
package com.google.ai.edge.gallery.data
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_ACCELERATORS
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TOPK
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TOPP
import com.google.ai.edge.gallery.ui.llmchat.createLlmChatConfigs
import kotlinx.serialization.Serializable
import com.google.gson.annotations.SerializedName
data class DefaultConfig(
@SerializedName("topK") val topK: Int?,
@SerializedName("topP") val topP: Float?,
@SerializedName("temperature") val temperature: Float?,
@SerializedName("accelerators") val accelerators: String?,
@SerializedName("maxTokens") val maxTokens: Int?,
)
/** A model in the model allowlist. */
@Serializable
data class AllowedModel(
val name: String,
val modelId: String,
@ -32,10 +34,11 @@ data class AllowedModel(
val description: String,
val sizeInBytes: Long,
val version: String,
val defaultConfig: Map<String, ConfigValue>,
val defaultConfig: DefaultConfig,
val taskTypes: List<String>,
val disabled: Boolean? = null,
val llmSupportImage: Boolean? = null,
val estimatedPeakMemoryInBytes: Long? = null,
) {
fun toModel(): Model {
// Construct HF download url.
@ -46,25 +49,13 @@ data class AllowedModel(
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)
var configs: List<Config> = listOf()
if (isLlmModel) {
var defaultTopK: Int = DEFAULT_TOPK
var defaultTopP: Float = DEFAULT_TOPP
var defaultTemperature: Float = DEFAULT_TEMPERATURE
var defaultMaxToken = 1024
var defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK
var defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP
var defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE
var defaultMaxToken = defaultConfig.maxTokens ?: 1024
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
if (defaultConfig.containsKey("topK")) {
defaultTopK = getIntConfigValue(defaultConfig["topK"], defaultTopK)
}
if (defaultConfig.containsKey("topP")) {
defaultTopP = getFloatConfigValue(defaultConfig["topP"], defaultTopP)
}
if (defaultConfig.containsKey("temperature")) {
defaultTemperature = getFloatConfigValue(defaultConfig["temperature"], defaultTemperature)
}
if (defaultConfig.containsKey("maxTokens")) {
defaultMaxToken = getIntConfigValue(defaultConfig["maxTokens"], defaultMaxToken)
}
if (defaultConfig.containsKey("accelerators")) {
val items = getStringConfigValue(defaultConfig["accelerators"], "gpu").split(",")
if (defaultConfig.accelerators != null) {
val items = defaultConfig.accelerators.split(",")
accelerators = mutableListOf()
for (item in items) {
if (item == "cpu") {
@ -74,13 +65,14 @@ data class AllowedModel(
}
}
}
configs = createLlmChatConfigs(
defaultTopK = defaultTopK,
defaultTopP = defaultTopP,
defaultTemperature = defaultTemperature,
defaultMaxToken = defaultMaxToken,
accelerators = accelerators,
)
configs =
createLlmChatConfigs(
defaultTopK = defaultTopK,
defaultTopP = defaultTopP,
defaultTemperature = defaultTemperature,
defaultMaxToken = defaultMaxToken,
accelerators = accelerators,
)
}
// Misc.
@ -97,6 +89,7 @@ data class AllowedModel(
info = description,
url = downloadUrl,
sizeInBytes = sizeInBytes,
estimatedPeakMemoryInBytes = estimatedPeakMemoryInBytes,
configs = configs,
downloadFileName = modelFile,
showBenchmarkButton = showBenchmarkButton,
@ -112,8 +105,4 @@ data class AllowedModel(
}
/** The model allowlist. */
@Serializable
data class ModelAllowlist(
val models: List<AllowedModel>,
)
data class ModelAllowlist(val models: List<AllowedModel>)

View file

@ -17,27 +17,21 @@
package com.google.ai.edge.gallery.data
import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Forum
import androidx.compose.material.icons.outlined.Mms
import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.material.icons.rounded.ImageSearch
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.ui.graphics.vector.ImageVector
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.icon.Forum
import com.google.ai.edge.gallery.ui.icon.Mms
import com.google.ai.edge.gallery.ui.icon.Widgets
/** Type of task. */
enum class TaskType(val label: String, val id: String) {
TEXT_CLASSIFICATION(label = "Text Classification", id = "text_classification"),
IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"),
IMAGE_GENERATION(label = "Image Generation", id = "image_generation"),
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2")
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
}
/** Data class for a task listed in home screen. */
@ -71,71 +65,47 @@ data class Task(
// The following fields are managed by the app. Don't need to set manually.
var index: Int = -1,
val updateTrigger: MutableState<Long> = mutableLongStateOf(0)
val updateTrigger: MutableState<Long> = mutableLongStateOf(0),
)
val TASK_TEXT_CLASSIFICATION = Task(
type = TaskType.TEXT_CLASSIFICATION,
iconVectorResourceId = R.drawable.text_spark,
models = MODELS_TEXT_CLASSIFICATION,
description = "Classify text into different categories",
textInputPlaceHolderRes = R.string.text_input_placeholder_text_classification
)
val TASK_LLM_CHAT =
Task(
type = TaskType.LLM_CHAT,
icon = Forum,
models = mutableListOf(),
description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl =
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
val TASK_IMAGE_CLASSIFICATION = Task(
type = TaskType.IMAGE_CLASSIFICATION,
icon = Icons.Rounded.ImageSearch,
description = "Classify images into different categories",
models = MODELS_IMAGE_CLASSIFICATION
)
val TASK_LLM_PROMPT_LAB =
Task(
type = TaskType.LLM_PROMPT_LAB,
icon = Widgets,
models = mutableListOf(),
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl =
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
val TASK_LLM_CHAT = Task(
type = TaskType.LLM_CHAT,
icon = Icons.Outlined.Forum,
models = mutableListOf(),
description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_LLM_PROMPT_LAB = Task(
type = TaskType.LLM_PROMPT_LAB,
icon = Icons.Outlined.Widgets,
models = mutableListOf(),
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_LLM_ASK_IMAGE = Task(
type = TaskType.LLM_ASK_IMAGE,
icon = Icons.Outlined.Mms,
models = mutableListOf(),
description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_IMAGE_GENERATION = Task(
type = TaskType.IMAGE_GENERATION,
iconVectorResourceId = R.drawable.image_spark,
models = MODELS_IMAGE_GENERATION,
description = "Generate images from text",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android",
sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/imagegeneration/ImageGenerationModelHelper.kt",
textInputPlaceHolderRes = R.string.text_image_generation_text_field_placeholder
)
val TASK_LLM_ASK_IMAGE =
Task(
type = TaskType.LLM_ASK_IMAGE,
icon = Mms,
models = mutableListOf(),
description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl =
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
/** All tasks. */
val TASKS: List<Task> = listOf(
TASK_LLM_ASK_IMAGE,
TASK_LLM_PROMPT_LAB,
TASK_LLM_CHAT,
)
val TASKS: List<Task> = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
fun getModelByName(name: String): Model? {
for (task in TASKS) {
@ -147,3 +117,12 @@ fun getModelByName(name: String): Model? {
}
return null
}
fun processTasks() {
for ((index, task) in TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess()
}
}
}

View file

@ -0,0 +1,22 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.data
enum class Accelerator(val label: String) {
CPU(label = "CPU"),
GPU(label = "GPU"),
}

View file

@ -16,19 +16,15 @@
package com.google.ai.edge.gallery.ui
import android.app.Application
import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory
import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.initializer
import androidx.lifecycle.viewmodel.viewModelFactory
import com.google.ai.edge.gallery.GalleryApplication
import com.google.ai.edge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.ai.edge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.textclassification.TextClassificationViewModel
object ViewModelProvider {
val Factory = viewModelFactory {
@ -36,42 +32,23 @@ object ViewModelProvider {
initializer {
val downloadRepository = galleryApplication().container.downloadRepository
val dataStoreRepository = galleryApplication().container.dataStoreRepository
val lifecycleProvider = galleryApplication().container.lifecycleProvider
ModelManagerViewModel(
downloadRepository = downloadRepository,
dataStoreRepository = dataStoreRepository,
lifecycleProvider = lifecycleProvider,
context = galleryApplication().container.context,
)
}
// Initializer for TextClassificationViewModel
initializer {
TextClassificationViewModel()
}
// Initializer for ImageClassificationViewModel
initializer {
ImageClassificationViewModel()
}
// Initializer for LlmChatViewModel.
initializer {
LlmChatViewModel()
}
initializer { LlmChatViewModel() }
// Initializer for LlmSingleTurnViewModel..
initializer {
LlmSingleTurnViewModel()
}
initializer { LlmSingleTurnViewModel() }
// Initializer for LlmAskImageViewModel.
initializer {
LlmAskImageViewModel()
}
// Initializer for ImageGenerationViewModel.
initializer {
ImageGenerationViewModel()
}
initializer { LlmAskImageViewModel() }
}
}

View file

@ -16,7 +16,7 @@
package com.google.ai.edge.gallery.ui.common
import android.net.Uri
import androidx.core.net.toUri
import net.openid.appauth.AuthorizationServiceConfiguration
object AuthConfig {
@ -34,8 +34,9 @@ object AuthConfig {
private const val tokenEndpoint = "https://huggingface.co/oauth/token"
// OAuth service configuration (AppAuth library requires this)
val authServiceConfig = AuthorizationServiceConfiguration(
Uri.parse(authEndpoint), // Authorization endpoint
Uri.parse(tokenEndpoint) // Token exchange endpoint
)
}
val authServiceConfig =
AuthorizationServiceConfiguration(
authEndpoint.toUri(), // Authorization endpoint
tokenEndpoint.toUri(), // Token exchange endpoint
)
}

View file

@ -0,0 +1,68 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalUriHandler
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextDecoration
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.customColors
@Composable
fun ClickableLink(url: String, linkText: String, icon: ImageVector) {
val uriHandler = LocalUriHandler.current
val annotatedText =
AnnotatedString(
text = linkText,
spanStyles =
listOf(
AnnotatedString.Range(
item =
SpanStyle(
color = MaterialTheme.customColors.linkColor,
textDecoration = TextDecoration.Underline,
),
start = 0,
end = linkText.length,
)
),
)
Row(verticalAlignment = Alignment.CenterVertically, horizontalArrangement = Arrangement.Center) {
Icon(icon, contentDescription = "", modifier = Modifier.size(16.dp))
Text(
text = annotatedText,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyLarge,
modifier = Modifier.padding(start = 6.dp).clickable { uriHandler.openUri(url) },
)
}
}

View file

@ -0,0 +1,41 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.graphics.Color
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.theme.customColors
@Composable
fun getTaskBgColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskBgColors.size
return MaterialTheme.customColors.taskBgColors[colorIndex]
}
@Composable
fun getTaskIconColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}
@Composable
fun getTaskIconColor(index: Int): Color {
val colorIndex: Int = index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}

View file

@ -14,8 +14,11 @@
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.chat
package com.google.ai.edge.gallery.ui.common
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.util.Log
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
@ -61,7 +64,6 @@ import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.google.ai.edge.gallery.data.BooleanSwitchConfig
@ -70,8 +72,6 @@ import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.ValueType
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
import kotlin.Double.Companion.NaN
@ -92,33 +92,32 @@ fun ConfigDialog(
showCancel: Boolean = true,
) {
val values: SnapshotStateMap<String, Any> = remember {
mutableStateMapOf<String, Any>().apply {
putAll(initialValues)
}
mutableStateMapOf<String, Any>().apply { putAll(initialValues) }
}
val interactionSource = remember { MutableInteractionSource() }
Dialog(onDismissRequest = onDismissed) {
val focusManager = LocalFocusManager.current
Card(
modifier = Modifier
.fillMaxWidth()
.clickable(
interactionSource = interactionSource, indication = null // Disable the ripple effect
modifier =
Modifier.fillMaxWidth().clickable(
interactionSource = interactionSource,
indication = null, // Disable the ripple effect
) {
focusManager.clearFocus()
},
shape = RoundedCornerShape(16.dp)
shape = RoundedCornerShape(16.dp),
) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) {
// Dialog title and subtitle.
Column {
Text(
title,
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
modifier = Modifier.padding(bottom = 8.dp),
)
// Subtitle.
if (subtitle.isNotEmpty()) {
@ -126,35 +125,27 @@ fun ConfigDialog(
subtitle,
style = labelSmallNarrow,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.offset(y = (-6).dp)
modifier = Modifier.offset(y = (-6).dp),
)
}
}
// List of config rows.
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.verticalScroll(rememberScrollState()).weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp),
) {
ConfigEditorsPanel(configs = configs, values = values)
}
// Button row.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
modifier = Modifier.fillMaxWidth().padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
// Cancel button.
if (showCancel) {
TextButton(
onClick = { onDismissed() },
) {
Text("Cancel")
}
TextButton(onClick = { onDismissed() }) { Text("Cancel") }
}
// Ok button
@ -162,7 +153,7 @@ fun ConfigDialog(
onClick = {
Log.d(TAG, "Values from dialog: $values")
onOk(values.toMap())
},
}
) {
Text(okBtnLabel)
}
@ -172,9 +163,7 @@ fun ConfigDialog(
}
}
/**
* Composable function to display a list of config editor rows.
*/
/** Composable function to display a list of config editor rows. */
@Composable
fun ConfigEditorsPanel(configs: List<Config>, values: SnapshotStateMap<String, Any>) {
for (config in configs) {
@ -210,11 +199,12 @@ fun LabelRow(config: LabelConfig, values: SnapshotStateMap<String, Any>) {
// Field label.
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
// Content label.
val label = try {
values[config.key.label] as String
} catch (e: Exception) {
""
}
val label =
try {
values[config.key.label] as String
} catch (e: Exception) {
""
}
Text(label, style = MaterialTheme.typography.bodyMedium)
}
}
@ -222,9 +212,9 @@ fun LabelRow(config: LabelConfig, values: SnapshotStateMap<String, Any>) {
/**
* Composable function to display a number slider with an associated text input field.
*
* This function renders a row containing a slider and a text field, both used to modify
* a numeric value. The slider allows users to visually adjust the value within a specified range,
* while the text field provides precise numeric input.
* This function renders a row containing a slider and a text field, both used to modify a numeric
* value. The slider allows users to visually adjust the value within a specified range, while the
* text field provides precise numeric input.
*/
@Composable
fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String, Any>) {
@ -233,52 +223,50 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
// Controls row.
Row(
modifier = Modifier.fillMaxWidth(), verticalAlignment = Alignment.CenterVertically
) {
Row(modifier = Modifier.fillMaxWidth(), verticalAlignment = Alignment.CenterVertically) {
var isFocused by remember { mutableStateOf(false) }
val focusRequester = remember { FocusRequester() }
// Number slider.
val sliderValue = try {
values[config.key.label] as Float
} catch (e: Exception) {
0f
}
Slider(modifier = Modifier
.height(24.dp)
.weight(1f),
val sliderValue =
try {
values[config.key.label] as Float
} catch (e: Exception) {
0f
}
Slider(
modifier = Modifier.height(24.dp).weight(1f),
value = sliderValue,
valueRange = config.sliderMin..config.sliderMax,
onValueChange = { values[config.key.label] = it })
onValueChange = { values[config.key.label] = it },
)
Spacer(modifier = Modifier.width(8.dp))
// Text field.
val textFieldValue = try {
when (config.valueType) {
ValueType.FLOAT -> {
"%.2f".format(values[config.key.label] as Float)
}
val textFieldValue =
try {
when (config.valueType) {
ValueType.FLOAT -> {
"%.2f".format(values[config.key.label] as Float)
}
ValueType.INT -> {
"${(values[config.key.label] as Float).toInt()}"
}
ValueType.INT -> {
"${(values[config.key.label] as Float).toInt()}"
}
else -> {
""
else -> {
""
}
}
} catch (e: Exception) {
""
}
} catch (e: Exception) {
""
}
// A smaller text field.
BasicTextField(
value = textFieldValue,
modifier = Modifier
.width(80.dp)
.focusRequester(focusRequester)
.onFocusChanged {
modifier =
Modifier.width(80.dp).focusRequester(focusRequester).onFocusChanged {
isFocused = it.isFocused
},
keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number),
@ -293,15 +281,16 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
) { innerTextField ->
Box(
modifier = Modifier.border(
width = if (isFocused) 2.dp else 1.dp,
color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline,
shape = RoundedCornerShape(4.dp)
)
modifier =
Modifier.border(
width = if (isFocused) 2.dp else 1.dp,
color =
if (isFocused) MaterialTheme.colorScheme.primary
else MaterialTheme.colorScheme.outline,
shape = RoundedCornerShape(4.dp),
)
) {
Box(modifier = Modifier.padding(8.dp)) {
innerTextField()
}
Box(modifier = Modifier.padding(8.dp)) { innerTextField() }
}
}
}
@ -311,16 +300,17 @@ fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String,
/**
* Composable function to display a row with a boolean switch.
*
* This function renders a row containing a label and a switch, allowing users to toggle
* a boolean value.
* This function renders a row containing a label and a switch, allowing users to toggle a boolean
* value.
*/
@Composable
fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<String, Any>) {
val switchValue = try {
values[config.key.label] as Boolean
} catch (e: Exception) {
false
}
val switchValue =
try {
values[config.key.label] as Boolean
} catch (e: Exception) {
false
}
Column(modifier = Modifier.fillMaxWidth()) {
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
Switch(checked = switchValue, onCheckedChange = { values[config.key.label] = it })
@ -331,63 +321,66 @@ fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<Strin
fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) {
val selectedOptions: List<String> = remember { (values[config.key.label] as String).split(",") }
var selectionStates: List<Boolean> by remember {
mutableStateOf(List(config.options.size) { index ->
selectedOptions.contains(config.options[index])
})
mutableStateOf(
List(config.options.size) { index -> selectedOptions.contains(config.options[index]) }
)
}
Column(modifier = Modifier.fillMaxWidth()) {
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
MultiChoiceSegmentedButtonRow {
config.options.forEachIndexed { index, label ->
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
index = index, count = config.options.size
), onCheckedChange = {
var newSelectionStates = selectionStates.toMutableList()
val selectedCount = newSelectionStates.count { it }
SegmentedButton(
shape = SegmentedButtonDefaults.itemShape(index = index, count = config.options.size),
onCheckedChange = {
var newSelectionStates = selectionStates.toMutableList()
val selectedCount = newSelectionStates.count { it }
// Single select.
if (!config.allowMultiple) {
if (!newSelectionStates[index]) {
newSelectionStates = MutableList(config.options.size) { it == index }
// Single select.
if (!config.allowMultiple) {
if (!newSelectionStates[index]) {
newSelectionStates = MutableList(config.options.size) { it == index }
}
}
}
// Multiple select.
else {
if (!(selectedCount == 1 && newSelectionStates[index])) {
newSelectionStates[index] = !newSelectionStates[index]
// Multiple select.
else {
if (!(selectedCount == 1 && newSelectionStates[index])) {
newSelectionStates[index] = !newSelectionStates[index]
}
}
}
selectionStates = newSelectionStates
selectionStates = newSelectionStates
values[config.key.label] =
config.options.filterIndexed { index, option -> selectionStates[index] }
.joinToString(",")
}, checked = selectionStates[index], label = { Text(label) })
values[config.key.label] =
config.options
.filterIndexed { index, option -> selectionStates[index] }
.joinToString(",")
},
checked = selectionStates[index],
label = { Text(label) },
)
}
}
}
}
@Composable
@Preview(showBackground = true)
fun ConfigDialogPreview() {
GalleryTheme {
val defaultValues: MutableMap<String, Any> = mutableMapOf()
for (config in MODEL_TEST1.configs) {
defaultValues[config.key.label] = config.defaultValue
}
// @Composable
// @Preview(showBackground = true)
// fun ConfigDialogPreview() {
// GalleryTheme {
// val defaultValues: MutableMap<String, Any> = mutableMapOf()
// for (config in MODEL_TEST1.configs) {
// defaultValues[config.key.label] = config.defaultValue
// }
Column {
ConfigDialog(
title = "Dialog title",
subtitle = "20250413",
configs = MODEL_TEST1.configs,
initialValues = defaultValues,
onDismissed = {},
onOk = {},
)
}
}
}
// Column {
// ConfigDialog(
// title = "Dialog title",
// subtitle = "20250413",
// configs = MODEL_TEST1.configs,
// initialValues = defaultValues,
// onDismissed = {},
// onOk = {},
// )
// }
// }
// }

View file

@ -16,8 +16,8 @@
package com.google.ai.edge.gallery.ui.common
import android.app.ActivityManager
import android.content.Intent
import android.net.Uri
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher
@ -51,16 +51,17 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import androidx.core.net.toUri
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.modelmanager.TokenRequestResultType
import com.google.ai.edge.gallery.ui.modelmanager.TokenStatus
import java.net.HttpURLConnection
import kotlin.math.max
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.net.HttpURLConnection
private const val TAG = "AGDownloadAndTryButton"
@ -72,14 +73,14 @@ private const val TAG = "AGDownloadAndTryButton"
* Handles the "Download & Try it" button click, managing the model download process based on
* various conditions.
*
* If the button is enabled and not currently checking the token, it initiates a coroutine to
* handle the download logic.
* If the button is enabled and not currently checking the token, it initiates a coroutine to handle
* the download logic.
*
* For models requiring download first, it specifically addresses HuggingFace URLs by first
* checking if authentication is necessary. If no authentication is needed, the download starts
* directly. Otherwise, it checks the current token status; if the token is invalid or expired,
* a token exchange flow is initiated. If a valid token exists, it attempts to access the
* download URL. If access is granted, the download begins; if not, a new token is requested.
* For models requiring download first, it specifically addresses HuggingFace URLs by first checking
* if authentication is necessary. If no authentication is needed, the download starts directly.
* Otherwise, it checks the current token status; if the token is invalid or expired, a token
* exchange flow is initiated. If a valid token exists, it attempts to access the download URL. If
* access is granted, the download begins; if not, a new token is requested.
*
* For non-HuggingFace URLs that need downloading, the download starts directly.
*
@ -102,21 +103,21 @@ fun DownloadAndTryButton(
enabled: Boolean,
needToDownloadFirst: Boolean,
modelManagerViewModel: ModelManagerViewModel,
onClicked: () -> Unit
onClicked: () -> Unit,
) {
val scope = rememberCoroutineScope()
val context = LocalContext.current
var checkingToken by remember { mutableStateOf(false) }
var showAgreementAckSheet by remember { mutableStateOf(false) }
var showErrorDialog by remember { mutableStateOf(false) }
var showMemoryWarning by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState()
// A launcher for requesting notification permission.
val permissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(task = task, model = model)
}
val permissionLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
modelManagerViewModel.downloadModel(task = task, model = model)
}
// Function to kick off download.
val startDownload: (accessToken: String?) -> Unit = { accessToken ->
@ -127,64 +128,73 @@ fun DownloadAndTryButton(
launcher = permissionLauncher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = model
model = model,
)
checkingToken = false
}
// A launcher for opening the custom tabs intent for requesting user agreement ack.
// Once the tab is closed, try starting the download process.
val agreementAckLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
Log.d(TAG, "User closes the browser tab. Try to start downloading.")
startDownload(modelManagerViewModel.curAccessToken)
}
val agreementAckLauncher: ActivityResultLauncher<Intent> =
rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
Log.d(TAG, "User closes the browser tab. Try to start downloading.")
startDownload(modelManagerViewModel.curAccessToken)
}
// A launcher for handling the authentication flow.
// It processes the result of the authentication activity and then checks if a user agreement
// acknowledgement is needed before proceeding with the model download.
val authResultLauncher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
modelManagerViewModel.handleAuthResult(result, onTokenRequested = { tokenRequestResult ->
when (tokenRequestResult.status) {
TokenRequestResultType.SUCCEEDED -> {
Log.d(TAG, "Token request succeeded. Checking if we need user to ack user agreement")
scope.launch(Dispatchers.IO) {
// Check if we can use the current token to access model. If not, we might need to
// acknowledge the user agreement.
if (modelManagerViewModel.getModelUrlResponse(
model = model,
accessToken = modelManagerViewModel.curAccessToken
) == HttpURLConnection.HTTP_FORBIDDEN
) {
Log.d(TAG, "Model '${model.name}' needs user agreement ack.")
showAgreementAckSheet = true
} else {
Log.d(
TAG,
"Model '${model.name}' does NOT need user agreement ack. Start downloading..."
)
withContext(Dispatchers.Main) {
startDownload(modelManagerViewModel.curAccessToken)
val authResultLauncher =
rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
modelManagerViewModel.handleAuthResult(
result,
onTokenRequested = { tokenRequestResult ->
when (tokenRequestResult.status) {
TokenRequestResultType.SUCCEEDED -> {
Log.d(TAG, "Token request succeeded. Checking if we need user to ack user agreement")
scope.launch(Dispatchers.IO) {
// Check if we can use the current token to access model. If not, we might need to
// acknowledge the user agreement.
if (
modelManagerViewModel.getModelUrlResponse(
model = model,
accessToken = modelManagerViewModel.curAccessToken,
) == HttpURLConnection.HTTP_FORBIDDEN
) {
Log.d(TAG, "Model '${model.name}' needs user agreement ack.")
showAgreementAckSheet = true
} else {
Log.d(
TAG,
"Model '${model.name}' does NOT need user agreement ack. Start downloading...",
)
withContext(Dispatchers.Main) {
startDownload(modelManagerViewModel.curAccessToken)
}
}
}
}
TokenRequestResultType.FAILED -> {
Log.d(
TAG,
"Token request done. Error message: ${tokenRequestResult.errorMessage ?: ""}",
)
checkingToken = false
}
TokenRequestResultType.USER_CANCELLED -> {
Log.d(TAG, "User cancelled. Do nothing")
checkingToken = false
}
}
}
TokenRequestResultType.FAILED -> {
Log.d(TAG, "Token request done. Error message: ${tokenRequestResult.errorMessage ?: ""}")
checkingToken = false
}
TokenRequestResultType.USER_CANCELLED -> {
Log.d(TAG, "User cancelled. Do nothing")
checkingToken = false
}
}
})
}
},
)
}
// Function to kick off the authentication and token exchange flow.
val startTokenExchange = {
@ -213,14 +223,12 @@ fun DownloadAndTryButton(
// Check if the url needs auth.
Log.d(
TAG,
"Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download"
"Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download",
)
val firstResponseCode = modelManagerViewModel.getModelUrlResponse(model = model)
if (firstResponseCode == HttpURLConnection.HTTP_OK) {
Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...")
withContext(Dispatchers.Main) {
startDownload(null)
}
withContext(Dispatchers.Main) { startDownload(null) }
return@launch
} else if (firstResponseCode < 0) {
checkingToken = false
@ -235,37 +243,36 @@ fun DownloadAndTryButton(
when (tokenStatusAndData.status) {
// If token is not stored or expired, log in and request a new token.
TokenStatus.NOT_STORED, TokenStatus.EXPIRED -> {
withContext(Dispatchers.Main) {
startTokenExchange()
}
TokenStatus.NOT_STORED,
TokenStatus.EXPIRED -> {
withContext(Dispatchers.Main) { startTokenExchange() }
}
// If token is still valid...
TokenStatus.NOT_EXPIRED -> {
// Use the current token to check the download url.
Log.d(TAG, "Checking the download url '${model.url}' with the current token...")
val responseCode = modelManagerViewModel.getModelUrlResponse(
model = model, accessToken = tokenStatusAndData.data!!.accessToken
)
val responseCode =
modelManagerViewModel.getModelUrlResponse(
model = model,
accessToken = tokenStatusAndData.data!!.accessToken,
)
if (responseCode == HttpURLConnection.HTTP_OK) {
// Download url is accessible. Download the model.
Log.d(TAG, "Download url is accessible with the current token.")
withContext(Dispatchers.Main) {
startDownload(tokenStatusAndData.data.accessToken)
startDownload(tokenStatusAndData.data!!.accessToken)
}
}
// Download url is NOT accessible. Request a new token.
else {
Log.d(
TAG,
"Download url is NOT accessible. Response code: ${responseCode}. Trying to request a new token."
"Download url is NOT accessible. Response code: ${responseCode}. Trying to request a new token.",
)
withContext(Dispatchers.Main) {
startTokenExchange()
}
withContext(Dispatchers.Main) { startTokenExchange() }
}
}
}
@ -274,24 +281,50 @@ fun DownloadAndTryButton(
else {
Log.d(
TAG,
"Model '${model.name}' is not from huggingface. Start downloading the model..."
"Model '${model.name}' is not from huggingface. Start downloading the model...",
)
withContext(Dispatchers.Main) {
startDownload(null)
}
withContext(Dispatchers.Main) { startDownload(null) }
}
} else {
withContext(Dispatchers.Main) {
onClicked()
val activityManager =
context.getSystemService(android.app.Activity.ACTIVITY_SERVICE) as? ActivityManager
val estimatedPeakMemoryInBytes = model.estimatedPeakMemoryInBytes
val isMemoryLow =
if (activityManager != null && estimatedPeakMemoryInBytes != null) {
val memoryInfo = ActivityManager.MemoryInfo()
activityManager.getMemoryInfo(memoryInfo)
Log.d(
TAG,
"AvailMem: ${memoryInfo.availMem}. TotalMem: ${memoryInfo.totalMem}. Estimated peak memory: ${estimatedPeakMemoryInBytes}.",
)
// The device should be able to run the model if `availMem` is larger than the
// estimated peak memory. Android also has a mechanism to kill background apps to
// free up memory for the foreground app. We believe that if half of the total
// memory on the device is larger than the estimated peak memory, it can run the
// model fine with this mechanism. For example, a phone with 12GB memory can have
// very few `availMem` but will have no problem running most models.
max(memoryInfo.availMem, memoryInfo.totalMem / 2) < estimatedPeakMemoryInBytes
} else {
false
}
if (isMemoryLow) {
showMemoryWarning = true
} else {
onClicked()
}
}
}
}
},
}
) {
Icon(
Icons.AutoMirrored.Rounded.ArrowForward,
contentDescription = "",
modifier = Modifier.padding(end = 4.dp)
modifier = Modifier.padding(end = 4.dp),
)
val textColor = MaterialTheme.colorScheme.onPrimary
@ -301,11 +334,7 @@ fun DownloadAndTryButton(
maxLines = 1,
color = { textColor },
style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased(
minFontSize = 8.sp,
maxFontSize = 14.sp,
stepSize = 1.sp
)
autoSize = TextAutoSize.StepBased(minFontSize = 8.sp, maxFontSize = 14.sp, stepSize = 1.sp),
)
} else {
if (needToDownloadFirst) {
@ -314,11 +343,8 @@ fun DownloadAndTryButton(
maxLines = 1,
color = { textColor },
style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased(
minFontSize = 8.sp,
maxFontSize = 14.sp,
stepSize = 1.sp
)
autoSize =
TextAutoSize.StepBased(minFontSize = 8.sp, maxFontSize = 14.sp, stepSize = 1.sp),
)
} else {
Text("Try it", maxLines = 1)
@ -341,28 +367,30 @@ fun DownloadAndTryButton(
) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier.padding(horizontal = 16.dp)
modifier = Modifier.padding(horizontal = 16.dp),
) {
Text("Acknowledge user agreement", style = MaterialTheme.typography.titleLarge)
Text(
"This is a gated model. Please click the button below to view and agree to the user agreement. After accepting, simply close that tab to proceed with the model download.",
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier.padding(vertical = 16.dp)
modifier = Modifier.padding(vertical = 16.dp),
)
Button(onClick = {
// Get agreement url from model url.
val index = model.url.indexOf("/resolve/")
// Show it in a tab.
if (index >= 0) {
val agreementUrl = model.url.substring(0, index)
Button(
onClick = {
// Get agreement url from model url.
val index = model.url.indexOf("/resolve/")
// Show it in a tab.
if (index >= 0) {
val agreementUrl = model.url.substring(0, index)
val customTabsIntent = CustomTabsIntent.Builder().build()
customTabsIntent.intent.setData(Uri.parse(agreementUrl))
agreementAckLauncher.launch(customTabsIntent.intent)
val customTabsIntent = CustomTabsIntent.Builder().build()
customTabsIntent.intent.setData(agreementUrl.toUri())
agreementAckLauncher.launch(customTabsIntent.intent)
}
// Dismiss the sheet.
showAgreementAckSheet = false
}
// Dismiss the sheet.
showAgreementAckSheet = false
}) {
) {
Text("Open user agreement")
}
}
@ -374,24 +402,34 @@ fun DownloadAndTryButton(
icon = {
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
},
title = {
Text("Unknown network error")
},
title = { Text("Unknown network error") },
text = { Text("Please check your internet connection.") },
onDismissRequest = { showErrorDialog = false },
confirmButton = { TextButton(onClick = { showErrorDialog = false }) { Text("Close") } },
)
}
if (showMemoryWarning) {
AlertDialog(
title = { Text("Memory Warning") },
text = {
Text("Please check your internet connection.")
},
onDismissRequest = {
showErrorDialog = false
Text(
"This model might need more memory than your device has available. " +
"Running it could cause the app to crash."
)
},
onDismissRequest = { showMemoryWarning = false },
confirmButton = {
TextButton(
onClick = {
showErrorDialog = false
onClicked()
showMemoryWarning = false
}
) {
Text("Close")
Text("Continue")
}
},
dismissButton = { TextButton(onClick = { showMemoryWarning = false }) { Text("Cancel") } },
)
}
}

View file

@ -33,18 +33,17 @@ import androidx.compose.ui.window.Dialog
@Composable
fun ErrorDialog(error: String, onDismiss: () -> Unit) {
Dialog(
onDismissRequest = onDismiss
) {
Dialog(onDismissRequest = onDismiss) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) {
// Title
Text(
"Error",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
modifier = Modifier.padding(bottom = 8.dp),
)
// Error
@ -55,11 +54,7 @@ fun ErrorDialog(error: String, onDismiss: () -> Unit) {
)
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button(
onClick = onDismiss
) {
Text("Close")
}
Button(onClick = onDismiss) { Text("Close") }
}
}
}

View file

@ -14,7 +14,7 @@
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.chat
package com.google.ai.edge.gallery.ui.common
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ProvideTextStyle
@ -25,8 +25,6 @@ import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.TextLinkStyles
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.tooling.preview.Preview
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors
import com.halilibo.richtext.commonmark.Markdown
import com.halilibo.richtext.ui.CodeBlockStyle
@ -34,52 +32,43 @@ import com.halilibo.richtext.ui.RichTextStyle
import com.halilibo.richtext.ui.material3.RichText
import com.halilibo.richtext.ui.string.RichTextStringStyle
/**
* Composable function to display Markdown-formatted text.
*/
/** Composable function to display Markdown-formatted text. */
@Composable
fun MarkdownText(
text: String,
modifier: Modifier = Modifier,
smallFontSize: Boolean = false
) {
fun MarkdownText(text: String, modifier: Modifier = Modifier, smallFontSize: Boolean = false) {
val fontSize =
if (smallFontSize) MaterialTheme.typography.bodyMedium.fontSize else MaterialTheme.typography.bodyLarge.fontSize
if (smallFontSize) MaterialTheme.typography.bodyMedium.fontSize
else MaterialTheme.typography.bodyLarge.fontSize
CompositionLocalProvider {
ProvideTextStyle(
value = TextStyle(
fontSize = fontSize,
lineHeight = fontSize * 1.3,
)
) {
ProvideTextStyle(value = TextStyle(fontSize = fontSize, lineHeight = fontSize * 1.3)) {
RichText(
modifier = modifier,
style = RichTextStyle(
codeBlockStyle = CodeBlockStyle(
textStyle = TextStyle(
fontSize = MaterialTheme.typography.bodySmall.fontSize,
fontFamily = FontFamily.Monospace,
)
style =
RichTextStyle(
codeBlockStyle =
CodeBlockStyle(
textStyle =
TextStyle(
fontSize = MaterialTheme.typography.bodySmall.fontSize,
fontFamily = FontFamily.Monospace,
)
),
stringStyle =
RichTextStringStyle(
linkStyle =
TextLinkStyles(style = SpanStyle(color = MaterialTheme.customColors.linkColor))
),
),
stringStyle = RichTextStringStyle(
linkStyle = TextLinkStyles(
style = SpanStyle(color = MaterialTheme.customColors.linkColor)
)
)
),
) {
Markdown(
content = text
)
Markdown(content = text)
}
}
}
}
@Preview(showBackground = true)
@Composable
fun MarkdownTextPreview() {
GalleryTheme {
MarkdownText(text = "*Hello World*\n**Good morning!!**")
}
}
// @Preview(showBackground = true)
// @Composable
// fun MarkdownTextPreview() {
// GalleryTheme {
// MarkdownText(text = "*Hello World*\n**Good morning!!**")
// }
// }

View file

@ -53,7 +53,7 @@ import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.chat.ConfigDialog
import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
@ -71,54 +71,54 @@ fun ModelPageAppBar(
isResettingSession: Boolean = false,
onResetSessionClicked: (Model) -> Unit = {},
canShowResetSessionButton: Boolean = false,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit =
{ _, _ ->
},
) {
var showConfigDialog by remember { mutableStateOf(false) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val context = LocalContext.current
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[model.name]
val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[model.name]
CenterAlignedTopAppBar(title = {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(4.dp)
) {
// Task type.
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
CenterAlignedTopAppBar(
title = {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(4.dp),
) {
Icon(
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
tint = getTaskIconColor(task = task),
modifier = Modifier.size(16.dp),
contentDescription = "",
)
Text(
task.type.label,
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.SemiBold),
color = getTaskIconColor(task = task)
// Task type.
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
tint = getTaskIconColor(task = task),
modifier = Modifier.size(16.dp),
contentDescription = "",
)
Text(
task.type.label,
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.SemiBold),
color = getTaskIconColor(task = task),
)
}
// Model chips pager.
ModelPickerChipsPager(
task = task,
initialModel = model,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = onModelSelected,
)
}
// Model chips pager.
ModelPickerChipsPager(
task = task,
initialModel = model,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = onModelSelected,
)
}
}, modifier = modifier,
},
modifier = modifier,
// The back button.
navigationIcon = {
IconButton(onClick = onBackClicked) {
Icon(
imageVector = Icons.AutoMirrored.Rounded.ArrowBack,
contentDescription = "",
)
Icon(imageVector = Icons.AutoMirrored.Rounded.ArrowBack, contentDescription = "")
}
},
// The config button for the model (if existed).
@ -136,19 +136,16 @@ fun ModelPageAppBar(
if (showConfigButton) {
val enableConfigButton = !isModelInitializing && !inProgress
IconButton(
onClick = {
showConfigDialog = true
},
onClick = { showConfigDialog = true },
enabled = enableConfigButton,
modifier = Modifier
.offset(x = configButtonOffset)
.alpha(if (!enableConfigButton) 0.5f else 1f)
modifier =
Modifier.offset(x = configButtonOffset).alpha(if (!enableConfigButton) 0.5f else 1f),
) {
Icon(
imageVector = Icons.Rounded.Tune,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary,
modifier = Modifier.size(20.dp)
modifier = Modifier.size(20.dp),
)
}
}
@ -157,39 +154,35 @@ fun ModelPageAppBar(
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceVariant,
strokeWidth = 2.dp,
modifier = Modifier.size(16.dp)
modifier = Modifier.size(16.dp),
)
} else {
val enableResetButton = !isModelInitializing && !modelPreparing
IconButton(
onClick = {
onResetSessionClicked(model)
},
onClick = { onResetSessionClicked(model) },
enabled = enableResetButton,
modifier = Modifier
.alpha(if (!enableResetButton) 0.5f else 1f)
modifier = Modifier.alpha(if (!enableResetButton) 0.5f else 1f),
) {
Box(
modifier = Modifier
.size(32.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainer),
contentAlignment = Alignment.Center
modifier =
Modifier.size(32.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainer),
contentAlignment = Alignment.Center,
) {
Icon(
imageVector = Icons.Rounded.MapsUgc,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary,
modifier = Modifier
.size(20.dp)
modifier = Modifier.size(20.dp),
)
}
}
}
}
}
})
},
)
// Config dialog.
if (showConfigDialog) {
@ -208,12 +201,16 @@ fun ModelPageAppBar(
var needReinitialization = false
for (config in model.configs) {
val key = config.key.label
val oldValue = convertValueToTargetType(
value = model.configValues.getValue(key), valueType = config.valueType
)
val newValue = convertValueToTargetType(
value = curConfigValues.getValue(key), valueType = config.valueType
)
val oldValue =
convertValueToTargetType(
value = model.configValues.getValue(key),
valueType = config.valueType,
)
val newValue =
convertValueToTargetType(
value = curConfigValues.getValue(key),
valueType = config.valueType,
)
if (oldValue != newValue) {
same = false
if (config.needReinitialization) {
@ -233,7 +230,10 @@ fun ModelPageAppBar(
// Force to re-initialize the model with the new configs.
if (needReinitialization) {
modelManagerViewModel.initializeModel(
context = context, task = task, model = model, force = true
context = context,
task = task,
model = model,
force = true,
)
}
@ -242,4 +242,4 @@ fun ModelPageAppBar(
},
)
}
}
}

View file

@ -16,6 +16,11 @@
package com.google.ai.edge.gallery.ui.common
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
@ -28,7 +33,6 @@ import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.CheckCircle
import androidx.compose.material.icons.outlined.CheckCircle
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
@ -39,34 +43,27 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.modelitem.StatusIcon
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
@Composable
fun ModelPicker(
task: Task,
modelManagerViewModel: ModelManagerViewModel,
onModelSelected: (Model) -> Unit
onModelSelected: (Model) -> Unit,
) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
Column(modifier = Modifier.padding(bottom = 8.dp)) {
// Title
Row(
modifier = Modifier
.padding(horizontal = 16.dp)
.padding(top = 4.dp, bottom = 4.dp),
modifier = Modifier.padding(horizontal = 16.dp).padding(top = 4.dp, bottom = 4.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
) {
@ -90,51 +87,47 @@ fun ModelPicker(
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier
.fillMaxWidth()
.clickable {
onModelSelected(model)
}
.background(if (selected) MaterialTheme.colorScheme.surfaceContainer else Color.Transparent)
.padding(horizontal = 16.dp, vertical = 8.dp),
modifier =
Modifier.fillMaxWidth()
.clickable { onModelSelected(model) }
.background(
if (selected) MaterialTheme.colorScheme.surfaceContainer else Color.Transparent
)
.padding(horizontal = 16.dp, vertical = 8.dp),
) {
Spacer(modifier = Modifier.width(24.dp))
Column(modifier = Modifier.weight(1f)) {
Text(model.name, style = MaterialTheme.typography.bodyMedium)
Row(
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically
verticalAlignment = Alignment.CenterVertically,
) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text(
model.sizeInBytes.humanReadableSize(),
color = MaterialTheme.colorScheme.secondary,
style = labelSmallNarrow.copy(lineHeight = 10.sp)
style = labelSmallNarrow.copy(lineHeight = 10.sp),
)
}
}
if (selected) {
Icon(
Icons.Filled.CheckCircle,
modifier = Modifier.size(16.dp),
contentDescription = ""
)
Icon(Icons.Filled.CheckCircle, modifier = Modifier.size(16.dp), contentDescription = "")
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ModelPickerPreview() {
val context = LocalContext.current
// @Preview(showBackground = true)
// @Composable
// fun ModelPickerPreview() {
// val context = LocalContext.current
GalleryTheme {
ModelPicker(
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
onModelSelected = {},
)
}
}
// GalleryTheme {
// ModelPicker(
// task = TASK_TEST1,
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// onModelSelected = {},
// )
// }
// }

View file

@ -64,9 +64,9 @@ import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.modelitem.StatusIcon
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlin.math.absoluteValue
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlin.math.absoluteValue
@OptIn(ExperimentalMaterial3Api::class)
@Composable
@ -83,14 +83,13 @@ fun ModelPickerChipsPager(
val scope = rememberCoroutineScope()
val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current
val screenWidthDp = remember {
with(density) {
windowInfo.containerSize.width.toDp()
}
}
val screenWidthDp = remember { with(density) { windowInfo.containerSize.width.toDp() } }
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
pageCount = { task.models.size })
val pagerState =
rememberPagerState(
initialPage = task.models.indexOf(initialModel),
pageCount = { task.models.size },
)
// Sync scrolling.
LaunchedEffect(modelManagerViewModel.pagerScrollState) {
@ -107,56 +106,51 @@ fun ModelPickerChipsPager(
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
val curAlpha = 1f - (pageOffset * 1.5f).coerceIn(0f, 1f)
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[model.name]
val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[model.name]
Box(
modifier = Modifier
.fillMaxWidth()
.graphicsLayer { alpha = curAlpha },
contentAlignment = Alignment.Center
modifier = Modifier.fillMaxWidth().graphicsLayer { alpha = curAlpha },
contentAlignment = Alignment.Center,
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp)
horizontalArrangement = Arrangement.spacedBy(2.dp),
) {
Row(verticalAlignment = Alignment.CenterVertically,
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
.clickable {
modelPickerModel = model
showModelPicker = true
}
.padding(start = 8.dp, end = 2.dp)
.padding(vertical = 4.dp)) Inner@{
modifier =
Modifier.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
.clickable {
modelPickerModel = model
showModelPicker = true
}
.padding(start = 8.dp, end = 2.dp)
.padding(vertical = 4.dp),
) Inner@{
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(21.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
this@Inner.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
visible =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
) {
// Circular progress indicator.
CircularProgressIndicator(
modifier = Modifier
.size(24.dp)
.alpha(0.5f),
modifier = Modifier.size(24.dp).alpha(0.5f),
strokeWidth = 2.dp,
color = MaterialTheme.colorScheme.onSurfaceVariant
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
}
}
Text(
model.name,
style = MaterialTheme.typography.labelLarge,
modifier = Modifier
.padding(start = 4.dp)
.widthIn(0.dp, screenWidthDp - 250.dp),
modifier = Modifier.padding(start = 4.dp).widthIn(0.dp, screenWidthDp - 250.dp),
maxLines = 1,
overflow = TextOverflow.MiddleEllipsis
overflow = TextOverflow.MiddleEllipsis,
)
Icon(
Icons.Rounded.ArrowDropDown,
@ -171,10 +165,7 @@ fun ModelPickerChipsPager(
// Model picker.
val curModelPickerModel = modelPickerModel
if (showModelPicker && curModelPickerModel != null) {
ModalBottomSheet(
onDismissRequest = { showModelPicker = false },
sheetState = sheetState,
) {
ModalBottomSheet(onDismissRequest = { showModelPicker = false }, sheetState = sheetState) {
ModelPicker(
task = task,
modelManagerViewModel = modelManagerViewModel,
@ -187,8 +178,8 @@ fun ModelPickerChipsPager(
}
onModelSelected(selectedModel)
}
},
)
}
}
}
}

View file

@ -52,28 +52,22 @@ private val SHAPES: List<Int> =
listOf(R.drawable.pantegon, R.drawable.double_circle, R.drawable.circle, R.drawable.four_circle)
/**
* Composable that displays an icon representing a task. It consists of a background
* image and a foreground icon, both centered within a square box.
* Composable that displays an icon representing a task. It consists of a background image and a
* foreground icon, both centered within a square box.
*/
@Composable
fun TaskIcon(task: Task, modifier: Modifier = Modifier, width: Dp = 56.dp) {
Box(
modifier = modifier
.width(width)
.aspectRatio(1f),
contentAlignment = Alignment.Center,
) {
Box(modifier = modifier.width(width).aspectRatio(1f), contentAlignment = Alignment.Center) {
Image(
painter = getTaskIconBgShape(task = task),
contentDescription = "",
modifier = Modifier
.fillMaxSize()
.alpha(0.6f),
modifier = Modifier.fillMaxSize().alpha(0.6f),
contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint(
MaterialTheme.customColors.taskIconShapeBgColor,
blendMode = BlendMode.SrcIn
)
colorFilter =
ColorFilter.tint(
MaterialTheme.customColors.taskIconShapeBgColor,
blendMode = BlendMode.SrcIn,
),
)
Icon(
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
@ -102,4 +96,4 @@ fun TaskIconPreview() {
TaskIcon(task = TASK_LLM_CHAT, width = 80.dp)
}
}
}
}

View file

@ -21,63 +21,18 @@ import android.content.Context
import android.content.pm.PackageManager
import android.net.Uri
import android.os.Build
import android.util.Log
import androidx.activity.compose.ManagedActivityResultLauncher
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.graphics.Color
import androidx.core.content.ContextCompat
import androidx.core.content.FileProvider
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASKS
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.ValueType
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkResult
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageType
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import com.google.ai.edge.gallery.ui.common.chat.Histogram
import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.customColors
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
import kotlin.math.abs
import kotlin.math.ln
import kotlin.math.max
import kotlin.math.min
import kotlin.math.pow
import kotlin.math.sqrt
private const val TAG = "AGUtils"
private const val LAUNCH_INFO_FILE_NAME = "launch_info"
private val STATS = listOf(
Stat(id = "min", label = "Min", unit = "ms"),
Stat(id = "max", label = "Max", unit = "ms"),
Stat(id = "avg", label = "Avg", unit = "ms"),
Stat(id = "stddev", label = "Stddev", unit = "ms")
)
interface LatencyProvider {
val latencyMs: Float
}
private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***"
data class JsonObjAndTextContent<T>(
val jsonObj: T, val textContent: String,
)
data class LaunchInfo(
val ts: Long
)
/** Format the bytes into a human-readable format. */
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
@ -139,320 +94,56 @@ fun Long.formatToHourMinSecond(): String {
return parts.joinToString(" ")
}
fun convertValueToTargetType(value: Any, valueType: ValueType): Any {
return when (valueType) {
ValueType.INT -> when (value) {
is Int -> value
is Float -> value.toInt()
is Double -> value.toInt()
is String -> value.toIntOrNull() ?: ""
is Boolean -> if (value) 1 else 0
else -> ""
}
ValueType.FLOAT -> when (value) {
is Int -> value.toFloat()
is Float -> value
is Double -> value.toFloat()
is String -> value.toFloatOrNull() ?: ""
is Boolean -> if (value) 1f else 0f
else -> ""
}
ValueType.DOUBLE -> when (value) {
is Int -> value.toDouble()
is Float -> value.toDouble()
is Double -> value
is String -> value.toDoubleOrNull() ?: ""
is Boolean -> if (value) 1.0 else 0.0
else -> ""
}
ValueType.BOOLEAN -> when (value) {
is Int -> value == 0
is Boolean -> value
is Float -> abs(value) > 1e-6
is Double -> abs(value) > 1e-6
is String -> value.isNotEmpty()
else -> false
}
ValueType.STRING -> value.toString()
}
}
fun getDistinctiveColor(index: Int): Color {
val colors = listOf(
// Color(0xffe6194b),
Color(0xff3cb44b),
Color(0xffffe119),
Color(0xff4363d8),
Color(0xfff58231),
Color(0xff911eb4),
Color(0xff46f0f0),
Color(0xfff032e6),
Color(0xffbcf60c),
Color(0xfffabebe),
Color(0xff008080),
Color(0xffe6beff),
Color(0xff9a6324),
Color(0xfffffac8),
Color(0xff800000),
Color(0xffaaffc3),
Color(0xff808000),
Color(0xffffd8b1),
Color(0xff000075)
)
val colors =
listOf(
// Color(0xffe6194b),
Color(0xff3cb44b),
Color(0xffffe119),
Color(0xff4363d8),
Color(0xfff58231),
Color(0xff911eb4),
Color(0xff46f0f0),
Color(0xfff032e6),
Color(0xffbcf60c),
Color(0xfffabebe),
Color(0xff008080),
Color(0xffe6beff),
Color(0xff9a6324),
Color(0xfffffac8),
Color(0xff800000),
Color(0xffaaffc3),
Color(0xff808000),
Color(0xffffd8b1),
Color(0xff000075),
)
return colors[index % colors.size]
}
fun Context.createTempPictureUri(
fileName: String = "picture_${System.currentTimeMillis()}", fileExtension: String = ".png"
fileName: String = "picture_${System.currentTimeMillis()}",
fileExtension: String = ".png",
): Uri {
val tempFile = File.createTempFile(
fileName, fileExtension, cacheDir
).apply {
createNewFile()
}
val tempFile = File.createTempFile(fileName, fileExtension, cacheDir).apply { createNewFile() }
return FileProvider.getUriForFile(
applicationContext,
"com.google.aiedge.gallery.provider" /* {applicationId}.provider */,
tempFile
tempFile,
)
}
fun runBasicBenchmark(
model: Model,
warmupCount: Int,
iterations: Int,
chatViewModel: ChatViewModel,
inferenceFn: () -> LatencyProvider,
chatMessageType: ChatMessageType,
) {
val start = System.currentTimeMillis()
var lastUpdateTs = 0L
val update: (ChatMessageBenchmarkResult) -> Unit = { message ->
if (lastUpdateTs == 0L) {
chatViewModel.addMessage(
model = model,
message = message,
)
lastUpdateTs = System.currentTimeMillis()
} else {
val curTs = System.currentTimeMillis()
if (curTs - lastUpdateTs > 500) {
chatViewModel.replaceLastMessage(model = model, message = message, type = chatMessageType)
lastUpdateTs = curTs
}
}
}
// Warmup.
val latencies: MutableList<Float> = mutableListOf()
for (count in 1..warmupCount) {
inferenceFn()
update(
ChatMessageBenchmarkResult(
orderedStats = STATS,
statValues = calculateStats(min = 0f, max = 0f, sum = 0f, latencies = latencies),
histogram = calculateLatencyHistogram(
latencies = latencies, min = 0f, max = 0f, avg = 0f
),
values = latencies,
warmupCurrent = count,
warmupTotal = warmupCount,
iterationCurrent = 0,
iterationTotal = iterations,
latencyMs = (System.currentTimeMillis() - start).toFloat(),
highlightStat = "avg"
)
)
}
// Benchmark iterations.
var min = Float.MAX_VALUE
var max = 0f
var sum = 0f
for (count in 1..iterations) {
val result = inferenceFn()
val latency = result.latencyMs
min = min(min, latency)
max = max(max, latency)
sum += latency
latencies.add(latency)
val curTs = System.currentTimeMillis()
if (curTs - lastUpdateTs > 500 || count == iterations) {
lastUpdateTs = curTs
val stats = calculateStats(min = min, max = max, sum = sum, latencies = latencies)
chatViewModel.replaceLastMessage(
model = model,
message = ChatMessageBenchmarkResult(
orderedStats = STATS,
statValues = stats,
histogram = calculateLatencyHistogram(
latencies = latencies,
min = min,
max = max,
avg = stats["avg"] ?: 0f,
),
values = latencies,
warmupCurrent = warmupCount,
warmupTotal = warmupCount,
iterationCurrent = count,
iterationTotal = iterations,
latencyMs = (System.currentTimeMillis() - start).toFloat(),
highlightStat = "avg"
),
type = chatMessageType,
)
}
// Go through other benchmark messages and update their buckets for the common min/max values.
var allMin = Float.MAX_VALUE
var allMax = 0f
val allMessages = chatViewModel.uiState.value.messagesByModel[model.name] ?: listOf()
for (message in allMessages) {
if (message is ChatMessageBenchmarkResult) {
val curMin = message.statValues["min"] ?: 0f
val curMax = message.statValues["max"] ?: 0f
allMin = min(allMin, curMin)
allMax = max(allMax, curMax)
}
}
for ((index, message) in allMessages.withIndex()) {
if (message === allMessages.last()) {
break
}
if (message is ChatMessageBenchmarkResult) {
val updatedMessage = ChatMessageBenchmarkResult(
orderedStats = STATS,
statValues = message.statValues,
histogram = calculateLatencyHistogram(
latencies = message.values,
min = allMin,
max = allMax,
avg = message.statValues["avg"] ?: 0f,
),
values = message.values,
warmupCurrent = message.warmupCurrent,
warmupTotal = message.warmupTotal,
iterationCurrent = message.iterationCurrent,
iterationTotal = message.iterationTotal,
latencyMs = message.latencyMs,
highlightStat = "avg"
)
chatViewModel.replaceMessage(model = model, index = index, message = updatedMessage)
}
}
}
}
private fun calculateStats(
min: Float, max: Float, sum: Float, latencies: MutableList<Float>
): MutableMap<String, Float> {
val avg = if (latencies.size == 0) 0f else sum / latencies.size
val squaredDifferences = latencies.map { (it - avg).pow(2) }
val variance = squaredDifferences.average()
val stddev = if (latencies.size == 0) 0f else sqrt(variance).toFloat()
var medium = 0f
if (latencies.size == 1) {
medium = latencies[0]
} else if (latencies.size > 1) {
latencies.sort()
val middle = latencies.size / 2
medium =
if (latencies.size % 2 == 0) (latencies[middle - 1] + latencies[middle]) / 2.0f else latencies[middle]
}
return mutableMapOf(
"min" to min, "max" to max, "avg" to avg, "stddev" to stddev, "medium" to medium
)
}
fun calculateLatencyHistogram(
latencies: List<Float>, min: Float, max: Float, avg: Float, numBuckets: Int = 20
): Histogram {
if (latencies.isEmpty() || numBuckets <= 0) {
return Histogram(
buckets = List(numBuckets) { 0 }, maxCount = 0
)
}
if (min == max) {
// All latencies are the same.
val result = MutableList(numBuckets) { 0 }
result[0] = latencies.size
return Histogram(buckets = result, maxCount = result[0], highlightBucketIndex = 0)
}
val bucketSize = (max - min) / numBuckets
val histogram = MutableList(numBuckets) { 0 }
val getBucketIndex: (value: Float) -> Int = {
var bucketIndex = ((it - min) / bucketSize).toInt()
// Handle the case where latency equals maxLatency
if (bucketIndex == numBuckets) {
bucketIndex = numBuckets - 1
}
bucketIndex
}
for (latency in latencies) {
val bucketIndex = getBucketIndex(latency)
histogram[bucketIndex]++
}
val avgBucketIndex = getBucketIndex(avg)
return Histogram(
buckets = histogram,
maxCount = histogram.maxOrNull() ?: 0,
highlightBucketIndex = avgBucketIndex
)
}
fun getConfigValueString(value: Any, config: Config): String {
var strNewValue = "$value"
if (config.valueType == ValueType.FLOAT) {
strNewValue = "%.2f".format(value)
}
return strNewValue
}
@Composable
fun getTaskBgColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskBgColors.size
return MaterialTheme.customColors.taskBgColors[colorIndex]
}
@Composable
fun getTaskIconColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}
@Composable
fun getTaskIconColor(index: Int): Color {
val colorIndex: Int = index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}
fun checkNotificationPermissionAndStartDownload(
context: Context,
launcher: ManagedActivityResultLauncher<String, Boolean>,
modelManagerViewModel: ModelManagerViewModel,
task: Task,
model: Model
model: Model,
) {
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(
context, Manifest.permission.POST_NOTIFICATIONS
) -> {
ContextCompat.checkSelfPermission(context, Manifest.permission.POST_NOTIFICATIONS) -> {
modelManagerViewModel.downloadModel(task = task, model = model)
}
@ -468,100 +159,3 @@ fun checkNotificationPermissionAndStartDownload(
fun ensureValidFileName(fileName: String): String {
return fileName.replace(Regex("[^a-zA-Z0-9._-]"), "_")
}
fun cleanUpMediapipeTaskErrorMessage(message: String): String {
val index = message.indexOf("=== Source Location Trace")
if (index >= 0) {
return message.substring(0, index)
}
return message
}
fun processTasks() {
for ((index, task) in TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess()
}
}
}
fun processLlmResponse(response: String): String {
// Add "thinking" and "done thinking" around the thinking content.
var newContent =
response.replace("<think>", "$START_THINKING\n").replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) {
val thinkingContent =
newContent.substring(0, endThinkingIndex + DONE_THINKING.length).replace(START_THINKING, "")
.replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
}
}
newContent = newContent.replace("\\n", "\n")
return newContent
}
@OptIn(ExperimentalSerializationApi::class)
inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET"
connection.connect()
val responseCode = connection.responseCode
if (responseCode == HttpURLConnection.HTTP_OK) {
val inputStream = connection.inputStream
val response = inputStream.bufferedReader().use { it.readText() }
// Parse JSON using kotlinx.serialization
val json = Json {
// Handle potential extra fields
ignoreUnknownKeys = true
allowComments = true
allowTrailingComma = true
}
val jsonObj = json.decodeFromString<T>(response)
return JsonObjAndTextContent(jsonObj = jsonObj, textContent = response)
} else {
Log.e("AGUtils", "HTTP error: $responseCode")
}
} catch (e: Exception) {
Log.e(
"AGUtils", "Error when getting json response: ${e.message}"
)
e.printStackTrace()
}
return null
}
fun writeLaunchInfo(context: Context) {
try {
val gson = Gson()
val launchInfo = LaunchInfo(ts = System.currentTimeMillis())
val jsonString = gson.toJson(launchInfo)
val file = File(context.getExternalFilesDir(null), LAUNCH_INFO_FILE_NAME)
file.writeText(jsonString)
} catch (e: Exception) {
Log.e(TAG, "Failed to write launch info", e)
}
}
fun readLaunchInfo(context: Context): LaunchInfo? {
try {
val gson = Gson()
val type = object : TypeToken<LaunchInfo>() {}.type
val file = File(context.getExternalFilesDir(null), LAUNCH_INFO_FILE_NAME)
val content = file.readText()
return gson.fromJson(content, type)
} catch (e: Exception) {
Log.e(TAG, "Failed to read launch info", e)
return null
}
}

View file

@ -16,52 +16,55 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.tooling.preview.Preview
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.ValueType
import com.google.ai.edge.gallery.ui.common.convertValueToTargetType
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.ui.common.ConfigDialog
private const val DEFAULT_BENCHMARK_WARM_UP_ITERATIONS = 50f
private const val DEFAULT_BENCHMARK_ITERATIONS = 200f
private val BENCHMARK_CONFIGS: List<Config> = listOf(
NumberSliderConfig(
key = ConfigKey.WARM_UP_ITERATIONS,
sliderMin = 10f,
sliderMax = 200f,
defaultValue = DEFAULT_BENCHMARK_WARM_UP_ITERATIONS,
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.BENCHMARK_ITERATIONS,
sliderMin = 50f,
sliderMax = 500f,
defaultValue = DEFAULT_BENCHMARK_ITERATIONS,
valueType = ValueType.INT
),
)
private val BENCHMARK_CONFIGS: List<Config> =
listOf(
NumberSliderConfig(
key = ConfigKey.WARM_UP_ITERATIONS,
sliderMin = 10f,
sliderMax = 200f,
defaultValue = DEFAULT_BENCHMARK_WARM_UP_ITERATIONS,
valueType = ValueType.INT,
),
NumberSliderConfig(
key = ConfigKey.BENCHMARK_ITERATIONS,
sliderMin = 50f,
sliderMax = 500f,
defaultValue = DEFAULT_BENCHMARK_ITERATIONS,
valueType = ValueType.INT,
),
)
private val BENCHMARK_CONFIGS_INITIAL_VALUES = mapOf(
ConfigKey.WARM_UP_ITERATIONS.label to DEFAULT_BENCHMARK_WARM_UP_ITERATIONS,
ConfigKey.BENCHMARK_ITERATIONS.label to DEFAULT_BENCHMARK_ITERATIONS
)
private val BENCHMARK_CONFIGS_INITIAL_VALUES =
mapOf(
ConfigKey.WARM_UP_ITERATIONS.label to DEFAULT_BENCHMARK_WARM_UP_ITERATIONS,
ConfigKey.BENCHMARK_ITERATIONS.label to DEFAULT_BENCHMARK_ITERATIONS,
)
/**
* Composable function to display a configuration dialog for benchmarking a chat message.
*
* This function renders a configuration dialog specifically tailored for setting up
* benchmark parameters. It allows users to specify warm-up and benchmark iterations
* before running a benchmark test on a given chat message.
* This function renders a configuration dialog specifically tailored for setting up benchmark
* parameters. It allows users to specify warm-up and benchmark iterations before running a
* benchmark test on a given chat message.
*/
@Composable
fun BenchmarkConfigDialog(
onDismissed: () -> Unit,
messageToBenchmark: ChatMessage?,
onBenchmarkClicked: (ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit
onBenchmarkClicked: (ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
) {
ConfigDialog(
title = "Benchmark configs",
@ -75,28 +78,32 @@ fun BenchmarkConfigDialog(
// Start benchmark.
messageToBenchmark?.let { message ->
val warmUpIterations = convertValueToTargetType(
value = curConfigValues.getValue(ConfigKey.WARM_UP_ITERATIONS.label),
valueType = ValueType.INT
) as Int
val benchmarkIterations = convertValueToTargetType(
value = curConfigValues.getValue(ConfigKey.BENCHMARK_ITERATIONS.label),
valueType = ValueType.INT
) as Int
val warmUpIterations =
convertValueToTargetType(
value = curConfigValues.getValue(ConfigKey.WARM_UP_ITERATIONS.label),
valueType = ValueType.INT,
)
as Int
val benchmarkIterations =
convertValueToTargetType(
value = curConfigValues.getValue(ConfigKey.BENCHMARK_ITERATIONS.label),
valueType = ValueType.INT,
)
as Int
onBenchmarkClicked(message, warmUpIterations, benchmarkIterations)
}
},
)
}
@Preview(showBackground = true)
@Composable
fun BenchmarkConfigDialogPreview() {
GalleryTheme {
BenchmarkConfigDialog(
onDismissed = {},
messageToBenchmark = null,
onBenchmarkClicked = { _, _, _ -> }
)
}
}
// @Preview(showBackground = true)
// @Composable
// fun BenchmarkConfigDialogPreview() {
// GalleryTheme {
// BenchmarkConfigDialog(
// onDismissed = {},
// messageToBenchmark = null,
// onBenchmarkClicked = { _, _, _ -> },
// )
// }
// }

View file

@ -17,10 +17,11 @@
package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.unit.Dp
import com.google.ai.edge.gallery.common.Classification
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.PromptTemplate
enum class ChatMessageType {
INFO,
@ -33,15 +34,15 @@ enum class ChatMessageType {
CONFIG_VALUES_CHANGE,
BENCHMARK_RESULT,
BENCHMARK_LLM_RESULT,
PROMPT_TEMPLATES
PROMPT_TEMPLATES,
}
enum class ChatSide {
USER, AGENT, SYSTEM
USER,
AGENT,
SYSTEM,
}
data class Classification(val label: String, val score: Float, val color: Color)
/** Base class for a chat message. */
open class ChatMessage(
open val type: ChatMessageType,
@ -70,7 +71,7 @@ class ChatMessageWarning(val content: String) :
class ChatMessageConfigValuesChange(
val model: Model,
val oldValues: Map<String, Any>,
val newValues: Map<String, Any>
val newValues: Map<String, Any>,
) : ChatMessage(type = ChatMessageType.CONFIG_VALUES_CHANGE, side = ChatSide.SYSTEM)
/** Chat message for plain text. */
@ -84,12 +85,13 @@ open class ChatMessageText(
// Benchmark result for LLM response.
var llmBenchmarkResult: ChatMessageBenchmarkLlmResult? = null,
override val accelerator: String = "",
) : ChatMessage(
type = ChatMessageType.TEXT,
side = side,
latencyMs = latencyMs,
accelerator = accelerator
) {
) :
ChatMessage(
type = ChatMessageType.TEXT,
side = side,
latencyMs = latencyMs,
accelerator = accelerator,
) {
override fun clone(): ChatMessageText {
return ChatMessageText(
content = content,
@ -107,15 +109,14 @@ class ChatMessageImage(
val bitmap: Bitmap,
val imageBitMap: ImageBitmap,
override val side: ChatSide,
override val latencyMs: Float = 0f
) :
ChatMessage(type = ChatMessageType.IMAGE, side = side, latencyMs = latencyMs) {
override val latencyMs: Float = 0f,
) : ChatMessage(type = ChatMessageType.IMAGE, side = side, latencyMs = latencyMs) {
override fun clone(): ChatMessageImage {
return ChatMessageImage(
bitmap = bitmap,
imageBitMap = imageBitMap,
side = side,
latencyMs = latencyMs
latencyMs = latencyMs,
)
}
}
@ -128,8 +129,7 @@ class ChatMessageImageWithHistory(
override val side: ChatSide,
override val latencyMs: Float = 0f,
var curIteration: Int = 0, // 0-based
) :
ChatMessage(type = ChatMessageType.IMAGE_WITH_HISTORY, side = side, latencyMs = latencyMs) {
) : ChatMessage(type = ChatMessageType.IMAGE_WITH_HISTORY, side = side, latencyMs = latencyMs) {
fun isRunning(): Boolean {
return curIteration < totalIterations - 1
}
@ -141,7 +141,8 @@ class ChatMessageClassification(
override val latencyMs: Float = 0f,
// Typical android phone width is > 320dp
val maxBarWidth: Dp? = null,
) : ChatMessage(type = ChatMessageType.CLASSIFICATION, side = ChatSide.AGENT, latencyMs = latencyMs)
) :
ChatMessage(type = ChatMessageType.CLASSIFICATION, side = ChatSide.AGENT, latencyMs = latencyMs)
/** A stat used in benchmark result. */
data class Stat(val id: String, val label: String, val unit: String)
@ -162,7 +163,7 @@ class ChatMessageBenchmarkResult(
ChatMessage(
type = ChatMessageType.BENCHMARK_RESULT,
side = ChatSide.AGENT,
latencyMs = latencyMs
latencyMs = latencyMs,
) {
fun isWarmingUp(): Boolean {
return warmupCurrent < warmupTotal
@ -180,23 +181,18 @@ class ChatMessageBenchmarkLlmResult(
val running: Boolean,
override val latencyMs: Float = 0f,
override val accelerator: String = "",
) : ChatMessage(
type = ChatMessageType.BENCHMARK_LLM_RESULT,
side = ChatSide.AGENT,
latencyMs = latencyMs,
accelerator = accelerator,
)
) :
ChatMessage(
type = ChatMessageType.BENCHMARK_LLM_RESULT,
side = ChatSide.AGENT,
latencyMs = latencyMs,
accelerator = accelerator,
)
data class Histogram(
val buckets: List<Int>,
val maxCount: Int,
val highlightBucketIndex: Int = -1
)
data class Histogram(val buckets: List<Int>, val maxCount: Int, val highlightBucketIndex: Int = -1)
/** Chat message for showing prompt templates. */
class ChatMessagePromptTemplates(
val templates: List<PromptTemplate>,
val showMakeYourOwn: Boolean = true,
) : ChatMessage(type = ChatMessageType.PROMPT_TEMPLATES, side = ChatSide.SYSTEM)
data class PromptTemplate(val title: String, val description: String, val prompt: String)

View file

@ -16,6 +16,11 @@
package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.preview.PreviewChatModel
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.content.ClipData
import android.graphics.Bitmap
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
@ -69,15 +74,13 @@ import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
import androidx.compose.ui.input.nestedscroll.NestedScrollSource
import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.ClipEntry
import androidx.compose.ui.platform.LocalClipboard
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.platform.LocalHapticFeedback
import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.Model
@ -86,20 +89,15 @@ import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.ui.common.ErrorDialog
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewChatModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors
import kotlinx.coroutines.launch
enum class ChatInputType {
TEXT, IMAGE,
TEXT,
IMAGE,
}
/**
* Composable function for the main chat panel, displaying messages and handling user input.
*/
/** Composable function for the main chat panel, displaying messages and handling user input. */
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ChatPanel(
@ -126,18 +124,19 @@ fun ChatPanel(
val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope()
val haptic = LocalHapticFeedback.current
val hasImageMessageToLastConfigChange = remember(messages) {
var foundImageMessage = false
for (message in messages.reversed()) {
if (message is ChatMessageConfigValuesChange) {
break
}
if (message is ChatMessageImage) {
foundImageMessage = true
val imageMessageCountToLastConfigChange =
remember(messages) {
var imageMessageCount = 0
for (message in messages.reversed()) {
if (message is ChatMessageConfigValuesChange) {
break
}
if (message is ChatMessageImage) {
imageMessageCount++
}
}
imageMessageCount
}
foundImageMessage
}
var curMessage by remember { mutableStateOf("") } // Correct state
val focusManager = LocalFocusManager.current
@ -163,8 +162,9 @@ fun ChatPanel(
lastMessageContent.value = tmpLastMessage.content
}
}
val lastShowingStatsByModel: MutableState<Map<String, MutableSet<ChatMessage>>> =
remember { mutableStateOf(mapOf()) }
val lastShowingStatsByModel: MutableState<Map<String, MutableSet<ChatMessage>>> = remember {
mutableStateOf(mapOf())
}
// Scroll the content to the bottom when any of these changes.
LaunchedEffect(
@ -217,15 +217,12 @@ fun ChatPanel(
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
}
Column(
modifier = modifier.imePadding()
) {
Column(modifier = modifier.imePadding()) {
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
LazyColumn(
modifier = Modifier
.fillMaxSize()
.nestedScroll(nestedScrollConnection),
state = listState, verticalArrangement = Arrangement.Top,
modifier = Modifier.fillMaxSize().nestedScroll(nestedScrollConnection),
state = listState,
verticalArrangement = Arrangement.Top,
) {
items(messages) { message ->
val imageHistoryCurIndex = remember { mutableIntStateOf(0) }
@ -254,14 +251,14 @@ fun ChatPanel(
val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius)
Column(
modifier = Modifier
.fillMaxWidth()
.padding(
start = 12.dp + extraPaddingStart,
end = 12.dp + extraPaddingEnd,
top = 6.dp,
bottom = 6.dp,
),
modifier =
Modifier.fillMaxWidth()
.padding(
start = 12.dp + extraPaddingStart,
end = 12.dp + extraPaddingEnd,
top = 6.dp,
bottom = 6.dp,
),
horizontalAlignment = hAlign,
) messageColumn@{
// Sender row.
@ -272,7 +269,7 @@ fun ChatPanel(
MessageSender(
message = message,
agentName = agentName,
imageHistoryCurIndex = imageHistoryCurIndex.intValue
imageHistoryCurIndex = imageHistoryCurIndex.intValue,
)
// Message body.
@ -290,40 +287,42 @@ fun ChatPanel(
is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message)
// Prompt templates.
is ChatMessagePromptTemplates -> MessageBodyPromptTemplates(message = message,
task = task,
onPromptClicked = { template ->
onSendMessage(
selectedModel,
listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER))
)
})
is ChatMessagePromptTemplates ->
MessageBodyPromptTemplates(
message = message,
task = task,
onPromptClicked = { template ->
onSendMessage(
selectedModel,
listOf(ChatMessageText(content = template.prompt, side = ChatSide.USER)),
)
},
)
// Non-system messages.
else -> {
// The bubble shape around the message body.
var messageBubbleModifier = Modifier
.clip(
MessageBubbleShape(
radius = bubbleBorderRadius,
hardCornerAtLeftOrRight = hardCornerAtLeftOrRight
var messageBubbleModifier =
Modifier.clip(
MessageBubbleShape(
radius = bubbleBorderRadius,
hardCornerAtLeftOrRight = hardCornerAtLeftOrRight,
)
)
)
.background(backgroundColor)
.background(backgroundColor)
if (message is ChatMessageText) {
messageBubbleModifier = messageBubbleModifier.pointerInput(Unit) {
detectTapGestures(
onLongPress = {
haptic.performHapticFeedback(HapticFeedbackType.LongPress)
longPressedMessage.value = message
showMessageLongPressedSheet = true
},
)
}
messageBubbleModifier =
messageBubbleModifier.pointerInput(Unit) {
detectTapGestures(
onLongPress = {
haptic.performHapticFeedback(HapticFeedbackType.LongPress)
longPressedMessage.value = message
showMessageLongPressedSheet = true
}
)
}
}
Box(
modifier = messageBubbleModifier,
) {
Box(modifier = messageBubbleModifier) {
when (message) {
// Text
is ChatMessageText -> MessageBodyText(message = message)
@ -331,32 +330,35 @@ fun ChatPanel(
// Image
is ChatMessageImage -> {
MessageBodyImage(
message = message, modifier = Modifier
.clickable {
onImageSelected(message.bitmap)
}
message = message,
modifier = Modifier.clickable { onImageSelected(message.bitmap) },
)
}
// Image with history (for image gen)
is ChatMessageImageWithHistory -> MessageBodyImageWithHistory(
message = message, imageHistoryCurIndex = imageHistoryCurIndex
)
is ChatMessageImageWithHistory ->
MessageBodyImageWithHistory(
message = message,
imageHistoryCurIndex = imageHistoryCurIndex,
)
// Classification result
is ChatMessageClassification -> MessageBodyClassification(
message = message, modifier = Modifier.width(
message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH
is ChatMessageClassification ->
MessageBodyClassification(
message = message,
modifier =
Modifier.width(message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH),
)
)
// Benchmark result.
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
// Benchmark LLM result.
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(
message = message, modifier = Modifier.wrapContentWidth()
)
is ChatMessageBenchmarkLlmResult ->
MessageBodyBenchmarkLlm(
message = message,
modifier = Modifier.wrapContentWidth(),
)
else -> {}
}
@ -369,10 +371,14 @@ fun ChatPanel(
) {
LatencyText(message = message)
// A button to show stats for the LLM message.
if (task.type.id.startsWith("llm_") && message is ChatMessageText
// This means we only want to show the action button when the message is done
// generating, at which point the latency will be set.
&& message.latencyMs >= 0
if (
task.type.id.startsWith("llm_") &&
message is ChatMessageText
// This means we only want to show the action button when the message is
// done
// generating, at which point the latency will be set.
&&
message.latencyMs >= 0
) {
val showingStats =
viewModel.isShowingStats(model = selectedModel, message = message)
@ -384,10 +390,7 @@ fun ChatPanel(
viewModel.toggleShowingStats(selectedModel, message)
// Add the stats message after the LLM message.
if (viewModel.isShowingStats(
model = selectedModel, message = message
)
) {
if (viewModel.isShowingStats(model = selectedModel, message = message)) {
val llmBenchmarkResult = message.llmBenchmarkResult
if (llmBenchmarkResult != null) {
viewModel.insertMessageAfter(
@ -399,32 +402,30 @@ fun ChatPanel(
}
// Remove the stats message.
else {
val curMessageIndex = viewModel.getMessageIndex(
model = selectedModel, message = message
)
val curMessageIndex =
viewModel.getMessageIndex(model = selectedModel, message = message)
viewModel.removeMessageAt(
model = selectedModel, index = curMessageIndex + 1
model = selectedModel,
index = curMessageIndex + 1,
)
}
},
enabled = !uiState.inProgress
enabled = !uiState.inProgress,
)
}
}
} else if (message.side == ChatSide.USER) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp)
horizontalArrangement = Arrangement.spacedBy(4.dp),
) {
// Run again button.
if (selectedModel.showRunAgainButton) {
MessageActionButton(
label = stringResource(R.string.run_again),
icon = Icons.Rounded.Refresh,
onClick = {
onRunAgainClicked(selectedModel, message)
},
enabled = !uiState.inProgress
onClick = { onRunAgainClicked(selectedModel, message) },
enabled = !uiState.inProgress,
)
}
@ -437,7 +438,7 @@ fun ChatPanel(
showBenchmarkConfigsDialog = true
benchmarkMessage.value = message
},
enabled = !uiState.inProgress
enabled = !uiState.inProgress,
)
}
}
@ -453,15 +454,16 @@ fun ChatPanel(
// Show an info message for ask image task to get users started.
if (task.type == TaskType.LLM_ASK_IMAGE && messages.isEmpty()) {
Column(
modifier = Modifier
.padding(horizontal = 16.dp)
.fillMaxSize(),
modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
verticalArrangement = Arrangement.Center,
) {
MessageBodyInfo(
ChatMessageInfo(content = "To get started, click + below to add an image and type a prompt to ask a question about it."),
smallFontSize = false
ChatMessageInfo(
content =
"To get started, click + below to add images (up to 10 in a single session) and type a prompt to ask a question about it."
),
smallFontSize = false,
)
}
}
@ -470,16 +472,18 @@ fun ChatPanel(
// Chat input
when (chatInputType) {
ChatInputType.TEXT -> {
// val isLlmTask = task.type == TaskType.LLM_CHAT
// val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
// val isLlmTask = task.type == TaskType.LLM_CHAT
// val notLlmStartScreen = !(messages.size == 1 && messages[0] is
// ChatMessagePromptTemplates)
MessageInputText(
modelManagerViewModel = modelManagerViewModel,
curMessage = curMessage,
inProgress = uiState.inProgress,
isResettingSession = uiState.isResettingSession,
modelPreparing = uiState.preparing,
hasImageMessage = hasImageMessageToLastConfigChange,
modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
imageMessageCount = imageMessageCountToLastConfigChange,
modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
onValueChanged = { curMessage = it },
onSendMessage = {
@ -488,66 +492,78 @@ fun ChatPanel(
},
onOpenPromptTemplatesClicked = {
onSendMessage(
selectedModel, listOf(
selectedModel,
listOf(
ChatMessagePromptTemplates(
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
templates = selectedModel.llmPromptTemplates,
showMakeYourOwn = false,
)
)
),
)
},
onStopButtonClicked = onStopButtonClicked,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false,
showImagePickerInMenu = selectedModel.llmSupportImage,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
)
}
ChatInputType.IMAGE -> MessageInputImage(
disableButtons = uiState.inProgress,
streamingMessage = streamingMessage,
onImageSelected = { bitmap ->
onSendMessage(
selectedModel, listOf(
ChatInputType.IMAGE ->
MessageInputImage(
disableButtons = uiState.inProgress,
streamingMessage = streamingMessage,
onImageSelected = { bitmap ->
onSendMessage(
selectedModel,
listOf(
ChatMessageImage(
bitmap = bitmap,
imageBitMap = bitmap.asImageBitmap(),
side = ChatSide.USER,
)
),
)
},
onStreamImage = { bitmap ->
onStreamImageMessage(
selectedModel,
ChatMessageImage(
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
)
bitmap = bitmap,
imageBitMap = bitmap.asImageBitmap(),
side = ChatSide.USER,
),
)
)
},
onStreamImage = { bitmap ->
onStreamImageMessage(
selectedModel, ChatMessageImage(
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
)
)
},
onStreamEnd = onStreamEnd,
)
},
onStreamEnd = onStreamEnd,
)
}
}
// Error dialog.
if (showErrorDialog) {
ErrorDialog(error = modelInitializationStatus?.error ?: "", onDismiss = {
showErrorDialog = false
})
ErrorDialog(
error = modelInitializationStatus?.error ?: "",
onDismiss = { showErrorDialog = false },
)
}
// Benchmark config dialog.
if (showBenchmarkConfigsDialog) {
BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false },
BenchmarkConfigDialog(
onDismissed = { showBenchmarkConfigsDialog = false },
messageToBenchmark = benchmarkMessage.value,
onBenchmarkClicked = { message, warmUpIterations, benchmarkIterations ->
onBenchmarkClicked(selectedModel, message, warmUpIterations, benchmarkIterations)
})
},
)
}
// Sheet to show when a message is long-pressed.
if (showMessageLongPressedSheet) {
val message = longPressedMessage.value
if (message != null && message is ChatMessageText) {
val clipboardManager = LocalClipboardManager.current
val clipboard = LocalClipboard.current
ModalBottomSheet(
onDismissRequest = { showMessageLongPressedSheet = false },
@ -555,28 +571,32 @@ fun ChatPanel(
) {
Column {
// Copy text.
Box(modifier = Modifier
.fillMaxWidth()
.clickable {
// Copy text.
val clipData = AnnotatedString(message.content)
clipboardManager.setText(clipData)
Box(
modifier =
Modifier.fillMaxWidth().clickable {
// Copy text.
scope.launch {
val clipData = ClipData.newPlainText("message content", message.content)
val clipEntry = ClipEntry(clipData = clipData)
clipboard.setClipEntry(clipEntry = clipEntry)
}
// Hide sheet.
showMessageLongPressedSheet = false
// Hide sheet.
showMessageLongPressedSheet = false
// Show a snack bar.
scope.launch {
snackbarHostState.showSnackbar("Text copied to clipboard")
// Show a snack bar.
scope.launch { snackbarHostState.showSnackbar("Text copied to clipboard") }
}
}) {
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier.padding(vertical = 8.dp, horizontal = 16.dp)
modifier = Modifier.padding(vertical = 8.dp, horizontal = 16.dp),
) {
Icon(
Icons.Rounded.ContentCopy, contentDescription = "", modifier = Modifier.size(18.dp)
Icons.Rounded.ContentCopy,
contentDescription = "",
modifier = Modifier.size(18.dp),
)
Text("Copy text")
}
@ -584,25 +604,24 @@ fun ChatPanel(
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ChatPanelPreview() {
GalleryTheme {
val context = LocalContext.current
val task = TASK_TEST1
ChatPanel(
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
task = task,
selectedModel = TASK_TEST1.models[1],
viewModel = PreviewChatModel(context = context),
navigateUp = {},
onSendMessage = { _, _ -> },
onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> },
)
}
}
// @Preview(showBackground = true)
// @Composable
// fun ChatPanelPreview() {
// GalleryTheme {
// val context = LocalContext.current
// val task = TASK_TEST1
// ChatPanel(
// modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
// task = task,
// selectedModel = TASK_TEST1.models[1],
// viewModel = PreviewChatModel(context = context),
// navigateUp = {},
// onSendMessage = { _, _ -> },
// onRunAgainClicked = { _, _ -> },
// onBenchmarkClicked = { _, _, _, _ -> },
// )
// }
// }

View file

@ -16,6 +16,10 @@
package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.preview.PreviewChatModel
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.graphics.Bitmap
import android.util.Log
import androidx.activity.compose.BackHandler
@ -55,7 +59,6 @@ import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
@ -63,13 +66,9 @@ import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.ModelPageAppBar
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.modelmanager.PagerScrollState
import com.google.ai.edge.gallery.ui.preview.PreviewChatModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import kotlin.math.absoluteValue
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlin.math.absoluteValue
private const val TAG = "AGChatView"
@ -77,8 +76,8 @@ private const val TAG = "AGChatView"
* A composable that displays a chat interface, allowing users to interact with different models
* associated with a given task.
*
* This composable provides a horizontal pager for switching between models, a model selector
* for configuring the selected model, and a chat panel for sending and receiving messages. It also
* This composable provides a horizontal pager for switching between models, a model selector for
* configuring the selected model, and a chat panel for sending and receiving messages. It also
* manages model initialization, cleanup, and download status, and handles navigation and system
* back gestures.
*/
@ -104,8 +103,11 @@ fun ChatView(
var selectedImage by remember { mutableStateOf<Bitmap?>(null) }
var showImageViewer by remember { mutableStateOf(false) }
val pagerState = rememberPagerState(initialPage = task.models.indexOf(selectedModel),
pageCount = { task.models.size })
val pagerState =
rememberPagerState(
initialPage = task.models.indexOf(selectedModel),
pageCount = { task.models.size },
)
val context = LocalContext.current
val scope = rememberCoroutineScope()
var navigatingUp by remember { mutableStateOf(false) }
@ -138,7 +140,7 @@ fun ChatView(
val curSelectedModel = task.models[pagerState.settledPage]
Log.d(
TAG,
"Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model."
"Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model.",
)
if (curSelectedModel.name != selectedModel.name) {
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
@ -148,52 +150,49 @@ fun ChatView(
LaunchedEffect(pagerState) {
// Collect from the a snapshotFlow reading the currentPage
snapshotFlow { pagerState.currentPage }.collect { page ->
Log.d(TAG, "Page changed to $page")
}
snapshotFlow { pagerState.currentPage }.collect { page -> Log.d(TAG, "Page changed to $page") }
}
// Trigger scroll sync.
LaunchedEffect(pagerState) {
snapshotFlow {
PagerScrollState(
page = pagerState.currentPage, offset = pagerState.currentPageOffsetFraction
)
}.collect { scrollState ->
modelManagerViewModel.pagerScrollState.value = scrollState
}
PagerScrollState(
page = pagerState.currentPage,
offset = pagerState.currentPageOffsetFraction,
)
}
.collect { scrollState -> modelManagerViewModel.pagerScrollState.value = scrollState }
}
// Handle system's edge swipe.
BackHandler {
handleNavigateUp()
}
BackHandler { handleNavigateUp() }
Scaffold(modifier = modifier, topBar = {
ModelPageAppBar(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
canShowResetSessionButton = true,
isResettingSession = uiState.isResettingSession,
inProgress = uiState.inProgress,
modelPreparing = uiState.preparing,
onResetSessionClicked = onResetSessionClicked,
onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage(
oldConfigValues = old, newConfigValues = new, model = selectedModel
)
},
onBackClicked = {
handleNavigateUp()
},
onModelSelected = { model ->
scope.launch {
pagerState.animateScrollToPage(task.models.indexOf(model))
}
},
)
}) { innerPadding ->
Scaffold(
modifier = modifier,
topBar = {
ModelPageAppBar(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
canShowResetSessionButton = true,
isResettingSession = uiState.isResettingSession,
inProgress = uiState.inProgress,
modelPreparing = uiState.preparing,
onResetSessionClicked = onResetSessionClicked,
onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage(
oldConfigValues = old,
newConfigValues = new,
model = selectedModel,
)
},
onBackClicked = { handleNavigateUp() },
onModelSelected = { model ->
scope.launch { pagerState.animateScrollToPage(task.models.indexOf(model)) }
},
)
},
) { innerPadding ->
Box {
// A horizontal scrollable pager to switch between models.
HorizontalPager(state = pagerState) { pageIndex ->
@ -202,17 +201,20 @@ fun ChatView(
// Calculate the alpha of the current page based on how far they are from the center.
val pageOffset =
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction)
.absoluteValue
val curAlpha = 1f - pageOffset.coerceIn(0f, 1f)
Column(
modifier = Modifier
.padding(innerPadding)
.fillMaxSize()
.background(MaterialTheme.colorScheme.surface)
modifier =
Modifier.padding(innerPadding)
.fillMaxSize()
.background(MaterialTheme.colorScheme.surface)
) {
ModelDownloadStatusInfoPanel(
model = curSelectedModel, task = task, modelManagerViewModel = modelManagerViewModel
model = curSelectedModel,
task = task,
modelManagerViewModel = modelManagerViewModel,
)
// The main messages panel.
@ -230,19 +232,16 @@ fun ChatView(
onStreamEnd = { averageFps ->
viewModel.addMessage(
model = curSelectedModel,
message = ChatMessageInfo(content = "Live camera session ended. Average FPS: $averageFps")
message =
ChatMessageInfo(content = "Live camera session ended. Average FPS: $averageFps"),
)
},
onStopButtonClicked = {
onStopButtonClicked(curSelectedModel)
},
onStopButtonClicked = { onStopButtonClicked(curSelectedModel) },
onImageSelected = { bitmap ->
selectedImage = bitmap
showImageViewer = true
},
modifier = Modifier
.weight(1f)
.graphicsLayer { alpha = curAlpha },
modifier = Modifier.weight(1f).graphicsLayer { alpha = curAlpha },
chatInputType = chatInputType,
showStopButtonInInputWhenInProgress = showStopButtonInInputWhenInProgress,
)
@ -254,39 +253,43 @@ fun ChatView(
AnimatedVisibility(
visible = showImageViewer,
enter = slideInVertically(initialOffsetY = { fullHeight -> fullHeight }) + fadeIn(),
exit = slideOutVertically(
targetOffsetY = { fullHeight -> fullHeight },
) + fadeOut()
exit = slideOutVertically(targetOffsetY = { fullHeight -> fullHeight }) + fadeOut(),
) {
selectedImage?.let { image ->
ZoomableBox(
modifier = Modifier
.fillMaxSize()
.padding(top = innerPadding.calculateTopPadding())
.background(Color.Black.copy(alpha = 0.95f)),
modifier =
Modifier.fillMaxSize()
.padding(top = innerPadding.calculateTopPadding())
.background(Color.Black.copy(alpha = 0.95f))
) {
Image(
bitmap = image.asImageBitmap(), contentDescription = "",
modifier = modifier
.fillMaxSize()
.graphicsLayer(
scaleX = scale, scaleY = scale, translationX = offsetX, translationY = offsetY
),
bitmap = image.asImageBitmap(),
contentDescription = "",
modifier =
modifier
.fillMaxSize()
.graphicsLayer(
scaleX = scale,
scaleY = scale,
translationX = offsetX,
translationY = offsetY,
),
contentScale = ContentScale.Fit,
)
// Close button.
IconButton(
onClick = {
showImageViewer = false
}, colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceVariant,
), modifier = Modifier.offset(x = (-8).dp, y = 8.dp)
onClick = { showImageViewer = false },
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceVariant
),
modifier = Modifier.offset(x = (-8).dp, y = 8.dp),
) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
tint = MaterialTheme.colorScheme.primary,
)
}
}
@ -296,20 +299,20 @@ fun ChatView(
}
}
@Preview
@Composable
fun ChatScreenPreview() {
GalleryTheme {
val context = LocalContext.current
val task = TASK_TEST1
ChatView(
task = task,
viewModel = PreviewChatModel(context = context),
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
onSendMessage = { _, _ -> },
onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> },
navigateUp = {},
)
}
}
// @Preview
// @Composable
// fun ChatScreenPreview() {
// GalleryTheme {
// val context = LocalContext.current
// val task = TASK_TEST1
// ChatView(
// task = task,
// viewModel = PreviewChatModel(context = context),
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// onSendMessage = { _, _ -> },
// onRunAgainClicked = { _, _ -> },
// onBenchmarkClicked = { _, _, _, _ -> },
// navigateUp = {},
// )
// }
// }

View file

@ -18,9 +18,9 @@ package com.google.ai.edge.gallery.ui.common.chat
import android.util.Log
import androidx.lifecycle.ViewModel
import com.google.ai.edge.gallery.common.processLlmResponse
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.processLlmResponse
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
@ -28,14 +28,10 @@ import kotlinx.coroutines.flow.update
private const val TAG = "AGChatViewModel"
data class ChatUiState(
/**
* Indicates whether the runtime is currently processing a message.
*/
/** Indicates whether the runtime is currently processing a message. */
val inProgress: Boolean = false,
/**
* Indicates whether the session is being reset.
*/
/** Indicates whether the session is being reset. */
val isResettingSession: Boolean = false,
/**
@ -43,14 +39,10 @@ data class ChatUiState(
*/
val preparing: Boolean = false,
/**
* A map of model names to lists of chat messages.
*/
/** A map of model names to lists of chat messages. */
val messagesByModel: Map<String, MutableList<ChatMessage>> = mapOf(),
/**
* A map of model names to the currently streaming chat message.
*/
/** A map of model names to the currently streaming chat message. */
val streamingMessagesByModel: Map<String, ChatMessage> = mapOf(),
/*
@ -60,9 +52,7 @@ data class ChatUiState(
val showingStatsByModel: Map<String, MutableSet<ChatMessage>> = mapOf(),
)
/**
* ViewModel responsible for managing the chat UI state and handling chat-related operations.
*/
/** ViewModel responsible for managing the chat UI state and handling chat-related operations. */
open class ChatViewModel(val task: Task) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow()
@ -137,12 +127,13 @@ open class ChatViewModel(val task: Task) : ViewModel() {
val lastMessage = newMessages.last()
if (lastMessage is ChatMessageText) {
val newContent = processLlmResponse(response = "${lastMessage.content}${partialContent}")
val newLastMessage = ChatMessageText(
content = newContent,
side = lastMessage.side,
latencyMs = latencyMs,
accelerator = lastMessage.accelerator,
)
val newLastMessage =
ChatMessageText(
content = newContent,
side = lastMessage.side,
latencyMs = latencyMs,
accelerator = lastMessage.accelerator,
)
newMessages.removeAt(newMessages.size - 1)
newMessages.add(newLastMessage)
}
@ -154,7 +145,7 @@ open class ChatViewModel(val task: Task) : ViewModel() {
fun updateLastTextMessageLlmBenchmarkResult(
model: Model,
llmBenchmarkResult: ChatMessageBenchmarkLlmResult
llmBenchmarkResult: ChatMessageBenchmarkLlmResult,
) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
@ -215,12 +206,17 @@ open class ChatViewModel(val task: Task) : ViewModel() {
}
fun addConfigChangedMessage(
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
oldConfigValues: Map<String, Any>,
newConfigValues: Map<String, Any>,
model: Model,
) {
Log.d(TAG, "Adding config changed message. Old: ${oldConfigValues}, new: $newConfigValues")
val message = ChatMessageConfigValuesChange(
model = model, oldValues = oldConfigValues, newValues = newConfigValues
)
val message =
ChatMessageConfigValuesChange(
model = model,
oldValues = oldConfigValues,
newValues = newConfigValues,
)
addMessage(message = message, model = model)
}
@ -253,8 +249,6 @@ open class ChatViewModel(val task: Task) : ViewModel() {
}
messagesByModel[model.name] = messages
}
return ChatUiState(
messagesByModel = messagesByModel
)
return ChatUiState(messagesByModel = messagesByModel)
}
}

View file

@ -37,9 +37,8 @@ import com.google.ai.edge.gallery.ui.theme.labelSmallNarrowMedium
/**
* Composable function to display a data card with a label and a numeric value.
*
* This function renders a column containing a label and a formatted numeric value.
* It provides options for highlighting the value and displaying a placeholder when the value is not
* available.
* This function renders a column containing a label and a formatted numeric value. It provides
* options for highlighting the value and displaying a placeholder when the value is not available.
*/
@Composable
fun DataCard(
@ -47,7 +46,7 @@ fun DataCard(
value: Float?,
unit: String,
highlight: Boolean = false,
showPlaceholder: Boolean = false
showPlaceholder: Boolean = false,
) {
var strValue = "-"
Column {
@ -57,19 +56,13 @@ fun DataCard(
} else {
strValue = if (value == null) "-" else "%.2f".format(value)
if (highlight) {
Text(
strValue, style = bodySmallMediumNarrowBold, color = MaterialTheme.colorScheme.primary
)
Text(strValue, style = bodySmallMediumNarrowBold, color = MaterialTheme.colorScheme.primary)
} else {
Text(strValue, style = bodySmallMediumNarrow)
}
}
if (strValue != "-") {
Text(
unit, style = labelSmallNarrow, modifier = Modifier
.alpha(0.5f)
.offset(y = (-1).dp)
)
Text(unit, style = labelSmallNarrow, modifier = Modifier.alpha(0.5f).offset(y = (-1).dp))
}
}
}
@ -80,14 +73,26 @@ fun DataCardPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp), horizontalArrangement = Arrangement.spacedBy(16.dp)) {
DataCard(
label = "sum", value = 123.45f, unit = "ms", highlight = true, showPlaceholder = false
label = "sum",
value = 123.45f,
unit = "ms",
highlight = true,
showPlaceholder = false,
)
DataCard(
label = "average", value = 12.3f, unit = "ms", highlight = false, showPlaceholder = false
label = "average",
value = 12.3f,
unit = "ms",
highlight = false,
showPlaceholder = false,
)
DataCard(
label = "test", value = null, unit = "ms", highlight = false, showPlaceholder = false
label = "test",
value = null,
unit = "ms",
highlight = false,
showPlaceholder = false,
)
}
}
}
}

View file

@ -1,226 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap
import android.graphics.Matrix
import android.util.Size
import androidx.camera.core.CameraSelector
import androidx.camera.core.ImageAnalysis
import androidx.camera.core.resolutionselector.ResolutionSelector
import androidx.camera.core.resolutionselector.ResolutionStrategy
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.aspectRatio
import androidx.compose.foundation.layout.fillMaxHeight
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.Card
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.core.content.ContextCompat
import androidx.lifecycle.compose.LocalLifecycleOwner
import java.util.concurrent.Executors
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine
/**
* Composable function to display a live camera feed in a dialog.
*
* This function renders a dialog that displays a live camera preview, along with optional
* classification results and FPS information. It manages camera initialization, frame capture,
* and dialog dismissal.
*/
@Composable
fun LiveCameraDialog(
onDismissed: (averageFps: Int) -> Unit,
onBitmap: (Bitmap) -> Unit,
streamingMessage: ChatMessage? = null,
) {
val context = LocalContext.current
val lifecycleOwner = LocalLifecycleOwner.current
var imageBitmap by remember { mutableStateOf<ImageBitmap?>(null) }
var cameraProvider: ProcessCameraProvider? by remember { mutableStateOf(null) }
var sumFps by remember { mutableLongStateOf(0L) }
var fpsCount by remember { mutableLongStateOf(0L) }
LaunchedEffect(key1 = true) {
cameraProvider = startCamera(
context,
lifecycleOwner,
onBitmap = onBitmap,
onImageBitmap = { b -> imageBitmap = b })
}
Dialog(onDismissRequest = {
cameraProvider?.unbindAll()
onDismissed((sumFps / fpsCount).toInt())
}) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 8.dp)
) {
Text(
"Live camera",
style = MaterialTheme.typography.titleLarge,
)
if (streamingMessage != null) {
val fps = (1000f / streamingMessage.latencyMs).toInt()
sumFps += fps.toLong()
fpsCount += 1
Text(
"%d FPS".format(fps),
style = MaterialTheme.typography.titleLarge,
)
}
}
// Camera live view.
Row(
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f),
horizontalArrangement = Arrangement.Center
) {
val ib = imageBitmap
if (ib != null) {
Image(
bitmap = ib,
contentDescription = "",
modifier = Modifier
.fillMaxHeight()
.clip(RoundedCornerShape(8.dp)),
contentScale = ContentScale.Inside
)
}
}
// Result.
if (streamingMessage != null && streamingMessage is ChatMessageClassification) {
MessageBodyClassification(
message = streamingMessage,
modifier = Modifier.fillMaxWidth(),
oneLineLabel = true
)
}
// Button.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
TextButton(
onClick = {
cameraProvider?.unbindAll()
onDismissed((sumFps / fpsCount).toInt())
},
) {
Text("OK")
}
}
}
}
}
}
/**
* Asynchronously initializes and starts the camera for image capture and analysis.
*
* This function sets up the camera using CameraX, configures image analysis, and binds
* the camera lifecycle to the provided LifecycleOwner. It captures frames from the camera,
* converts them to Bitmaps and ImageBitmaps, and invokes the provided callbacks.
*/
private suspend fun startCamera(
context: android.content.Context,
lifecycleOwner: androidx.lifecycle.LifecycleOwner,
onBitmap: (Bitmap) -> Unit,
onImageBitmap: (ImageBitmap) -> Unit
): ProcessCameraProvider? = suspendCoroutine { continuation ->
val cameraProviderFuture = ProcessCameraProvider.getInstance(context)
cameraProviderFuture.addListener({
val cameraProvider = cameraProviderFuture.get()
val resolutionSelector = ResolutionSelector.Builder().setResolutionStrategy(
ResolutionStrategy(
Size(1080, 1080),
ResolutionStrategy.FALLBACK_RULE_CLOSEST_LOWER_THEN_HIGHER
)
).build()
val imageAnalysis =
ImageAnalysis.Builder().setResolutionSelector(resolutionSelector)
.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST).build().also {
it.setAnalyzer(Executors.newSingleThreadExecutor()) { imageProxy ->
var bitmap = imageProxy.toBitmap()
val rotation = imageProxy.imageInfo.rotationDegrees
bitmap = if (rotation != 0) {
val matrix = Matrix().apply {
postRotate(rotation.toFloat())
}
Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else bitmap
onBitmap(bitmap)
onImageBitmap(bitmap.asImageBitmap())
imageProxy.close()
}
}
val cameraSelector = CameraSelector.DEFAULT_BACK_CAMERA
try {
cameraProvider?.unbindAll()
cameraProvider?.bindToLifecycle(
lifecycleOwner, cameraSelector, imageAnalysis
)
// Resume with the provider
continuation.resume(cameraProvider)
} catch (exc: Exception) {
// todo: Handle exceptions (e.g., camera initialization failure)
}
}, ContextCompat.getMainExecutor(context))
}

View file

@ -16,16 +16,15 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.PlayArrow
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
@ -35,62 +34,56 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow
/**
* Composable function to display an action button below a chat message.
*/
/** Composable function to display an action button below a chat message. */
@Composable
fun MessageActionButton(
label: String,
icon: ImageVector,
onClick: () -> Unit,
modifier: Modifier = Modifier,
enabled: Boolean = true
enabled: Boolean = true,
) {
val curModifier = modifier
.padding(top = 4.dp)
.clip(CircleShape)
.background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh)
val curModifier =
modifier
.padding(top = 4.dp)
.clip(CircleShape)
.background(
if (enabled) MaterialTheme.colorScheme.secondaryContainer
else MaterialTheme.colorScheme.surfaceContainerHigh
)
val alpha: Float = if (enabled) 1.0f else 0.3f
Row(
modifier = if (enabled) curModifier.clickable { onClick() } else modifier,
verticalAlignment = Alignment.CenterVertically,
) {
Icon(
icon, contentDescription = "", modifier = Modifier
.size(16.dp)
.offset(x = 6.dp)
.alpha(alpha)
icon,
contentDescription = "",
modifier = Modifier.size(16.dp).offset(x = 6.dp).alpha(alpha),
)
Text(
label,
color = MaterialTheme.colorScheme.onSecondaryContainer,
style = bodySmallNarrow,
modifier = Modifier
.padding(
start = 10.dp, end = 8.dp, top = 4.dp, bottom = 4.dp
)
.alpha(alpha)
modifier = Modifier.padding(start = 10.dp, end = 8.dp, top = 4.dp, bottom = 4.dp).alpha(alpha),
)
}
}
@Preview(showBackground = true)
@Composable
fun MessageActionButtonPreview() {
GalleryTheme {
Column {
MessageActionButton(label = "run", icon = Icons.Default.PlayArrow, onClick = {})
MessageActionButton(
label = "run",
icon = Icons.Default.PlayArrow,
enabled = false,
onClick = {})
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageActionButtonPreview() {
// GalleryTheme {
// Column {
// MessageActionButton(label = "run", icon = Icons.Default.PlayArrow, onClick = {})
// MessageActionButton(
// label = "run",
// icon = Icons.Default.PlayArrow,
// enabled = false,
// onClick = {})
// }
// }
// }

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
@ -31,9 +33,7 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import kotlin.math.max
private const val DEFAULT_HISTOGRAM_BAR_HEIGHT = 50f
@ -41,37 +41,31 @@ private const val DEFAULT_HISTOGRAM_BAR_HEIGHT = 50f
/**
* Composable function to display benchmark results within a chat message.
*
* This function renders benchmark statistics (e.g., average latency) in data cards and
* visualizes the latency distribution using a histogram.
* This function renders benchmark statistics (e.g., average latency) in data cards and visualizes
* the latency distribution using a histogram.
*/
@Composable
fun MessageBodyBenchmark(message: ChatMessageBenchmarkResult) {
Column(
modifier = Modifier
.padding(12.dp)
.fillMaxWidth(),
verticalArrangement = Arrangement.spacedBy(8.dp)
modifier = Modifier.padding(12.dp).fillMaxWidth(),
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// Data cards.
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) {
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween) {
for (stat in message.orderedStats) {
DataCard(
label = stat.label,
unit = stat.unit,
value = message.statValues[stat.id],
highlight = stat.id == message.highlightStat,
showPlaceholder = message.isWarmingUp()
showPlaceholder = message.isWarmingUp(),
)
}
}
// Histogram
if (message.histogram.buckets.isNotEmpty()) {
Row(
horizontalArrangement = Arrangement.spacedBy(2.dp)
) {
Row(horizontalArrangement = Arrangement.spacedBy(2.dp)) {
for ((index, count) in message.histogram.buckets.withIndex()) {
var barBgColor = MaterialTheme.colorScheme.onSurfaceVariant
var alpha = 0.3f
@ -84,24 +78,24 @@ fun MessageBodyBenchmark(message: ChatMessageBenchmarkResult) {
}
// Bar container.
Column(
modifier = Modifier
.height(DEFAULT_HISTOGRAM_BAR_HEIGHT.dp)
.width(4.dp),
modifier = Modifier.height(DEFAULT_HISTOGRAM_BAR_HEIGHT.dp).width(4.dp),
verticalArrangement = Arrangement.Bottom,
) {
// Bar content.
Box(
modifier = Modifier
.height(
max(
1f,
count.toFloat() / message.histogram.maxCount.toFloat() * DEFAULT_HISTOGRAM_BAR_HEIGHT
).dp
)
.fillMaxWidth()
.clip(RoundedCornerShape(20.dp, 20.dp, 0.dp, 0.dp))
.alpha(alpha)
.background(barBgColor)
modifier =
Modifier.height(
max(
1f,
count.toFloat() / message.histogram.maxCount.toFloat() *
DEFAULT_HISTOGRAM_BAR_HEIGHT,
)
.dp
)
.fillMaxWidth()
.clip(RoundedCornerShape(20.dp, 20.dp, 0.dp, 0.dp))
.alpha(alpha)
.background(barBgColor)
)
}
}
@ -110,31 +104,31 @@ fun MessageBodyBenchmark(message: ChatMessageBenchmarkResult) {
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyBenchmarkPreview() {
GalleryTheme {
MessageBodyBenchmark(
message = ChatMessageBenchmarkResult(
orderedStats = listOf(
Stat(id = "stat1", label = "Stat1", unit = "ms"),
Stat(id = "stat2", label = "Stat2", unit = "ms"),
Stat(id = "stat3", label = "Stat3", unit = "ms"),
Stat(id = "stat4", label = "Stat4", unit = "ms")
),
statValues = mutableMapOf(
"stat1" to 0.3f,
"stat2" to 0.4f,
"stat3" to 0.5f,
),
values = listOf(),
histogram = Histogram(listOf(), 0),
warmupCurrent = 0,
warmupTotal = 0,
iterationCurrent = 0,
iterationTotal = 0,
highlightStat = "stat2"
)
)
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageBodyBenchmarkPreview() {
// GalleryTheme {
// MessageBodyBenchmark(
// message = ChatMessageBenchmarkResult(
// orderedStats = listOf(
// Stat(id = "stat1", label = "Stat1", unit = "ms"),
// Stat(id = "stat2", label = "Stat2", unit = "ms"),
// Stat(id = "stat3", label = "Stat3", unit = "ms"),
// Stat(id = "stat4", label = "Stat4", unit = "ms")
// ),
// statValues = mutableMapOf(
// "stat1" to 0.3f,
// "stat2" to 0.4f,
// "stat3" to 0.5f,
// ),
// values = listOf(),
// histogram = Histogram(listOf(), 0),
// warmupCurrent = 0,
// warmupTotal = 0,
// iterationCurrent = 0,
// iterationTotal = 0,
// highlightStat = "stat2"
// )
// )
// }
// }

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
@ -23,9 +25,7 @@ import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display benchmark LLM results within a chat message.
@ -34,41 +34,32 @@ import com.google.ai.edge.gallery.ui.theme.GalleryTheme
*/
@Composable
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult, modifier: Modifier = Modifier) {
Column(
modifier = modifier.padding(12.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
Column(modifier = modifier.padding(12.dp), verticalArrangement = Arrangement.spacedBy(8.dp)) {
// Data cards.
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) {
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween) {
for (stat in message.orderedStats) {
DataCard(
label = stat.label,
unit = stat.unit,
value = message.statValues[stat.id],
)
DataCard(label = stat.label, unit = stat.unit, value = message.statValues[stat.id])
}
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyBenchmarkLlmPreview() {
GalleryTheme {
MessageBodyBenchmarkLlm(
message = ChatMessageBenchmarkLlmResult(
orderedStats = listOf(
Stat(id = "stat1", label = "Stat1", unit = "tokens/s"),
Stat(id = "stat2", label = "Stat2", unit = "tokens/s")
),
statValues = mutableMapOf(
"stat1" to 0.3f,
"stat2" to 0.4f,
),
running = false,
)
)
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageBodyBenchmarkLlmPreview() {
// GalleryTheme {
// MessageBodyBenchmarkLlm(
// message = ChatMessageBenchmarkLlmResult(
// orderedStats = listOf(
// Stat(id = "stat1", label = "Stat1", unit = "tokens/s"),
// Stat(id = "stat2", label = "Stat2", unit = "tokens/s")
// ),
// statValues = mutableMapOf(
// "stat1" to 0.3f,
// "stat2" to 0.4f,
// ),
// running = false,
// )
// )
// }
// }

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
@ -32,11 +34,8 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
val CLASSIFICATION_BAR_HEIGHT = 8.dp
val CLASSIFICATION_BAR_MAX_WIDTH = 200.dp
@ -44,7 +43,8 @@ val CLASSIFICATION_BAR_MAX_WIDTH = 200.dp
/**
* Composable function to display classification results.
*
* This function renders a list of classifications, each with its label, score, and a visual score bar.
* This function renders a list of classifications, each with its label, score, and a visual score
* bar.
*/
@Composable
fun MessageBodyClassification(
@ -52,45 +52,40 @@ fun MessageBodyClassification(
modifier: Modifier = Modifier,
oneLineLabel: Boolean = false,
) {
Column(
modifier = modifier.padding(12.dp)
) {
Column(modifier = modifier.padding(12.dp)) {
for (classification in message.classifications) {
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) {
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween) {
// Classification label.
Text(
classification.label,
maxLines = if (oneLineLabel) 1 else Int.MAX_VALUE,
overflow = TextOverflow.Ellipsis,
style = MaterialTheme.typography.bodySmall,
modifier = Modifier.weight(1f)
modifier = Modifier.weight(1f),
)
// Classification score.
Text(
"%.2f".format(classification.score),
style = MaterialTheme.typography.bodySmall,
modifier = Modifier
.align(Alignment.Bottom),
modifier = Modifier.align(Alignment.Bottom),
)
}
Spacer(modifier = Modifier.height(2.dp))
// Score bar.
Box {
Box(
modifier = Modifier
.fillMaxWidth()
.height(CLASSIFICATION_BAR_HEIGHT)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceDim)
modifier =
Modifier.fillMaxWidth()
.height(CLASSIFICATION_BAR_HEIGHT)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceDim)
)
Box(
modifier = Modifier
.fillMaxWidth(classification.score)
.height(CLASSIFICATION_BAR_HEIGHT)
.clip(CircleShape)
.background(classification.color)
modifier =
Modifier.fillMaxWidth(classification.score)
.height(CLASSIFICATION_BAR_HEIGHT)
.clip(CircleShape)
.background(classification.color)
)
}
Spacer(modifier = Modifier.height(6.dp))
@ -98,18 +93,20 @@ fun MessageBodyClassification(
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyClassificationPreview() {
GalleryTheme {
MessageBodyClassification(
message = ChatMessageClassification(
classifications = listOf(
Classification(label = "label1", score = 0.3f, color = Color.Red),
Classification(label = "label2", score = 0.7f, color = Color.Blue)
),
latencyMs = 12345f,
),
)
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageBodyClassificationPreview() {
// GalleryTheme {
// MessageBodyClassification(
// message =
// ChatMessageClassification(
// classifications =
// listOf(
// Classification(label = "label1", score = 0.3f, color = Color.Red),
// Classification(label = "label2", score = 0.7f, color = Color.Blue),
// ),
// latencyMs = 12345f,
// )
// )
// }
// }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
@ -34,33 +37,26 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.ui.common.convertValueToTargetType
import com.google.ai.edge.gallery.ui.common.getConfigValueString
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.data.getConfigValueString
import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow
import com.google.ai.edge.gallery.ui.theme.titleSmaller
/**
* Composable function to display a message indicating configuration value changes.
*
* This function renders a centered row containing a box that displays the old and new
* values of configuration settings that have been updated.
* This function renders a centered row containing a box that displays the old and new values of
* configuration settings that have been updated.
*/
@Composable
fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) {
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.Center,
) {
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) {
Box(
modifier = Modifier
.clip(RoundedCornerShape(4.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer)
modifier =
Modifier.clip(RoundedCornerShape(4.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer)
) {
Column(modifier = Modifier.padding(8.dp)) {
// Title.
@ -74,11 +70,7 @@ fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) {
// Keys
Column {
for (config in message.model.configs) {
Text(
"${config.key.label}:",
style = bodySmallNarrow,
modifier = Modifier.alpha(0.6f),
)
Text("${config.key.label}:", style = bodySmallNarrow, modifier = Modifier.alpha(0.6f))
}
}
@ -88,23 +80,25 @@ fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) {
Column {
for (config in message.model.configs) {
val key = config.key.label
val oldValue: Any = convertValueToTargetType(
value = message.oldValues.getValue(key), valueType = config.valueType
)
val newValue: Any = convertValueToTargetType(
value = message.newValues.getValue(key), valueType = config.valueType
)
val oldValue: Any =
convertValueToTargetType(
value = message.oldValues.getValue(key),
valueType = config.valueType,
)
val newValue: Any =
convertValueToTargetType(
value = message.newValues.getValue(key),
valueType = config.valueType,
)
if (oldValue == newValue) {
Text("$newValue", style = bodySmallNarrow)
} else {
Row(verticalAlignment = Alignment.CenterVertically) {
Text(
getConfigValueString(oldValue, config), style = bodySmallNarrow
)
Text(getConfigValueString(oldValue, config), style = bodySmallNarrow)
Text(
"",
style = bodySmallNarrow.copy(fontSize = 12.sp),
modifier = Modifier.padding(start = 4.dp, end = 4.dp)
modifier = Modifier.padding(start = 4.dp, end = 4.dp),
)
Text(
getConfigValueString(newValue, config),
@ -121,24 +115,24 @@ fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) {
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyConfigUpdatePreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyConfigUpdate(
message = ChatMessageConfigValuesChange(
model = MODEL_TEST1,
oldValues = mapOf(
ConfigKey.MAX_RESULT_COUNT.label to 100,
ConfigKey.USE_GPU.label to false
),
newValues = mapOf(
ConfigKey.MAX_RESULT_COUNT.label to 200,
ConfigKey.USE_GPU.label to true
)
)
)
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageBodyConfigUpdatePreview() {
// GalleryTheme {
// Row(modifier = Modifier.padding(16.dp)) {
// MessageBodyConfigUpdate(
// message = ChatMessageConfigValuesChange(
// model = MODEL_TEST1,
// oldValues = mapOf(
// ConfigKey.MAX_RESULT_COUNT.label to 100,
// ConfigKey.USE_GPU.label to false
// ),
// newValues = mapOf(
// ConfigKey.MAX_RESULT_COUNT.label to 200,
// ConfigKey.USE_GPU.label to true
// )
// )
// )
// }
// }
// }

View file

@ -36,9 +36,7 @@ fun MessageBodyImage(message: ChatMessageImage, modifier: Modifier = Modifier) {
Image(
bitmap = message.imageBitMap,
contentDescription = "",
modifier = modifier
.height(imageHeight.dp)
.width(imageWidth.dp),
modifier = modifier.height(imageHeight.dp).width(imageWidth.dp),
contentScale = ContentScale.Fit,
)
}
}

View file

@ -43,7 +43,7 @@ import androidx.compose.ui.unit.dp
@Composable
fun MessageBodyImageWithHistory(
message: ChatMessageImageWithHistory,
imageHistoryCurIndex: MutableIntState
imageHistoryCurIndex: MutableIntState,
) {
val prevMessage: MutableState<ChatMessageImageWithHistory?> = remember { mutableStateOf(null) }
@ -68,15 +68,15 @@ fun MessageBodyImageWithHistory(
Image(
bitmap = curImageBitmap,
contentDescription = "",
modifier = Modifier
.height(imageHeight.dp)
.width(imageWidth.dp)
.pointerInput(Unit) {
detectHorizontalDragGestures(onDragStart = {
value = 0f
savedIndex = imageHistoryCurIndex.intValue
}) { _, dragAmount ->
value += (dragAmount / 20f)// Adjust sensitivity here
modifier =
Modifier.height(imageHeight.dp).width(imageWidth.dp).pointerInput(Unit) {
detectHorizontalDragGestures(
onDragStart = {
value = 0f
savedIndex = imageHistoryCurIndex.intValue
}
) { _, dragAmount ->
value += (dragAmount / 20f) // Adjust sensitivity here
imageHistoryCurIndex.intValue = (savedIndex + value).toInt()
if (imageHistoryCurIndex.intValue < 0) {
imageHistoryCurIndex.intValue = 0
@ -88,4 +88,4 @@ fun MessageBodyImageWithHistory(
contentScale = ContentScale.Fit,
)
}
}
}

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
@ -27,9 +29,8 @@ import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.common.MarkdownText
import com.google.ai.edge.gallery.ui.theme.customColors
/**
@ -39,29 +40,27 @@ import com.google.ai.edge.gallery.ui.theme.customColors
*/
@Composable
fun MessageBodyInfo(message: ChatMessageInfo, smallFontSize: Boolean = true) {
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
) {
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) {
Box(
modifier = Modifier
.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor)
modifier =
Modifier.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor)
) {
MarkdownText(
text = message.content,
modifier = Modifier.padding(12.dp),
smallFontSize = smallFontSize
smallFontSize = smallFontSize,
)
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyInfoPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyInfo(message = ChatMessageInfo(content = "This is a model"))
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageBodyInfoPreview() {
// GalleryTheme {
// Row(modifier = Modifier.padding(16.dp)) {
// MessageBodyInfo(message = ChatMessageInfo(content = "This is a model"))
// }
// }
// }

View file

@ -16,12 +16,12 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
@ -33,29 +33,21 @@ import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.common.getTaskIconColor
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private val IMAGE_RESOURCES = listOf(
R.drawable.pantegon,
R.drawable.double_circle,
R.drawable.circle,
R.drawable.four_circle
)
private val IMAGE_RESOURCES =
listOf(R.drawable.pantegon, R.drawable.double_circle, R.drawable.circle, R.drawable.four_circle)
private const val ANIMATION_DURATION = 300
private const val ANIMATION_DURATION2 = 300
private const val PAUSE_DURATION = 200
private const val PAUSE_DURATION2 = 0
/**
* Composable function to display a loading indicator.
*/
/** Composable function to display a loading indicator. */
@Composable
fun MessageBodyLoading() {
val progress = remember { Animatable(0f) }
@ -67,18 +59,17 @@ fun MessageBodyLoading() {
var progressJob = launch {
progress.animateTo(
targetValue = 1f,
animationSpec = tween(
durationMillis = ANIMATION_DURATION,
easing = multiBounceEasing(bounces = 3, decay = 0.02f)
)
animationSpec =
tween(
durationMillis = ANIMATION_DURATION,
easing = multiBounceEasing(bounces = 3, decay = 0.02f),
),
)
}
var alphaJob = launch {
alphaAnim.animateTo(
targetValue = 1f,
animationSpec = tween(
durationMillis = ANIMATION_DURATION / 2,
)
animationSpec = tween(durationMillis = ANIMATION_DURATION / 2),
)
}
progressJob.join()
@ -88,18 +79,17 @@ fun MessageBodyLoading() {
progressJob = launch {
progress.animateTo(
targetValue = 0f,
animationSpec = tween(
durationMillis = ANIMATION_DURATION2,
easing = multiBounceEasing(bounces = 3, decay = 0.02f)
)
animationSpec =
tween(
durationMillis = ANIMATION_DURATION2,
easing = multiBounceEasing(bounces = 3, decay = 0.02f),
),
)
}
alphaJob = launch {
alphaAnim.animateTo(
targetValue = 0f,
animationSpec = tween(
durationMillis = ANIMATION_DURATION2 / 2,
)
animationSpec = tween(durationMillis = ANIMATION_DURATION2 / 2),
)
}
@ -118,25 +108,21 @@ fun MessageBodyLoading() {
contentDescription = "",
contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint(getTaskIconColor(index = index)),
modifier = Modifier
.graphicsLayer {
scaleX = progress.value * 0.2f + 0.8f
scaleY = progress.value * 0.2f + 0.8f
rotationZ = progress.value * 100
alpha = if (index != activeImageIndex.intValue) 0f else alphaAnim.value
}
.size(24.dp)
modifier =
Modifier.graphicsLayer {
scaleX = progress.value * 0.2f + 0.8f
scaleY = progress.value * 0.2f + 0.8f
rotationZ = progress.value * 100
alpha = if (index != activeImageIndex.intValue) 0f else alphaAnim.value
}
.size(24.dp),
)
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyLoadingPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyLoading()
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageBodyLoadingPreview() {
// GalleryTheme { Row(modifier = Modifier.padding(16.dp)) { MessageBodyLoading() } }
// }

View file

@ -16,13 +16,17 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.ALL_PREVIEW_TASKS
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
@ -41,13 +45,10 @@ import androidx.compose.ui.draw.shadow
import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.PromptTemplate
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.getTaskIconColor
import com.google.ai.edge.gallery.ui.preview.ALL_PREVIEW_TASKS
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
private const val CARD_HEIGHT = 100
@ -63,16 +64,15 @@ fun MessageBodyPromptTemplates(
Column(
modifier = Modifier.padding(top = 12.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
Text(
"Try an example prompt",
style = MaterialTheme.typography.titleLarge.copy(
fontWeight = FontWeight.Bold,
brush = Brush.linearGradient(
colors = gradientColors,
)
),
style =
MaterialTheme.typography.titleLarge.copy(
fontWeight = FontWeight.Bold,
brush = Brush.linearGradient(colors = gradientColors),
),
modifier = Modifier.fillMaxWidth(),
textAlign = TextAlign.Center,
)
@ -80,41 +80,30 @@ fun MessageBodyPromptTemplates(
Text(
"Or make your own",
style = MaterialTheme.typography.titleSmall,
modifier = Modifier
.fillMaxWidth()
.offset(y = (-4).dp),
modifier = Modifier.fillMaxWidth().offset(y = (-4).dp),
textAlign = TextAlign.Center,
)
}
LazyColumn(
modifier = Modifier
.height((rowCount * (CARD_HEIGHT + 8)).dp),
modifier = Modifier.height((rowCount * (CARD_HEIGHT + 8)).dp),
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// Cards.
items(message.templates) { template ->
Box(
modifier = Modifier
.border(
width = 1.dp,
color = color.copy(alpha = 0.3f),
shape = RoundedCornerShape(24.dp)
)
.height(CARD_HEIGHT.dp)
.shadow(
elevation = 2.dp,
shape = RoundedCornerShape(24.dp),
spotColor = color
)
.background(MaterialTheme.colorScheme.surface)
.clickable {
onPromptClicked(template)
}
modifier =
Modifier.border(
width = 1.dp,
color = color.copy(alpha = 0.3f),
shape = RoundedCornerShape(24.dp),
)
.height(CARD_HEIGHT.dp)
.shadow(elevation = 2.dp, shape = RoundedCornerShape(24.dp), spotColor = color)
.background(MaterialTheme.colorScheme.surface)
.clickable { onPromptClicked(template) }
) {
Column(
modifier = Modifier
.padding(horizontal = 12.dp, vertical = 20.dp)
.fillMaxSize(),
modifier = Modifier.padding(horizontal = 12.dp, vertical = 20.dp).fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
) {
Text(
@ -134,35 +123,37 @@ fun MessageBodyPromptTemplates(
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyPromptTemplatesPreview() {
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess()
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageBodyPromptTemplatesPreview() {
// for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
// task.index = index
// for (model in task.models) {
// model.preProcess()
// }
// }
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyPromptTemplates(
message = ChatMessagePromptTemplates(
templates = listOf(
PromptTemplate(
title = "Math Worksheets",
description = "Create a set of math worksheets for parents",
prompt = ""
),
PromptTemplate(
title = "Shape Sequencer",
description = "Find the next shape in a sequence",
prompt = ""
)
)
),
task = TASK_TEST1,
)
}
}
}
// GalleryTheme {
// Row(modifier = Modifier.padding(16.dp)) {
// MessageBodyPromptTemplates(
// message =
// ChatMessagePromptTemplates(
// templates =
// listOf(
// PromptTemplate(
// title = "Math Worksheets",
// description = "Create a set of math worksheets for parents",
// prompt = "",
// ),
// PromptTemplate(
// title = "Shape Sequencer",
// description = "Find the next shape in a sequence",
// prompt = "",
// ),
// )
// ),
// task = TASK_TEST1,
// )
// }
// }
// }

View file

@ -16,9 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
// import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
@ -26,13 +26,10 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.common.MarkdownText
/**
* Composable function to display the text content of a ChatMessageText.
*/
/** Composable function to display the text content of a ChatMessageText. */
@Composable
fun MessageBodyText(message: ChatMessageText) {
if (message.side == ChatSide.USER) {
@ -40,7 +37,7 @@ fun MessageBodyText(message: ChatMessageText) {
message.content,
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.Medium),
color = Color.White,
modifier = Modifier.padding(12.dp)
modifier = Modifier.padding(12.dp),
)
} else if (message.side == ChatSide.AGENT) {
if (message.isMarkdown) {
@ -50,31 +47,25 @@ fun MessageBodyText(message: ChatMessageText) {
message.content,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.Medium),
color = MaterialTheme.colorScheme.onSurface,
modifier = Modifier.padding(12.dp)
modifier = Modifier.padding(12.dp),
)
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyTextPreview() {
GalleryTheme {
Column {
Row(
modifier = Modifier
.padding(16.dp)
.background(MaterialTheme.colorScheme.primary),
) {
MessageBodyText(ChatMessageText(content = "Hello world", side = ChatSide.USER))
}
Row(
modifier = Modifier
.padding(16.dp)
.background(MaterialTheme.colorScheme.surfaceContainer),
) {
MessageBodyText(ChatMessageText(content = "yes hello world", side = ChatSide.AGENT))
}
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageBodyTextPreview() {
// GalleryTheme {
// Column {
// Row(modifier = Modifier.padding(16.dp).background(MaterialTheme.colorScheme.primary)) {
// MessageBodyText(ChatMessageText(content = "Hello world", side = ChatSide.USER))
// }
// Row(
// modifier = Modifier.padding(16.dp).background(MaterialTheme.colorScheme.surfaceContainer)
// ) {
// MessageBodyText(ChatMessageText(content = "yes hello world", side = ChatSide.AGENT))
// }
// }
// }
// }

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
@ -27,9 +29,8 @@ import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.common.MarkdownText
/**
* Composable function to display warning message content within a chat.
@ -38,29 +39,27 @@ import com.google.ai.edge.gallery.ui.theme.GalleryTheme
*/
@Composable
fun MessageBodyWarning(message: ChatMessageWarning) {
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
) {
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) {
Box(
modifier = Modifier
.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer)
modifier =
Modifier.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer)
) {
MarkdownText(
text = message.content,
modifier = Modifier.padding(horizontal = 16.dp, vertical = 6.dp),
smallFontSize = true
smallFontSize = true,
)
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyWarningPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyWarning(message = ChatMessageWarning(content = "This is a warning"))
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageBodyWarningPreview() {
// GalleryTheme {
// Row(modifier = Modifier.padding(16.dp)) {
// MessageBodyWarning(message = ChatMessageWarning(content = "This is a warning"))
// }
// }
// }

View file

@ -29,41 +29,39 @@ import androidx.compose.ui.unit.LayoutDirection
/**
* Custom Shape for creating message bubble outlines with configurable corner radii.
*
* This class defines a custom Shape that generates a rounded rectangle outline,
* suitable for message bubbles. It allows specifying a uniform corner radius for
* most corners, but also provides the option to have a hard (non-rounded) corner
* on either the left or right side.
* This class defines a custom Shape that generates a rounded rectangle outline, suitable for
* message bubbles. It allows specifying a uniform corner radius for most corners, but also provides
* the option to have a hard (non-rounded) corner on either the left or right side.
*/
class MessageBubbleShape(
private val radius: Dp,
private val hardCornerAtLeftOrRight: Boolean = false
private val hardCornerAtLeftOrRight: Boolean = false,
) : Shape {
override fun createOutline(
size: Size,
layoutDirection: LayoutDirection,
density: Density
density: Density,
): Outline {
val radiusPx = with(density) { radius.toPx() }
val path = Path().apply {
addRoundRect(
RoundRect(
left = 0f,
top = 0f,
right = size.width,
bottom = size.height,
topLeftCornerRadius = if (hardCornerAtLeftOrRight) CornerRadius(0f, 0f) else CornerRadius(
radiusPx,
radiusPx
),
topRightCornerRadius = if (hardCornerAtLeftOrRight) CornerRadius(
radiusPx,
radiusPx
) else CornerRadius(0f, 0f), // No rounding here
bottomLeftCornerRadius = CornerRadius(radiusPx, radiusPx),
bottomRightCornerRadius = CornerRadius(radiusPx, radiusPx)
val path =
Path().apply {
addRoundRect(
RoundRect(
left = 0f,
top = 0f,
right = size.width,
bottom = size.height,
topLeftCornerRadius =
if (hardCornerAtLeftOrRight) CornerRadius(0f, 0f)
else CornerRadius(radiusPx, radiusPx),
topRightCornerRadius =
if (hardCornerAtLeftOrRight) CornerRadius(radiusPx, radiusPx)
else CornerRadius(0f, 0f), // No rounding here
bottomLeftCornerRadius = CornerRadius(radiusPx, radiusPx),
bottomRightCornerRadius = CornerRadius(radiusPx, radiusPx),
)
)
)
}
}
return Outline.Generic(path)
}
}

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.Manifest
import android.content.Context
import android.content.pm.PackageManager
@ -28,7 +31,6 @@ import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.PickVisualMediaRequest
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
@ -49,11 +51,9 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.core.content.ContextCompat
import com.google.ai.edge.gallery.ui.common.createTempPictureUri
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
private const val TAG = "AGMessageInputImage"
@ -102,30 +102,28 @@ fun MessageInputImage(
}
// Permission request when taking picture.
val takePicturePermissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) { permissionGranted ->
if (permissionGranted) {
tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri)
val takePicturePermissionLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
permissionGranted ->
if (permissionGranted) {
tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri)
}
}
}
// Permission request when using live camera.
val liveCameraPermissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) { permissionGranted ->
if (permissionGranted) {
showLiveCameraDialog = true
val liveCameraPermissionLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
permissionGranted ->
if (permissionGranted) {
showLiveCameraDialog = true
}
}
}
val buttonAlpha = if (disableButtons) 0.3f else 1f
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp),
modifier = Modifier.fillMaxWidth().padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.End,
) {
@ -139,9 +137,8 @@ fun MessageInputImage(
// Launch the photo picker and let the user choose only images.
pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly))
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.primary,
),
colors =
IconButtonDefaults.iconButtonColors(containerColor = MaterialTheme.colorScheme.primary),
modifier = Modifier.alpha(buttonAlpha),
) {
Icon(Icons.Rounded.Photo, contentDescription = "", tint = MaterialTheme.colorScheme.onPrimary)
@ -157,9 +154,7 @@ fun MessageInputImage(
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(
context, Manifest.permission.CAMERA
) -> {
ContextCompat.checkSelfPermission(context, Manifest.permission.CAMERA) -> {
tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri)
}
@ -170,15 +165,14 @@ fun MessageInputImage(
}
}
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.primary,
),
colors =
IconButtonDefaults.iconButtonColors(containerColor = MaterialTheme.colorScheme.primary),
modifier = Modifier.alpha(buttonAlpha),
) {
Icon(
Icons.Rounded.PhotoCamera,
contentDescription = "",
tint = MaterialTheme.colorScheme.onPrimary
tint = MaterialTheme.colorScheme.onPrimary,
)
}
@ -192,9 +186,7 @@ fun MessageInputImage(
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(
context, Manifest.permission.CAMERA
) -> {
ContextCompat.checkSelfPermission(context, Manifest.permission.CAMERA) -> {
showLiveCameraDialog = true
}
@ -204,25 +196,30 @@ fun MessageInputImage(
}
}
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.primary,
),
colors =
IconButtonDefaults.iconButtonColors(containerColor = MaterialTheme.colorScheme.primary),
modifier = Modifier.alpha(buttonAlpha),
) {
Icon(
Icons.Rounded.Videocam, contentDescription = "", tint = MaterialTheme.colorScheme.onPrimary
Icons.Rounded.Videocam,
contentDescription = "",
tint = MaterialTheme.colorScheme.onPrimary,
)
}
}
// Live camera stream dialog.
if (showLiveCameraDialog) {
LiveCameraDialog(
streamingMessage = streamingMessage, onDismissed = { averageFps ->
onStreamEnd(averageFps)
showLiveCameraDialog = false
}, onBitmap = onStreamImage
)
// TODO(migration)
//
// LiveCameraDialog(
// streamingMessage = streamingMessage,
// onDismissed = { averageFps ->
// onStreamEnd(averageFps)
// showLiveCameraDialog = false
// },
// onBitmap = onStreamImage,
// )
}
}
@ -237,33 +234,33 @@ private fun handleImageSelected(
) {
Log.d(TAG, "Selected URI: $uri")
val bitmap: Bitmap? = try {
val inputStream = context.contentResolver.openInputStream(uri)
val tmpBitmap = BitmapFactory.decodeStream(inputStream)
if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) {
val matrix = Matrix()
matrix.postRotate(90f)
Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true)
} else {
tmpBitmap
val bitmap: Bitmap? =
try {
val inputStream = context.contentResolver.openInputStream(uri)
val tmpBitmap = BitmapFactory.decodeStream(inputStream)
if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) {
val matrix = Matrix()
matrix.postRotate(90f)
Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true)
} else {
tmpBitmap
}
} catch (e: Exception) {
e.printStackTrace()
null
}
} catch (e: Exception) {
e.printStackTrace()
null
}
if (bitmap != null) {
onImageSelected(bitmap)
}
}
@Preview(showBackground = true)
@Composable
fun MessageInputImagePreview() {
GalleryTheme {
Column {
MessageInputImage(onImageSelected = {})
MessageInputImage(disableButtons = true, onImageSelected = {})
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageInputImagePreview() {
// GalleryTheme {
// Column {
// MessageInputImage(onImageSelected = {})
// MessageInputImage(disableButtons = true, onImageSelected = {})
// }
// }
// }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.Manifest
import android.content.Context
import android.content.pm.PackageManager
@ -43,6 +46,7 @@ import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.horizontalScroll
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
@ -55,6 +59,7 @@ import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
@ -98,26 +103,22 @@ import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.viewinterop.AndroidView
import androidx.core.content.ContextCompat
import androidx.lifecycle.compose.LocalLifecycleOwner
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.common.createTempPictureUri
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.launch
import java.util.concurrent.Executors
import kotlinx.coroutines.launch
private const val TAG = "AGMessageInputText"
/**
* Composable function to display a text input field for composing chat messages.
*
* This function renders a row containing a text field for message input and a send button.
* It handles message composition, input validation, and sending messages.
* This function renders a row containing a text field for message input and a send button. It
* handles message composition, input validation, and sending messages.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
@ -126,7 +127,7 @@ fun MessageInputText(
curMessage: String,
isResettingSession: Boolean,
inProgress: Boolean,
hasImageMessage: Boolean,
imageMessageCount: Int,
modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit,
@ -145,239 +146,254 @@ fun MessageInputText(
var showTextInputHistorySheet by remember { mutableStateOf(false) }
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
val updatePickedImages: (Bitmap) -> Unit = { bitmap ->
val newPickedImages: MutableList<Bitmap> = mutableListOf()
val updatePickedImages: (List<Bitmap>) -> Unit = { bitmaps ->
var newPickedImages: MutableList<Bitmap> = mutableListOf()
newPickedImages.addAll(pickedImages)
newPickedImages.add(bitmap)
newPickedImages.addAll(bitmaps)
if (newPickedImages.size > MAX_IMAGE_COUNT) {
newPickedImages = newPickedImages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
}
pickedImages = newPickedImages.toList()
}
var hasFrontCamera by remember { mutableStateOf(false) }
LaunchedEffect(Unit) {
checkFrontCamera(context = context, callback = { hasFrontCamera = it })
}
LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) }
// Permission request when taking picture.
val takePicturePermissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) { permissionGranted ->
if (permissionGranted) {
showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri()
showCameraCaptureBottomSheet = true
val takePicturePermissionLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
permissionGranted ->
if (permissionGranted) {
showAddContentMenu = false
showCameraCaptureBottomSheet = true
}
}
}
// Registers a photo picker activity launcher in single-select mode.
val pickMedia =
rememberLauncherForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri ->
// Callback is invoked after the user selects a media item or closes the
rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris ->
// Callback is invoked after the user selects media items or closes the
// photo picker.
if (uri != null) {
handleImageSelected(context = context, uri = uri, onImageSelected = { bitmap ->
updatePickedImages(bitmap)
})
if (uris.isNotEmpty()) {
handleImagesSelected(
context = context,
uris = uris,
onImagesSelected = { bitmaps -> updatePickedImages(bitmaps) },
)
}
}
Box(contentAlignment = Alignment.CenterStart) {
Column {
// A preview panel for the selected image.
if (pickedImages.isNotEmpty()) {
Box(
contentAlignment = Alignment.TopEnd, modifier = Modifier.offset(x = 16.dp, y = (-80).dp)
Row(
modifier =
Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()),
horizontalArrangement = Arrangement.spacedBy(16.dp),
) {
Image(
bitmap = pickedImages.last().asImageBitmap(),
contentDescription = "",
modifier = Modifier
.height(80.dp)
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
.clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
)
Box(modifier = Modifier
.offset(x = 10.dp, y = (-10).dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surface)
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
.clickable {
pickedImages = listOf()
}) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
modifier = Modifier
.padding(3.dp)
.size(16.dp)
for (image in pickedImages) {
Box(contentAlignment = Alignment.TopEnd) {
Image(
bitmap = image.asImageBitmap(),
contentDescription = "",
modifier =
Modifier.height(80.dp)
.shadow(2.dp, shape = RoundedCornerShape(8.dp))
.clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
)
Box(
modifier =
Modifier.offset(x = 10.dp, y = (-10).dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surface)
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
.clickable { pickedImages = pickedImages.filter { image != it } }
) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
modifier = Modifier.padding(3.dp).size(16.dp),
)
}
}
}
}
}
Box(contentAlignment = Alignment.CenterStart) {
// A plus button to show a popup menu to add stuff to the chat.
IconButton(
enabled = !inProgress && !isResettingSession,
onClick = { showAddContentMenu = true },
modifier = Modifier.offset(x = 16.dp).alpha(0.8f),
) {
Icon(Icons.Rounded.Add, contentDescription = "", modifier = Modifier.size(28.dp))
}
Row(
modifier =
Modifier.fillMaxWidth()
.padding(12.dp)
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(28.dp)),
verticalAlignment = Alignment.CenterVertically,
) {
val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT
DropdownMenu(
expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false },
) {
if (showImagePickerInMenu) {
// Take a picture.
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
Text("Take a picture")
}
},
enabled = enableAddImageMenuItems,
onClick = {
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(context, Manifest.permission.CAMERA) -> {
showAddContentMenu = false
showCameraCaptureBottomSheet = true
}
// Otherwise, ask for permission
else -> {
takePicturePermissionLauncher.launch(Manifest.permission.CAMERA)
}
}
},
)
// Pick an image from album.
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(Icons.Rounded.Photo, contentDescription = "")
Text("Pick from album")
}
},
enabled = enableAddImageMenuItems,
onClick = {
// Launch the photo picker and let the user choose only images.
pickMedia.launch(
PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly)
)
showAddContentMenu = false
},
)
}
// Prompt templates.
if (showPromptTemplatesInMenu) {
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(Icons.Rounded.PostAdd, contentDescription = "")
Text("Prompt templates")
}
},
onClick = {
onOpenPromptTemplatesClicked()
showAddContentMenu = false
},
)
}
// Prompt history.
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(Icons.Rounded.History, contentDescription = "")
Text("Input history")
}
},
onClick = {
showAddContentMenu = false
showTextInputHistorySheet = true
},
)
}
}
}
// A plus button to show a popup menu to add stuff to the chat.
IconButton(
enabled = !inProgress && !isResettingSession,
onClick = { showAddContentMenu = true },
modifier = Modifier
.offset(x = 16.dp)
.alpha(0.8f)
) {
Icon(
Icons.Rounded.Add,
contentDescription = "",
modifier = Modifier.size(28.dp),
)
}
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp)
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(28.dp)),
verticalAlignment = Alignment.CenterVertically,
) {
DropdownMenu(
expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false }) {
if (showImagePickerInMenu) {
// Take a picture.
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.PhotoCamera, contentDescription = "")
Text("Take a picture")
}
},
enabled = pickedImages.isEmpty() && !hasImageMessage,
onClick = {
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(
context, Manifest.permission.CAMERA
) -> {
showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri()
showCameraCaptureBottomSheet = true
}
// Otherwise, ask for permission
else -> {
takePicturePermissionLauncher.launch(Manifest.permission.CAMERA)
}
}
})
// Pick an image from album.
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.Photo, contentDescription = "")
Text("Pick from album")
}
},
enabled = pickedImages.isEmpty() && !hasImageMessage,
onClick = {
// Launch the photo picker and let the user choose only images.
pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly))
showAddContentMenu = false
})
}
// Prompt templates.
if (showPromptTemplatesInMenu) {
DropdownMenuItem(text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.PostAdd, contentDescription = "")
Text("Prompt templates")
}
}, onClick = {
onOpenPromptTemplatesClicked()
showAddContentMenu = false
})
}
// Prompt history.
DropdownMenuItem(text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.History, contentDescription = "")
Text("Input history")
}
}, onClick = {
showAddContentMenu = false
showTextInputHistorySheet = true
})
}
// Text field.
TextField(value = curMessage,
minLines = 1,
maxLines = 3,
onValueChange = onValueChanged,
colors = TextFieldDefaults.colors(
unfocusedContainerColor = Color.Transparent,
focusedContainerColor = Color.Transparent,
focusedIndicatorColor = Color.Transparent,
unfocusedIndicatorColor = Color.Transparent,
disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyLarge,
modifier = Modifier
.weight(1f)
.padding(start = 36.dp),
placeholder = { Text(stringResource(textFieldPlaceHolderRes)) })
Spacer(modifier = Modifier.width(8.dp))
if (inProgress && showStopButtonWhenInProgress) {
if (!modelInitializing && !modelPreparing) {
IconButton(
onClick = onStopButtonClicked,
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
// Text field.
TextField(
value = curMessage,
minLines = 1,
maxLines = 3,
onValueChange = onValueChanged,
colors =
TextFieldDefaults.colors(
unfocusedContainerColor = Color.Transparent,
focusedContainerColor = Color.Transparent,
focusedIndicatorColor = Color.Transparent,
unfocusedIndicatorColor = Color.Transparent,
disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyLarge,
modifier = Modifier.weight(1f).padding(start = 36.dp),
placeholder = { Text(stringResource(textFieldPlaceHolderRes)) },
)
Spacer(modifier = Modifier.width(8.dp))
if (inProgress && showStopButtonWhenInProgress) {
if (!modelInitializing && !modelPreparing) {
IconButton(
onClick = onStopButtonClicked,
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer
),
) {
Icon(
Icons.Rounded.Stop,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary,
)
}
}
} // Send button. Only shown when text is not empty.
else if (curMessage.isNotEmpty()) {
IconButton(
enabled = !inProgress && !isResettingSession,
onClick = {
onSendMessage(
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
)
pickedImages = listOf()
},
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer
),
) {
Icon(
Icons.Rounded.Stop, contentDescription = "", tint = MaterialTheme.colorScheme.primary
Icons.AutoMirrored.Rounded.Send,
contentDescription = "",
modifier = Modifier.offset(x = 2.dp),
tint = MaterialTheme.colorScheme.primary,
)
}
}
} // Send button. Only shown when text is not empty.
else if (curMessage.isNotEmpty()) {
IconButton(
enabled = !inProgress && !isResettingSession,
onClick = {
onSendMessage(
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
)
pickedImages = listOf()
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
),
) {
Icon(
Icons.AutoMirrored.Rounded.Send,
contentDescription = "",
modifier = Modifier.offset(x = 2.dp),
tint = MaterialTheme.colorScheme.primary
)
}
Spacer(modifier = Modifier.width(4.dp))
}
Spacer(modifier = Modifier.width(4.dp))
}
}
@ -385,40 +401,37 @@ fun MessageInputText(
if (showTextInputHistorySheet) {
TextInputHistorySheet(
history = modelManagerUiState.textInputHistory,
onDismissed = {
showTextInputHistorySheet = false
},
onDismissed = { showTextInputHistorySheet = false },
onHistoryItemClicked = { item ->
onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
pickedImages = listOf()
modelManagerViewModel.promoteTextInputHistoryItem(item)
},
onHistoryItemDeleted = { item ->
modelManagerViewModel.deleteTextInputHistory(item)
},
onHistoryItemsDeleteAll = {
modelManagerViewModel.clearTextInputHistory()
})
onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) },
onHistoryItemsDeleteAll = { modelManagerViewModel.clearTextInputHistory() },
)
}
if (showCameraCaptureBottomSheet) {
ModalBottomSheet(
sheetState = cameraCaptureSheetState,
onDismissRequest = { showCameraCaptureBottomSheet = false }) {
onDismissRequest = { showCameraCaptureBottomSheet = false },
) {
val lifecycleOwner = LocalLifecycleOwner.current
val previewUseCase = remember { androidx.camera.core.Preview.Builder().build() }
val imageCaptureUseCase = remember {
// Try to limit the image size.
val preferredSize = Size(512, 512)
val resolutionStrategy = ResolutionStrategy(
preferredSize,
ResolutionStrategy.FALLBACK_RULE_CLOSEST_HIGHER_THEN_LOWER
)
val resolutionSelector = ResolutionSelector.Builder()
.setResolutionStrategy(resolutionStrategy)
.setAspectRatioStrategy(AspectRatioStrategy.RATIO_4_3_FALLBACK_AUTO_STRATEGY)
.build()
val resolutionStrategy =
ResolutionStrategy(
preferredSize,
ResolutionStrategy.FALLBACK_RULE_CLOSEST_HIGHER_THEN_LOWER,
)
val resolutionSelector =
ResolutionSelector.Builder()
.setResolutionStrategy(resolutionStrategy)
.setAspectRatioStrategy(AspectRatioStrategy.RATIO_4_3_FALLBACK_AUTO_STRATEGY)
.build()
ImageCapture.Builder().setResolutionSelector(resolutionSelector).build()
}
@ -430,17 +443,16 @@ fun MessageInputText(
fun rebindCameraProvider() {
cameraProvider?.let { cameraProvider ->
val cameraSelector = CameraSelector.Builder()
.requireLensFacing(cameraSide)
.build()
val cameraSelector = CameraSelector.Builder().requireLensFacing(cameraSide).build()
try {
cameraProvider.unbindAll()
val camera = cameraProvider.bindToLifecycle(
lifecycleOwner = lifecycleOwner,
cameraSelector = cameraSelector,
previewUseCase,
imageCaptureUseCase
)
val camera =
cameraProvider.bindToLifecycle(
lifecycleOwner = lifecycleOwner,
cameraSelector = cameraSelector,
previewUseCase,
imageCaptureUseCase,
)
cameraControl = camera.cameraControl
} catch (e: Exception) {
Log.d(TAG, "Failed to bind camera", e)
@ -453,15 +465,13 @@ fun MessageInputText(
rebindCameraProvider()
}
LaunchedEffect(cameraSide) {
rebindCameraProvider()
}
LaunchedEffect(cameraSide) { rebindCameraProvider() }
DisposableEffect(Unit) { // Or key on lifecycleOwner if it makes more sense
onDispose {
cameraProvider?.unbindAll() // Unbind all use cases from the camera provider
if (!executor.isShutdown) {
executor.shutdown() // Shut down the executor service
executor.shutdown() // Shut down the executor service
}
}
}
@ -485,54 +495,54 @@ fun MessageInputText(
cameraCaptureSheetState.hide()
showCameraCaptureBottomSheet = false
}
}, colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceVariant,
), modifier = Modifier
.offset(x = (-8).dp, y = 8.dp)
.align(Alignment.TopEnd)
},
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceVariant
),
modifier = Modifier.offset(x = (-8).dp, y = 8.dp).align(Alignment.TopEnd),
) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
tint = MaterialTheme.colorScheme.primary,
)
}
// Button that triggers the image capture process
IconButton(
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.primary,
),
modifier = Modifier
.align(Alignment.BottomCenter)
.padding(bottom = 32.dp)
.size(64.dp)
.border(2.dp, MaterialTheme.colorScheme.onPrimary, CircleShape),
colors =
IconButtonDefaults.iconButtonColors(containerColor = MaterialTheme.colorScheme.primary),
modifier =
Modifier.align(Alignment.BottomCenter)
.padding(bottom = 32.dp)
.size(64.dp)
.border(2.dp, MaterialTheme.colorScheme.onPrimary, CircleShape),
onClick = {
val callback = object : ImageCapture.OnImageCapturedCallback() {
override fun onCaptureSuccess(image: ImageProxy) {
try {
var bitmap = image.toBitmap()
val rotation = image.imageInfo.rotationDegrees
bitmap = if (rotation != 0) {
val matrix = Matrix().apply {
postRotate(rotation.toFloat())
val callback =
object : ImageCapture.OnImageCapturedCallback() {
override fun onCaptureSuccess(image: ImageProxy) {
try {
var bitmap = image.toBitmap()
val rotation = image.imageInfo.rotationDegrees
bitmap =
if (rotation != 0) {
val matrix = Matrix().apply { postRotate(rotation.toFloat()) }
Log.d(TAG, "image size: ${bitmap.width}, ${bitmap.height}")
Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else bitmap
updatePickedImages(listOf(bitmap))
} catch (e: Exception) {
Log.e(TAG, "Failed to process image", e)
} finally {
image.close()
scope.launch {
cameraCaptureSheetState.hide()
showCameraCaptureBottomSheet = false
}
Log.d(TAG, "image size: ${bitmap.width}, ${bitmap.height}")
Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else bitmap
updatePickedImages(bitmap)
} catch (e: Exception) {
Log.e(TAG, "Failed to process image", e)
} finally {
image.close()
scope.launch {
cameraCaptureSheetState.hide()
showCameraCaptureBottomSheet = false
}
}
}
}
imageCaptureUseCase.takePicture(executor, callback)
},
) {
@ -540,32 +550,32 @@ fun MessageInputText(
Icons.Rounded.PhotoCamera,
contentDescription = "",
tint = MaterialTheme.colorScheme.onPrimary,
modifier = Modifier.size(36.dp)
modifier = Modifier.size(36.dp),
)
}
// Button that toggles the front and back camera.
if (hasFrontCamera) {
IconButton(
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
),
modifier = Modifier
.align(Alignment.BottomEnd)
.padding(bottom = 40.dp, end = 32.dp)
.size(48.dp),
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer
),
modifier =
Modifier.align(Alignment.BottomEnd).padding(bottom = 40.dp, end = 32.dp).size(48.dp),
onClick = {
cameraSide = when (cameraSide) {
CameraSelector.LENS_FACING_BACK -> CameraSelector.LENS_FACING_FRONT
else -> CameraSelector.LENS_FACING_BACK
}
cameraSide =
when (cameraSide) {
CameraSelector.LENS_FACING_BACK -> CameraSelector.LENS_FACING_FRONT
else -> CameraSelector.LENS_FACING_BACK
}
},
) {
Icon(
Icons.Rounded.FlipCameraAndroid,
contentDescription = "",
tint = MaterialTheme.colorScheme.onSecondaryContainer,
modifier = Modifier.size(24.dp)
modifier = Modifier.size(24.dp),
)
}
}
@ -574,25 +584,32 @@ fun MessageInputText(
}
}
private fun handleImageSelected(
private fun handleImagesSelected(
context: Context,
uri: Uri,
onImageSelected: (Bitmap) -> Unit,
uris: List<Uri>,
onImagesSelected: (List<Bitmap>) -> Unit,
// For some reason, some Android phone would store the picture taken by the camera rotated
// horizontally. Use this flag to rotate the image back to portrait if the picture's width
// is bigger than height.
rotateForPortrait: Boolean = false,
) {
val bitmap: Bitmap? = try {
val inputStream = context.contentResolver.openInputStream(uri)
val tmpBitmap = BitmapFactory.decodeStream(inputStream)
rotateImageIfNecessary(bitmap = tmpBitmap, rotateForPortrait = rotateForPortrait)
} catch (e: Exception) {
e.printStackTrace()
null
val images: MutableList<Bitmap> = mutableListOf()
for (uri in uris) {
val bitmap: Bitmap? =
try {
val inputStream = context.contentResolver.openInputStream(uri)
val tmpBitmap = BitmapFactory.decodeStream(inputStream)
rotateImageIfNecessary(bitmap = tmpBitmap, rotateForPortrait = rotateForPortrait)
} catch (e: Exception) {
e.printStackTrace()
null
}
if (bitmap != null) {
images.add(bitmap)
}
}
if (bitmap != null) {
onImageSelected(bitmap)
if (images.isNotEmpty()) {
onImagesSelected(images)
}
}
@ -608,106 +625,106 @@ private fun rotateImageIfNecessary(bitmap: Bitmap, rotateForPortrait: Boolean =
private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) {
val cameraProviderFuture = ProcessCameraProvider.getInstance(context)
cameraProviderFuture.addListener({
val cameraProvider = cameraProviderFuture.get()
try {
// Attempt to select the default front camera
val hasFront = cameraProvider.hasCamera(CameraSelector.DEFAULT_FRONT_CAMERA)
callback(hasFront)
} catch (e: Exception) {
e.printStackTrace()
callback(false)
}
}, ContextCompat.getMainExecutor(context))
cameraProviderFuture.addListener(
{
val cameraProvider = cameraProviderFuture.get()
try {
// Attempt to select the default front camera
val hasFront = cameraProvider.hasCamera(CameraSelector.DEFAULT_FRONT_CAMERA)
callback(hasFront)
} catch (e: Exception) {
e.printStackTrace()
callback(false)
}
},
ContextCompat.getMainExecutor(context),
)
}
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> {
val messages: MutableList<ChatMessage> = mutableListOf()
var messages: MutableList<ChatMessage> = mutableListOf()
if (pickedImages.isNotEmpty()) {
val lastImage = pickedImages.last()
messages.add(
ChatMessageImage(
bitmap = lastImage, imageBitMap = lastImage.asImageBitmap(), side = ChatSide.USER
for (image in pickedImages) {
messages.add(
ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER)
)
)
}
}
messages.add(
ChatMessageText(
content = text, side = ChatSide.USER
)
)
// Cap the number of image messages.
if (messages.size > MAX_IMAGE_COUNT) {
messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
}
messages.add(ChatMessageText(content = text, side = ChatSide.USER))
return messages
}
@Preview(showBackground = true)
@Composable
fun MessageInputTextPreview() {
val context = LocalContext.current
GalleryTheme {
Column {
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = false,
isResettingSession = false,
modelInitializing = false,
hasImageMessage = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
showImagePickerInMenu = true,
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = false,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = true,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = false,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = true,
isResettingSession = false,
hasImageMessage = false,
modelInitializing = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
)
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageInputTextPreview() {
// val context = LocalContext.current
// GalleryTheme {
// Column {
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "hello",
// inProgress = false,
// isResettingSession = false,
// modelInitializing = false,
// hasImageMessage = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// showStopButtonWhenInProgress = true,
// showImagePickerInMenu = true,
// )
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "hello",
// inProgress = false,
// isResettingSession = false,
// hasImageMessage = false,
// modelInitializing = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// showStopButtonWhenInProgress = true,
// )
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "hello",
// inProgress = true,
// isResettingSession = false,
// hasImageMessage = false,
// modelInitializing = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// )
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "",
// inProgress = false,
// isResettingSession = false,
// hasImageMessage = false,
// modelInitializing = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// )
// MessageInputText(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// curMessage = "",
// inProgress = true,
// isResettingSession = false,
// hasImageMessage = false,
// modelInitializing = false,
// textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
// onValueChanged = {},
// onSendMessage = {},
// showStopButtonWhenInProgress = true,
// )
// }
// }
// }

View file

@ -16,22 +16,17 @@
package com.google.ai.edge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.padding
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.ui.common.humanReadableDuration
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display the latency of a chat message, if available.
*/
/** Composable function to display the latency of a chat message, if available. */
@Composable
fun LatencyText(message: ChatMessage) {
if (message.latencyMs >= 0) {
@ -43,21 +38,19 @@ fun LatencyText(message: ChatMessage) {
}
}
@Preview(showBackground = true)
@Composable
fun LatencyTextPreview() {
GalleryTheme {
Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp)) {
for (latencyMs in listOf(123f, 1234f, 123456f, 7234567f)) {
LatencyText(
message = ChatMessage(
latencyMs = latencyMs,
type = ChatMessageType.TEXT,
side = ChatSide.AGENT
)
)
}
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun LatencyTextPreview() {
// GalleryTheme {
// Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp))
// {
// for (latencyMs in listOf(123f, 1234f, 123456f, 7234567f)) {
// LatencyText(
// message =
// ChatMessage(latencyMs = latencyMs, type = ChatMessageType.TEXT, side =
// ChatSide.AGENT)
// )
// }
// }
// }
// }

View file

@ -16,9 +16,10 @@
package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
@ -32,40 +33,37 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow
data class MessageLayoutConfig(
val horizontalArrangement: Arrangement.Horizontal,
val modifier: Modifier,
val userLabel: String,
val rightSideLabel: String
val rightSideLabel: String,
)
/**
* Composable function to display the sender information for a chat message.
*
* This function handles different types of chat messages, including system messages,
* benchmark results, and image generation results, and displays the appropriate sender label
* and status information.
* This function handles different types of chat messages, including system messages, benchmark
* results, and image generation results, and displays the appropriate sender label and status
* information.
*/
@Composable
fun MessageSender(
message: ChatMessage,
agentName: String = "",
imageHistoryCurIndex: Int = 0
) {
fun MessageSender(message: ChatMessage, agentName: String = "", imageHistoryCurIndex: Int = 0) {
// No user label for system messages.
if (message.side == ChatSide.SYSTEM) {
return
}
val (horizontalArrangement, modifier, userLabel, rightSideLabel) = getMessageLayoutConfig(
message = message, agentName = agentName, imageHistoryCurIndex = imageHistoryCurIndex
)
val (horizontalArrangement, modifier, userLabel, rightSideLabel) =
getMessageLayoutConfig(
message = message,
agentName = agentName,
imageHistoryCurIndex = imageHistoryCurIndex,
)
Row(
modifier = modifier,
@ -74,10 +72,7 @@ fun MessageSender(
) {
Row(verticalAlignment = Alignment.CenterVertically) {
// Sender label.
Text(
userLabel,
style = MaterialTheme.typography.titleSmall,
)
Text(userLabel, style = MaterialTheme.typography.titleSmall)
when (message) {
// Benchmark running status.
@ -87,21 +82,18 @@ fun MessageSender(
CircularProgressIndicator(
modifier = Modifier.size(10.dp),
strokeWidth = 1.5.dp,
color = MaterialTheme.colorScheme.secondary
color = MaterialTheme.colorScheme.secondary,
)
Spacer(modifier = Modifier.width(4.dp))
}
val statusLabel = if (message.isWarmingUp()) {
stringResource(R.string.warming_up)
} else if (message.isRunning()) {
stringResource(R.string.running)
} else ""
val statusLabel =
if (message.isWarmingUp()) {
stringResource(R.string.warming_up)
} else if (message.isRunning()) {
stringResource(R.string.running)
} else ""
if (statusLabel.isNotEmpty()) {
Text(
statusLabel,
color = MaterialTheme.colorScheme.secondary,
style = bodySmallNarrow,
)
Text(statusLabel, color = MaterialTheme.colorScheme.secondary, style = bodySmallNarrow)
}
}
@ -112,7 +104,7 @@ fun MessageSender(
CircularProgressIndicator(
modifier = Modifier.size(10.dp),
strokeWidth = 1.5.dp,
color = MaterialTheme.colorScheme.secondary
color = MaterialTheme.colorScheme.secondary,
)
}
}
@ -124,7 +116,7 @@ fun MessageSender(
CircularProgressIndicator(
modifier = Modifier.size(10.dp),
strokeWidth = 1.5.dp,
color = MaterialTheme.colorScheme.secondary
color = MaterialTheme.colorScheme.secondary,
)
Spacer(modifier = Modifier.width(4.dp))
Text(
@ -141,8 +133,7 @@ fun MessageSender(
when (message) {
is ChatMessageBenchmarkResult,
is ChatMessageImageWithHistory,
is ChatMessageBenchmarkLlmResult,
-> {
is ChatMessageBenchmarkLlmResult -> {
Text(rightSideLabel, style = MaterialTheme.typography.bodySmall)
}
}
@ -169,11 +160,12 @@ private fun getMessageLayoutConfig(
horizontalArrangement = Arrangement.SpaceBetween
modifier = modifier.fillMaxWidth()
userLabel = "Benchmark"
rightSideLabel = if (message.isWarmingUp()) {
"${message.warmupCurrent}/${message.warmupTotal}"
} else {
"${message.iterationCurrent}/${message.iterationTotal}"
}
rightSideLabel =
if (message.isWarmingUp()) {
"${message.warmupCurrent}/${message.warmupTotal}"
} else {
"${message.iterationCurrent}/${message.iterationTotal}"
}
}
is ChatMessageBenchmarkLlmResult -> {
@ -198,64 +190,68 @@ private fun getMessageLayoutConfig(
horizontalArrangement = horizontalArrangement,
modifier = modifier,
userLabel = userLabel,
rightSideLabel = rightSideLabel
rightSideLabel = rightSideLabel,
)
}
@Preview(showBackground = true)
@Composable
fun MessageSenderPreview() {
GalleryTheme {
Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp)) {
// Agent message.
MessageSender(
message = ChatMessageText(content = "hello world", side = ChatSide.AGENT),
agentName = stringResource(R.string.chat_generic_agent_name)
)
// User message.
MessageSender(
message = ChatMessageText(content = "hello world", side = ChatSide.USER),
agentName = stringResource(R.string.chat_generic_agent_name)
)
// Benchmark during warmup.
MessageSender(
message = ChatMessageBenchmarkResult(
orderedStats = listOf(),
statValues = mutableMapOf(),
values = listOf(),
histogram = Histogram(listOf(), 0),
warmupCurrent = 10,
warmupTotal = 50,
iterationCurrent = 0,
iterationTotal = 200
),
agentName = stringResource(R.string.chat_generic_agent_name)
)
// Benchmark during running.
MessageSender(
message = ChatMessageBenchmarkResult(
orderedStats = listOf(),
statValues = mutableMapOf(),
values = listOf(),
histogram = Histogram(listOf(), 0),
warmupCurrent = 50,
warmupTotal = 50,
iterationCurrent = 123,
iterationTotal = 200
),
agentName = stringResource(R.string.chat_generic_agent_name)
)
// Image generation during running.
MessageSender(
message = ChatMessageImageWithHistory(
bitmaps = listOf(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888)),
imageBitMaps = listOf(),
totalIterations = 10,
ChatSide.AGENT
),
agentName = stringResource(R.string.chat_generic_agent_name),
imageHistoryCurIndex = 4,
)
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun MessageSenderPreview() {
// GalleryTheme {
// Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp))
// {
// // Agent message.
// MessageSender(
// message = ChatMessageText(content = "hello world", side = ChatSide.AGENT),
// agentName = stringResource(R.string.chat_generic_agent_name),
// )
// // User message.
// MessageSender(
// message = ChatMessageText(content = "hello world", side = ChatSide.USER),
// agentName = stringResource(R.string.chat_generic_agent_name),
// )
// // Benchmark during warmup.
// MessageSender(
// message =
// ChatMessageBenchmarkResult(
// orderedStats = listOf(),
// statValues = mutableMapOf(),
// values = listOf(),
// histogram = Histogram(listOf(), 0),
// warmupCurrent = 10,
// warmupTotal = 50,
// iterationCurrent = 0,
// iterationTotal = 200,
// ),
// agentName = stringResource(R.string.chat_generic_agent_name),
// )
// // Benchmark during running.
// MessageSender(
// message =
// ChatMessageBenchmarkResult(
// orderedStats = listOf(),
// statValues = mutableMapOf(),
// values = listOf(),
// histogram = Histogram(listOf(), 0),
// warmupCurrent = 50,
// warmupTotal = 50,
// iterationCurrent = 123,
// iterationTotal = 200,
// ),
// agentName = stringResource(R.string.chat_generic_agent_name),
// )
// // Image generation during running.
// MessageSender(
// message =
// ChatMessageImageWithHistory(
// bitmaps = listOf(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888)),
// imageBitMaps = listOf(),
// totalIterations = 10,
// ChatSide.AGENT,
// ),
// agentName = stringResource(R.string.chat_generic_agent_name),
// imageHistoryCurIndex = 4,
// )
// }
// }
// }

View file

@ -45,7 +45,7 @@ import kotlinx.coroutines.delay
fun ModelDownloadStatusInfoPanel(
model: Model,
task: Task,
modelManagerViewModel: ModelManagerViewModel
modelManagerViewModel: ModelManagerViewModel,
) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
@ -62,9 +62,12 @@ fun ModelDownloadStatusInfoPanel(
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
downloadingAnimationConditionMet =
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS || curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED || curStatus?.status == ModelDownloadStatusType.UNZIPPING
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS ||
curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED ||
curStatus?.status == ModelDownloadStatusType.UNZIPPING
downloadModelButtonConditionMet =
curStatus?.status == ModelDownloadStatusType.FAILED || curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
curStatus?.status == ModelDownloadStatusType.FAILED ||
curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
LaunchedEffect(downloadingAnimationConditionMet) {
if (downloadingAnimationConditionMet) {
@ -87,24 +90,22 @@ fun ModelDownloadStatusInfoPanel(
AnimatedVisibility(
visible = shouldShowDownloadingAnimation,
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
exit = scaleOut(targetScale = 0.9f) + fadeOut()
exit = scaleOut(targetScale = 0.9f) + fadeOut(),
) {
Box(
modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center
) {
Box(modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
ModelDownloadingAnimation(
model = model, task = task, modelManagerViewModel = modelManagerViewModel
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
)
}
}
AnimatedVisibility(
visible = shouldShowDownloadModelButton, enter = fadeIn(), exit = fadeOut()
) {
AnimatedVisibility(visible = shouldShowDownloadModelButton, enter = fadeIn(), exit = fadeOut()) {
Column(
modifier = Modifier.fillMaxSize(),
verticalArrangement = Arrangement.Center,
horizontalAlignment = Alignment.CenterHorizontally
horizontalAlignment = Alignment.CenterHorizontally,
) {
DownloadAndTryButton(
task = task,
@ -112,8 +113,8 @@ fun ModelDownloadStatusInfoPanel(
enabled = true,
needToDownloadFirst = true,
modelManagerViewModel = modelManagerViewModel,
onClicked = {}
onClicked = {},
)
}
}
}
}

View file

@ -16,6 +16,11 @@
package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.Easing
import androidx.compose.animation.core.tween
@ -47,13 +52,10 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
@ -62,14 +64,10 @@ import com.google.ai.edge.gallery.ui.common.formatToHourMinSecond
import com.google.ai.edge.gallery.ui.common.getTaskIconColor
import com.google.ai.edge.gallery.ui.common.humanReadableSize
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
import kotlinx.coroutines.delay
import kotlin.math.cos
import kotlin.math.pow
import kotlinx.coroutines.delay
private val GRID_SIZE = 240.dp
private val GRID_SPACING = 0.dp
@ -78,7 +76,6 @@ private const val ANIMATION_DURATION = 500
private const val START_SCALE = 0.9f
private const val END_SCALE = 0.6f
/**
* Composable function to display a loading animation using a 2x2 grid of images with a synchronized
* scaling and rotation effect.
@ -87,7 +84,7 @@ private const val END_SCALE = 0.6f
fun ModelDownloadingAnimation(
model: Model,
task: Task,
modelManagerViewModel: ModelManagerViewModel
modelManagerViewModel: ModelManagerViewModel,
) {
val scale = remember { Animatable(END_SCALE) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
@ -103,26 +100,27 @@ fun ModelDownloadingAnimation(
// Phase 1: Scale up
scale.animateTo(
targetValue = START_SCALE,
animationSpec = tween(
durationMillis = ANIMATION_DURATION,
easing = multiBounceEasing(bounces = 3, decay = 0.02f)
)
animationSpec =
tween(
durationMillis = ANIMATION_DURATION,
easing = multiBounceEasing(bounces = 3, decay = 0.02f),
),
)
delay(PAUSE_DURATION.toLong())
// Phase 2: Scale down
scale.animateTo(
targetValue = END_SCALE,
animationSpec = tween(
durationMillis = ANIMATION_DURATION,
easing = multiBounceEasing(bounces = 3, decay = 0.02f)
)
animationSpec =
tween(
durationMillis = ANIMATION_DURATION,
easing = multiBounceEasing(bounces = 3, decay = 0.02f),
),
)
delay(PAUSE_DURATION.toLong())
}
}
// Failure message.
val curDownloadStatus = downloadStatus
if (curDownloadStatus != null && curDownloadStatus.status == ModelDownloadStatusType.FAILED) {
@ -139,58 +137,55 @@ fun ModelDownloadingAnimation(
else {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier.offset(y = -GRID_SIZE / 8)
modifier = Modifier.offset(y = -GRID_SIZE / 8),
) {
LazyVerticalGrid(
columns = GridCells.Fixed(2),
horizontalArrangement = Arrangement.spacedBy(GRID_SPACING),
verticalArrangement = Arrangement.spacedBy(GRID_SPACING),
modifier = Modifier
.width(GRID_SIZE)
.height(GRID_SIZE)
modifier = Modifier.width(GRID_SIZE).height(GRID_SIZE),
) {
itemsIndexed(
listOf(
R.drawable.pantegon,
R.drawable.double_circle,
R.drawable.circle,
R.drawable.four_circle
R.drawable.four_circle,
)
) { index, imageResource ->
val currentScale =
if (index == 0 || index == 3) scale.value else START_SCALE + END_SCALE - scale.value
Box(
modifier = Modifier
.width((GRID_SIZE - GRID_SPACING) / 2)
.height((GRID_SIZE - GRID_SPACING) / 2),
contentAlignment = when (index) {
0 -> Alignment.BottomEnd
1 -> Alignment.BottomStart
2 -> Alignment.TopEnd
3 -> Alignment.TopStart
else -> Alignment.Center
}
modifier =
Modifier.width((GRID_SIZE - GRID_SPACING) / 2).height((GRID_SIZE - GRID_SPACING) / 2),
contentAlignment =
when (index) {
0 -> Alignment.BottomEnd
1 -> Alignment.BottomStart
2 -> Alignment.TopEnd
3 -> Alignment.TopStart
else -> Alignment.Center
},
) {
Image(
painter = painterResource(id = imageResource),
contentDescription = "",
contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint(getTaskIconColor(index = index)),
modifier = Modifier
.graphicsLayer {
scaleX = currentScale
scaleY = currentScale
rotationZ = currentScale * 120
alpha = 0.8f
}
.size(70.dp)
modifier =
Modifier.graphicsLayer {
scaleX = currentScale
scaleY = currentScale
rotationZ = currentScale * 120
alpha = 0.8f
}
.size(70.dp),
)
}
}
}
// Download stats
var sizeLabel = model.totalBytes.humanReadableSize()
if (curDownloadStatus != null) {
@ -203,8 +198,7 @@ fun ModelDownloadingAnimation(
sizeLabel =
"${curDownloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
if (curDownloadStatus.bytesPerSecond > 0) {
sizeLabel =
"$sizeLabel · ${curDownloadStatus.bytesPerSecond.humanReadableSize()} / s"
sizeLabel = "$sizeLabel · ${curDownloadStatus.bytesPerSecond.humanReadableSize()} / s"
if (curDownloadStatus.remainingMs >= 0) {
sizeLabel =
"$sizeLabel · ${curDownloadStatus.remainingMs.formatToHourMinSecond()} left"
@ -229,8 +223,7 @@ fun ModelDownloadingAnimation(
style = MaterialTheme.typography.labelMedium,
textAlign = TextAlign.Center,
overflow = TextOverflow.Visible,
modifier = Modifier
.padding(bottom = 4.dp)
modifier = Modifier.padding(bottom = 4.dp),
)
}
@ -241,10 +234,7 @@ fun ModelDownloadingAnimation(
progress = { animatedProgress.value },
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 36.dp)
.padding(horizontal = 36.dp)
modifier = Modifier.fillMaxWidth().padding(bottom = 36.dp).padding(horizontal = 36.dp),
)
LaunchedEffect(curDownloadProgress) {
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
@ -255,23 +245,19 @@ fun ModelDownloadingAnimation(
LinearProgressIndicator(
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 36.dp)
.padding(horizontal = 36.dp)
modifier = Modifier.fillMaxWidth().padding(bottom = 36.dp).padding(horizontal = 36.dp),
)
}
Text(
"Feel free to switch apps or lock your device.\n"
+ "The download will continue in the background.\n"
+ "We'll send a notification when it's done.",
"Feel free to switch apps or lock your device.\n" +
"The download will continue in the background.\n" +
"We'll send a notification when it's done.",
style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Center
textAlign = TextAlign.Center,
)
}
}
}
// Custom Easing function for a multi-bounce effect
@ -283,18 +269,18 @@ fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
}
}
@Preview(showBackground = true)
@Composable
fun ModelDownloadingAnimationPreview() {
val context = LocalContext.current
// @Preview(showBackground = true)
// @Composable
// fun ModelDownloadingAnimationPreview() {
// val context = LocalContext.current
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
ModelDownloadingAnimation(
model = MODEL_TEST1,
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = context)
)
}
}
}
// GalleryTheme {
// Row(modifier = Modifier.padding(16.dp)) {
// ModelDownloadingAnimation(
// model = MODEL_TEST1,
// task = TASK_TEST1,
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// )
// }
// }
// }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
@ -34,37 +37,35 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display a visual indicator for model initialization status.
*
* This function renders a row containing a circular progress indicator and a message
* indicating that the model is currently initializing. It provides a visual cue to the
* user that the model is in a loading state.
* This function renders a row containing a circular progress indicator and a message indicating
* that the model is currently initializing. It provides a visual cue to the user that the model is
* in a loading state.
*/
@Composable
fun ModelInitializationStatusChip() {
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) {
Box(
modifier = Modifier
.padding(8.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.secondaryContainer)
modifier =
Modifier.padding(8.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.secondaryContainer)
) {
Row(
modifier = Modifier.padding(top = 4.dp, bottom = 4.dp, start = 8.dp, end = 8.dp),
horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically
verticalAlignment = Alignment.CenterVertically,
) {
// Circular progress indicator.
CircularProgressIndicator(
modifier = Modifier.size(14.dp),
strokeWidth = 2.dp,
color = MaterialTheme.colorScheme.onSecondaryContainer
color = MaterialTheme.colorScheme.onSecondaryContainer,
)
Spacer(modifier = Modifier.width(8.dp))
@ -80,10 +81,8 @@ fun ModelInitializationStatusChip() {
}
}
@Preview(showBackground = true)
@Composable
fun ModelInitializationStatusPreview() {
GalleryTheme {
ModelInitializationStatusChip()
}
}
// @Preview(showBackground = true)
// @Composable
// fun ModelInitializationStatusPreview() {
// GalleryTheme { ModelInitializationStatusChip() }
// }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
@ -24,8 +27,6 @@ import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display a button to download model if the model has not been downloaded.
@ -35,20 +36,14 @@ fun ModelNotDownloaded(modifier: Modifier = Modifier, onClicked: () -> Unit) {
Column(
modifier = modifier.fillMaxSize(),
verticalArrangement = Arrangement.Center,
horizontalAlignment = Alignment.CenterHorizontally
horizontalAlignment = Alignment.CenterHorizontally,
) {
Button(
onClick = onClicked,
) {
Text("Download & Try it", maxLines = 1)
}
Button(onClick = onClicked) { Text("Download & Try it", maxLines = 1) }
}
}
@Preview(showBackground = true)
@Composable
fun Preview() {
GalleryTheme {
ModelNotDownloaded(onClicked = {})
}
}
// @Preview(showBackground = true)
// @Composable
// fun Preview() {
// GalleryTheme { ModelNotDownloaded(onClicked = {}) }
// }

View file

@ -16,7 +16,12 @@
package com.google.ai.edge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST2
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
@ -31,17 +36,13 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.convertValueToTargetType
import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.ui.common.ConfigDialog
import com.google.ai.edge.gallery.ui.common.modelitem.ModelItem
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.preview.TASK_TEST2
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display a selectable model item with an option to configure its settings.
@ -53,39 +54,31 @@ fun ModelSelector(
modelManagerViewModel: ModelManagerViewModel,
modifier: Modifier = Modifier,
contentAlpha: Float = 1f,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit =
{ _, _ ->
},
) {
var showConfigDialog by remember { mutableStateOf(false) }
val context = LocalContext.current
Column(
modifier = modifier
) {
Column(modifier = modifier) {
Box(
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 8.dp),
contentAlignment = Alignment.Center
modifier = Modifier.fillMaxWidth().padding(bottom = 8.dp),
contentAlignment = Alignment.Center,
) {
// Model row.
Row(
modifier = Modifier
.fillMaxWidth()
.graphicsLayer { alpha = contentAlpha },
verticalAlignment = Alignment.CenterVertically
modifier = Modifier.fillMaxWidth().graphicsLayer { alpha = contentAlpha },
verticalAlignment = Alignment.CenterVertically,
) {
ModelItem(
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelClicked = {},
onConfigClicked = {
showConfigDialog = true
},
onConfigClicked = { showConfigDialog = true },
verticalSpacing = 10.dp,
modifier = Modifier
.weight(1f)
.padding(horizontal = 16.dp),
modifier = Modifier.weight(1f).padding(horizontal = 16.dp),
showDeleteButton = false,
showConfigButtonIfExisted = true,
canExpand = false,
@ -111,12 +104,16 @@ fun ModelSelector(
var needReinitialization = false
for (config in model.configs) {
val key = config.key.label
val oldValue = convertValueToTargetType(
value = model.configValues.getValue(key), valueType = config.valueType
)
val newValue = convertValueToTargetType(
value = curConfigValues.getValue(key), valueType = config.valueType
)
val oldValue =
convertValueToTargetType(
value = model.configValues.getValue(key),
valueType = config.valueType,
)
val newValue =
convertValueToTargetType(
value = curConfigValues.getValue(key),
valueType = config.valueType,
)
if (oldValue != newValue) {
same = false
if (config.needReinitialization) {
@ -139,7 +136,7 @@ fun ModelSelector(
context = context,
task = task,
model = model,
force = true
force = true,
)
}
@ -150,27 +147,26 @@ fun ModelSelector(
}
}
@Preview(showBackground = true)
@Composable
fun ModelSelectorPreview(
) {
GalleryTheme {
Column(verticalArrangement = Arrangement.spacedBy(16.dp)) {
ModelSelector(
model = TASK_TEST1.models[0],
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelSelector(
model = TASK_TEST1.models[1],
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelSelector(
model = TASK_TEST2.models[1],
task = TASK_TEST2,
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun ModelSelectorPreview() {
// GalleryTheme {
// Column(verticalArrangement = Arrangement.spacedBy(16.dp)) {
// ModelSelector(
// model = TASK_TEST1.models[0],
// task = TASK_TEST1,
// modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
// )
// ModelSelector(
// model = TASK_TEST1.models[1],
// task = TASK_TEST1,
// modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
// )
// ModelSelector(
// model = TASK_TEST2.models[1],
// task = TASK_TEST2,
// modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
// )
// }
// }
// }

View file

@ -16,6 +16,8 @@
package com.google.ai.edge.gallery.ui.common.chat
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
@ -53,10 +55,8 @@ import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
@ -68,7 +68,7 @@ fun TextInputHistorySheet(
onHistoryItemClicked: (String) -> Unit,
onHistoryItemDeleted: (String) -> Unit,
onHistoryItemsDeleteAll: () -> Unit,
onDismissed: () -> Unit
onDismissed: () -> Unit,
) {
val sheetState = rememberModalBottomSheetState()
val scope = rememberCoroutineScope()
@ -101,7 +101,7 @@ fun TextInputHistorySheet(
sheetState.hide()
onDismissed()
}
}
},
)
}
}
@ -112,7 +112,7 @@ private fun SheetContent(
onHistoryItemClicked: (String) -> Unit,
onHistoryItemDeleted: (String) -> Unit,
onHistoryItemsDeleteAll: () -> Unit,
onDismissed: () -> Unit
onDismissed: () -> Unit,
) {
val scope = rememberCoroutineScope()
var showConfirmDeleteDialog by remember { mutableStateOf(false) }
@ -122,47 +122,44 @@ private fun SheetContent(
Text(
"Text input history",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier
.fillMaxWidth()
.padding(8.dp),
textAlign = TextAlign.Center
modifier = Modifier.fillMaxWidth().padding(8.dp),
textAlign = TextAlign.Center,
)
IconButton(modifier = Modifier.padding(end = 12.dp), onClick = {
showConfirmDeleteDialog = true
}) {
IconButton(
modifier = Modifier.padding(end = 12.dp),
onClick = { showConfirmDeleteDialog = true },
) {
Icon(Icons.Rounded.DeleteSweep, contentDescription = "")
}
}
LazyColumn(modifier = Modifier.weight(1f)) {
items(history, key = { it }) { item ->
Row(
modifier = Modifier
.fillMaxWidth()
.padding(horizontal = 8.dp, vertical = 2.dp)
.clip(RoundedCornerShape(24.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor)
.clickable {
onHistoryItemClicked(item)
},
modifier =
Modifier.fillMaxWidth()
.padding(horizontal = 8.dp, vertical = 2.dp)
.clip(RoundedCornerShape(24.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor)
.clickable { onHistoryItemClicked(item) },
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp)
horizontalArrangement = Arrangement.spacedBy(8.dp),
) {
Text(
item,
style = MaterialTheme.typography.bodyMedium,
maxLines = 3,
overflow = TextOverflow.Ellipsis,
modifier = Modifier
.padding(vertical = 16.dp)
.padding(start = 16.dp)
.weight(1f)
modifier = Modifier.padding(vertical = 16.dp).padding(start = 16.dp).weight(1f),
)
IconButton(modifier = Modifier.padding(end = 8.dp), onClick = {
scope.launch {
delay(400)
onHistoryItemDeleted(item)
}
}) {
IconButton(
modifier = Modifier.padding(end = 8.dp),
onClick = {
scope.launch {
delay(400)
onHistoryItemDeleted(item)
}
},
) {
Icon(Icons.Rounded.Delete, contentDescription = "")
}
}
@ -171,18 +168,17 @@ private fun SheetContent(
}
if (showConfirmDeleteDialog) {
AlertDialog(onDismissRequest = { showConfirmDeleteDialog = false },
AlertDialog(
onDismissRequest = { showConfirmDeleteDialog = false },
title = { Text("Clear history?") },
text = {
Text(
"Are you sure you want to clear the history? This action cannot be undone."
)
},
text = { Text("Are you sure you want to clear the history? This action cannot be undone.") },
confirmButton = {
Button(onClick = {
showConfirmDeleteDialog = false
onHistoryItemsDeleteAll()
}) {
Button(
onClick = {
showConfirmDeleteDialog = false
onHistoryItemsDeleteAll()
}
) {
Text(stringResource(R.string.ok))
}
},
@ -190,25 +186,29 @@ private fun SheetContent(
TextButton(onClick = { showConfirmDeleteDialog = false }) {
Text(stringResource(R.string.cancel))
}
})
},
)
}
}
@Preview(showBackground = true)
@Composable
fun TextInputHistorySheetContentPreview() {
GalleryTheme {
SheetContent(
history = listOf(
"Analyze the sentiment of the following Tweets and classify them as POSITIVE, NEGATIVE, or NEUTRAL. \"It's so beautiful today!\"",
"I have the ingredients above. Not sure what to cook for lunch. Show me a list of foods with the recipes.",
"You are Santa Claus, write a letter back for this kid.",
"Generate a list of cookie recipes. Make the outputs in JSON format."
),
onHistoryItemClicked = {},
onHistoryItemDeleted = {},
onHistoryItemsDeleteAll = {},
onDismissed = {},
)
}
}
// @Preview(showBackground = true)
// @Composable
// fun TextInputHistorySheetContentPreview() {
// GalleryTheme {
// SheetContent(
// history =
// listOf(
// "Analyze the sentiment of the following Tweets and classify them as POSITIVE, NEGATIVE,
// or NEUTRAL. \"It's so beautiful today!\"",
// "I have the ingredients above. Not sure what to cook for lunch. Show me a list of foods
// with the recipes.",
// "You are Santa Claus, write a letter back for this kid.",
// "Generate a list of cookie recipes. Make the outputs in JSON format.",
// ),
// onHistoryItemClicked = {},
// onHistoryItemDeleted = {},
// onHistoryItemsDeleteAll = {},
// onDismissed = {},
// )
// }
// }

View file

@ -35,25 +35,28 @@ fun ZoomableBox(
modifier: Modifier = Modifier,
minScale: Float = 1f,
maxScale: Float = 5f,
content: @Composable ZoomableBoxScope.() -> Unit
content: @Composable ZoomableBoxScope.() -> Unit,
) {
var scale by remember { mutableFloatStateOf(1f) }
var offsetX by remember { mutableFloatStateOf(0f) }
var offsetY by remember { mutableFloatStateOf(0f) }
var size by remember { mutableStateOf(IntSize.Zero) }
Box(modifier = modifier
.onSizeChanged { size = it }
.pointerInput(Unit) {
detectTransformGestures { _, pan, zoom, _ ->
scale = maxOf(minScale, minOf(scale * zoom, maxScale))
val maxX = (size.width * (scale - 1)) / 2
val minX = -maxX
offsetX = maxOf(minX, minOf(maxX, offsetX + pan.x))
val maxY = (size.height * (scale - 1)) / 2
val minY = -maxY
offsetY = maxOf(minY, minOf(maxY, offsetY + pan.y))
}
}, contentAlignment = Alignment.TopEnd
Box(
modifier =
modifier
.onSizeChanged { size = it }
.pointerInput(Unit) {
detectTransformGestures { _, pan, zoom, _ ->
scale = maxOf(minScale, minOf(scale * zoom, maxScale))
val maxX = (size.width * (scale - 1)) / 2
val minX = -maxX
offsetX = maxOf(minX, minOf(maxX, offsetX + pan.x))
val maxY = (size.height * (scale - 1)) / 2
val minY = -maxY
offsetY = maxOf(minY, minOf(maxY, offsetY + pan.y))
}
},
contentAlignment = Alignment.TopEnd,
) {
val scope = ZoomableBoxScopeImpl(scale, offsetX, offsetY)
scope.content()
@ -67,5 +70,7 @@ interface ZoomableBoxScope {
}
private data class ZoomableBoxScopeImpl(
override val scale: Float, override val offsetX: Float, override val offsetY: Float
override val scale: Float,
override val offsetX: Float,
override val offsetY: Float,
) : ZoomableBoxScope

View file

@ -1,73 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.modelitem
import androidx.compose.animation.core.DeferredTargetAnimation
import androidx.compose.animation.core.ExperimentalAnimatableApi
import androidx.compose.animation.core.VectorConverter
import androidx.compose.animation.core.tween
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.ui.Modifier
import androidx.compose.ui.composed
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.layout.LookaheadScope
import androidx.compose.ui.layout.approachLayout
import androidx.compose.ui.unit.Constraints
import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.unit.IntSize
import androidx.compose.ui.unit.round
const val LAYOUT_ANIMATION_DURATION = 250
context(LookaheadScope)
@OptIn(ExperimentalAnimatableApi::class)
fun Modifier.animateLayout(): Modifier = composed {
val sizeAnim = remember { DeferredTargetAnimation(IntSize.VectorConverter) }
val offsetAnim = remember { DeferredTargetAnimation(IntOffset.VectorConverter) }
val scope = rememberCoroutineScope()
this.approachLayout(
isMeasurementApproachInProgress = { lookaheadSize ->
sizeAnim.updateTarget(lookaheadSize, scope, tween(LAYOUT_ANIMATION_DURATION))
!sizeAnim.isIdle
},
isPlacementApproachInProgress = { lookaheadCoordinates ->
val target = lookaheadScopeCoordinates.localLookaheadPositionOf(lookaheadCoordinates)
offsetAnim.updateTarget(target.round(), scope, tween(LAYOUT_ANIMATION_DURATION))
!offsetAnim.isIdle
}
) { measurable, _ ->
val (animWidth, animHeight) = sizeAnim.updateTarget(
lookaheadSize,
scope,
tween(LAYOUT_ANIMATION_DURATION)
)
measurable.measure(Constraints.fixed(animWidth, animHeight))
.run {
layout(width, height) {
coordinates?.let {
val target = lookaheadScopeCoordinates.localLookaheadPositionOf(it).round()
val animOffset = offsetAnim.updateTarget(target, scope, tween(LAYOUT_ANIMATION_DURATION))
val current = lookaheadScopeCoordinates.localPositionOf(it, Offset.Zero).round()
val (x, y) = animOffset - current
place(x, y)
} ?: place(0, 0)
}
}
}
}

View file

@ -25,28 +25,16 @@ import androidx.compose.ui.res.stringResource
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.Model
/**
* Composable function to display a confirmation dialog for deleting a model.
*/
/** Composable function to display a confirmation dialog for deleting a model. */
@Composable
fun ConfirmDeleteModelDialog(model: Model, onConfirm: () -> Unit, onDismiss: () -> Unit) {
AlertDialog(onDismissRequest = onDismiss,
AlertDialog(
onDismissRequest = onDismiss,
title = { Text(stringResource(R.string.confirm_delete_model_dialog_title)) },
text = {
Text(
stringResource(R.string.confirm_delete_model_dialog_content).format(
model.name
)
)
Text(stringResource(R.string.confirm_delete_model_dialog_content).format(model.name))
},
confirmButton = {
Button(onClick = onConfirm) {
Text(stringResource(R.string.ok))
}
},
dismissButton = {
TextButton(onClick = onDismiss) {
Text(stringResource(R.string.cancel))
}
})
confirmButton = { Button(onClick = onConfirm) { Text(stringResource(R.string.ok)) } },
dismissButton = { TextButton(onClick = onDismiss) { Text(stringResource(R.string.cancel)) } },
)
}

View file

@ -16,8 +16,17 @@
package com.google.ai.edge.gallery.ui.common.modelitem
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST2
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST3
// import com.google.ai.edge.gallery.ui.preview.MODEL_TEST4
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
// import com.google.ai.edge.gallery.ui.preview.TASK_TEST2
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.content.Intent
import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.animation.AnimatedContent
@ -52,37 +61,28 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import androidx.core.net.toUri
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.DownloadAndTryButton
import com.google.ai.edge.gallery.ui.common.MarkdownText
import com.google.ai.edge.gallery.ui.common.TaskIcon
import com.google.ai.edge.gallery.ui.common.chat.MarkdownText
import com.google.ai.edge.gallery.ui.common.checkNotificationPermissionAndStartDownload
import com.google.ai.edge.gallery.ui.common.getTaskBgColor
import com.google.ai.edge.gallery.ui.common.getTaskIconColor
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST1
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST2
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST3
import com.google.ai.edge.gallery.ui.preview.MODEL_TEST4
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.TASK_TEST1
import com.google.ai.edge.gallery.ui.preview.TASK_TEST2
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
private val DEFAULT_VERTICAL_PADDING = 16.dp
/**
* Composable function to display a model item in the model manager list.
*
* This function renders a card representing a model, displaying its task icon, name,
* download status, and providing action buttons. It supports expanding to show a
* model description and buttons for learning more (opening a URL) and downloading/trying
* the model.
* This function renders a card representing a model, displaying its task icon, name, download
* status, and providing action buttons. It supports expanding to show a model description and
* buttons for learning more (opening a URL) and downloading/trying the model.
*/
@OptIn(ExperimentalSharedTransitionApi::class)
@Composable
@ -103,170 +103,174 @@ fun ModelItem(
val downloadStatus by remember {
derivedStateOf { modelManagerUiState.modelDownloadStatus[model.name] }
}
val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(task = task, model = model)
}
val launcher =
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
modelManagerViewModel.downloadModel(task = task, model = model)
}
var isExpanded by remember { mutableStateOf(false) }
var boxModifier = modifier
.fillMaxWidth()
.clip(RoundedCornerShape(size = 42.dp))
.background(
getTaskBgColor(task)
)
boxModifier = if (canExpand) {
boxModifier.clickable(onClick = {
if (!model.imported) {
isExpanded = !isExpanded
} else {
onModelClicked(model)
}
}, interactionSource = remember { MutableInteractionSource() }, indication = ripple(
bounded = true,
radius = 1000.dp,
)
)
} else {
boxModifier
}
var boxModifier =
modifier.fillMaxWidth().clip(RoundedCornerShape(size = 42.dp)).background(getTaskBgColor(task))
boxModifier =
if (canExpand) {
boxModifier.clickable(
onClick = {
if (!model.imported) {
isExpanded = !isExpanded
} else {
onModelClicked(model)
}
},
interactionSource = remember { MutableInteractionSource() },
indication = ripple(bounded = true, radius = 1000.dp),
)
} else {
boxModifier
}
Box(
modifier = boxModifier,
contentAlignment = Alignment.Center,
) {
Box(modifier = boxModifier, contentAlignment = Alignment.Center) {
SharedTransitionLayout {
AnimatedContent(
isExpanded, label = "item_layout_transition",
) { targetState ->
val taskIcon = @Composable {
TaskIcon(
task = task, modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "task_icon"),
animatedVisibilityScope = this@AnimatedContent,
)
)
}
val modelNameAndStatus = @Composable {
ModelNameAndStatus(
model = model,
task = task,
downloadStatus = downloadStatus,
isExpanded = isExpanded,
animatedVisibilityScope = this@AnimatedContent,
sharedTransitionScope = this@SharedTransitionLayout
)
}
val actionButton = @Composable {
ModelItemActionButton(
context = context,
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
downloadStatus = downloadStatus,
onDownloadClicked = { model ->
checkNotificationPermissionAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = model
)
},
showDeleteButton = showDeleteButton,
showDownloadButton = false,
modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "action_button"),
animatedVisibilityScope = this@AnimatedContent,
)
)
}
val expandButton = @Composable {
Icon(
// For imported model, show ">" directly indicating users can just tap the model item to
// go into it without needing to expand it first.
if (model.imported) Icons.Rounded.ChevronRight else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "",
tint = getTaskIconColor(task),
modifier = Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "expand_button"),
animatedVisibilityScope = this@AnimatedContent,
)
)
}
val description = @Composable {
if (model.info.isNotEmpty()) {
MarkdownText(
model.info, modifier = Modifier
.sharedElement(
sharedContentState = rememberSharedContentState(key = "description"),
AnimatedContent(isExpanded, label = "item_layout_transition") { targetState ->
val taskIcon =
@Composable {
TaskIcon(
task = task,
modifier =
Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "task_icon"),
animatedVisibilityScope = this@AnimatedContent,
)
.skipToLookaheadSize()
),
)
}
}
val buttonsRow = @Composable {
Row(
horizontalArrangement = Arrangement.spacedBy(12.dp), modifier = Modifier
.sharedElement(
sharedContentState = rememberSharedContentState(key = "buttons_row"),
animatedVisibilityScope = this@AnimatedContent,
)
.skipToLookaheadSize()
) {
// The "learn more" button. Click to show related urls in a bottom sheet.
if (model.learnMoreUrl.isNotEmpty()) {
OutlinedButton(
onClick = {
if (isExpanded) {
val intent = Intent(Intent.ACTION_VIEW, Uri.parse(model.learnMoreUrl))
context.startActivity(intent)
}
},
) {
Text("Learn More", maxLines = 1)
}
}
// Button to start the download and start the chat session with the model.
val needToDownloadFirst =
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
DownloadAndTryButton(task = task,
val modelNameAndStatus =
@Composable {
ModelNameAndStatus(
model = model,
enabled = isExpanded,
needToDownloadFirst = needToDownloadFirst,
modelManagerViewModel = modelManagerViewModel,
onClicked = { onModelClicked(model) })
task = task,
downloadStatus = downloadStatus,
isExpanded = isExpanded,
animatedVisibilityScope = this@AnimatedContent,
sharedTransitionScope = this@SharedTransitionLayout,
)
}
val actionButton =
@Composable {
ModelItemActionButton(
context = context,
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
downloadStatus = downloadStatus,
onDownloadClicked = { model ->
checkNotificationPermissionAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
task = task,
model = model,
)
},
showDeleteButton = showDeleteButton,
showDownloadButton = false,
modifier =
Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "action_button"),
animatedVisibilityScope = this@AnimatedContent,
),
)
}
val expandButton =
@Composable {
Icon(
// For imported model, show ">" directly indicating users can just tap the model item
// to
// go into it without needing to expand it first.
if (model.imported) Icons.Rounded.ChevronRight
else if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "",
tint = getTaskIconColor(task),
modifier =
Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "expand_button"),
animatedVisibilityScope = this@AnimatedContent,
),
)
}
val description =
@Composable {
if (model.info.isNotEmpty()) {
MarkdownText(
model.info,
modifier =
Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "description"),
animatedVisibilityScope = this@AnimatedContent,
)
.skipToLookaheadSize(),
)
}
}
val buttonsRow =
@Composable {
Row(
horizontalArrangement = Arrangement.spacedBy(12.dp),
modifier =
Modifier.sharedElement(
sharedContentState = rememberSharedContentState(key = "buttons_row"),
animatedVisibilityScope = this@AnimatedContent,
)
.skipToLookaheadSize(),
) {
// The "learn more" button. Click to show related urls in a bottom sheet.
if (model.learnMoreUrl.isNotEmpty()) {
OutlinedButton(
onClick = {
if (isExpanded) {
val intent = Intent(Intent.ACTION_VIEW, model.learnMoreUrl.toUri())
context.startActivity(intent)
}
}
) {
Text("Learn More", maxLines = 1)
}
}
// Button to start the download and start the chat session with the model.
val needToDownloadFirst =
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED ||
downloadStatus?.status == ModelDownloadStatusType.FAILED
DownloadAndTryButton(
task = task,
model = model,
enabled = isExpanded,
needToDownloadFirst = needToDownloadFirst,
modelManagerViewModel = modelManagerViewModel,
onClicked = { onModelClicked(model) },
)
}
}
}
// Collapsed state.
if (!targetState) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp),
modifier = Modifier
.fillMaxWidth()
.padding(start = 18.dp, end = 18.dp)
.padding(vertical = verticalSpacing)
modifier =
Modifier.fillMaxWidth()
.padding(start = 18.dp, end = 18.dp)
.padding(vertical = verticalSpacing),
) {
// Icon at the left.
taskIcon()
// Model name and status at the center.
Row(modifier = Modifier.weight(1f)) {
modelNameAndStatus()
}
Row(modifier = Modifier.weight(1f)) { modelNameAndStatus() }
// Action button and expand/collapse button at the right.
Row(verticalAlignment = Alignment.CenterVertically) {
actionButton()
@ -278,9 +282,8 @@ fun ModelItem(
Column(
verticalArrangement = Arrangement.spacedBy(14.dp),
horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier
.fillMaxWidth()
.padding(vertical = verticalSpacing, horizontal = 18.dp)
modifier =
Modifier.fillMaxWidth().padding(vertical = verticalSpacing, horizontal = 18.dp),
) {
Box(contentAlignment = Alignment.Center) {
// Icon at the top-center.
@ -289,7 +292,7 @@ fun ModelItem(
Row(
modifier = Modifier.fillMaxWidth(),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.End
horizontalArrangement = Arrangement.End,
) {
actionButton()
expandButton()
@ -308,37 +311,37 @@ fun ModelItem(
}
}
@Preview(showBackground = true)
@Composable
fun PreviewModelItem() {
GalleryTheme {
Column(
verticalArrangement = Arrangement.spacedBy(16.dp), modifier = Modifier.padding(16.dp)
) {
ModelItem(
model = MODEL_TEST1,
task = TASK_TEST1,
onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelItem(
model = MODEL_TEST2,
task = TASK_TEST1,
onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelItem(
model = MODEL_TEST3,
task = TASK_TEST2,
onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelItem(
model = MODEL_TEST4,
task = TASK_TEST2,
onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun PreviewModelItem() {
// GalleryTheme {
// Column(
// verticalArrangement = Arrangement.spacedBy(16.dp), modifier = Modifier.padding(16.dp)
// ) {
// ModelItem(
// model = MODEL_TEST1,
// task = TASK_TEST1,
// onModelClicked = { },
// modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
// )
// ModelItem(
// model = MODEL_TEST2,
// task = TASK_TEST1,
// onModelClicked = { },
// modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
// )
// ModelItem(
// model = MODEL_TEST3,
// task = TASK_TEST2,
// onModelClicked = { },
// modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
// )
// ModelItem(
// model = MODEL_TEST4,
// task = TASK_TEST2,
// onModelClicked = { },
// modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
// )
// }
// }
// }

View file

@ -66,69 +66,50 @@ fun ModelItemActionButton(
Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) {
when (downloadStatus?.status) {
// Button to start the download.
ModelDownloadStatusType.NOT_DOWNLOADED, ModelDownloadStatusType.FAILED ->
ModelDownloadStatusType.NOT_DOWNLOADED,
ModelDownloadStatusType.FAILED ->
if (showDownloadButton) {
IconButton(onClick = {
onDownloadClicked(model)
}) {
Icon(
Icons.Rounded.FileDownload,
contentDescription = "",
tint = getTaskIconColor(task),
)
IconButton(onClick = { onDownloadClicked(model) }) {
Icon(Icons.Rounded.FileDownload, contentDescription = "", tint = getTaskIconColor(task))
}
}
// Button to delete the download.
ModelDownloadStatusType.SUCCEEDED -> {
if (showDeleteButton) {
IconButton(onClick = {
showConfirmDeleteDialog = true
}) {
Icon(
Icons.Rounded.Delete,
contentDescription = "",
tint = getTaskIconColor(task),
)
IconButton(onClick = { showConfirmDeleteDialog = true }) {
Icon(Icons.Rounded.Delete, contentDescription = "", tint = getTaskIconColor(task))
}
}
}
// Show spinner when the model is partially downloaded because it might some time for
// background task to be started by Android.
ModelDownloadStatusType.PARTIALLY_DOWNLOADED -> {
CircularProgressIndicator(
modifier = Modifier
.padding(end = 12.dp)
.size(24.dp)
)
CircularProgressIndicator(modifier = Modifier.padding(end = 12.dp).size(24.dp))
}
// Button to cancel the download when it is in progress.
ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = {
modelManagerViewModel.cancelDownloadModel(
task = task,
model = model
)
}) {
Icon(
Icons.Rounded.Cancel,
contentDescription = "",
tint = getTaskIconColor(task),
)
}
ModelDownloadStatusType.IN_PROGRESS,
ModelDownloadStatusType.UNZIPPING ->
IconButton(
onClick = { modelManagerViewModel.cancelDownloadModel(task = task, model = model) }
) {
Icon(Icons.Rounded.Cancel, contentDescription = "", tint = getTaskIconColor(task))
}
else -> {}
}
}
if (showConfirmDeleteDialog) {
ConfirmDeleteModelDialog(model = model, onConfirm = {
modelManagerViewModel.deleteModel(task = task, model = model)
showConfirmDeleteDialog = false
}, onDismiss = {
showConfirmDeleteDialog = false
})
ConfirmDeleteModelDialog(
model = model,
onConfirm = {
modelManagerViewModel.deleteModel(task = task, model = model)
showConfirmDeleteDialog = false
},
onDismiss = { showConfirmDeleteDialog = false },
)
}
}
}

View file

@ -52,8 +52,8 @@ import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
* This function renders the model's name and its current download status, including:
* - Model name.
* - Failure message (if download failed).
* - Download progress (received size, total size, download rate, remaining time) for
* in-progress downloads.
* - Download progress (received size, total size, download rate, remaining time) for in-progress
* downloads.
* - "Unzipping..." status for unzipping processes.
* - Model size for successful downloads.
*/
@ -66,7 +66,7 @@ fun ModelNameAndStatus(
isExpanded: Boolean,
sharedTransitionScope: SharedTransitionScope,
animatedVisibilityScope: AnimatedVisibilityScope,
modifier: Modifier = Modifier
modifier: Modifier = Modifier,
) {
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
@ -77,18 +77,17 @@ fun ModelNameAndStatus(
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
) {
// Model name.
Row(
verticalAlignment = Alignment.CenterVertically,
) {
Row(verticalAlignment = Alignment.CenterVertically) {
Text(
model.name,
maxLines = 1,
overflow = TextOverflow.MiddleEllipsis,
style = MaterialTheme.typography.titleMedium,
modifier = Modifier.sharedElement(
rememberSharedContentState(key = "model_name"),
animatedVisibilityScope = animatedVisibilityScope
)
modifier =
Modifier.sharedElement(
rememberSharedContentState(key = "model_name"),
animatedVisibilityScope = animatedVisibilityScope,
),
)
}
@ -97,12 +96,13 @@ fun ModelNameAndStatus(
if (!inProgress && !isPartiallyDownloaded) {
StatusIcon(
downloadStatus = downloadStatus,
modifier = modifier
.padding(end = 4.dp)
.sharedElement(
rememberSharedContentState(key = "download_status_icon"),
animatedVisibilityScope = animatedVisibilityScope
)
modifier =
modifier
.padding(end = 4.dp)
.sharedElement(
rememberSharedContentState(key = "download_status_icon"),
animatedVisibilityScope = animatedVisibilityScope,
),
)
}
@ -114,10 +114,11 @@ fun ModelNameAndStatus(
color = MaterialTheme.colorScheme.error,
style = labelSmallNarrow,
overflow = TextOverflow.Ellipsis,
modifier = Modifier.sharedElement(
rememberSharedContentState(key = "failure_messsage"),
animatedVisibilityScope = animatedVisibilityScope
)
modifier =
Modifier.sharedElement(
rememberSharedContentState(key = "failure_messsage"),
animatedVisibilityScope = animatedVisibilityScope,
),
)
}
}
@ -138,8 +139,7 @@ fun ModelNameAndStatus(
sizeLabel =
"${downloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
if (downloadStatus.bytesPerSecond > 0) {
sizeLabel =
"$sizeLabel · ${downloadStatus.bytesPerSecond.humanReadableSize()} / s"
sizeLabel = "$sizeLabel · ${downloadStatus.bytesPerSecond.humanReadableSize()} / s"
if (downloadStatus.remainingMs >= 0) {
sizeLabel =
"$sizeLabel\n${downloadStatus.remainingMs.formatToHourMinSecond()} left"
@ -162,7 +162,7 @@ fun ModelNameAndStatus(
}
Column(
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start,
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
) {
for ((index, line) in sizeLabel.split("\n").withIndex()) {
Text(
@ -172,12 +172,12 @@ fun ModelNameAndStatus(
style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp),
textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start,
overflow = TextOverflow.Visible,
modifier = Modifier
.offset(y = if (index == 0) 0.dp else (-1).dp)
.sharedElement(
rememberSharedContentState(key = "status_label_${index}"),
animatedVisibilityScope = animatedVisibilityScope
)
modifier =
Modifier.offset(y = if (index == 0) 0.dp else (-1).dp)
.sharedElement(
rememberSharedContentState(key = "status_label_${index}"),
animatedVisibilityScope = animatedVisibilityScope,
),
)
}
}
@ -191,12 +191,12 @@ fun ModelNameAndStatus(
progress = { animatedProgress.value },
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.padding(top = 2.dp)
.sharedElement(
rememberSharedContentState(key = "download_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope
)
modifier =
Modifier.padding(top = 2.dp)
.sharedElement(
rememberSharedContentState(key = "download_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope,
),
)
LaunchedEffect(curDownloadProgress) {
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
@ -207,12 +207,12 @@ fun ModelNameAndStatus(
LinearProgressIndicator(
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.padding(top = 2.dp)
.sharedElement(
rememberSharedContentState(key = "unzip_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope
)
modifier =
Modifier.padding(top = 2.dp)
.sharedElement(
rememberSharedContentState(key = "unzip_progress_bar"),
animatedVisibilityScope = animatedVisibilityScope,
),
)
}
}

View file

@ -16,8 +16,9 @@
package com.google.ai.edge.gallery.ui.common.modelitem
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.size
import androidx.compose.material.icons.Icons
@ -31,75 +32,71 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.ModelDownloadStatus
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors
private val SIZE = 18.dp
/**
* Composable function to display an icon representing the download status of a model.
*/
/** Composable function to display an icon representing the download status of a model. */
@Composable
fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifier) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.Center,
modifier = modifier
modifier = modifier,
) {
when (downloadStatus?.status) {
ModelDownloadStatusType.NOT_DOWNLOADED -> Icon(
Icons.AutoMirrored.Outlined.HelpOutline,
tint = Color(0xFFCCCCCC),
contentDescription = "",
modifier = Modifier.size(SIZE)
)
ModelDownloadStatusType.NOT_DOWNLOADED ->
Icon(
Icons.AutoMirrored.Outlined.HelpOutline,
tint = Color(0xFFCCCCCC),
contentDescription = "",
modifier = Modifier.size(SIZE),
)
ModelDownloadStatusType.SUCCEEDED -> {
Icon(
Icons.Filled.DownloadForOffline,
tint = MaterialTheme.customColors.successColor,
contentDescription = "",
modifier = Modifier.size(SIZE)
modifier = Modifier.size(SIZE),
)
}
ModelDownloadStatusType.FAILED -> Icon(
Icons.Rounded.Error,
tint = Color(0xFFAA0000),
contentDescription = "",
modifier = Modifier.size(SIZE)
)
ModelDownloadStatusType.FAILED ->
Icon(
Icons.Rounded.Error,
tint = Color(0xFFAA0000),
contentDescription = "",
modifier = Modifier.size(SIZE),
)
ModelDownloadStatusType.IN_PROGRESS -> Icon(
Icons.Rounded.Downloading,
contentDescription = "",
modifier = Modifier.size(SIZE)
)
ModelDownloadStatusType.IN_PROGRESS ->
Icon(Icons.Rounded.Downloading, contentDescription = "", modifier = Modifier.size(SIZE))
else -> {}
}
}
}
@Preview(showBackground = true)
@Composable
fun StatusIconPreview() {
GalleryTheme {
Column {
for (downloadStatus in listOf(
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED),
ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS),
ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED),
ModelDownloadStatus(status = ModelDownloadStatusType.FAILED),
ModelDownloadStatus(status = ModelDownloadStatusType.UNZIPPING),
ModelDownloadStatus(status = ModelDownloadStatusType.PARTIALLY_DOWNLOADED),
)) {
StatusIcon(downloadStatus = downloadStatus)
}
}
}
}
// @Preview(showBackground = true)
// @Composable
// fun StatusIconPreview() {
// GalleryTheme {
// Column {
// for (downloadStatus in
// listOf(
// ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED),
// ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS),
// ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED),
// ModelDownloadStatus(status = ModelDownloadStatusType.FAILED),
// ModelDownloadStatus(status = ModelDownloadStatusType.UNZIPPING),
// ModelDownloadStatus(status = ModelDownloadStatusType.PARTIALLY_DOWNLOADED),
// )) {
// StatusIcon(downloadStatus = downloadStatus)
// }
// }
// }
// }

View file

@ -16,6 +16,9 @@
package com.google.ai.edge.gallery.ui.home
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import android.content.Context
import android.content.Intent
import android.net.Uri
@ -95,20 +98,17 @@ import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextDecoration
import androidx.compose.ui.text.withLink
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.GalleryTopAppBar
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.AppBarAction
import com.google.ai.edge.gallery.data.AppBarActionType
import com.google.ai.edge.gallery.data.ImportedModelInfo
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.proto.ImportedModel
import com.google.ai.edge.gallery.ui.common.TaskIcon
import com.google.ai.edge.gallery.ui.common.getTaskBgColor
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors
import com.google.ai.edge.gallery.ui.theme.titleMediumNarrow
import kotlinx.coroutines.delay
@ -125,8 +125,7 @@ private const val MIN_TASK_CARD_ICON_SIZE = 50
/** Navigation destination data */
object HomeScreenDestination {
@StringRes
val titleRes = R.string.app_name
@StringRes val titleRes = R.string.app_name
}
@OptIn(ExperimentalMaterial3Api::class)
@ -134,7 +133,7 @@ object HomeScreenDestination {
fun HomeScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateToTaskScreen: (Task) -> Unit,
modifier: Modifier = Modifier
modifier: Modifier = Modifier,
) {
val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior()
val uiState by modelManagerViewModel.uiState.collectAsState()
@ -145,53 +144,56 @@ fun HomeScreen(
var showImportDialog by remember { mutableStateOf(false) }
var showImportingDialog by remember { mutableStateOf(false) }
val selectedLocalModelFileUri = remember { mutableStateOf<Uri?>(null) }
val selectedImportedModelInfo = remember { mutableStateOf<ImportedModelInfo?>(null) }
val selectedImportedModelInfo = remember { mutableStateOf<ImportedModel?>(null) }
val coroutineScope = rememberCoroutineScope()
val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope()
val context = LocalContext.current
val filePickerLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
val fileName = getFileName(context = context, uri = uri)
Log.d(TAG, "Selected file: $fileName")
if (fileName != null && !fileName.endsWith(".task")) {
showUnsupportedFileTypeDialog = true
} else {
selectedLocalModelFileUri.value = uri
showImportDialog = true
}
} ?: run {
Log.d(TAG, "No file selected or URI is null.")
val filePickerLauncher: ActivityResultLauncher<Intent> =
rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
val fileName = getFileName(context = context, uri = uri)
Log.d(TAG, "Selected file: $fileName")
if (fileName != null && !fileName.endsWith(".task")) {
showUnsupportedFileTypeDialog = true
} else {
selectedLocalModelFileUri.value = uri
showImportDialog = true
}
} ?: run { Log.d(TAG, "No file selected or URI is null.") }
} else {
Log.d(TAG, "File picking cancelled.")
}
} else {
Log.d(TAG, "File picking cancelled.")
}
}
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = {
GalleryTopAppBar(
title = stringResource(HomeScreenDestination.titleRes),
rightAction = AppBarAction(actionType = AppBarActionType.APP_SETTING, actionFn = {
showSettingsDialog = true
}),
scrollBehavior = scrollBehavior,
)
}, floatingActionButton = {
// A floating action button to show "import model" bottom sheet.
SmallFloatingActionButton(
onClick = {
showImportModelSheet = true
},
containerColor = MaterialTheme.colorScheme.secondaryContainer,
contentColor = MaterialTheme.colorScheme.secondary,
) {
Icon(Icons.Filled.Add, "")
}
}) { innerPadding ->
Scaffold(
modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection),
topBar = {
GalleryTopAppBar(
title = stringResource(HomeScreenDestination.titleRes),
rightAction =
AppBarAction(
actionType = AppBarActionType.APP_SETTING,
actionFn = { showSettingsDialog = true },
),
scrollBehavior = scrollBehavior,
)
},
floatingActionButton = {
// A floating action button to show "import model" bottom sheet.
SmallFloatingActionButton(
onClick = { showImportModelSheet = true },
containerColor = MaterialTheme.colorScheme.secondaryContainer,
contentColor = MaterialTheme.colorScheme.secondary,
) {
Icon(Icons.Filled.Add, "")
}
},
) { innerPadding ->
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.fillMaxSize()) {
TaskList(
tasks = uiState.tasks,
@ -216,37 +218,36 @@ fun HomeScreen(
// Import model bottom sheet.
if (showImportModelSheet) {
ModalBottomSheet(
onDismissRequest = { showImportModelSheet = false },
sheetState = sheetState,
) {
ModalBottomSheet(onDismissRequest = { showImportModelSheet = false }, sheetState = sheetState) {
Text(
"Import model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp)
modifier = Modifier.padding(vertical = 4.dp, horizontal = 16.dp),
)
Box(modifier = Modifier.clickable {
coroutineScope.launch {
// Give it sometime to show the click effect.
delay(200)
showImportModelSheet = false
Box(
modifier =
Modifier.clickable {
coroutineScope.launch {
// Give it sometime to show the click effect.
delay(200)
showImportModelSheet = false
// Show file picker.
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*"
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
// Show file picker.
val intent =
Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "*/*"
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
}
filePickerLauncher.launch(intent)
}
}
filePickerLauncher.launch(intent)
}
}) {
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
modifier = Modifier.fillMaxWidth().padding(16.dp),
) {
Icon(Icons.AutoMirrored.Outlined.NoteAdd, contentDescription = "")
Text("From local model file")
@ -258,11 +259,15 @@ fun HomeScreen(
// Import dialog
if (showImportDialog) {
selectedLocalModelFileUri.value?.let { uri ->
ModelImportDialog(uri = uri, onDismiss = { showImportDialog = false }, onDone = { info ->
selectedImportedModelInfo.value = info
showImportDialog = false
showImportingDialog = true
})
ModelImportDialog(
uri = uri,
onDismiss = { showImportDialog = false },
onDone = { info ->
selectedImportedModelInfo.value = info
showImportDialog = false
showImportingDialog = true
},
)
}
}
@ -270,20 +275,18 @@ fun HomeScreen(
if (showImportingDialog) {
selectedLocalModelFileUri.value?.let { uri ->
selectedImportedModelInfo.value?.let { info ->
ModelImportingDialog(uri = uri,
ModelImportingDialog(
uri = uri,
info = info,
onDismiss = { showImportingDialog = false },
onDone = {
modelManagerViewModel.addImportedLlmModel(
info = it,
)
modelManagerViewModel.addImportedLlmModel(info = it)
showImportingDialog = false
// Show a snack bar for successful import.
scope.launch {
snackbarHostState.showSnackbar("Model imported successfully")
}
})
scope.launch { snackbarHostState.showSnackbar("Model imported successfully") }
},
)
}
}
}
@ -293,9 +296,7 @@ fun HomeScreen(
AlertDialog(
onDismissRequest = { showUnsupportedFileTypeDialog = false },
title = { Text("Unsupported file type") },
text = {
Text("Only \".task\" file type is supported.")
},
text = { Text("Only \".task\" file type is supported.") },
confirmButton = {
Button(onClick = { showUnsupportedFileTypeDialog = false }) {
Text(stringResource(R.string.ok))
@ -309,21 +310,11 @@ fun HomeScreen(
icon = {
Icon(Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error)
},
title = {
Text(uiState.loadingModelAllowlistError)
},
text = {
Text("Please check your internet connection and try again later.")
},
onDismissRequest = {
modelManagerViewModel.loadModelAllowlist()
},
title = { Text(uiState.loadingModelAllowlistError) },
text = { Text("Please check your internet connection and try again later.") },
onDismissRequest = { modelManagerViewModel.loadModelAllowlist() },
confirmButton = {
TextButton(onClick = {
modelManagerViewModel.loadModelAllowlist()
}) {
Text("Retry")
}
TextButton(onClick = { modelManagerViewModel.loadModelAllowlist() }) { Text("Retry") }
},
)
}
@ -339,31 +330,22 @@ private fun TaskList(
) {
val density = LocalDensity.current
val windowInfo = LocalWindowInfo.current
val screenWidthDp = remember {
with(density) {
windowInfo.containerSize.width.toDp()
}
}
val screenHeightDp = remember {
with(density) {
windowInfo.containerSize.height.toDp()
}
}
val screenWidthDp = remember { with(density) { windowInfo.containerSize.width.toDp() } }
val screenHeightDp = remember { with(density) { windowInfo.containerSize.height.toDp() } }
val sizeFraction = remember { ((screenWidthDp - 360.dp) / (410.dp - 360.dp)).coerceIn(0f, 1f) }
val linkColor = MaterialTheme.customColors.linkColor
val introText = buildAnnotatedString {
append("Welcome to Google AI Edge Gallery! Explore a world of amazing on-device models from ")
withLink(
link = LinkAnnotation.Url(
url = "https://huggingface.co/litert-community", // Replace with the actual URL
styles = TextLinkStyles(
style = SpanStyle(
color = linkColor,
textDecoration = TextDecoration.Underline,
)
link =
LinkAnnotation.Url(
url = "https://huggingface.co/litert-community", // Replace with the actual URL
styles =
TextLinkStyles(
style = SpanStyle(color = linkColor, textDecoration = TextDecoration.Underline)
),
)
)
) {
append("LiteRT community")
}
@ -378,9 +360,7 @@ private fun TaskList(
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// New rel
item(key = "newReleaseNotification", span = { GridItemSpan(2) }) {
NewReleaseNotification()
}
item(key = "newReleaseNotification", span = { GridItemSpan(2) }) { NewReleaseNotification() }
// Headline.
item(key = "headline", span = { GridItemSpan(2) }) {
@ -388,7 +368,7 @@ private fun TaskList(
introText,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.padding(bottom = 20.dp).padding(horizontal = 16.dp)
modifier = Modifier.padding(bottom = 20.dp).padding(horizontal = 16.dp),
)
}
@ -396,16 +376,12 @@ private fun TaskList(
item(key = "loading", span = { GridItemSpan(2) }) {
Row(
horizontalArrangement = Arrangement.Center,
modifier = Modifier
.fillMaxWidth()
.padding(top = 32.dp)
modifier = Modifier.fillMaxWidth().padding(top = 32.dp),
) {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceVariant,
strokeWidth = 3.dp,
modifier = Modifier
.padding(end = 8.dp)
.size(20.dp)
modifier = Modifier.padding(end = 8.dp).size(20.dp),
)
Text("Loading model list...", style = MaterialTheme.typography.bodyMedium)
}
@ -417,17 +393,16 @@ private fun TaskList(
"Example LLM Use Cases",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.padding(bottom = 4.dp)
modifier = Modifier.padding(bottom = 4.dp),
)
}
items(tasks) { task ->
TaskCard(
sizeFraction = sizeFraction, task = task, onClick = {
navigateToTaskScreen(task)
}, modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f)
sizeFraction = sizeFraction,
task = task,
onClick = { navigateToTaskScreen(task) },
modifier = Modifier.fillMaxWidth().aspectRatio(1f),
)
}
}
@ -440,22 +415,23 @@ private fun TaskList(
// Gradient overlay at the bottom.
Box(
modifier = Modifier
.fillMaxWidth()
.height(screenHeightDp * 0.25f)
.background(
Brush.verticalGradient(
colors = MaterialTheme.customColors.homeBottomGradient,
modifier =
Modifier.fillMaxWidth()
.height(screenHeightDp * 0.25f)
.background(
Brush.verticalGradient(colors = MaterialTheme.customColors.homeBottomGradient)
)
)
.align(Alignment.BottomCenter)
.align(Alignment.BottomCenter)
)
}
}
@Composable
private fun TaskCard(
task: Task, onClick: () -> Unit, sizeFraction: Float, modifier: Modifier = Modifier
task: Task,
onClick: () -> Unit,
sizeFraction: Float,
modifier: Modifier = Modifier,
) {
val padding =
(MAX_TASK_CARD_PADDING - MIN_TASK_CARD_PADDING) * sizeFraction + MIN_TASK_CARD_PADDING
@ -485,14 +461,16 @@ private fun TaskCard(
}
var curModelCountLabel by remember { mutableStateOf("") }
var modelCountLabelVisible by remember { mutableStateOf(true) }
val modelCountAlpha: Float by animateFloatAsState(
targetValue = if (modelCountLabelVisible) 1f else 0f,
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
)
val modelCountScale: Float by animateFloatAsState(
targetValue = if (modelCountLabelVisible) 1f else 0.7f,
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION)
)
val modelCountAlpha: Float by
animateFloatAsState(
targetValue = if (modelCountLabelVisible) 1f else 0f,
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION),
)
val modelCountScale: Float by
animateFloatAsState(
targetValue = if (modelCountLabelVisible) 1f else 0.7f,
animationSpec = tween(durationMillis = TASK_COUNT_ANIMATION_DURATION),
)
LaunchedEffect(modelCountLabel) {
if (curModelCountLabel.isEmpty()) {
@ -506,20 +484,10 @@ private fun TaskCard(
}
Card(
modifier = modifier
.clip(RoundedCornerShape(radius.dp))
.clickable(
onClick = onClick,
),
colors = CardDefaults.cardColors(
containerColor = getTaskBgColor(task = task)
),
modifier = modifier.clip(RoundedCornerShape(radius.dp)).clickable(onClick = onClick),
colors = CardDefaults.cardColors(containerColor = getTaskBgColor(task = task)),
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(padding.dp),
) {
Column(modifier = Modifier.fillMaxSize().padding(padding.dp)) {
// Icon.
TaskIcon(task = task, width = iconSize.dp)
@ -529,10 +497,7 @@ private fun TaskCard(
Text(
task.type.label,
color = MaterialTheme.colorScheme.primary,
style = titleMediumNarrow.copy(
fontSize = 20.sp,
fontWeight = FontWeight.Bold,
),
style = titleMediumNarrow.copy(fontSize = 20.sp, fontWeight = FontWeight.Bold),
)
Spacer(modifier = Modifier.weight(1f))
@ -542,9 +507,7 @@ private fun TaskCard(
curModelCountLabel,
color = MaterialTheme.colorScheme.secondary,
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier
.alpha(modelCountAlpha)
.scale(modelCountScale),
modifier = Modifier.alpha(modelCountAlpha).scale(modelCountScale),
)
}
}
@ -567,15 +530,13 @@ fun getFileName(context: Context, uri: Uri): String? {
return null
}
@Preview
@Composable
fun HomeScreenPreview(
) {
GalleryTheme {
HomeScreen(
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
navigateToTaskScreen = {},
)
}
}
// @Preview
// @Composable
// fun HomeScreenPreview() {
// GalleryTheme {
// HomeScreen(
// modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
// navigateToTaskScreen = {},
// )
// }
// }

View file

@ -63,69 +63,74 @@ import com.google.ai.edge.gallery.data.Accelerator
import com.google.ai.edge.gallery.data.BooleanSwitchConfig
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.DEFAULT_MAX_TOKEN
import com.google.ai.edge.gallery.data.DEFAULT_TEMPERATURE
import com.google.ai.edge.gallery.data.DEFAULT_TOPK
import com.google.ai.edge.gallery.data.DEFAULT_TOPP
import com.google.ai.edge.gallery.data.IMPORTS_DIR
import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.ImportedModelInfo
import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.ValueType
import com.google.ai.edge.gallery.ui.common.chat.ConfigEditorsPanel
import com.google.ai.edge.gallery.data.convertValueToTargetType
import com.google.ai.edge.gallery.proto.ImportedModel
import com.google.ai.edge.gallery.proto.LlmConfig
import com.google.ai.edge.gallery.ui.common.ConfigEditorsPanel
import com.google.ai.edge.gallery.ui.common.ensureValidFileName
import com.google.ai.edge.gallery.ui.common.humanReadableSize
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_MAX_TOKEN
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TEMPERATURE
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TOPK
import com.google.ai.edge.gallery.ui.llmchat.DEFAULT_TOPP
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import java.io.File
import java.io.FileOutputStream
import java.net.URLDecoder
import java.nio.charset.StandardCharsets
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
private const val TAG = "AGModelImportDialog"
private val IMPORT_CONFIGS_LLM: List<Config> = listOf(
LabelConfig(key = ConfigKey.NAME), LabelConfig(key = ConfigKey.MODEL_TYPE), NumberSliderConfig(
key = ConfigKey.DEFAULT_MAX_TOKENS,
sliderMin = 100f,
sliderMax = 1024f,
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
valueType = ValueType.INT
), NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = DEFAULT_TOPK.toFloat(),
valueType = ValueType.INT
), NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = DEFAULT_TOPP,
valueType = ValueType.FLOAT
), NumberSliderConfig(
key = ConfigKey.DEFAULT_TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = DEFAULT_TEMPERATURE,
valueType = ValueType.FLOAT
), BooleanSwitchConfig(
key = ConfigKey.SUPPORT_IMAGE,
defaultValue = false,
), SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
allowMultiple = true,
private val IMPORT_CONFIGS_LLM: List<Config> =
listOf(
LabelConfig(key = ConfigKey.NAME),
LabelConfig(key = ConfigKey.MODEL_TYPE),
NumberSliderConfig(
key = ConfigKey.DEFAULT_MAX_TOKENS,
sliderMin = 100f,
sliderMax = 1024f,
defaultValue = DEFAULT_MAX_TOKEN.toFloat(),
valueType = ValueType.INT,
),
NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = DEFAULT_TOPK.toFloat(),
valueType = ValueType.INT,
),
NumberSliderConfig(
key = ConfigKey.DEFAULT_TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = DEFAULT_TOPP,
valueType = ValueType.FLOAT,
),
NumberSliderConfig(
key = ConfigKey.DEFAULT_TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = DEFAULT_TEMPERATURE,
valueType = ValueType.FLOAT,
),
BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false),
SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,
options = listOf(Accelerator.CPU.label, Accelerator.GPU.label),
allowMultiple = true,
),
)
)
@Composable
fun ModelImportDialog(
uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
) {
fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -> Unit) {
val context = LocalContext.current
val info = remember { getFileSizeAndDisplayNameFromUri(context = context, uri = uri) }
val fileSize by remember { mutableLongStateOf(info.first) }
@ -142,78 +147,110 @@ fun ModelImportDialog(
}
}
val values: SnapshotStateMap<String, Any> = remember {
mutableStateMapOf<String, Any>().apply {
putAll(initialValues)
}
mutableStateMapOf<String, Any>().apply { putAll(initialValues) }
}
val interactionSource = remember { MutableInteractionSource() }
Dialog(
onDismissRequest = onDismiss,
) {
Dialog(onDismissRequest = onDismiss) {
val focusManager = LocalFocusManager.current
Card(
modifier = Modifier
.fillMaxWidth()
.clickable(
interactionSource = interactionSource, indication = null // Disable the ripple effect
modifier =
Modifier.fillMaxWidth().clickable(
interactionSource = interactionSource,
indication = null, // Disable the ripple effect
) {
focusManager.clearFocus()
}, shape = RoundedCornerShape(16.dp)
},
shape = RoundedCornerShape(16.dp),
) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) {
// Title.
Text(
"Import Model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
modifier = Modifier.padding(bottom = 8.dp),
)
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.verticalScroll(rememberScrollState()).weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp),
) {
// Default configs for users to set.
ConfigEditorsPanel(
configs = IMPORT_CONFIGS_LLM,
values = values,
)
ConfigEditorsPanel(configs = IMPORT_CONFIGS_LLM, values = values)
}
// Button row.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
modifier = Modifier.fillMaxWidth().padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
// Cancel button.
TextButton(
onClick = { onDismiss() },
) {
Text("Cancel")
}
TextButton(onClick = { onDismiss() }) { Text("Cancel") }
// Import button
Button(
onClick = {
onDone(
ImportedModelInfo(
fileName = fileName,
fileSize = fileSize,
defaultValues = values,
val supportedAccelerators =
(convertValueToTargetType(
value = values.get(ConfigKey.COMPATIBLE_ACCELERATORS.label)!!,
valueType = ValueType.STRING,
)
as String)
.split(",")
val defaultMaxTokens =
convertValueToTargetType(
value = values.get(ConfigKey.DEFAULT_MAX_TOKENS.label)!!,
valueType = ValueType.INT,
)
)
},
as Int
val defaultTopk =
convertValueToTargetType(
value = values.get(ConfigKey.DEFAULT_TOPK.label)!!,
valueType = ValueType.INT,
)
as Int
val defaultTopp =
convertValueToTargetType(
value = values.get(ConfigKey.DEFAULT_TOPP.label)!!,
valueType = ValueType.FLOAT,
)
as Float
val defaultTemperature =
convertValueToTargetType(
value = values.get(ConfigKey.DEFAULT_TEMPERATURE.label)!!,
valueType = ValueType.FLOAT,
)
as Float
val supportImage =
convertValueToTargetType(
value = values.get(ConfigKey.SUPPORT_IMAGE.label)!!,
valueType = ValueType.BOOLEAN,
)
as Boolean
val importedModel: ImportedModel =
ImportedModel.newBuilder()
.setFileName(fileName)
.setFileSize(fileSize)
.setLlmConfig(
LlmConfig.newBuilder()
.addAllCompatibleAccelerators(supportedAccelerators)
.setDefaultMaxTokens(defaultMaxTokens)
.setDefaultTopk(defaultTopk)
.setDefaultTopp(defaultTopp)
.setDefaultTemperature(defaultTemperature)
.setSupportImage(supportImage)
.build()
)
.build()
onDone(importedModel)
}
) {
Text("Import")
}
}
}
}
}
@ -221,7 +258,10 @@ fun ModelImportDialog(
@Composable
fun ModelImportingDialog(
uri: Uri, info: ImportedModelInfo, onDismiss: () -> Unit, onDone: (ImportedModelInfo) -> Unit
uri: Uri,
info: ImportedModel,
onDismiss: () -> Unit,
onDone: (ImportedModel) -> Unit,
) {
var error by remember { mutableStateOf("") }
val context = LocalContext.current
@ -230,20 +270,16 @@ fun ModelImportingDialog(
LaunchedEffect(Unit) {
// Import.
importModel(context = context,
importModel(
context = context,
coroutineScope = coroutineScope,
fileName = info.fileName,
fileSize = info.fileSize,
uri = uri,
onDone = {
onDone(info)
},
onProgress = {
progress = it
},
onError = {
error = it
})
onDone = { onDone(info) },
onProgress = { progress = it },
onError = { error = it },
)
}
Dialog(
@ -252,13 +288,14 @@ fun ModelImportingDialog(
) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) {
// Title.
Text(
"Import Model",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
modifier = Modifier.padding(bottom = 8.dp),
)
// No error.
@ -272,9 +309,7 @@ fun ModelImportingDialog(
val animatedProgress = remember { Animatable(0f) }
LinearProgressIndicator(
progress = { animatedProgress.value },
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 8.dp),
modifier = Modifier.fillMaxWidth().padding(bottom = 8.dp),
)
LaunchedEffect(progress) {
animatedProgress.animateTo(progress, animationSpec = tween(150))
@ -284,24 +319,23 @@ fun ModelImportingDialog(
// Has error.
else {
Row(
verticalAlignment = Alignment.Top, horizontalArrangement = Arrangement.spacedBy(6.dp)
verticalAlignment = Alignment.Top,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(
Icons.Rounded.Error, contentDescription = "", tint = MaterialTheme.colorScheme.error
Icons.Rounded.Error,
contentDescription = "",
tint = MaterialTheme.colorScheme.error,
)
Text(
error,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.error,
modifier = Modifier.padding(top = 4.dp)
modifier = Modifier.padding(top = 4.dp),
)
}
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.End) {
Button(onClick = {
onDismiss()
}) {
Text("Close")
}
Button(onClick = { onDismiss() }) { Text("Close") }
}
}
}
@ -376,17 +410,17 @@ private fun getFileSizeAndDisplayNameFromUri(context: Context, uri: Uri): Pair<L
var displayName = ""
try {
contentResolver.query(
uri, arrayOf(OpenableColumns.SIZE, OpenableColumns.DISPLAY_NAME), null, null, null
)?.use { cursor ->
if (cursor.moveToFirst()) {
val sizeIndex = cursor.getColumnIndexOrThrow(OpenableColumns.SIZE)
fileSize = cursor.getLong(sizeIndex)
contentResolver
.query(uri, arrayOf(OpenableColumns.SIZE, OpenableColumns.DISPLAY_NAME), null, null, null)
?.use { cursor ->
if (cursor.moveToFirst()) {
val sizeIndex = cursor.getColumnIndexOrThrow(OpenableColumns.SIZE)
fileSize = cursor.getLong(sizeIndex)
val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME)
displayName = cursor.getString(nameIndex)
val nameIndex = cursor.getColumnIndexOrThrow(OpenableColumns.DISPLAY_NAME)
displayName = cursor.getString(nameIndex)
}
}
}
} catch (e: Exception) {
e.printStackTrace()
return Pair(0L, "")

View file

@ -1,3 +1,19 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.home
import android.util.Log
@ -29,22 +45,17 @@ import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.compose.LocalLifecycleOwner
import com.google.ai.edge.gallery.BuildConfig
import com.google.ai.edge.gallery.ui.common.getJsonResponse
import com.google.ai.edge.gallery.ui.modelmanager.ClickableLink
import com.google.ai.edge.gallery.common.getJsonResponse
import com.google.ai.edge.gallery.ui.common.ClickableLink
import kotlin.math.max
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import kotlin.math.max
private const val TAG = "AGNewReleaseNotification"
private const val TAG = "AGNewReleaseNotifi"
private const val REPO = "google-ai-edge/gallery"
@Serializable
data class ReleaseInfo(
val html_url: String,
val tag_name: String,
)
data class ReleaseInfo(val html_url: String, val tag_name: String)
@Composable
fun NewReleaseNotification() {
@ -84,35 +95,31 @@ fun NewReleaseNotification() {
lifecycleOwner.lifecycle.addObserver(observer)
onDispose {
lifecycleOwner.lifecycle.removeObserver(observer)
}
onDispose { lifecycleOwner.lifecycle.removeObserver(observer) }
}
AnimatedVisibility(
visible = newReleaseVersion.isNotEmpty(),
enter = fadeIn() + expandVertically()
enter = fadeIn() + expandVertically(),
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier
.padding(horizontal = 16.dp)
.padding(bottom = 12.dp)
.clip(
CircleShape
)
.background(MaterialTheme.colorScheme.tertiaryContainer)
.padding(4.dp)
modifier =
Modifier.padding(horizontal = 16.dp)
.padding(bottom = 12.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.tertiaryContainer)
.padding(4.dp),
) {
Text(
"New release $newReleaseVersion available",
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier.padding(start = 12.dp)
modifier = Modifier.padding(start = 12.dp),
)
Row(
modifier = Modifier.padding(end = 12.dp),
verticalAlignment = Alignment.CenterVertically
verticalAlignment = Alignment.CenterVertically,
) {
ClickableLink(
url = newReleaseUrl,

View file

@ -61,10 +61,8 @@ import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.google.ai.edge.gallery.BuildConfig
import com.google.ai.edge.gallery.proto.Theme
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.THEME_AUTO
import com.google.ai.edge.gallery.ui.theme.THEME_DARK
import com.google.ai.edge.gallery.ui.theme.THEME_LIGHT
import com.google.ai.edge.gallery.ui.theme.ThemeSettings
import com.google.ai.edge.gallery.ui.theme.labelSmallNarrow
import java.time.Instant
@ -73,18 +71,19 @@ import java.time.format.DateTimeFormatter
import java.util.Locale
import kotlin.math.min
private val THEME_OPTIONS = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK)
private val THEME_OPTIONS = listOf(Theme.THEME_AUTO, Theme.THEME_LIGHT, Theme.THEME_DARK)
@Composable
fun SettingsDialog(
curThemeOverride: String,
curThemeOverride: Theme,
modelManagerViewModel: ModelManagerViewModel,
onDismissed: () -> Unit,
) {
var selectedTheme by remember { mutableStateOf(curThemeOverride) }
var hfToken by remember { mutableStateOf(modelManagerViewModel.getTokenStatusAndData().data) }
val dateFormatter = remember {
DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.systemDefault())
DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
.withZone(ZoneId.systemDefault())
.withLocale(Locale.getDefault())
}
var customHfToken by remember { mutableStateOf("") }
@ -95,72 +94,75 @@ fun SettingsDialog(
Dialog(onDismissRequest = onDismissed) {
val focusManager = LocalFocusManager.current
Card(
modifier = Modifier
.fillMaxWidth()
.clickable(
interactionSource = interactionSource, indication = null // Disable the ripple effect
modifier =
Modifier.fillMaxWidth().clickable(
interactionSource = interactionSource,
indication = null, // Disable the ripple effect
) {
focusManager.clearFocus()
}, shape = RoundedCornerShape(16.dp)
},
shape = RoundedCornerShape(16.dp),
) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp),
) {
// Dialog title and subtitle.
Column {
Text(
"Settings",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
modifier = Modifier.padding(bottom = 8.dp),
)
// Subtitle.
Text(
"App version: ${BuildConfig.VERSION_NAME}",
style = labelSmallNarrow,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.offset(y = (-6).dp)
modifier = Modifier.offset(y = (-6).dp),
)
}
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
.weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp)
modifier = Modifier.verticalScroll(rememberScrollState()).weight(1f, fill = false),
verticalArrangement = Arrangement.spacedBy(16.dp),
) {
// Theme switcher.
Column(
modifier = Modifier.fillMaxWidth()
) {
Column(modifier = Modifier.fillMaxWidth()) {
Text(
"Theme",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold),
)
MultiChoiceSegmentedButtonRow {
THEME_OPTIONS.forEachIndexed { index, label ->
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
index = index, count = THEME_OPTIONS.size
), onCheckedChange = {
selectedTheme = label
THEME_OPTIONS.forEachIndexed { index, theme ->
SegmentedButton(
shape =
SegmentedButtonDefaults.itemShape(index = index, count = THEME_OPTIONS.size),
onCheckedChange = {
selectedTheme = theme
// Update theme settings.
// This will update app's theme.
ThemeSettings.themeOverride.value = label
// Update theme settings.
// This will update app's theme.
ThemeSettings.themeOverride.value = theme
// Save to data store.
modelManagerViewModel.saveThemeOverride(label)
}, checked = label == selectedTheme, label = { Text(label) })
// Save to data store.
modelManagerViewModel.saveThemeOverride(theme)
},
checked = theme == selectedTheme,
label = { Text(themeLabel(theme)) },
)
}
}
}
// HF Token management.
Column(
modifier = Modifier.fillMaxWidth(), verticalArrangement = Arrangement.spacedBy(4.dp)
modifier = Modifier.fillMaxWidth(),
verticalArrangement = Arrangement.spacedBy(4.dp),
) {
Text(
"HuggingFace access token",
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold)
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold),
)
// Show the start of the token.
val curHfToken = hfToken
@ -168,23 +170,23 @@ fun SettingsDialog(
Text(
curHfToken.accessToken.substring(0, min(16, curHfToken.accessToken.length)) + "...",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
Text(
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
} else {
Text(
"Not available",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
Text(
"The token will be automatically retrieved when a gated model is downloaded",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
}
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
@ -192,46 +194,42 @@ fun SettingsDialog(
onClick = {
modelManagerViewModel.clearAccessToken()
hfToken = null
}, enabled = curHfToken != null
},
enabled = curHfToken != null,
) {
Text("Clear")
}
BasicTextField(
value = customHfToken,
singleLine = true,
modifier = Modifier
.fillMaxWidth()
.padding(top = 4.dp)
.focusRequester(focusRequester)
.onFocusChanged {
isFocused = it.isFocused
},
onValueChange = {
customHfToken = it
},
modifier =
Modifier.fillMaxWidth()
.padding(top = 4.dp)
.focusRequester(focusRequester)
.onFocusChanged { isFocused = it.isFocused },
onValueChange = { customHfToken = it },
textStyle = TextStyle(color = MaterialTheme.colorScheme.onSurface),
cursorBrush = SolidColor(MaterialTheme.colorScheme.onSurface),
) { innerTextField ->
Box(
modifier = Modifier
.border(
width = if (isFocused) 2.dp else 1.dp,
color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline,
shape = CircleShape,
)
.height(40.dp), contentAlignment = Alignment.CenterStart
modifier =
Modifier.border(
width = if (isFocused) 2.dp else 1.dp,
color =
if (isFocused) MaterialTheme.colorScheme.primary
else MaterialTheme.colorScheme.outline,
shape = CircleShape,
)
.height(40.dp),
contentAlignment = Alignment.CenterStart,
) {
Row(verticalAlignment = Alignment.CenterVertically) {
Box(
modifier = Modifier
.padding(start = 16.dp)
.weight(1f)
) {
Box(modifier = Modifier.padding(start = 16.dp).weight(1f)) {
if (customHfToken.isEmpty()) {
Text(
"Enter token manually",
color = MaterialTheme.colorScheme.onSurfaceVariant,
style = MaterialTheme.typography.bodySmall
style = MaterialTheme.typography.bodySmall,
)
}
innerTextField()
@ -246,7 +244,8 @@ fun SettingsDialog(
expiresAt = System.currentTimeMillis() + 1000L * 60 * 60 * 24 * 365 * 10,
)
hfToken = modelManagerViewModel.getTokenStatusAndData().data
}) {
},
) {
Icon(Icons.Rounded.CheckCircle, contentDescription = "")
}
}
@ -257,24 +256,24 @@ fun SettingsDialog(
}
}
// Button row.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
modifier = Modifier.fillMaxWidth().padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
// Close button
Button(
onClick = {
onDismissed()
},
) {
Text("Close")
}
Button(onClick = { onDismissed() }) { Text("Close") }
}
}
}
}
}
private fun themeLabel(theme: Theme): String {
return when (theme) {
Theme.THEME_AUTO -> "Auto"
Theme.THEME_LIGHT -> "Light"
Theme.THEME_DARK -> "Dark"
else -> "Unknown"
}
}

View file

@ -30,61 +30,64 @@ val Deployed_code: ImageVector
if (internal_Deployed_code != null) {
return internal_Deployed_code!!
}
internal_Deployed_code = ImageVector.Builder(
name = "Deployed_code",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f
).apply {
path(
fill = SolidColor(Color.Black),
fillAlpha = 1.0f,
stroke = null,
strokeAlpha = 1.0f,
strokeLineWidth = 1.0f,
strokeLineCap = StrokeCap.Butt,
strokeLineJoin = StrokeJoin.Miter,
strokeLineMiter = 1.0f,
pathFillType = PathFillType.NonZero
) {
moveTo(440f, 777f)
verticalLineToRelative(-274f)
lineTo(200f, 364f)
verticalLineToRelative(274f)
close()
moveToRelative(80f, 0f)
lineToRelative(240f, -139f)
verticalLineToRelative(-274f)
lineTo(520f, 503f)
close()
moveToRelative(-40f, -343f)
lineToRelative(237f, -137f)
lineToRelative(-237f, -137f)
lineToRelative(-237f, 137f)
close()
moveTo(160f, 708f)
quadToRelative(-19f, -11f, -29.5f, -29f)
reflectiveQuadTo(120f, 639f)
verticalLineToRelative(-318f)
quadToRelative(0f, -22f, 10.5f, -40f)
reflectiveQuadToRelative(29.5f, -29f)
lineToRelative(280f, -161f)
quadToRelative(19f, -11f, 40f, -11f)
reflectiveQuadToRelative(40f, 11f)
lineToRelative(280f, 161f)
quadToRelative(19f, 11f, 29.5f, 29f)
reflectiveQuadToRelative(10.5f, 40f)
verticalLineToRelative(318f)
quadToRelative(0f, 22f, -10.5f, 40f)
reflectiveQuadTo(800f, 708f)
lineTo(520f, 869f)
quadToRelative(-19f, 11f, -40f, 11f)
reflectiveQuadToRelative(-40f, -11f)
close()
moveToRelative(320f, -228f)
}
}.build()
internal_Deployed_code =
ImageVector.Builder(
name = "Deployed_code",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(
fill = SolidColor(Color.Black),
fillAlpha = 1.0f,
stroke = null,
strokeAlpha = 1.0f,
strokeLineWidth = 1.0f,
strokeLineCap = StrokeCap.Butt,
strokeLineJoin = StrokeJoin.Miter,
strokeLineMiter = 1.0f,
pathFillType = PathFillType.NonZero,
) {
moveTo(440f, 777f)
verticalLineToRelative(-274f)
lineTo(200f, 364f)
verticalLineToRelative(274f)
close()
moveToRelative(80f, 0f)
lineToRelative(240f, -139f)
verticalLineToRelative(-274f)
lineTo(520f, 503f)
close()
moveToRelative(-40f, -343f)
lineToRelative(237f, -137f)
lineToRelative(-237f, -137f)
lineToRelative(-237f, 137f)
close()
moveTo(160f, 708f)
quadToRelative(-19f, -11f, -29.5f, -29f)
reflectiveQuadTo(120f, 639f)
verticalLineToRelative(-318f)
quadToRelative(0f, -22f, 10.5f, -40f)
reflectiveQuadToRelative(29.5f, -29f)
lineToRelative(280f, -161f)
quadToRelative(19f, -11f, 40f, -11f)
reflectiveQuadToRelative(40f, 11f)
lineToRelative(280f, 161f)
quadToRelative(19f, 11f, 29.5f, 29f)
reflectiveQuadToRelative(10.5f, 40f)
verticalLineToRelative(318f)
quadToRelative(0f, 22f, -10.5f, 40f)
reflectiveQuadTo(800f, 708f)
lineTo(520f, 869f)
quadToRelative(-19f, 11f, -40f, 11f)
reflectiveQuadToRelative(-40f, -11f)
close()
moveToRelative(320f, -228f)
}
}
.build()
return internal_Deployed_code!!
}

View file

@ -0,0 +1,78 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Forum: ImageVector
get() {
if (_Forum != null) return _Forum!!
_Forum =
ImageVector.Builder(
name = "Forum",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(280f, 720f)
quadToRelative(-17f, 0f, -28.5f, -11.5f)
reflectiveQuadTo(240f, 680f)
verticalLineToRelative(-80f)
horizontalLineToRelative(520f)
verticalLineToRelative(-360f)
horizontalLineToRelative(80f)
quadToRelative(17f, 0f, 28.5f, 11.5f)
reflectiveQuadTo(880f, 280f)
verticalLineToRelative(600f)
lineTo(720f, 720f)
close()
moveTo(80f, 680f)
verticalLineToRelative(-560f)
quadToRelative(0f, -17f, 11.5f, -28.5f)
reflectiveQuadTo(120f, 80f)
horizontalLineToRelative(520f)
quadToRelative(17f, 0f, 28.5f, 11.5f)
reflectiveQuadTo(680f, 120f)
verticalLineToRelative(360f)
quadToRelative(0f, 17f, -11.5f, 28.5f)
reflectiveQuadTo(640f, 520f)
horizontalLineTo(240f)
close()
moveToRelative(520f, -240f)
verticalLineToRelative(-280f)
horizontalLineTo(160f)
verticalLineToRelative(280f)
close()
moveToRelative(-440f, 0f)
verticalLineToRelative(-280f)
close()
}
}
.build()
return _Forum!!
}
private var _Forum: ImageVector? = null

View file

@ -0,0 +1,73 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Mms: ImageVector
get() {
if (_Mms != null) return _Mms!!
_Mms =
ImageVector.Builder(
name = "Mms",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(240f, 560f)
horizontalLineToRelative(480f)
lineTo(570f, 360f)
lineTo(450f, 520f)
lineToRelative(-90f, -120f)
close()
moveTo(80f, 880f)
verticalLineToRelative(-720f)
quadToRelative(0f, -33f, 23.5f, -56.5f)
reflectiveQuadTo(160f, 80f)
horizontalLineToRelative(640f)
quadToRelative(33f, 0f, 56.5f, 23.5f)
reflectiveQuadTo(880f, 160f)
verticalLineToRelative(480f)
quadToRelative(0f, 33f, -23.5f, 56.5f)
reflectiveQuadTo(800f, 720f)
horizontalLineTo(240f)
close()
moveToRelative(126f, -240f)
horizontalLineToRelative(594f)
verticalLineToRelative(-480f)
horizontalLineTo(160f)
verticalLineToRelative(525f)
close()
moveToRelative(-46f, 0f)
verticalLineToRelative(-480f)
close()
}
}
.build()
return _Mms!!
}
private var _Mms: ImageVector? = null

View file

@ -0,0 +1,87 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Widgets: ImageVector
get() {
if (_Widgets != null) return _Widgets!!
_Widgets =
ImageVector.Builder(
name = "Widgets",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(666f, 520f)
lineTo(440f, 294f)
lineToRelative(226f, -226f)
lineToRelative(226f, 226f)
close()
moveToRelative(-546f, -80f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(400f, 400f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(-400f, 0f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(80f, -480f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(200f)
close()
moveToRelative(467f, 48f)
lineToRelative(113f, -113f)
lineToRelative(-113f, -113f)
lineToRelative(-113f, 113f)
close()
moveToRelative(-67f, 352f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(600f)
close()
moveToRelative(-400f, 0f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(200f)
close()
moveToRelative(400f, -160f)
}
}
.build()
return _Widgets!!
}
private var _Widgets: ImageVector? = null

View file

@ -1,154 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imageclassification
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import androidx.compose.ui.graphics.Color
import com.google.android.gms.tflite.client.TfLiteInitializationOptions
import com.google.android.gms.tflite.gpu.support.TfLiteGpu
import com.google.android.gms.tflite.java.TfLite
import com.google.ai.edge.gallery.ui.common.chat.Classification
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.common.LatencyProvider
import org.tensorflow.lite.DataType
import org.tensorflow.lite.InterpreterApi
import org.tensorflow.lite.gpu.GpuDelegateFactory
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.File
import java.io.FileInputStream
private const val TAG = "AGImageClassificationModelHelper"
class ImageClassificationInferenceResult(
val categories: List<Classification>, override val latencyMs: Float
) : LatencyProvider
//TODO: handle error.
object ImageClassificationModelHelper {
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU)
TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask ->
val optionsBuilder = TfLiteInitializationOptions.builder()
if (gpuTask.result) {
optionsBuilder.setEnableGpuDelegateSupport(true)
}
val task = TfLite.initialize(
context,
optionsBuilder.build()
)
task.addOnSuccessListener {
val interpreterOption =
InterpreterApi.Options().setRuntime(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)
if (useGpu) {
interpreterOption.addDelegateFactory(GpuDelegateFactory())
}
val interpreter = InterpreterApi.create(
File(model.getPath(context = context)), interpreterOption
)
model.instance = interpreter
onDone("")
}
}
}
fun cleanUp(model: Model) {
if (model.instance == null) {
return
}
val instance = model.instance as InterpreterApi
instance.close()
}
fun runInference(
context: Context,
model: Model,
input: Bitmap,
primaryColor: Color,
): ImageClassificationInferenceResult {
val instance = model.instance
if (instance == null) {
Log.d(
TAG, "Model '${model.name}' has not been initialized"
)
return ImageClassificationInferenceResult(categories = listOf(), latencyMs = 0f)
}
// Pre-process image.
val start = System.currentTimeMillis()
val interpreter = model.instance as InterpreterApi
val inputShape = interpreter.getInputTensor(0).shape()
val imageProcessor = ImageProcessor.Builder()
.add(ResizeOp(inputShape[1], inputShape[2], ResizeOp.ResizeMethod.BILINEAR))
.add(NormalizeOp(127.5f, 127.5f)) // Normalize pixel values
.build()
val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(input)
val inputTensorBuffer = imageProcessor.process(tensorImage).tensorBuffer
// Output buffer
val outputBuffer =
TensorBuffer.createFixedSize(interpreter.getOutputTensor(0).shape(), DataType.FLOAT32)
// Run inference
interpreter.run(inputTensorBuffer.buffer, outputBuffer.buffer)
// Post-process result.
val output = outputBuffer.floatArray
val labelsFilePath = model.getPath(
context = context,
fileName = model.getExtraDataFile(name = "labels")?.downloadFileName ?: "_"
)
val labelsFileInputStream = FileInputStream(File(labelsFilePath))
val labels = FileUtil.loadLabels(labelsFileInputStream)
labelsFileInputStream.close()
val topIndices =
getTopKMaxIndices(output = output, k = model.getIntConfigValue(ConfigKey.MAX_RESULT_COUNT))
val categories: List<Classification> =
topIndices.map { i ->
Classification(
label = labels[i],
score = output[i],
color = primaryColor
)
}
return ImageClassificationInferenceResult(
categories = categories,
latencyMs = (System.currentTimeMillis() - start).toFloat()
)
}
private fun getTopKMaxIndices(output: FloatArray, k: Int): List<Int> {
if (k <= 0 || output.isEmpty()) {
return emptyList()
}
val indexedValues = output.withIndex().toList()
val sortedIndexedValues =
indexedValues.sortedByDescending { it.value }
return sortedIndexedValues.take(k).map { it.index } // Take the top k and extract indices
}
}

View file

@ -1,98 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imageclassification
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.common.chat.ChatInputType
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatView
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */
object ImageClassificationDestination {
@Serializable
val route = "ImageClassificationRoute"
}
@Composable
fun ImageClassificationScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: ImageClassificationViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
val context = LocalContext.current
val primaryColor = MaterialTheme.colorScheme.primary
ChatView(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage(
model = model,
message = message,
)
if (message is ChatMessageImage) {
viewModel.generateResponse(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor,
)
}
},
onStreamImageMessage = { model, message ->
viewModel.generateStreamingResponse(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor,
)
},
onRunAgainClicked = { model, message ->
viewModel.runAgain(
context = context,
model = model,
message = message,
primaryColor = primaryColor,
)
},
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
viewModel.benchmark(
context = context,
model = model,
message = message,
warmupCount = warmUpIterations,
iterations = benchmarkIterations,
primaryColor = primaryColor,
)
},
navigateUp = navigateUp,
modifier = modifier,
chatInputType = ChatInputType.IMAGE,
)
}

View file

@ -1,165 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imageclassification
import android.content.Context
import android.graphics.Bitmap
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.ui.common.chat.ChatMessage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageClassification
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageType
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import com.google.ai.edge.gallery.ui.common.runBasicBenchmark
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
class ImageClassificationViewModel : ChatViewModel(task = TASK_IMAGE_CLASSIFICATION) {
private val mutex = Mutex()
fun generateResponse(context: Context, model: Model, input: Bitmap, primaryColor: Color) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
val result = ImageClassificationModelHelper.runInference(
context = context, model = model, input = input, primaryColor = primaryColor
)
super.addMessage(
model = model,
message = ChatMessageClassification(
classifications = result.categories,
latencyMs = result.latencyMs,
maxBarWidth = 300.dp,
),
)
}
}
fun generateStreamingResponse(
context: Context,
model: Model,
input: Bitmap,
primaryColor: Color
) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (mutex.tryLock()) {
try {
val result = ImageClassificationModelHelper.runInference(
context = context, model = model, input = input, primaryColor = primaryColor
)
updateStreamingMessage(
model = model,
message = ChatMessageClassification(
classifications = result.categories,
latencyMs = result.latencyMs
)
)
} finally {
mutex.unlock()
}
} else {
// skip call if the previous call has not been finished (mutex is still locked).
}
}
}
fun benchmark(
context: Context,
model: Model,
message: ChatMessage,
warmupCount: Int,
iterations: Int,
primaryColor: Color
) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (message is ChatMessageImage) {
setInProgress(true)
runBasicBenchmark(
model = model,
warmupCount = warmupCount,
iterations = iterations,
chatViewModel = this@ImageClassificationViewModel,
inferenceFn = {
ImageClassificationModelHelper.runInference(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor
)
},
chatMessageType = ChatMessageType.BENCHMARK_RESULT,
)
setInProgress(false)
}
}
}
fun runAgain(context: Context, model: Model, message: ChatMessage, primaryColor: Color) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (message is ChatMessageImage) {
// Clone the clicked message and add it.
addMessage(model = model, message = message.clone())
// Run inference.
val result =
ImageClassificationModelHelper.runInference(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor
)
// Add response message.
val newMessage = generateClassificationMessage(result = result)
addMessage(model = model, message = newMessage)
}
}
}
private fun generateClassificationMessage(result: ImageClassificationInferenceResult): ChatMessageClassification {
return ChatMessageClassification(
classifications = result.categories,
latencyMs = result.latencyMs,
maxBarWidth = 300.dp,
)
}
}

View file

@ -1,83 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imagegeneration
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import com.google.mediapipe.framework.image.BitmapExtractor
import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.common.LatencyProvider
import com.google.ai.edge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import kotlin.random.Random
private const val TAG = "AGImageGenerationModelHelper"
class ImageGenerationInferenceResult(
val bitmap: Bitmap, override val latencyMs: Float
) : LatencyProvider
object ImageGenerationModelHelper {
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
try {
val options = ImageGenerator.ImageGeneratorOptions.builder()
.setImageGeneratorModelDirectory(model.getPath(context = context))
.build()
model.instance = ImageGenerator.createFromOptions(context, options)
} catch (e: Exception) {
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
return
}
onDone("")
}
fun cleanUp(model: Model) {
if (model.instance == null) {
return
}
val instance = model.instance as ImageGenerator
try {
instance.close()
} catch (e: Exception) {
// ignore
}
model.instance = null
Log.d(TAG, "Clean up done.")
}
fun runInference(
model: Model,
input: String,
onStep: (curIteration: Int, totalIterations: Int, ImageGenerationInferenceResult, isLast: Boolean) -> Unit
) {
val start = System.currentTimeMillis()
val instance = model.instance as ImageGenerator
val iterations = model.getIntConfigValue(ConfigKey.ITERATIONS)
instance.setInputs(input, iterations, Random.nextInt())
for (i in 0..<iterations) {
val result = ImageGenerationInferenceResult(
bitmap = BitmapExtractor.extract(
instance.execute(true)?.generatedImage(),
),
latencyMs = (System.currentTimeMillis() - start).toFloat(),
)
onStep(i, iterations, result, i == iterations - 1)
}
}
}

View file

@ -1,65 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imagegeneration
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatView
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */
object ImageGenerationDestination {
@Serializable
val route = "ImageGenerationRoute"
}
@Composable
fun ImageGenerationScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: ImageGenerationViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
ChatView(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, messages ->
val message = messages[0]
viewModel.addMessage(
model = model,
message = message,
)
if (message is ChatMessageText) {
viewModel.generateResponse(
model = model,
input = message.content,
)
}
},
onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> },
navigateUp = navigateUp,
modifier = modifier,
)
}

View file

@ -1,87 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.imagegeneration
import android.graphics.Bitmap
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.graphics.asImageBitmap
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_IMAGE_GENERATION
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImageWithHistory
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageType
import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
class ImageGenerationViewModel : ChatViewModel(task = TASK_IMAGE_GENERATION) {
fun generateResponse(model: Model, input: String) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
// Loading.
addMessage(
model = model,
message = ChatMessageLoading(),
)
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
// Run inference.
val bitmaps: MutableList<Bitmap> = mutableListOf()
val imageBitmaps: MutableList<ImageBitmap> = mutableListOf()
ImageGenerationModelHelper.runInference(
model = model, input = input
) { step, totalIterations, result, isLast ->
bitmaps.add(result.bitmap)
imageBitmaps.add(result.bitmap.asImageBitmap())
val message = ChatMessageImageWithHistory(
bitmaps = bitmaps,
imageBitMaps = imageBitmaps,
totalIterations = totalIterations,
side = ChatSide.AGENT,
latencyMs = result.latencyMs,
curIteration = step,
)
if (step == 0) {
removeLastMessage(model = model)
super.addMessage(
model = model,
message = message,
)
} else {
super.replaceLastMessage(
model = model,
message = message,
type = ChatMessageType.IMAGE_WITH_HISTORY
)
}
if (isLast) {
setInProgress(false)
}
}
}
}
}

View file

@ -1,72 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.llmchat
import com.google.ai.edge.gallery.data.Accelerator
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.ValueType
const val DEFAULT_MAX_TOKEN = 1024
const val DEFAULT_TOPK = 40
const val DEFAULT_TOPP = 0.9f
const val DEFAULT_TEMPERATURE = 1.0f
val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
fun createLlmChatConfigs(
defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
defaultTopK: Int = DEFAULT_TOPK,
defaultTopP: Float = DEFAULT_TOPP,
defaultTemperature: Float = DEFAULT_TEMPERATURE,
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
): List<Config> {
return listOf(
LabelConfig(
key = ConfigKey.MAX_TOKENS,
defaultValue = "$defaultMaxToken",
),
NumberSliderConfig(
key = ConfigKey.TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = defaultTopK.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = defaultTopP,
valueType = ValueType.FLOAT
),
NumberSliderConfig(
key = ConfigKey.TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = defaultTemperature,
valueType = ValueType.FLOAT
),
SegmentedButtonConfig(
key = ConfigKey.ACCELERATOR,
defaultValue = accelerators[0].label,
options = accelerators.map { it.label }
)
)
}

View file

@ -19,10 +19,15 @@ package com.google.ai.edge.gallery.ui.llmchat
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import com.google.ai.edge.gallery.common.cleanUpMediapipeTaskErrorMessage
import com.google.ai.edge.gallery.data.Accelerator
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.DEFAULT_MAX_TOKEN
import com.google.ai.edge.gallery.data.DEFAULT_TEMPERATURE
import com.google.ai.edge.gallery.data.DEFAULT_TOPK
import com.google.ai.edge.gallery.data.DEFAULT_TOPP
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.common.cleanUpMediapipeTaskErrorMessage
import com.google.mediapipe.framework.image.BitmapImageBuilder
import com.google.mediapipe.tasks.genai.llminference.GraphOptions
import com.google.mediapipe.tasks.genai.llminference.LlmInference
@ -31,6 +36,7 @@ import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
private const val TAG = "AGLlmChatModelHelper"
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
typealias CleanUpListener = () -> Unit
data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession)
@ -39,9 +45,7 @@ object LlmChatModelHelper {
// Indexed by model name.
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
fun initialize(
context: Context, model: Model, onDone: (String) -> Unit
) {
fun initialize(context: Context, model: Model, onDone: (String) -> Unit) {
// Prepare options.
val maxTokens =
model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN)
@ -52,29 +56,36 @@ object LlmChatModelHelper {
val accelerator =
model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = Accelerator.GPU.label)
Log.d(TAG, "Initializing...")
val preferredBackend = when (accelerator) {
Accelerator.CPU.label -> LlmInference.Backend.CPU
Accelerator.GPU.label -> LlmInference.Backend.GPU
else -> LlmInference.Backend.GPU
}
val preferredBackend =
when (accelerator) {
Accelerator.CPU.label -> LlmInference.Backend.CPU
Accelerator.GPU.label -> LlmInference.Backend.GPU
else -> LlmInference.Backend.GPU
}
val options =
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend)
.setMaxNumImages(if (model.llmSupportImage) 1 else 0)
LlmInference.LlmInferenceOptions.builder()
.setModelPath(model.getPath(context = context))
.setMaxTokens(maxTokens)
.setPreferredBackend(preferredBackend)
.setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0)
.build()
// Create an instance of the LLM Inference task and session.
try {
val llmInference = LlmInference.createFromOptions(context, options)
val session = LlmInferenceSession.createFromOptions(
llmInference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
).build()
)
val session =
LlmInferenceSession.createFromOptions(
llmInference,
LlmInferenceSession.LlmInferenceSessionOptions.builder()
.setTopK(topK)
.setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
)
.build(),
)
model.instance = LlmModelInstance(engine = llmInference, session = session)
} catch (e: Exception) {
onDone(cleanUpMediapipeTaskErrorMessage(e.message ?: "Unknown error"))
@ -96,14 +107,18 @@ object LlmChatModelHelper {
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
val newSession = LlmInferenceSession.createFromOptions(
inference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
).build()
)
val newSession =
LlmInferenceSession.createFromOptions(
inference,
LlmInferenceSession.LlmInferenceSessionOptions.builder()
.setTopK(topK)
.setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
)
.build(),
)
instance.session = newSession
Log.d(TAG, "Resetting done")
} catch (e: Exception) {
@ -117,12 +132,19 @@ object LlmChatModelHelper {
}
val instance = model.instance as LlmModelInstance
try {
instance.session.close()
} catch (e: Exception) {
Log.e(TAG, "Failed to close the LLM Inference session: ${e.message}")
}
try {
// This will also close the session. Do not call session.close manually.
instance.engine.close()
} catch (e: Exception) {
// ignore
Log.e(TAG, "Failed to close the LLM Inference engine: ${e.message}")
}
val onCleanUp = cleanUpListeners.remove(model.name)
if (onCleanUp != null) {
onCleanUp()
@ -136,7 +158,7 @@ object LlmChatModelHelper {
input: String,
resultListener: ResultListener,
cleanUpListener: CleanUpListener,
image: Bitmap? = null,
images: List<Bitmap> = listOf(),
) {
val instance = model.instance as LlmModelInstance
@ -151,9 +173,9 @@ object LlmChatModelHelper {
// image.
val session = instance.session
session.addQueryChunk(input)
if (image != null) {
for (image in images) {
session.addImage(BitmapImageBuilder(image).build())
}
session.generateResponseAsync(resultListener)
val unused = session.generateResponseAsync(resultListener)
}
}

View file

@ -26,16 +26,13 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatView
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */
object LlmChatDestination {
@Serializable
val route = "LlmChatRoute"
}
object LlmAskImageDestination {
@Serializable
val route = "LlmAskImageRoute"
}
@ -44,9 +41,7 @@ fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmChatViewModel = viewModel(
factory = ViewModelProvider.Factory
),
viewModel: LlmChatViewModel = viewModel(factory = ViewModelProvider.Factory),
) {
ChatViewWrapper(
viewModel = viewModel,
@ -61,9 +56,7 @@ fun LlmAskImageScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmAskImageViewModel = viewModel(
factory = ViewModelProvider.Factory
),
viewModel: LlmAskImageViewModel = viewModel(factory = ViewModelProvider.Factory),
) {
ChatViewWrapper(
viewModel = viewModel,
@ -78,7 +71,7 @@ fun ChatViewWrapper(
viewModel: LlmChatViewModel,
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier
modifier: Modifier = Modifier,
) {
val context = LocalContext.current
@ -88,57 +81,58 @@ fun ChatViewWrapper(
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, messages ->
for (message in messages) {
viewModel.addMessage(
model = model,
message = message,
)
viewModel.addMessage(model = model, message = message)
}
var text = ""
var image: Bitmap? = null
val images: MutableList<Bitmap> = mutableListOf()
var chatMessageText: ChatMessageText? = null
for (message in messages) {
if (message is ChatMessageText) {
chatMessageText = message
text = message.content
} else if (message is ChatMessageImage) {
image = message.bitmap
images.add(message.bitmap)
}
}
if (text.isNotEmpty() && chatMessageText != null) {
modelManagerViewModel.addTextInputHistory(text)
viewModel.generateResponse(model = model, input = text, image = image, onError = {
viewModel.handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = chatMessageText,
)
})
viewModel.generateResponse(
model = model,
input = text,
images = images,
onError = {
viewModel.handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = chatMessageText,
)
},
)
}
},
onRunAgainClicked = { model, message ->
if (message is ChatMessageText) {
viewModel.runAgain(model = model, message = message, onError = {
viewModel.handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = message,
)
})
viewModel.runAgain(
model = model,
message = message,
onError = {
viewModel.handleError(
context = context,
model = model,
modelManagerViewModel = modelManagerViewModel,
triggeredMessage = message,
)
},
)
}
},
onBenchmarkClicked = { _, _, _, _ ->
},
onResetSessionClicked = { model ->
viewModel.resetSession(model = model)
},
onBenchmarkClicked = { _, _, _, _ -> },
onResetSessionClicked = { model -> viewModel.resetSession(model = model) },
showStopButtonInInputWhenInProgress = true,
onStopButtonClicked = { model ->
viewModel.stopResponse(model = model)
},
onStopButtonClicked = { model -> viewModel.stopResponse(model = model) },
navigateUp = navigateUp,
modifier = modifier,
)
}
}

View file

@ -22,8 +22,8 @@ import android.util.Log
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
@ -39,25 +39,28 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private const val TAG = "AGLlmChatViewModel"
private val STATS = listOf(
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec")
)
private val STATS =
listOf(
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec"),
)
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
fun generateResponse(
model: Model,
input: String,
images: List<Bitmap> = listOf(),
onError: () -> Unit,
) {
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
setPreparing(true)
// Loading.
addMessage(
model = model,
message = ChatMessageLoading(accelerator = accelerator),
)
addMessage(model = model, message = ChatMessageLoading(accelerator = accelerator))
// Wait for instance to be initialized.
while (model.instance == null) {
@ -68,9 +71,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
// Run inference.
val instance = model.instance as LlmModelInstance
var prefillTokens = instance.session.sizeInTokens(input)
if (image != null) {
prefillTokens += 257
}
prefillTokens += images.size * 257
var firstRun = true
var timeToFirstToken = 0f
@ -81,9 +82,10 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
val start = System.currentTimeMillis()
try {
LlmChatModelHelper.runInference(model = model,
LlmChatModelHelper.runInference(
model = model,
input = input,
image = image,
images = images,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
@ -106,18 +108,17 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
// Add an empty message that will receive streaming results.
addMessage(
model = model,
message = ChatMessageText(
content = "",
side = ChatSide.AGENT,
accelerator = accelerator
)
message =
ChatMessageText(content = "", side = ChatSide.AGENT, accelerator = accelerator),
)
}
// Incrementally update the streamed partial results.
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
updateLastTextMessageContentIncrementally(
model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat()
model = model,
partialContent = partialResult,
latencyMs = latencyMs.toFloat(),
)
if (done) {
@ -130,18 +131,21 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
if (lastMessage is ChatMessageText) {
updateLastTextMessageLlmBenchmarkResult(
model = model, llmBenchmarkResult = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
model = model,
llmBenchmarkResult =
ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues =
mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = false,
latencyMs = -1f,
accelerator = accelerator,
),
running = false,
latencyMs = -1f,
accelerator = accelerator,
)
)
}
}
@ -149,7 +153,8 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
cleanUpListener = {
setInProgress(false)
setPreparing(false)
})
},
)
} catch (e: Exception) {
Log.e(TAG, "Error occurred while running inference", e)
setInProgress(false)
@ -201,9 +206,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
addMessage(model = model, message = message.clone())
// Run inference.
generateResponse(
model = model, input = message.content, onError = onError
)
generateResponse(model = model, input = message.content, onError = onError)
}
}
@ -229,20 +232,18 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
// Add a warning message for re-initializing the session.
addMessage(
model = model,
message = ChatMessageWarning(content = "Error occurred. Re-initializing the session.")
message = ChatMessageWarning(content = "Error occurred. Re-initializing the session."),
)
// Add the triggered message back.
addMessage(model = model, message = triggeredMessage)
// Re-initialize the session/engine.
modelManagerViewModel.initializeModel(
context = context, task = task, model = model
)
modelManagerViewModel.initializeModel(context = context, task = task, model = model)
// Re-generate the response automatically.
generateResponse(model = model, input = triggeredMessage.content, onError = {})
}
}
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)

View file

@ -16,6 +16,10 @@
package com.google.ai.edge.gallery.ui.llmsingleturn
// import androidx.compose.ui.tooling.preview.Preview
// import com.google.ai.edge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
// import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.util.Log
import androidx.activity.compose.BackHandler
import androidx.compose.foundation.background
@ -39,7 +43,6 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection
import androidx.compose.ui.tooling.preview.Preview
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.ui.ViewModelProvider
@ -48,17 +51,12 @@ import com.google.ai.edge.gallery.ui.common.ModelPageAppBar
import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable
/** Navigation destination data */
object LlmSingleTurnDestination {
@Serializable
val route = "LlmSingleTurnRoute"
}
@ -69,9 +67,7 @@ fun LlmSingleTurnScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmSingleTurnViewModel = viewModel(
factory = ViewModelProvider.Factory
),
viewModel: LlmSingleTurnViewModel = viewModel(factory = ViewModelProvider.Factory),
) {
val task = viewModel.task
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
@ -95,9 +91,7 @@ fun LlmSingleTurnScreen(
}
// Handle system's edge swipe.
BackHandler {
handleNavigateUp()
}
BackHandler { handleNavigateUp() }
// Initialize model when model/download state changes.
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
@ -106,7 +100,7 @@ fun LlmSingleTurnScreen(
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
Log.d(
TAG,
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect"
"Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect",
)
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
}
@ -118,50 +112,55 @@ fun LlmSingleTurnScreen(
showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR
}
Scaffold(modifier = modifier, topBar = {
ModelPageAppBar(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
inProgress = uiState.inProgress,
modelPreparing = uiState.preparing,
onConfigChanged = { _, _ -> },
onBackClicked = { handleNavigateUp() },
onModelSelected = { newSelectedModel ->
scope.launch(Dispatchers.Default) {
// Clean up current model.
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
Scaffold(
modifier = modifier,
topBar = {
ModelPageAppBar(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
inProgress = uiState.inProgress,
modelPreparing = uiState.preparing,
onConfigChanged = { _, _ -> },
onBackClicked = { handleNavigateUp() },
onModelSelected = { newSelectedModel ->
scope.launch(Dispatchers.Default) {
// Clean up current model.
modelManagerViewModel.cleanupModel(task = task, model = selectedModel)
// Update selected model.
modelManagerViewModel.selectModel(model = newSelectedModel)
}
}
)
}) { innerPadding ->
Column(
modifier = Modifier.padding(
top = innerPadding.calculateTopPadding(),
start = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
end = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
// Update selected model.
modelManagerViewModel.selectModel(model = newSelectedModel)
}
},
)
},
) { innerPadding ->
Column(
modifier =
Modifier.padding(
top = innerPadding.calculateTopPadding(),
start = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
end = innerPadding.calculateStartPadding(LocalLayoutDirection.current),
)
) {
ModelDownloadStatusInfoPanel(
model = selectedModel,
task = task,
modelManagerViewModel = modelManagerViewModel
modelManagerViewModel = modelManagerViewModel,
)
// Main UI after model is downloaded.
val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
Box(
contentAlignment = Alignment.BottomCenter,
modifier = Modifier
.weight(1f)
// Just hide the UI without removing it from the screen so that the scroll syncing
// from ResponsePanel still works.
.alpha(if (modelDownloaded) 1.0f else 0.0f)
modifier =
Modifier.weight(1f)
// Just hide the UI without removing it from the screen so that the scroll syncing
// from ResponsePanel still works.
.alpha(if (modelDownloaded) 1.0f else 0.0f),
) {
VerticalSplitView(modifier = Modifier.fillMaxSize(),
VerticalSplitView(
modifier = Modifier.fillMaxSize(),
topView = {
PromptTemplatesPanel(
model = selectedModel,
@ -170,49 +169,47 @@ fun LlmSingleTurnScreen(
onSend = { fullPrompt ->
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
},
onStopButtonClicked = { model ->
viewModel.stopResponse(model = model)
},
modifier = Modifier.fillMaxSize()
onStopButtonClicked = { model -> viewModel.stopResponse(model = model) },
modifier = Modifier.fillMaxSize(),
)
},
bottomView = {
Box(
contentAlignment = Alignment.BottomCenter,
modifier = Modifier
.fillMaxSize()
.background(MaterialTheme.customColors.agentBubbleBgColor)
modifier =
Modifier.fillMaxSize().background(MaterialTheme.customColors.agentBubbleBgColor),
) {
ResponsePanel(
model = selectedModel,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
modifier = Modifier
.fillMaxSize()
.padding(bottom = innerPadding.calculateBottomPadding())
modifier =
Modifier.fillMaxSize().padding(bottom = innerPadding.calculateBottomPadding()),
)
}
})
},
)
}
if (showErrorDialog) {
ErrorDialog(error = modelInitializationStatus?.error ?: "", onDismiss = {
showErrorDialog = false
})
ErrorDialog(
error = modelInitializationStatus?.error ?: "",
onDismiss = { showErrorDialog = false },
)
}
}
}
}
@Preview(showBackground = true)
@Composable
fun LlmSingleTurnScreenPreview() {
val context = LocalContext.current
GalleryTheme {
LlmSingleTurnScreen(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
viewModel = PreviewLlmSingleTurnViewModel(),
navigateUp = {},
)
}
}
// @Preview(showBackground = true)
// @Composable
// fun LlmSingleTurnScreenPreview() {
// val context = LocalContext.current
// GalleryTheme {
// LlmSingleTurnScreen(
// modelManagerViewModel = PreviewModelManagerViewModel(context = context),
// viewModel = PreviewLlmSingleTurnViewModel(),
// navigateUp = {},
// )
// }
// }

View file

@ -19,12 +19,12 @@ package com.google.ai.edge.gallery.ui.llmsingleturn
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.common.processLlmResponse
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.common.processLlmResponse
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance
import kotlinx.coroutines.Dispatchers
@ -34,12 +34,10 @@ import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
private const val TAG = "AGLlmSingleTurnViewModel"
private const val TAG = "AGLlmSingleTurnVM"
data class LlmSingleTurnUiState(
/**
* Indicates whether the runtime is currently processing a message.
*/
/** Indicates whether the runtime is currently processing a message. */
val inProgress: Boolean = false,
/**
@ -57,12 +55,13 @@ data class LlmSingleTurnUiState(
val selectedPromptTemplateType: PromptTemplateType = PromptTemplateType.entries[0],
)
private val STATS = listOf(
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec")
)
private val STATS =
listOf(
Stat(id = "time_to_first_token", label = "1st token", unit = "sec"),
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec"),
)
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
@ -94,7 +93,8 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
val start = System.currentTimeMillis()
var response = ""
var lastBenchmarkUpdateTs = 0L
LlmChatModelHelper.runInference(model = model,
LlmChatModelHelper.runInference(
model = model,
input = input,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
@ -116,7 +116,7 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
updateResponse(
model = model,
promptTemplateType = uiState.value.selectedPromptTemplateType,
response = response
response = response,
)
// Update benchmark (with throttling).
@ -125,21 +125,23 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
if (decodeSpeed.isNaN()) {
decodeSpeed = 0f
}
val benchmark = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = !done,
latencyMs = -1f,
)
val benchmark =
ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues =
mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = !done,
latencyMs = -1f,
)
updateBenchmark(
model = model,
promptTemplateType = uiState.value.selectedPromptTemplateType,
benchmark = benchmark
benchmark = benchmark,
)
lastBenchmarkUpdateTs = curTs
}
@ -151,7 +153,8 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
cleanUpListener = {
setPreparing(false)
setInProgress(false)
})
},
)
}
}
@ -161,7 +164,9 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
// Clear response.
updateResponse(model = model, promptTemplateType = promptTemplateType, response = "")
this._uiState.update { this.uiState.value.copy(selectedPromptTemplateType = promptTemplateType) }
this._uiState.update {
this.uiState.value.copy(selectedPromptTemplateType = promptTemplateType)
}
}
fun setInProgress(inProgress: Boolean) {
@ -184,7 +189,9 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
}
fun updateBenchmark(
model: Model, promptTemplateType: PromptTemplateType, benchmark: ChatMessageBenchmarkLlmResult
model: Model,
promptTemplateType: PromptTemplateType,
benchmark: ChatMessageBenchmarkLlmResult,
) {
_uiState.update { currentState ->
val currentBenchmark = currentState.benchmarkByModel
@ -218,4 +225,4 @@ open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewMo
benchmarkByModel = benchmarkByModel,
)
}
}
}

View file

@ -16,21 +16,23 @@
package com.google.ai.edge.gallery.ui.llmsingleturn
import androidx.compose.ui.graphics.Brush.Companion.linearGradient
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.text.buildAnnotatedString
import androidx.compose.ui.text.withStyle
import androidx.compose.ui.text.SpanStyle
import androidx.compose.ui.graphics.Brush.Companion.linearGradient
enum class PromptTemplateInputEditorType {
SINGLE_SELECT
}
enum class RewriteToneType(val label: String) {
FORMAL(label = "Formal"), CASUAL(label = "Casual"), FRIENDLY(label = "Friendly"), POLITE(label = "Polite"), ENTHUSIASTIC(
label = "Enthusiastic"
),
FORMAL(label = "Formal"),
CASUAL(label = "Casual"),
FRIENDLY(label = "Friendly"),
POLITE(label = "Polite"),
ENTHUSIASTIC(label = "Enthusiastic"),
CONCISE(label = "Concise"),
}
@ -69,51 +71,60 @@ class PromptTemplateSingleSelectInputEditor(
override val label: String,
val options: List<String> = listOf(),
override val defaultOption: String = "",
) : PromptTemplateInputEditor(
label = label, type = PromptTemplateInputEditorType.SINGLE_SELECT, defaultOption = defaultOption
)
) :
PromptTemplateInputEditor(
label = label,
type = PromptTemplateInputEditorType.SINGLE_SELECT,
defaultOption = defaultOption,
)
data class PromptTemplateConfig(val inputEditors: List<PromptTemplateInputEditor> = listOf())
private val GEMINI_GRADIENT_STYLE = SpanStyle(
brush = linearGradient(
colors = listOf(Color(0xFF4285f4), Color(0xFF9b72cb), Color(0xFFd96570))
private val GEMINI_GRADIENT_STYLE =
SpanStyle(
brush = linearGradient(colors = listOf(Color(0xFF4285f4), Color(0xFF9b72cb), Color(0xFFd96570)))
)
)
@Suppress("ImmutableEnum")
enum class PromptTemplateType(
val label: String,
val config: PromptTemplateConfig,
val genFullPrompt: (userInput: String, inputEditorValues: Map<String, Any>) -> AnnotatedString = { _, _ ->
AnnotatedString("")
},
val genFullPrompt: (userInput: String, inputEditorValues: Map<String, Any>) -> AnnotatedString =
{ _, _ ->
AnnotatedString("")
},
val examplePrompts: List<String> = listOf(),
) {
FREE_FORM(
label = "Free form",
config = PromptTemplateConfig(),
genFullPrompt = { userInput, _ -> AnnotatedString(userInput) },
examplePrompts = listOf(
"Suggest 3 topics for a podcast about \"Friendships in your 20s\".",
"Outline the key sections needed in a basic logo design brief.",
"List 3 pros and 3 cons to consider before buying a smart watch.",
"Write a short, optimistic quote about the future of technology.",
"Generate 3 potential names for a mobile app that helps users identify plants.",
"Explain the difference between AI and machine learning in 2 sentences.",
"Create a simple haiku about a cat sleeping in the sun.",
"List 3 ways to make instant noodles taste better using common kitchen ingredients."
)
examplePrompts =
listOf(
"Suggest 3 topics for a podcast about \"Friendships in your 20s\".",
"Outline the key sections needed in a basic logo design brief.",
"List 3 pros and 3 cons to consider before buying a smart watch.",
"Write a short, optimistic quote about the future of technology.",
"Generate 3 potential names for a mobile app that helps users identify plants.",
"Explain the difference between AI and machine learning in 2 sentences.",
"Create a simple haiku about a cat sleeping in the sun.",
"List 3 ways to make instant noodles taste better using common kitchen ingredients.",
),
),
REWRITE_TONE(
label = "Rewrite tone", config = PromptTemplateConfig(
inputEditors = listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.TONE.label,
options = RewriteToneType.entries.map { it.label },
defaultOption = RewriteToneType.FORMAL.label
)
)
), genFullPrompt = { userInput, inputEditorValues ->
label = "Rewrite tone",
config =
PromptTemplateConfig(
inputEditors =
listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.TONE.label,
options = RewriteToneType.entries.map { it.label },
defaultOption = RewriteToneType.FORMAL.label,
)
)
),
genFullPrompt = { userInput, inputEditorValues ->
val tone = inputEditorValues[InputEditorLabel.TONE.label] as String
buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) {
@ -121,25 +132,29 @@ enum class PromptTemplateType(
}
append(userInput)
}
}, examplePrompts = listOf(
"Hey team, just wanted to remind everyone about the meeting tomorrow @ 10. Be there!",
"Our new software update includes several bug fixes and performance improvements.",
"Due to the fact that the weather was bad, we decided to postpone the event.",
"Please find attached the requested documentation for your perusal.",
"Welcome to the team. Review the onboarding materials.",
)
},
examplePrompts =
listOf(
"Hey team, just wanted to remind everyone about the meeting tomorrow @ 10. Be there!",
"Our new software update includes several bug fixes and performance improvements.",
"Due to the fact that the weather was bad, we decided to postpone the event.",
"Please find attached the requested documentation for your perusal.",
"Welcome to the team. Review the onboarding materials.",
),
),
SUMMARIZE_TEXT(
label = "Summarize text",
config = PromptTemplateConfig(
inputEditors = listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.STYLE.label,
options = SummarizationType.entries.map { it.label },
defaultOption = SummarizationType.KEY_BULLET_POINT.label
)
)
),
config =
PromptTemplateConfig(
inputEditors =
listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.STYLE.label,
options = SummarizationType.entries.map { it.label },
defaultOption = SummarizationType.KEY_BULLET_POINT.label,
)
)
),
genFullPrompt = { userInput, inputEditorValues ->
val style = inputEditorValues[InputEditorLabel.STYLE.label] as String
buildAnnotatedString {
@ -149,37 +164,38 @@ enum class PromptTemplateType(
append(userInput)
}
},
examplePrompts = listOf(
"The new Pixel phone features an advanced camera system with improved low-light performance and AI-powered editing tools. The display is brighter and more energy-efficient. It runs on the latest Tensor chip, offering faster processing and enhanced security features. Battery life has also been extended, providing all-day power for most users.",
"Beginning this Friday, January 24, giant pandas Bao Li and Qing Bao are officially on view to the public at the Smithsonians National Zoo and Conservation Biology Institute (NZCBI). The 3-year-old bears arrived in Washington this past October, undergoing a quarantine period before making their debut. Under NZCBIs new agreement with the CWCA, Qing Bao and Bao Li will remain in the United States for ten years, until April 2034, in exchange for an annual fee of \$1 million. The pair are still too young to breed, as pandas only reach sexual maturity between ages 4 and 7. “Kind of picture them as like awkward teenagers right now,” Lally told WUSA9. “We still have about two years before we would probably even see signs that theyre ready to start mating.”",
),
examplePrompts =
listOf(
"The new Pixel phone features an advanced camera system with improved low-light performance and AI-powered editing tools. The display is brighter and more energy-efficient. It runs on the latest Tensor chip, offering faster processing and enhanced security features. Battery life has also been extended, providing all-day power for most users.",
"Beginning this Friday, January 24, giant pandas Bao Li and Qing Bao are officially on view to the public at the Smithsonians National Zoo and Conservation Biology Institute (NZCBI). The 3-year-old bears arrived in Washington this past October, undergoing a quarantine period before making their debut. Under NZCBIs new agreement with the CWCA, Qing Bao and Bao Li will remain in the United States for ten years, until April 2034, in exchange for an annual fee of \$1 million. The pair are still too young to breed, as pandas only reach sexual maturity between ages 4 and 7. “Kind of picture them as like awkward teenagers right now,” Lally told WUSA9. “We still have about two years before we would probably even see signs that theyre ready to start mating.”",
),
),
CODE_SNIPPET(
label = "Code snippet",
config = PromptTemplateConfig(
inputEditors = listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.LANGUAGE.label,
options = LanguageType.entries.map { it.label },
defaultOption = LanguageType.JAVASCRIPT.label
)
)
),
config =
PromptTemplateConfig(
inputEditors =
listOf(
PromptTemplateSingleSelectInputEditor(
label = InputEditorLabel.LANGUAGE.label,
options = LanguageType.entries.map { it.label },
defaultOption = LanguageType.JAVASCRIPT.label,
)
)
),
genFullPrompt = { userInput, inputEditorValues ->
val language = inputEditorValues[InputEditorLabel.LANGUAGE.label] as String
buildAnnotatedString {
withStyle(GEMINI_GRADIENT_STYLE) {
append("Write a $language code snippet to ")
}
withStyle(GEMINI_GRADIENT_STYLE) { append("Write a $language code snippet to ") }
append(userInput)
}
},
examplePrompts = listOf(
"Create an alert box that says \"Hello, World!\"",
"Declare an immutable variable named 'appName' with the value \"AI Gallery\"",
"Print the numbers from 1 to 5 using a for loop.",
"Write a function that returns the square of an integer input.",
),
examplePrompts =
listOf(
"Create an alert box that says \"Hello, World!\"",
"Declare an immutable variable named 'appName' with the value \"AI Gallery\"",
"Print the numbers from 1 to 5 using a for loop.",
"Write a function that returns the square of an integer input.",
),
),
}

View file

@ -16,6 +16,7 @@
package com.google.ai.edge.gallery.ui.llmsingleturn
import android.content.ClipData
import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.background
import androidx.compose.foundation.border
@ -79,7 +80,8 @@ import androidx.compose.ui.draw.clip
import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.platform.ClipEntry
import androidx.compose.ui.platform.LocalClipboard
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.text.TextLayoutResult
@ -108,7 +110,7 @@ fun PromptTemplatesPanel(
modelManagerViewModel: ModelManagerViewModel,
onSend: (fullPrompt: String) -> Unit,
onStopButtonClicked: (Model) -> Unit,
modifier: Modifier = Modifier
modifier: Modifier = Modifier,
) {
val scope = rememberCoroutineScope()
val uiState by viewModel.uiState.collectAsState()
@ -125,13 +127,12 @@ fun PromptTemplatesPanel(
uiState.selectedPromptTemplateType.genFullPrompt(curTextInputContent, inputEditorValues)
}
}
val clipboardManager = LocalClipboardManager.current
val clipboard = LocalClipboard.current
val focusRequester = remember { FocusRequester() }
val focusManager = LocalFocusManager.current
val interactionSource = remember { MutableInteractionSource() }
val expandedStates = remember { mutableStateMapOf<String, Boolean>() }
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[model.name]
val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[model.name]
// Update input editor values when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) {
@ -147,11 +148,10 @@ fun PromptTemplatesPanel(
Column(modifier = modifier) {
// Scrollable tab row for all prompt templates.
PrimaryScrollableTabRow(
selectedTabIndex = selectedTabIndex
) {
PrimaryScrollableTabRow(selectedTabIndex = selectedTabIndex) {
TAB_TITLES.forEachIndexed { index, title ->
Tab(selected = selectedTabIndex == index,
Tab(
selected = selectedTabIndex == index,
enabled = !inProgress,
onClick = {
// Clear input when tab changes.
@ -162,41 +162,34 @@ fun PromptTemplatesPanel(
selectedTabIndex = index
viewModel.selectPromptTemplate(
model = model,
promptTemplateType = promptTemplateTypes[index]
promptTemplateType = promptTemplateTypes[index],
)
},
text = {
Text(
text = title,
modifier = Modifier.alpha(if (inProgress) 0.5f else 1f)
)
})
text = { Text(text = title, modifier = Modifier.alpha(if (inProgress) 0.5f else 1f)) },
)
}
}
// Content.
Column(
modifier = Modifier
.weight(1f)
.fillMaxWidth()
) {
Column(modifier = Modifier.weight(1f).fillMaxWidth()) {
// Input editor row.
if (selectedPromptTemplateType.config.inputEditors.isNotEmpty()) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
modifier = Modifier
.fillMaxWidth()
.background(MaterialTheme.colorScheme.surfaceContainerLow)
.padding(horizontal = 16.dp, vertical = 10.dp)
modifier =
Modifier.fillMaxWidth()
.background(MaterialTheme.colorScheme.surfaceContainerLow)
.padding(horizontal = 16.dp, vertical = 10.dp),
) {
// Input editors.
for (inputEditor in selectedPromptTemplateType.config.inputEditors) {
when (inputEditor.type) {
PromptTemplateInputEditorType.SINGLE_SELECT -> SingleSelectButton(config = inputEditor as PromptTemplateSingleSelectInputEditor,
onSelected = { option ->
inputEditorValues[inputEditor.label] = option
})
PromptTemplateInputEditorType.SINGLE_SELECT ->
SingleSelectButton(
config = inputEditor as PromptTemplateSingleSelectInputEditor,
onSelected = { option -> inputEditorValues[inputEditor.label] = option },
)
}
}
}
@ -205,12 +198,10 @@ fun PromptTemplatesPanel(
// Text input box.
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(rememberScrollState())
.clickable(
modifier =
Modifier.fillMaxSize().verticalScroll(rememberScrollState()).clickable(
interactionSource = interactionSource,
indication = null // Disable the ripple effect
indication = null, // Disable the ripple effect
) {
// Request focus on the TextField when the Column is clicked
focusRequester.requestFocus()
@ -220,32 +211,31 @@ fun PromptTemplatesPanel(
Text(
fullPrompt,
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
.padding(bottom = 40.dp)
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
.background(MaterialTheme.customColors.agentBubbleBgColor)
.padding(16.dp)
.focusRequester(focusRequester)
modifier =
Modifier.fillMaxWidth()
.padding(16.dp)
.padding(bottom = 40.dp)
.clip(MessageBubbleShape(radius = bubbleBorderRadius))
.background(MaterialTheme.customColors.agentBubbleBgColor)
.padding(16.dp)
.focusRequester(focusRequester),
)
} else {
TextField(
value = curTextInputContent,
onValueChange = { curTextInputContent = it },
colors = TextFieldDefaults.colors(
unfocusedContainerColor = Color.Transparent,
focusedContainerColor = Color.Transparent,
focusedIndicatorColor = Color.Transparent,
unfocusedIndicatorColor = Color.Transparent,
disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent,
),
colors =
TextFieldDefaults.colors(
unfocusedContainerColor = Color.Transparent,
focusedContainerColor = Color.Transparent,
focusedIndicatorColor = Color.Transparent,
unfocusedIndicatorColor = Color.Transparent,
disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyLarge,
placeholder = { Text("Enter content") },
modifier = Modifier
.padding(bottom = 40.dp)
.focusRequester(focusRequester)
modifier = Modifier.padding(bottom = 40.dp).focusRequester(focusRequester),
)
}
}
@ -254,26 +244,35 @@ fun PromptTemplatesPanel(
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp),
modifier = Modifier
.fillMaxWidth()
.padding(vertical = 4.dp, horizontal = 16.dp)
modifier = Modifier.fillMaxWidth().padding(vertical = 4.dp, horizontal = 16.dp),
) {
// Full prompt switch.
if (selectedPromptTemplateType != PromptTemplateType.FREE_FORM && curTextInputContent.isNotEmpty()) {
Row(verticalAlignment = Alignment.CenterVertically,
if (
selectedPromptTemplateType != PromptTemplateType.FREE_FORM &&
curTextInputContent.isNotEmpty()
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp),
modifier = Modifier
.clip(CircleShape)
.background(if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.customColors.agentBubbleBgColor)
.clickable {
inputEditorValues[FULL_PROMPT_SWITCH_KEY] =
!(inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean)
}
.height(40.dp)
.border(
width = 1.dp, color = MaterialTheme.colorScheme.surface, shape = CircleShape
)
.padding(horizontal = 12.dp)) {
modifier =
Modifier.clip(CircleShape)
.background(
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean)
MaterialTheme.colorScheme.secondaryContainer
else MaterialTheme.customColors.agentBubbleBgColor
)
.clickable {
inputEditorValues[FULL_PROMPT_SWITCH_KEY] =
!(inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean)
}
.height(40.dp)
.border(
width = 1.dp,
color = MaterialTheme.colorScheme.surface,
shape = CircleShape,
)
.padding(horizontal = 12.dp),
) {
if (inputEditorValues[FULL_PROMPT_SWITCH_KEY] as Boolean) {
Icon(
imageVector = Icons.Rounded.Visibility,
@ -284,9 +283,7 @@ fun PromptTemplatesPanel(
Icon(
imageVector = Icons.Rounded.VisibilityOff,
contentDescription = "",
modifier = Modifier
.size(FilterChipDefaults.IconSize)
.alpha(0.3f),
modifier = Modifier.size(FilterChipDefaults.IconSize).alpha(0.3f),
)
}
Text("Preview prompt", style = MaterialTheme.typography.labelMedium)
@ -299,20 +296,27 @@ fun PromptTemplatesPanel(
if (curTextInputContent.isNotEmpty()) {
OutlinedIconButton(
onClick = {
val clipData = fullPrompt
clipboardManager.setText(clipData)
scope.launch {
val clipData = ClipData.newPlainText("prompt", fullPrompt)
val clipEntry = ClipEntry(clipData = clipData)
clipboard.setClipEntry(clipEntry = clipEntry)
}
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
disabledContainerColor =
MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE)
modifier = Modifier.size(ICON_BUTTON_SIZE),
) {
Icon(
Icons.Outlined.ContentCopy, contentDescription = "", modifier = Modifier.size(20.dp)
Icons.Outlined.ContentCopy,
contentDescription = "",
modifier = Modifier.size(20.dp),
)
}
}
@ -321,38 +325,35 @@ fun PromptTemplatesPanel(
OutlinedIconButton(
enabled = !inProgress,
onClick = { showExamplePromptBottomSheet = true },
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
disabledContainerColor = MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.customColors.agentBubbleBgColor,
disabledContainerColor =
MaterialTheme.customColors.agentBubbleBgColor.copy(alpha = 0.4f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.2f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE)
modifier = Modifier.size(ICON_BUTTON_SIZE),
) {
Icon(
Icons.Rounded.Add,
contentDescription = "",
modifier = Modifier.size(20.dp),
)
Icon(Icons.Rounded.Add, contentDescription = "", modifier = Modifier.size(20.dp))
}
val modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING
if (inProgress && !modelInitializing && !uiState.preparing) {
IconButton(
onClick = {
onStopButtonClicked(model)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
),
modifier = Modifier.size(ICON_BUTTON_SIZE)
onClick = { onStopButtonClicked(model) },
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer
),
modifier = Modifier.size(ICON_BUTTON_SIZE),
) {
Icon(
Icons.Rounded.Stop,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
tint = MaterialTheme.colorScheme.primary,
)
}
} else {
@ -363,21 +364,21 @@ fun PromptTemplatesPanel(
focusManager.clearFocus()
onSend(fullPrompt.text)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
disabledContainerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
),
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
disabledContainerColor =
MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f),
contentColor = MaterialTheme.colorScheme.primary,
disabledContentColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f),
),
border = BorderStroke(width = 1.dp, color = MaterialTheme.colorScheme.surface),
modifier = Modifier.size(ICON_BUTTON_SIZE)
modifier = Modifier.size(ICON_BUTTON_SIZE),
) {
Icon(
Icons.AutoMirrored.Rounded.Send,
contentDescription = "",
modifier = Modifier
.size(20.dp)
.offset(x = 2.dp),
modifier = Modifier.size(20.dp).offset(x = 2.dp),
)
}
}
@ -396,89 +397,82 @@ fun PromptTemplatesPanel(
// Title
Text(
"Select an example",
modifier = Modifier
.fillMaxWidth()
.padding(16.dp),
style = MaterialTheme.typography.titleLarge
modifier = Modifier.fillMaxWidth().padding(16.dp),
style = MaterialTheme.typography.titleLarge,
)
// Examples
for (prompt in selectedPromptTemplateType.examplePrompts) {
var textLayoutResultState by remember { mutableStateOf<TextLayoutResult?>(null) }
val hasOverflow = remember(textLayoutResultState) {
textLayoutResultState?.hasVisualOverflow ?: false
}
val hasOverflow =
remember(textLayoutResultState) { textLayoutResultState?.hasVisualOverflow ?: false }
val isExpanded = expandedStates[prompt] ?: false
Column(
modifier = Modifier
.fillMaxWidth()
.clickable {
curTextInputContent = prompt
scope.launch {
// Give it sometime to show the click effect.
delay(200)
showExamplePromptBottomSheet = false
modifier =
Modifier.fillMaxWidth()
.clickable {
curTextInputContent = prompt
scope.launch {
// Give it sometime to show the click effect.
delay(200)
showExamplePromptBottomSheet = false
}
}
}
.padding(horizontal = 16.dp, vertical = 8.dp),
.padding(horizontal = 16.dp, vertical = 8.dp)
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp),
) {
Icon(Icons.Outlined.Description, contentDescription = "")
Text(prompt,
Text(
prompt,
maxLines = if (isExpanded) Int.MAX_VALUE else 3,
overflow = TextOverflow.Ellipsis,
style = MaterialTheme.typography.bodySmall,
modifier = Modifier.weight(1f),
onTextLayout = { textLayoutResultState = it }
onTextLayout = { textLayoutResultState = it },
)
}
if (hasOverflow && !isExpanded) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 2.dp),
horizontalArrangement = Arrangement.End
modifier = Modifier.fillMaxWidth().padding(top = 2.dp),
horizontalArrangement = Arrangement.End,
) {
Box(modifier = Modifier
.padding(end = 16.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
.clickable {
expandedStates[prompt] = true
}
.padding(vertical = 1.dp, horizontal = 6.dp)) {
Box(
modifier =
Modifier.padding(end = 16.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
.clickable { expandedStates[prompt] = true }
.padding(vertical = 1.dp, horizontal = 6.dp)
) {
Icon(
Icons.Outlined.ExpandMore,
contentDescription = "",
modifier = Modifier.size(12.dp)
modifier = Modifier.size(12.dp),
)
}
}
} else if (isExpanded) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 2.dp),
horizontalArrangement = Arrangement.End
modifier = Modifier.fillMaxWidth().padding(top = 2.dp),
horizontalArrangement = Arrangement.End,
) {
Box(modifier = Modifier
.padding(end = 16.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
.clickable {
expandedStates[prompt] = false
}
.padding(vertical = 1.dp, horizontal = 6.dp)) {
Box(
modifier =
Modifier.padding(end = 16.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHighest)
.clickable { expandedStates[prompt] = false }
.padding(vertical = 1.dp, horizontal = 6.dp)
) {
Icon(
Icons.Outlined.ExpandLess,
contentDescription = "",
modifier = Modifier.size(12.dp)
modifier = Modifier.size(12.dp),
)
}
}

View file

@ -16,6 +16,7 @@
package com.google.ai.edge.gallery.ui.llmsingleturn
import android.content.ClipData
import android.util.Log
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
@ -49,24 +50,26 @@ import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshotFlow
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.platform.ClipEntry
import androidx.compose.ui.platform.LocalClipboard
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.ui.common.chat.MarkdownText
import com.google.ai.edge.gallery.ui.common.MarkdownText
import com.google.ai.edge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.ai.edge.gallery.ui.common.chat.MessageBodyLoading
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.modelmanager.PagerScrollState
import kotlinx.coroutines.launch
private val OPTIONS = listOf("Response", "Benchmark")
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
@ -88,23 +91,21 @@ fun ResponsePanel(
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val responseScrollState = rememberScrollState()
var selectedOptionIndex by remember { mutableIntStateOf(0) }
val clipboardManager = LocalClipboardManager.current
val pagerState = rememberPagerState(
initialPage = task.models.indexOf(model),
pageCount = { task.models.size })
val clipboard = LocalClipboard.current
val scope = rememberCoroutineScope()
val pagerState =
rememberPagerState(initialPage = task.models.indexOf(model), pageCount = { task.models.size })
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
// Select the "response" tab when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) {
selectedOptionIndex = 0
}
LaunchedEffect(selectedPromptTemplateType) { selectedOptionIndex = 0 }
// Update selected model and clean up previous model when page is settled on a model page.
LaunchedEffect(pagerState.settledPage) {
val curSelectedModel = task.models[pagerState.settledPage]
Log.d(
TAG,
"Pager settled on model '${curSelectedModel.name}' from '${model.name}'. Updating selected model."
"Pager settled on model '${curSelectedModel.name}' from '${model.name}'. Updating selected model.",
)
if (curSelectedModel.name != model.name) {
modelManagerViewModel.cleanupModel(task = task, model = model)
@ -115,13 +116,12 @@ fun ResponsePanel(
// Trigger scroll sync.
LaunchedEffect(pagerState) {
snapshotFlow {
PagerScrollState(
page = pagerState.currentPage,
offset = pagerState.currentPageOffsetFraction
)
}.collect { scrollState ->
modelManagerViewModel.pagerScrollState.value = scrollState
}
PagerScrollState(
page = pagerState.currentPage,
offset = pagerState.currentPageOffsetFraction,
)
}
.collect { scrollState -> modelManagerViewModel.pagerScrollState.value = scrollState }
}
// Scroll pager when selected model changes.
@ -147,9 +147,7 @@ fun ResponsePanel(
if (initializing) {
Box(
contentAlignment = Alignment.TopStart,
modifier = modifier
.fillMaxSize()
.padding(horizontal = 16.dp)
modifier = modifier.fillMaxSize().padding(horizontal = 16.dp),
) {
MessageBodyLoading()
}
@ -159,7 +157,7 @@ fun ResponsePanel(
Row(
modifier = Modifier.fillMaxSize(),
horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically
verticalAlignment = Alignment.CenterVertically,
) {
Text(
"Response will appear here",
@ -170,11 +168,7 @@ fun ResponsePanel(
}
// Response markdown.
else {
Column(
modifier = modifier
.padding(horizontal = 16.dp)
.padding(bottom = 4.dp)
) {
Column(modifier = modifier.padding(horizontal = 16.dp).padding(bottom = 4.dp)) {
// Response/benchmark switch.
Row(modifier = Modifier.fillMaxWidth()) {
PrimaryTabRow(
@ -182,66 +176,64 @@ fun ResponsePanel(
containerColor = Color.Transparent,
) {
OPTIONS.forEachIndexed { index, title ->
Tab(selected = selectedOptionIndex == index, onClick = {
selectedOptionIndex = index
}, text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp)
) {
Icon(
ICONS[index],
contentDescription = "",
modifier = Modifier
.size(16.dp)
.alpha(0.7f)
)
var curTitle = title
if (accelerator.isNotEmpty()) {
curTitle = "$curTitle on $accelerator"
}
val titleColor = MaterialTheme.colorScheme.primary
BasicText(
text = curTitle,
maxLines = 1,
color = { titleColor },
style = MaterialTheme.typography.bodyMedium,
autoSize = TextAutoSize.StepBased(
minFontSize = 9.sp,
maxFontSize = 14.sp,
stepSize = 1.sp
Tab(
selected = selectedOptionIndex == index,
onClick = { selectedOptionIndex = index },
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp),
) {
Icon(
ICONS[index],
contentDescription = "",
modifier = Modifier.size(16.dp).alpha(0.7f),
)
)
}
})
var curTitle = title
if (accelerator.isNotEmpty()) {
curTitle = "$curTitle on $accelerator"
}
val titleColor = MaterialTheme.colorScheme.primary
BasicText(
text = curTitle,
maxLines = 1,
color = { titleColor },
style = MaterialTheme.typography.bodyMedium,
autoSize =
TextAutoSize.StepBased(
minFontSize = 9.sp,
maxFontSize = 14.sp,
stepSize = 1.sp,
),
)
}
},
)
}
}
}
if (selectedOptionIndex == 0) {
Box(
contentAlignment = Alignment.BottomEnd,
modifier = Modifier.weight(1f)
) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(responseScrollState)
) {
Box(contentAlignment = Alignment.BottomEnd, modifier = Modifier.weight(1f)) {
Column(modifier = Modifier.fillMaxSize().verticalScroll(responseScrollState)) {
MarkdownText(
text = response,
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp)
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp),
)
}
// Copy button.
IconButton(
onClick = {
val clipData = AnnotatedString(response)
clipboardManager.setText(clipData)
scope.launch {
val clipData = ClipData.newPlainText("response", response)
val clipEntry = ClipEntry(clipData = clipData)
clipboard.setClipEntry(clipEntry = clipEntry)
}
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
contentColor = MaterialTheme.colorScheme.primary,
),
colors =
IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
contentColor = MaterialTheme.colorScheme.primary,
),
) {
Icon(
Icons.Outlined.ContentCopy,

Some files were not shown because too many files have changed in this diff Show more