Initial checkin

This commit is contained in:
Jing Jin 2025-04-14 16:42:40 -07:00
parent 9f4376e65c
commit ea31fd0544
152 changed files with 14171 additions and 1 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
.DS_Store

35
Android/.gitignore vendored Normal file
View file

@ -0,0 +1,35 @@
# Gradle files
.gradle/
build/
# Local configuration file (sdk path, etc)
local.properties
# Log/OS Files
*.log
# Android Studio generated files and folders
captures/
.externalNativeBuild/
.cxx/
*.apk
output.json
# IntelliJ
*.iml
.idea/
misc.xml
deploymentTargetDropDown.xml
render.experimental.xml
# Keystore files
*.jks
*.keystore
# Google Services (e.g. APIs or Firebase)
google-services.json
# Android Profiling
*.hprof
.DS_Store

1
Android/README.md Normal file
View file

@ -0,0 +1 @@
# AI Edge Gallery (Android)

15
Android/src/.gitignore vendored Normal file
View file

@ -0,0 +1,15 @@
*.iml
.gradle
/local.properties
/.idea/caches
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
/.idea/navEditor.xml
/.idea/assetWizardSettings.xml
.DS_Store
/build
/captures
.externalNativeBuild
.cxx
local.properties

1
Android/src/app/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/build

View file

@ -0,0 +1,101 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
plugins {
alias(libs.plugins.android.application)
alias(libs.plugins.kotlin.android)
alias(libs.plugins.kotlin.compose)
alias(libs.plugins.kotlin.serialization)
}
android {
namespace = "com.google.aiedge.gallery"
compileSdk = 35
defaultConfig {
applicationId = "com.google.aiedge.gallery"
minSdk = 24
targetSdk = 35
versionCode = 1
versionName = "1.0"
// Needed for HuggingFace auth workflows.
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.aiedge.gallery.oauth"
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
signingConfig = signingConfigs.getByName("debug")
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
}
kotlinOptions {
jvmTarget = "11"
freeCompilerArgs += "-Xcontext-receivers"
}
buildFeatures {
compose = true
}
}
dependencies {
implementation(libs.androidx.core.ktx)
implementation(libs.androidx.lifecycle.runtime.ktx)
implementation(libs.androidx.activity.compose)
implementation(platform(libs.androidx.compose.bom))
implementation(libs.androidx.ui)
implementation(libs.androidx.ui.graphics)
implementation(libs.androidx.ui.tooling.preview)
implementation(libs.androidx.material3)
implementation(libs.androidx.compose.navigation)
implementation(libs.kotlinx.serialization.json)
implementation(libs.material.icon.extended)
implementation(libs.androidx.work.runtime)
implementation(libs.androidx.datastore.preferences)
implementation(libs.com.google.code.gson)
implementation(libs.androidx.lifecycle.process)
implementation(libs.mediapipe.tasks.text)
implementation(libs.mediapipe.tasks.genai)
implementation(libs.mediapipe.tasks.imagegen)
implementation(libs.commonmark)
implementation(libs.richtext)
implementation(libs.tflite)
implementation(libs.tflite.gpu)
implementation(libs.tflite.support)
implementation(libs.camerax.core)
implementation(libs.camerax.camera2)
implementation(libs.camerax.lifecycle)
implementation(libs.camerax.view)
implementation(libs.openid.appauth)
implementation(libs.androidx.splashscreen)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
androidTestImplementation(platform(libs.androidx.compose.bom))
androidTestImplementation(libs.androidx.ui.test.junit4)
debugImplementation(libs.androidx.ui.tooling)
debugImplementation(libs.androidx.ui.test.manifest)
}

21
Android/src/app/proguard-rules.pro vendored Normal file
View file

@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View file

@ -0,0 +1,40 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.Assert.*
/**
* Instrumented test, which will execute on an Android device.
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
@RunWith(AndroidJUnit4::class)
class ExampleInstrumentedTest {
@Test
fun useAppContext() {
// Context of the app under test.
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
assertEquals("com.google.aiedge.gallery", appContext.packageName)
}
}

View file

@ -0,0 +1,83 @@
<?xml version="1.0" encoding="utf-8"?>
<!--
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature
android:name="android.hardware.camera"
android:required="false" />
<application
android:name=".GalleryApplication"
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher"
android:supportsRtl="true"
android:theme="@style/Theme.Gallery"
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:exported="true"
android:theme="@style/Theme.Gallery.SplashScreen"
android:windowSoftInputMode="adjustResize">
<!-- This is for putting the app into launcher -->
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
<!-- This is for deep linking -->
<intent-filter>
<action android:name="android.intent.action.VIEW" />
<category android:name="android.intent.category.DEFAULT" />
<category android:name="android.intent.category.BROWSABLE" />
<data android:scheme="com.google.aiedge.gallery" />
</intent-filter>
</activity>
<uses-native-library
android:name="libOpenCL.so"
android:required="false" />
<uses-native-library
android:name="libOpenCL-car.so"
android:required="false" />
<uses-native-library
android:name="libOpenCL-pixel.so"
android:required="false" />
<provider
android:name="androidx.core.content.FileProvider"
android:authorities="${applicationId}.provider"
android:exported="false"
android:grantUriPermissions="true">
<meta-data
android:name="android.support.FILE_PROVIDER_PATHS"
android:resource="@xml/file_paths" />
</provider>
</application>
</manifest>

View file

@ -0,0 +1,192 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@file:OptIn(ExperimentalMaterial3Api::class)
package com.google.aiedge.gallery
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBarScrollBehavior
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.ui.navigation.GalleryNavHost
/**
* Top level composable representing the main screen of the application.
*/
@Composable
fun GalleryApp(navController: NavHostController = rememberNavController()) {
GalleryNavHost(navController = navController)
}
/**
* The top app bar.
*/
@Composable
fun GalleryTopAppBar(
title: String,
modifier: Modifier = Modifier,
leftAction: AppBarAction? = null,
rightAction: AppBarAction? = null,
scrollBehavior: TopAppBarScrollBehavior? = null,
loadingHfModels: Boolean = false,
subtitle: String = "",
) {
CenterAlignedTopAppBar(
title = {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
if (title == stringResource(R.string.app_name)) {
Icon(
painterResource(R.drawable.logo),
modifier = Modifier.size(20.dp),
contentDescription = "",
tint = Color.Unspecified,
)
}
Text(
title,
style = MaterialTheme.typography.titleLarge.copy(fontWeight = FontWeight.SemiBold)
)
}
if (subtitle.isNotEmpty()) {
Text(
subtitle,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.secondary
)
}
}
},
modifier = modifier,
scrollBehavior = scrollBehavior,
// The button at the left.
navigationIcon = {
when (leftAction?.actionType) {
AppBarActionType.NAVIGATE_UP -> {
IconButton(onClick = leftAction.actionFn) {
Icon(
imageVector = Icons.AutoMirrored.Rounded.ArrowBack,
contentDescription = "",
)
}
}
AppBarActionType.REFRESH_MODELS -> {
IconButton(onClick = leftAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Refresh,
contentDescription = "",
tint = MaterialTheme.colorScheme.secondary
)
}
}
AppBarActionType.REFRESHING_MODELS -> {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
strokeWidth = 3.dp,
modifier = Modifier
.padding(start = 16.dp)
.size(20.dp)
)
}
else -> {}
}
},
// The "action" component at the right.
actions = {
when (rightAction?.actionType) {
// Click an icon to open "app setting".
AppBarActionType.APP_SETTING -> {
IconButton(onClick = rightAction.actionFn) {
Icon(
imageVector = Icons.Rounded.Settings,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
// Click an icon to open "download manager".
AppBarActionType.DOWNLOAD_MANAGER -> {
if (loadingHfModels) {
CircularProgressIndicator(
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
strokeWidth = 3.dp,
modifier = Modifier
.padding(end = 12.dp)
.size(20.dp)
)
}
// else {
// IconButton(onClick = rightAction.actionFn) {
// Icon(
// imageVector = Deployed_code,
// contentDescription = "",
// tint = MaterialTheme.colorScheme.primary
// )
// }
// }
}
AppBarActionType.MODEL_SELECTOR -> {
Text("ms")
}
// Click a button to navigate up.
AppBarActionType.NAVIGATE_UP -> {
TextButton(onClick = rightAction.actionFn) {
Text("Done")
}
}
else -> {}
}
}
)
}

View file

@ -0,0 +1,51 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery
import android.app.Application
import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.preferencesDataStore
import com.google.aiedge.gallery.data.AppContainer
import com.google.aiedge.gallery.data.DefaultAppContainer
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.ui.theme.ThemeSettings
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "app_gallery_preferences")
class GalleryApplication : Application() {
/** AppContainer instance used by the rest of classes to obtain dependencies */
lateinit var container: AppContainer
override fun onCreate() {
super.onCreate()
// Process tasks.
for ((index, task) in TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess(task = task)
}
}
container = DefaultAppContainer(this, dataStore)
// Load theme.
ThemeSettings.themeOverride.value = container.dataStoreRepository.readThemeOverride()
}
}

View file

@ -0,0 +1,44 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery
import androidx.lifecycle.DefaultLifecycleObserver
import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.ProcessLifecycleOwner
interface AppLifecycleProvider {
val isAppInForeground: Boolean
}
class GalleryLifecycleProvider : AppLifecycleProvider, DefaultLifecycleObserver {
private var _isAppInForeground = false
init {
ProcessLifecycleOwner.get().lifecycle.addObserver(this)
}
override val isAppInForeground: Boolean
get() = _isAppInForeground
override fun onResume(owner: LifecycleOwner) {
_isAppInForeground = true
}
override fun onPause(owner: LifecycleOwner) {
_isAppInForeground = false
}
}

View file

@ -0,0 +1,45 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery
import android.os.Bundle
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.Surface
import androidx.compose.ui.Modifier
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import com.google.aiedge.gallery.ui.theme.GalleryTheme
class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
installSplashScreen()
super.onCreate(savedInstanceState)
enableEdgeToEdge()
setContent {
GalleryTheme {
Surface(
modifier = Modifier.fillMaxSize()
) {
GalleryApp()
}
}
}
}
}

View file

@ -0,0 +1,19 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery
const val VERSION = "20250413"

View file

@ -0,0 +1,30 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
/** Possible action for app bar. */
enum class AppBarActionType {
NO_ACTION,
APP_SETTING,
DOWNLOAD_MANAGER,
MODEL_SELECTOR,
NAVIGATE_UP,
REFRESH_MODELS,
REFRESHING_MODELS,
}
class AppBarAction(val actionType: AppBarActionType, val actionFn: () -> Unit)

View file

@ -0,0 +1,47 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import com.google.aiedge.gallery.GalleryLifecycleProvider
import com.google.aiedge.gallery.AppLifecycleProvider
/**
* App container for Dependency injection.
*
* This interface defines the dependencies required by the application.
*/
interface AppContainer {
val context: Context
val lifecycleProvider: AppLifecycleProvider
val dataStoreRepository: DataStoreRepository
val downloadRepository: DownloadRepository
}
/**
* Default implementation of the AppContainer interface.
*
* This class provides concrete implementations for the application's dependencies,
*/
class DefaultAppContainer(ctx: Context, dataStore: DataStore<Preferences>) : AppContainer {
override val context = ctx
override val lifecycleProvider = GalleryLifecycleProvider()
override val dataStoreRepository = DefaultDataStoreRepository(dataStore)
override val downloadRepository = DefaultDownloadRepository(ctx, lifecycleProvider)
}

View file

@ -0,0 +1,107 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
/**
* The types of configuration editors available.
*
* This enum defines the different UI components used to edit configuration values.
* Each type corresponds to a specific editor widget, such as a slider or a switch.
*/
enum class ConfigEditorType {
NUMBER_SLIDER,
BOOLEAN_SWITCH,
DROPDOWN,
}
/**
* The data types of configuration values.
*/
enum class ValueType {
INT,
FLOAT,
DOUBLE,
STRING,
BOOLEAN,
}
/**
* Base class for configuration settings.
*
* @param type The type of configuration editor.
* @param key The unique key for the configuration setting.
* @param defaultValue The default value for the configuration setting.
* @param valueType The data type of the configuration value.
* @param needReinitialization Indicates whether the model needs to be reinitialized after changing
* this config.
*/
open class Config(
val type: ConfigEditorType,
open val key: ConfigKey,
open val defaultValue: Any,
open val valueType: ValueType,
open val needReinitialization: Boolean = true,
)
/**
* Configuration setting for a number slider.
*
* @param sliderMin The minimum value of the slider.
* @param sliderMax The maximum value of the slider.
*/
class NumberSliderConfig(
override val key: ConfigKey,
val sliderMin: Float,
val sliderMax: Float,
override val defaultValue: Float,
override val valueType: ValueType,
override val needReinitialization: Boolean = true,
) :
Config(
type = ConfigEditorType.NUMBER_SLIDER,
key = key,
defaultValue = defaultValue,
valueType = valueType
)
/**
* Configuration setting for a boolean switch.
*/
class BooleanSwitchConfig(
override val key: ConfigKey,
override val defaultValue: Boolean,
override val needReinitialization: Boolean = true,
) : Config(
type = ConfigEditorType.BOOLEAN_SWITCH,
key = key,
defaultValue = defaultValue,
valueType = ValueType.BOOLEAN,
)
/**
* Configuration setting for a dropdown.
*/
class SegmentedButtonConfig(
override val key: ConfigKey,
override val defaultValue: String,
val options: List<String>,
) : Config(
type = ConfigEditorType.DROPDOWN,
key = key,
defaultValue = defaultValue,
valueType = ValueType.STRING,
)

View file

@ -0,0 +1,32 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
// Keys used to send/receive data to Work.
const val KEY_MODEL_URL = "KEY_MODEL_URL"
const val KEY_MODEL_DOWNLOAD_FILE_NAME = "KEY_MODEL_DOWNLOAD_FILE_NAME"
const val KEY_MODEL_TOTAL_BYTES = "KEY_MODEL_TOTAL_BYTES"
const val KEY_MODEL_DOWNLOAD_RECEIVED_BYTES = "KEY_MODEL_DOWNLOAD_RECEIVED_BYTES"
const val KEY_MODEL_DOWNLOAD_RATE = "KEY_MODEL_DOWNLOAD_RATE"
const val KEY_MODEL_DOWNLOAD_REMAINING_MS = "KEY_MODEL_DOWNLOAD_REMAINING_SECONDS"
const val KEY_MODEL_DOWNLOAD_ERROR_MESSAGE = "KEY_MODEL_DOWNLOAD_ERROR_MESSAGE"
const val KEY_MODEL_DOWNLOAD_ACCESS_TOKEN = "KEY_MODEL_DOWNLOAD_ACCESS_TOKEN"
const val KEY_MODEL_EXTRA_DATA_URLS = "KEY_MODEL_EXTRA_DATA_URLS"
const val KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES = "KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES"
const val KEY_MODEL_IS_ZIP = "KEY_MODEL_IS_ZIP"
const val KEY_MODEL_UNZIPPED_DIR = "KEY_MODEL_UNZIPPED_DIR"
const val KEY_MODEL_START_UNZIPPING = "KEY_MODEL_START_UNZIPPING"

View file

@ -0,0 +1,208 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
import android.security.keystore.KeyGenParameterSpec
import android.security.keystore.KeyProperties
import android.util.Base64
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.longPreferencesKey
import androidx.datastore.preferences.core.stringPreferencesKey
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import com.google.aiedge.gallery.ui.theme.THEME_AUTO
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking
import java.security.KeyStore
import javax.crypto.Cipher
import javax.crypto.KeyGenerator
import javax.crypto.SecretKey
data class AccessTokenData(
val accessToken: String,
val refreshToken: String,
val expiresAtSeconds: Long
)
interface DataStoreRepository {
fun saveTextInputHistory(history: List<String>)
fun readTextInputHistory(): List<String>
fun saveThemeOverride(theme: String)
fun readThemeOverride(): String
fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long)
fun readAccessTokenData(): AccessTokenData?
}
/**
* Repository for managing data using DataStore, with JSON serialization.
*
* This class provides methods to read, add, remove, and clear data stored in DataStore,
* using JSON serialization for complex objects. It uses Gson for serializing and deserializing
* lists of objects to/from JSON strings.
*
* DataStore is used to persist data as JSON strings under specified keys.
*/
class DefaultDataStoreRepository(
private val dataStore: DataStore<Preferences>
) :
DataStoreRepository {
private object PreferencesKeys {
val TEXT_INPUT_HISTORY = stringPreferencesKey("text_input_history")
val THEME_OVERRIDE = stringPreferencesKey("theme_override")
val ENCRYPTED_ACCESS_TOKEN = stringPreferencesKey("encrypted_access_token")
// Store Initialization Vector
val ACCESS_TOKEN_IV = stringPreferencesKey("access_token_iv")
val ENCRYPTED_REFRESH_TOKEN = stringPreferencesKey("encrypted_refresh_token")
// Store Initialization Vector
val REFRESH_TOKEN_IV = stringPreferencesKey("refresh_token_iv")
val ACCESS_TOKEN_EXPIRES_AT = longPreferencesKey("access_token_expires_at")
}
private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key"
private val keyStore: KeyStore = KeyStore.getInstance("AndroidKeyStore").apply { load(null) }
override fun saveTextInputHistory(history: List<String>) {
runBlocking {
dataStore.edit { preferences ->
val gson = Gson()
val jsonString = gson.toJson(history)
preferences[PreferencesKeys.TEXT_INPUT_HISTORY] = jsonString
}
}
}
override fun readTextInputHistory(): List<String> {
return runBlocking {
val preferences = dataStore.data.first()
getTextInputHistory(preferences)
}
}
override fun saveThemeOverride(theme: String) {
runBlocking {
dataStore.edit { preferences ->
preferences[PreferencesKeys.THEME_OVERRIDE] = theme
}
}
}
override fun readThemeOverride(): String {
return runBlocking {
val preferences = dataStore.data.first()
preferences[PreferencesKeys.THEME_OVERRIDE] ?: THEME_AUTO
}
}
override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) {
runBlocking {
val (encryptedAccessToken, accessTokenIv) = encrypt(accessToken)
val (encryptedRefreshToken, refreshTokenIv) = encrypt(refreshToken)
dataStore.edit { preferences ->
preferences[PreferencesKeys.ENCRYPTED_ACCESS_TOKEN] = encryptedAccessToken
preferences[PreferencesKeys.ACCESS_TOKEN_IV] = accessTokenIv
preferences[PreferencesKeys.ENCRYPTED_REFRESH_TOKEN] = encryptedRefreshToken
preferences[PreferencesKeys.REFRESH_TOKEN_IV] = refreshTokenIv
preferences[PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT] = expiresAt
}
}
}
override fun readAccessTokenData(): AccessTokenData? {
return runBlocking {
val preferences = dataStore.data.first()
val encryptedAccessToken = preferences[PreferencesKeys.ENCRYPTED_ACCESS_TOKEN]
val encryptedRefreshToken = preferences[PreferencesKeys.ENCRYPTED_REFRESH_TOKEN]
val accessTokenIv = preferences[PreferencesKeys.ACCESS_TOKEN_IV]
val refreshTokenIv = preferences[PreferencesKeys.REFRESH_TOKEN_IV]
val expiresAt = preferences[PreferencesKeys.ACCESS_TOKEN_EXPIRES_AT]
var decryptedAccessToken: String? = null
var decryptedRefreshToken: String? = null
if (encryptedAccessToken != null && accessTokenIv != null) {
decryptedAccessToken = decrypt(encryptedAccessToken, accessTokenIv)
}
if (encryptedRefreshToken != null && refreshTokenIv != null) {
decryptedRefreshToken = decrypt(encryptedRefreshToken, refreshTokenIv)
}
if (decryptedAccessToken != null && decryptedRefreshToken != null && expiresAt != null) {
AccessTokenData(decryptedAccessToken, decryptedRefreshToken, expiresAt)
} else {
null
}
}
}
private fun getTextInputHistory(preferences: Preferences): List<String> {
val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]"
val gson = Gson()
val listType = object : TypeToken<List<String>>() {}.type
return gson.fromJson(infosStr, listType)
}
private fun getOrCreateSecretKey(): SecretKey {
return (keyStore.getKey(keystoreAlias, null) as? SecretKey) ?: run {
val keyGenerator =
KeyGenerator.getInstance(KeyProperties.KEY_ALGORITHM_AES, "AndroidKeyStore")
val keyGenParameterSpec = KeyGenParameterSpec.Builder(
keystoreAlias,
KeyProperties.PURPOSE_ENCRYPT or KeyProperties.PURPOSE_DECRYPT
)
.setBlockModes(KeyProperties.BLOCK_MODE_GCM)
.setEncryptionPaddings(KeyProperties.ENCRYPTION_PADDING_NONE)
.setUserAuthenticationRequired(false) // Consider setting to true for added security
.build()
keyGenerator.init(keyGenParameterSpec)
keyGenerator.generateKey()
}
}
private fun encrypt(plainText: String): Pair<String, String> {
val secretKey = getOrCreateSecretKey()
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
cipher.init(Cipher.ENCRYPT_MODE, secretKey)
val iv = cipher.iv
val encryptedBytes = cipher.doFinal(plainText.toByteArray(Charsets.UTF_8))
return Base64.encodeToString(encryptedBytes, Base64.DEFAULT) to Base64.encodeToString(
iv,
Base64.DEFAULT
)
}
private fun decrypt(encryptedText: String, ivText: String): String? {
val secretKey = getOrCreateSecretKey()
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
val ivBytes = Base64.decode(ivText, Base64.DEFAULT)
val spec = javax.crypto.spec.GCMParameterSpec(128, ivBytes) // 128 bit tag length
cipher.init(Cipher.DECRYPT_MODE, secretKey, spec)
val encryptedBytes = Base64.decode(encryptedText, Base64.DEFAULT)
return try {
String(cipher.doFinal(encryptedBytes), Charsets.UTF_8)
} catch (e: Exception) {
// Handle decryption errors (e.g., key not found)
null
}
}
}

View file

@ -0,0 +1,312 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
import android.Manifest
import android.app.NotificationChannel
import android.app.NotificationManager
import android.app.PendingIntent
import android.content.Context
import android.content.Intent
import android.content.pm.PackageManager
import android.net.Uri
import android.os.Build
import android.util.Log
import androidx.core.app.ActivityCompat
import androidx.core.app.NotificationCompat
import androidx.core.app.NotificationManagerCompat
import androidx.work.Data
import androidx.work.ExistingWorkPolicy
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.Operation
import androidx.work.OutOfQuotaPolicy
import androidx.work.WorkInfo
import androidx.work.WorkManager
import androidx.work.WorkQuery
import com.google.common.util.concurrent.FutureCallback
import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.aiedge.gallery.AppLifecycleProvider
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.worker.DownloadWorker
import java.util.UUID
private const val TAG = "AGDownloadRepository"
private const val MODEL_NAME_TAG = "modelName"
data class AGWorkInfo(val modelName: String, val workId: String)
interface DownloadRepository {
fun downloadModel(
model: Model, onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit
)
fun cancelDownloadModel(model: Model)
fun cancelAll(models: List<Model>, onComplete: () -> Unit)
fun observerWorkerProgress(
workerId: UUID,
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
)
fun getEnqueuedOrRunningWorkInfos(): List<AGWorkInfo>
}
/**
* Repository for managing model downloads using WorkManager.
*
* This class provides methods to initiate model downloads, cancel downloads, observe download
* progress, and retrieve information about enqueued or running download tasks. It utilizes
* WorkManager to handle background download operations.
*/
class DefaultDownloadRepository(
private val context: Context,
private val lifecycleProvider: AppLifecycleProvider,
) : DownloadRepository {
private val workManager = WorkManager.getInstance(context)
override fun downloadModel(
model: Model, onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit
) {
// Create input data.
val builder = Data.Builder()
val totalBytes = model.totalBytes + model.extraDataFiles.sumOf { it.sizeInBytes }
val inputDataBuilder = builder.putString(KEY_MODEL_URL, model.url)
.putString(KEY_MODEL_DOWNLOAD_FILE_NAME, model.downloadFileName)
.putBoolean(KEY_MODEL_IS_ZIP, model.isZip).putString(KEY_MODEL_UNZIPPED_DIR, model.unzipDir)
.putLong(
KEY_MODEL_TOTAL_BYTES, totalBytes
)
if (model.extraDataFiles.isNotEmpty()) {
inputDataBuilder.putString(
KEY_MODEL_EXTRA_DATA_URLS, model.extraDataFiles.joinToString(",") { it.url }
).putString(
KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES,
model.extraDataFiles.joinToString(",") { it.downloadFileName }
)
}
if (model.accessToken != null) {
inputDataBuilder.putString(KEY_MODEL_DOWNLOAD_ACCESS_TOKEN, model.accessToken)
}
val inputData = inputDataBuilder.build()
// Create worker request.
val downloadWorkRequest =
OneTimeWorkRequestBuilder<DownloadWorker>().setExpedited(OutOfQuotaPolicy.RUN_AS_NON_EXPEDITED_WORK_REQUEST)
.setInputData(inputData).addTag("${MODEL_NAME_TAG}:${model.name}").build()
val workerId = downloadWorkRequest.id
// Start!
workManager.enqueueUniqueWork(
model.name, ExistingWorkPolicy.REPLACE, downloadWorkRequest
)
// Observe progress.
observerWorkerProgress(
workerId = workerId, model = model, onStatusUpdated = onStatusUpdated
)
}
override fun cancelDownloadModel(model: Model) {
workManager.cancelAllWorkByTag("${MODEL_NAME_TAG}:${model.name}")
}
override fun cancelAll(models: List<Model>, onComplete: () -> Unit) {
if (models.isEmpty()) {
onComplete()
return
}
val futures = mutableListOf<ListenableFuture<Operation.State.SUCCESS>>()
for (tag in models.map { "${MODEL_NAME_TAG}:${it.name}" }) {
futures.add(workManager.cancelAllWorkByTag(tag).result)
}
val combinedFuture: ListenableFuture<List<Operation.State.SUCCESS>> = Futures.allAsList(futures)
Futures.addCallback(
combinedFuture, object : FutureCallback<List<Operation.State.SUCCESS>> {
override fun onSuccess(result: List<Operation.State.SUCCESS>?) {
// All cancellations are complete
onComplete()
}
override fun onFailure(t: Throwable) {
// At least one cancellation failed
t.printStackTrace()
onComplete()
}
}, MoreExecutors.directExecutor()
)
}
override fun observerWorkerProgress(
workerId: UUID,
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) {
workManager.getWorkInfoByIdLiveData(workerId).observeForever { workInfo ->
if (workInfo != null) {
when (workInfo.state) {
WorkInfo.State.RUNNING -> {
val receivedBytes = workInfo.progress.getLong(KEY_MODEL_DOWNLOAD_RECEIVED_BYTES, 0L)
val downloadRate = workInfo.progress.getLong(KEY_MODEL_DOWNLOAD_RATE, 0L)
val remainingSeconds = workInfo.progress.getLong(KEY_MODEL_DOWNLOAD_REMAINING_MS, 0L)
val startUnzipping = workInfo.progress.getBoolean(KEY_MODEL_START_UNZIPPING, false)
if (!startUnzipping) {
if (receivedBytes != 0L) {
onStatusUpdated(
model, ModelDownloadStatus(
status = ModelDownloadStatusType.IN_PROGRESS,
totalBytes = model.totalBytes,
receivedBytes = receivedBytes,
bytesPerSecond = downloadRate,
remainingMs = remainingSeconds,
)
)
}
} else {
onStatusUpdated(
model, ModelDownloadStatus(
status = ModelDownloadStatusType.UNZIPPING,
)
)
}
}
WorkInfo.State.SUCCEEDED -> {
Log.d("repo", "worker %s success".format(workerId.toString()))
onStatusUpdated(
model, ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED,
)
)
sendNotification(
title = context.getString(
R.string.notification_title_success
),
text = context.getString(R.string.notification_content_success).format(model.name),
modelName = model.name,
)
}
WorkInfo.State.FAILED, WorkInfo.State.CANCELLED -> {
var status = ModelDownloadStatusType.FAILED
val errorMessage = workInfo.outputData.getString(KEY_MODEL_DOWNLOAD_ERROR_MESSAGE) ?: ""
Log.d(
"repo", "worker %s FAILED or CANCELLED: %s".format(workerId.toString(), errorMessage)
)
if (workInfo.state == WorkInfo.State.CANCELLED) {
status = ModelDownloadStatusType.NOT_DOWNLOADED
} else {
sendNotification(
title = context.getString(R.string.notification_title_fail),
text = context.getString(R.string.notification_content_success).format(model.name),
modelName = "",
)
}
onStatusUpdated(
model, ModelDownloadStatus(status = status, errorMessage = errorMessage)
)
}
else -> {}
}
}
}
}
/**
* Retrieves a list of AGWorkInfo objects representing WorkManager work items that are either
* enqueued or currently running.
*/
override fun getEnqueuedOrRunningWorkInfos(): List<AGWorkInfo> {
val workQuery =
WorkQuery.Builder.fromStates(listOf(WorkInfo.State.ENQUEUED, WorkInfo.State.RUNNING)).build()
return workManager.getWorkInfos(workQuery).get().map { info ->
val tags = info.tags
var modelName = ""
Log.d(TAG, "work: ${info.id}, tags: $tags")
for (tag in tags) {
if (tag.startsWith("${MODEL_NAME_TAG}:")) {
val index = tag.indexOf(':')
if (index >= 0) {
modelName = tag.substring(index + 1)
break
}
}
}
return@map AGWorkInfo(modelName = modelName, workId = info.id.toString())
}
}
private fun sendNotification(title: String, text: String, modelName: String) {
// Don't send notification if app is in foreground.
if (lifecycleProvider.isAppInForeground) {
return
}
val channelId = "download_notification"
val channelName = "AI Edge Gallery download notification"
// Create the NotificationChannel, but only on API 26+ because
// the NotificationChannel class is new and not in the support library
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
val importance = NotificationManager.IMPORTANCE_HIGH
val channel = NotificationChannel(channelId, channelName, importance)
val notificationManager: NotificationManager =
context.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager
notificationManager.createNotificationChannel(channel)
}
// Create an Intent to open your app with a deep link.
val intent = Intent(
Intent.ACTION_VIEW,
Uri.parse("com.google.aiedge.gallery://model/${modelName}")
).apply {
flags = Intent.FLAG_ACTIVITY_NEW_TASK
}
// Create a PendingIntent
val pendingIntent: PendingIntent = PendingIntent.getActivity(
context, 0, intent, PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE
)
val builder = NotificationCompat.Builder(context, channelId)
// TODO: replace icon.
.setSmallIcon(android.R.drawable.ic_dialog_info).setContentTitle(title).setContentText(text)
.setPriority(NotificationCompat.PRIORITY_HIGH).setContentIntent(pendingIntent)
.setAutoCancel(true)
with(NotificationManagerCompat.from(context)) {
// notificationId is a unique int for each notification that you must define
if (ActivityCompat.checkSelfPermission(
context, Manifest.permission.POST_NOTIFICATIONS
) != PackageManager.PERMISSION_GRANTED
) {
// Permission not granted, return or handle accordingly. In real app, request permission.
return
}
notify(1, builder.build())
}
}
}

View file

@ -0,0 +1,175 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
import com.google.aiedge.gallery.ui.llmchat.createLLmChatConfig
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonPrimitive
@Serializable
data class HfModelSummary(val modelId: String)
@Serializable
data class HfModelDetails(val id: String, val siblings: List<HfModelFile>)
@Serializable
data class HfModelFile(val rfilename: String)
@Serializable(with = ConfigValueSerializer::class)
sealed class ConfigValue {
@Serializable
data class IntValue(val value: Int) : ConfigValue()
@Serializable
data class FloatValue(val value: Float) : ConfigValue()
@Serializable
data class StringValue(val value: String) : ConfigValue()
}
/**
* Custom serializer for the ConfigValue class.
*
* This object implements the KSerializer interface to provide custom serialization and
* deserialization logic for the ConfigValue class. It handles different types of ConfigValue
* (IntValue, FloatValue, StringValue) and supports JSON format.
*/
object ConfigValueSerializer : KSerializer<ConfigValue> {
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("ConfigValue")
override fun serialize(encoder: Encoder, value: ConfigValue) {
when (value) {
is ConfigValue.IntValue -> encoder.encodeInt(value.value)
is ConfigValue.FloatValue -> encoder.encodeFloat(value.value)
is ConfigValue.StringValue -> encoder.encodeString(value.value)
}
}
override fun deserialize(decoder: Decoder): ConfigValue {
val input = decoder as? JsonDecoder
?: throw SerializationException("This serializer only works with Json")
return when (val element = input.decodeJsonElement()) {
is JsonPrimitive -> {
if (element.isString) {
ConfigValue.StringValue(element.content)
} else if (element.content.contains('.')) {
ConfigValue.FloatValue(element.content.toFloat())
} else {
ConfigValue.IntValue(element.content.toInt())
}
}
else -> throw SerializationException("Expected JsonPrimitive")
}
}
}
@Serializable
data class HfModel(
var id: String = "",
val task: String,
val name: String,
val url: String = "",
val file: String = "",
val sizeInBytes: Long,
val configs: Map<String, ConfigValue>,
) {
fun toModel(): Model {
val parts = if (url.isNotEmpty()) {
url.split('/')
} else if (file.isNotEmpty()) {
listOf(file)
} else {
listOf("")
}
val fileName = "${id}_${(parts.lastOrNull() ?: "")}".replace(Regex("[^a-zA-Z0-9._-]"), "_")
// Generate configs based on the given default values.
val configs: List<Config> = when (task) {
TASK_LLM_CHAT.type.label -> createLLmChatConfig(defaults = configs)
// todo: add configs for other types.
else -> listOf()
}
// Construct url.
var modelUrl = url
if (modelUrl.isEmpty() && file.isNotEmpty()) {
modelUrl = "https://huggingface.co/${id}/resolve/main/${file}?download=true"
}
// Other parameters.
val showBenchmarkButton = when (task) {
TASK_LLM_CHAT.type.label -> false
else -> true
}
val showRunAgainButton = when (task) {
TASK_LLM_CHAT.type.label -> false
else -> true
}
return Model(
hfModelId = id,
name = name,
url = modelUrl,
sizeInBytes = sizeInBytes,
downloadFileName = fileName,
configs = configs,
showBenchmarkButton = showBenchmarkButton,
showRunAgainButton = showRunAgainButton,
)
}
}
fun getIntConfigValue(configValue: ConfigValue?, default: Int): Int {
if (configValue == null) {
return default
}
return when (configValue) {
is ConfigValue.IntValue -> configValue.value
is ConfigValue.FloatValue -> configValue.value.toInt()
is ConfigValue.StringValue -> 0
}
}
fun getFloatConfigValue(configValue: ConfigValue?, default: Float): Float {
if (configValue == null) {
return default
}
return when (configValue) {
is ConfigValue.IntValue -> configValue.value.toFloat()
is ConfigValue.FloatValue -> configValue.value
is ConfigValue.StringValue -> 0f
}
}
fun getStringConfigValue(configValue: ConfigValue?, default: String): String {
if (configValue == null) {
return default
}
return when (configValue) {
is ConfigValue.IntValue -> "${configValue.value}"
is ConfigValue.FloatValue -> "${configValue.value}"
is ConfigValue.StringValue -> configValue.value
}
}

View file

@ -0,0 +1,376 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
import android.content.Context
import com.google.aiedge.gallery.ui.common.chat.PromptTemplate
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.llmchat.createLlmChatConfigs
data class ModelDataFile(
val name: String,
val url: String,
val downloadFileName: String,
val sizeInBytes: Long,
)
enum class LlmBackend {
CPU, GPU
}
/** A model for a task */
data class Model(
/** The Hugging Face model ID (if applicable). */
val hfModelId: String = "",
/** The name (for display purpose) of the model. */
val name: String,
/** The name of the downloaded model file. */
val downloadFileName: String,
/** The URL to download the model from. */
val url: String,
/** The size of the model file in bytes. */
val sizeInBytes: Long,
/** A list of additional data files required by the model. */
val extraDataFiles: List<ModelDataFile> = listOf(),
/**
* A description or information about the model.
*
* Will be shown at the start of the chat session and in the expanded model item.
*/
val info: String = "",
/**
* The url to jump to when clicking "learn more" in expanded model item.
*/
val learnMoreUrl: String = "",
/** A list of configurable parameters for the model. */
val configs: List<Config> = listOf(),
/** Whether to show the "run again" button in the UI. */
val showRunAgainButton: Boolean = true,
/** Whether to show the "benchmark" button in the UI. */
val showBenchmarkButton: Boolean = true,
/** Indicates whether the model is a zip file. */
val isZip: Boolean = false,
/** The name of the directory to unzip the model to (if it's a zip file). */
val unzipDir: String = "",
/** The preferred backend of the model (only for LLM). */
val llmBackend: LlmBackend = LlmBackend.GPU,
/** The prompt templates for the model (only for LLM). */
val llmPromptTemplates: List<PromptTemplate> = listOf(),
// The following fields are managed by the app. Don't need to set manually.
var taskType: TaskType? = null,
var instance: Any? = null,
var initializing: Boolean = false,
var configValues: Map<String, Any> = mapOf(),
var totalBytes: Long = 0L,
var accessToken: String? = null,
) {
fun preProcess(task: Task) {
this.taskType = task.type
val configValues: MutableMap<String, Any> = mutableMapOf()
for (config in this.configs) {
configValues[config.key.label] = config.defaultValue
}
this.configValues = configValues
this.totalBytes = this.sizeInBytes + this.extraDataFiles.sumOf { it.sizeInBytes }
}
fun getPath(context: Context, fileName: String = downloadFileName): String {
return if (this.isZip && this.unzipDir.isNotEmpty()) {
"${context.getExternalFilesDir(null)}/${this.unzipDir}"
} else {
"${context.getExternalFilesDir(null)}/${fileName}"
}
}
fun getIntConfigValue(key: ConfigKey, defaultValue: Int = 0): Int {
return getTypedConfigValue(
key = key, valueType = ValueType.INT, defaultValue = defaultValue
) as Int
}
fun getFloatConfigValue(key: ConfigKey, defaultValue: Float = 0.0f): Float {
return getTypedConfigValue(
key = key, valueType = ValueType.FLOAT, defaultValue = defaultValue
) as Float
}
fun getBooleanConfigValue(key: ConfigKey, defaultValue: Boolean = false): Boolean {
return getTypedConfigValue(
key = key, valueType = ValueType.BOOLEAN, defaultValue = defaultValue
) as Boolean
}
fun getExtraDataFile(name: String): ModelDataFile? {
return extraDataFiles.find { it.name == name }
}
private fun getTypedConfigValue(key: ConfigKey, valueType: ValueType, defaultValue: Any): Any {
return convertValueToTargetType(
value = configValues.getOrDefault(key.label, defaultValue), valueType = valueType
)
}
}
enum class ModelDownloadStatusType {
NOT_DOWNLOADED, PARTIALLY_DOWNLOADED, IN_PROGRESS, UNZIPPING, SUCCEEDED, FAILED,
}
data class ModelDownloadStatus(
val status: ModelDownloadStatusType,
val totalBytes: Long = 0,
val receivedBytes: Long = 0,
val errorMessage: String = "",
val bytesPerSecond: Long = 0,
val remainingMs: Long = 0,
)
////////////////////////////////////////////////////////////////////////////////////////////////////
// Configs.
enum class ConfigKey(val label: String, val id: String) {
MAX_TOKENS("Max tokens", id = "max_token"),
TOPK("TopK", id = "topk"),
TOPP(
"TopP",
id = "topp"
),
TEMPERATURE("Temperature", id = "temperature"),
MAX_RESULT_COUNT(
"Max result count",
id = "max_result_count"
),
USE_GPU("Use GPU", id = "use_gpu"),
WARM_UP_ITERATIONS(
"Warm up iterations",
id = "warm_up_iterations"
),
BENCHMARK_ITERATIONS(
"Benchmark iterations",
id = "benchmark_iterations"
),
ITERATIONS("Iterations", id = "iterations"),
THEME("Theme", id = "theme"),
}
val MOBILENET_CONFIGS: List<Config> = listOf(
NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f,
sliderMax = 5f,
defaultValue = 3f,
valueType = ValueType.INT
), BooleanSwitchConfig(
key = ConfigKey.USE_GPU,
defaultValue = false,
)
)
val IMAGE_GENERATION_CONFIGS: List<Config> = listOf(
NumberSliderConfig(
key = ConfigKey.ITERATIONS,
sliderMin = 5f,
sliderMax = 50f,
defaultValue = 10f,
valueType = ValueType.INT,
needReinitialization = false,
)
)
const val TEXT_CLASSIFICATION_INFO =
"Model is trained on movie reviews dataset. Type a movie review below and see the scores of positive or negative sentiment."
const val TEXT_CLASSIFICATION_LEARN_MORE_URL =
"https://ai.google.dev/edge/mediapipe/solutions/text/text_classifier"
const val IMAGE_CLASSIFICATION_INFO = ""
const val IMAGE_CLASSIFICATION_LEARN_MORE_URL = "https://ai.google.dev/edge/litert/android"
const val LLM_CHAT_INFO =
"Some description about this large language model. A community org for developers to discover models that are ready for deployment to edge platforms"
const val LLM_CHAT_LEARN_MORE_URL =
"https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android"
const val IMAGE_GENERATION_INFO =
"Powered by [MediaPipe Image Generation API](https://ai.google.dev/edge/mediapipe/solutions/vision/image_generator/android)"
////////////////////////////////////////////////////////////////////////////////////////////////////
// Model spec.
val MODEL_LLM_GEMMA_2B_GPU_INT4: Model = Model(
name = "Gemma 2B (GPU int4)",
downloadFileName = "gemma-2b-it-gpu-int4.bin",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma-2b-it-gpu-int4.bin",
sizeInBytes = 1354301440L,
configs = createLlmChatConfigs(),
info = LLM_CHAT_INFO,
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
)
val MODEL_LLM_GEMMA_2_2B_GPU_INT8: Model = Model(
name = "Gemma 2 2B (GPU int8)",
downloadFileName = "gemma2-2b-it-gpu-int8.bin",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/gemma2-2b-it-gpu-int8.bin",
sizeInBytes = 2627141632L,
configs = createLlmChatConfigs(),
info = LLM_CHAT_INFO,
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
)
val MODEL_LLM_GEMMA_3_1B_INT4: Model = Model(
name = "Gemma 3 1B (int4)",
downloadFileName = "gemma3-1b-it-int4.task",
url = "https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.task?download=true",
sizeInBytes = 554661243L,
configs = createLlmChatConfigs(defaultTopK = 64, defaultTopP = 0.95f),
info = LLM_CHAT_INFO,
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
llmPromptTemplates = listOf(
PromptTemplate(
title = "Emoji Fun",
description = "Generate emojis by emotions",
prompt = "Show me emojis grouped by emotions"
),
PromptTemplate(
title = "Trip Planner",
description = "Plan a trip to a destination",
prompt = "Plan a two-day trip to San Francisco"
),
)
)
val MODEL_LLM_DEEPSEEK: Model = Model(
name = "Deepseek",
downloadFileName = "deepseek.task",
url = "https://huggingface.co/litert-community/DeepSeek-R1-Distill-Qwen-1.5B/resolve/main/deepseek_q8_ekv1280.task?download=true",
sizeInBytes = 1860686856L,
llmBackend = LlmBackend.CPU,
configs = createLlmChatConfigs(defaultTemperature = 0.6f, defaultTopK = 40, defaultTopP = 0.7f),
info = LLM_CHAT_INFO,
learnMoreUrl = LLM_CHAT_LEARN_MORE_URL,
)
val MODEL_TEXT_CLASSIFICATION_MOBILEBERT: Model = Model(
name = "MobileBert",
downloadFileName = "bert_classifier.tflite",
url = "https://storage.googleapis.com/mediapipe-models/text_classifier/bert_classifier/float32/latest/bert_classifier.tflite",
sizeInBytes = 25707538L,
info = TEXT_CLASSIFICATION_INFO,
learnMoreUrl = TEXT_CLASSIFICATION_LEARN_MORE_URL,
)
val MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING: Model = Model(
name = "Average word embedding",
downloadFileName = "average_word_classifier.tflite",
url = "https://storage.googleapis.com/mediapipe-models/text_classifier/average_word_classifier/float32/latest/average_word_classifier.tflite",
sizeInBytes = 775708L,
info = TEXT_CLASSIFICATION_INFO,
)
val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1: Model = Model(
name = "Mobilenet V1",
downloadFileName = "mobilenet_v1.tflite",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v1.tflite",
sizeInBytes = 16900760L,
extraDataFiles = listOf(
ModelDataFile(
name = "labels",
url = "https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt",
downloadFileName = "mobilenet_labels_v1.txt",
sizeInBytes = 21685L
),
),
configs = MOBILENET_CONFIGS,
info = IMAGE_CLASSIFICATION_INFO,
learnMoreUrl = IMAGE_CLASSIFICATION_LEARN_MORE_URL,
)
val MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2: Model = Model(
name = "Mobilenet V2",
downloadFileName = "mobilenet_v2.tflite",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/mobilenet_v2.tflite",
sizeInBytes = 13978596L,
extraDataFiles = listOf(
ModelDataFile(
name = "labels",
url = "https://raw.githubusercontent.com/leferrad/tensorflow-mobilenet/refs/heads/master/imagenet/labels.txt",
downloadFileName = "mobilenet_labels_v2.txt",
sizeInBytes = 21685L
),
),
configs = MOBILENET_CONFIGS,
info = IMAGE_CLASSIFICATION_INFO,
)
val MODEL_IMAGE_GENERATION_STABLE_DIFFUSION: Model = Model(
name = "Stable diffusion",
downloadFileName = "sd15.zip",
isZip = true,
unzipDir = "sd15",
url = "https://storage.googleapis.com/tfweb/app_gallery_models/sd15.zip",
sizeInBytes = 1906219565L,
showRunAgainButton = false,
showBenchmarkButton = false,
info = IMAGE_GENERATION_INFO,
configs = IMAGE_GENERATION_CONFIGS,
)
val EMPTY_MODEL: Model = Model(
name = "empty",
downloadFileName = "empty.tflite",
url = "",
sizeInBytes = 0L,
)
////////////////////////////////////////////////////////////////////////////////////////////////////
// Model collections for different tasks.
val MODELS_TEXT_CLASSIFICATION: MutableList<Model> = mutableListOf(
MODEL_TEXT_CLASSIFICATION_MOBILEBERT,
MODEL_TEXT_CLASSIFICATION_AVERAGE_WORD_EMBEDDING,
)
val MODELS_IMAGE_CLASSIFICATION: MutableList<Model> = mutableListOf(
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V1,
MODEL_IMAGE_CLASSIFICATION_MOBILENET_V2,
)
val MODELS_LLM_CHAT: MutableList<Model> = mutableListOf(
MODEL_LLM_GEMMA_2B_GPU_INT4,
MODEL_LLM_GEMMA_2_2B_GPU_INT8,
MODEL_LLM_GEMMA_3_1B_INT4,
MODEL_LLM_DEEPSEEK,
)
val MODELS_IMAGE_GENERATION: MutableList<Model> =
mutableListOf(MODEL_IMAGE_GENERATION_STABLE_DIFFUSION)

View file

@ -0,0 +1,111 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.data
import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ImageSearch
import androidx.compose.ui.graphics.vector.ImageVector
import com.google.aiedge.gallery.R
/** Type of task. */
enum class TaskType(val label: String) {
TEXT_CLASSIFICATION("Text Classification"),
IMAGE_CLASSIFICATION("Image Classification"),
IMAGE_GENERATION("Image Generation"),
LLM_CHAT("LLM Chat"),
TEST_TASK_1("Test task 1"),
TEST_TASK_2("Test task 2")
}
/** Data class for a task listed in home screen. */
data class Task(
/** Type of the task. */
val type: TaskType,
/** Icon to be shown in the task tile. */
val icon: ImageVector? = null,
/** Vector resource id for the icon. This precedes the icon if both are set. */
val iconVectorResourceId: Int? = null,
/** List of models for the task. */
val models: MutableList<Model>,
/** Description of the task. */
val description: String,
/** Placeholder text for the name of the agent shown above chat messages. */
@StringRes val agentNameRes: Int = R.string.chat_generic_agent_name,
/** Placeholder text for the text input field. */
@StringRes val textInputPlaceHolderRes: Int = R.string.chat_textinput_placeholder,
// The following fields are managed by the app. Don't need to set manually.
var index: Int = -1
)
val TASK_TEXT_CLASSIFICATION = Task(
type = TaskType.TEXT_CLASSIFICATION,
iconVectorResourceId = R.drawable.text_spark,
models = MODELS_TEXT_CLASSIFICATION,
description = "Classify text into different categories",
textInputPlaceHolderRes = R.string.text_input_placeholder_text_classification
)
val TASK_IMAGE_CLASSIFICATION = Task(
type = TaskType.IMAGE_CLASSIFICATION,
icon = Icons.Rounded.ImageSearch,
description = "Classify images into different categories",
models = MODELS_IMAGE_CLASSIFICATION
)
val TASK_LLM_CHAT = Task(
type = TaskType.LLM_CHAT,
iconVectorResourceId = R.drawable.chat_spark,
models = MODELS_LLM_CHAT,
description = "Chat? with a on-device large language model",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat
)
val TASK_IMAGE_GENERATION = Task(
type = TaskType.IMAGE_GENERATION,
iconVectorResourceId = R.drawable.image_spark,
models = MODELS_IMAGE_GENERATION,
description = "Generate images from text",
textInputPlaceHolderRes = R.string.text_image_generation_text_field_placeholder
)
/** All tasks. */
val TASKS: List<Task> = listOf(
TASK_TEXT_CLASSIFICATION,
TASK_IMAGE_CLASSIFICATION,
TASK_IMAGE_GENERATION,
TASK_LLM_CHAT,
)
fun getModelByName(name: String): Model? {
for (task in TASKS) {
for (model in task.models) {
if (model.name == name) {
return model
}
}
}
return null
}

View file

@ -0,0 +1,70 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui
import android.app.Application
import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory
import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.initializer
import androidx.lifecycle.viewmodel.viewModelFactory
import com.google.aiedge.gallery.GalleryApplication
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.aiedge.gallery.ui.llmchat.LlmChatViewModel
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationViewModel
object ViewModelProvider {
val Factory = viewModelFactory {
// Initializer for ModelManagerViewModel.
initializer {
val downloadRepository = galleryApplication().container.downloadRepository
val dataStoreRepository = galleryApplication().container.dataStoreRepository
ModelManagerViewModel(
downloadRepository = downloadRepository,
dataStoreRepository = dataStoreRepository,
context = galleryApplication().container.context,
)
}
// Initializer for TextClassificationViewModel
initializer {
TextClassificationViewModel()
}
// Initializer for ImageClassificationViewModel
initializer {
ImageClassificationViewModel()
}
// Initializer for LlmChatViewModel.
initializer {
LlmChatViewModel()
}
initializer {
ImageGenerationViewModel()
}
}
}
/**
* Extension function to queries for [Application] object and returns an instance of
* [GalleryApplication].
*/
fun CreationExtras.galleryApplication(): GalleryApplication =
(this[AndroidViewModelFactory.APPLICATION_KEY] as GalleryApplication)

View file

@ -0,0 +1,41 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import android.net.Uri
import net.openid.appauth.AuthorizationServiceConfiguration
object AuthConfig {
// Hugging Face Client ID.
const val clientId = "88a0ac25-fcf4-467b-b8cd-ebcc2aec9bd0"
// Registered redirect URI.
//
// The scheme needs to match the
// "android.defaultConfig.manifestPlaceholders["appAuthRedirectScheme"]" field in
// "build.gradle.kts".
const val redirectUri = "com.google.aiedge.gallery.oauth://oauthredirect"
// OAuth 2.0 Endpoints (Authorization + Token Exchange)
private const val authEndpoint = "https://huggingface.co/oauth/authorize"
private const val tokenEndpoint = "https://huggingface.co/oauth/token"
// OAuth service configuration (AppAuth library requires this)
val authServiceConfig = AuthorizationServiceConfiguration(
Uri.parse(authEndpoint), // Authorization endpoint
Uri.parse(tokenEndpoint) // Token exchange endpoint
)
}

View file

@ -0,0 +1,334 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import android.content.Intent
import android.net.Uri
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.ActivityResultLauncher
import androidx.activity.result.contract.ActivityResultContracts
import androidx.browser.customtabs.CustomTabsIntent
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowForward
import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.modelmanager.TokenRequestResultType
import com.google.aiedge.gallery.ui.modelmanager.TokenStatus
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.net.HttpURLConnection
private const val TAG = "AGDownloadAndTryButton"
// TODO:
// - replace the download button in chat view page with this one, and add a flag to not "onclick"
// just download
/**
* Handles the "Download & Try it" button click, managing the model download process based on
* various conditions.
*
* If the button is enabled and not currently checking the token, it initiates a coroutine to
* handle the download logic.
*
* For models requiring download first, it specifically addresses HuggingFace URLs by first
* checking if authentication is necessary. If no authentication is needed, the download starts
* directly. Otherwise, it checks the current token status; if the token is invalid or expired,
* a token exchange flow is initiated. If a valid token exists, it attempts to access the
* download URL. If access is granted, the download begins; if not, a new token is requested.
*
* For non-HuggingFace URLs that need downloading, the download starts directly.
*
* If the model doesn't need to be downloaded first, the provided `onClicked` callback is executed.
*
* Additionally, for gated HuggingFace models, if accessing the model after token exchange results
* in a forbidden error, a modal bottom sheet is displayed, prompting the user to acknowledge the
* user agreement by opening it in a custom tab. Upon closing the tab, the download process is
* retried.
*
* The composable also manages UI states for indicating token checking and displaying the agreement
* acknowledgement sheet, and it handles requesting notification permissions before initiating the
* actual download.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun DownloadAndTryButton(
model: Model,
enabled: Boolean,
needToDownloadFirst: Boolean,
modelManagerViewModel: ModelManagerViewModel,
onClicked: () -> Unit
) {
val scope = rememberCoroutineScope()
val context = LocalContext.current
var checkingToken by remember { mutableStateOf(false) }
var showAgreementAckSheet by remember { mutableStateOf(false) }
val sheetState = rememberModalBottomSheetState()
// A launcher for requesting notification permission.
val permissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(model)
}
// Function to kick off download.
val startDownload: (accessToken: String?) -> Unit = { accessToken ->
model.accessToken = accessToken
onClicked()
checkNotificationPermissonAndStartDownload(
context = context,
launcher = permissionLauncher,
modelManagerViewModel = modelManagerViewModel,
model = model
)
checkingToken = false
}
// A launcher for opening the custom tabs intent for requesting user agreement ack.
// Once the tab is closed, try starting the download process.
val agreementAckLauncher: ActivityResultLauncher<Intent> = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
Log.d(TAG, "User closes the browser tab. Try to start downloading.")
startDownload(modelManagerViewModel.curAccessToken)
}
// A launcher for handling the authentication flow.
// It processes the result of the authentication activity and then checks if a user agreement
// acknowledgement is needed before proceeding with the model download.
val authResultLauncher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
modelManagerViewModel.handleAuthResult(result, onTokenRequested = { tokenRequestResult ->
when (tokenRequestResult.status) {
TokenRequestResultType.SUCCEEDED -> {
Log.d(TAG, "Token request succeeded. Checking if we need user to ack user agreement")
scope.launch(Dispatchers.IO) {
// Check if we can use the current token to access model. If not, we might need to
// acknowledge the user agreement.
if (modelManagerViewModel.getModelUrlResponse(
model = model,
accessToken = modelManagerViewModel.curAccessToken
) == HttpURLConnection.HTTP_FORBIDDEN
) {
Log.d(TAG, "Model '${model.name}' needs user agreement ack.")
showAgreementAckSheet = true
} else {
Log.d(
TAG,
"Model '${model.name}' does NOT need user agreement ack. Start downloading..."
)
withContext(Dispatchers.Main) {
startDownload(modelManagerViewModel.curAccessToken)
}
}
}
}
TokenRequestResultType.FAILED -> {
Log.d(TAG, "Token request done. Error message: ${tokenRequestResult.errorMessage ?: ""}")
checkingToken = false
}
TokenRequestResultType.USER_CANCELLED -> {
Log.d(TAG, "User cancelled. Do nothing")
checkingToken = false
}
}
})
}
// Function to kick off the authentication and token exchange flow.
val startTokenExchange = {
val authRequest = modelManagerViewModel.getAuthorizationRequest()
val authIntent = modelManagerViewModel.authService.getAuthorizationRequestIntent(authRequest)
authResultLauncher.launch(authIntent)
}
Button(
onClick = {
if (!enabled || checkingToken) {
return@Button
}
// Launches a coroutine to handle the initial check and potential authentication flow
// before downloading the model. It checks if the model needs to be downloaded first,
// handles HuggingFace URLs by verifying the need for authentication, and initiates
// the token exchange process if required or proceeds with the download if no auth is needed
// or a valid token is available.
scope.launch(Dispatchers.IO) {
if (needToDownloadFirst) {
// For HuggingFace urls
if (model.url.startsWith("https://huggingface.co")) {
checkingToken = true
// Check if the url needs auth.
Log.d(
TAG,
"Model '${model.name}' is from HuggingFace. Checking if the url needs auth to download"
)
if (modelManagerViewModel.getModelUrlResponse(model = model) == HttpURLConnection.HTTP_OK) {
Log.d(TAG, "Model '${model.name}' doesn't need auth. Start downloading the model...")
withContext(Dispatchers.Main) {
startDownload(null)
}
return@launch
}
Log.d(TAG, "Model '${model.name}' needs auth. Start token exchange process...")
// Get current token status
val tokenStatusAndData = modelManagerViewModel.getTokenStatusAndData()
when (tokenStatusAndData.status) {
// If token is not stored or expired, log in and request a new token.
TokenStatus.NOT_STORED, TokenStatus.EXPIRED -> {
withContext(Dispatchers.Main) {
startTokenExchange()
}
}
// If token is still valid...
TokenStatus.NOT_EXPIRED -> {
// Use the current token to check the download url.
Log.d(TAG, "Checking the download url '${model.url}' with the current token...")
val responseCode = modelManagerViewModel.getModelUrlResponse(
model = model, accessToken = tokenStatusAndData.data!!.accessToken
)
if (responseCode == HttpURLConnection.HTTP_OK) {
// Download url is accessible. Download the model.
Log.d(TAG, "Download url is accessible with the current token.")
withContext(Dispatchers.Main) {
startDownload(tokenStatusAndData.data.accessToken)
}
}
// Download url is NOT accessible. Request a new token.
else {
Log.d(
TAG,
"Download url is NOT accessible. Response code: ${responseCode}. Trying to request a new token."
)
withContext(Dispatchers.Main) {
startTokenExchange()
}
}
}
}
}
// For other urls, just download the model.
else {
Log.d(
TAG,
"Model '${model.name}' is not from huggingface. Start downloading the model..."
)
withContext(Dispatchers.Main) {
startDownload(null)
}
}
} else {
withContext(Dispatchers.Main) {
onClicked()
}
}
}
},
) {
Icon(
Icons.AutoMirrored.Rounded.ArrowForward,
contentDescription = "",
modifier = Modifier.padding(end = 4.dp)
)
if (checkingToken) {
Text("Checking access...")
} else {
if (needToDownloadFirst) {
Text("Download & Try it", maxLines = 1)
} else {
Text("Try it", maxLines = 1)
}
}
}
// A ModalBottomSheet composable that displays information about the user agreement
// for a gated model and provides a button to open the agreement in a custom tab.
// Upon clicking the button, it constructs the agreement URL, launches it using a
// custom tab, and then dismisses the bottom sheet.
if (showAgreementAckSheet) {
ModalBottomSheet(
onDismissRequest = {
showAgreementAckSheet = false
checkingToken = false
},
sheetState = sheetState,
modifier = Modifier.wrapContentHeight(),
) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier.padding(horizontal = 16.dp)
) {
Text("Acknowledge user agreement", style = MaterialTheme.typography.titleLarge)
Text(
"This is a gated model. Please click the button below to view and agree to the user agreement. After accepting, simply close that tab to proceed with the model download.",
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier.padding(vertical = 16.dp)
)
Button(onClick = {
// Get agreement url from model url.
val index = model.url.indexOf("/resolve/")
// Show it in a tab.
if (index >= 0) {
val agreementUrl = model.url.substring(0, index)
val customTabsIntent = CustomTabsIntent.Builder().build()
customTabsIntent.intent.setData(Uri.parse(agreementUrl))
agreementAckLauncher.launch(customTabsIntent.intent)
}
// Dismiss the sheet.
showAgreementAckSheet = false
}) {
Text("Open user agreement")
}
}
}
}
}

View file

@ -0,0 +1,105 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.aspectRatio
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.BlendMode
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.painter.Painter
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.vectorResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
private val SHAPES: List<Int> =
listOf(R.drawable.pantegon, R.drawable.double_circle, R.drawable.circle, R.drawable.four_circle)
/**
* Composable that displays an icon representing a task. It consists of a background
* image and a foreground icon, both centered within a square box.
*/
@Composable
fun TaskIcon(task: Task, modifier: Modifier = Modifier, width: Dp = 56.dp) {
Box(
modifier = modifier
.width(width)
.aspectRatio(1f),
contentAlignment = Alignment.Center,
) {
Image(
painter = getTaskIconBgShape(task = task),
contentDescription = "",
modifier = Modifier
.fillMaxSize()
.alpha(0.6f),
contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint(
MaterialTheme.customColors.taskIconShapeBgColor,
blendMode = BlendMode.SrcIn
)
)
Icon(
task.icon ?: ImageVector.vectorResource(task.iconVectorResourceId!!),
tint = getTaskIconColor(task = task),
modifier = Modifier.size(width * 0.6f),
contentDescription = "",
)
}
}
@Composable
private fun getTaskIconBgShape(task: Task): Painter {
val colorIndex: Int = task.index % SHAPES.size
return painterResource(SHAPES[colorIndex])
}
@Preview(showBackground = true)
@Composable
fun TaskIconPreview() {
for ((index, task) in TASKS.withIndex()) {
task.index = index
}
GalleryTheme {
Column(modifier = Modifier.background(Color.Gray)) {
TaskIcon(task = TASK_LLM_CHAT, width = 80.dp)
}
}
}

View file

@ -0,0 +1,442 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import android.Manifest
import android.content.Context
import android.content.pm.PackageManager
import android.net.Uri
import android.os.Build
import androidx.activity.compose.ManagedActivityResultLauncher
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.graphics.Color
import androidx.core.content.ContextCompat
import androidx.core.content.FileProvider
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import com.google.aiedge.gallery.ui.common.chat.Histogram
import com.google.aiedge.gallery.ui.common.chat.Stat
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.customColors
import java.io.File
import kotlin.math.abs
import kotlin.math.ln
import kotlin.math.max
import kotlin.math.min
import kotlin.math.pow
import kotlin.math.sqrt
private val STATS = listOf(
Stat(id = "min", label = "Min", unit = "ms"),
Stat(id = "max", label = "Max", unit = "ms"),
Stat(id = "avg", label = "Avg", unit = "ms"),
Stat(id = "stddev", label = "Stddev", unit = "ms")
)
interface LatencyProvider {
val latencyMs: Float
}
/** Format the bytes into a human-readable format. */
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
val bytes = this
val unit = if (si) 1000 else 1024
if (bytes < unit) return "$bytes B"
val exp = (ln(bytes.toDouble()) / ln(unit.toDouble())).toInt()
val pre = (if (si) "kMGTPE" else "KMGTPE")[exp - 1] + if (si) "" else "i"
var formatString = "%.1f %sB"
if (extraDecimalForGbAndAbove && pre.lowercase() != "k" && pre != "M") {
formatString = "%.2f %sB"
}
return formatString.format(bytes / unit.toDouble().pow(exp.toDouble()), pre)
}
fun Float.humanReadableDuration(): String {
val milliseconds = this
if (milliseconds < 1000) {
return "$milliseconds ms"
}
val seconds = milliseconds / 1000f
if (seconds < 60) {
return "%.1f s".format(seconds)
}
val minutes = seconds / 60f
if (minutes < 60) {
return "%.1f min".format(minutes)
}
val hours = minutes / 60f
return "%.1f h".format(hours)
}
fun Long.formatToHourMinSecond(): String {
val ms = this
if (ms < 0) {
return "-"
}
val seconds = ms / 1000
val hours = seconds / 3600
val minutes = (seconds % 3600) / 60
val remainingSeconds = seconds % 60
val parts = mutableListOf<String>()
if (hours > 0) {
parts.add("$hours h")
}
if (minutes > 0) {
parts.add("$minutes min")
}
if (remainingSeconds > 0 || (hours == 0L && minutes == 0L)) {
parts.add("$remainingSeconds sec")
}
return parts.joinToString(" ")
}
fun convertValueToTargetType(value: Any, valueType: ValueType): Any {
return when (valueType) {
ValueType.INT -> when (value) {
is Int -> value
is Float -> value.toInt()
is Double -> value.toInt()
is String -> value.toIntOrNull() ?: ""
is Boolean -> if (value) 1 else 0
else -> ""
}
ValueType.FLOAT -> when (value) {
is Int -> value.toFloat()
is Float -> value
is Double -> value.toFloat()
is String -> value.toFloatOrNull() ?: ""
is Boolean -> if (value) 1f else 0f
else -> ""
}
ValueType.DOUBLE -> when (value) {
is Int -> value.toDouble()
is Float -> value.toDouble()
is Double -> value
is String -> value.toDoubleOrNull() ?: ""
is Boolean -> if (value) 1.0 else 0.0
else -> ""
}
ValueType.BOOLEAN -> when (value) {
is Int -> value == 0
is Boolean -> value
is Float -> abs(value) > 1e-6
is Double -> abs(value) > 1e-6
is String -> value.isNotEmpty()
else -> false
}
ValueType.STRING -> value.toString()
}
}
fun getDistinctiveColor(index: Int): Color {
val colors = listOf(
// Color(0xffe6194b),
Color(0xff3cb44b),
Color(0xffffe119),
Color(0xff4363d8),
Color(0xfff58231),
Color(0xff911eb4),
Color(0xff46f0f0),
Color(0xfff032e6),
Color(0xffbcf60c),
Color(0xfffabebe),
Color(0xff008080),
Color(0xffe6beff),
Color(0xff9a6324),
Color(0xfffffac8),
Color(0xff800000),
Color(0xffaaffc3),
Color(0xff808000),
Color(0xffffd8b1),
Color(0xff000075)
)
return colors[index % colors.size]
}
fun Context.createTempPictureUri(
fileName: String = "picture_${System.currentTimeMillis()}", fileExtension: String = ".png"
): Uri {
val tempFile = File.createTempFile(
fileName, fileExtension, cacheDir
).apply {
createNewFile()
}
return FileProvider.getUriForFile(
applicationContext, "com.google.aiedge.gallery.provider", tempFile
)
}
fun runBasicBenchmark(
model: Model,
warmupCount: Int,
iterations: Int,
chatViewModel: ChatViewModel,
inferenceFn: () -> LatencyProvider,
chatMessageType: ChatMessageType,
) {
val start = System.currentTimeMillis()
var lastUpdateTs = 0L
val update: (ChatMessageBenchmarkResult) -> Unit = { message ->
if (lastUpdateTs == 0L) {
chatViewModel.addMessage(
model = model,
message = message,
)
lastUpdateTs = System.currentTimeMillis()
} else {
val curTs = System.currentTimeMillis()
if (curTs - lastUpdateTs > 500) {
chatViewModel.replaceLastMessage(model = model, message = message, type = chatMessageType)
lastUpdateTs = curTs
}
}
}
// Warmup.
val latencies: MutableList<Float> = mutableListOf()
for (count in 1..warmupCount) {
inferenceFn()
update(
ChatMessageBenchmarkResult(
orderedStats = STATS,
statValues = calculateStats(min = 0f, max = 0f, sum = 0f, latencies = latencies),
histogram = calculateLatencyHistogram(
latencies = latencies, min = 0f, max = 0f, avg = 0f
),
values = latencies,
warmupCurrent = count,
warmupTotal = warmupCount,
iterationCurrent = 0,
iterationTotal = iterations,
latencyMs = (System.currentTimeMillis() - start).toFloat(),
highlightStat = "avg"
)
)
}
// Benchmark iterations.
var min = Float.MAX_VALUE
var max = 0f
var sum = 0f
for (count in 1..iterations) {
val result = inferenceFn()
val latency = result.latencyMs
min = min(min, latency)
max = max(max, latency)
sum += latency
latencies.add(latency)
val curTs = System.currentTimeMillis()
if (curTs - lastUpdateTs > 500 || count == iterations) {
lastUpdateTs = curTs
val stats = calculateStats(min = min, max = max, sum = sum, latencies = latencies)
chatViewModel.replaceLastMessage(
model = model,
message = ChatMessageBenchmarkResult(
orderedStats = STATS,
statValues = stats,
histogram = calculateLatencyHistogram(
latencies = latencies,
min = min,
max = max,
avg = stats["avg"] ?: 0f,
),
values = latencies,
warmupCurrent = warmupCount,
warmupTotal = warmupCount,
iterationCurrent = count,
iterationTotal = iterations,
latencyMs = (System.currentTimeMillis() - start).toFloat(),
highlightStat = "avg"
),
type = chatMessageType,
)
}
// Go through other benchmark messages and update their buckets for the common min/max values.
var allMin = Float.MAX_VALUE
var allMax = 0f
val allMessages = chatViewModel.uiState.value.messagesByModel[model.name] ?: listOf()
for (message in allMessages) {
if (message is ChatMessageBenchmarkResult) {
val curMin = message.statValues["min"] ?: 0f
val curMax = message.statValues["max"] ?: 0f
allMin = min(allMin, curMin)
allMax = max(allMax, curMax)
}
}
for ((index, message) in allMessages.withIndex()) {
if (message === allMessages.last()) {
break
}
if (message is ChatMessageBenchmarkResult) {
val updatedMessage = ChatMessageBenchmarkResult(
orderedStats = STATS,
statValues = message.statValues,
histogram = calculateLatencyHistogram(
latencies = message.values,
min = allMin,
max = allMax,
avg = message.statValues["avg"] ?: 0f,
),
values = message.values,
warmupCurrent = message.warmupCurrent,
warmupTotal = message.warmupTotal,
iterationCurrent = message.iterationCurrent,
iterationTotal = message.iterationTotal,
latencyMs = message.latencyMs,
highlightStat = "avg"
)
chatViewModel.replaceMessage(model = model, index = index, message = updatedMessage)
}
}
}
}
private fun calculateStats(
min: Float, max: Float, sum: Float, latencies: MutableList<Float>
): MutableMap<String, Float> {
val avg = if (latencies.size == 0) 0f else sum / latencies.size
val squaredDifferences = latencies.map { (it - avg).pow(2) }
val variance = squaredDifferences.average()
val stddev = if (latencies.size == 0) 0f else sqrt(variance).toFloat()
var medium = 0f
if (latencies.size == 1) {
medium = latencies[0]
} else if (latencies.size > 1) {
latencies.sort()
val middle = latencies.size / 2
medium =
if (latencies.size % 2 == 0) (latencies[middle - 1] + latencies[middle]) / 2.0f else latencies[middle]
}
return mutableMapOf(
"min" to min, "max" to max, "avg" to avg, "stddev" to stddev, "medium" to medium
)
}
fun calculateLatencyHistogram(
latencies: List<Float>, min: Float, max: Float, avg: Float, numBuckets: Int = 20
): Histogram {
if (latencies.isEmpty() || numBuckets <= 0) {
return Histogram(
buckets = List(numBuckets) { 0 }, maxCount = 0
)
}
if (min == max) {
// All latencies are the same.
val result = MutableList(numBuckets) { 0 }
result[0] = latencies.size
return Histogram(buckets = result, maxCount = result[0], highlightBucketIndex = 0)
}
val bucketSize = (max - min) / numBuckets
val histogram = MutableList(numBuckets) { 0 }
val getBucketIndex: (value: Float) -> Int = {
var bucketIndex = ((it - min) / bucketSize).toInt()
// Handle the case where latency equals maxLatency
if (bucketIndex == numBuckets) {
bucketIndex = numBuckets - 1
}
bucketIndex
}
for (latency in latencies) {
val bucketIndex = getBucketIndex(latency)
histogram[bucketIndex]++
}
val avgBucketIndex = getBucketIndex(avg)
return Histogram(
buckets = histogram,
maxCount = histogram.maxOrNull() ?: 0,
highlightBucketIndex = avgBucketIndex
)
}
fun getConfigValueString(value: Any, config: Config): String {
var strNewValue = "$value"
if (config.valueType == ValueType.FLOAT) {
strNewValue = "%.2f".format(value)
}
return strNewValue
}
@Composable
fun getTaskBgColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskBgColors.size
return MaterialTheme.customColors.taskBgColors[colorIndex]
}
@Composable
fun getTaskIconColor(task: Task): Color {
val colorIndex: Int = task.index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}
@Composable
fun getTaskIconColor(index: Int): Color {
val colorIndex: Int = index % MaterialTheme.customColors.taskIconColors.size
return MaterialTheme.customColors.taskIconColors[colorIndex]
}
fun checkNotificationPermissonAndStartDownload(
context: Context,
launcher: ManagedActivityResultLauncher<String, Boolean>,
modelManagerViewModel: ModelManagerViewModel,
model: Model
) {
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(
context, Manifest.permission.POST_NOTIFICATIONS
) -> {
modelManagerViewModel.downloadModel(model)
}
// Otherwise, ask for permission
else -> {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
launcher.launch(Manifest.permission.POST_NOTIFICATIONS)
}
}
}
}

View file

@ -0,0 +1,102 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.runtime.Composable
import androidx.compose.ui.tooling.preview.Preview
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.theme.GalleryTheme
private const val DEFAULT_BENCHMARK_WARM_UP_ITERATIONS = 50f
private const val DEFAULT_BENCHMARK_ITERATIONS = 200f
private val BENCHMARK_CONFIGS: List<Config> = listOf(
NumberSliderConfig(
key = ConfigKey.WARM_UP_ITERATIONS,
sliderMin = 10f,
sliderMax = 200f,
defaultValue = DEFAULT_BENCHMARK_WARM_UP_ITERATIONS,
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.BENCHMARK_ITERATIONS,
sliderMin = 50f,
sliderMax = 500f,
defaultValue = DEFAULT_BENCHMARK_ITERATIONS,
valueType = ValueType.INT
),
)
private val BENCHMARK_CONFIGS_INITIAL_VALUES = mapOf(
ConfigKey.WARM_UP_ITERATIONS.label to DEFAULT_BENCHMARK_WARM_UP_ITERATIONS,
ConfigKey.BENCHMARK_ITERATIONS.label to DEFAULT_BENCHMARK_ITERATIONS
)
/**
* Composable function to display a configuration dialog for benchmarking a chat message.
*
* This function renders a configuration dialog specifically tailored for setting up
* benchmark parameters. It allows users to specify warm-up and benchmark iterations
* before running a benchmark test on a given chat message.
*/
@Composable
fun BenchmarkConfigDialog(
onDismissed: () -> Unit,
messageToBenchmark: ChatMessage?,
onBenchmarkClicked: (ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit
) {
ConfigDialog(
title = "Benchmark configs",
okBtnLabel = "Start",
configs = BENCHMARK_CONFIGS,
initialValues = BENCHMARK_CONFIGS_INITIAL_VALUES,
onDismissed = onDismissed,
onOk = { curConfigValues ->
// Hide config dialog.
onDismissed()
// Start benchmark.
messageToBenchmark?.let { message ->
val warmUpIterations = convertValueToTargetType(
value = curConfigValues.getValue(ConfigKey.WARM_UP_ITERATIONS.label),
valueType = ValueType.INT
) as Int
val benchmarkIterations = convertValueToTargetType(
value = curConfigValues.getValue(ConfigKey.BENCHMARK_ITERATIONS.label),
valueType = ValueType.INT
) as Int
onBenchmarkClicked(message, warmUpIterations, benchmarkIterations)
}
},
)
}
@Preview(showBackground = true)
@Composable
fun BenchmarkConfigDialogPreview() {
GalleryTheme {
BenchmarkConfigDialog(
onDismissed = {},
messageToBenchmark = null,
onBenchmarkClicked = { _, _, _ -> }
)
}
}

View file

@ -0,0 +1,180 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import android.graphics.Bitmap
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.unit.Dp
import com.google.aiedge.gallery.data.Model
enum class ChatMessageType {
INFO,
TEXT,
IMAGE,
IMAGE_WITH_HISTORY,
LOADING,
CLASSIFICATION,
CONFIG_VALUES_CHANGE,
BENCHMARK_RESULT,
BENCHMARK_LLM_RESULT,
PROMPT_TEMPLATES
}
enum class ChatSide {
USER, AGENT, SYSTEM
}
data class Classification(val label: String, val score: Float, val color: Color)
/** Base class for a chat message. */
open class ChatMessage(
open val type: ChatMessageType, open val side: ChatSide, open val latencyMs: Float = -1f
) {
open fun clone(): ChatMessage {
return ChatMessage(type = type, side = side, latencyMs = latencyMs)
}
}
/** Chat message for showing loading status. */
class ChatMessageLoading : ChatMessage(type = ChatMessageType.LOADING, side = ChatSide.AGENT)
/** Chat message for info (help). */
class ChatMessageInfo(val content: String) :
ChatMessage(type = ChatMessageType.INFO, side = ChatSide.SYSTEM)
/** Chat message for config values change. */
class ChatMessageConfigValuesChange(
val model: Model,
val oldValues: Map<String, Any>,
val newValues: Map<String, Any>
) : ChatMessage(type = ChatMessageType.CONFIG_VALUES_CHANGE, side = ChatSide.SYSTEM)
/** Chat message for plain text. */
open class ChatMessageText(
val content: String,
override val side: ChatSide,
// Negative numbers will hide the latency display.
override val latencyMs: Float = 0f,
val isMarkdown: Boolean = true,
) : ChatMessage(type = ChatMessageType.TEXT, side = side, latencyMs = latencyMs) {
override fun clone(): ChatMessageText {
return ChatMessageText(
content = content,
side = side,
latencyMs = latencyMs,
isMarkdown = isMarkdown
)
}
}
/** Chat message for images. */
class ChatMessageImage(
val bitmap: Bitmap,
val imageBitMap: ImageBitmap,
override val side: ChatSide,
override val latencyMs: Float = 0f
) :
ChatMessage(type = ChatMessageType.IMAGE, side = side, latencyMs = latencyMs) {
override fun clone(): ChatMessageImage {
return ChatMessageImage(
bitmap = bitmap,
imageBitMap = imageBitMap,
side = side,
latencyMs = latencyMs
)
}
}
/** Chat message for images with history. */
class ChatMessageImageWithHistory(
val bitmaps: List<Bitmap>,
val imageBitMaps: List<ImageBitmap>,
val totalIterations: Int,
override val side: ChatSide,
override val latencyMs: Float = 0f,
var curIteration: Int = 0, // 0-based
) :
ChatMessage(type = ChatMessageType.IMAGE_WITH_HISTORY, side = side, latencyMs = latencyMs) {
fun isRunning(): Boolean {
return curIteration < totalIterations - 1
}
}
/** Chat message for showing classification result. */
class ChatMessageClassification(
val classifications: List<Classification>,
override val latencyMs: Float = 0f,
// Typical android phone width is > 320dp
val maxBarWidth: Dp? = null,
) : ChatMessage(type = ChatMessageType.CLASSIFICATION, side = ChatSide.AGENT, latencyMs = latencyMs)
/** A stat used in benchmark result. */
data class Stat(val id: String, val label: String, val unit: String)
/** Chat message for showing benchmark result. */
class ChatMessageBenchmarkResult(
val orderedStats: List<Stat>,
val statValues: MutableMap<String, Float>,
val values: List<Float>,
val histogram: Histogram,
val warmupCurrent: Int,
val warmupTotal: Int,
val iterationCurrent: Int,
val iterationTotal: Int,
override val latencyMs: Float = 0f,
val highlightStat: String = "",
) :
ChatMessage(
type = ChatMessageType.BENCHMARK_RESULT,
side = ChatSide.AGENT,
latencyMs = latencyMs
) {
fun isWarmingUp(): Boolean {
return warmupCurrent < warmupTotal
}
fun isRunning(): Boolean {
return iterationCurrent < iterationTotal
}
}
/** Chat message for showing LLM benchmark result. */
class ChatMessageBenchmarkLlmResult(
val orderedStats: List<Stat>,
val statValues: MutableMap<String, Float>,
val running: Boolean,
override val latencyMs: Float = 0f,
) : ChatMessage(
type = ChatMessageType.BENCHMARK_LLM_RESULT,
side = ChatSide.AGENT,
latencyMs = latencyMs
)
data class Histogram(
val buckets: List<Int>,
val maxCount: Int,
val highlightBucketIndex: Int = -1
)
/** Chat message for showing prompt templates. */
class ChatMessagePromptTemplates(
val templates: List<PromptTemplate>,
val showMakeYourOwn: Boolean = true,
) : ChatMessage(type = ChatMessageType.PROMPT_TEMPLATES, side = ChatSide.SYSTEM)
data class PromptTemplate(val title: String, val description: String, val prompt: String)

View file

@ -0,0 +1,491 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.gestures.detectTapGestures
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.WindowInsets
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.ime
import androidx.compose.foundation.layout.imePadding
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Timer
import androidx.compose.material.icons.rounded.ContentCopy
import androidx.compose.material.icons.rounded.Refresh
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.SnackbarHost
import androidx.compose.material3.SnackbarHostState
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.hapticfeedback.HapticFeedbackType
import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
import androidx.compose.ui.input.nestedscroll.NestedScrollSource
import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.platform.LocalHapticFeedback
import androidx.compose.ui.res.dimensionResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatus
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.launch
enum class ChatInputType {
TEXT, IMAGE,
}
/**
* Composable function for the main chat panel, displaying messages and handling user input.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ChatPanel(
modelManagerViewModel: ModelManagerViewModel,
task: Task,
selectedModel: Model,
viewModel: ChatViewModel,
onSendMessage: (Model, ChatMessage) -> Unit,
onRunAgainClicked: (Model, ChatMessage) -> Unit,
onBenchmarkClicked: (Model, ChatMessage, warmUpIterations: Int, benchmarkIterations: Int) -> Unit,
modifier: Modifier = Modifier,
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
onStreamEnd: (Int) -> Unit = {},
onStopButtonClicked: () -> Unit = {},
chatInputType: ChatInputType = ChatInputType.TEXT,
showStopButtonInInputWhenInProgress: Boolean = false,
) {
val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val messages = uiState.messagesByModel[selectedModel.name] ?: listOf()
val streamingMessage = uiState.streamingMessagesByModel[selectedModel.name]
val snackbarHostState = remember { SnackbarHostState() }
val scope = rememberCoroutineScope()
val haptic = LocalHapticFeedback.current
var curMessage by remember { mutableStateOf("") } // Correct state
val focusManager = LocalFocusManager.current
// Remember the LazyListState to control scrolling
val listState = rememberLazyListState()
val density = LocalDensity.current
var showBenchmarkConfigsDialog by remember { mutableStateOf(false) }
val benchmarkMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
var showMessageLongPressedSheet by remember { mutableStateOf(false) }
val longPressedMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
// Keep track of the last message and last message content.
val lastMessage: MutableState<ChatMessage?> = remember { mutableStateOf(null) }
val lastMessageContent: MutableState<String> = remember { mutableStateOf("") }
if (messages.isNotEmpty()) {
val tmpLastMessage = messages.last()
lastMessage.value = tmpLastMessage
if (tmpLastMessage is ChatMessageText) {
lastMessageContent.value = tmpLastMessage.content
}
}
// Scroll the content to the bottom when any of these changes.
LaunchedEffect(
messages.size,
lastMessage.value,
lastMessageContent.value,
WindowInsets.ime.getBottom(density),
) {
if (messages.isNotEmpty()) {
listState.animateScrollToItem(messages.lastIndex, scrollOffset = 10000)
}
}
val nestedScrollConnection = remember {
object : NestedScrollConnection {
override fun onPreScroll(available: Offset, source: NestedScrollSource): Offset {
// If downward scroll, clear the focus from any currently focused composable.
// This is useful for dismissing software keyboards or hiding text input fields
// when the user starts scrolling down a list.
if (available.y > 0) {
focusManager.clearFocus()
}
// Let LazyColumn handle the scroll
return Offset.Zero
}
}
}
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[selectedModel.name]
Column(
modifier = modifier.imePadding()
) {
Box(contentAlignment = Alignment.BottomCenter, modifier = Modifier.weight(1f)) {
LazyColumn(
modifier = Modifier
.fillMaxSize()
.nestedScroll(nestedScrollConnection),
state = listState, verticalArrangement = Arrangement.Top,
) {
items(messages) { message ->
val imageHistoryCurIndex = remember { mutableIntStateOf(0) }
var hAlign: Alignment.Horizontal = Alignment.End
var backgroundColor: Color = MaterialTheme.customColors.userBubbleBgColor
var hardCornerAtLeftOrRight = false
var extraPaddingStart = 48.dp
var extraPaddingEnd = 0.dp
if (message.side == ChatSide.AGENT) {
hAlign = Alignment.Start
backgroundColor = MaterialTheme.customColors.agentBubbleBgColor
hardCornerAtLeftOrRight = true
extraPaddingStart = 0.dp
extraPaddingEnd = 48.dp
} else if (message.side == ChatSide.SYSTEM) {
extraPaddingStart = 24.dp
extraPaddingEnd = 24.dp
if (message.type == ChatMessageType.PROMPT_TEMPLATES) {
extraPaddingStart = 12.dp
extraPaddingEnd = 12.dp
}
}
if (message.type == ChatMessageType.IMAGE) {
backgroundColor = Color.Transparent
}
val bubbleBorderRadius = dimensionResource(R.dimen.chat_bubble_corner_radius)
Column(
modifier = Modifier
.fillMaxWidth()
.padding(
start = 12.dp + extraPaddingStart,
end = 12.dp + extraPaddingEnd,
top = 6.dp,
bottom = 6.dp,
),
horizontalAlignment = hAlign,
) {
// Sender row.
MessageSender(
message = message,
agentNameRes = task.agentNameRes,
imageHistoryCurIndex = imageHistoryCurIndex.intValue
)
// Message body.
when (message) {
// Loading.
is ChatMessageLoading -> MessageBodyLoading()
// Info.
is ChatMessageInfo -> MessageBodyInfo(message = message)
// Config values change.
is ChatMessageConfigValuesChange -> MessageBodyConfigUpdate(message = message)
// Prompt templates.
is ChatMessagePromptTemplates -> MessageBodyPromptTemplates(message = message,
task = task,
onPromptClicked = { template ->
onSendMessage(
selectedModel, ChatMessageText(content = template.prompt, side = ChatSide.USER)
)
})
// Non-system messages.
else -> {
// The bubble shape around the message body.
var messageBubbleModifier = Modifier
.clip(
MessageBubbleShape(
radius = bubbleBorderRadius,
hardCornerAtLeftOrRight = hardCornerAtLeftOrRight
)
)
.background(backgroundColor)
if (message is ChatMessageText) {
messageBubbleModifier = messageBubbleModifier
.pointerInput(Unit) {
detectTapGestures(
onLongPress = {
haptic.performHapticFeedback(HapticFeedbackType.LongPress)
longPressedMessage.value = message
showMessageLongPressedSheet = true
},
)
}
}
Box(
modifier = messageBubbleModifier,
) {
when (message) {
// Text
is ChatMessageText -> MessageBodyText(message = message)
// Image
is ChatMessageImage -> MessageBodyImage(message = message)
// Image with history (for image gen)
is ChatMessageImageWithHistory -> MessageBodyImageWithHistory(
message = message, imageHistoryCurIndex = imageHistoryCurIndex
)
// Classification result
is ChatMessageClassification -> MessageBodyClassification(
message = message,
modifier = Modifier.width(message.maxBarWidth ?: CLASSIFICATION_BAR_MAX_WIDTH)
)
// Benchmark result.
is ChatMessageBenchmarkResult -> MessageBodyBenchmark(message = message)
// Benchmark LLM result.
is ChatMessageBenchmarkLlmResult -> MessageBodyBenchmarkLlm(message = message)
else -> {}
}
}
if (message.side == ChatSide.AGENT) {
LatencyText(message = message)
} else if (message.side == ChatSide.USER) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp)
) {
// Run again button.
if (selectedModel.showRunAgainButton) {
MessageActionButton(
label = stringResource(R.string.run_again),
icon = Icons.Rounded.Refresh,
onClick = {
onRunAgainClicked(selectedModel, message)
},
enabled = !uiState.inProgress
)
}
// Benchmark button
if (selectedModel.showBenchmarkButton) {
MessageActionButton(
label = stringResource(R.string.benchmark),
icon = Icons.Outlined.Timer,
onClick = {
if (selectedModel.taskType == TaskType.LLM_CHAT) {
onBenchmarkClicked(selectedModel, message, 0, 0)
} else {
showBenchmarkConfigsDialog = true
benchmarkMessage.value = message
}
},
enabled = !uiState.inProgress
)
}
}
}
}
}
}
}
}
// Model initialization in-progress message.
this@Column.AnimatedVisibility(
visible = modelInitializationStatus == ModelInitializationStatus.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
modifier = Modifier.offset(y = 12.dp)
) {
ModelInitializationStatusChip()
}
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
}
// Chat input
when (chatInputType) {
ChatInputType.TEXT -> {
val isLlmTask = selectedModel.taskType == TaskType.LLM_CHAT
val notLlmStartScreen = !(messages.size == 1 && messages[0] is ChatMessagePromptTemplates)
MessageInputText(
modelManagerViewModel = modelManagerViewModel,
curMessage = curMessage,
inProgress = uiState.inProgress,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
onValueChanged = { curMessage = it },
onSendMessage = {
onSendMessage(selectedModel, it)
curMessage = ""
},
onOpenPromptTemplatesClicked = {
onSendMessage(
selectedModel, ChatMessagePromptTemplates(
templates = selectedModel.llmPromptTemplates, showMakeYourOwn = false
)
)
},
onStopButtonClicked = onStopButtonClicked,
showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
)
}
ChatInputType.IMAGE -> MessageInputImage(
disableButtons = uiState.inProgress,
streamingMessage = streamingMessage,
onImageSelected = { bitmap ->
onSendMessage(
selectedModel, ChatMessageImage(
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
)
)
},
onStreamImage = { bitmap ->
onStreamImageMessage(
selectedModel, ChatMessageImage(
bitmap = bitmap, imageBitMap = bitmap.asImageBitmap(), side = ChatSide.USER
)
)
},
onStreamEnd = onStreamEnd,
)
}
}
// Benchmark config dialog.
if (showBenchmarkConfigsDialog) {
BenchmarkConfigDialog(onDismissed = { showBenchmarkConfigsDialog = false },
messageToBenchmark = benchmarkMessage.value,
onBenchmarkClicked = { message, warmUpIterations, benchmarkIterations ->
onBenchmarkClicked(selectedModel, message, warmUpIterations, benchmarkIterations)
})
}
// Sheet to show when a message is long-pressed.
if (showMessageLongPressedSheet) {
val message = longPressedMessage.value
if (message != null && message is ChatMessageText) {
val clipboardManager = LocalClipboardManager.current
ModalBottomSheet(
onDismissRequest = { showMessageLongPressedSheet = false },
modifier = Modifier.wrapContentHeight(),
) {
Column {
// Copy text.
Box(modifier = Modifier
.fillMaxWidth()
.clickable {
// Copy text.
val clipData = AnnotatedString(message.content)
clipboardManager.setText(clipData)
// Hide sheet.
showMessageLongPressedSheet = false
// Show a snack bar.
scope.launch {
snackbarHostState.showSnackbar("Text copied to clipboard")
}
}) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
modifier = Modifier
.padding(vertical = 8.dp, horizontal = 16.dp)
) {
Icon(
Icons.Rounded.ContentCopy,
contentDescription = "",
modifier = Modifier.size(18.dp)
)
Text("Copy text")
}
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ChatPanelPreview() {
GalleryTheme {
val context = LocalContext.current
val task = TASK_TEST1
ChatPanel(
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
task = task,
selectedModel = TASK_TEST1.models[1],
viewModel = PreviewChatModel(context = context),
onSendMessage = { _, _ -> },
onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> },
)
}
}

View file

@ -0,0 +1,306 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import android.util.Log
import androidx.activity.compose.BackHandler
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import com.google.aiedge.gallery.GalleryTopAppBar
import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlin.math.absoluteValue
private const val TAG = "AGChatView"
/**
* A composable that displays a chat interface, allowing users to interact with different models
* associated with a given task.
*
* This composable provides a horizontal pager for switching between models, a model selector
* for configuring the selected model, and a chat panel for sending and receiving messages. It also
* manages model initialization, cleanup, and download status, and handles navigation and system
* back gestures.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ChatView(
task: Task,
viewModel: ChatViewModel,
modelManagerViewModel: ModelManagerViewModel,
onSendMessage: (Model, ChatMessage) -> Unit,
onRunAgainClicked: (Model, ChatMessage) -> Unit,
onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
onStopButtonClicked: (Model) -> Unit = {},
chatInputType: ChatInputType = ChatInputType.TEXT,
showStopButtonInInputWhenInProgress: Boolean = false,
) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel
val pagerState = rememberPagerState(initialPage = task.models.indexOf(selectedModel),
pageCount = { task.models.size })
val context = LocalContext.current
val scope = rememberCoroutineScope()
val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(selectedModel)
}
val handleNavigateUp = {
navigateUp()
// clean up all models.
scope.launch(Dispatchers.Default) {
for (model in task.models) {
modelManagerViewModel.cleanupModel(model = model)
}
}
}
// Initialize model when model/download state changes.
val status = modelManagerUiState.modelDownloadStatus[selectedModel.name]
LaunchedEffect(status, selectedModel.name) {
if (status?.status == ModelDownloadStatusType.SUCCEEDED) {
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
modelManagerViewModel.initializeModel(context, model = selectedModel)
}
}
// Update selected model and clean up previous model when page is settled on a model page.
LaunchedEffect(pagerState.settledPage) {
val curSelectedModel = task.models[pagerState.settledPage]
Log.d(
TAG,
"Pager settled on model '${curSelectedModel.name}' from '${selectedModel.name}'. Updating selected model."
)
if (curSelectedModel.name != selectedModel.name) {
modelManagerViewModel.cleanupModel(model = selectedModel)
}
modelManagerViewModel.selectModel(curSelectedModel)
}
// Handle system's edge swipe.
BackHandler {
handleNavigateUp()
}
Scaffold(modifier = modifier, topBar = {
GalleryTopAppBar(
title = task.type.label,
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = {
handleNavigateUp()
}),
rightAction = AppBarAction(actionType = AppBarActionType.NO_ACTION, actionFn = {}),
)
}) { innerPadding ->
Box {
// A horizontal scrollable pager to switch between models.
HorizontalPager(state = pagerState) { pageIndex ->
val curSelectedModel = task.models[pageIndex]
// Calculate the alpha of the current page based on how far they are from the center.
val pageOffset = (
(pagerState.currentPage - pageIndex) + pagerState
.currentPageOffsetFraction
).absoluteValue
val curAlpha = 1f - pageOffset.coerceIn(0f, 1f)
Column(
modifier = Modifier
.padding(innerPadding)
.fillMaxSize()
.background(MaterialTheme.colorScheme.surface)
) {
// Model selector at the top.
ModelSelector(
model = curSelectedModel,
task = task,
modelManagerViewModel = modelManagerViewModel,
onConfigChanged = { old, new ->
viewModel.addConfigChangedMessage(
oldConfigValues = old,
newConfigValues = new,
model = curSelectedModel
)
},
modifier = Modifier.fillMaxWidth(),
contentAlpha = curAlpha,
)
// Manages the conditional display of UI elements (download model button and downloading
// animation) based on the corresponding download status.
//
// It uses delayed visibility ensuring they are shown only after a short delay if their
// respective conditions remain true. This prevents UI flickering and provides a smoother
// user experience.
val curStatus = modelManagerUiState.modelDownloadStatus[curSelectedModel.name]
var shouldShowDownloadingAnimation by remember { mutableStateOf(false) }
var downloadingAnimationConditionMet by remember { mutableStateOf(false) }
var shouldShowDownloadModelButton by remember { mutableStateOf(false) }
var downloadModelButtonConditionMet by remember { mutableStateOf(false) }
downloadingAnimationConditionMet =
curStatus?.status == ModelDownloadStatusType.IN_PROGRESS ||
curStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED ||
curStatus?.status == ModelDownloadStatusType.UNZIPPING
downloadModelButtonConditionMet =
curStatus?.status == ModelDownloadStatusType.FAILED ||
curStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED
LaunchedEffect(downloadingAnimationConditionMet) {
if (downloadingAnimationConditionMet) {
delay(100)
shouldShowDownloadingAnimation = true
} else {
shouldShowDownloadingAnimation = false
}
}
LaunchedEffect(downloadModelButtonConditionMet) {
if (downloadModelButtonConditionMet) {
delay(700)
shouldShowDownloadModelButton = true
} else {
shouldShowDownloadModelButton = false
}
}
AnimatedVisibility(
visible = shouldShowDownloadingAnimation,
enter = scaleIn(initialScale = 0.9f) + fadeIn(),
exit = scaleOut(targetScale = 0.9f) + fadeOut()
) {
Box(
modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.Center
) {
ModelDownloadingAnimation()
}
}
AnimatedVisibility(
visible = shouldShowDownloadModelButton,
enter = fadeIn(),
exit = fadeOut()
) {
ModelNotDownloaded(modifier = Modifier.weight(1f), onClicked = {
checkNotificationPermissonAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
model = curSelectedModel
)
})
}
// The main messages panel.
if (curStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
ChatPanel(
modelManagerViewModel = modelManagerViewModel,
task = task,
selectedModel = curSelectedModel,
viewModel = viewModel,
onSendMessage = onSendMessage,
onRunAgainClicked = onRunAgainClicked,
onBenchmarkClicked = onBenchmarkClicked,
onStreamImageMessage = onStreamImageMessage,
onStreamEnd = { averageFps ->
viewModel.addMessage(
model = curSelectedModel,
message = ChatMessageInfo(content = "Live camera session ended. Average FPS: $averageFps")
)
},
onStopButtonClicked = {
onStopButtonClicked(curSelectedModel)
},
modifier = Modifier
.weight(1f)
.graphicsLayer { alpha = curAlpha },
chatInputType = chatInputType,
showStopButtonInInputWhenInProgress = showStopButtonInInputWhenInProgress,
)
}
}
}
}
}
}
@Preview
@Composable
fun ChatScreenPreview() {
GalleryTheme {
val context = LocalContext.current
val task = TASK_TEST1
ChatView(
task = task,
viewModel = PreviewChatModel(context = context),
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
onSendMessage = { _, _ -> },
onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> },
navigateUp = {},
)
}
}

View file

@ -0,0 +1,189 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import android.util.Log
import androidx.lifecycle.ViewModel
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
private const val TAG = "AGChatViewModel"
private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***"
data class ChatUiState(
/**
* Indicates whether the runtime is currently processing a message.
*/
val inProgress: Boolean = false,
/**
* A map of model names to lists of chat messages.
*/
val messagesByModel: Map<String, MutableList<ChatMessage>> = mapOf(),
/**
* A map of model names to the currently streaming chat message.
*/
val streamingMessagesByModel: Map<String, ChatMessage> = mapOf(),
)
/**
* ViewModel responsible for managing the chat UI state and handling chat-related operations.
*/
open class ChatViewModel(val task: Task) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow()
fun addMessage(model: Model, message: ChatMessage) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList()
if (newMessages != null) {
newMessagesByModel[model.name] = newMessages
// Remove prompt template message if it is the current last message.
if (newMessages.size > 0 && newMessages.last().type == ChatMessageType.PROMPT_TEMPLATES) {
newMessages.removeAt(newMessages.size - 1)
}
newMessages.add(message)
}
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
}
fun removeLastMessage(model: Model) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
if (newMessages.size > 0) {
newMessages.removeAt(newMessages.size - 1)
}
newMessagesByModel[model.name] = newMessages
_uiState.update { _uiState.value.copy(messagesByModel = newMessagesByModel) }
}
fun getLastMessage(model: Model): ChatMessage? {
return (_uiState.value.messagesByModel[model.name] ?: listOf()).lastOrNull()
}
fun updateLastMessageContentIncrementally(
model: Model,
partialContent: String,
latencyMs: Float,
) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
if (newMessages.size > 0) {
val lastMessage = newMessages.last()
if (lastMessage is ChatMessageText) {
var newContent = "${lastMessage.content}${partialContent}"
// TODO: special handling for deepseek to remove the <think> tag.
// Add "thinking" and "done thinking" around the thinking content.
newContent = newContent
.replace("<think>", "$START_THINKING\n")
.replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) {
val thinkingContent =
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
.replace(START_THINKING, "")
.replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
}
}
val newLastMessage = ChatMessageText(
content = newContent,
side = lastMessage.side,
latencyMs = latencyMs,
)
newMessages.removeAt(newMessages.size - 1)
newMessages.add(newLastMessage)
}
}
newMessagesByModel[model.name] = newMessages
val newUiState = _uiState.value.copy(messagesByModel = newMessagesByModel)
_uiState.update { newUiState }
}
fun replaceLastMessage(model: Model, message: ChatMessage, type: ChatMessageType) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
if (newMessages.size > 0) {
val index = newMessages.indexOfLast { it.type == type }
if (index >= 0) {
newMessages[index] = message
}
}
newMessagesByModel[model.name] = newMessages
val newUiState = _uiState.value.copy(messagesByModel = newMessagesByModel)
_uiState.update { newUiState }
}
fun replaceMessage(model: Model, index: Int, message: ChatMessage) {
val newMessagesByModel = _uiState.value.messagesByModel.toMutableMap()
val newMessages = newMessagesByModel[model.name]?.toMutableList() ?: mutableListOf()
if (newMessages.size > 0) {
newMessages[index] = message
}
newMessagesByModel[model.name] = newMessages
val newUiState = _uiState.value.copy(messagesByModel = newMessagesByModel)
_uiState.update { newUiState }
}
fun updateStreamingMessage(model: Model, message: ChatMessage) {
val newStreamingMessagesByModel = _uiState.value.streamingMessagesByModel.toMutableMap()
newStreamingMessagesByModel[model.name] = message
_uiState.update { _uiState.value.copy(streamingMessagesByModel = newStreamingMessagesByModel) }
}
fun setInProgress(inProgress: Boolean) {
_uiState.update { _uiState.value.copy(inProgress = inProgress) }
}
fun isInProgress(): Boolean {
return _uiState.value.inProgress
}
fun addConfigChangedMessage(
oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>, model: Model
) {
Log.d(TAG, "Adding config changed message. Old: ${oldConfigValues}, new: $newConfigValues")
val message = ChatMessageConfigValuesChange(
model = model, oldValues = oldConfigValues, newValues = newConfigValues
)
addMessage(message = message, model = model)
}
private fun createUiState(task: Task): ChatUiState {
val messagesByModel: MutableMap<String, MutableList<ChatMessage>> = mutableMapOf()
for (model in task.models) {
val messages: MutableList<ChatMessage> = mutableListOf()
if (model.llmPromptTemplates.isNotEmpty()) {
messages.add(ChatMessagePromptTemplates(templates = model.llmPromptTemplates))
}
messagesByModel[model.name] = messages
}
return ChatUiState(
messagesByModel = messagesByModel
)
}
}

View file

@ -0,0 +1,313 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import android.util.Log
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.BasicTextField
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material3.Button
import androidx.compose.material3.Card
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.SegmentedButton
import androidx.compose.material3.SegmentedButtonDefaults
import androidx.compose.material3.SingleChoiceSegmentedButtonRow
import androidx.compose.material3.Slider
import androidx.compose.material3.Switch
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateMapOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshots.SnapshotStateMap
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.focus.onFocusChanged
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.google.aiedge.gallery.data.BooleanSwitchConfig
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.ui.preview.MODEL_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
import kotlin.Double.Companion.NaN
private const val TAG = "AGConfigDialog"
/**
* Displays a configuration dialog allowing users to modify settings through various input controls.
*/
@Composable
fun ConfigDialog(
title: String,
configs: List<Config>,
initialValues: Map<String, Any>,
onDismissed: () -> Unit,
onOk: (Map<String, Any>) -> Unit,
okBtnLabel: String = "OK",
subtitle: String = "",
showCancel: Boolean = true,
) {
val values: SnapshotStateMap<String, Any> = remember {
mutableStateMapOf<String, Any>().apply {
putAll(initialValues)
}
}
Dialog(onDismissRequest = onDismissed) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Dialog title and subtitle.
Column {
Text(
title,
style = MaterialTheme.typography.titleLarge,
modifier = Modifier.padding(bottom = 8.dp)
)
// Subtitle.
if (subtitle.isNotEmpty()) {
Text(
subtitle,
style = labelSmallNarrow,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.offset(y = (-6).dp)
)
}
}
// List of config rows.
for (config in configs) {
when (config) {
// Number slider.
is NumberSliderConfig -> {
NumberSliderRow(config = config, values = values)
}
// Boolean switch.
is BooleanSwitchConfig -> {
BooleanSwitchRow(config = config, values = values)
}
is SegmentedButtonConfig -> {
SegmentedButtonRow(config = config, values = values)
}
else -> {}
}
}
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
// Cancel button.
if (showCancel) {
TextButton(
onClick = { onDismissed() },
) {
Text("Cancel")
}
}
// Ok button
Button(
onClick = {
Log.d(TAG, "Values from dialog: $values")
onOk(values.toMap())
},
) {
Text(okBtnLabel)
}
}
}
}
}
}
/**
* Composable function to display a number slider with an associated text input field.
*
* This function renders a row containing a slider and a text field, both used to modify
* a numeric value. The slider allows users to visually adjust the value within a specified range,
* while the text field provides precise numeric input.
*/
@Composable
fun NumberSliderRow(config: NumberSliderConfig, values: SnapshotStateMap<String, Any>) {
Column(modifier = Modifier.fillMaxWidth()) {
// Field label.
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
// Controls row.
Row(
modifier = Modifier.fillMaxWidth(), verticalAlignment = Alignment.CenterVertically
) {
var isFocused by remember { mutableStateOf(false) }
val focusRequester = remember { FocusRequester() }
// Number slider.
val sliderValue = try {
values[config.key.label] as Float
} catch (e: Exception) {
0f
}
Slider(modifier = Modifier
.height(24.dp)
.weight(1f),
value = sliderValue,
valueRange = config.sliderMin..config.sliderMax,
onValueChange = { values[config.key.label] = it })
Spacer(modifier = Modifier.width(8.dp))
// Text field.
val textFieldValue = try {
when (config.valueType) {
ValueType.FLOAT -> {
"%.2f".format(values[config.key.label] as Float)
}
ValueType.INT -> {
"${(values[config.key.label] as Float).toInt()}"
}
else -> {
""
}
}
} catch (e: Exception) {
""
}
// A smaller text field.
BasicTextField(
value = textFieldValue,
modifier = Modifier
.width(80.dp)
.focusRequester(focusRequester)
.onFocusChanged {
isFocused = it.isFocused
},
keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number),
onValueChange = {
if (it.isNotEmpty()) {
values[config.key.label] = it.toFloatOrNull() ?: NaN
} else {
values[config.key.label] = NaN
}
},
) { innerTextField ->
Box(
modifier = Modifier.border(
width = if (isFocused) 2.dp else 1.dp,
color = if (isFocused) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.outline,
shape = RoundedCornerShape(4.dp)
)
) {
Box(modifier = Modifier.padding(8.dp)) {
innerTextField()
}
}
}
}
}
}
/**
* Composable function to display a row with a boolean switch.
*
* This function renders a row containing a label and a switch, allowing users to toggle
* a boolean value.
*/
@Composable
fun BooleanSwitchRow(config: BooleanSwitchConfig, values: SnapshotStateMap<String, Any>) {
val switchValue = try {
values[config.key.label] as Boolean
} catch (e: Exception) {
false
}
Column(modifier = Modifier.fillMaxWidth()) {
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
Switch(checked = switchValue, onCheckedChange = { values[config.key.label] = it })
}
}
@Composable
fun SegmentedButtonRow(config: SegmentedButtonConfig, values: SnapshotStateMap<String, Any>) {
var selectedIndex by remember { mutableIntStateOf(config.options.indexOf(values[config.key.label])) }
Column(modifier = Modifier.fillMaxWidth()) {
Text(config.key.label, style = MaterialTheme.typography.titleSmall)
SingleChoiceSegmentedButtonRow {
config.options.forEachIndexed { index, label ->
SegmentedButton(shape = SegmentedButtonDefaults.itemShape(
index = index, count = config.options.size
), onClick = {
selectedIndex = index
values[config.key.label] = label
}, selected = index == selectedIndex, label = { Text(label) })
}
}
}
}
@Composable
@Preview(showBackground = true)
fun ConfigDialogPreview() {
GalleryTheme {
val defaultValues: MutableMap<String, Any> = mutableMapOf()
for (config in MODEL_TEST1.configs) {
defaultValues[config.key.label] = config.defaultValue
}
Column {
ConfigDialog(
title = "Dialog title",
subtitle = "20250413",
configs = MODEL_TEST1.configs,
initialValues = defaultValues,
onDismissed = {},
onOk = {},
)
}
}
}

View file

@ -0,0 +1,93 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.bodySmallMediumNarrow
import com.google.aiedge.gallery.ui.theme.bodySmallMediumNarrowBold
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
import com.google.aiedge.gallery.ui.theme.labelSmallNarrowMedium
/**
* Composable function to display a data card with a label and a numeric value.
*
* This function renders a column containing a label and a formatted numeric value.
* It provides options for highlighting the value and displaying a placeholder when the value is not
* available.
*/
@Composable
fun DataCard(
label: String,
value: Float?,
unit: String,
highlight: Boolean = false,
showPlaceholder: Boolean = false
) {
var strValue = "-"
Column {
Text(label, style = labelSmallNarrowMedium)
if (showPlaceholder) {
Text("-", style = bodySmallMediumNarrow)
} else {
strValue = if (value == null) "-" else "%.2f".format(value)
if (highlight) {
Text(
strValue, style = bodySmallMediumNarrowBold, color = MaterialTheme.colorScheme.primary
)
} else {
Text(strValue, style = bodySmallMediumNarrow)
}
}
if (strValue != "-") {
Text(
unit, style = labelSmallNarrow, modifier = Modifier
.alpha(0.5f)
.offset(y = (-1).dp)
)
}
}
}
@Preview(showBackground = true)
@Composable
fun DataCardPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp), horizontalArrangement = Arrangement.spacedBy(16.dp)) {
DataCard(
label = "sum", value = 123.45f, unit = "ms", highlight = true, showPlaceholder = false
)
DataCard(
label = "average", value = 12.3f, unit = "ms", highlight = false, showPlaceholder = false
)
DataCard(
label = "test", value = null, unit = "ms", highlight = false, showPlaceholder = false
)
}
}
}

View file

@ -0,0 +1,226 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import android.graphics.Bitmap
import android.graphics.Matrix
import android.util.Size
import androidx.camera.core.CameraSelector
import androidx.camera.core.ImageAnalysis
import androidx.camera.core.resolutionselector.ResolutionSelector
import androidx.camera.core.resolutionselector.ResolutionStrategy
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.aspectRatio
import androidx.compose.foundation.layout.fillMaxHeight
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.Card
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.core.content.ContextCompat
import androidx.lifecycle.compose.LocalLifecycleOwner
import java.util.concurrent.Executors
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine
/**
* Composable function to display a live camera feed in a dialog.
*
* This function renders a dialog that displays a live camera preview, along with optional
* classification results and FPS information. It manages camera initialization, frame capture,
* and dialog dismissal.
*/
@Composable
fun LiveCameraDialog(
onDismissed: (averageFps: Int) -> Unit,
onBitmap: (Bitmap) -> Unit,
streamingMessage: ChatMessage? = null,
) {
val context = LocalContext.current
val lifecycleOwner = LocalLifecycleOwner.current
var imageBitmap by remember { mutableStateOf<ImageBitmap?>(null) }
var cameraProvider: ProcessCameraProvider? by remember { mutableStateOf(null) }
var sumFps by remember { mutableLongStateOf(0L) }
var fpsCount by remember { mutableLongStateOf(0L) }
LaunchedEffect(key1 = true) {
cameraProvider = startCamera(
context,
lifecycleOwner,
onBitmap = onBitmap,
onImageBitmap = { b -> imageBitmap = b })
}
Dialog(onDismissRequest = {
cameraProvider?.unbindAll()
onDismissed((sumFps / fpsCount).toInt())
}) {
Card(modifier = Modifier.fillMaxWidth(), shape = RoundedCornerShape(16.dp)) {
Column(
modifier = Modifier.padding(20.dp), verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Title
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier
.fillMaxWidth()
.padding(bottom = 8.dp)
) {
Text(
"Live camera",
style = MaterialTheme.typography.titleLarge,
)
if (streamingMessage != null) {
val fps = (1000f / streamingMessage.latencyMs).toInt()
sumFps += fps.toLong()
fpsCount += 1
Text(
"%d FPS".format(fps),
style = MaterialTheme.typography.titleLarge,
)
}
}
// Camera live view.
Row(
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f),
horizontalArrangement = Arrangement.Center
) {
val ib = imageBitmap
if (ib != null) {
Image(
bitmap = ib,
contentDescription = "",
modifier = Modifier
.fillMaxHeight()
.clip(RoundedCornerShape(8.dp)),
contentScale = ContentScale.Inside
)
}
}
// Result.
if (streamingMessage != null && streamingMessage is ChatMessageClassification) {
MessageBodyClassification(
message = streamingMessage,
modifier = Modifier.fillMaxWidth(),
oneLineLabel = true
)
}
// Button.
Row(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
horizontalArrangement = Arrangement.End,
) {
TextButton(
onClick = {
cameraProvider?.unbindAll()
onDismissed((sumFps / fpsCount).toInt())
},
) {
Text("OK")
}
}
}
}
}
}
/**
* Asynchronously initializes and starts the camera for image capture and analysis.
*
* This function sets up the camera using CameraX, configures image analysis, and binds
* the camera lifecycle to the provided LifecycleOwner. It captures frames from the camera,
* converts them to Bitmaps and ImageBitmaps, and invokes the provided callbacks.
*/
private suspend fun startCamera(
context: android.content.Context,
lifecycleOwner: androidx.lifecycle.LifecycleOwner,
onBitmap: (Bitmap) -> Unit,
onImageBitmap: (ImageBitmap) -> Unit
): ProcessCameraProvider? = suspendCoroutine { continuation ->
val cameraProviderFuture = ProcessCameraProvider.getInstance(context)
cameraProviderFuture.addListener({
val cameraProvider = cameraProviderFuture.get()
val resolutionSelector = ResolutionSelector.Builder().setResolutionStrategy(
ResolutionStrategy(
Size(1080, 1080),
ResolutionStrategy.FALLBACK_RULE_CLOSEST_LOWER_THEN_HIGHER
)
).build()
val imageAnalysis =
ImageAnalysis.Builder().setResolutionSelector(resolutionSelector)
.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST).build().also {
it.setAnalyzer(Executors.newSingleThreadExecutor()) { imageProxy ->
var bitmap = imageProxy.toBitmap()
val rotation = imageProxy.imageInfo.rotationDegrees
bitmap = if (rotation != 0) {
val matrix = Matrix().apply {
postRotate(rotation.toFloat())
}
Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else bitmap
onBitmap(bitmap)
onImageBitmap(bitmap.asImageBitmap())
imageProxy.close()
}
}
val cameraSelector = CameraSelector.DEFAULT_BACK_CAMERA
try {
cameraProvider?.unbindAll()
cameraProvider?.bindToLifecycle(
lifecycleOwner, cameraSelector, imageAnalysis
)
// Resume with the provider
continuation.resume(cameraProvider)
} catch (exc: Exception) {
// todo: Handle exceptions (e.g., camera initialization failure)
}
}, ContextCompat.getMainExecutor(context))
}

View file

@ -0,0 +1,76 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ProvideTextStyle
import androidx.compose.runtime.Composable
import androidx.compose.runtime.CompositionLocalProvider
import androidx.compose.ui.Modifier
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.tooling.preview.Preview
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.halilibo.richtext.commonmark.Markdown
import com.halilibo.richtext.ui.CodeBlockStyle
import com.halilibo.richtext.ui.RichTextStyle
import com.halilibo.richtext.ui.material3.RichText
/**
* Composable function to display Markdown-formatted text.
*/
@Composable
fun MarkdownText(
text: String,
modifier: Modifier = Modifier,
smallFontSize: Boolean = false
) {
val fontSize =
if (smallFontSize) MaterialTheme.typography.bodySmall.fontSize else MaterialTheme.typography.bodyMedium.fontSize
CompositionLocalProvider {
ProvideTextStyle(
value = TextStyle(
fontSize = fontSize,
lineHeight = fontSize * 1.2,
)
) {
RichText(
modifier = modifier,
style = RichTextStyle(
codeBlockStyle = CodeBlockStyle(
textStyle = TextStyle(
fontSize = MaterialTheme.typography.bodySmall.fontSize,
fontFamily = FontFamily.Monospace,
)
)
),
) {
Markdown(
content = text
)
}
}
}
}
@Preview(showBackground = true)
@Composable
fun MarkdownTextPreview() {
GalleryTheme {
MarkdownText(text = "*Hello World*\n**Good morning!!**")
}
}

View file

@ -0,0 +1,95 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.PlayArrow
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.bodySmallNarrow
/**
* Composable function to display an action button below a chat message.
*/
@Composable
fun MessageActionButton(
label: String,
icon: ImageVector,
onClick: () -> Unit,
enabled: Boolean = true
) {
val modifier = Modifier
.padding(top = 4.dp)
.clip(CircleShape)
.background(if (enabled) MaterialTheme.colorScheme.secondaryContainer else MaterialTheme.colorScheme.surfaceContainerHigh)
val alpha: Float = if (enabled) 1.0f else 0.3f
Row(
modifier = if (enabled) modifier.clickable { onClick() } else modifier,
verticalAlignment = Alignment.CenterVertically,
) {
Icon(
icon, contentDescription = "", modifier = Modifier
.size(16.dp)
.offset(x = 6.dp)
.alpha(alpha)
)
Text(
label,
color = MaterialTheme.colorScheme.onSecondaryContainer,
style = bodySmallNarrow,
modifier = Modifier
.padding(
start = 10.dp, end = 8.dp, top = 4.dp, bottom = 4.dp
)
.alpha(alpha)
)
}
}
@Preview(showBackground = true)
@Composable
fun MessageActionButtonPreview() {
GalleryTheme {
Column {
MessageActionButton(label = "run", icon = Icons.Default.PlayArrow, onClick = {})
MessageActionButton(
label = "run",
icon = Icons.Default.PlayArrow,
enabled = false,
onClick = {})
}
}
}

View file

@ -0,0 +1,140 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import kotlin.math.max
private const val DEFAULT_HISTOGRAM_BAR_HEIGHT = 50f
/**
* Composable function to display benchmark results within a chat message.
*
* This function renders benchmark statistics (e.g., average latency) in data cards and
* visualizes the latency distribution using a histogram.
*/
@Composable
fun MessageBodyBenchmark(message: ChatMessageBenchmarkResult) {
Column(
modifier = Modifier
.padding(12.dp)
.fillMaxWidth(),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
// Data cards.
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) {
for (stat in message.orderedStats) {
DataCard(
label = stat.label,
unit = stat.unit,
value = message.statValues[stat.id],
highlight = stat.id == message.highlightStat,
showPlaceholder = message.isWarmingUp()
)
}
}
// Histogram
if (message.histogram.buckets.isNotEmpty()) {
Row(
horizontalArrangement = Arrangement.spacedBy(2.dp)
) {
for ((index, count) in message.histogram.buckets.withIndex()) {
var barBgColor = MaterialTheme.colorScheme.onSurfaceVariant
var alpha = 0.3f
if (count != 0) {
alpha = 0.5f
}
if (index == message.histogram.highlightBucketIndex) {
barBgColor = MaterialTheme.colorScheme.primary
alpha = 0.8f
}
// Bar container.
Column(
modifier = Modifier
.height(DEFAULT_HISTOGRAM_BAR_HEIGHT.dp)
.width(4.dp),
verticalArrangement = Arrangement.Bottom,
) {
// Bar content.
Box(
modifier = Modifier
.height(
max(
1f,
count.toFloat() / message.histogram.maxCount.toFloat() * DEFAULT_HISTOGRAM_BAR_HEIGHT
).dp
)
.fillMaxWidth()
.clip(RoundedCornerShape(20.dp, 20.dp, 0.dp, 0.dp))
.alpha(alpha)
.background(barBgColor)
)
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyBenchmarkPreview() {
GalleryTheme {
MessageBodyBenchmark(
message = ChatMessageBenchmarkResult(
orderedStats = listOf(
Stat(id = "stat1", label = "Stat1", unit = "ms"),
Stat(id = "stat2", label = "Stat2", unit = "ms"),
Stat(id = "stat3", label = "Stat3", unit = "ms"),
Stat(id = "stat4", label = "Stat4", unit = "ms")
),
statValues = mutableMapOf(
"stat1" to 0.3f,
"stat2" to 0.4f,
"stat3" to 0.5f,
),
values = listOf(),
histogram = Histogram(listOf(), 0),
warmupCurrent = 0,
warmupTotal = 0,
iterationCurrent = 0,
iterationTotal = 0,
highlightStat = "stat2"
)
)
}
}

View file

@ -0,0 +1,76 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentWidth
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display benchmark LLM results within a chat message.
*
* This function renders benchmark statistics (e.g., various token speed) in data cards
*/
@Composable
fun MessageBodyBenchmarkLlm(message: ChatMessageBenchmarkLlmResult) {
Column(
modifier = Modifier
.padding(12.dp)
.wrapContentWidth(),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
// Data cards.
Row(
modifier = Modifier.wrapContentWidth(), horizontalArrangement = Arrangement.spacedBy(16.dp)
) {
for (stat in message.orderedStats) {
DataCard(
label = stat.label,
unit = stat.unit,
value = message.statValues[stat.id],
)
}
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyBenchmarkLlmPreview() {
GalleryTheme {
MessageBodyBenchmarkLlm(
message = ChatMessageBenchmarkLlmResult(
orderedStats = listOf(
Stat(id = "stat1", label = "Stat1", unit = "tokens/s"),
Stat(id = "stat2", label = "Stat2", unit = "tokens/s")
),
statValues = mutableMapOf(
"stat1" to 0.3f,
"stat2" to 0.4f,
),
running = false,
)
)
}
}

View file

@ -0,0 +1,115 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
val CLASSIFICATION_BAR_HEIGHT = 8.dp
val CLASSIFICATION_BAR_MAX_WIDTH = 200.dp
/**
* Composable function to display classification results.
*
* This function renders a list of classifications, each with its label, score, and a visual score bar.
*/
@Composable
fun MessageBodyClassification(
message: ChatMessageClassification,
modifier: Modifier = Modifier,
oneLineLabel: Boolean = false,
) {
Column(
modifier = modifier.padding(12.dp)
) {
for (classification in message.classifications) {
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.SpaceBetween
) {
// Classification label.
Text(
classification.label,
maxLines = if (oneLineLabel) 1 else Int.MAX_VALUE,
overflow = TextOverflow.Ellipsis,
style = MaterialTheme.typography.bodySmall,
modifier = Modifier.weight(1f)
)
// Classification score.
Text(
"%.2f".format(classification.score),
style = MaterialTheme.typography.bodySmall,
modifier = Modifier
.align(Alignment.Bottom),
)
}
Spacer(modifier = Modifier.height(2.dp))
// Score bar.
Box {
Box(
modifier = Modifier
.fillMaxWidth()
.height(CLASSIFICATION_BAR_HEIGHT)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceDim)
)
Box(
modifier = Modifier
.fillMaxWidth(classification.score)
.height(CLASSIFICATION_BAR_HEIGHT)
.clip(CircleShape)
.background(classification.color)
)
}
Spacer(modifier = Modifier.height(6.dp))
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyClassificationPreview() {
GalleryTheme {
MessageBodyClassification(
message = ChatMessageClassification(
classifications = listOf(
Classification(label = "label1", score = 0.3f, color = Color.Red),
Classification(label = "label2", score = 0.7f, color = Color.Blue)
),
latencyMs = 12345f,
),
)
}
}

View file

@ -0,0 +1,144 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.common.getConfigValueString
import com.google.aiedge.gallery.ui.preview.MODEL_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.bodySmallNarrow
import com.google.aiedge.gallery.ui.theme.titleSmaller
/**
* Composable function to display a message indicating configuration value changes.
*
* This function renders a centered row containing a box that displays the old and new
* values of configuration settings that have been updated.
*/
@Composable
fun MessageBodyConfigUpdate(message: ChatMessageConfigValuesChange) {
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.Center,
) {
Box(
modifier = Modifier
.clip(RoundedCornerShape(4.dp))
.background(MaterialTheme.colorScheme.tertiaryContainer)
) {
Column(modifier = Modifier.padding(8.dp)) {
// Title.
Text(
"Configs updated",
color = MaterialTheme.colorScheme.onTertiaryContainer,
style = titleSmaller,
)
Row(modifier = Modifier.padding(top = 8.dp)) {
// Keys
Column {
for (config in message.model.configs) {
Text(
"${config.key.label}:",
style = bodySmallNarrow,
modifier = Modifier.alpha(0.6f),
)
}
}
Spacer(modifier = Modifier.width(4.dp))
// Values
Column {
for (config in message.model.configs) {
val key = config.key.label
val oldValue: Any = convertValueToTargetType(
value = message.oldValues.getValue(key), valueType = config.valueType
)
val newValue: Any = convertValueToTargetType(
value = message.newValues.getValue(key), valueType = config.valueType
)
if (oldValue == newValue) {
Text("$newValue", style = bodySmallNarrow)
} else {
Row(verticalAlignment = Alignment.CenterVertically) {
Text(
getConfigValueString(oldValue, config), style = bodySmallNarrow
)
Text(
"",
style = bodySmallNarrow.copy(fontSize = 12.sp),
modifier = Modifier.padding(start = 4.dp, end = 4.dp)
)
Text(
getConfigValueString(newValue, config),
style = bodySmallNarrow.copy(fontWeight = FontWeight.Bold),
color = MaterialTheme.colorScheme.primary,
)
}
}
}
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyConfigUpdatePreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyConfigUpdate(
message = ChatMessageConfigValuesChange(
model = MODEL_TEST1,
oldValues = mapOf(
ConfigKey.MAX_RESULT_COUNT.label to 100,
ConfigKey.USE_GPU.label to false
),
newValues = mapOf(
ConfigKey.MAX_RESULT_COUNT.label to 200,
ConfigKey.USE_GPU.label to true
)
)
)
}
}
}

View file

@ -0,0 +1,43 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.width
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.unit.dp
@Composable
fun MessageBodyImage(message: ChatMessageImage) {
val bitmapWidth = message.bitmap.width
val bitmapHeight = message.bitmap.height
val imageWidth =
if (bitmapWidth >= bitmapHeight) 200 else (200f / bitmapHeight * bitmapWidth).toInt()
val imageHeight =
if (bitmapHeight >= bitmapWidth) 200 else (200f / bitmapWidth * bitmapHeight).toInt()
Image(
bitmap = message.imageBitMap,
contentDescription = "",
modifier = Modifier
.height(imageHeight.dp)
.width(imageWidth.dp),
contentScale = ContentScale.Fit,
)
}

View file

@ -0,0 +1,91 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.Image
import androidx.compose.foundation.gestures.detectHorizontalDragGestures
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.width
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.MutableIntState
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.unit.dp
/**
* Composable function to display an image message with history, allowing users to navigate through
* different versions by sliding on the image.
*/
@Composable
fun MessageBodyImageWithHistory(
message: ChatMessageImageWithHistory,
imageHistoryCurIndex: MutableIntState
) {
val prevMessage: MutableState<ChatMessageImageWithHistory?> = remember { mutableStateOf(null) }
LaunchedEffect(message) {
imageHistoryCurIndex.intValue = message.bitmaps.size - 1
prevMessage.value = message
}
Column {
val curImage = message.bitmaps[imageHistoryCurIndex.intValue]
val curImageBitmap = message.imageBitMaps[imageHistoryCurIndex.intValue]
val bitmapWidth = curImage.width
val bitmapHeight = curImage.height
val imageWidth =
if (bitmapWidth >= bitmapHeight) 200 else (200f / bitmapHeight * bitmapWidth).toInt()
val imageHeight =
if (bitmapHeight >= bitmapWidth) 200 else (200f / bitmapWidth * bitmapHeight).toInt()
var value by remember { mutableFloatStateOf(0f) }
var savedIndex by remember { mutableIntStateOf(0) }
Image(
bitmap = curImageBitmap,
contentDescription = "",
modifier = Modifier
.height(imageHeight.dp)
.width(imageWidth.dp)
.pointerInput(Unit) {
detectHorizontalDragGestures(onDragStart = {
value = 0f
savedIndex = imageHistoryCurIndex.intValue
}) { _, dragAmount ->
value += (dragAmount / 20f)// Adjust sensitivity here
imageHistoryCurIndex.intValue = (savedIndex + value).toInt()
if (imageHistoryCurIndex.intValue < 0) {
imageHistoryCurIndex.intValue = 0
} else if (imageHistoryCurIndex.intValue > message.bitmaps.size - 1) {
imageHistoryCurIndex.intValue = message.bitmaps.size - 1
}
}
},
contentScale = ContentScale.Fit,
)
}
}

View file

@ -0,0 +1,63 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
/**
* Composable function to display informational message content within a chat.
*
* Supports markdown.
*/
@Composable
fun MessageBodyInfo(message: ChatMessageInfo) {
Row(
modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center
) {
Box(
modifier = Modifier
.clip(RoundedCornerShape(16.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor)
) {
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp), smallFontSize = true)
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyInfoPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyInfo(message = ChatMessageInfo(content = "This is a model"))
}
}
}

View file

@ -0,0 +1,142 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private val IMAGE_RESOURCES = listOf(
R.drawable.pantegon,
R.drawable.double_circle,
R.drawable.circle,
R.drawable.four_circle
)
private const val ANIMATION_DURATION = 300
private const val ANIMATION_DURATION2 = 300
private const val PAUSE_DURATION = 200
private const val PAUSE_DURATION2 = 0
/**
* Composable function to display a loading indicator.
*/
@Composable
fun MessageBodyLoading() {
val progress = remember { Animatable(0f) }
val alphaAnim = remember { Animatable(0f) }
val activeImageIndex = remember { mutableIntStateOf(0) }
LaunchedEffect(Unit) { // Run this once
while (true) {
var progressJob = launch {
progress.animateTo(
targetValue = 1f,
animationSpec = tween(
durationMillis = ANIMATION_DURATION,
easing = multiBounceEasing(bounces = 3, decay = 0.02f)
)
)
}
var alphaJob = launch {
alphaAnim.animateTo(
targetValue = 1f,
animationSpec = tween(
durationMillis = ANIMATION_DURATION / 2,
)
)
}
progressJob.join()
alphaJob.join()
delay((PAUSE_DURATION).toLong())
progressJob = launch {
progress.animateTo(
targetValue = 0f,
animationSpec = tween(
durationMillis = ANIMATION_DURATION2,
easing = multiBounceEasing(bounces = 3, decay = 0.02f)
)
)
}
alphaJob = launch {
alphaAnim.animateTo(
targetValue = 0f,
animationSpec = tween(
durationMillis = ANIMATION_DURATION2 / 2,
)
)
}
progressJob.join()
alphaJob.join()
delay(PAUSE_DURATION2.toLong())
activeImageIndex.intValue = (activeImageIndex.intValue + 1) % IMAGE_RESOURCES.size
}
}
Box(contentAlignment = Alignment.Center) {
for ((index, imageResource) in IMAGE_RESOURCES.withIndex()) {
Image(
painter = painterResource(id = imageResource),
contentDescription = "",
contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint(getTaskIconColor(index = index)),
modifier = Modifier
.graphicsLayer {
scaleX = progress.value * 0.2f + 0.8f
scaleY = progress.value * 0.2f + 0.8f
rotationZ = progress.value * 100
alpha = if (index != activeImageIndex.intValue) 0f else alphaAnim.value
}
.size(24.dp)
)
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyLoadingPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyLoading()
}
}
}

View file

@ -0,0 +1,168 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.shadow
import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.preview.ALL_PREVIEW_TASKS
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
private const val CARD_HEIGHT = 100
@Composable
fun MessageBodyPromptTemplates(
message: ChatMessagePromptTemplates,
task: Task,
onPromptClicked: (PromptTemplate) -> Unit = {},
) {
val rowCount = message.templates.size.toFloat()
val color = getTaskIconColor(task)
val gradientColors = listOf(color.copy(alpha = 0.5f), color)
Column(
modifier = Modifier.padding(top = 12.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
Text(
"Try an example prompt",
style = MaterialTheme.typography.titleLarge.copy(
fontWeight = FontWeight.Bold,
brush = Brush.linearGradient(
colors = gradientColors,
)
),
modifier = Modifier.fillMaxWidth(),
textAlign = TextAlign.Center,
)
if (message.showMakeYourOwn) {
Text(
"Or make your own",
style = MaterialTheme.typography.titleSmall,
modifier = Modifier
.fillMaxWidth()
.offset(y = -4.dp),
textAlign = TextAlign.Center,
)
}
LazyColumn(
modifier = Modifier
.height((rowCount * (CARD_HEIGHT + 8)).dp),
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// Cards.
items(message.templates) { template ->
Box(
modifier = Modifier
.border(
width = 1.dp,
color = color.copy(alpha = 0.3f),
shape = RoundedCornerShape(24.dp)
)
.height(CARD_HEIGHT.dp)
.shadow(
elevation = 2.dp,
shape = RoundedCornerShape(24.dp),
spotColor = color
)
.background(MaterialTheme.colorScheme.surface)
.clickable {
onPromptClicked(template)
}
) {
Column(
modifier = Modifier
.padding(horizontal = 12.dp, vertical = 20.dp)
.fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
) {
Text(
template.title,
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.Bold),
)
Spacer(modifier = Modifier.weight(1f))
Text(
template.description,
style = MaterialTheme.typography.bodyMedium,
textAlign = TextAlign.Center,
)
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyPromptTemplatesPreview() {
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess(task = task)
}
}
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
MessageBodyPromptTemplates(
message = ChatMessagePromptTemplates(
templates = listOf(
PromptTemplate(
title = "Math Worksheets",
description = "Create a set of math worksheets for parents",
prompt = ""
),
PromptTemplate(
title = "Shape Sequencer",
description = "Find the next shape in a sequence",
prompt = ""
)
)
),
task = TASK_TEST1,
)
}
}
}

View file

@ -0,0 +1,80 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display the text content of a ChatMessageText.
*/
@Composable
fun MessageBodyText(message: ChatMessageText) {
if (message.side == ChatSide.USER) {
Text(
message.content,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.Medium),
color = Color.White,
modifier = Modifier.padding(12.dp)
)
} else if (message.side == ChatSide.AGENT) {
if (message.isMarkdown) {
MarkdownText(text = message.content, modifier = Modifier.padding(12.dp))
} else {
Text(
message.content,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.Medium),
color = MaterialTheme.colorScheme.onSurface,
modifier = Modifier.padding(12.dp)
)
}
}
}
@Preview(showBackground = true)
@Composable
fun MessageBodyTextPreview() {
GalleryTheme {
Column {
Row(
modifier = Modifier
.padding(16.dp)
.background(MaterialTheme.colorScheme.primary),
) {
MessageBodyText(ChatMessageText(content = "Hello world", side = ChatSide.USER))
}
Row(
modifier = Modifier
.padding(16.dp)
.background(MaterialTheme.colorScheme.surfaceContainer),
) {
MessageBodyText(ChatMessageText(content = "yes hello world", side = ChatSide.AGENT))
}
}
}
}

View file

@ -0,0 +1,69 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.ui.geometry.CornerRadius
import androidx.compose.ui.geometry.RoundRect
import androidx.compose.ui.geometry.Size
import androidx.compose.ui.graphics.Outline
import androidx.compose.ui.graphics.Path
import androidx.compose.ui.graphics.Shape
import androidx.compose.ui.unit.Density
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.LayoutDirection
/**
* Custom Shape for creating message bubble outlines with configurable corner radii.
*
* This class defines a custom Shape that generates a rounded rectangle outline,
* suitable for message bubbles. It allows specifying a uniform corner radius for
* most corners, but also provides the option to have a hard (non-rounded) corner
* on either the left or right side.
*/
class MessageBubbleShape(
private val radius: Dp,
private val hardCornerAtLeftOrRight: Boolean = false
) : Shape {
override fun createOutline(
size: Size,
layoutDirection: LayoutDirection,
density: Density
): Outline {
val radiusPx = with(density) { radius.toPx() }
val path = Path().apply {
addRoundRect(
RoundRect(
left = 0f,
top = 0f,
right = size.width,
bottom = size.height,
topLeftCornerRadius = if (hardCornerAtLeftOrRight) CornerRadius(0f, 0f) else CornerRadius(
radiusPx,
radiusPx
),
topRightCornerRadius = if (hardCornerAtLeftOrRight) CornerRadius(
radiusPx,
radiusPx
) else CornerRadius(0f, 0f), // No rounding here
bottomLeftCornerRadius = CornerRadius(radiusPx, radiusPx),
bottomRightCornerRadius = CornerRadius(radiusPx, radiusPx)
)
)
}
return Outline.Generic(path)
}
}

View file

@ -0,0 +1,269 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import android.Manifest
import android.content.Context
import android.content.pm.PackageManager
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Matrix
import android.net.Uri
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.PickVisualMediaRequest
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Photo
import androidx.compose.material.icons.rounded.PhotoCamera
import androidx.compose.material.icons.rounded.Videocam
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.core.content.ContextCompat
import com.google.aiedge.gallery.ui.common.createTempPictureUri
import com.google.aiedge.gallery.ui.theme.GalleryTheme
private const val TAG = "AGMessageInputImage"
/**
* Composable function to display image input options for chat messages.
*
* This function renders a row containing buttons that allow the user to select images from albums,
* take photos using the camera, or initiate a live camera stream. It handles permission requests,
* image selection, and launching camera activities.
*/
@Composable
fun MessageInputImage(
onImageSelected: (Bitmap) -> Unit,
streamingMessage: ChatMessage? = null,
onStreamImage: (Bitmap) -> Unit = {},
onStreamEnd: (Int) -> Unit = {},
disableButtons: Boolean = false,
) {
val context = LocalContext.current
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
var showLiveCameraDialog by remember { mutableStateOf(false) }
// Registers a photo picker activity launcher in single-select mode.
val pickMedia =
rememberLauncherForActivityResult(ActivityResultContracts.PickVisualMedia()) { uri ->
// Callback is invoked after the user selects a media item or closes the
// photo picker.
if (uri != null) {
handleImageSelected(context = context, uri = uri, onImageSelected = onImageSelected)
} else {
Log.d(TAG, "No media selected")
}
}
// launches camera
val cameraLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.TakePicture()) { isImageSaved ->
if (isImageSaved) {
handleImageSelected(
context = context,
uri = tempPhotoUri,
onImageSelected = onImageSelected,
rotateForPortrait = true,
)
}
}
// Permission request when taking picture.
val takePicturePermissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) { permissionGranted ->
if (permissionGranted) {
tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri)
}
}
// Permission request when using live camera.
val liveCameraPermissionLauncher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) { permissionGranted ->
if (permissionGranted) {
showLiveCameraDialog = true
}
}
val buttonAlpha = if (disableButtons) 0.3f else 1f
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.End,
) {
// Pick from albums.
IconButton(
onClick = {
if (disableButtons) {
return@IconButton
}
// Launch the photo picker and let the user choose only images.
pickMedia.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly))
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.primary,
),
modifier = Modifier.alpha(buttonAlpha),
) {
Icon(Icons.Rounded.Photo, contentDescription = "", tint = MaterialTheme.colorScheme.onPrimary)
}
// Take picture
IconButton(
onClick = {
if (disableButtons) {
return@IconButton
}
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(
context, Manifest.permission.CAMERA
) -> {
tempPhotoUri = context.createTempPictureUri()
cameraLauncher.launch(tempPhotoUri)
}
// Otherwise, ask for permission
else -> {
takePicturePermissionLauncher.launch(Manifest.permission.CAMERA)
}
}
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.primary,
),
modifier = Modifier.alpha(buttonAlpha),
) {
Icon(
Icons.Rounded.PhotoCamera,
contentDescription = "",
tint = MaterialTheme.colorScheme.onPrimary
)
}
// Video stream.
IconButton(
onClick = {
if (disableButtons) {
return@IconButton
}
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(
context, Manifest.permission.CAMERA
) -> {
showLiveCameraDialog = true
}
// Otherwise, ask for permission
else -> {
liveCameraPermissionLauncher.launch(Manifest.permission.CAMERA)
}
}
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.primary,
),
modifier = Modifier.alpha(buttonAlpha),
) {
Icon(
Icons.Rounded.Videocam, contentDescription = "", tint = MaterialTheme.colorScheme.onPrimary
)
}
}
// Live camera stream dialog.
if (showLiveCameraDialog) {
LiveCameraDialog(
streamingMessage = streamingMessage, onDismissed = { averageFps ->
onStreamEnd(averageFps)
showLiveCameraDialog = false
}, onBitmap = onStreamImage
)
}
}
private fun handleImageSelected(
context: Context,
uri: Uri,
onImageSelected: (Bitmap) -> Unit,
// For some reason, some Android phone would store the picture taken by the camera rotated
// horizontally. Use this flag to rotate the image back to portrait if the picture's width
// is bigger than height.
rotateForPortrait: Boolean = false,
) {
Log.d(TAG, "Selected URI: $uri")
val bitmap: Bitmap? = try {
val inputStream = context.contentResolver.openInputStream(uri)
val tmpBitmap = BitmapFactory.decodeStream(inputStream)
if (rotateForPortrait && tmpBitmap.width > tmpBitmap.height) {
val matrix = Matrix()
matrix.postRotate(90f)
Bitmap.createBitmap(tmpBitmap, 0, 0, tmpBitmap.width, tmpBitmap.height, matrix, true)
} else {
tmpBitmap
}
} catch (e: Exception) {
e.printStackTrace()
null
}
if (bitmap != null) {
onImageSelected(bitmap)
}
}
@Preview(showBackground = true)
@Composable
fun MessageInputImagePreview() {
GalleryTheme {
Column {
MessageInputImage(onImageSelected = {})
MessageInputImage(disableButtons = true, onImageSelected = {})
}
}
}

View file

@ -0,0 +1,268 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.annotation.StringRes
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.Send
import androidx.compose.material.icons.rounded.Add
import androidx.compose.material.icons.rounded.History
import androidx.compose.material.icons.rounded.PostAdd
import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material3.DropdownMenu
import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextField
import androidx.compose.material3.TextFieldDefaults
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display a text input field for composing chat messages.
*
* This function renders a row containing a text field for message input and a send button.
* It handles message composition, input validation, and sending messages.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun MessageInputText(
modelManagerViewModel: ModelManagerViewModel,
curMessage: String,
inProgress: Boolean,
@StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit,
onSendMessage: (ChatMessage) -> Unit,
onOpenPromptTemplatesClicked: () -> Unit = {},
onStopButtonClicked: () -> Unit = {},
showPromptTemplatesInMenu: Boolean = true,
showStopButtonWhenInProgress: Boolean = false,
) {
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
var showAddContentMenu by remember { mutableStateOf(false) }
var showTextInputHistorySheet by remember { mutableStateOf(false) }
Box(contentAlignment = Alignment.CenterStart) {
// A plus button to show a popup menu to add stuff to the chat.
IconButton(
enabled = !inProgress,
onClick = { showAddContentMenu = true },
modifier = Modifier
.offset(x = 16.dp)
.alpha(0.8f)
) {
Icon(
Icons.Rounded.Add,
contentDescription = "",
modifier = Modifier.size(28.dp),
)
}
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp)
.border(1.dp, MaterialTheme.colorScheme.outlineVariant, RoundedCornerShape(28.dp)),
verticalAlignment = Alignment.CenterVertically,
) {
DropdownMenu(
expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false }) {
if (showPromptTemplatesInMenu) {
DropdownMenuItem(text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.PostAdd, contentDescription = "")
Text("Prompt templates")
}
}, onClick = {
onOpenPromptTemplatesClicked()
showAddContentMenu = false
})
}
DropdownMenuItem(text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
Icon(Icons.Rounded.History, contentDescription = "")
Text("Input history")
}
}, onClick = {
showAddContentMenu = false
showTextInputHistorySheet = true
})
}
// Text field.
TextField(value = curMessage,
minLines = 1,
maxLines = 3,
onValueChange = onValueChanged,
colors = TextFieldDefaults.colors(
unfocusedContainerColor = Color.Transparent,
focusedContainerColor = Color.Transparent,
focusedIndicatorColor = Color.Transparent,
unfocusedIndicatorColor = Color.Transparent,
disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyMedium,
modifier = Modifier
.weight(1f)
.padding(start = 36.dp),
placeholder = { Text(stringResource(textFieldPlaceHolderRes)) })
Spacer(modifier = Modifier.width(8.dp))
if (inProgress && showStopButtonWhenInProgress) {
IconButton(
onClick = onStopButtonClicked,
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
),
) {
Icon(
Icons.Rounded.Stop,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
} // Send button. Only shown when text is not empty.
else if (curMessage.isNotEmpty()) {
IconButton(
enabled = !inProgress,
onClick = {
onSendMessage(ChatMessageText(content = curMessage.trim(), side = ChatSide.USER))
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer,
),
) {
Icon(
Icons.AutoMirrored.Rounded.Send,
contentDescription = "",
modifier = Modifier.offset(x = 2.dp),
tint = if (inProgress) MaterialTheme.colorScheme.surfaceContainerHigh else MaterialTheme.colorScheme.primary
)
}
}
Spacer(modifier = Modifier.width(4.dp))
}
}
// A bottom sheet to show the text input history to pick from.
if (showTextInputHistorySheet) {
TextInputHistorySheet(
history = modelManagerUiState.textInputHistory,
onDismissed = {
showTextInputHistorySheet = false
},
onHistoryItemClicked = { item ->
onSendMessage(ChatMessageText(content = item, side = ChatSide.USER))
modelManagerViewModel.promoteTextInputHistoryItem(item)
},
onHistoryItemDeleted = { item ->
modelManagerViewModel.deleteTextInputHistory(item)
},
onHistoryItemsDeleteAll = {
modelManagerViewModel.clearTextInputHistory()
}
)
}
}
@Preview(showBackground = true)
@Composable
fun MessageInputTextPreview() {
val context = LocalContext.current
GalleryTheme {
Column {
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "hello",
inProgress = true,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = false,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
)
MessageInputText(
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
curMessage = "",
inProgress = true,
textFieldPlaceHolderRes = R.string.chat_textinput_placeholder,
onValueChanged = {},
onSendMessage = {},
showStopButtonWhenInProgress = true,
)
}
}
}

View file

@ -0,0 +1,63 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.ui.common.humanReadableDuration
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display the latency of a chat message, if available.
*/
@Composable
fun LatencyText(message: ChatMessage) {
if (message.latencyMs >= 0) {
Text(
message.latencyMs.humanReadableDuration(),
modifier = Modifier.alpha(0.5f),
style = MaterialTheme.typography.labelSmall,
)
}
}
@Preview(showBackground = true)
@Composable
fun LatencyTextPreview() {
GalleryTheme {
Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp)) {
for (latencyMs in listOf(123f, 1234f, 123456f, 7234567f)) {
LatencyText(
message = ChatMessage(
latencyMs = latencyMs,
type = ChatMessageType.TEXT,
side = ChatSide.AGENT
)
)
}
}
}
}

View file

@ -0,0 +1,256 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import android.graphics.Bitmap
import androidx.annotation.StringRes
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.bodySmallNarrow
import com.google.aiedge.gallery.ui.theme.bodySmallSemiBold
data class MessageLayoutConfig(
val horizontalArrangement: Arrangement.Horizontal,
val modifier: Modifier,
val userLabel: String,
val rightSideLabel: String
)
/**
* Composable function to display the sender information for a chat message.
*
* This function handles different types of chat messages, including system messages,
* benchmark results, and image generation results, and displays the appropriate sender label
* and status information.
*/
@Composable
fun MessageSender(
message: ChatMessage, @StringRes agentNameRes: Int, imageHistoryCurIndex: Int = 0
) {
// No user label for system messages.
if (message.side == ChatSide.SYSTEM) {
return
}
val (horizontalArrangement, modifier, userLabel, rightSideLabel) = getMessageLayoutConfig(
message = message, agentNameRes = agentNameRes, imageHistoryCurIndex = imageHistoryCurIndex
)
Row(
modifier = modifier,
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = horizontalArrangement,
) {
Row(verticalAlignment = Alignment.CenterVertically) {
// Sender label.
Text(
userLabel,
style = bodySmallSemiBold,
)
when (message) {
// Benchmark running status.
is ChatMessageBenchmarkResult -> {
if (message.isRunning()) {
Spacer(modifier = Modifier.width(8.dp))
CircularProgressIndicator(
modifier = Modifier.size(10.dp),
strokeWidth = 1.5.dp,
color = MaterialTheme.colorScheme.secondary
)
Spacer(modifier = Modifier.width(4.dp))
}
val statusLabel = if (message.isWarmingUp()) {
stringResource(R.string.warming_up)
} else if (message.isRunning()) {
stringResource(R.string.running)
} else ""
if (statusLabel.isNotEmpty()) {
Text(
statusLabel,
color = MaterialTheme.colorScheme.secondary,
style = bodySmallNarrow,
)
}
}
// Benchmark LLM running status.
is ChatMessageBenchmarkLlmResult -> {
if (message.running) {
Spacer(modifier = Modifier.width(8.dp))
CircularProgressIndicator(
modifier = Modifier.size(10.dp),
strokeWidth = 1.5.dp,
color = MaterialTheme.colorScheme.secondary
)
}
}
// Image generation running status.
is ChatMessageImageWithHistory -> {
if (message.isRunning()) {
Spacer(modifier = Modifier.width(8.dp))
CircularProgressIndicator(
modifier = Modifier.size(10.dp),
strokeWidth = 1.5.dp,
color = MaterialTheme.colorScheme.secondary
)
Spacer(modifier = Modifier.width(4.dp))
Text(
stringResource(R.string.running),
color = MaterialTheme.colorScheme.secondary,
style = bodySmallNarrow,
)
}
}
}
}
// Right-side text.
when (message) {
is ChatMessageBenchmarkResult,
is ChatMessageImageWithHistory,
is ChatMessageBenchmarkLlmResult,
-> {
Text(rightSideLabel, style = MaterialTheme.typography.bodySmall)
}
}
}
}
@Composable
private fun getMessageLayoutConfig(
message: ChatMessage,
@StringRes agentNameRes: Int,
imageHistoryCurIndex: Int,
): MessageLayoutConfig {
var userLabel = stringResource(R.string.chat_you)
var rightSideLabel = ""
var horizontalArrangement = Arrangement.End
var modifier = Modifier.padding(bottom = 2.dp)
if (message.side == ChatSide.AGENT) {
userLabel = stringResource(agentNameRes)
}
when (message) {
is ChatMessageBenchmarkResult -> {
horizontalArrangement = Arrangement.SpaceBetween
modifier = modifier.fillMaxWidth()
userLabel = "Benchmark"
rightSideLabel = if (message.isWarmingUp()) {
"${message.warmupCurrent}/${message.warmupTotal}"
} else {
"${message.iterationCurrent}/${message.iterationTotal}"
}
}
is ChatMessageBenchmarkLlmResult -> {
horizontalArrangement = Arrangement.SpaceBetween
modifier = modifier.fillMaxWidth()
userLabel = "Benchmark"
}
is ChatMessageImageWithHistory -> {
horizontalArrangement = Arrangement.SpaceBetween
if (message.bitmaps.isNotEmpty()) {
modifier = modifier.width(200.dp)
}
rightSideLabel = "${imageHistoryCurIndex + 1}/${message.totalIterations}"
}
}
return MessageLayoutConfig(
horizontalArrangement = horizontalArrangement,
modifier = modifier,
userLabel = userLabel,
rightSideLabel = rightSideLabel
)
}
@Preview(showBackground = true)
@Composable
fun MessageSenderPreview() {
GalleryTheme {
Column(modifier = Modifier.padding(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp)) {
// Agent message.
MessageSender(
message = ChatMessageText(content = "hello world", side = ChatSide.AGENT),
agentNameRes = R.string.chat_generic_agent_name
)
// User message.
MessageSender(
message = ChatMessageText(content = "hello world", side = ChatSide.USER),
agentNameRes = R.string.chat_generic_agent_name
)
// Benchmark during warmup.
MessageSender(
message = ChatMessageBenchmarkResult(
orderedStats = listOf(),
statValues = mutableMapOf(),
values = listOf(),
histogram = Histogram(listOf(), 0),
warmupCurrent = 10,
warmupTotal = 50,
iterationCurrent = 0,
iterationTotal = 200
), agentNameRes = R.string.chat_generic_agent_name
)
// Benchmark during running.
MessageSender(
message = ChatMessageBenchmarkResult(
orderedStats = listOf(),
statValues = mutableMapOf(),
values = listOf(),
histogram = Histogram(listOf(), 0),
warmupCurrent = 50,
warmupTotal = 50,
iterationCurrent = 123,
iterationTotal = 200
), agentNameRes = R.string.chat_generic_agent_name
)
// Image generation during running.
MessageSender(
message = ChatMessageImageWithHistory(
bitmaps = listOf(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888)),
imageBitMaps = listOf(),
totalIterations = 10,
ChatSide.AGENT
),
agentNameRes = R.string.chat_generic_agent_name,
imageHistoryCurIndex = 4,
)
}
}
}

View file

@ -0,0 +1,176 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.Easing
import androidx.compose.animation.core.tween
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.itemsIndexed
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import kotlinx.coroutines.delay
import kotlin.math.cos
import kotlin.math.pow
private val GRID_SIZE = 240.dp
private val GRID_SPACING = 0.dp
private const val PAUSE_DURATION = 200
private const val ANIMATION_DURATION = 500
private const val START_SCALE = 0.9f
private const val END_SCALE = 0.6f
/**
* Composable function to display a loading animation using a 2x2 grid of images with a synchronized
* scaling and rotation effect.
*/
@Composable
fun ModelDownloadingAnimation() {
val scale = remember { Animatable(END_SCALE) }
LaunchedEffect(Unit) { // Run this once
while (true) {
// Phase 1: Scale up
scale.animateTo(
targetValue = START_SCALE,
animationSpec = tween(
durationMillis = ANIMATION_DURATION,
easing = multiBounceEasing(bounces = 3, decay = 0.02f)
)
)
delay(PAUSE_DURATION.toLong())
// Phase 2: Scale down
scale.animateTo(
targetValue = END_SCALE,
animationSpec = tween(
durationMillis = ANIMATION_DURATION,
easing = multiBounceEasing(bounces = 3, decay = 0.02f)
)
)
delay(PAUSE_DURATION.toLong())
}
}
Column(
horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier.offset(y = -GRID_SIZE / 8)
) {
LazyVerticalGrid(
columns = GridCells.Fixed(2),
horizontalArrangement = Arrangement.spacedBy(GRID_SPACING),
verticalArrangement = Arrangement.spacedBy(GRID_SPACING),
modifier = Modifier
.width(GRID_SIZE)
.height(GRID_SIZE)
) {
itemsIndexed(
listOf(
R.drawable.pantegon,
R.drawable.double_circle,
R.drawable.circle,
R.drawable.four_circle
)
) { index, imageResource ->
val currentScale =
if (index == 0 || index == 3) scale.value else START_SCALE + END_SCALE - scale.value
Box(
modifier = Modifier
.width((GRID_SIZE - GRID_SPACING) / 2)
.height((GRID_SIZE - GRID_SPACING) / 2),
contentAlignment = when (index) {
0 -> Alignment.BottomEnd
1 -> Alignment.BottomStart
2 -> Alignment.TopEnd
3 -> Alignment.TopStart
else -> Alignment.Center
}
) {
Image(
painter = painterResource(id = imageResource),
contentDescription = "",
contentScale = ContentScale.Fit,
colorFilter = ColorFilter.tint(getTaskIconColor(index = index)),
modifier = Modifier
.graphicsLayer {
scaleX = currentScale
scaleY = currentScale
rotationZ = currentScale * 120
alpha = 0.8f
}
.size(70.dp)
)
}
}
}
Text(
"Feel free to switch apps or lock your device.\n"
+ "The download will continue in the background.\n"
+ "We'll send a notification when it's done.",
style = MaterialTheme.typography.bodyMedium,
textAlign = TextAlign.Center
)
}
}
// Custom Easing function for a multi-bounce effect
fun multiBounceEasing(bounces: Int, decay: Float): Easing = Easing { x ->
if (x == 1f) {
1f
} else {
-decay.pow(x) * cos((x * (bounces + 0.9f) * Math.PI / 1.3f)).toFloat() + 1f
}
}
@Preview(showBackground = true)
@Composable
fun ModelDownloadingAnimationPreview() {
GalleryTheme {
Row(modifier = Modifier.padding(16.dp)) {
ModelDownloadingAnimation()
}
}
}

View file

@ -0,0 +1,89 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display a visual indicator for model initialization status.
*
* This function renders a row containing a circular progress indicator and a message
* indicating that the model is currently initializing. It provides a visual cue to the
* user that the model is in a loading state.
*/
@Composable
fun ModelInitializationStatusChip() {
Row(modifier = Modifier.fillMaxWidth(), horizontalArrangement = Arrangement.Center) {
Box(
modifier = Modifier
.padding(8.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.secondaryContainer)
) {
Row(
modifier = Modifier.padding(top = 4.dp, bottom = 4.dp, start = 8.dp, end = 8.dp),
horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically
) {
// Circular progress indicator.
CircularProgressIndicator(
modifier = Modifier.size(14.dp),
strokeWidth = 2.dp,
color = MaterialTheme.colorScheme.onSecondaryContainer
)
Spacer(modifier = Modifier.width(8.dp))
// Text message.
Text(
stringResource(R.string.model_is_initializing_msg),
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSecondaryContainer,
)
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ModelInitializationStatusPreview() {
GalleryTheme {
ModelInitializationStatusChip()
}
}

View file

@ -0,0 +1,54 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.Button
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.tooling.preview.Preview
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display a button to download model if the model has not been downloaded.
*/
@Composable
fun ModelNotDownloaded(modifier: Modifier = Modifier, onClicked: () -> Unit) {
Column(
modifier = modifier.fillMaxSize(),
verticalArrangement = Arrangement.Center,
horizontalAlignment = Alignment.CenterHorizontally
) {
Button(
onClick = onClicked,
) {
Text("Download & Try it", maxLines = 1)
}
}
}
@Preview(showBackground = true)
@Composable
fun Preview() {
GalleryTheme {
ModelNotDownloaded(onClicked = {})
}
}

View file

@ -0,0 +1,170 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.convertValueToTargetType
import com.google.aiedge.gallery.ui.common.modelitem.ModelItem
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.preview.TASK_TEST2
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display a selectable model item with an option to configure its settings.
*/
@Composable
fun ModelSelector(
model: Model,
task: Task,
modelManagerViewModel: ModelManagerViewModel,
modifier: Modifier = Modifier,
contentAlpha: Float = 1f,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
) {
var showConfigDialog by remember { mutableStateOf(false) }
val context = LocalContext.current
Column(
modifier = modifier
) {
Box(
modifier = Modifier
.fillMaxWidth().padding(bottom = 8.dp),
contentAlignment = Alignment.Center
) {
// Model row.
Row(
modifier = Modifier
.fillMaxWidth()
.graphicsLayer { alpha = contentAlpha },
verticalAlignment = Alignment.CenterVertically
) {
ModelItem(
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelClicked = {},
onConfigClicked = {
showConfigDialog = true
},
verticalSpacing = 10.dp,
modifier = Modifier
.weight(1f)
.padding(horizontal = 16.dp),
showDeleteButton = false,
showConfigButtonIfExisted = true,
canExpand = false,
)
}
}
}
// Config dialog.
if (showConfigDialog) {
ConfigDialog(
title = "Model configs",
configs = model.configs,
initialValues = model.configValues,
onDismissed = { showConfigDialog = false },
onOk = { curConfigValues ->
// Hide config dialog.
showConfigDialog = false
// Check if the configs are changed or not. Also check if the model needs to be
// re-initialized.
var same = true
var needReinitialization = false
for (config in model.configs) {
val key = config.key.label
val oldValue = convertValueToTargetType(
value = model.configValues.getValue(key), valueType = config.valueType
)
val newValue = convertValueToTargetType(
value = curConfigValues.getValue(key), valueType = config.valueType
)
if (oldValue != newValue) {
same = false
if (config.needReinitialization) {
needReinitialization = true
}
break
}
}
if (same) {
return@ConfigDialog
}
// Save the config values to Model.
val oldConfigValues = model.configValues
model.configValues = curConfigValues
// Force to re-initialize the model with the new configs.
if (needReinitialization) {
modelManagerViewModel.initializeModel(context = context, model = model, force = true)
}
// Notify.
onConfigChanged(oldConfigValues, model.configValues)
},
)
}
}
@Preview(showBackground = true)
@Composable
fun ModelSelectorPreview(
) {
GalleryTheme {
Column(verticalArrangement = Arrangement.spacedBy(16.dp)) {
ModelSelector(
model = TASK_TEST1.models[0],
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelSelector(
model = TASK_TEST1.models[1],
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelSelector(
model = TASK_TEST2.models[1],
task = TASK_TEST2,
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
}
}
}

View file

@ -0,0 +1,211 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.wrapContentHeight
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Delete
import androidx.compose.material.icons.rounded.DeleteSweep
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.customColors
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun TextInputHistorySheet(
history: List<String>,
onHistoryItemClicked: (String) -> Unit,
onHistoryItemDeleted: (String) -> Unit,
onHistoryItemsDeleteAll: () -> Unit,
onDismissed: () -> Unit
) {
val sheetState = rememberModalBottomSheetState()
val scope = rememberCoroutineScope()
ModalBottomSheet(
onDismissRequest = onDismissed,
sheetState = sheetState,
modifier = Modifier.wrapContentHeight(),
) {
SheetContent(
history = history,
onHistoryItemClicked = { item ->
scope.launch {
sheetState.hide()
delay(100)
onHistoryItemClicked(item)
onDismissed()
}
},
onHistoryItemDeleted = onHistoryItemDeleted,
onHistoryItemsDeleteAll = {
scope.launch {
sheetState.hide()
onDismissed()
onHistoryItemsDeleteAll()
}
},
onDismissed = {
scope.launch {
sheetState.hide()
onDismissed()
}
}
)
}
}
@Composable
private fun SheetContent(
history: List<String>,
onHistoryItemClicked: (String) -> Unit,
onHistoryItemDeleted: (String) -> Unit,
onHistoryItemsDeleteAll: () -> Unit,
onDismissed: () -> Unit
) {
val scope = rememberCoroutineScope()
var showConfirmDeleteDialog by remember { mutableStateOf(false) }
Column {
Box(contentAlignment = Alignment.CenterEnd) {
Text(
"Text input history",
style = MaterialTheme.typography.titleLarge,
modifier = Modifier
.fillMaxWidth()
.padding(8.dp),
textAlign = TextAlign.Center
)
IconButton(modifier = Modifier.padding(end = 12.dp), onClick = {
showConfirmDeleteDialog = true
}) {
Icon(Icons.Rounded.DeleteSweep, contentDescription = "")
}
}
LazyColumn(modifier = Modifier.weight(1f)) {
items(history, key = { it }) { item ->
Row(
modifier = Modifier
.fillMaxWidth()
.padding(horizontal = 8.dp, vertical = 2.dp)
.clip(RoundedCornerShape(24.dp))
.background(MaterialTheme.customColors.agentBubbleBgColor)
.clickable {
onHistoryItemClicked(item)
},
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(8.dp)
) {
Text(
item,
style = MaterialTheme.typography.bodyMedium,
modifier = Modifier
.padding(vertical = 16.dp)
.padding(start = 16.dp)
.weight(1f)
)
IconButton(modifier = Modifier.padding(end = 8.dp), onClick = {
scope.launch {
delay(400)
onHistoryItemDeleted(item)
}
}) {
Icon(Icons.Rounded.Delete, contentDescription = "")
}
}
}
}
}
if (showConfirmDeleteDialog) {
AlertDialog(onDismissRequest = { showConfirmDeleteDialog = false },
title = { Text("Clear history?") },
text = {
Text(
"Are you sure you want to clear the history? This action cannot be undone."
)
},
confirmButton = {
Button(onClick = {
showConfirmDeleteDialog = false
onHistoryItemsDeleteAll()
}) {
Text(stringResource(R.string.ok))
}
},
dismissButton = {
TextButton(onClick = { showConfirmDeleteDialog = false }) {
Text(stringResource(R.string.cancel))
}
})
}
}
@Preview(showBackground = true)
@Composable
fun TextInputHistorySheetContentPreview() {
GalleryTheme {
SheetContent(
history = listOf(
"Analyze the sentiment of the following Tweets and classify them as POSITIVE, NEGATIVE, or NEUTRAL. \"It's so beautiful today!\"",
"I have the ingredients above. Not sure what to cook for lunch. Show me a list of foods with the recipes.",
"You are Santa Claus, write a letter back for this kid.",
"Generate a list of cookie recipes. Make the outputs in JSON format."
),
onHistoryItemClicked = {},
onHistoryItemDeleted = {},
onHistoryItemsDeleteAll = {},
onDismissed = {},
)
}
}

View file

@ -0,0 +1,73 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.modelitem
import androidx.compose.animation.core.DeferredTargetAnimation
import androidx.compose.animation.core.ExperimentalAnimatableApi
import androidx.compose.animation.core.VectorConverter
import androidx.compose.animation.core.tween
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.ui.Modifier
import androidx.compose.ui.composed
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.layout.LookaheadScope
import androidx.compose.ui.layout.approachLayout
import androidx.compose.ui.unit.Constraints
import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.unit.IntSize
import androidx.compose.ui.unit.round
const val LAYOUT_ANIMATION_DURATION = 250
context(LookaheadScope)
@OptIn(ExperimentalAnimatableApi::class)
fun Modifier.animateLayout(): Modifier = composed {
val sizeAnim = remember { DeferredTargetAnimation(IntSize.VectorConverter) }
val offsetAnim = remember { DeferredTargetAnimation(IntOffset.VectorConverter) }
val scope = rememberCoroutineScope()
this.approachLayout(
isMeasurementApproachInProgress = { lookaheadSize ->
sizeAnim.updateTarget(lookaheadSize, scope, tween(LAYOUT_ANIMATION_DURATION))
!sizeAnim.isIdle
},
isPlacementApproachInProgress = { lookaheadCoordinates ->
val target = lookaheadScopeCoordinates.localLookaheadPositionOf(lookaheadCoordinates)
offsetAnim.updateTarget(target.round(), scope, tween(LAYOUT_ANIMATION_DURATION))
!offsetAnim.isIdle
}
) { measurable, _ ->
val (animWidth, animHeight) = sizeAnim.updateTarget(
lookaheadSize,
scope,
tween(LAYOUT_ANIMATION_DURATION)
)
measurable.measure(Constraints.fixed(animWidth, animHeight))
.run {
layout(width, height) {
coordinates?.let {
val target = lookaheadScopeCoordinates.localLookaheadPositionOf(it).round()
val animOffset = offsetAnim.updateTarget(target, scope, tween(LAYOUT_ANIMATION_DURATION))
val current = lookaheadScopeCoordinates.localPositionOf(it, Offset.Zero).round()
val (x, y) = animOffset - current
place(x, y)
} ?: place(0, 0)
}
}
}
}

View file

@ -0,0 +1,52 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.modelitem
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Button
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.ui.res.stringResource
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.Model
/**
* Composable function to display a confirmation dialog for deleting a model.
*/
@Composable
fun ConfirmDeleteModelDialog(model: Model, onConfirm: () -> Unit, onDismiss: () -> Unit) {
AlertDialog(onDismissRequest = onDismiss,
title = { Text(stringResource(R.string.confirm_delete_model_dialog_title)) },
text = {
Text(
stringResource(R.string.confirm_delete_model_dialog_content).format(
model.name
)
)
},
confirmButton = {
Button(onClick = onConfirm) {
Text(stringResource(R.string.ok))
}
},
dismissButton = {
TextButton(onClick = onDismiss) {
Text(stringResource(R.string.cancel))
}
})
}

View file

@ -0,0 +1,405 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.modelitem
import android.content.Intent
import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.animation.core.animateFloatAsState
import androidx.compose.animation.core.tween
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.heightIn
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material.icons.rounded.UnfoldLess
import androidx.compose.material.icons.rounded.UnfoldMore
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.OutlinedButton
import androidx.compose.material3.Text
import androidx.compose.material3.ripple
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.movableContentOf
import androidx.compose.runtime.movableContentWithReceiverOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.layout.LookaheadScope
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.DownloadAndTryButton
import com.google.aiedge.gallery.ui.common.TaskIcon
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.checkNotificationPermissonAndStartDownload
import com.google.aiedge.gallery.ui.common.getTaskBgColor
import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.MODEL_TEST1
import com.google.aiedge.gallery.ui.preview.MODEL_TEST2
import com.google.aiedge.gallery.ui.preview.MODEL_TEST3
import com.google.aiedge.gallery.ui.preview.MODEL_TEST4
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.preview.TASK_TEST2
import com.google.aiedge.gallery.ui.theme.GalleryTheme
private val DEFAULT_VERTICAL_PADDING = 16.dp
/**
* Composable function to display a model item in the model manager list.
*
* This function renders a card representing a model, displaying its task icon, name,
* download status, and providing action buttons. It supports expanding to show a
* model description and buttons for learning more (opening a URL) and downloading/trying
* the model.
*/
@Composable
fun ModelItem(
model: Model,
task: Task,
modelManagerViewModel: ModelManagerViewModel,
onModelClicked: (Model) -> Unit,
modifier: Modifier = Modifier,
onConfigClicked: () -> Unit = {},
verticalSpacing: Dp = DEFAULT_VERTICAL_PADDING,
showDeleteButton: Boolean = true,
showConfigButtonIfExisted: Boolean = false,
canExpand: Boolean = true,
) {
val context = LocalContext.current
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val downloadStatus by remember {
derivedStateOf { modelManagerUiState.modelDownloadStatus[model.name] }
}
val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.RequestPermission()
) {
modelManagerViewModel.downloadModel(model)
}
var isExpanded by remember { mutableStateOf(false) }
// Animate alpha for model description and button rows when switching between layouts.
val alphaAnimation by animateFloatAsState(
targetValue = if (isExpanded) 1f else 0f,
animationSpec = tween(durationMillis = LAYOUT_ANIMATION_DURATION - 50)
)
LookaheadScope {
// Task icon.
val taskIcon = remember {
movableContentOf {
TaskIcon(
task = task, modifier = Modifier.animateLayout()
)
}
}
// Model name and status.
val modelNameAndStatus = remember {
movableContentOf {
ModelNameAndStatus(
model = model,
task = task,
downloadStatus = downloadStatus,
isExpanded = isExpanded,
modifier = Modifier.animateLayout()
)
}
}
val actionButton = remember {
movableContentOf {
ModelItemActionButton(
context = context,
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
downloadStatus = downloadStatus,
onDownloadClicked = { model ->
checkNotificationPermissonAndStartDownload(
context = context,
launcher = launcher,
modelManagerViewModel = modelManagerViewModel,
model = model
)
},
showDeleteButton = showDeleteButton,
showDownloadButton = false,
)
}
}
// Expand/collapse icon, or the config icon.
val expandButton = remember {
movableContentOf {
if (showConfigButtonIfExisted) {
if (downloadStatus?.status === ModelDownloadStatusType.SUCCEEDED) {
if (model.configs.isNotEmpty()) {
IconButton(onClick = onConfigClicked) {
Icon(
Icons.Rounded.Settings,
contentDescription = "",
tint = getTaskIconColor(task)
)
}
}
}
} else {
Icon(
if (isExpanded) Icons.Rounded.UnfoldLess else Icons.Rounded.UnfoldMore,
contentDescription = "",
tint = getTaskIconColor(task),
)
}
}
}
// Model description shown in expanded layout.
val modelDescription = remember {
movableContentOf { m: Modifier ->
if (model.info.isNotEmpty()) {
MarkdownText(
model.info,
modifier = Modifier
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp)
.animateLayout()
.then(m)
)
}
}
}
// Button rows shown in expanded layout.
val buttonRows = remember {
movableContentOf { m: Modifier ->
Row(
modifier = Modifier
.heightIn(min = 0.dp, max = if (isExpanded) 1000.dp else 0.dp)
.animateLayout()
.then(m),
horizontalArrangement = Arrangement.spacedBy(12.dp),
) {
// The "learn more" button. Click to show url in default browser.
if (model.learnMoreUrl.isNotEmpty()) {
OutlinedButton(
onClick = {
if (isExpanded) {
val intent = Intent(Intent.ACTION_VIEW, Uri.parse(model.learnMoreUrl))
context.startActivity(intent)
}
},
) {
Text("Learn More", maxLines = 1)
}
}
// Button to start the download and start the chat session with the model.
val needToDownloadFirst =
downloadStatus?.status == ModelDownloadStatusType.NOT_DOWNLOADED || downloadStatus?.status == ModelDownloadStatusType.FAILED
DownloadAndTryButton(
model = model,
enabled = isExpanded,
needToDownloadFirst = needToDownloadFirst,
modelManagerViewModel = modelManagerViewModel,
onClicked = { onModelClicked(model) }
)
// Button(
// onClick = {
// if (isExpanded) {
// onModelClicked(model)
// if (needToDownloadFirst) {
// scope.launch {
// delay(80)
// checkNotificationPermissonAndStartDownload(
// context = context,
// launcher = launcher,
// modelManagerViewModel = modelManagerViewModel,
// model = model
// )
// }
// }
// }
// },
// ) {
// Icon(
// Icons.AutoMirrored.Rounded.ArrowForward,
// contentDescription = "",
// modifier = Modifier.padding(end = 4.dp)
// )
// if (needToDownloadFirst) {
// Text("Download & Try it", maxLines = 1)
// } else {
// Text("Try it", maxLines = 1)
// }
// }
}
}
}
val container = remember {
movableContentWithReceiverOf<LookaheadScope, @Composable () -> Unit> { content ->
Box(
modifier = Modifier.animateLayout(),
contentAlignment = Alignment.TopEnd,
) {
content()
}
}
}
var boxModifier = modifier
.fillMaxWidth()
.clip(RoundedCornerShape(size = 42.dp))
.background(
getTaskBgColor(task)
)
boxModifier = if (canExpand) {
boxModifier.clickable(
onClick = { isExpanded = !isExpanded },
interactionSource = remember { MutableInteractionSource() },
indication = ripple(
bounded = true,
radius = 500.dp,
)
)
} else {
boxModifier
}
Box(
modifier = boxModifier,
contentAlignment = Alignment.Center
) {
if (isExpanded) {
container {
// The main part (icon, model name, status, etc)
Column(
verticalArrangement = Arrangement.spacedBy(14.dp),
horizontalAlignment = Alignment.CenterHorizontally,
modifier = Modifier
.fillMaxWidth()
.padding(vertical = verticalSpacing, horizontal = 18.dp)
) {
Box(contentAlignment = Alignment.Center) {
taskIcon()
Row(
modifier = Modifier.fillMaxWidth(),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.End
) {
actionButton()
expandButton()
}
}
modelNameAndStatus()
modelDescription(Modifier.alpha(alphaAnimation))
buttonRows(Modifier.alpha(alphaAnimation)) // Apply alpha here
}
}
} else {
container {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
// The main part (icon, model name, status, etc)
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp),
modifier = Modifier
.fillMaxWidth()
.padding(start = 18.dp, end = 18.dp)
.padding(vertical = verticalSpacing)
) {
taskIcon()
Row(modifier = Modifier.weight(1f)) {
modelNameAndStatus()
}
Row(verticalAlignment = Alignment.CenterVertically) {
actionButton()
expandButton()
}
}
Column(
modifier = Modifier.offset(y = 30.dp),
horizontalAlignment = Alignment.CenterHorizontally
) {
modelDescription(Modifier.alpha(alphaAnimation))
buttonRows(Modifier.alpha(alphaAnimation))
}
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun PreviewModelItem() {
GalleryTheme {
Column(
verticalArrangement = Arrangement.spacedBy(16.dp), modifier = Modifier.padding(16.dp)
) {
ModelItem(
model = MODEL_TEST1,
task = TASK_TEST1,
onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelItem(
model = MODEL_TEST2,
task = TASK_TEST1,
onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelItem(
model = MODEL_TEST3,
task = TASK_TEST2,
onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
ModelItem(
model = MODEL_TEST4,
task = TASK_TEST2,
onModelClicked = { },
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
)
}
}
}

View file

@ -0,0 +1,133 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.modelitem
import android.content.Context
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.Cancel
import androidx.compose.material.icons.rounded.Delete
import androidx.compose.material.icons.rounded.FileDownload
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
/**
* Composable function to display action buttons for a model item, based on its download status.
*
* This function renders the appropriate action button (download, delete, cancel) based on the
* provided ModelDownloadStatus. It also handles notification permission requests for downloading
* and displays a confirmation dialog for deleting models.
*/
@Composable
fun ModelItemActionButton(
context: Context,
model: Model,
task: Task,
modelManagerViewModel: ModelManagerViewModel,
downloadStatus: ModelDownloadStatus?,
onDownloadClicked: (Model) -> Unit,
modifier: Modifier = Modifier,
showDeleteButton: Boolean = true,
showDownloadButton: Boolean = true,
) {
var showConfirmDeleteDialog by remember { mutableStateOf(false) }
Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) {
when (downloadStatus?.status) {
// Button to start the download.
ModelDownloadStatusType.NOT_DOWNLOADED, ModelDownloadStatusType.FAILED ->
if (showDownloadButton) {
IconButton(onClick = {
onDownloadClicked(model)
}) {
Icon(
Icons.Rounded.FileDownload,
contentDescription = "",
tint = getTaskIconColor(task),
)
}
}
// Button to delete the download.
ModelDownloadStatusType.SUCCEEDED -> {
if (showDeleteButton) {
IconButton(onClick = {
showConfirmDeleteDialog = true
}) {
Icon(
Icons.Rounded.Delete,
contentDescription = "",
tint = getTaskIconColor(task),
)
}
}
}
// Show spinner when the model is partially downloaded because it might some time for
// background task to be started by Android.
ModelDownloadStatusType.PARTIALLY_DOWNLOADED -> {
CircularProgressIndicator(
modifier = Modifier
.padding(end = 12.dp)
.size(24.dp)
)
}
// Button to cancel the download when it is in progress.
ModelDownloadStatusType.IN_PROGRESS, ModelDownloadStatusType.UNZIPPING -> IconButton(onClick = {
modelManagerViewModel.cancelDownloadModel(
model
)
}) {
Icon(
Icons.Rounded.Cancel,
contentDescription = "",
tint = getTaskIconColor(task),
)
}
else -> {}
}
}
if (showConfirmDeleteDialog) {
ConfirmDeleteModelDialog(model = model, onConfirm = {
modelManagerViewModel.deleteModel(model)
showConfirmDeleteDialog = false
}, onDismiss = {
showConfirmDeleteDialog = false
})
}
}

View file

@ -0,0 +1,187 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.modelitem
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.tween
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.formatToHourMinSecond
import com.google.aiedge.gallery.ui.common.getTaskIconColor
import com.google.aiedge.gallery.ui.common.humanReadableSize
import com.google.aiedge.gallery.ui.theme.labelSmallNarrow
/**
* Composable function to display the model name and its download status information.
*
* This function renders the model's name and its current download status, including:
* - Model name.
* - Failure message (if download failed).
* - Download progress (received size, total size, download rate, remaining time) for
* in-progress downloads.
* - "Unzipping..." status for unzipping processes.
* - Model size for successful downloads.
*/
@Composable
fun ModelNameAndStatus(
model: Model,
task: Task,
downloadStatus: ModelDownloadStatus?,
isExpanded: Boolean,
modifier: Modifier = Modifier
) {
val inProgress = downloadStatus?.status == ModelDownloadStatusType.IN_PROGRESS
val isPartiallyDownloaded = downloadStatus?.status == ModelDownloadStatusType.PARTIALLY_DOWNLOADED
var curDownloadProgress = 0f
Column(
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start
) {
// Model name.
Row(
verticalAlignment = Alignment.CenterVertically,
) {
Text(
model.name,
style = MaterialTheme.typography.titleMedium,
modifier = modifier,
)
}
Row(verticalAlignment = Alignment.CenterVertically) {
// Status icon.
if (!inProgress && !isPartiallyDownloaded) {
StatusIcon(
downloadStatus = downloadStatus,
modifier = modifier.padding(end = 4.dp)
)
}
// Failure message.
if (downloadStatus != null && downloadStatus.status == ModelDownloadStatusType.FAILED) {
Row(verticalAlignment = Alignment.CenterVertically) {
Text(
downloadStatus.errorMessage,
color = MaterialTheme.colorScheme.error,
style = labelSmallNarrow,
overflow = TextOverflow.Ellipsis,
modifier = modifier,
)
}
}
// Status label
else {
var sizeLabel = model.totalBytes.humanReadableSize()
var fontSize = 11.sp
// Populate the status label.
if (downloadStatus != null) {
// For in-progress model, show {receivedSize} / {totalSize} - {rate} - {remainingTime}
if (inProgress || isPartiallyDownloaded) {
var totalSize = downloadStatus.totalBytes
if (totalSize == 0L) {
totalSize = model.totalBytes
}
sizeLabel =
"${downloadStatus.receivedBytes.humanReadableSize(extraDecimalForGbAndAbove = true)} of ${totalSize.humanReadableSize()}"
if (downloadStatus.bytesPerSecond > 0) {
sizeLabel =
"$sizeLabel · ${downloadStatus.bytesPerSecond.humanReadableSize()} / s"
if (downloadStatus.remainingMs >= 0) {
sizeLabel =
"$sizeLabel\n${downloadStatus.remainingMs.formatToHourMinSecond()} left"
}
}
if (isPartiallyDownloaded) {
sizeLabel = "$sizeLabel (resuming...)"
}
curDownloadProgress =
downloadStatus.receivedBytes.toFloat() / downloadStatus.totalBytes.toFloat()
if (curDownloadProgress.isNaN()) {
curDownloadProgress = 0f
}
fontSize = 9.sp
}
// Status for unzipping.
else if (downloadStatus.status == ModelDownloadStatusType.UNZIPPING) {
sizeLabel = "Unzipping..."
}
}
Column(
horizontalAlignment = if (isExpanded) Alignment.CenterHorizontally else Alignment.Start,
) {
for ((index, line) in sizeLabel.split("\n").withIndex()) {
Text(
line,
color = MaterialTheme.colorScheme.secondary,
maxLines = 1,
style = labelSmallNarrow.copy(fontSize = fontSize, lineHeight = 10.sp),
textAlign = if (isExpanded) TextAlign.Center else TextAlign.Start,
overflow = TextOverflow.Visible,
modifier = modifier.offset(y = if (index == 0) 0.dp else (-1).dp)
)
}
}
}
}
// Download progress bar.
if (inProgress || isPartiallyDownloaded) {
val animatedProgress = remember { Animatable(0f) }
LinearProgressIndicator(
progress = { animatedProgress.value },
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = modifier.padding(top = 2.dp)
)
LaunchedEffect(curDownloadProgress) {
animatedProgress.animateTo(curDownloadProgress, animationSpec = tween(150))
}
}
// Unzipping progress.
else if (downloadStatus?.status == ModelDownloadStatusType.UNZIPPING) {
LinearProgressIndicator(
color = getTaskIconColor(task = task),
trackColor = MaterialTheme.colorScheme.surfaceContainerHighest,
modifier = Modifier
.padding(top = 2.dp),
)
}
}
}

View file

@ -0,0 +1,94 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common.modelitem
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.size
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.outlined.HelpOutline
import androidx.compose.material.icons.filled.DownloadForOffline
import androidx.compose.material.icons.rounded.Error
import androidx.compose.material3.Icon
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/**
* Composable function to display an icon representing the download status of a model.
*/
@Composable
fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifier) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.Center,
modifier = modifier
) {
when (downloadStatus?.status) {
ModelDownloadStatusType.NOT_DOWNLOADED -> Icon(
Icons.AutoMirrored.Outlined.HelpOutline,
tint = Color(0xFFCCCCCC),
contentDescription = "",
modifier = Modifier.size(14.dp)
)
ModelDownloadStatusType.SUCCEEDED -> {
Icon(
Icons.Filled.DownloadForOffline,
tint = Color(0xff3d860b),
contentDescription = "",
modifier = Modifier.size(14.dp)
)
}
ModelDownloadStatusType.FAILED -> Icon(
Icons.Rounded.Error,
tint = Color(0xFFAA0000),
contentDescription = "",
modifier = Modifier.size(14.dp)
)
else -> {}
}
}
}
@Preview(showBackground = true)
@Composable
fun StatusIconPreview() {
GalleryTheme {
Column {
for (downloadStatus in listOf(
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED),
ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS),
ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED),
ModelDownloadStatus(status = ModelDownloadStatusType.FAILED),
ModelDownloadStatus(status = ModelDownloadStatusType.UNZIPPING),
ModelDownloadStatus(status = ModelDownloadStatusType.PARTIALLY_DOWNLOADED),
)) {
StatusIcon(downloadStatus = downloadStatus)
}
}
}
}

View file

@ -0,0 +1,273 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.home
import androidx.annotation.StringRes
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.aspectRatio
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.GridItemSpan
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.items
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Brush
import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.layout.layout
import androidx.compose.ui.platform.LocalConfiguration
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.GalleryTopAppBar
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.TaskIcon
import com.google.aiedge.gallery.ui.common.getTaskBgColor
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.theme.ThemeSettings
import com.google.aiedge.gallery.ui.theme.customColors
import com.google.aiedge.gallery.ui.theme.titleMediumNarrow
/** Navigation destination data */
object HomeScreenDestination {
@StringRes
val titleRes = R.string.app_name
}
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun HomeScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateToTaskScreen: (Task) -> Unit,
modifier: Modifier = Modifier
) {
val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior()
val uiState by modelManagerViewModel.uiState.collectAsState()
var showSettingsDialog by remember { mutableStateOf(false) }
val tasks = uiState.tasks
val loadingHfModels = uiState.loadingHfModels
Scaffold(modifier = modifier.nestedScroll(scrollBehavior.nestedScrollConnection), topBar = {
GalleryTopAppBar(
title = stringResource(HomeScreenDestination.titleRes),
rightAction = AppBarAction(
actionType = AppBarActionType.APP_SETTING, actionFn = {
showSettingsDialog = true
}
),
loadingHfModels = loadingHfModels,
scrollBehavior = scrollBehavior,
)
}) { innerPadding ->
TaskList(
tasks = tasks,
navigateToTaskScreen = navigateToTaskScreen,
modifier = Modifier.fillMaxSize(),
contentPadding = innerPadding,
)
}
// Settings dialog.
if (showSettingsDialog) {
SettingsDialog(
curThemeOverride = modelManagerViewModel.readThemeOverride(),
onDismissed = { showSettingsDialog = false },
onOk = { curConfigValues ->
// Update theme settings.
// This will update app's theme.
val themeOverride = curConfigValues[ConfigKey.THEME.label] as String
ThemeSettings.themeOverride.value = themeOverride
// Save to data store.
modelManagerViewModel.saveThemeOverride(themeOverride)
},
)
}
}
@Composable
private fun TaskList(
tasks: List<Task>,
navigateToTaskScreen: (Task) -> Unit,
modifier: Modifier = Modifier,
contentPadding: PaddingValues = PaddingValues(0.dp),
) {
Box(modifier = modifier.fillMaxSize()) {
LazyVerticalGrid(
columns = GridCells.Fixed(count = 2),
contentPadding = contentPadding,
modifier = modifier.padding(12.dp),
horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// Headline.
item(span = { GridItemSpan(2) }) {
Text(
"Welcome to AI Edge Gallery! Explore a world of \namazing on-device models from LiteRT community",
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.padding(bottom = 20.dp)
)
}
// Cards.
items(tasks) { task ->
TaskCard(
task = task,
onClick = {
navigateToTaskScreen(task)
},
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f)
)
}
}
// Gradient overlay at the bottom.
Box(
modifier = Modifier
.fillMaxWidth()
.height(LocalConfiguration.current.screenHeightDp.dp * 0.25f)
.background(
Brush.verticalGradient(
colors = MaterialTheme.customColors.homeBottomGradient,
)
)
.align(Alignment.BottomCenter)
)
}
}
@Composable
private fun TaskCard(task: Task, onClick: () -> Unit, modifier: Modifier = Modifier) {
Card(
modifier = modifier
.clip(RoundedCornerShape(43.5.dp))
.clickable(
onClick = onClick,
),
colors = CardDefaults.cardColors(
containerColor = getTaskBgColor(task = task)
),
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(24.dp),
) {
// Icon.
TaskIcon(task = task)
Spacer(modifier = Modifier.weight(1f))
// Title.
val pair = task.type.label.splitByFirstSpace()
Text(
pair.first,
color = MaterialTheme.colorScheme.primary,
style = titleMediumNarrow.copy(
fontSize = 20.sp,
fontWeight = FontWeight.Bold,
),
)
if (pair.second.isNotEmpty()) {
Text(
pair.second,
color = MaterialTheme.colorScheme.primary,
style = titleMediumNarrow.copy(
fontSize = 18.sp,
fontWeight = FontWeight.Bold,
),
modifier = Modifier.layout { measurable, constraints ->
val placeable = measurable.measure(constraints)
layout(placeable.width, placeable.height) {
placeable.placeRelative(0, -4.dp.roundToPx())
}
}
)
}
// Model count.
val modelCountLabel = when (task.models.size) {
1 -> "1 Model"
else -> "%d Models".format(task.models.size)
}
Text(
modelCountLabel,
color = MaterialTheme.colorScheme.secondary,
style = MaterialTheme.typography.bodyMedium
)
}
}
}
private fun String.splitByFirstSpace(): Pair<String, String> {
val spaceIndex = this.indexOf(' ')
if (spaceIndex == -1) {
return Pair(this, "")
}
return Pair(this.substring(0, spaceIndex), this.substring(spaceIndex + 1))
}
@Preview
@Composable
fun HomeScreenPreview(
) {
GalleryTheme {
HomeScreen(
modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current),
navigateToTaskScreen = {},
)
}
}

View file

@ -0,0 +1,60 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.home
import androidx.compose.runtime.Composable
import com.google.aiedge.gallery.VERSION
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
import com.google.aiedge.gallery.ui.theme.THEME_AUTO
import com.google.aiedge.gallery.ui.theme.THEME_DARK
import com.google.aiedge.gallery.ui.theme.THEME_LIGHT
private val CONFIGS: List<Config> = listOf(
SegmentedButtonConfig(
key = ConfigKey.THEME,
defaultValue = THEME_AUTO,
options = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK)
)
)
@Composable
fun SettingsDialog(
curThemeOverride: String,
onDismissed: () -> Unit,
onOk: (Map<String, Any>) -> Unit,
) {
val initialValues = mapOf(
ConfigKey.THEME.label to curThemeOverride
)
ConfigDialog(
title = "Settings",
subtitle = "App version: $VERSION",
okBtnLabel = "OK",
configs = CONFIGS,
initialValues = initialValues,
onDismissed = onDismissed,
onOk = { curConfigValues ->
onOk(curConfigValues)
// Hide config dialog.
onDismissed()
},
)
}

View file

@ -0,0 +1,91 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.PathFillType
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.StrokeCap
import androidx.compose.ui.graphics.StrokeJoin
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Deployed_code: ImageVector
get() {
if (internal_Deployed_code != null) {
return internal_Deployed_code!!
}
internal_Deployed_code = ImageVector.Builder(
name = "Deployed_code",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f
).apply {
path(
fill = SolidColor(Color.Black),
fillAlpha = 1.0f,
stroke = null,
strokeAlpha = 1.0f,
strokeLineWidth = 1.0f,
strokeLineCap = StrokeCap.Butt,
strokeLineJoin = StrokeJoin.Miter,
strokeLineMiter = 1.0f,
pathFillType = PathFillType.NonZero
) {
moveTo(440f, 777f)
verticalLineToRelative(-274f)
lineTo(200f, 364f)
verticalLineToRelative(274f)
close()
moveToRelative(80f, 0f)
lineToRelative(240f, -139f)
verticalLineToRelative(-274f)
lineTo(520f, 503f)
close()
moveToRelative(-40f, -343f)
lineToRelative(237f, -137f)
lineToRelative(-237f, -137f)
lineToRelative(-237f, 137f)
close()
moveTo(160f, 708f)
quadToRelative(-19f, -11f, -29.5f, -29f)
reflectiveQuadTo(120f, 639f)
verticalLineToRelative(-318f)
quadToRelative(0f, -22f, 10.5f, -40f)
reflectiveQuadToRelative(29.5f, -29f)
lineToRelative(280f, -161f)
quadToRelative(19f, -11f, 40f, -11f)
reflectiveQuadToRelative(40f, 11f)
lineToRelative(280f, 161f)
quadToRelative(19f, 11f, 29.5f, 29f)
reflectiveQuadToRelative(10.5f, 40f)
verticalLineToRelative(318f)
quadToRelative(0f, 22f, -10.5f, 40f)
reflectiveQuadTo(800f, 708f)
lineTo(520f, 869f)
quadToRelative(-19f, 11f, -40f, 11f)
reflectiveQuadToRelative(-40f, -11f)
close()
moveToRelative(320f, -228f)
}
}.build()
return internal_Deployed_code!!
}
private var internal_Deployed_code: ImageVector? = null

View file

@ -0,0 +1,154 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.imageclassification
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import androidx.compose.ui.graphics.Color
import com.google.android.gms.tflite.client.TfLiteInitializationOptions
import com.google.android.gms.tflite.gpu.support.TfLiteGpu
import com.google.android.gms.tflite.java.TfLite
import com.google.aiedge.gallery.ui.common.chat.Classification
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.LatencyProvider
import org.tensorflow.lite.DataType
import org.tensorflow.lite.InterpreterApi
import org.tensorflow.lite.gpu.GpuDelegateFactory
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.File
import java.io.FileInputStream
private const val TAG = "AGImageClassificationModelHelper"
class ImageClassificationInferenceResult(
val categories: List<Classification>, override val latencyMs: Float
) : LatencyProvider
//TODO: handle error.
object ImageClassificationModelHelper {
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
val useGpu = model.getBooleanConfigValue(key = ConfigKey.USE_GPU)
TfLiteGpu.isGpuDelegateAvailable(context).continueWith { gpuTask ->
val optionsBuilder = TfLiteInitializationOptions.builder()
if (gpuTask.result) {
optionsBuilder.setEnableGpuDelegateSupport(true)
}
val task = TfLite.initialize(
context,
optionsBuilder.build()
)
task.addOnSuccessListener {
val interpreterOption =
InterpreterApi.Options().setRuntime(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)
if (useGpu) {
interpreterOption.addDelegateFactory(GpuDelegateFactory())
}
val interpreter = InterpreterApi.create(
File(model.getPath(context = context)), interpreterOption
)
model.instance = interpreter
onDone()
}
}
}
fun cleanUp(model: Model) {
if (model.instance == null) {
return
}
val instance = model.instance as InterpreterApi
instance.close()
}
fun runInference(
context: Context,
model: Model,
input: Bitmap,
primaryColor: Color,
): ImageClassificationInferenceResult {
val instance = model.instance
if (instance == null) {
Log.d(
TAG, "Model '${model.name}' has not been initialized"
)
return ImageClassificationInferenceResult(categories = listOf(), latencyMs = 0f)
}
// Pre-process image.
val start = System.currentTimeMillis()
val interpreter = model.instance as InterpreterApi
val inputShape = interpreter.getInputTensor(0).shape()
val imageProcessor = ImageProcessor.Builder()
.add(ResizeOp(inputShape[1], inputShape[2], ResizeOp.ResizeMethod.BILINEAR))
.add(NormalizeOp(127.5f, 127.5f)) // Normalize pixel values
.build()
val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(input)
val inputTensorBuffer = imageProcessor.process(tensorImage).tensorBuffer
// Output buffer
val outputBuffer =
TensorBuffer.createFixedSize(interpreter.getOutputTensor(0).shape(), DataType.FLOAT32)
// Run inference
interpreter.run(inputTensorBuffer.buffer, outputBuffer.buffer)
// Post-process result.
val output = outputBuffer.floatArray
val labelsFilePath = model.getPath(
context = context,
fileName = model.getExtraDataFile(name = "labels")?.downloadFileName ?: "_"
)
val labelsFileInputStream = FileInputStream(File(labelsFilePath))
val labels = FileUtil.loadLabels(labelsFileInputStream)
labelsFileInputStream.close()
val topIndices =
getTopKMaxIndices(output = output, k = model.getIntConfigValue(ConfigKey.MAX_RESULT_COUNT))
val categories: List<Classification> =
topIndices.map { i ->
Classification(
label = labels[i],
score = output[i],
color = primaryColor
)
}
return ImageClassificationInferenceResult(
categories = categories,
latencyMs = (System.currentTimeMillis() - start).toFloat()
)
}
private fun getTopKMaxIndices(output: FloatArray, k: Int): List<Int> {
if (k <= 0 || output.isEmpty()) {
return emptyList()
}
val indexedValues = output.withIndex().toList()
val sortedIndexedValues =
indexedValues.sortedByDescending { it.value }
return sortedIndexedValues.take(k).map { it.index } // Take the top k and extract indices
}
}

View file

@ -0,0 +1,97 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.imageclassification
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.chat.ChatInputType
import com.google.aiedge.gallery.ui.common.chat.ChatMessageImage
import com.google.aiedge.gallery.ui.common.chat.ChatView
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */
object ImageClassificationDestination {
@Serializable
val route = "ImageClassificationRoute"
}
@Composable
fun ImageClassificationScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: ImageClassificationViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
val context = LocalContext.current
val primaryColor = MaterialTheme.colorScheme.primary
ChatView(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message ->
viewModel.addMessage(
model = model,
message = message,
)
if (message is ChatMessageImage) {
viewModel.generateResponse(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor,
)
}
},
onStreamImageMessage = { model, message ->
viewModel.generateStreamingResponse(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor,
)
},
onRunAgainClicked = { model, message ->
viewModel.runAgain(
context = context,
model = model,
message = message,
primaryColor = primaryColor,
)
},
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
viewModel.benchmark(
context = context,
model = model,
message = message,
warmupCount = warmUpIterations,
iterations = benchmarkIterations,
primaryColor = primaryColor,
)
},
navigateUp = navigateUp,
modifier = modifier,
chatInputType = ChatInputType.IMAGE,
)
}

View file

@ -0,0 +1,165 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.imageclassification
import android.content.Context
import android.graphics.Bitmap
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.ui.common.chat.ChatMessage
import com.google.aiedge.gallery.ui.common.chat.ChatMessageClassification
import com.google.aiedge.gallery.ui.common.chat.ChatMessageImage
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import com.google.aiedge.gallery.ui.common.runBasicBenchmark
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
class ImageClassificationViewModel : ChatViewModel(task = TASK_IMAGE_CLASSIFICATION) {
private val mutex = Mutex()
fun generateResponse(context: Context, model: Model, input: Bitmap, primaryColor: Color) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
val result = ImageClassificationModelHelper.runInference(
context = context, model = model, input = input, primaryColor = primaryColor
)
super.addMessage(
model = model,
message = ChatMessageClassification(
classifications = result.categories,
latencyMs = result.latencyMs,
maxBarWidth = 300.dp,
),
)
}
}
fun generateStreamingResponse(
context: Context,
model: Model,
input: Bitmap,
primaryColor: Color
) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (mutex.tryLock()) {
try {
val result = ImageClassificationModelHelper.runInference(
context = context, model = model, input = input, primaryColor = primaryColor
)
updateStreamingMessage(
model = model,
message = ChatMessageClassification(
classifications = result.categories,
latencyMs = result.latencyMs
)
)
} finally {
mutex.unlock()
}
} else {
// skip call if the previous call has not been finished (mutex is still locked).
}
}
}
fun benchmark(
context: Context,
model: Model,
message: ChatMessage,
warmupCount: Int,
iterations: Int,
primaryColor: Color
) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (message is ChatMessageImage) {
setInProgress(true)
runBasicBenchmark(
model = model,
warmupCount = warmupCount,
iterations = iterations,
chatViewModel = this@ImageClassificationViewModel,
inferenceFn = {
ImageClassificationModelHelper.runInference(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor
)
},
chatMessageType = ChatMessageType.BENCHMARK_RESULT,
)
setInProgress(false)
}
}
}
fun runAgain(context: Context, model: Model, message: ChatMessage, primaryColor: Color) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (message is ChatMessageImage) {
// Clone the clicked message and add it.
addMessage(model = model, message = message.clone())
// Run inference.
val result =
ImageClassificationModelHelper.runInference(
context = context,
model = model,
input = message.bitmap,
primaryColor = primaryColor
)
// Add response message.
val newMessage = generateClassificationMessage(result = result)
addMessage(model = model, message = newMessage)
}
}
}
private fun generateClassificationMessage(result: ImageClassificationInferenceResult): ChatMessageClassification {
return ChatMessageClassification(
classifications = result.categories,
latencyMs = result.latencyMs,
maxBarWidth = 300.dp,
)
}
}

View file

@ -0,0 +1,77 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.imagegeneration
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import com.google.mediapipe.framework.image.BitmapExtractor
import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.LatencyProvider
import kotlin.random.Random
private const val TAG = "AGImageGenerationModelHelper"
class ImageGenerationInferenceResult(
val bitmap: Bitmap, override val latencyMs: Float
) : LatencyProvider
object ImageGenerationModelHelper {
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
val options = ImageGenerator.ImageGeneratorOptions.builder()
.setImageGeneratorModelDirectory(model.getPath(context = context))
.build()
model.instance = ImageGenerator.createFromOptions(context, options)
onDone()
}
fun cleanUp(model: Model) {
if (model.instance == null) {
return
}
val instance = model.instance as ImageGenerator
try {
instance.close()
} catch (e: Exception) {
// ignore
}
model.instance = null
Log.d(TAG, "Clean up done.")
}
fun runInference(
model: Model,
input: String,
onStep: (curIteration: Int, totalIterations: Int, ImageGenerationInferenceResult, isLast: Boolean) -> Unit
) {
val start = System.currentTimeMillis()
val instance = model.instance as ImageGenerator
val iterations = model.getIntConfigValue(ConfigKey.ITERATIONS)
instance.setInputs(input, iterations, Random.nextInt())
for (i in 0..<iterations) {
val result = ImageGenerationInferenceResult(
bitmap = BitmapExtractor.extract(
instance.execute(true)?.generatedImage(),
),
latencyMs = (System.currentTimeMillis() - start).toFloat(),
)
onStep(i, iterations, result, i == iterations - 1)
}
}
}

View file

@ -0,0 +1,64 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.imagegeneration
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatView
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */
object ImageGenerationDestination {
@Serializable
val route = "ImageGenerationRoute"
}
@Composable
fun ImageGenerationScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: ImageGenerationViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
ChatView(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message ->
viewModel.addMessage(
model = model,
message = message,
)
if (message is ChatMessageText) {
viewModel.generateResponse(
model = model,
input = message.content,
)
}
},
onRunAgainClicked = { _, _ -> },
onBenchmarkClicked = { _, _, _, _ -> },
navigateUp = navigateUp,
modifier = modifier,
)
}

View file

@ -0,0 +1,87 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.imagegeneration
import android.graphics.Bitmap
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.graphics.asImageBitmap
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.ui.common.chat.ChatMessageImageWithHistory
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
import com.google.aiedge.gallery.ui.common.chat.ChatSide
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
class ImageGenerationViewModel : ChatViewModel(task = TASK_IMAGE_GENERATION) {
fun generateResponse(model: Model, input: String) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
// Loading.
addMessage(
model = model,
message = ChatMessageLoading(),
)
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
// Run inference.
val bitmaps: MutableList<Bitmap> = mutableListOf()
val imageBitmaps: MutableList<ImageBitmap> = mutableListOf()
ImageGenerationModelHelper.runInference(
model = model, input = input
) { step, totalIterations, result, isLast ->
bitmaps.add(result.bitmap)
imageBitmaps.add(result.bitmap.asImageBitmap())
val message = ChatMessageImageWithHistory(
bitmaps = bitmaps,
imageBitMaps = imageBitmaps,
totalIterations = totalIterations,
side = ChatSide.AGENT,
latencyMs = result.latencyMs,
curIteration = step,
)
if (step == 0) {
removeLastMessage(model = model)
super.addMessage(
model = model,
message = message,
)
} else {
super.replaceLastMessage(
model = model,
message = message,
type = ChatMessageType.IMAGE_WITH_HISTORY
)
}
if (isLast) {
setInProgress(false)
}
}
}
}
}

View file

@ -0,0 +1,84 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmchat
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.ConfigValue
import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.ValueType
import com.google.aiedge.gallery.data.getFloatConfigValue
import com.google.aiedge.gallery.data.getIntConfigValue
private const val DEFAULT_MAX_TOKEN = 1024
private const val DEFAULT_TOPK = 40
private const val DEFAULT_TOPP = 0.9f
private const val DEFAULT_TEMPERATURE = 1.0f
fun createLlmChatConfigs(
defaultMaxToken: Int = DEFAULT_MAX_TOKEN,
defaultTopK: Int = DEFAULT_TOPK,
defaultTopP: Float = DEFAULT_TOPP,
defaultTemperature: Float = DEFAULT_TEMPERATURE
): List<Config> {
return listOf(
NumberSliderConfig(
key = ConfigKey.MAX_TOKENS,
sliderMin = 100f,
sliderMax = 1024f,
defaultValue = defaultMaxToken.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.TOPK,
sliderMin = 5f,
sliderMax = 40f,
defaultValue = defaultTopK.toFloat(),
valueType = ValueType.INT
),
NumberSliderConfig(
key = ConfigKey.TOPP,
sliderMin = 0.0f,
sliderMax = 1.0f,
defaultValue = defaultTopP,
valueType = ValueType.FLOAT
),
NumberSliderConfig(
key = ConfigKey.TEMPERATURE,
sliderMin = 0.0f,
sliderMax = 2.0f,
defaultValue = defaultTemperature,
valueType = ValueType.FLOAT
),
)
}
fun createLLmChatConfig(defaults: Map<String, ConfigValue>): List<Config> {
val defaultMaxToken =
getIntConfigValue(defaults[ConfigKey.MAX_TOKENS.id], default = DEFAULT_MAX_TOKEN)
val defaultTopK = getIntConfigValue(defaults[ConfigKey.TOPK.id], default = DEFAULT_TOPK)
val defaultTopP = getFloatConfigValue(defaults[ConfigKey.TOPP.id], default = DEFAULT_TOPP)
val defaultTemperature =
getFloatConfigValue(defaults[ConfigKey.TEMPERATURE.id], default = DEFAULT_TEMPERATURE)
return createLlmChatConfigs(
defaultMaxToken = defaultMaxToken,
defaultTopK = defaultTopK,
defaultTopP = defaultTopP,
defaultTemperature = defaultTemperature
)
}

View file

@ -0,0 +1,135 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmchat
import android.content.Context
import android.util.Log
import com.google.common.util.concurrent.ListenableFuture
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.LlmBackend
import com.google.aiedge.gallery.data.Model
private const val TAG = "AGLlmChatModelHelper"
private const val DEFAULT_MAX_TOKEN = 1024
private const val DEFAULT_TOPK = 40
private const val DEFAULT_TOPP = 0.9f
private const val DEFAULT_TEMPERATURE = 1.0f
typealias ResultListener = (partialResult: String, done: Boolean) -> Unit
typealias CleanUpListener = () -> Unit
data class LlmModelInstance(val engine: LlmInference, val session: LlmInferenceSession)
object LlmChatModelHelper {
// Indexed by model name.
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
private val generateResponseListenableFutures: MutableMap<String, ListenableFuture<String>> =
mutableMapOf()
fun initialize(
context: Context, model: Model, onDone: () -> Unit
) {
val maxTokens =
model.getIntConfigValue(key = ConfigKey.MAX_TOKENS, defaultValue = DEFAULT_MAX_TOKEN)
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
Log.d(TAG, "Initializing...")
val preferredBackend = when (model.llmBackend) {
LlmBackend.CPU -> LlmInference.Backend.CPU
LlmBackend.GPU -> LlmInference.Backend.GPU
}
val options =
LlmInference.LlmInferenceOptions.builder().setModelPath(model.getPath(context = context))
.setMaxTokens(maxTokens).setPreferredBackend(preferredBackend).build()
// Create an instance of the LLM Inference task
try {
val llmInference = LlmInference.createFromOptions(context, options)
// val session = LlmInferenceSession.createFromOptions(
// llmInference,
// LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
// .setTemperature(temperature).build()
// )
model.instance = llmInference
// LlmModelInstance(engine = llmInference, session = session)
} catch (e: Exception) {
e.printStackTrace()
}
onDone()
}
fun cleanUp(model: Model) {
if (model.instance == null) {
return
}
val instance = model.instance as LlmInference
try {
instance.close()
// instance.session.close()
// instance.engine.close()
} catch (e: Exception) {
// ignore
}
val onCleanUp = cleanUpListeners.remove(model.name)
if (onCleanUp != null) {
onCleanUp()
}
model.instance = null
Log.d(TAG, "Clean up done.")
}
fun runInference(
model: Model,
input: String,
resultListener: ResultListener,
cleanUpListener: CleanUpListener,
) {
val instance = model.instance as LlmInference
// Set listener.
if (!cleanUpListeners.containsKey(model.name)) {
cleanUpListeners[model.name] = cleanUpListener
}
// Start async inference.
val future = instance.generateResponseAsync(input, resultListener)
generateResponseListenableFutures[model.name] = future
// val session = instance.session
// TODO: need to count token and reset session.
// session.addQueryChunk(input)
// session.generateResponseAsync(resultListener)
}
fun stopInference(model: Model) {
val instance = model.instance as LlmInference
if (instance != null) {
instance.close()
}
// val future = generateResponseListenableFutures[model.name]
// if (future != null) {
// future.cancel(true)
// generateResponseListenableFutures.remove(model.name)
// }
}
}

View file

@ -0,0 +1,77 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmchat
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatView
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */
object LlmChatDestination {
@Serializable
val route = "LlmChatRoute"
}
@Composable
fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmChatViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
ChatView(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message ->
viewModel.addMessage(
model = model,
message = message,
)
if (message is ChatMessageText) {
modelManagerViewModel.addTextInputHistory(message.content)
viewModel.generateResponse(
model = model,
input = message.content,
)
}
},
onRunAgainClicked = { model, message ->
if (message is ChatMessageText) {
viewModel.runAgain(model = model, message = message)
}
},
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
if (message is ChatMessageText) {
viewModel.benchmark(
model = model,
message = message
)
}
},
navigateUp = navigateUp,
modifier = modifier,
)
}

View file

@ -0,0 +1,209 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.llmchat
import androidx.lifecycle.viewModelScope
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.aiedge.gallery.ui.common.chat.ChatMessageLoading
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
import com.google.aiedge.gallery.ui.common.chat.ChatSide
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import com.google.aiedge.gallery.ui.common.chat.Stat
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private const val TAG = "AGLlmChatViewModel"
private val STATS = listOf(
Stat(id = "time_to_first_token", label = "Time to 1st token", unit = "sec"),
Stat(id = "prefill_speed", label = "Prefill speed", unit = "tokens/s"),
Stat(id = "decode_speed", label = "Decode speed", unit = "tokens/s"),
Stat(id = "latency", label = "Latency", unit = "sec")
)
class LlmChatViewModel : ChatViewModel(task = TASK_LLM_CHAT) {
fun generateResponse(model: Model, input: String) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
// Loading.
addMessage(
model = model,
message = ChatMessageLoading(),
)
// Wait for instance to be initialized.
while (model.instance == null) {
delay(100)
}
// Run inference.
val start = System.currentTimeMillis()
LlmChatModelHelper.runInference(
model = model,
input = input,
resultListener = { partialResult, done ->
// Remove the last message if it is a "loading" message.
// This will only be done once.
val lastMessage = getLastMessage(model = model)
if (lastMessage?.type == ChatMessageType.LOADING) {
removeLastMessage(model = model)
// Add an empty message that will receive streaming results.
addMessage(
model = model,
message = ChatMessageText(content = "", side = ChatSide.AGENT)
)
}
// Incrementally update the streamed partial results.
val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1
updateLastMessageContentIncrementally(
model = model,
partialContent = partialResult,
latencyMs = latencyMs.toFloat()
)
if (done) {
setInProgress(false)
}
}, cleanUpListener = {
setInProgress(false)
})
}
}
fun runAgain(model: Model, message: ChatMessageText) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
// Clone the clicked message and add it.
addMessage(model = model, message = message.clone())
// Run inference.
generateResponse(
model = model,
input = message.content,
)
}
}
fun benchmark(model: Model, message: ChatMessageText) {
viewModelScope.launch(Dispatchers.Default) {
setInProgress(true)
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
val instance = model.instance as LlmInference
val prefillTokens = instance.sizeInTokens(message.content)
// Add the message to show benchmark results.
val benchmarkLlmResult = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(),
running = true,
latencyMs = -1f,
)
addMessage(model = model, message = benchmarkLlmResult)
// Run inference.
val result = StringBuilder()
var firstRun = true
var timeToFirstToken = 0f
var firstTokenTs = 0L
var decodeTokens = 0
var prefillSpeed = 0f
var decodeSpeed: Float
val start = System.currentTimeMillis()
var lastUpdateTime = 0L
LlmChatModelHelper.runInference(
model = model,
input = message.content,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
if (firstRun) {
firstTokenTs = System.currentTimeMillis()
timeToFirstToken = (firstTokenTs - start) / 1000f
prefillSpeed = prefillTokens / timeToFirstToken
firstRun = false
// Update message to show prefill speed.
replaceLastMessage(
model = model,
message = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = false,
latencyMs = -1f,
),
type = ChatMessageType.BENCHMARK_LLM_RESULT,
)
} else {
decodeTokens++
}
result.append(partialResult)
if (curTs - lastUpdateTime > 500 || done) {
decodeSpeed =
decodeTokens / ((curTs - firstTokenTs) / 1000f)
if (decodeSpeed.isNaN()) {
decodeSpeed = 0f
}
replaceLastMessage(
model = model,
message = ChatMessageBenchmarkLlmResult(
orderedStats = STATS,
statValues = mutableMapOf(
"prefill_speed" to prefillSpeed,
"decode_speed" to decodeSpeed,
"time_to_first_token" to timeToFirstToken,
"latency" to (curTs - start).toFloat() / 1000f,
),
running = !done,
latencyMs = -1f,
),
type = ChatMessageType.BENCHMARK_LLM_RESULT
)
lastUpdateTime = curTs
if (done) {
setInProgress(false)
}
}
},
cleanUpListener = {
setInProgress(false)
}
)
}
}
}

View file

@ -0,0 +1,98 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.modelmanager
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.modelitem.ModelItem
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/** The list of models in the model manager. */
@OptIn(ExperimentalFoundationApi::class)
@Composable
fun ModelList(
task: Task,
modelManagerViewModel: ModelManagerViewModel,
contentPadding: PaddingValues,
onModelClicked: (Model) -> Unit,
modifier: Modifier = Modifier,
) {
LazyColumn(
modifier = modifier.padding(top = 8.dp),
contentPadding = contentPadding,
verticalArrangement = Arrangement.spacedBy(8.dp),
) {
// Headline.
item(key = "headline") {
Text(
task.description,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier
.padding(bottom = 20.dp)
.fillMaxWidth()
)
}
// List of models within a task.
items(items = task.models) { model ->
Box {
ModelItem(
model = model,
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelClicked = onModelClicked,
modifier = Modifier.padding(start = 12.dp, end = 12.dp)
)
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ModelListPreview() {
val context = LocalContext.current
GalleryTheme {
ModelList(
task = TASK_TEST1,
modelManagerViewModel = PreviewModelManagerViewModel(context = context),
onModelClicked = {},
contentPadding = PaddingValues(all = 16.dp),
)
}
}

View file

@ -0,0 +1,130 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.modelmanager
import androidx.activity.compose.BackHandler
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Scaffold
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import com.google.aiedge.gallery.GalleryTopAppBar
import com.google.aiedge.gallery.data.AppBarAction
import com.google.aiedge.gallery.data.AppBarActionType
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
import com.google.aiedge.gallery.ui.theme.GalleryTheme
/** A screen to manage models. */
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelManager(
task: Task,
viewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
onModelClicked: (Model) -> Unit,
modifier: Modifier = Modifier,
) {
val uiState by viewModel.uiState.collectAsState()
val coroutineScope = rememberCoroutineScope()
// Set title based on the task.
var title = "${task.type.label} model"
if (task.models.size != 1) {
title += "s"
}
// Handle system's edge swipe.
BackHandler {
navigateUp()
}
Scaffold(
modifier = modifier,
topBar = {
GalleryTopAppBar(
title = title,
// subtitle = String.format(
// stringResource(R.string.downloaded_size),
// totalSizeInBytes.humanReadableSize()
// ),
// Refresh model list button at the left side of the app bar.
// leftAction = AppBarAction(actionType = if (uiState.loadingHfModels) {
// AppBarActionType.REFRESHING_MODELS
// } else {
// AppBarActionType.REFRESH_MODELS
// }, actionFn = {
// coroutineScope.launch(Dispatchers.IO) {
// viewModel.loadHfModels()
// }
// }),
leftAction = AppBarAction(actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp)
// "Done" button at the right side of the app bar to navigate up.
// rightAction = AppBarAction(
// actionType = AppBarActionType.NAVIGATE_UP, actionFn = navigateUp
// ),
)
},
) { innerPadding ->
ModelList(
task = task,
modelManagerViewModel = viewModel,
contentPadding = innerPadding,
onModelClicked = onModelClicked,
modifier = Modifier.fillMaxSize()
)
}
}
private fun getTotalDownloadedFileSize(uiState: ModelManagerUiState): Long {
var totalSizeInBytes = 0L
for ((name, status) in uiState.modelDownloadStatus.entries) {
if (status.status == ModelDownloadStatusType.SUCCEEDED) {
totalSizeInBytes += getModelByName(name)?.totalBytes ?: 0L
} else if (status.status == ModelDownloadStatusType.IN_PROGRESS) {
totalSizeInBytes += status.receivedBytes
}
}
return totalSizeInBytes
}
@Preview
@Composable
fun ModelManagerPreview() {
val context = LocalContext.current
GalleryTheme {
ModelManager(
viewModel = PreviewModelManagerViewModel(context = context),
onModelClicked = {},
task = TASK_TEST1,
navigateUp = {},
)
}
}

View file

@ -0,0 +1,697 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.modelmanager
import android.content.Context
import android.net.Uri
import android.util.Log
import androidx.activity.result.ActivityResult
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.aiedge.gallery.data.AGWorkInfo
import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.DataStoreRepository
import com.google.aiedge.gallery.data.DownloadRepository
import com.google.aiedge.gallery.data.EMPTY_MODEL
import com.google.aiedge.gallery.data.HfModel
import com.google.aiedge.gallery.data.HfModelDetails
import com.google.aiedge.gallery.data.HfModelSummary
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.TASKS
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.common.AuthConfig
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationModelHelper
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationModelHelper
import com.google.aiedge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.aiedge.gallery.ui.textclassification.TextClassificationModelHelper
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.Json
import net.openid.appauth.AuthorizationException
import net.openid.appauth.AuthorizationRequest
import net.openid.appauth.AuthorizationResponse
import net.openid.appauth.AuthorizationService
import net.openid.appauth.ResponseTypeValues
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
private const val TAG = "AGModelManagerViewModel"
private const val HG_COMMUNITY = "jinjingforevercommunity"
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
enum class ModelInitializationStatus {
NOT_INITIALIZED, INITIALIZING, INITIALIZED,
}
enum class TokenStatus {
NOT_STORED, EXPIRED, NOT_EXPIRED,
}
enum class TokenRequestResultType {
FAILED, SUCCEEDED, USER_CANCELLED
}
data class TokenStatusAndData(
val status: TokenStatus,
val data: AccessTokenData?,
)
data class TokenRequestResult(
val status: TokenRequestResultType,
val errorMessage: String? = null
)
data class ModelManagerUiState(
/**
* A list of tasks available in the application.
*/
val tasks: List<Task>,
/**
* A map that stores lists of models indexed by task name.
*/
val modelsByTaskName: Map<String, MutableList<Model>>,
/**
* A map that tracks the download status of each model, indexed by model name.
*/
val modelDownloadStatus: Map<String, ModelDownloadStatus>,
/**
* A map that tracks the initialization status of each model, indexed by model name.
*/
val modelInitializationStatus: Map<String, ModelInitializationStatus>,
/**
* Whether Hugging Face models from the given community are currently being loaded.
*/
val loadingHfModels: Boolean = false,
/**
* The currently selected model.
*/
val selectedModel: Model = EMPTY_MODEL,
/**
* The history of text inputs entered by the user.
*/
val textInputHistory: List<String> = listOf(),
)
/**
* ViewModel responsible for managing models, their download status, and initialization.
*
* This ViewModel handles model-related operations such as downloading, deleting, initializing,
* and cleaning up models. It also manages the UI state for model management, including the
* list of tasks, models, download statuses, and initialization statuses.
*/
open class ModelManagerViewModel(
private val downloadRepository: DownloadRepository,
private val dataStoreRepository: DataStoreRepository,
context: Context,
) : ViewModel() {
private val externalFilesDir = context.getExternalFilesDir(null)
private val inProgressWorkInfos: List<AGWorkInfo> =
downloadRepository.getEnqueuedOrRunningWorkInfos()
protected val _uiState = MutableStateFlow(createUiState())
val uiState = _uiState.asStateFlow()
val authService = AuthorizationService(context)
var curAccessToken: String = ""
init {
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
// Iterate through the inProgressWorkInfos and retrieve the corresponding modes.
// Those models are the ones that have not finished downloading.
val models: MutableList<Model> = mutableListOf()
for (info in inProgressWorkInfos) {
getModelByName(info.modelName)?.let { model ->
models.add(model)
}
}
// Cancel all pending downloads for the retrieved models.
downloadRepository.cancelAll(models) {
Log.d(TAG, "All pending work is cancelled")
viewModelScope.launch(Dispatchers.IO) {
// Load models from hg community.
loadHfModels()
Log.d(TAG, "Done loading HF models")
// Kick off downloads for these models .
withContext(Dispatchers.Main) {
for (info in inProgressWorkInfos) {
val model: Model? = getModelByName(info.modelName)
if (model != null) {
Log.d(TAG, "Sending a new download request for '${model.name}'")
downloadRepository.downloadModel(
model, onStatusUpdated = this@ModelManagerViewModel::setDownloadStatus
)
}
}
}
}
}
}
override fun onCleared() {
super.onCleared()
authService.dispose()
}
fun selectModel(model: Model) {
_uiState.update { _uiState.value.copy(selectedModel = model) }
}
fun downloadModel(model: Model) {
// Update status.
setDownloadStatus(
curModel = model, status = ModelDownloadStatus(status = ModelDownloadStatusType.IN_PROGRESS)
)
// Delete the model files first.
deleteModel(model = model)
// Start to send download request.
downloadRepository.downloadModel(
model, onStatusUpdated = this::setDownloadStatus
)
}
fun cancelDownloadModel(model: Model) {
downloadRepository.cancelDownloadModel(model)
}
fun deleteModel(model: Model) {
deleteFileFromExternalFilesDir(model.downloadFileName)
for (file in model.extraDataFiles) {
deleteFileFromExternalFilesDir(file.downloadFileName)
}
if (model.isZip && model.unzipDir.isNotEmpty()) {
deleteDirFromExternalFilesDir(model.unzipDir)
}
// Update model download status to NotDownloaded.
val curModelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
curModelDownloadStatus[model.name] =
ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED)
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
_uiState.update { newUiState }
}
fun initializeModel(context: Context, model: Model, force: Boolean = false) {
viewModelScope.launch(Dispatchers.Default) {
// Skip if initialized already.
if (!force && uiState.value.modelInitializationStatus[model.name] == ModelInitializationStatus.INITIALIZED) {
Log.d(TAG, "Model '${model.name}' has been initialized. Skipping.")
return@launch
}
// Skip if initialization is in progress.
if (model.initializing) {
Log.d(TAG, "Model '${model.name}' is being initialized. Skipping.")
return@launch
}
// Clean up.
cleanupModel(model = model)
// Start initialization.
Log.d(TAG, "Initializing model '${model.name}'...")
model.initializing = true
// Show initializing status after a delay. When the delay expires, check if the model has
// been initialized or not. If so, skip.
launch {
delay(500)
if (model.instance == null) {
updateModelInitializationStatus(
model = model, status = ModelInitializationStatus.INITIALIZING
)
}
}
val onDone: () -> Unit = {
if (model.instance != null) {
Log.d(TAG, "Model '${model.name}' initialized successfully")
model.initializing = false
updateModelInitializationStatus(
model = model,
status = ModelInitializationStatus.INITIALIZED,
)
}
}
when (model.taskType) {
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
)
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
)
TaskType.LLM_CHAT -> LlmChatModelHelper.initialize(
context = context,
model = model,
onDone = onDone,
)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.initialize(
context = context, model = model, onDone = onDone
)
else -> {}
}
}
}
fun cleanupModel(model: Model) {
if (model.instance != null) {
Log.d(TAG, "Cleaning up model '${model.name}'...")
when (model.taskType) {
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
TaskType.IMAGE_CLASSIFICATION -> ImageClassificationModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.IMAGE_GENERATION -> ImageGenerationModelHelper.cleanUp(model = model)
else -> {}
}
model.instance = null
model.initializing = false
updateModelInitializationStatus(
model = model, status = ModelInitializationStatus.NOT_INITIALIZED
)
}
}
fun setDownloadStatus(curModel: Model, status: ModelDownloadStatus) {
// Update model download progress.
val curModelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
curModelDownloadStatus[curModel.name] = status
val newUiState = uiState.value.copy(modelDownloadStatus = curModelDownloadStatus)
// Delete downloaded file if status is failed or not_downloaded.
if (status.status == ModelDownloadStatusType.FAILED || status.status == ModelDownloadStatusType.NOT_DOWNLOADED) {
deleteFileFromExternalFilesDir(curModel.downloadFileName)
}
_uiState.update { newUiState }
}
fun addTextInputHistory(text: String) {
if (uiState.value.textInputHistory.indexOf(text) < 0) {
val newHistory = uiState.value.textInputHistory.toMutableList()
newHistory.add(0, text)
if (newHistory.size > TEXT_INPUT_HISTORY_MAX_SIZE) {
newHistory.removeAt(newHistory.size - 1)
}
_uiState.update { _uiState.value.copy(textInputHistory = newHistory) }
dataStoreRepository.saveTextInputHistory(_uiState.value.textInputHistory)
}
}
fun promoteTextInputHistoryItem(text: String) {
val index = uiState.value.textInputHistory.indexOf(text)
if (index >= 0) {
val newHistory = uiState.value.textInputHistory.toMutableList()
newHistory.removeAt(index)
newHistory.add(0, text)
_uiState.update { _uiState.value.copy(textInputHistory = newHistory) }
dataStoreRepository.saveTextInputHistory(_uiState.value.textInputHistory)
}
}
fun deleteTextInputHistory(text: String) {
val index = uiState.value.textInputHistory.indexOf(text)
if (index >= 0) {
val newHistory = uiState.value.textInputHistory.toMutableList()
newHistory.removeAt(index)
_uiState.update { _uiState.value.copy(textInputHistory = newHistory) }
dataStoreRepository.saveTextInputHistory(_uiState.value.textInputHistory)
}
}
fun clearTextInputHistory() {
_uiState.update { _uiState.value.copy(textInputHistory = mutableListOf()) }
dataStoreRepository.saveTextInputHistory(_uiState.value.textInputHistory)
}
fun readThemeOverride(): String {
return dataStoreRepository.readThemeOverride()
}
fun saveThemeOverride(theme: String) {
dataStoreRepository.saveThemeOverride(theme = theme)
}
fun getModelUrlResponse(model: Model, accessToken: String? = null): Int {
val url = URL(model.url)
val connection = url.openConnection() as HttpURLConnection
if (accessToken != null) {
connection.setRequestProperty(
"Authorization",
"Bearer $accessToken"
)
}
connection.connect()
// Report the result.
return connection.responseCode
}
fun getTokenStatusAndData(): TokenStatusAndData {
// Try to load token data from DataStore.
var tokenStatus = TokenStatus.NOT_STORED
Log.d(TAG, "Reading token data from data store...")
val tokenData = dataStoreRepository.readAccessTokenData()
// Token exists.
if (tokenData != null) {
Log.d(TAG, "Token exists and loaded.")
// Check expiration (with 5-minute buffer).
val curTs = System.currentTimeMillis()
val expirationTs = tokenData.expiresAtSeconds - 5 * 60
Log.d(
TAG,
"Checking whether token has expired or not. Current ts: $curTs, expires at: $expirationTs"
)
if (curTs >= expirationTs) {
Log.d(TAG, "Token expired!")
tokenStatus = TokenStatus.EXPIRED
} else {
Log.d(TAG, "Token not expired.")
tokenStatus = TokenStatus.NOT_EXPIRED
curAccessToken = tokenData.accessToken
}
} else {
Log.d(TAG, "Token doesn't exists.")
}
return TokenStatusAndData(status = tokenStatus, data = tokenData)
}
fun getAuthorizationRequest(): AuthorizationRequest {
return AuthorizationRequest.Builder(
AuthConfig.authServiceConfig,
AuthConfig.clientId,
ResponseTypeValues.CODE,
Uri.parse(AuthConfig.redirectUri)
).setScope("read-repos").build()
}
fun handleAuthResult(result: ActivityResult, onTokenRequested: (TokenRequestResult) -> Unit) {
val dataIntent = result.data
if (dataIntent == null) {
onTokenRequested(
TokenRequestResult(
status = TokenRequestResultType.FAILED,
errorMessage = "Empty auth result"
)
)
return
}
val response = AuthorizationResponse.fromIntent(dataIntent)
val exception = AuthorizationException.fromIntent(dataIntent)
when {
response?.authorizationCode != null -> {
// Authorization successful, exchange the code for tokens
var errorMessage: String? = null
authService.performTokenRequest(
response.createTokenExchangeRequest()
) { tokenResponse, tokenEx ->
if (tokenResponse != null) {
if (tokenResponse.accessToken == null) {
errorMessage = "Empty access token"
} else if (tokenResponse.refreshToken == null) {
errorMessage = "Empty refresh token"
} else if (tokenResponse.accessTokenExpirationTime == null) {
errorMessage = "Empty expiration time"
} else {
// Token exchange successful. Store the tokens securely
Log.d(TAG, "Token exchange successful. Storing tokens...")
dataStoreRepository.saveAccessTokenData(
accessToken = tokenResponse.accessToken!!,
refreshToken = tokenResponse.refreshToken!!,
expiresAt = tokenResponse.accessTokenExpirationTime!!
)
curAccessToken = tokenResponse.accessToken!!
Log.d(TAG, "Token successfully saved.")
}
} else if (tokenEx != null) {
errorMessage = "Token exchange failed: ${tokenEx.message}"
} else {
errorMessage = "Token exchange failed"
}
if (errorMessage == null) {
onTokenRequested(TokenRequestResult(status = TokenRequestResultType.SUCCEEDED))
} else {
onTokenRequested(
TokenRequestResult(
status = TokenRequestResultType.FAILED,
errorMessage = errorMessage
)
)
}
}
}
exception != null -> {
onTokenRequested(
TokenRequestResult(
status = if (exception.message == "User cancelled flow") TokenRequestResultType.USER_CANCELLED else TokenRequestResultType.FAILED,
errorMessage = "${exception.message}"
)
)
}
else -> {
onTokenRequested(
TokenRequestResult(
status = TokenRequestResultType.USER_CANCELLED,
)
)
}
}
}
private fun isModelPartiallyDownloaded(model: Model): Boolean {
return inProgressWorkInfos.find { it.modelName == model.name } != null
}
private fun createUiState(): ModelManagerUiState {
val modelsByTaskName: Map<String, MutableList<Model>> =
TASKS.associate { task -> task.type.label to task.models }
val modelDownloadStatus: MutableMap<String, ModelDownloadStatus> = mutableMapOf()
val modelInstances: MutableMap<String, ModelInitializationStatus> = mutableMapOf()
for ((_, models) in modelsByTaskName.entries) {
for (model in models) {
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED
}
}
val textInputHistory = dataStoreRepository.readTextInputHistory()
Log.d(TAG, "text input history: $textInputHistory")
return ModelManagerUiState(
tasks = TASKS,
modelsByTaskName = modelsByTaskName,
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances,
textInputHistory = textInputHistory,
)
}
/**
* Retrieves the download status of a model.
*
* This function determines the download status of a given model by checking if it's fully
* downloaded, partially downloaded, or not downloaded at all. It also retrieves the received and
* total bytes for partially downloaded models.
*/
private fun getModelDownloadStatus(model: Model): ModelDownloadStatus {
var status = ModelDownloadStatusType.NOT_DOWNLOADED
var receivedBytes = 0L
var totalBytes = 0L
if (isModelDownloaded(model = model)) {
if (isModelPartiallyDownloaded(model = model)) {
status = ModelDownloadStatusType.PARTIALLY_DOWNLOADED
val file = File(externalFilesDir, model.downloadFileName)
receivedBytes = file.length()
totalBytes = model.totalBytes
} else {
status = ModelDownloadStatusType.SUCCEEDED
}
}
return ModelDownloadStatus(
status = status, receivedBytes = receivedBytes, totalBytes = totalBytes
)
}
suspend fun loadHfModels() {
// Update loading state shown in ui.
_uiState.update {
uiState.value.copy(
loadingHfModels = true,
)
}
val modelDownloadStatus = uiState.value.modelDownloadStatus.toMutableMap()
val modelInstances = uiState.value.modelInitializationStatus.toMutableMap()
try {
// Load model summaries.
val modelSummaries =
getJsonResponse<List<HfModelSummary>>(url = "https://huggingface.co/api/models?search=$HG_COMMUNITY")
Log.d(TAG, "HF model summaries: $modelSummaries")
// Load individual models in parallel.
if (modelSummaries != null) {
coroutineScope {
val hfModels = modelSummaries.map { summary ->
async {
val details =
getJsonResponse<HfModelDetails>(url = "https://huggingface.co/api/models/${summary.modelId}")
if (details != null && details.siblings.find { it.rfilename == "app.json" } != null) {
val hfModel =
getJsonResponse<HfModel>(url = "https://huggingface.co/${summary.modelId}/resolve/main/app.json")
if (hfModel != null) {
hfModel.id = details.id
}
return@async hfModel
}
return@async null
}
}
// Process loaded app.json
for (hfModel in hfModels.awaitAll()) {
if (hfModel != null) {
Log.d(TAG, "HF model: $hfModel")
val task = TASKS.find { it.type.label == hfModel.task }
val model = hfModel.toModel()
if (task != null && task.models.find { it.hfModelId == model.hfModelId } == null) {
model.preProcess(task = task)
Log.d(TAG, "AG model: $model")
task.models.add(model)
// Add initial status and states.
modelDownloadStatus[model.name] = getModelDownloadStatus(model = model)
modelInstances[model.name] = ModelInitializationStatus.NOT_INITIALIZED
}
}
}
}
}
_uiState.update {
uiState.value.copy(
loadingHfModels = false,
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = modelInstances
)
}
} catch (e: Exception) {
e.printStackTrace()
}
}
private inline fun <reified T> getJsonResponse(url: String): T? {
try {
val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET"
connection.connect()
val responseCode = connection.responseCode
if (responseCode == HttpURLConnection.HTTP_OK) {
val inputStream = connection.inputStream
val response = inputStream.bufferedReader().use { it.readText() }
// Parse JSON using kotlinx.serialization
val json = Json { ignoreUnknownKeys = true } // Handle potential extra fields
val jsonObj = json.decodeFromString<T>(response)
return jsonObj
} else {
println("HTTP error: $responseCode")
}
} catch (e: Exception) {
e.printStackTrace()
}
return null
}
private fun isFileInExternalFilesDir(fileName: String): Boolean {
if (externalFilesDir != null) {
val file = File(externalFilesDir, fileName)
return file.exists()
} else {
return false
}
}
private fun deleteFileFromExternalFilesDir(fileName: String) {
if (isFileInExternalFilesDir(fileName)) {
val file = File(externalFilesDir, fileName)
file.delete()
}
}
private fun deleteDirFromExternalFilesDir(dir: String) {
if (isFileInExternalFilesDir(dir)) {
val file = File(externalFilesDir, dir)
file.deleteRecursively()
}
}
private fun updateModelInitializationStatus(model: Model, status: ModelInitializationStatus) {
val curModelInstance = uiState.value.modelInitializationStatus.toMutableMap()
curModelInstance[model.name] = status
val newUiState = uiState.value.copy(modelInitializationStatus = curModelInstance)
_uiState.update { newUiState }
}
private fun isModelDownloaded(model: Model): Boolean {
val downloadedFileExists =
model.downloadFileName.isNotEmpty() && isFileInExternalFilesDir(model.downloadFileName)
val unzippedDirectoryExists =
model.isZip && model.unzipDir.isNotEmpty() && isFileInExternalFilesDir(model.unzipDir)
// Will also return true if model is partially downloaded.
return downloadedFileExists || unzippedDirectoryExists
}
}

View file

@ -0,0 +1,265 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.navigation
import android.util.Log
import androidx.compose.animation.AnimatedContentTransitionScope
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.EnterTransition
import androidx.compose.animation.ExitTransition
import androidx.compose.animation.core.EaseOutExpo
import androidx.compose.animation.core.FiniteAnimationSpec
import androidx.compose.animation.core.tween
import androidx.compose.animation.slideInHorizontally
import androidx.compose.animation.slideOutHorizontally
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.zIndex
import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.navigation.NavBackStackEntry
import androidx.navigation.NavHostController
import androidx.navigation.NavType
import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.navArgument
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_IMAGE_CLASSIFICATION
import com.google.aiedge.gallery.data.TASK_IMAGE_GENERATION
import com.google.aiedge.gallery.data.TASK_LLM_CHAT
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.getModelByName
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.home.HomeScreen
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationDestination
import com.google.aiedge.gallery.ui.imageclassification.ImageClassificationScreen
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationDestination
import com.google.aiedge.gallery.ui.imagegeneration.ImageGenerationScreen
import com.google.aiedge.gallery.ui.llmchat.LlmChatDestination
import com.google.aiedge.gallery.ui.llmchat.LlmChatScreen
import com.google.aiedge.gallery.ui.modelmanager.ModelManager
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.textclassification.TextClassificationDestination
import com.google.aiedge.gallery.ui.textclassification.TextClassificationScreen
private const val TAG = "AGGalleryNavGraph"
private const val ROUTE_PLACEHOLDER = "placeholder"
private const val ENTER_ANIMATION_DURATION_MS = 500
private val ENTER_ANIMATION_EASING = EaseOutExpo
private const val ENTER_ANIMATION_DELAY_MS = 100
private const val EXIT_ANIMATION_DURATION_MS = 500
private val EXIT_ANIMATION_EASING = EaseOutExpo
private fun enterTween(): FiniteAnimationSpec<IntOffset> {
return tween(
ENTER_ANIMATION_DURATION_MS,
easing = ENTER_ANIMATION_EASING,
delayMillis = ENTER_ANIMATION_DELAY_MS
)
}
private fun exitTween(): FiniteAnimationSpec<IntOffset> {
return tween(EXIT_ANIMATION_DURATION_MS, easing = EXIT_ANIMATION_EASING)
}
private fun AnimatedContentTransitionScope<*>.slideEnter(): EnterTransition {
return slideIntoContainer(
animationSpec = enterTween(),
towards = AnimatedContentTransitionScope.SlideDirection.Left,
)
}
private fun AnimatedContentTransitionScope<*>.slideExit(): ExitTransition {
return slideOutOfContainer(
animationSpec = exitTween(),
towards = AnimatedContentTransitionScope.SlideDirection.Right,
)
}
/**
* Navigation routes.
*/
@Composable
fun GalleryNavHost(
navController: NavHostController,
modifier: Modifier = Modifier,
modelManagerViewModel: ModelManagerViewModel = viewModel(factory = ViewModelProvider.Factory)
) {
var showModelManager by remember { mutableStateOf(false) }
var pickedTask by remember { mutableStateOf<Task?>(null) }
HomeScreen(
modelManagerViewModel = modelManagerViewModel,
navigateToTaskScreen = { task ->
pickedTask = task
showModelManager = true
},
)
// Model manager.
AnimatedVisibility(
visible = showModelManager,
enter = slideInHorizontally(initialOffsetX = { it }),
exit = slideOutHorizontally(targetOffsetX = { it }),
) {
val curPickedTask = pickedTask
if (curPickedTask != null) {
ModelManager(
viewModel = modelManagerViewModel,
task = curPickedTask,
onModelClicked = { model ->
navigateToTaskScreen(
navController = navController, taskType = model.taskType!!, model = model
)
},
navigateUp = { showModelManager = false })
}
}
NavHost(
navController = navController,
// Default to open home screen.
startDestination = ROUTE_PLACEHOLDER,
enterTransition = { EnterTransition.None },
exitTransition = { ExitTransition.None },
modifier = modifier.zIndex(1f)
) {
// Placeholder root screen
composable(
route = ROUTE_PLACEHOLDER,
) {
Text("")
}
// Text classification.
composable(
route = "${TextClassificationDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_TEXT_CLASSIFICATION)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
TextClassificationScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
// Image classification.
composable(
route = "${ImageClassificationDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_IMAGE_CLASSIFICATION)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
ImageClassificationScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
// Image generation.
composable(
route = "${ImageGenerationDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_IMAGE_GENERATION)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
ImageGenerationScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
// LLMm chat demos.
composable(
route = "${LlmChatDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_CHAT)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmChatScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
}
// Handle incoming intents for deep links
val intent = androidx.activity.compose.LocalActivity.current?.intent
val data = intent?.data
if (data != null) {
intent.data = null
Log.d(TAG, "navigation link clicked: $data")
if (data.toString().startsWith("com.google.aiedge.gallery://model/")) {
val modelName = data.pathSegments.last()
getModelByName(modelName)?.let { model ->
navigateToTaskScreen(
navController = navController,
taskType = model.taskType!!,
model = model
)
}
}
}
}
fun navigateToTaskScreen(
navController: NavHostController, taskType: TaskType, model: Model? = null
) {
val modelName = model?.name ?: ""
when (taskType) {
TaskType.TEXT_CLASSIFICATION -> navController.navigate("${TextClassificationDestination.route}/${modelName}")
TaskType.IMAGE_CLASSIFICATION -> navController.navigate("${ImageClassificationDestination.route}/${modelName}")
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.IMAGE_GENERATION -> navController.navigate("${ImageGenerationDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
}
fun getModelFromNavigationParam(entry: NavBackStackEntry, task: Task): Model? {
var modelName = entry.arguments?.getString("modelName") ?: ""
if (modelName.isEmpty()) {
modelName = task.models[0].name
}
val model = getModelByName(modelName)
return model
}

View file

@ -0,0 +1,90 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.preview
import android.content.Context
import android.graphics.Bitmap
import android.graphics.Canvas
import android.graphics.drawable.Drawable
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.core.content.ContextCompat
import com.google.aiedge.gallery.R
import com.google.aiedge.gallery.ui.common.chat.ChatMessageClassification
import com.google.aiedge.gallery.ui.common.chat.ChatMessageImage
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatSide
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import com.google.aiedge.gallery.ui.common.chat.Classification
class PreviewChatModel(context: Context) : ChatViewModel(task = TASK_TEST1) {
init {
val model = task.models[1]
addMessage(
model = model,
message = ChatMessageText(
content = "Thanks everyone for your enthusiasm on the team lunch, but people who can sign on the cheque is OOO next week \uD83D\uDE02,",
side = ChatSide.USER
),
)
addMessage(
model = model,
message = ChatMessageText(
content = "Today is Wednesday!", side = ChatSide.AGENT, latencyMs = 1232f
),
)
addMessage(
model = model,
message = ChatMessageClassification(
classifications = listOf(
Classification(label = "label1", score = 0.3f, color = Color.Red),
Classification(label = "label2", score = 0.7f, color = Color.Blue)
),
latencyMs = 12345f,
),
)
val bitmap =
getBitmapFromVectorDrawable(
context = context,
drawableId = R.drawable.ic_launcher_background
)!!
addMessage(
model = model,
message = ChatMessageImage(
bitmap = bitmap,
imageBitMap = bitmap.asImageBitmap(),
side = ChatSide.USER,
),
)
}
private fun getBitmapFromVectorDrawable(context: Context, drawableId: Int): Bitmap? {
val drawable: Drawable = ContextCompat.getDrawable(context, drawableId)
?: return null // Drawable not found
val bitmap = Bitmap.createBitmap(
drawable.intrinsicWidth,
drawable.intrinsicHeight,
Bitmap.Config.ARGB_8888
)
val canvas = Canvas(bitmap)
drawable.setBounds(0, 0, canvas.width, canvas.height)
drawable.draw(canvas)
return bitmap
}
}

View file

@ -0,0 +1,43 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.preview
import com.google.aiedge.gallery.data.AccessTokenData
import com.google.aiedge.gallery.data.DataStoreRepository
class PreviewDataStoreRepository : DataStoreRepository {
override fun saveTextInputHistory(history: List<String>) {
}
override fun readTextInputHistory(): List<String> {
return listOf()
}
override fun saveThemeOverride(theme: String) {
}
override fun readThemeOverride(): String {
return ""
}
override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) {
}
override fun readAccessTokenData(): AccessTokenData? {
return null
}
}

View file

@ -0,0 +1,47 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.preview
import com.google.aiedge.gallery.data.AGWorkInfo
import com.google.aiedge.gallery.data.DownloadRepository
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus
import java.util.UUID
class PreviewDownloadRepository : DownloadRepository {
override fun downloadModel(
model: Model, onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit
) {
}
override fun cancelDownloadModel(model: Model) {
}
override fun cancelAll(models: List<Model>, onComplete: () -> Unit) {
}
override fun observerWorkerProgress(
workerId: UUID,
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit
) {
}
override fun getEnqueuedOrRunningWorkInfos(): List<AGWorkInfo> {
return listOf()
}
}

View file

@ -0,0 +1,71 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.preview
import android.content.Context
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatus
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerUiState
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.coroutines.flow.update
class PreviewModelManagerViewModel(context: Context) :
ModelManagerViewModel(
downloadRepository = PreviewDownloadRepository(),
dataStoreRepository = PreviewDataStoreRepository(),
context = context
) {
init {
for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
task.index = index
for (model in task.models) {
model.preProcess(task = task)
}
}
val modelsByTaskName: Map<String, MutableList<Model>> =
ALL_PREVIEW_TASKS.associate { task -> task.type.label to task.models }
val modelDownloadStatus = mapOf(
MODEL_TEST1.name to ModelDownloadStatus(
status = ModelDownloadStatusType.IN_PROGRESS,
receivedBytes = 1234,
totalBytes = 3456,
bytesPerSecond = 2333,
remainingMs = 324,
),
MODEL_TEST2.name to ModelDownloadStatus(
status = ModelDownloadStatusType.SUCCEEDED
),
MODEL_TEST3.name to ModelDownloadStatus(
status = ModelDownloadStatusType.FAILED, errorMessage = "Http code 404"
),
MODEL_TEST4.name to ModelDownloadStatus(
status = ModelDownloadStatusType.NOT_DOWNLOADED
),
)
val newUiState = ModelManagerUiState(
tasks = ALL_PREVIEW_TASKS,
modelsByTaskName = modelsByTaskName,
modelDownloadStatus = modelDownloadStatus,
modelInitializationStatus = mapOf(),
selectedModel = MODEL_TEST2,
)
_uiState.update { newUiState }
}
}

View file

@ -0,0 +1,96 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.preview
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.AccountBox
import androidx.compose.material.icons.rounded.AutoAwesome
import com.google.aiedge.gallery.data.BooleanSwitchConfig
import com.google.aiedge.gallery.data.Config
import com.google.aiedge.gallery.data.ConfigKey
import com.google.aiedge.gallery.data.SegmentedButtonConfig
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.NumberSliderConfig
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.data.TaskType
import com.google.aiedge.gallery.data.ValueType
val TEST_CONFIGS1: List<Config> = listOf(
NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f,
sliderMax = 5f,
defaultValue = 3f,
valueType = ValueType.INT
), BooleanSwitchConfig(
key = ConfigKey.USE_GPU,
defaultValue = false,
), SegmentedButtonConfig(
key = ConfigKey.THEME,
defaultValue = "Auto",
options = listOf("Auto", "Light", "Dark")
)
)
val MODEL_TEST1: Model = Model(
name = "deterministic3",
downloadFileName = "deterministric3.json",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/deterministic3.json",
sizeInBytes = 40146048L,
configs = TEST_CONFIGS1,
)
val MODEL_TEST2: Model = Model(
name = "isnet",
downloadFileName = "isnet.tflite",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/isnet-general-use-int8.tflite",
sizeInBytes = 44366296L,
configs = TEST_CONFIGS1,
)
val MODEL_TEST3: Model = Model(
name = "yolo",
downloadFileName = "yolo.json",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/yolo.json",
sizeInBytes = 40641364L
)
val MODEL_TEST4: Model = Model(
name = "mobilenet v3",
downloadFileName = "mobilenet_v3_large.pt2",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/mobilenet_v3_large.pt2",
sizeInBytes = 277135998L
)
val TASK_TEST1 = Task(
type = TaskType.TEST_TASK_1,
icon = Icons.Rounded.AutoAwesome,
models = mutableListOf(MODEL_TEST1, MODEL_TEST2),
description = "This is a test task (1)"
)
val TASK_TEST2 = Task(
type = TaskType.TEST_TASK_2,
icon = Icons.Rounded.AccountBox,
models = mutableListOf(MODEL_TEST3, MODEL_TEST4),
description = "This is a test task (2)"
)
val ALL_PREVIEW_TASKS: List<Task> = listOf(
TASK_TEST1,
TASK_TEST2,
)

View file

@ -0,0 +1,95 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.textclassification
import android.content.Context
import android.util.Log
import com.google.mediapipe.tasks.components.containers.Category
import com.google.mediapipe.tasks.core.BaseOptions
import com.google.mediapipe.tasks.text.textclassifier.TextClassifier
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.ui.common.LatencyProvider
import java.io.File
import java.io.FileInputStream
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
private const val TAG = "AGTextClassificationModelHelper"
class TextClassificationInferenceResult(
val categories: List<Category>, override val latencyMs: Float
) : LatencyProvider
// TODO: handle error.
/**
* Helper object for managing text classification models.
*/
object TextClassificationModelHelper {
fun initialize(context: Context, model: Model, onDone: () -> Unit) {
val modelByteBuffer = readFileToByteBuffer(File(model.getPath(context = context)))
if (modelByteBuffer != null) {
val options = TextClassifier.TextClassifierOptions.builder().setBaseOptions(
BaseOptions.builder().setModelAssetBuffer(modelByteBuffer).build()
).build()
model.instance = TextClassifier.createFromOptions(context, options)
onDone()
}
}
fun runInference(model: Model, input: String): TextClassificationInferenceResult {
val instance = model.instance
val start = System.currentTimeMillis()
val classifier: TextClassifier = instance as TextClassifier
val result = classifier.classify(input)
val categories = result.classificationResult().classifications().first().categories()
val latencyMs = (System.currentTimeMillis() - start).toFloat()
return TextClassificationInferenceResult(categories = categories, latencyMs = latencyMs)
}
fun cleanUp(model: Model) {
if (model.instance == null) {
return
}
val instance = model.instance as TextClassifier
try {
instance.close()
} catch (e: Exception) {
// ignore
}
model.instance = null
Log.d(TAG, "Clean up done.")
}
private fun readFileToByteBuffer(file: File): ByteBuffer? {
return try {
val fileInputStream = FileInputStream(file)
val fileChannel: FileChannel = fileInputStream.channel
val byteBuffer = ByteBuffer.allocateDirect(fileChannel.size().toInt())
fileChannel.read(byteBuffer)
byteBuffer.rewind()
fileInputStream.close()
byteBuffer
} catch (e: Exception) {
e.printStackTrace()
null
}
}
}

View file

@ -0,0 +1,74 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.textclassification
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatView
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.serialization.Serializable
/** Navigation destination data */
object TextClassificationDestination {
@Serializable
val route = "TextClassificationRoute"
}
@Composable
fun TextClassificationScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: TextClassificationViewModel = viewModel(
factory = ViewModelProvider.Factory
),
) {
ChatView(
task = viewModel.task,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
onSendMessage = { model, message ->
viewModel.addMessage(
model = model,
message = message,
)
if (message is ChatMessageText) {
modelManagerViewModel.addTextInputHistory(message.content)
viewModel.generateResponse(
model = model,
input = message.content,
)
}
},
onRunAgainClicked = { model, message ->
viewModel.runAgain(model = model, message = message)
},
onBenchmarkClicked = { model, message, warmUpIterations, benchmarkIterations ->
viewModel.benchmark(
model = model,
message = message,
warmupCount = warmUpIterations,
itertations = benchmarkIterations,
)
},
navigateUp = navigateUp,
modifier = modifier
)
}

View file

@ -0,0 +1,128 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.textclassification
import android.util.Log
import androidx.compose.ui.graphics.Color
import androidx.lifecycle.viewModelScope
import com.google.mediapipe.tasks.components.containers.Category
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_TEXT_CLASSIFICATION
import com.google.aiedge.gallery.ui.common.chat.ChatMessage
import com.google.aiedge.gallery.ui.common.chat.ChatMessageClassification
import com.google.aiedge.gallery.ui.common.chat.ChatMessageText
import com.google.aiedge.gallery.ui.common.chat.ChatMessageType
import com.google.aiedge.gallery.ui.common.chat.ChatViewModel
import com.google.aiedge.gallery.ui.common.chat.Classification
import com.google.aiedge.gallery.ui.common.getDistinctiveColor
import com.google.aiedge.gallery.ui.common.runBasicBenchmark
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
private const val TAG = "AGTextClassificationViewModel"
class TextClassificationViewModel : ChatViewModel(task = TASK_TEXT_CLASSIFICATION) {
fun generateResponse(model: Model, input: String) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
val result = TextClassificationModelHelper.runInference(model = model, input = input)
Log.d(TAG, "$result")
addMessage(
model = model,
message = generateClassificationMessage(result = result),
)
}
}
fun runAgain(model: Model, message: ChatMessage) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (message is ChatMessageText) {
// Clone the clicked message and add it.
addMessage(model = model, message = message.clone())
// Run inference.
val result =
TextClassificationModelHelper.runInference(model = model, input = message.content)
// Add response message.
val newMessage = generateClassificationMessage(result = result)
addMessage(
model = model,
message = newMessage,
)
}
}
}
fun benchmark(
model: Model, message: ChatMessage, warmupCount: Int, itertations: Int
) {
viewModelScope.launch(Dispatchers.Default) {
// Wait for model to be initialized.
while (model.instance == null) {
delay(100)
}
if (message is ChatMessageText) {
setInProgress(true)
runBasicBenchmark(
model = model,
warmupCount = warmupCount,
iterations = itertations,
chatViewModel = this@TextClassificationViewModel,
inferenceFn = {
TextClassificationModelHelper.runInference(model = model, input = message.content)
},
chatMessageType = ChatMessageType.BENCHMARK_RESULT,
)
setInProgress(false)
}
}
}
private fun generateClassificationMessage(result: TextClassificationInferenceResult): ChatMessageClassification {
return ChatMessageClassification(classifications = result.categories.mapIndexed { index, category ->
val color = when (category.categoryName().lowercase()) {
"negative", "0" -> Color(0xffe6194B)
"positive", "1" -> Color(0xff3cb44b)
else -> getDistinctiveColor(index)
}
category.toClassification(color = color)
}.sortedBy { it.label }, latencyMs = result.latencyMs)
}
}
fun Category.toClassification(color: Color): Classification {
var categoryName = this.categoryName()
if (categoryName == "0") {
categoryName = "negative"
} else if (categoryName == "1") {
categoryName = "positive"
}
return Classification(label = categoryName, score = this.score(), color = color)
}

View file

@ -0,0 +1,92 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.theme
import androidx.compose.ui.graphics.Color
//val primaryLight = Color(0xFF32628D)
val primaryLight = Color(0xFF1F1F1F)
val onPrimaryLight = Color(0xFFFFFFFF)
val primaryContainerLight = Color(0xFFD0E4FF)
val onPrimaryContainerLight = Color(0xFF144A74)
val secondaryLight = Color(0xFF526070)
val onSecondaryLight = Color(0xFFFFFFFF)
val secondaryContainerLight = Color(0xFFD6E4F7)
val onSecondaryContainerLight = Color(0xFF3B4857)
val tertiaryLight = Color(0xFF775A0B)
val onTertiaryLight = Color(0xFFFFFFFF)
val tertiaryContainerLight = Color(0xFFFFDF9B)
val onTertiaryContainerLight = Color(0xFF5B4300)
val errorLight = Color(0xFF904A43)
val onErrorLight = Color(0xFFFFFFFF)
val errorContainerLight = Color(0xFFFFDAD5)
val onErrorContainerLight = Color(0xFF73342D)
val backgroundLight = Color(0xFFF8F9FF)
val onBackgroundLight = Color(0xFF191C20)
val surfaceLight = Color(0xFFF8F9FF)
val onSurfaceLight = Color(0xFF191C20)
val surfaceVariantLight = Color(0xFFDEE3EB)
val onSurfaceVariantLight = Color(0xFF42474E)
val outlineLight = Color(0xFF73777F)
val outlineVariantLight = Color(0xFFC2C7CF)
val scrimLight = Color(0xFF000000)
val inverseSurfaceLight = Color(0xFF2D3135)
val inverseOnSurfaceLight = Color(0xFFEFF1F6)
val inversePrimaryLight = Color(0xFF9DCAFC)
val surfaceDimLight = Color(0xFFD8DAE0)
val surfaceBrightLight = Color(0xFFF8F9FF)
val surfaceContainerLowestLight = Color(0xFFFFFFFF)
val surfaceContainerLowLight = Color(0xFFF2F3F9)
val surfaceContainerLight = Color(0xFFECEEF4)
val surfaceContainerHighLight = Color(0xFFE6E8EE)
val surfaceContainerHighestLight = Color(0xFFE0E2E8)
val primaryDark = Color(0xFF9DCAFC)
val onPrimaryDark = Color(0xFF003355)
val primaryContainerDark = Color(0xFF144A74)
val onPrimaryContainerDark = Color(0xFFD0E4FF)
val secondaryDark = Color(0xFFBAC8DA)
val onSecondaryDark = Color(0xFF243240)
val secondaryContainerDark = Color(0xFF3B4857)
val onSecondaryContainerDark = Color(0xFFD6E4F7)
val tertiaryDark = Color(0xFFE8C26C)
val onTertiaryDark = Color(0xFF3F2E00)
val tertiaryContainerDark = Color(0xFF5B4300)
val onTertiaryContainerDark = Color(0xFFFFDF9B)
val errorDark = Color(0xFFFFB4AB)
val onErrorDark = Color(0xFF561E19)
val errorContainerDark = Color(0xFF73342D)
val onErrorContainerDark = Color(0xFFFFDAD5)
val backgroundDark = Color(0xFF101418)
val onBackgroundDark = Color(0xFFE0E2E8)
val surfaceDark = Color(0xFF101418)
val onSurfaceDark = Color(0xFFE0E2E8)
val surfaceVariantDark = Color(0xFF42474E)
val onSurfaceVariantDark = Color(0xFFC2C7CF)
val outlineDark = Color(0xFF8C9199)
val outlineVariantDark = Color(0xFF42474E)
val scrimDark = Color(0xFF000000)
val inverseSurfaceDark = Color(0xFFE0E2E8)
val inverseOnSurfaceDark = Color(0xFF2D3135)
val inversePrimaryDark = Color(0xFF32628D)
val surfaceDimDark = Color(0xFF101418)
val surfaceBrightDark = Color(0xFF36393E)
val surfaceContainerLowestDark = Color(0xFF0B0E12)
val surfaceContainerLowDark = Color(0xFF191C20)
val surfaceContainerDark = Color(0xFF1D2024)
val surfaceContainerHighDark = Color(0xFF272A2F)
val surfaceContainerHighestDark = Color(0xFF32353A)

View file

@ -0,0 +1,223 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.theme
import android.app.Activity
import androidx.compose.foundation.isSystemInDarkTheme
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.darkColorScheme
import androidx.compose.material3.lightColorScheme
import androidx.compose.runtime.Composable
import androidx.compose.runtime.CompositionLocalProvider
import androidx.compose.runtime.Immutable
import androidx.compose.runtime.ReadOnlyComposable
import androidx.compose.runtime.SideEffect
import androidx.compose.runtime.staticCompositionLocalOf
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalView
import androidx.core.view.WindowCompat
private val lightScheme = lightColorScheme(
primary = primaryLight,
onPrimary = onPrimaryLight,
primaryContainer = primaryContainerLight,
onPrimaryContainer = onPrimaryContainerLight,
secondary = secondaryLight,
onSecondary = onSecondaryLight,
secondaryContainer = secondaryContainerLight,
onSecondaryContainer = onSecondaryContainerLight,
tertiary = tertiaryLight,
onTertiary = onTertiaryLight,
tertiaryContainer = tertiaryContainerLight,
onTertiaryContainer = onTertiaryContainerLight,
error = errorLight,
onError = onErrorLight,
errorContainer = errorContainerLight,
onErrorContainer = onErrorContainerLight,
background = backgroundLight,
onBackground = onBackgroundLight,
surface = surfaceLight,
onSurface = onSurfaceLight,
surfaceVariant = surfaceVariantLight,
onSurfaceVariant = onSurfaceVariantLight,
outline = outlineLight,
outlineVariant = outlineVariantLight,
scrim = scrimLight,
inverseSurface = inverseSurfaceLight,
inverseOnSurface = inverseOnSurfaceLight,
inversePrimary = inversePrimaryLight,
surfaceDim = surfaceDimLight,
surfaceBright = surfaceBrightLight,
surfaceContainerLowest = surfaceContainerLowestLight,
surfaceContainerLow = surfaceContainerLowLight,
surfaceContainer = surfaceContainerLight,
surfaceContainerHigh = surfaceContainerHighLight,
surfaceContainerHighest = surfaceContainerHighestLight,
)
private val darkScheme = darkColorScheme(
primary = primaryDark,
onPrimary = onPrimaryDark,
primaryContainer = primaryContainerDark,
onPrimaryContainer = onPrimaryContainerDark,
secondary = secondaryDark,
onSecondary = onSecondaryDark,
secondaryContainer = secondaryContainerDark,
onSecondaryContainer = onSecondaryContainerDark,
tertiary = tertiaryDark,
onTertiary = onTertiaryDark,
tertiaryContainer = tertiaryContainerDark,
onTertiaryContainer = onTertiaryContainerDark,
error = errorDark,
onError = onErrorDark,
errorContainer = errorContainerDark,
onErrorContainer = onErrorContainerDark,
background = backgroundDark,
onBackground = onBackgroundDark,
surface = surfaceDark,
onSurface = onSurfaceDark,
surfaceVariant = surfaceVariantDark,
onSurfaceVariant = onSurfaceVariantDark,
outline = outlineDark,
outlineVariant = outlineVariantDark,
scrim = scrimDark,
inverseSurface = inverseSurfaceDark,
inverseOnSurface = inverseOnSurfaceDark,
inversePrimary = inversePrimaryDark,
surfaceDim = surfaceDimDark,
surfaceBright = surfaceBrightDark,
surfaceContainerLowest = surfaceContainerLowestDark,
surfaceContainerLow = surfaceContainerLowDark,
surfaceContainer = surfaceContainerDark,
surfaceContainerHigh = surfaceContainerHighDark,
surfaceContainerHighest = surfaceContainerHighestDark,
)
@Immutable
data class CustomColors(
val taskBgColors: List<Color> = listOf(),
val taskIconColors: List<Color> = listOf(),
val taskIconShapeBgColor: Color = Color.Transparent,
val homeBottomGradient: List<Color> = listOf(),
val userBubbleBgColor: Color = Color.Transparent,
val agentBubbleBgColor: Color = Color.Transparent,
)
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
val lightCustomColors = CustomColors(
taskBgColors = listOf(
// yellow
Color(0xFFFFEFC9),
// red
Color(0xFFFFEDE6),
// green
Color(0xFFE1F6DE),
// blue
Color(0xFFEDF0FF)
),
taskIconColors = listOf(
Color(0xFFE37400),
Color(0xFFD93025),
Color(0xFF34A853),
Color(0xFF1967D2),
),
taskIconShapeBgColor = Color.White,
homeBottomGradient = listOf(
Color(0x00F8F9FF),
Color(0xffFFEFC9)
),
agentBubbleBgColor = Color(0xFFe9eef6),
userBubbleBgColor = Color(0xFF32628D),
)
val darkCustomColors = CustomColors(
taskBgColors = listOf(
// yellow
Color(0xFF33302A),
// red
Color(0xFF362F2D),
// green
Color(0xFF2E312D),
// blue
Color(0xFF303033)
),
taskIconColors = listOf(
Color(0xFFFFB955),
Color(0xFFFFB4AB),
Color(0xFF6DD58C),
Color(0xFFAAC7FF),
),
taskIconShapeBgColor = Color(0xFF202124),
homeBottomGradient = listOf(
Color(0x00F8F9FF),
Color(0x1AF6AD01)
),
agentBubbleBgColor = Color(0xFF1b1c1d),
userBubbleBgColor = Color(0xFF1f3760),
)
val MaterialTheme.customColors: CustomColors
@Composable
@ReadOnlyComposable
get() = LocalCustomColors.current
/**
* Controls the color of the phone's status bar icons based on whether the app is using a dark
* theme.
*/
@Composable
fun StatusBarColorController(useDarkTheme: Boolean) {
val view = LocalView.current
val currentWindow = (view.context as? Activity)?.window
if (currentWindow != null) {
SideEffect {
WindowCompat.setDecorFitsSystemWindows(currentWindow, false)
val controller = WindowCompat.getInsetsController(currentWindow, view)
controller.isAppearanceLightStatusBars = !useDarkTheme // Set to true for light icons
}
}
}
@Composable
fun GalleryTheme(
content: @Composable () -> Unit
) {
val themeOverride = ThemeSettings.themeOverride
val darkTheme: Boolean = isSystemInDarkTheme() || themeOverride.value == THEME_DARK
StatusBarColorController(useDarkTheme = darkTheme)
val colorScheme = when {
darkTheme -> darkScheme
else -> lightScheme
}
val customColorsPalette = if (darkTheme) darkCustomColors else lightCustomColors
CompositionLocalProvider(
LocalCustomColors provides customColorsPalette
) {
MaterialTheme(
colorScheme = colorScheme,
typography = AppTypography,
content = content
)
}
}

View file

@ -0,0 +1,27 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.theme
import androidx.compose.runtime.mutableStateOf
const val THEME_AUTO = "Auto"
const val THEME_LIGHT = "Light"
const val THEME_DARK = "Dark"
object ThemeSettings {
val themeOverride = mutableStateOf("")
}

View file

@ -0,0 +1,91 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.theme
import androidx.compose.material3.Typography
import androidx.compose.ui.text.font.Font
import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.sp
import com.google.aiedge.gallery.R
val nunitoFontFamily = FontFamily(
Font(R.font.nunito_regular, FontWeight.Normal),
Font(R.font.nunito_extralight, FontWeight.ExtraLight),
Font(R.font.nunito_light, FontWeight.Light),
Font(R.font.nunito_medium, FontWeight.Medium),
Font(R.font.nunito_semibold, FontWeight.SemiBold),
Font(R.font.nunito_bold, FontWeight.Bold),
Font(R.font.nunito_extrabold, FontWeight.ExtraBold),
Font(R.font.nunito_black, FontWeight.Black),
)
val baseline = Typography()
val AppTypography = Typography(
displayLarge = baseline.displayLarge.copy(fontFamily = nunitoFontFamily),
displayMedium = baseline.displayMedium.copy(fontFamily = nunitoFontFamily),
displaySmall = baseline.displaySmall.copy(fontFamily = nunitoFontFamily),
headlineLarge = baseline.headlineLarge.copy(fontFamily = nunitoFontFamily),
headlineMedium = baseline.headlineMedium.copy(fontFamily = nunitoFontFamily),
headlineSmall = baseline.headlineSmall.copy(fontFamily = nunitoFontFamily),
titleLarge = baseline.titleLarge.copy(fontFamily = nunitoFontFamily),
titleMedium = baseline.titleMedium.copy(fontFamily = nunitoFontFamily),
titleSmall = baseline.titleSmall.copy(fontFamily = nunitoFontFamily),
bodyLarge = baseline.bodyLarge.copy(fontFamily = nunitoFontFamily),
bodyMedium = baseline.bodyMedium.copy(fontFamily = nunitoFontFamily),
bodySmall = baseline.bodySmall.copy(fontFamily = nunitoFontFamily),
labelLarge = baseline.labelLarge.copy(fontFamily = nunitoFontFamily),
labelMedium = baseline.labelMedium.copy(fontFamily = nunitoFontFamily),
labelSmall = baseline.labelSmall.copy(fontFamily = nunitoFontFamily),
)
val titleMediumNarrow =
baseline.titleMedium.copy(fontFamily = nunitoFontFamily, letterSpacing = 0.0.sp)
val titleSmaller = baseline.titleSmall.copy(
fontFamily = nunitoFontFamily,
fontSize = 12.sp,
fontWeight = FontWeight.Bold
)
val labelSmallNarrow =
baseline.labelSmall.copy(fontFamily = nunitoFontFamily, letterSpacing = 0.0.sp)
val labelSmallNarrowMedium =
baseline.labelSmall.copy(
fontFamily = nunitoFontFamily,
fontWeight = FontWeight.Medium,
letterSpacing = 0.0.sp
)
val bodySmallNarrow =
baseline.bodySmall.copy(fontFamily = nunitoFontFamily, letterSpacing = 0.0.sp)
val bodySmallSemiBold =
baseline.bodySmall.copy(fontFamily = nunitoFontFamily, fontWeight = FontWeight.SemiBold)
val bodySmallMediumNarrow =
baseline.bodySmall.copy(fontFamily = nunitoFontFamily, letterSpacing = 0.0.sp, fontSize = 14.sp)
val bodySmallMediumNarrowBold =
baseline.bodySmall.copy(
fontFamily = nunitoFontFamily,
letterSpacing = 0.0.sp,
fontSize = 14.sp,
fontWeight = FontWeight.Bold
)

View file

@ -0,0 +1,243 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.worker
import android.content.Context
import android.util.Log
import androidx.work.CoroutineWorker
import androidx.work.Data
import androidx.work.WorkerParameters
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ACCESS_TOKEN
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_ERROR_MESSAGE
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_FILE_NAME
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RATE
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_RECEIVED_BYTES
import com.google.aiedge.gallery.data.KEY_MODEL_DOWNLOAD_REMAINING_MS
import com.google.aiedge.gallery.data.KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES
import com.google.aiedge.gallery.data.KEY_MODEL_EXTRA_DATA_URLS
import com.google.aiedge.gallery.data.KEY_MODEL_IS_ZIP
import com.google.aiedge.gallery.data.KEY_MODEL_START_UNZIPPING
import com.google.aiedge.gallery.data.KEY_MODEL_TOTAL_BYTES
import com.google.aiedge.gallery.data.KEY_MODEL_UNZIPPED_DIR
import com.google.aiedge.gallery.data.KEY_MODEL_URL
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.BufferedInputStream
import java.io.File
import java.io.FileInputStream
import java.io.FileOutputStream
import java.io.IOException
import java.net.HttpURLConnection
import java.net.URL
import java.util.zip.ZipEntry
import java.util.zip.ZipInputStream
private const val TAG = "AGDownloadWorker"
data class UrlAndFileName(val url: String, val fileName: String)
class DownloadWorker(context: Context, params: WorkerParameters) :
CoroutineWorker(context, params) {
private val externalFilesDir = context.getExternalFilesDir(null)
override suspend fun doWork(): Result {
val fileUrl = inputData.getString(KEY_MODEL_URL)
val fileName = inputData.getString(KEY_MODEL_DOWNLOAD_FILE_NAME)
val isZip = inputData.getBoolean(KEY_MODEL_IS_ZIP, false)
val unzippedDir = inputData.getString(KEY_MODEL_UNZIPPED_DIR)
val extraDataFileUrls = inputData.getString(KEY_MODEL_EXTRA_DATA_URLS)?.split(",") ?: listOf()
val extraDataFileNames =
inputData.getString(KEY_MODEL_EXTRA_DATA_DOWNLOAD_FILE_NAMES)?.split(",") ?: listOf()
val totalBytes = inputData.getLong(KEY_MODEL_TOTAL_BYTES, 0L)
val accessToken = inputData.getString(KEY_MODEL_DOWNLOAD_ACCESS_TOKEN)
return withContext(Dispatchers.IO) {
if (fileUrl == null || fileName == null) {
Result.failure()
} else {
return@withContext try {
// Collect data for all files.
val allFiles: MutableList<UrlAndFileName> = mutableListOf()
allFiles.add(UrlAndFileName(url = fileUrl, fileName = fileName))
for (index in extraDataFileUrls.indices) {
allFiles.add(
UrlAndFileName(
url = extraDataFileUrls[index], fileName = extraDataFileNames[index]
)
)
}
Log.d(TAG, "About to download: $allFiles")
// Download them in sequence.
// TODO: maybe consider downloading them in parallel.
var downloadedBytes = 0L
val bytesReadSizeBuffer: MutableList<Long> = mutableListOf()
val bytesReadLatencyBuffer: MutableList<Long> = mutableListOf()
for (file in allFiles) {
val url = URL(file.url)
val connection = url.openConnection() as HttpURLConnection
if (accessToken != null) {
connection.setRequestProperty("Authorization", "Bearer $accessToken")
}
// Read the file and see if it is partially downloaded.
val outputFile = File(applicationContext.getExternalFilesDir(null), file.fileName)
val outputFileBytes = outputFile.length()
if (outputFileBytes > 0) {
Log.d(
TAG,
"File '${file.fileName}' partial size: ${outputFileBytes}. Trying to resume download"
)
connection.setRequestProperty(
"Range", "bytes=${outputFileBytes}-"
)
}
connection.connect()
Log.d(TAG, "response code: ${connection.responseCode}")
if (connection.responseCode == HttpURLConnection.HTTP_OK || connection.responseCode == HttpURLConnection.HTTP_PARTIAL) {
val contentRange = connection.getHeaderField("Content-Range")
if (contentRange != null) {
// Parse the Content-Range header
val rangeParts = contentRange.substringAfter("bytes ").split("/")
val byteRange = rangeParts[0].split("-")
val startByte = byteRange[0].toLong()
val endByte = byteRange[1].toLong()
Log.d(
TAG,
"Content-Range: $contentRange. Start bytes: ${startByte}, end bytes: $endByte"
)
downloadedBytes += startByte
} else {
Log.d(TAG, "Download starts from beginning.")
}
} else {
throw IOException("HTTP error code: ${connection.responseCode}")
}
val inputStream = connection.inputStream
val outputStream = FileOutputStream(outputFile, true /* append */)
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
var bytesRead: Int
var lastSetProgressTs: Long = 0
var deltaBytes = 0L
while (inputStream.read(buffer).also { bytesRead = it } != -1) {
outputStream.write(buffer, 0, bytesRead)
downloadedBytes += bytesRead
deltaBytes += bytesRead
// Report progress every 200 ms.
val curTs = System.currentTimeMillis()
if (curTs - lastSetProgressTs > 200) {
// Calculate download rate.
var bytesPerMs = 0f
if (lastSetProgressTs != 0L) {
if (bytesReadSizeBuffer.size == 5) {
bytesReadSizeBuffer.removeAt(bytesReadLatencyBuffer.lastIndex)
}
bytesReadSizeBuffer.add(deltaBytes)
if (bytesReadLatencyBuffer.size == 5) {
bytesReadLatencyBuffer.removeAt(bytesReadLatencyBuffer.lastIndex)
}
bytesReadLatencyBuffer.add(curTs - lastSetProgressTs)
deltaBytes = 0L
bytesPerMs = bytesReadSizeBuffer.sum().toFloat() / bytesReadLatencyBuffer.sum()
}
// Calculate remaining seconds
var remainingMs = 0f
if (bytesPerMs > 0f && totalBytes > 0L) {
remainingMs = (totalBytes - downloadedBytes) / bytesPerMs
}
setProgress(
Data.Builder().putLong(
KEY_MODEL_DOWNLOAD_RECEIVED_BYTES, downloadedBytes
).putLong(KEY_MODEL_DOWNLOAD_RATE, (bytesPerMs * 1000).toLong()).putLong(
KEY_MODEL_DOWNLOAD_REMAINING_MS, remainingMs.toLong()
).build()
)
lastSetProgressTs = curTs
}
}
outputStream.close()
inputStream.close()
Log.d(TAG, "Download done")
// Unzip if the downloaded file is a zip.
if (isZip && unzippedDir != null) {
setProgress(Data.Builder().putBoolean(KEY_MODEL_START_UNZIPPING, true).build())
// Prepare target dir.
val destDir = File("${externalFilesDir}${File.separator}${unzippedDir}")
if (!destDir.exists()) {
destDir.mkdirs()
}
// Unzip.
val unzipBuffer = ByteArray(4096)
val zipFilePath = "${externalFilesDir}${File.separator}${fileName}"
val zipIn = ZipInputStream(BufferedInputStream(FileInputStream(zipFilePath)))
var zipEntry: ZipEntry? = zipIn.nextEntry
while (zipEntry != null) {
val filePath = destDir.absolutePath + File.separator + zipEntry.name
// Extract files.
if (!zipEntry.isDirectory) {
// extract file
val bos = FileOutputStream(filePath)
bos.use { curBos ->
var len: Int
while (zipIn.read(unzipBuffer).also { len = it } > 0) {
curBos.write(unzipBuffer, 0, len)
}
}
}
// Create dir.
else {
val dir = File(filePath)
dir.mkdirs()
}
zipIn.closeEntry()
zipEntry = zipIn.nextEntry
}
zipIn.close()
// Delete the original file.
val zipFile = File(zipFilePath)
zipFile.delete()
}
}
Result.success()
} catch (e: IOException) {
Result.failure(
Data.Builder().putString(KEY_MODEL_DOWNLOAD_ERROR_MESSAGE, e.message).build()
)
}
}
}
}
}

View file

@ -0,0 +1,27 @@
<!--
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="38dp"
android:height="38dp"
android:viewportWidth="38"
android:viewportHeight="38">
<group>
<path
android:fillColor="#FF1967D2"
android:pathData="M9.32 21.5h15.3v-1.7c-0.44-0.1-0.87-0.24-1.28-0.4-0.41-0.17-0.82-0.37-1.2-0.57H9.32v2.67Zm0-4.89h10.1c-0.34-0.39-0.65-0.8-0.93-1.24-0.29-0.44-0.54-0.92-0.74-1.44H9.32v2.68Zm0-4.9h7.53V9.09H9.32v2.64Zm17.86 6.42c0-2.3-0.81-4.26-2.44-5.87-1.6-1.63-3.56-2.44-5.87-2.44 2.3 0 4.26-0.8 5.87-2.41 1.63-1.63 2.44-3.6 2.44-5.9 0 2.3 0.8 4.27 2.41 5.9 1.63 1.6 3.6 2.4 5.9 2.4-2.3 0-4.27 0.82-5.9 2.45-1.6 1.6-2.4 3.56-2.4 5.87Zm-21.4 1.04V25v2.49V5.5v2.1 2.22 7.34 2.99-0.24-0.74ZM3.12 33.9V5.5c0-0.72 0.26-1.34 0.77-1.86 0.55-0.54 1.18-0.81 1.9-0.81h13.83c-0.36 0.38-0.7 0.8-1 1.24-0.32 0.44-0.58 0.92-0.78 1.43H5.79V27.5L8.16 25H31.5v-5.83c0.51-0.2 1-0.46 1.43-0.77 0.44-0.31 0.86-0.65 1.24-1.01V25c0 0.72-0.27 1.36-0.81 1.9-0.52 0.52-1.14 0.78-1.86 0.78H9.32l-6.21 6.21Z"/>
</group>
</vector>

View file

@ -0,0 +1,26 @@
<!--
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<vector xmlns:android="http://schemas.android.com/apk/res/android" xmlns:aapt="http://schemas.android.com/aapt"
android:viewportWidth="63"
android:viewportHeight="63"
android:width="63dp"
android:height="63dp">
<path
android:pathData="M63 31.5C63 14.103 48.897 0 31.5 0C14.103 0 0 14.103 0 31.5C0 48.897 14.103 63 31.5 63C48.897 63 63 48.897 63 31.5Z"
android:fillColor="#FFFFFF"
android:fillAlpha="1" />
</vector>

View file

@ -0,0 +1,26 @@
<!--
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<vector xmlns:android="http://schemas.android.com/apk/res/android" xmlns:aapt="http://schemas.android.com/aapt"
android:viewportWidth="62"
android:viewportHeight="62"
android:width="62dp"
android:height="62dp">
<path
android:pathData="M7.78628 54.2125C-2.59542 43.8291 -2.59542 26.9941 7.78628 16.6107L16.6114 7.78754C26.9931 -2.59584 43.8287 -2.59584 54.2137 7.78754C64.5954 18.1709 64.5954 35.0059 54.2137 45.3893L45.3886 54.2125C35.0069 64.5958 18.1713 64.5958 7.78628 54.2125Z"
android:fillColor="#FFFFFF"
android:fillAlpha="1" />
</vector>

View file

@ -0,0 +1,26 @@
<!--
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<vector xmlns:android="http://schemas.android.com/apk/res/android" xmlns:aapt="http://schemas.android.com/aapt"
android:viewportWidth="59"
android:viewportHeight="59"
android:width="59dp"
android:height="59dp">
<path
android:pathData="M1.3077 21.3393C-4.19593 8.66644 8.66702 -4.19605 21.3396 1.30784L23.4364 2.21791C27.3032 3.89848 31.6966 3.89848 35.5666 2.21791L37.6602 1.30784C50.3359 -4.19605 63.1957 8.66644 57.6921 21.3393L56.7817 23.4344C55.1036 27.3037 55.1036 31.6964 56.7817 35.5654L57.6921 37.6609C63.1957 50.3337 50.3359 63.1962 37.6602 57.692L35.5666 56.782C31.6966 55.1017 27.3032 55.1017 23.4364 56.782L21.3396 57.692C8.66702 63.1962 -4.19593 50.3337 1.3077 37.6609L2.21809 35.5654C3.89932 31.6964 3.89932 27.3037 2.21809 23.4344L1.3077 21.3393Z"
android:fillColor="#FFFFFF"
android:fillAlpha="1" />
</vector>

View file

@ -0,0 +1,186 @@
<?xml version="1.0" encoding="utf-8"?>
<!--
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path
android:fillColor="#3DDC84"
android:pathData="M0,0h108v108h-108z" />
<path
android:fillColor="#00000000"
android:pathData="M9,0L9,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,0L19,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,0L29,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,0L39,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,0L49,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,0L59,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,0L69,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,0L79,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M89,0L89,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M99,0L99,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,9L108,9"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,19L108,19"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,29L108,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,39L108,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,49L108,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,59L108,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,69L108,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,79L108,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,89L108,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,99L108,99"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,29L89,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,39L89,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,49L89,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,59L89,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,69L89,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,79L89,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,19L29,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,19L39,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,19L49,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,19L59,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,19L69,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,19L79,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
</vector>

View file

@ -0,0 +1,46 @@
<!--
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
<aapt:attr name="android:fillColor">
<gradient
android:endX="85.84757"
android:endY="92.4963"
android:startX="42.9492"
android:startY="49.59793"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
android:strokeWidth="1"
android:strokeColor="#00000000" />
</vector>

Some files were not shown because too many files have changed in this diff Show more