Merge branch 'main' into main

This commit is contained in:
Fattire 2025-06-28 12:09:16 -07:00 committed by GitHub
commit 7123f37989
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
46 changed files with 1681 additions and 797 deletions

View file

@ -12,6 +12,10 @@ A clear and concise description of what the bug is.
**To Reproduce:** **To Reproduce:**
Steps to reproduce the behavior: Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior:** **Expected behavior:**
A clear and concise description of what you expected to happen. A clear and concise description of what you expected to happen.
@ -19,16 +23,10 @@ A clear and concise description of what you expected to happen.
**Screenshots:** **Screenshots:**
If applicable, add screenshots to help explain your problem. If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):** **Device & App Information (Please complete the following):**
- OS: [e.g. iOS] - Device: [e.g., Samsung Galaxy S23, Google Pixel 7]
- Browser [e.g. chrome, safari] - Android Version: [e.g., Android 12, Android 13]
- Version [e.g. 22] - App Version: [e.g., 1.0.1, v1.0.2]
**Smartphone (please complete the following information):**
- Device: [e.g. iPhone6]
- OS: [e.g. iOS8.1]
- Browser [e.g. stock browser, safari]
- Version [e.g. 22]
**Additional context:** **Additional context:**
Add any other context about the problem here. Add any other context about the problem here.

View file

@ -21,5 +21,9 @@ jobs:
steps: steps:
- name: Checkout the source code - name: Checkout the source code
uses: actions/checkout@v3 uses: actions/checkout@v3
- uses: actions/setup-java@v4
with:
distribution: 'temurin'
java-version: '21'
- name: Build - name: Build
run: ./gradlew assembleRelease run: ./gradlew assembleRelease

View file

@ -20,6 +20,8 @@ plugins {
alias(libs.plugins.kotlin.compose) alias(libs.plugins.kotlin.compose)
alias(libs.plugins.kotlin.serialization) alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.protobuf) alias(libs.plugins.protobuf)
alias(libs.plugins.hilt.application)
kotlin("kapt")
} }
android { android {
@ -91,11 +93,15 @@ dependencies {
implementation(libs.openid.appauth) implementation(libs.openid.appauth)
implementation(libs.androidx.splashscreen) implementation(libs.androidx.splashscreen)
implementation(libs.protobuf.javalite) implementation(libs.protobuf.javalite)
implementation(libs.hilt.android)
implementation(libs.hilt.navigation.compose)
kapt(libs.hilt.android.compiler)
testImplementation(libs.junit) testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit) androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core) androidTestImplementation(libs.androidx.espresso.core)
androidTestImplementation(platform(libs.androidx.compose.bom)) androidTestImplementation(platform(libs.androidx.compose.bom))
androidTestImplementation(libs.androidx.ui.test.junit4) androidTestImplementation(libs.androidx.ui.test.junit4)
androidTestImplementation(libs.hilt.android.testing)
debugImplementation(libs.androidx.ui.tooling) debugImplementation(libs.androidx.ui.tooling)
debugImplementation(libs.androidx.ui.test.manifest) debugImplementation(libs.androidx.ui.test.manifest)
implementation(libs.material) implementation(libs.material)

View file

@ -29,6 +29,7 @@
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/> <uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
<uses-permission android:name="android.permission.INTERNET" /> <uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" /> <uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<uses-permission android:name="android.permission.WAKE_LOCK"/> <uses-permission android:name="android.permission.WAKE_LOCK"/>
<uses-feature <uses-feature
@ -70,17 +71,6 @@
</intent-filter> </intent-filter>
</activity> </activity>
<!-- For LLM inference engine -->
<uses-native-library
android:name="libOpenCL.so"
android:required="false" />
<uses-native-library
android:name="libOpenCL-car.so"
android:required="false" />
<uses-native-library
android:name="libOpenCL-pixel.so"
android:required="false" />
<provider <provider
android:name="androidx.core.content.FileProvider" android:name="androidx.core.content.FileProvider"
android:authorities="${applicationId}.provider" android:authorities="${applicationId}.provider"

View file

@ -17,48 +17,23 @@
package com.google.ai.edge.gallery package com.google.ai.edge.gallery
import android.app.Application import android.app.Application
import android.content.Context
import androidx.datastore.core.CorruptionException
import androidx.datastore.core.DataStore
import androidx.datastore.core.Serializer
import androidx.datastore.dataStore
import com.google.ai.edge.gallery.common.writeLaunchInfo import com.google.ai.edge.gallery.common.writeLaunchInfo
import com.google.ai.edge.gallery.data.AppContainer import com.google.ai.edge.gallery.data.DataStoreRepository
import com.google.ai.edge.gallery.data.DefaultAppContainer
import com.google.ai.edge.gallery.proto.Settings
import com.google.ai.edge.gallery.ui.theme.ThemeSettings import com.google.ai.edge.gallery.ui.theme.ThemeSettings
import com.google.protobuf.InvalidProtocolBufferException import dagger.hilt.android.HiltAndroidApp
import java.io.InputStream import javax.inject.Inject
import java.io.OutputStream
object SettingsSerializer : Serializer<Settings> {
override val defaultValue: Settings = Settings.getDefaultInstance()
override suspend fun readFrom(input: InputStream): Settings {
try {
return Settings.parseFrom(input)
} catch (exception: InvalidProtocolBufferException) {
throw CorruptionException("Cannot read proto.", exception)
}
}
override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
}
private val Context.dataStore: DataStore<Settings> by
dataStore(fileName = "settings.pb", serializer = SettingsSerializer)
@HiltAndroidApp
class GalleryApplication : Application() { class GalleryApplication : Application() {
/** AppContainer instance used by the rest of classes to obtain dependencies */
lateinit var container: AppContainer @Inject lateinit var dataStoreRepository: DataStoreRepository
override fun onCreate() { override fun onCreate() {
super.onCreate() super.onCreate()
writeLaunchInfo(context = this) writeLaunchInfo(context = this)
container = DefaultAppContainer(this, dataStore)
// Load saved theme. // Load saved theme.
ThemeSettings.themeOverride.value = container.dataStoreRepository.readTheme() ThemeSettings.themeOverride.value = dataStoreRepository.readTheme()
} }
} }

View file

@ -18,6 +18,7 @@ package com.google.ai.edge.gallery
import android.os.Build import android.os.Build
import android.os.Bundle import android.os.Bundle
import android.view.WindowManager
import androidx.activity.ComponentActivity import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge import androidx.activity.enableEdgeToEdge
@ -26,8 +27,11 @@ import androidx.compose.material3.Surface
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import com.google.ai.edge.gallery.ui.theme.GalleryTheme import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import dagger.hilt.android.AndroidEntryPoint
@AndroidEntryPoint
class MainActivity : ComponentActivity() { class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) { override fun onCreate(savedInstanceState: Bundle?) {
installSplashScreen() installSplashScreen()
@ -39,5 +43,7 @@ class MainActivity : ComponentActivity() {
window.isNavigationBarContrastEnforced = false window.isNavigationBarContrastEnforced = false
} }
setContent { GalleryTheme { Surface(modifier = Modifier.fillMaxSize()) { GalleryApp() } } } setContent { GalleryTheme { Surface(modifier = Modifier.fillMaxSize()) { GalleryApp() } } }
// Keep the screen on while the app is running for better demo experience.
window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
} }
} }

View file

@ -0,0 +1,38 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery
import androidx.datastore.core.CorruptionException
import androidx.datastore.core.Serializer
import com.google.ai.edge.gallery.proto.Settings
import com.google.protobuf.InvalidProtocolBufferException
import java.io.InputStream
import java.io.OutputStream
object SettingsSerializer : Serializer<Settings> {
override val defaultValue: Settings = Settings.getDefaultInstance()
override suspend fun readFrom(input: InputStream): Settings {
try {
return Settings.parseFrom(input)
} catch (exception: InvalidProtocolBufferException) {
throw CorruptionException("Cannot read proto.", exception)
}
}
override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
}

View file

@ -25,3 +25,5 @@ interface LatencyProvider {
data class Classification(val label: String, val score: Float, val color: Color) data class Classification(val label: String, val score: Float, val color: Color)
data class JsonObjAndTextContent<T>(val jsonObj: T, val textContent: String) data class JsonObjAndTextContent<T>(val jsonObj: T, val textContent: String)
class AudioClip(val audioData: ByteArray, val sampleRate: Int)

View file

@ -17,12 +17,17 @@
package com.google.ai.edge.gallery.common package com.google.ai.edge.gallery.common
import android.content.Context import android.content.Context
import android.net.Uri
import android.util.Log import android.util.Log
import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.gson.Gson import com.google.gson.Gson
import com.google.gson.reflect.TypeToken import com.google.gson.reflect.TypeToken
import java.io.File import java.io.File
import java.net.HttpURLConnection import java.net.HttpURLConnection
import java.net.URL import java.net.URL
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.math.floor
data class LaunchInfo(val ts: Long) data class LaunchInfo(val ts: Long)
@ -112,3 +117,135 @@ inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
return null return null
} }
fun convertWavToMonoWithMaxSeconds(
context: Context,
stereoUri: Uri,
maxSeconds: Int = 30,
): AudioClip? {
Log.d(TAG, "Start to convert wav file to mono channel")
try {
val inputStream = context.contentResolver.openInputStream(stereoUri) ?: return null
val originalBytes = inputStream.readBytes()
inputStream.close()
// Read WAV header
if (originalBytes.size < 44) {
// Not a valid WAV file
Log.e(TAG, "Not a valid wav file")
return null
}
val headerBuffer = ByteBuffer.wrap(originalBytes, 0, 44).order(ByteOrder.LITTLE_ENDIAN)
val channels = headerBuffer.getShort(22)
var sampleRate = headerBuffer.getInt(24)
val bitDepth = headerBuffer.getShort(34)
Log.d(TAG, "File metadata: channels: $channels, sampleRate: $sampleRate, bitDepth: $bitDepth")
// Normalize audio to 16-bit.
val audioDataBytes = originalBytes.copyOfRange(fromIndex = 44, toIndex = originalBytes.size)
var sixteenBitBytes: ByteArray =
if (bitDepth.toInt() == 8) {
Log.d(TAG, "Converting 8-bit audio to 16-bit.")
convert8BitTo16Bit(audioDataBytes)
} else {
// Assume 16-bit or other format that can be handled directly
audioDataBytes
}
// Convert byte array to short array for processing
val shortBuffer =
ByteBuffer.wrap(sixteenBitBytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
var pcmSamples = ShortArray(shortBuffer.remaining())
shortBuffer.get(pcmSamples)
// Resample if sample rate is less than 16000 Hz ---
if (sampleRate < SAMPLE_RATE) {
Log.d(TAG, "Resampling from $sampleRate Hz to $SAMPLE_RATE Hz.")
pcmSamples = resample(pcmSamples, sampleRate, SAMPLE_RATE, channels.toInt())
sampleRate = SAMPLE_RATE
Log.d(TAG, "Resampling complete. New sample count: ${pcmSamples.size}")
}
// Convert stereo to mono if necessary
var monoSamples =
if (channels.toInt() == 2) {
Log.d(TAG, "Converting stereo to mono.")
val mono = ShortArray(pcmSamples.size / 2)
for (i in mono.indices) {
val left = pcmSamples[i * 2]
val right = pcmSamples[i * 2 + 1]
mono[i] = ((left + right) / 2).toShort()
}
mono
} else {
Log.d(TAG, "Audio is already mono. No channel conversion needed.")
pcmSamples
}
// Trim the audio to maxSeconds ---
val maxSamples = maxSeconds * sampleRate
if (monoSamples.size > maxSamples) {
Log.d(TAG, "Trimming clip from ${monoSamples.size} samples to $maxSamples samples.")
monoSamples = monoSamples.copyOfRange(0, maxSamples)
}
val monoByteBuffer = ByteBuffer.allocate(monoSamples.size * 2).order(ByteOrder.LITTLE_ENDIAN)
monoByteBuffer.asShortBuffer().put(monoSamples)
return AudioClip(audioData = monoByteBuffer.array(), sampleRate = sampleRate)
} catch (e: Exception) {
Log.e(TAG, "Failed to convert wav to mono", e)
return null
}
}
/** Converts 8-bit unsigned PCM audio data to 16-bit signed PCM. */
private fun convert8BitTo16Bit(eightBitData: ByteArray): ByteArray {
// The new 16-bit data will be twice the size
val sixteenBitData = ByteArray(eightBitData.size * 2)
val buffer = ByteBuffer.wrap(sixteenBitData).order(ByteOrder.LITTLE_ENDIAN)
for (byte in eightBitData) {
// Convert the unsigned 8-bit byte (0-255) to a signed 16-bit short (-32768 to 32767)
// 1. Get the unsigned value by masking with 0xFF
// 2. Subtract 128 to center the waveform around 0 (range becomes -128 to 127)
// 3. Scale by 256 to expand to the 16-bit range
val unsignedByte = byte.toInt() and 0xFF
val sixteenBitSample = ((unsignedByte - 128) * 256).toShort()
buffer.putShort(sixteenBitSample)
}
return sixteenBitData
}
/** Resamples PCM audio data from an original sample rate to a target sample rate. */
private fun resample(
inputSamples: ShortArray,
originalSampleRate: Int,
targetSampleRate: Int,
channels: Int,
): ShortArray {
if (originalSampleRate == targetSampleRate) {
return inputSamples
}
val ratio = targetSampleRate.toDouble() / originalSampleRate
val outputLength = (inputSamples.size * ratio).toInt()
val resampledData = ShortArray(outputLength)
if (channels == 1) { // Mono
for (i in resampledData.indices) {
val position = i / ratio
val index1 = floor(position).toInt()
val index2 = index1 + 1
val fraction = position - index1
val sample1 = if (index1 < inputSamples.size) inputSamples[index1].toDouble() else 0.0
val sample2 = if (index2 < inputSamples.size) inputSamples[index2].toDouble() else 0.0
resampledData[i] = (sample1 * (1 - fraction) + sample2 * fraction).toInt().toShort()
}
}
return resampledData
}

View file

@ -50,6 +50,7 @@ enum class ConfigKey(val label: String) {
DEFAULT_TOPP("Default TopP"), DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"), DEFAULT_TEMPERATURE("Default temperature"),
SUPPORT_IMAGE("Support image"), SUPPORT_IMAGE("Support image"),
SUPPORT_AUDIO("Support audio"),
MAX_RESULT_COUNT("Max result count"), MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"), USE_GPU("Use GPU"),
ACCELERATOR("Choose accelerator"), ACCELERATOR("Choose accelerator"),

View file

@ -44,3 +44,12 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
// Max number of images allowed in a "ask image" session. // Max number of images allowed in a "ask image" session.
const val MAX_IMAGE_COUNT = 10 const val MAX_IMAGE_COUNT = 10
// Max number of audio clip in an "ask audio" session.
const val MAX_AUDIO_CLIP_COUNT = 1
// Max audio clip duration in seconds.
const val MAX_AUDIO_CLIP_DURATION_SEC = 30
// Audio-recording related consts.
const val SAMPLE_RATE = 16000

View file

@ -87,6 +87,9 @@ data class Model(
/** Whether the LLM model supports image input. */ /** Whether the LLM model supports image input. */
val llmSupportImage: Boolean = false, val llmSupportImage: Boolean = false,
/** Whether the LLM model supports audio input. */
val llmSupportAudio: Boolean = false,
/** Whether the model is imported or not. */ /** Whether the model is imported or not. */
val imported: Boolean = false, val imported: Boolean = false,

View file

@ -38,6 +38,7 @@ data class AllowedModel(
val taskTypes: List<String>, val taskTypes: List<String>,
val disabled: Boolean? = null, val disabled: Boolean? = null,
val llmSupportImage: Boolean? = null, val llmSupportImage: Boolean? = null,
val llmSupportAudio: Boolean? = null,
val estimatedPeakMemoryInBytes: Long? = null, val estimatedPeakMemoryInBytes: Long? = null,
) { ) {
fun toModel(): Model { fun toModel(): Model {
@ -49,10 +50,10 @@ data class AllowedModel(
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id) taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)
var configs: List<Config> = listOf() var configs: List<Config> = listOf()
if (isLlmModel) { if (isLlmModel) {
var defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK val defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK
var defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP val defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP
var defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE val defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE
var defaultMaxToken = defaultConfig.maxTokens ?: 1024 val defaultMaxToken = defaultConfig.maxTokens ?: 1024
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
if (defaultConfig.accelerators != null) { if (defaultConfig.accelerators != null) {
val items = defaultConfig.accelerators.split(",") val items = defaultConfig.accelerators.split(",")
@ -96,6 +97,7 @@ data class AllowedModel(
showRunAgainButton = showRunAgainButton, showRunAgainButton = showRunAgainButton,
learnMoreUrl = "https://huggingface.co/${modelId}", learnMoreUrl = "https://huggingface.co/${modelId}",
llmSupportImage = llmSupportImage == true, llmSupportImage = llmSupportImage == true,
llmSupportAudio = llmSupportAudio == true,
) )
} }

View file

@ -17,19 +17,22 @@
package com.google.ai.edge.gallery.data package com.google.ai.edge.gallery.data
import androidx.annotation.StringRes import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Forum
import androidx.compose.material.icons.outlined.Mic
import androidx.compose.material.icons.outlined.Mms
import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.runtime.MutableState import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableLongStateOf import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.icon.Forum
import com.google.ai.edge.gallery.ui.icon.Mms
import com.google.ai.edge.gallery.ui.icon.Widgets
/** Type of task. */ /** Type of task. */
enum class TaskType(val label: String, val id: String) { enum class TaskType(val label: String, val id: String) {
LLM_CHAT(label = "AI Chat", id = "llm_chat"), LLM_CHAT(label = "AI Chat", id = "llm_chat"),
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"), LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"), LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2"), TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
} }
@ -71,7 +74,7 @@ data class Task(
val TASK_LLM_CHAT = val TASK_LLM_CHAT =
Task( Task(
type = TaskType.LLM_CHAT, type = TaskType.LLM_CHAT,
icon = Forum, icon = Icons.Outlined.Forum,
models = mutableListOf(), models = mutableListOf(),
description = "Chat with on-device large language models", description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -83,7 +86,7 @@ val TASK_LLM_CHAT =
val TASK_LLM_PROMPT_LAB = val TASK_LLM_PROMPT_LAB =
Task( Task(
type = TaskType.LLM_PROMPT_LAB, type = TaskType.LLM_PROMPT_LAB,
icon = Widgets, icon = Icons.Outlined.Widgets,
models = mutableListOf(), models = mutableListOf(),
description = "Single turn use cases with on-device large language model", description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -95,7 +98,7 @@ val TASK_LLM_PROMPT_LAB =
val TASK_LLM_ASK_IMAGE = val TASK_LLM_ASK_IMAGE =
Task( Task(
type = TaskType.LLM_ASK_IMAGE, type = TaskType.LLM_ASK_IMAGE,
icon = Mms, icon = Icons.Outlined.Mms,
models = mutableListOf(), models = mutableListOf(),
description = "Ask questions about images with on-device large language models", description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -104,8 +107,23 @@ val TASK_LLM_ASK_IMAGE =
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat, textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
) )
val TASK_LLM_ASK_AUDIO =
Task(
type = TaskType.LLM_ASK_AUDIO,
icon = Icons.Outlined.Mic,
models = mutableListOf(),
// TODO(do not submit)
description =
"Instantly transcribe and/or translate audio clips using on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl =
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
/** All tasks. */ /** All tasks. */
val TASKS: List<Task> = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT) val TASKS: List<Task> =
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
fun getModelByName(name: String): Model? { fun getModelByName(name: String): Model? {
for (task in TASKS) { for (task in TASKS) {

View file

@ -0,0 +1,86 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.di
import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.core.DataStoreFactory
import androidx.datastore.core.Serializer
import androidx.datastore.dataStoreFile
import com.google.ai.edge.gallery.AppLifecycleProvider
import com.google.ai.edge.gallery.GalleryLifecycleProvider
import com.google.ai.edge.gallery.SettingsSerializer
import com.google.ai.edge.gallery.data.DataStoreRepository
import com.google.ai.edge.gallery.data.DefaultDataStoreRepository
import com.google.ai.edge.gallery.data.DefaultDownloadRepository
import com.google.ai.edge.gallery.data.DownloadRepository
import com.google.ai.edge.gallery.proto.Settings
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.android.qualifiers.ApplicationContext
import dagger.hilt.components.SingletonComponent
import javax.inject.Singleton
@Module
@InstallIn(SingletonComponent::class)
internal object AppModule {
// Provides the SettingsSerializer
@Provides
@Singleton
fun provideSettingsSerializer(): Serializer<Settings> {
return SettingsSerializer
}
// Provides DataStore<Settings>
@Provides
@Singleton
fun provideSettingsDataStore(
@ApplicationContext context: Context,
settingsSerializer: Serializer<Settings>,
): DataStore<Settings> {
return DataStoreFactory.create(
serializer = settingsSerializer,
produceFile = { context.dataStoreFile("settings.pb") },
)
}
// Provides AppLifecycleProvider
@Provides
@Singleton
fun provideAppLifecycleProvider(): AppLifecycleProvider {
return GalleryLifecycleProvider()
}
// Provides DataStoreRepository
@Provides
@Singleton
fun provideDataStoreRepository(dataStore: DataStore<Settings>): DataStoreRepository {
return DefaultDataStoreRepository(dataStore)
}
// Provides DownloadRepository
@Provides
@Singleton
fun provideDownloadRepository(
@ApplicationContext context: Context,
lifecycleProvider: AppLifecycleProvider,
): DownloadRepository {
return DefaultDownloadRepository(context, lifecycleProvider)
}
}

View file

@ -1,60 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui
import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory
import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.initializer
import androidx.lifecycle.viewmodel.viewModelFactory
import com.google.ai.edge.gallery.GalleryApplication
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
object ViewModelProvider {
val Factory = viewModelFactory {
// Initializer for ModelManagerViewModel.
initializer {
val downloadRepository = galleryApplication().container.downloadRepository
val dataStoreRepository = galleryApplication().container.dataStoreRepository
val lifecycleProvider = galleryApplication().container.lifecycleProvider
ModelManagerViewModel(
downloadRepository = downloadRepository,
dataStoreRepository = dataStoreRepository,
lifecycleProvider = lifecycleProvider,
context = galleryApplication().container.context,
)
}
// Initializer for LlmChatViewModel.
initializer { LlmChatViewModel() }
// Initializer for LlmSingleTurnViewModel..
initializer { LlmSingleTurnViewModel() }
// Initializer for LlmAskImageViewModel.
initializer { LlmAskImageViewModel() }
}
}
/**
* Extension function to queries for [Application] object and returns an instance of
* [GalleryApplication].
*/
fun CreationExtras.galleryApplication(): GalleryApplication =
(this[AndroidViewModelFactory.APPLICATION_KEY] as GalleryApplication)

View file

@ -64,6 +64,7 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
private const val TAG = "AGDownloadAndTryButton" private const val TAG = "AGDownloadAndTryButton"
private const val SYSTEM_RESERVED_MEMORY_IN_BYTES = 3 * (1L shl 30)
// TODO: // TODO:
// - replace the download button in chat view page with this one, and add a flag to not "onclick" // - replace the download button in chat view page with this one, and add a flag to not "onclick"
@ -290,7 +291,6 @@ fun DownloadAndTryButton(
val activityManager = val activityManager =
context.getSystemService(android.app.Activity.ACTIVITY_SERVICE) as? ActivityManager context.getSystemService(android.app.Activity.ACTIVITY_SERVICE) as? ActivityManager
val estimatedPeakMemoryInBytes = model.estimatedPeakMemoryInBytes val estimatedPeakMemoryInBytes = model.estimatedPeakMemoryInBytes
val isMemoryLow = val isMemoryLow =
if (activityManager != null && estimatedPeakMemoryInBytes != null) { if (activityManager != null && estimatedPeakMemoryInBytes != null) {
val memoryInfo = ActivityManager.MemoryInfo() val memoryInfo = ActivityManager.MemoryInfo()
@ -302,11 +302,10 @@ fun DownloadAndTryButton(
// The device should be able to run the model if `availMem` is larger than the // The device should be able to run the model if `availMem` is larger than the
// estimated peak memory. Android also has a mechanism to kill background apps to // estimated peak memory. Android also has a mechanism to kill background apps to
// free up memory for the foreground app. We believe that if half of the total // free up memory for the foreground app. Reserving 3G for system buffer memory to
// memory on the device is larger than the estimated peak memory, it can run the // avoid the app being killed by the system.
// model fine with this mechanism. For example, a phone with 12GB memory can have max(memoryInfo.availMem, memoryInfo.totalMem - SYSTEM_RESERVED_MEMORY_IN_BYTES) <
// very few `availMem` but will have no problem running most models. estimatedPeakMemoryInBytes
max(memoryInfo.availMem, memoryInfo.totalMem / 2) < estimatedPeakMemoryInBytes
} else { } else {
false false
} }

View file

@ -0,0 +1,344 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.chat
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioTrack
import android.util.Log
import androidx.compose.foundation.Canvas
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.PlayArrow
import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.DisposableEffect
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.geometry.CornerRadius
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.geometry.Size
import androidx.compose.ui.geometry.toRect
import androidx.compose.ui.graphics.BlendMode
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.drawscope.drawIntoCanvas
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
import com.google.ai.edge.gallery.ui.theme.customColors
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
private const val TAG = "AGAudioPlaybackPanel"
private const val BAR_SPACE = 2
private const val BAR_WIDTH = 2
private const val MIN_BAR_COUNT = 16
private const val MAX_BAR_COUNT = 48
/**
* A Composable that displays an audio playback panel, including play/stop controls, a waveform
* visualization, and the duration of the audio clip.
*/
@Composable
fun AudioPlaybackPanel(
audioData: ByteArray,
sampleRate: Int,
isRecording: Boolean,
modifier: Modifier = Modifier,
onDarkBg: Boolean = false,
) {
val coroutineScope = rememberCoroutineScope()
var isPlaying by remember { mutableStateOf(false) }
val audioTrackState = remember { mutableStateOf<AudioTrack?>(null) }
val durationInSeconds =
remember(audioData) {
// PCM 16-bit
val bytesPerSample = 2
val bytesPerFrame = bytesPerSample * 1 // mono
val totalFrames = audioData.size.toDouble() / bytesPerFrame
totalFrames / sampleRate
}
val barCount =
remember(durationInSeconds) {
val f = durationInSeconds / MAX_AUDIO_CLIP_DURATION_SEC
((MAX_BAR_COUNT - MIN_BAR_COUNT) * f + MIN_BAR_COUNT).toInt()
}
val amplitudeLevels =
remember(audioData) { generateAmplitudeLevels(audioData = audioData, barCount = barCount) }
var playbackProgress by remember { mutableFloatStateOf(0f) }
// Reset when a new recording is started.
LaunchedEffect(isRecording) {
if (isRecording) {
val audioTrack = audioTrackState.value
audioTrack?.stop()
isPlaying = false
playbackProgress = 0f
}
}
// Cleanup on Composable Disposal.
DisposableEffect(Unit) {
onDispose {
val audioTrack = audioTrackState.value
audioTrack?.stop()
audioTrack?.release()
}
}
Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) {
// Button to play/stop the clip.
IconButton(
onClick = {
coroutineScope.launch {
if (!isPlaying) {
isPlaying = true
playAudio(
audioTrackState = audioTrackState,
audioData = audioData,
sampleRate = sampleRate,
onProgress = { playbackProgress = it },
onCompletion = {
playbackProgress = 0f
isPlaying = false
},
)
} else {
stopPlayAudio(audioTrackState = audioTrackState)
playbackProgress = 0f
isPlaying = false
}
}
}
) {
Icon(
if (isPlaying) Icons.Rounded.Stop else Icons.Rounded.PlayArrow,
contentDescription = "",
tint = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
)
}
// Visualization
AmplitudeBarGraph(
amplitudeLevels = amplitudeLevels,
progress = playbackProgress,
modifier =
Modifier.width((barCount * BAR_WIDTH + (barCount - 1) * BAR_SPACE).dp).height(24.dp),
onDarkBg = onDarkBg,
)
// Duration
Text(
"${"%.1f".format(durationInSeconds)}s",
style = MaterialTheme.typography.labelLarge,
color = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
modifier = Modifier.padding(start = 12.dp),
)
}
}
@Composable
private fun AmplitudeBarGraph(
amplitudeLevels: List<Float>,
progress: Float,
modifier: Modifier = Modifier,
onDarkBg: Boolean = false,
) {
val barColor = MaterialTheme.customColors.waveFormBgColor
val progressColor = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary
Canvas(modifier = modifier) {
val barCount = amplitudeLevels.size
val barWidth = (size.width - BAR_SPACE.dp.toPx() * (barCount - 1)) / barCount
val cornerRadius = CornerRadius(x = barWidth, y = barWidth)
// Use drawIntoCanvas for advanced blend mode operations
drawIntoCanvas { canvas ->
// 1. Save the current state of the canvas onto a temporary, offscreen layer
canvas.saveLayer(size.toRect(), androidx.compose.ui.graphics.Paint())
// 2. Draw the bars in grey.
amplitudeLevels.forEachIndexed { index, level ->
val barHeight = (level * size.height).coerceAtLeast(1.5f)
val left = index * (barWidth + BAR_SPACE.dp.toPx())
drawRoundRect(
color = barColor,
topLeft = Offset(x = left, y = size.height / 2 - barHeight / 2),
size = Size(barWidth, barHeight),
cornerRadius = cornerRadius,
)
}
// 3. Draw the progress rectangle using BlendMode.SrcIn to only draw where the bars already
// exists.
val progressWidth = size.width * progress
drawRect(
color = progressColor,
topLeft = Offset.Zero,
size = Size(progressWidth, size.height),
blendMode = BlendMode.SrcIn,
)
// 4. Restore the layer, merging it onto the main canvas
canvas.restore()
}
}
}
private suspend fun playAudio(
audioTrackState: MutableState<AudioTrack?>,
audioData: ByteArray,
sampleRate: Int,
onProgress: (Float) -> Unit,
onCompletion: () -> Unit,
) {
Log.d(TAG, "Start playing audio...")
try {
withContext(Dispatchers.IO) {
var lastProgressUpdateMs = 0L
audioTrackState.value?.release()
val audioTrack =
AudioTrack.Builder()
.setAudioAttributes(
AudioAttributes.Builder()
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()
)
.setAudioFormat(
AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
.setSampleRate(sampleRate)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.build()
)
.setTransferMode(AudioTrack.MODE_STATIC)
.setBufferSizeInBytes(audioData.size)
.build()
val bytesPerFrame = 2 // For PCM 16-bit Mono
val totalFrames = audioData.size / bytesPerFrame
audioTrackState.value = audioTrack
audioTrack.write(audioData, 0, audioData.size)
audioTrack.play()
// Coroutine to monitor progress
while (isActive && audioTrack.playState == AudioTrack.PLAYSTATE_PLAYING) {
val currentFrames = audioTrack.playbackHeadPosition
if (currentFrames >= totalFrames) {
break // Exit loop when playback is done
}
val progress = currentFrames.toFloat() / totalFrames
val curMs = System.currentTimeMillis()
if (curMs - lastProgressUpdateMs > 30) {
onProgress(progress)
lastProgressUpdateMs = curMs
}
}
if (isActive) {
audioTrackState.value?.stop()
}
}
} catch (e: Exception) {
// Ignore
} finally {
onProgress(1f)
onCompletion()
}
}
private fun stopPlayAudio(audioTrackState: MutableState<AudioTrack?>) {
Log.d(TAG, "Stopping playing audio...")
val audioTrack = audioTrackState.value
audioTrack?.stop()
audioTrack?.release()
audioTrackState.value = null
}
/**
* Processes a raw PCM 16-bit audio byte array to generate a list of normalized amplitude levels for
* visualization.
*/
private fun generateAmplitudeLevels(audioData: ByteArray, barCount: Int): List<Float> {
if (audioData.isEmpty()) {
return List(barCount) { 0f }
}
// 1. Parse bytes into 16-bit short samples (PCM 16-bit)
val shortBuffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
val samples = ShortArray(shortBuffer.remaining())
shortBuffer.get(samples)
if (samples.isEmpty()) {
return List(barCount) { 0f }
}
// 2. Determine the size of each chunk
val chunkSize = samples.size / barCount
val amplitudeLevels = mutableListOf<Float>()
// 3. Get the max value for each chunk
for (i in 0 until barCount) {
val chunkStart = i * chunkSize
val chunkEnd = (chunkStart + chunkSize).coerceAtMost(samples.size)
var maxAmplitudeInChunk = 0.0
for (j in chunkStart until chunkEnd) {
val sampleAbs = kotlin.math.abs(samples[j].toDouble())
if (sampleAbs > maxAmplitudeInChunk) {
maxAmplitudeInChunk = sampleAbs
}
}
// 4. Normalize the value (0 to 1)
// Short.MAX_VALUE is 32767.0, a good reference for max amplitude
val normalizedRms = (maxAmplitudeInChunk / Short.MAX_VALUE).toFloat().coerceIn(0f, 1f)
amplitudeLevels.add(normalizedRms)
}
// Normalize the resulting levels so that the max value becomes 0.9.
val maxVal = amplitudeLevels.max()
if (maxVal == 0f) {
return amplitudeLevels
}
val scaleFactor = 0.9f / maxVal
return amplitudeLevels.map { it * scaleFactor }
}

View file

@ -0,0 +1,327 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.chat
import android.annotation.SuppressLint
import android.content.Context
import android.media.AudioFormat
import android.media.AudioRecord
import android.media.MediaRecorder
import android.util.Log
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ArrowUpward
import androidx.compose.material.icons.rounded.Mic
import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.DisposableEffect
import androidx.compose.runtime.MutableLongState
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.ai.edge.gallery.ui.theme.customColors
import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.math.abs
import kotlin.math.pow
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
private const val TAG = "AGAudioRecorderPanel"
private const val CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO
private const val AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT
/**
* A Composable that provides an audio recording panel. It allows users to record audio clips,
* displays the recording duration and a live amplitude visualization, and provides options to play
* back the recorded clip or send it.
*/
@Composable
fun AudioRecorderPanel(onSendAudioClip: (ByteArray) -> Unit) {
val context = LocalContext.current
val coroutineScope = rememberCoroutineScope()
var isRecording by remember { mutableStateOf(false) }
val elapsedMs = remember { mutableLongStateOf(0L) }
val audioRecordState = remember { mutableStateOf<AudioRecord?>(null) }
val audioStream = remember { ByteArrayOutputStream() }
val recordedBytes = remember { mutableStateOf<ByteArray?>(null) }
var currentAmplitude by remember { mutableIntStateOf(0) }
val elapsedSeconds by remember {
derivedStateOf { "%.1f".format(elapsedMs.value.toFloat() / 1000f) }
}
// Cleanup on Composable Disposal.
DisposableEffect(Unit) { onDispose { audioRecordState.value?.release() } }
Column(modifier = Modifier.padding(bottom = 12.dp)) {
// Title bar.
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
) {
// Logo and state.
Row(
modifier = Modifier.padding(start = 16.dp),
horizontalArrangement = Arrangement.spacedBy(12.dp),
) {
Icon(
painterResource(R.drawable.logo),
modifier = Modifier.size(20.dp),
contentDescription = "",
tint = Color.Unspecified,
)
Text(
"Record audio clip (up to $MAX_AUDIO_CLIP_DURATION_SEC seconds)",
style = MaterialTheme.typography.labelLarge,
)
}
}
// Recorded clip.
Row(
modifier = Modifier.fillMaxWidth().padding(vertical = 12.dp).height(40.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.Center,
) {
val curRecordedBytes = recordedBytes.value
if (curRecordedBytes == null) {
// Info message when there is no recorded clip and the recording has not started yet.
if (!isRecording) {
Text(
"Tap the record button to start",
style = MaterialTheme.typography.labelLarge,
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
}
// Visualization for clip being recorded.
else {
Row(
horizontalArrangement = Arrangement.spacedBy(12.dp),
verticalAlignment = Alignment.CenterVertically,
) {
Box(
modifier =
Modifier.size(8.dp)
.background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
)
Text("$elapsedSeconds s")
}
}
}
// Controls for recorded clip.
else {
Row {
// Clip player.
AudioPlaybackPanel(
audioData = curRecordedBytes,
sampleRate = SAMPLE_RATE,
isRecording = isRecording,
)
// Button to send the clip
IconButton(onClick = { onSendAudioClip(curRecordedBytes) }) {
Icon(Icons.Rounded.ArrowUpward, contentDescription = "")
}
}
}
}
// Buttons
Box(contentAlignment = Alignment.Center, modifier = Modifier.fillMaxWidth().height(40.dp)) {
// Visualization of the current amplitude.
if (isRecording) {
// Normalize the amplitude (0-32767) to a fraction (0.0-1.0)
// We use a power scale (exponent < 1) to make the pulse more visible for lower volumes.
val normalizedAmplitude = (currentAmplitude.toFloat() / 32767f).pow(0.35f)
// Define the min and max size of the circle
val minSize = 38.dp
val maxSize = 100.dp
// Map the normalized amplitude to our size range
val scale by
remember(normalizedAmplitude) {
derivedStateOf { (minSize + (maxSize - minSize) * normalizedAmplitude) / minSize }
}
Box(
modifier =
Modifier.size(minSize)
.graphicsLayer(scaleX = scale, scaleY = scale, clip = false, alpha = 0.3f)
.background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
)
}
// Record/stop button.
IconButton(
onClick = {
coroutineScope.launch {
if (!isRecording) {
isRecording = true
recordedBytes.value = null
startRecording(
context = context,
audioRecordState = audioRecordState,
audioStream = audioStream,
elapsedMs = elapsedMs,
onAmplitudeChanged = { currentAmplitude = it },
onMaxDurationReached = {
val curRecordedBytes =
stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
recordedBytes.value = curRecordedBytes
isRecording = false
},
)
} else {
val curRecordedBytes =
stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
recordedBytes.value = curRecordedBytes
isRecording = false
}
}
},
modifier =
Modifier.clip(CircleShape).background(MaterialTheme.customColors.recordButtonBgColor),
) {
Icon(
if (isRecording) Icons.Rounded.Stop else Icons.Rounded.Mic,
contentDescription = "",
tint = Color.White,
)
}
}
}
}
// Permission is checked in parent composable.
@SuppressLint("MissingPermission")
private suspend fun startRecording(
context: Context,
audioRecordState: MutableState<AudioRecord?>,
audioStream: ByteArrayOutputStream,
elapsedMs: MutableLongState,
onAmplitudeChanged: (Int) -> Unit,
onMaxDurationReached: () -> Unit,
) {
Log.d(TAG, "Start recording...")
val minBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT)
audioRecordState.value?.release()
val recorder =
AudioRecord(
MediaRecorder.AudioSource.MIC,
SAMPLE_RATE,
CHANNEL_CONFIG,
AUDIO_FORMAT,
minBufferSize,
)
audioRecordState.value = recorder
val buffer = ByteArray(minBufferSize)
// The function will only return when the recording is done (when stopRecording is called).
coroutineScope {
launch(Dispatchers.IO) {
recorder.startRecording()
val startMs = System.currentTimeMillis()
elapsedMs.value = 0L
while (audioRecordState.value?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
val bytesRead = recorder.read(buffer, 0, buffer.size)
if (bytesRead > 0) {
val currentAmplitude = calculatePeakAmplitude(buffer = buffer, bytesRead = bytesRead)
onAmplitudeChanged(currentAmplitude)
audioStream.write(buffer, 0, bytesRead)
}
elapsedMs.value = System.currentTimeMillis() - startMs
if (elapsedMs.value >= MAX_AUDIO_CLIP_DURATION_SEC * 1000) {
onMaxDurationReached()
break
}
}
}
}
}
private fun stopRecording(
audioRecordState: MutableState<AudioRecord?>,
audioStream: ByteArrayOutputStream,
): ByteArray {
Log.d(TAG, "Stopping recording...")
val recorder = audioRecordState.value
if (recorder?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
recorder.stop()
}
recorder?.release()
audioRecordState.value = null
val recordedBytes = audioStream.toByteArray()
audioStream.reset()
Log.d(TAG, "Stopped. Recorded ${recordedBytes.size} bytes.")
return recordedBytes
}
private fun calculatePeakAmplitude(buffer: ByteArray, bytesRead: Int): Int {
// Wrap the byte array in a ByteBuffer and set the order to little-endian
val shortBuffer =
ByteBuffer.wrap(buffer, 0, bytesRead).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
var maxAmplitude = 0
// Iterate through the short buffer to find the maximum absolute value
while (shortBuffer.hasRemaining()) {
val currentSample = abs(shortBuffer.get().toInt())
if (currentSample > maxAmplitude) {
maxAmplitude = currentSample
}
}
return maxAmplitude
}

View file

@ -17,18 +17,22 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap import android.graphics.Bitmap
import android.util.Log
import androidx.compose.ui.graphics.ImageBitmap import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.unit.Dp import androidx.compose.ui.unit.Dp
import com.google.ai.edge.gallery.common.Classification import com.google.ai.edge.gallery.common.Classification
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.PromptTemplate import com.google.ai.edge.gallery.data.PromptTemplate
private const val TAG = "AGChatMessage"
enum class ChatMessageType { enum class ChatMessageType {
INFO, INFO,
WARNING, WARNING,
TEXT, TEXT,
IMAGE, IMAGE,
IMAGE_WITH_HISTORY, IMAGE_WITH_HISTORY,
AUDIO_CLIP,
LOADING, LOADING,
CLASSIFICATION, CLASSIFICATION,
CONFIG_VALUES_CHANGE, CONFIG_VALUES_CHANGE,
@ -121,6 +125,90 @@ class ChatMessageImage(
} }
} }
/** Chat message for audio clip. */
class ChatMessageAudioClip(
val audioData: ByteArray,
val sampleRate: Int,
override val side: ChatSide,
override val latencyMs: Float = 0f,
) : ChatMessage(type = ChatMessageType.AUDIO_CLIP, side = side, latencyMs = latencyMs) {
override fun clone(): ChatMessageAudioClip {
return ChatMessageAudioClip(
audioData = audioData,
sampleRate = sampleRate,
side = side,
latencyMs = latencyMs,
)
}
fun genByteArrayForWav(): ByteArray {
val header = ByteArray(44)
val pcmDataSize = audioData.size
val wavFileSize = pcmDataSize + 44 // 44 bytes for the header
val channels = 1 // Mono
val bitsPerSample: Short = 16
val byteRate = sampleRate * channels * bitsPerSample / 8
Log.d(TAG, "Wav metadata: sampleRate: $sampleRate")
// RIFF/WAVE header
header[0] = 'R'.code.toByte()
header[1] = 'I'.code.toByte()
header[2] = 'F'.code.toByte()
header[3] = 'F'.code.toByte()
header[4] = (wavFileSize and 0xff).toByte()
header[5] = (wavFileSize shr 8 and 0xff).toByte()
header[6] = (wavFileSize shr 16 and 0xff).toByte()
header[7] = (wavFileSize shr 24 and 0xff).toByte()
header[8] = 'W'.code.toByte()
header[9] = 'A'.code.toByte()
header[10] = 'V'.code.toByte()
header[11] = 'E'.code.toByte()
header[12] = 'f'.code.toByte()
header[13] = 'm'.code.toByte()
header[14] = 't'.code.toByte()
header[15] = ' '.code.toByte()
header[16] = 16
header[17] = 0
header[18] = 0
header[19] = 0 // Sub-chunk size (16 for PCM)
header[20] = 1
header[21] = 0 // Audio format (1 for PCM)
header[22] = channels.toByte()
header[23] = 0 // Number of channels
header[24] = (sampleRate and 0xff).toByte()
header[25] = (sampleRate shr 8 and 0xff).toByte()
header[26] = (sampleRate shr 16 and 0xff).toByte()
header[27] = (sampleRate shr 24 and 0xff).toByte()
header[28] = (byteRate and 0xff).toByte()
header[29] = (byteRate shr 8 and 0xff).toByte()
header[30] = (byteRate shr 16 and 0xff).toByte()
header[31] = (byteRate shr 24 and 0xff).toByte()
header[32] = (channels * bitsPerSample / 8).toByte()
header[33] = 0 // Block align
header[34] = bitsPerSample.toByte()
header[35] = (bitsPerSample.toInt() shr 8 and 0xff).toByte() // Bits per sample
header[36] = 'd'.code.toByte()
header[37] = 'a'.code.toByte()
header[38] = 't'.code.toByte()
header[39] = 'a'.code.toByte()
header[40] = (pcmDataSize and 0xff).toByte()
header[41] = (pcmDataSize shr 8 and 0xff).toByte()
header[42] = (pcmDataSize shr 16 and 0xff).toByte()
header[43] = (pcmDataSize shr 24 and 0xff).toByte()
return header + audioData
}
fun getDurationInSeconds(): Float {
// PCM 16-bit
val bytesPerSample = 2
val bytesPerFrame = bytesPerSample * 1 // mono
val totalFrames = audioData.size.toFloat() / bytesPerFrame
return totalFrames / sampleRate
}
}
/** Chat message for images with history. */ /** Chat message for images with history. */
class ChatMessageImageWithHistory( class ChatMessageImageWithHistory(
val bitmaps: List<Bitmap>, val bitmaps: List<Bitmap>,

View file

@ -137,6 +137,19 @@ fun ChatPanel(
} }
imageMessageCount imageMessageCount
} }
val audioClipMesssageCountToLastconfigChange =
remember(messages) {
var audioClipMessageCount = 0
for (message in messages.reversed()) {
if (message is ChatMessageConfigValuesChange) {
break
}
if (message is ChatMessageAudioClip) {
audioClipMessageCount++
}
}
audioClipMessageCount
}
var curMessage by remember { mutableStateOf("") } // Correct state var curMessage by remember { mutableStateOf("") } // Correct state
val focusManager = LocalFocusManager.current val focusManager = LocalFocusManager.current
@ -342,6 +355,9 @@ fun ChatPanel(
imageHistoryCurIndex = imageHistoryCurIndex, imageHistoryCurIndex = imageHistoryCurIndex,
) )
// Audio clip.
is ChatMessageAudioClip -> MessageBodyAudioClip(message = message)
// Classification result // Classification result
is ChatMessageClassification -> is ChatMessageClassification ->
MessageBodyClassification( MessageBodyClassification(
@ -467,6 +483,22 @@ fun ChatPanel(
) )
} }
} }
// Show an info message for ask image task to get users started.
else if (task.type == TaskType.LLM_ASK_AUDIO && messages.isEmpty()) {
Column(
modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center,
) {
MessageBodyInfo(
ChatMessageInfo(
content =
"To get started, tap the + icon to add your audio clip. Limited to 1 clip up to 30 seconds long."
),
smallFontSize = false,
)
}
}
} }
// Chat input // Chat input
@ -482,6 +514,7 @@ fun ChatPanel(
isResettingSession = uiState.isResettingSession, isResettingSession = uiState.isResettingSession,
modelPreparing = uiState.preparing, modelPreparing = uiState.preparing,
imageMessageCount = imageMessageCountToLastConfigChange, imageMessageCount = imageMessageCountToLastConfigChange,
audioClipMessageCount = audioClipMesssageCountToLastconfigChange,
modelInitializing = modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes, textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
@ -504,7 +537,10 @@ fun ChatPanel(
onStopButtonClicked = onStopButtonClicked, onStopButtonClicked = onStopButtonClicked,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, // showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false, showPromptTemplatesInMenu = false,
showImagePickerInMenu = selectedModel.llmSupportImage, showImagePickerInMenu =
selectedModel.llmSupportImage && task.type === TaskType.LLM_ASK_IMAGE,
showAudioItemsInMenu =
selectedModel.llmSupportAudio && task.type === TaskType.LLM_ASK_AUDIO,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress, showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
) )
} }

View file

@ -53,7 +53,7 @@ data class ChatUiState(
) )
/** ViewModel responsible for managing the chat UI state and handling chat-related operations. */ /** ViewModel responsible for managing the chat UI state and handling chat-related operations. */
open class ChatViewModel(val task: Task) : ViewModel() { abstract class ChatViewModel(val task: Task) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task)) private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow() val uiState = _uiState.asStateFlow()

View file

@ -14,8 +14,20 @@
* limitations under the License. * limitations under the License.
*/ */
package com.google.ai.edge.gallery.ui.preview package com.google.ai.edge.gallery.ui.common.chat
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel import androidx.compose.foundation.layout.padding
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
class PreviewLlmSingleTurnViewModel : LlmSingleTurnViewModel(task = TASK_TEST1) @Composable
fun MessageBodyAudioClip(message: ChatMessageAudioClip, modifier: Modifier = Modifier) {
AudioPlaybackPanel(
audioData = message.audioData,
sampleRate = message.sampleRate,
isRecording = false,
modifier = Modifier.padding(end = 16.dp),
onDarkBg = true,
)
}

View file

@ -21,6 +21,7 @@ package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme // import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.Manifest import android.Manifest
import android.content.Context import android.content.Context
import android.content.Intent
import android.content.pm.PackageManager import android.content.pm.PackageManager
import android.graphics.Bitmap import android.graphics.Bitmap
import android.graphics.BitmapFactory import android.graphics.BitmapFactory
@ -65,9 +66,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.Send import androidx.compose.material.icons.automirrored.rounded.Send
import androidx.compose.material.icons.rounded.Add import androidx.compose.material.icons.rounded.Add
import androidx.compose.material.icons.rounded.AudioFile
import androidx.compose.material.icons.rounded.Close import androidx.compose.material.icons.rounded.Close
import androidx.compose.material.icons.rounded.FlipCameraAndroid import androidx.compose.material.icons.rounded.FlipCameraAndroid
import androidx.compose.material.icons.rounded.History import androidx.compose.material.icons.rounded.History
import androidx.compose.material.icons.rounded.Mic
import androidx.compose.material.icons.rounded.Photo import androidx.compose.material.icons.rounded.Photo
import androidx.compose.material.icons.rounded.PhotoCamera import androidx.compose.material.icons.rounded.PhotoCamera
import androidx.compose.material.icons.rounded.PostAdd import androidx.compose.material.icons.rounded.PostAdd
@ -107,9 +110,14 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.viewinterop.AndroidView import androidx.compose.ui.viewinterop.AndroidView
import androidx.core.content.ContextCompat import androidx.core.content.ContextCompat
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import com.google.ai.edge.gallery.common.AudioClip
import com.google.ai.edge.gallery.common.convertWavToMonoWithMaxSeconds
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_COUNT
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import java.util.concurrent.Executors import java.util.concurrent.Executors
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
private const val TAG = "AGMessageInputText" private const val TAG = "AGMessageInputText"
@ -128,6 +136,7 @@ fun MessageInputText(
isResettingSession: Boolean, isResettingSession: Boolean,
inProgress: Boolean, inProgress: Boolean,
imageMessageCount: Int, imageMessageCount: Int,
audioClipMessageCount: Int,
modelInitializing: Boolean, modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int, @StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit, onValueChanged: (String) -> Unit,
@ -137,6 +146,7 @@ fun MessageInputText(
onStopButtonClicked: () -> Unit = {}, onStopButtonClicked: () -> Unit = {},
showPromptTemplatesInMenu: Boolean = false, showPromptTemplatesInMenu: Boolean = false,
showImagePickerInMenu: Boolean = false, showImagePickerInMenu: Boolean = false,
showAudioItemsInMenu: Boolean = false,
showStopButtonWhenInProgress: Boolean = false, showStopButtonWhenInProgress: Boolean = false,
) { ) {
val context = LocalContext.current val context = LocalContext.current
@ -146,7 +156,12 @@ fun MessageInputText(
var showTextInputHistorySheet by remember { mutableStateOf(false) } var showTextInputHistorySheet by remember { mutableStateOf(false) }
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) } var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var showAudioRecorderBottomSheet by remember { mutableStateOf(false) }
val audioRecorderSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) } var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
var pickedAudioClips by remember { mutableStateOf<List<AudioClip>>(listOf()) }
var hasFrontCamera by remember { mutableStateOf(false) }
val updatePickedImages: (List<Bitmap>) -> Unit = { bitmaps -> val updatePickedImages: (List<Bitmap>) -> Unit = { bitmaps ->
var newPickedImages: MutableList<Bitmap> = mutableListOf() var newPickedImages: MutableList<Bitmap> = mutableListOf()
newPickedImages.addAll(pickedImages) newPickedImages.addAll(pickedImages)
@ -156,7 +171,16 @@ fun MessageInputText(
} }
pickedImages = newPickedImages.toList() pickedImages = newPickedImages.toList()
} }
var hasFrontCamera by remember { mutableStateOf(false) }
val updatePickedAudioClips: (List<AudioClip>) -> Unit = { audioDataList ->
var newAudioDataList: MutableList<AudioClip> = mutableListOf()
newAudioDataList.addAll(pickedAudioClips)
newAudioDataList.addAll(audioDataList)
if (newAudioDataList.size > MAX_AUDIO_CLIP_COUNT) {
newAudioDataList = newAudioDataList.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
}
pickedAudioClips = newAudioDataList.toList()
}
LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) } LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) }
@ -170,6 +194,16 @@ fun MessageInputText(
} }
} }
// Permission request when recording audio clips.
val recordAudioClipsPermissionLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
permissionGranted ->
if (permissionGranted) {
showAddContentMenu = false
showAudioRecorderBottomSheet = true
}
}
// Registers a photo picker activity launcher in single-select mode. // Registers a photo picker activity launcher in single-select mode.
val pickMedia = val pickMedia =
rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris -> rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris ->
@ -184,9 +218,31 @@ fun MessageInputText(
} }
} }
val pickWav =
rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
Log.d(TAG, "Picked wav file: $uri")
scope.launch(Dispatchers.IO) {
convertWavToMonoWithMaxSeconds(context = context, stereoUri = uri)?.let { audioClip ->
updatePickedAudioClips(
listOf(
AudioClip(audioData = audioClip.audioData, sampleRate = audioClip.sampleRate)
)
)
}
}
}
} else {
Log.d(TAG, "Wav picking cancelled.")
}
}
Column { Column {
// A preview panel for the selected image. // A preview panel for the selected images and audio clips.
if (pickedImages.isNotEmpty()) { if (pickedImages.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
Row( Row(
modifier = modifier =
Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()), Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()),
@ -203,20 +259,30 @@ fun MessageInputText(
.clip(RoundedCornerShape(8.dp)) .clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)), .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
) )
MediaPanelCloseButton { pickedImages = pickedImages.filter { image != it } }
}
}
for ((index, audioClip) in pickedAudioClips.withIndex()) {
Box(contentAlignment = Alignment.TopEnd) {
Box( Box(
modifier = modifier =
Modifier.offset(x = 10.dp, y = (-10).dp) Modifier.shadow(2.dp, shape = RoundedCornerShape(8.dp))
.clip(CircleShape) .clip(RoundedCornerShape(8.dp))
.background(MaterialTheme.colorScheme.surface) .background(MaterialTheme.colorScheme.surface)
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape) .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp))
.clickable { pickedImages = pickedImages.filter { image != it } }
) { ) {
Icon( AudioPlaybackPanel(
Icons.Rounded.Close, audioData = audioClip.audioData,
contentDescription = "", sampleRate = audioClip.sampleRate,
modifier = Modifier.padding(3.dp).size(16.dp), isRecording = false,
modifier = Modifier.padding(end = 16.dp),
) )
} }
MediaPanelCloseButton {
pickedAudioClips =
pickedAudioClips.filterIndexed { curIndex, curAudioData -> curIndex != index }
}
} }
} }
} }
@ -239,10 +305,13 @@ fun MessageInputText(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
) { ) {
val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT
val enableRecordAudioClipMenuItems =
(audioClipMessageCount + pickedAudioClips.size) < MAX_AUDIO_CLIP_COUNT
DropdownMenu( DropdownMenu(
expanded = showAddContentMenu, expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false }, onDismissRequest = { showAddContentMenu = false },
) { ) {
// Image related menu items.
if (showImagePickerInMenu) { if (showImagePickerInMenu) {
// Take a picture. // Take a picture.
DropdownMenuItem( DropdownMenuItem(
@ -295,6 +364,70 @@ fun MessageInputText(
) )
} }
// Audio related menu items.
if (showAudioItemsInMenu) {
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(Icons.Rounded.Mic, contentDescription = "")
Text("Record audio clip")
}
},
enabled = enableRecordAudioClipMenuItems,
onClick = {
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO) -> {
showAddContentMenu = false
showAudioRecorderBottomSheet = true
}
// Otherwise, ask for permission
else -> {
recordAudioClipsPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO)
}
}
},
)
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(Icons.Rounded.AudioFile, contentDescription = "")
Text("Pick wav file")
}
},
enabled = enableRecordAudioClipMenuItems,
onClick = {
showAddContentMenu = false
// Show file picker.
val intent =
Intent(Intent.ACTION_GET_CONTENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "audio/*"
// Provide a list of more specific MIME types to filter for.
val mimeTypes = arrayOf("audio/wav", "audio/x-wav")
putExtra(Intent.EXTRA_MIME_TYPES, mimeTypes)
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
.addFlags(Intent.FLAG_GRANT_PERSISTABLE_URI_PERMISSION)
.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
}
pickWav.launch(intent)
},
)
}
// Prompt templates. // Prompt templates.
if (showPromptTemplatesInMenu) { if (showPromptTemplatesInMenu) {
DropdownMenuItem( DropdownMenuItem(
@ -369,15 +502,22 @@ fun MessageInputText(
) )
} }
} }
} // Send button. Only shown when text is not empty. }
else if (curMessage.isNotEmpty()) { // Send button. Only shown when text is not empty, or there is at least one recorded
// audio clip.
else if (curMessage.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
IconButton( IconButton(
enabled = !inProgress && !isResettingSession, enabled = !inProgress && !isResettingSession,
onClick = { onClick = {
onSendMessage( onSendMessage(
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim()) createMessagesToSend(
pickedImages = pickedImages,
audioClips = pickedAudioClips,
text = curMessage.trim(),
)
) )
pickedImages = listOf() pickedImages = listOf()
pickedAudioClips = listOf()
}, },
colors = colors =
IconButtonDefaults.iconButtonColors( IconButtonDefaults.iconButtonColors(
@ -403,8 +543,15 @@ fun MessageInputText(
history = modelManagerUiState.textInputHistory, history = modelManagerUiState.textInputHistory,
onDismissed = { showTextInputHistorySheet = false }, onDismissed = { showTextInputHistorySheet = false },
onHistoryItemClicked = { item -> onHistoryItemClicked = { item ->
onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item)) onSendMessage(
createMessagesToSend(
pickedImages = pickedImages,
audioClips = pickedAudioClips,
text = item,
)
)
pickedImages = listOf() pickedImages = listOf()
pickedAudioClips = listOf()
modelManagerViewModel.promoteTextInputHistoryItem(item) modelManagerViewModel.promoteTextInputHistoryItem(item)
}, },
onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) }, onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) },
@ -582,6 +729,43 @@ fun MessageInputText(
} }
} }
} }
if (showAudioRecorderBottomSheet) {
ModalBottomSheet(
sheetState = audioRecorderSheetState,
onDismissRequest = { showAudioRecorderBottomSheet = false },
) {
AudioRecorderPanel(
onSendAudioClip = { audioData ->
scope.launch {
updatePickedAudioClips(
listOf(AudioClip(audioData = audioData, sampleRate = SAMPLE_RATE))
)
audioRecorderSheetState.hide()
showAudioRecorderBottomSheet = false
}
}
)
}
}
}
@Composable
private fun MediaPanelCloseButton(onClicked: () -> Unit) {
Box(
modifier =
Modifier.offset(x = 10.dp, y = (-10).dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surface)
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
.clickable { onClicked() }
) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
modifier = Modifier.padding(3.dp).size(16.dp),
)
}
} }
private fun handleImagesSelected( private fun handleImagesSelected(
@ -641,20 +825,50 @@ private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) {
) )
} }
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> { private fun createMessagesToSend(
pickedImages: List<Bitmap>,
audioClips: List<AudioClip>,
text: String,
): List<ChatMessage> {
var messages: MutableList<ChatMessage> = mutableListOf() var messages: MutableList<ChatMessage> = mutableListOf()
// Add image messages.
var imageMessages: MutableList<ChatMessageImage> = mutableListOf()
if (pickedImages.isNotEmpty()) { if (pickedImages.isNotEmpty()) {
for (image in pickedImages) { for (image in pickedImages) {
messages.add( imageMessages.add(
ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER) ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER)
) )
} }
} }
// Cap the number of image messages. // Cap the number of image messages.
if (messages.size > MAX_IMAGE_COUNT) { if (imageMessages.size > MAX_IMAGE_COUNT) {
messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT) imageMessages = imageMessages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
}
messages.addAll(imageMessages)
// Add audio messages.
var audioMessages: MutableList<ChatMessageAudioClip> = mutableListOf()
if (audioClips.isNotEmpty()) {
for (audioClip in audioClips) {
audioMessages.add(
ChatMessageAudioClip(
audioData = audioClip.audioData,
sampleRate = audioClip.sampleRate,
side = ChatSide.USER,
)
)
}
}
// Cap the number of audio messages.
if (audioMessages.size > MAX_AUDIO_CLIP_COUNT) {
audioMessages = audioMessages.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
}
messages.addAll(audioMessages)
if (text.isNotEmpty()) {
messages.add(ChatMessageText(content = text, side = ChatSide.USER))
} }
messages.add(ChatMessageText(content = text, side = ChatSide.USER))
return messages return messages
} }

View file

@ -121,6 +121,7 @@ private val IMPORT_CONFIGS_LLM: List<Config> =
valueType = ValueType.FLOAT, valueType = ValueType.FLOAT,
), ),
BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false), BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false),
BooleanSwitchConfig(key = ConfigKey.SUPPORT_AUDIO, defaultValue = false),
SegmentedButtonConfig( SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS, key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label, defaultValue = Accelerator.CPU.label,
@ -230,6 +231,12 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
valueType = ValueType.BOOLEAN, valueType = ValueType.BOOLEAN,
) )
as Boolean as Boolean
val supportAudio =
convertValueToTargetType(
value = values.get(ConfigKey.SUPPORT_AUDIO.label)!!,
valueType = ValueType.BOOLEAN,
)
as Boolean
val importedModel: ImportedModel = val importedModel: ImportedModel =
ImportedModel.newBuilder() ImportedModel.newBuilder()
.setFileName(fileName) .setFileName(fileName)
@ -242,6 +249,7 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
.setDefaultTopp(defaultTopp) .setDefaultTopp(defaultTopp)
.setDefaultTemperature(defaultTemperature) .setDefaultTemperature(defaultTemperature)
.setSupportImage(supportImage) .setSupportImage(supportImage)
.setSupportAudio(supportAudio)
.build() .build()
) )
.build() .build()

View file

@ -173,7 +173,7 @@ fun SettingsDialog(
color = MaterialTheme.colorScheme.onSurfaceVariant, color = MaterialTheme.colorScheme.onSurfaceVariant,
) )
Text( Text(
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}", "Expires at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant, color = MaterialTheme.colorScheme.onSurfaceVariant,
) )

View file

@ -1,78 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Forum: ImageVector
get() {
if (_Forum != null) return _Forum!!
_Forum =
ImageVector.Builder(
name = "Forum",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(280f, 720f)
quadToRelative(-17f, 0f, -28.5f, -11.5f)
reflectiveQuadTo(240f, 680f)
verticalLineToRelative(-80f)
horizontalLineToRelative(520f)
verticalLineToRelative(-360f)
horizontalLineToRelative(80f)
quadToRelative(17f, 0f, 28.5f, 11.5f)
reflectiveQuadTo(880f, 280f)
verticalLineToRelative(600f)
lineTo(720f, 720f)
close()
moveTo(80f, 680f)
verticalLineToRelative(-560f)
quadToRelative(0f, -17f, 11.5f, -28.5f)
reflectiveQuadTo(120f, 80f)
horizontalLineToRelative(520f)
quadToRelative(17f, 0f, 28.5f, 11.5f)
reflectiveQuadTo(680f, 120f)
verticalLineToRelative(360f)
quadToRelative(0f, 17f, -11.5f, 28.5f)
reflectiveQuadTo(640f, 520f)
horizontalLineTo(240f)
close()
moveToRelative(520f, -240f)
verticalLineToRelative(-280f)
horizontalLineTo(160f)
verticalLineToRelative(280f)
close()
moveToRelative(-440f, 0f)
verticalLineToRelative(-280f)
close()
}
}
.build()
return _Forum!!
}
private var _Forum: ImageVector? = null

View file

@ -1,73 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Mms: ImageVector
get() {
if (_Mms != null) return _Mms!!
_Mms =
ImageVector.Builder(
name = "Mms",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(240f, 560f)
horizontalLineToRelative(480f)
lineTo(570f, 360f)
lineTo(450f, 520f)
lineToRelative(-90f, -120f)
close()
moveTo(80f, 880f)
verticalLineToRelative(-720f)
quadToRelative(0f, -33f, 23.5f, -56.5f)
reflectiveQuadTo(160f, 80f)
horizontalLineToRelative(640f)
quadToRelative(33f, 0f, 56.5f, 23.5f)
reflectiveQuadTo(880f, 160f)
verticalLineToRelative(480f)
quadToRelative(0f, 33f, -23.5f, 56.5f)
reflectiveQuadTo(800f, 720f)
horizontalLineTo(240f)
close()
moveToRelative(126f, -240f)
horizontalLineToRelative(594f)
verticalLineToRelative(-480f)
horizontalLineTo(160f)
verticalLineToRelative(525f)
close()
moveToRelative(-46f, 0f)
verticalLineToRelative(-480f)
close()
}
}
.build()
return _Mms!!
}
private var _Mms: ImageVector? = null

View file

@ -1,87 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Widgets: ImageVector
get() {
if (_Widgets != null) return _Widgets!!
_Widgets =
ImageVector.Builder(
name = "Widgets",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(666f, 520f)
lineTo(440f, 294f)
lineToRelative(226f, -226f)
lineToRelative(226f, 226f)
close()
moveToRelative(-546f, -80f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(400f, 400f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(-400f, 0f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(80f, -480f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(200f)
close()
moveToRelative(467f, 48f)
lineToRelative(113f, -113f)
lineToRelative(-113f, -113f)
lineToRelative(-113f, 113f)
close()
moveToRelative(-67f, 352f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(600f)
close()
moveToRelative(-400f, 0f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(200f)
close()
moveToRelative(400f, -160f)
}
}
.build()
return _Widgets!!
}
private var _Widgets: ImageVector? = null

View file

@ -62,13 +62,13 @@ object LlmChatModelHelper {
Accelerator.GPU.label -> LlmInference.Backend.GPU Accelerator.GPU.label -> LlmInference.Backend.GPU
else -> LlmInference.Backend.GPU else -> LlmInference.Backend.GPU
} }
val options = val optionsBuilder =
LlmInference.LlmInferenceOptions.builder() LlmInference.LlmInferenceOptions.builder()
.setModelPath(model.getPath(context = context)) .setModelPath(model.getPath(context = context))
.setMaxTokens(maxTokens) .setMaxTokens(maxTokens)
.setPreferredBackend(preferredBackend) .setPreferredBackend(preferredBackend)
.setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0) .setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0)
.build() val options = optionsBuilder.build()
// Create an instance of the LLM Inference task and session. // Create an instance of the LLM Inference task and session.
try { try {
@ -82,7 +82,9 @@ object LlmChatModelHelper {
.setTopP(topP) .setTopP(topP)
.setTemperature(temperature) .setTemperature(temperature)
.setGraphOptions( .setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() GraphOptions.builder()
.setEnableVisionModality(model.llmSupportImage)
.build()
) )
.build(), .build(),
) )
@ -115,7 +117,9 @@ object LlmChatModelHelper {
.setTopP(topP) .setTopP(topP)
.setTemperature(temperature) .setTemperature(temperature)
.setGraphOptions( .setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() GraphOptions.builder()
.setEnableVisionModality(model.llmSupportImage)
.build()
) )
.build(), .build(),
) )
@ -159,6 +163,7 @@ object LlmChatModelHelper {
resultListener: ResultListener, resultListener: ResultListener,
cleanUpListener: CleanUpListener, cleanUpListener: CleanUpListener,
images: List<Bitmap> = listOf(), images: List<Bitmap> = listOf(),
audioClips: List<ByteArray> = listOf(),
) { ) {
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
@ -172,10 +177,16 @@ object LlmChatModelHelper {
// For a model that supports image modality, we need to add the text query chunk before adding // For a model that supports image modality, we need to add the text query chunk before adding
// image. // image.
val session = instance.session val session = instance.session
session.addQueryChunk(input) if (input.trim().isNotEmpty()) {
session.addQueryChunk(input)
}
for (image in images) { for (image in images) {
session.addImage(BitmapImageBuilder(image).build()) session.addImage(BitmapImageBuilder(image).build())
} }
for (audioClip in audioClips) {
// Uncomment when audio is supported.
// session.addAudio(audioClip)
}
val unused = session.generateResponseAsync(resultListener) val unused = session.generateResponseAsync(resultListener)
} }
} }

View file

@ -20,8 +20,7 @@ import android.graphics.Bitmap
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatView import com.google.ai.edge.gallery.ui.common.chat.ChatView
@ -36,12 +35,16 @@ object LlmAskImageDestination {
val route = "LlmAskImageRoute" val route = "LlmAskImageRoute"
} }
object LlmAskAudioDestination {
val route = "LlmAskAudioRoute"
}
@Composable @Composable
fun LlmChatScreen( fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmChatViewModel = viewModel(factory = ViewModelProvider.Factory), viewModel: LlmChatViewModel,
) { ) {
ChatViewWrapper( ChatViewWrapper(
viewModel = viewModel, viewModel = viewModel,
@ -56,7 +59,22 @@ fun LlmAskImageScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmAskImageViewModel = viewModel(factory = ViewModelProvider.Factory), viewModel: LlmAskImageViewModel,
) {
ChatViewWrapper(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = navigateUp,
modifier = modifier,
)
}
@Composable
fun LlmAskAudioScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmAskAudioViewModel,
) { ) {
ChatViewWrapper( ChatViewWrapper(
viewModel = viewModel, viewModel = viewModel,
@ -68,7 +86,7 @@ fun LlmAskImageScreen(
@Composable @Composable
fun ChatViewWrapper( fun ChatViewWrapper(
viewModel: LlmChatViewModel, viewModel: LlmChatViewModelBase,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
@ -86,6 +104,7 @@ fun ChatViewWrapper(
var text = "" var text = ""
val images: MutableList<Bitmap> = mutableListOf() val images: MutableList<Bitmap> = mutableListOf()
val audioMessages: MutableList<ChatMessageAudioClip> = mutableListOf()
var chatMessageText: ChatMessageText? = null var chatMessageText: ChatMessageText? = null
for (message in messages) { for (message in messages) {
if (message is ChatMessageText) { if (message is ChatMessageText) {
@ -93,14 +112,17 @@ fun ChatViewWrapper(
text = message.content text = message.content
} else if (message is ChatMessageImage) { } else if (message is ChatMessageImage) {
images.add(message.bitmap) images.add(message.bitmap)
} else if (message is ChatMessageAudioClip) {
audioMessages.add(message)
} }
} }
if (text.isNotEmpty() && chatMessageText != null) { if ((text.isNotEmpty() && chatMessageText != null) || audioMessages.isNotEmpty()) {
modelManagerViewModel.addTextInputHistory(text) modelManagerViewModel.addTextInputHistory(text)
viewModel.generateResponse( viewModel.generateResponse(
model = model, model = model,
input = text, input = text,
images = images, images = images,
audioMessages = audioMessages,
onError = { onError = {
viewModel.handleError( viewModel.handleError(
context = context, context = context,

View file

@ -22,9 +22,11 @@ import android.util.Log
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
@ -34,6 +36,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import com.google.ai.edge.gallery.ui.common.chat.Stat import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -47,11 +51,12 @@ private val STATS =
Stat(id = "latency", label = "Latency", unit = "sec"), Stat(id = "latency", label = "Latency", unit = "sec"),
) )
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) { open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTask) {
fun generateResponse( fun generateResponse(
model: Model, model: Model,
input: String, input: String,
images: List<Bitmap> = listOf(), images: List<Bitmap> = listOf(),
audioMessages: List<ChatMessageAudioClip> = listOf(),
onError: () -> Unit, onError: () -> Unit,
) { ) {
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
@ -72,6 +77,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
var prefillTokens = instance.session.sizeInTokens(input) var prefillTokens = instance.session.sizeInTokens(input)
prefillTokens += images.size * 257 prefillTokens += images.size * 257
for (audioMessage in audioMessages) {
// 150ms = 1 audio token
val duration = audioMessage.getDurationInSeconds()
prefillTokens += (duration * 1000f / 150f).toInt()
}
var firstRun = true var firstRun = true
var timeToFirstToken = 0f var timeToFirstToken = 0f
@ -86,6 +96,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
model = model, model = model,
input = input, input = input,
images = images, images = images,
audioClips = audioMessages.map { it.genByteArrayForWav() },
resultListener = { partialResult, done -> resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis() val curTs = System.currentTimeMillis()
@ -214,7 +225,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
context: Context, context: Context,
model: Model, model: Model,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
triggeredMessage: ChatMessageText, triggeredMessage: ChatMessageText?,
) { ) {
// Clean up. // Clean up.
modelManagerViewModel.cleanupModel(task = task, model = model) modelManagerViewModel.cleanupModel(task = task, model = model)
@ -236,14 +247,27 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
) )
// Add the triggered message back. // Add the triggered message back.
addMessage(model = model, message = triggeredMessage) if (triggeredMessage != null) {
addMessage(model = model, message = triggeredMessage)
}
// Re-initialize the session/engine. // Re-initialize the session/engine.
modelManagerViewModel.initializeModel(context = context, task = task, model = model) modelManagerViewModel.initializeModel(context = context, task = task, model = model)
// Re-generate the response automatically. // Re-generate the response automatically.
generateResponse(model = model, input = triggeredMessage.content, onError = {}) if (triggeredMessage != null) {
generateResponse(model = model, input = triggeredMessage.content, onError = {})
}
} }
} }
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) @HiltViewModel
class LlmChatViewModel @Inject constructor() : LlmChatViewModelBase(curTask = TASK_LLM_CHAT)
@HiltViewModel
class LlmAskImageViewModel @Inject constructor() :
LlmChatViewModelBase(curTask = TASK_LLM_ASK_IMAGE)
@HiltViewModel
class LlmAskAudioViewModel @Inject constructor() :
LlmChatViewModelBase(curTask = TASK_LLM_ASK_AUDIO)

View file

@ -43,9 +43,8 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection import androidx.compose.ui.platform.LocalLayoutDirection
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.ui.ViewModelProvider import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.ui.common.ErrorDialog import com.google.ai.edge.gallery.ui.common.ErrorDialog
import com.google.ai.edge.gallery.ui.common.ModelPageAppBar import com.google.ai.edge.gallery.ui.common.ModelPageAppBar
import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
@ -67,9 +66,9 @@ fun LlmSingleTurnScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmSingleTurnViewModel = viewModel(factory = ViewModelProvider.Factory), viewModel: LlmSingleTurnViewModel,
) { ) {
val task = viewModel.task val task = TASK_LLM_PROMPT_LAB
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel val selectedModel = modelManagerUiState.selectedModel

View file

@ -27,6 +27,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.Stat import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
@ -63,8 +65,9 @@ private val STATS =
Stat(id = "latency", label = "Latency", unit = "sec"), Stat(id = "latency", label = "Latency", unit = "sec"),
) )
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() { @HiltViewModel
private val _uiState = MutableStateFlow(createUiState(task = task)) class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = TASK_LLM_PROMPT_LAB))
val uiState = _uiState.asStateFlow() val uiState = _uiState.asStateFlow()
fun generateResponse(model: Model, input: String) { fun generateResponse(model: Model, input: String) {

View file

@ -36,6 +36,7 @@ import com.google.ai.edge.gallery.data.ModelAllowlist
import com.google.ai.edge.gallery.data.ModelDownloadStatus import com.google.ai.edge.gallery.data.ModelDownloadStatus
import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.data.TASKS import com.google.ai.edge.gallery.data.TASKS
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
@ -51,9 +52,12 @@ import com.google.ai.edge.gallery.ui.common.AuthConfig
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.gson.Gson import com.google.gson.Gson
import com.google.gson.reflect.TypeToken import com.google.gson.reflect.TypeToken
import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext
import java.io.File import java.io.File
import java.net.HttpURLConnection import java.net.HttpURLConnection
import java.net.URL import java.net.URL
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
@ -133,11 +137,14 @@ data class PagerScrollState(val page: Int = 0, val offset: Float = 0f)
* cleaning up models. It also manages the UI state for model management, including the list of * cleaning up models. It also manages the UI state for model management, including the list of
* tasks, models, download statuses, and initialization statuses. * tasks, models, download statuses, and initialization statuses.
*/ */
open class ModelManagerViewModel( @HiltViewModel
open class ModelManagerViewModel
@Inject
constructor(
private val downloadRepository: DownloadRepository, private val downloadRepository: DownloadRepository,
private val dataStoreRepository: DataStoreRepository, private val dataStoreRepository: DataStoreRepository,
private val lifecycleProvider: AppLifecycleProvider, private val lifecycleProvider: AppLifecycleProvider,
context: Context, @ApplicationContext private val context: Context,
) : ViewModel() { ) : ViewModel() {
private val externalFilesDir = context.getExternalFilesDir(null) private val externalFilesDir = context.getExternalFilesDir(null)
private val inProgressWorkInfos: List<AGWorkInfo> = private val inProgressWorkInfos: List<AGWorkInfo> =
@ -281,15 +288,12 @@ open class ModelManagerViewModel(
} }
} }
when (task.type) { when (task.type) {
TaskType.LLM_CHAT -> TaskType.LLM_CHAT,
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone) TaskType.LLM_ASK_IMAGE,
TaskType.LLM_ASK_AUDIO,
TaskType.LLM_PROMPT_LAB -> TaskType.LLM_PROMPT_LAB ->
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone) LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
TaskType.LLM_ASK_IMAGE ->
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {} TaskType.TEST_TASK_2 -> {}
} }
@ -301,9 +305,11 @@ open class ModelManagerViewModel(
model.cleanUpAfterInit = false model.cleanUpAfterInit = false
Log.d(TAG, "Cleaning up model '${model.name}'...") Log.d(TAG, "Cleaning up model '${model.name}'...")
when (task.type) { when (task.type) {
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_CHAT,
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_PROMPT_LAB,
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model) TaskType.LLM_ASK_IMAGE,
TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model)
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {} TaskType.TEST_TASK_2 -> {}
} }
@ -410,14 +416,19 @@ open class ModelManagerViewModel(
// Create model. // Create model.
val model = createModelFromImportedModelInfo(info = info) val model = createModelFromImportedModelInfo(info = info)
for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) { for (task in
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
// Remove duplicated imported model if existed. // Remove duplicated imported model if existed.
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) { if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first") Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex) task.models.removeAt(modelIndex)
} }
if ((task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || task != TASK_LLM_ASK_IMAGE) { if (
(task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) ||
(task == TASK_LLM_ASK_AUDIO && model.llmSupportAudio) ||
(task != TASK_LLM_ASK_IMAGE && task != TASK_LLM_ASK_AUDIO)
) {
task.models.add(model) task.models.add(model)
} }
task.updateTrigger.value = System.currentTimeMillis() task.updateTrigger.value = System.currentTimeMillis()
@ -657,6 +668,7 @@ open class ModelManagerViewModel(
TASK_LLM_CHAT.models.clear() TASK_LLM_CHAT.models.clear()
TASK_LLM_PROMPT_LAB.models.clear() TASK_LLM_PROMPT_LAB.models.clear()
TASK_LLM_ASK_IMAGE.models.clear() TASK_LLM_ASK_IMAGE.models.clear()
TASK_LLM_ASK_AUDIO.models.clear()
for (allowedModel in modelAllowlist.models) { for (allowedModel in modelAllowlist.models) {
if (allowedModel.disabled == true) { if (allowedModel.disabled == true) {
continue continue
@ -672,6 +684,9 @@ open class ModelManagerViewModel(
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) { if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
TASK_LLM_ASK_IMAGE.models.add(model) TASK_LLM_ASK_IMAGE.models.add(model)
} }
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_AUDIO.type.id)) {
TASK_LLM_ASK_AUDIO.models.add(model)
}
} }
// Pre-process all tasks. // Pre-process all tasks.
@ -760,6 +775,9 @@ open class ModelManagerViewModel(
if (model.llmSupportImage) { if (model.llmSupportImage) {
TASK_LLM_ASK_IMAGE.models.add(model) TASK_LLM_ASK_IMAGE.models.add(model)
} }
if (model.llmSupportAudio) {
TASK_LLM_ASK_AUDIO.models.add(model)
}
// Update status. // Update status.
modelDownloadStatus[model.name] = modelDownloadStatus[model.name] =
@ -800,6 +818,7 @@ open class ModelManagerViewModel(
accelerators = accelerators, accelerators = accelerators,
) )
val llmSupportImage = info.llmConfig.supportImage val llmSupportImage = info.llmConfig.supportImage
val llmSupportAudio = info.llmConfig.supportAudio
val model = val model =
Model( Model(
name = info.fileName, name = info.fileName,
@ -811,6 +830,7 @@ open class ModelManagerViewModel(
showRunAgainButton = false, showRunAgainButton = false,
imported = true, imported = true,
llmSupportImage = llmSupportImage, llmSupportImage = llmSupportImage,
llmSupportAudio = llmSupportAudio,
) )
model.preProcess() model.preProcess()

View file

@ -36,10 +36,10 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.IntOffset import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.zIndex import androidx.compose.ui.zIndex
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.Lifecycle import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.navigation.NavBackStackEntry import androidx.navigation.NavBackStackEntry
import androidx.navigation.NavHostController import androidx.navigation.NavHostController
import androidx.navigation.NavType import androidx.navigation.NavType
@ -47,20 +47,26 @@ import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable import androidx.navigation.compose.composable
import androidx.navigation.navArgument import androidx.navigation.navArgument
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.TaskType import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.data.getModelByName import com.google.ai.edge.gallery.data.getModelByName
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.home.HomeScreen import com.google.ai.edge.gallery.ui.home.HomeScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmChatScreen import com.google.ai.edge.gallery.ui.llmchat.LlmChatScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnDestination import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnScreen import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManager import com.google.ai.edge.gallery.ui.modelmanager.ModelManager
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
@ -104,7 +110,7 @@ private fun AnimatedContentTransitionScope<*>.slideExit(): ExitTransition {
fun GalleryNavHost( fun GalleryNavHost(
navController: NavHostController, navController: NavHostController,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
modelManagerViewModel: ModelManagerViewModel = viewModel(factory = ViewModelProvider.Factory), modelManagerViewModel: ModelManagerViewModel = hiltViewModel(),
) { ) {
val lifecycleOwner = LocalLifecycleOwner.current val lifecycleOwner = LocalLifecycleOwner.current
var showModelManager by remember { mutableStateOf(false) } var showModelManager by remember { mutableStateOf(false) }
@ -181,11 +187,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) { backStackEntry ->
getModelFromNavigationParam(it, TASK_LLM_CHAT)?.let { defaultModel -> val viewModel: LlmChatViewModel = hiltViewModel(backStackEntry)
getModelFromNavigationParam(backStackEntry, TASK_LLM_CHAT)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmChatScreen( LlmChatScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() }, navigateUp = { navController.navigateUp() },
) )
@ -198,28 +207,54 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) { backStackEntry ->
getModelFromNavigationParam(it, TASK_LLM_PROMPT_LAB)?.let { defaultModel -> val viewModel: LlmSingleTurnViewModel = hiltViewModel(backStackEntry)
getModelFromNavigationParam(backStackEntry, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen( LlmSingleTurnScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() }, navigateUp = { navController.navigateUp() },
) )
} }
} }
// LLM image to text. // Ask image.
composable( composable(
route = "${LlmAskImageDestination.route}/{modelName}", route = "${LlmAskImageDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) { backStackEntry ->
getModelFromNavigationParam(it, TASK_LLM_ASK_IMAGE)?.let { defaultModel -> val viewModel: LlmAskImageViewModel = hiltViewModel()
getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmAskImageScreen( LlmAskImageScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
// Ask audio.
composable(
route = "${LlmAskAudioDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) { backStackEntry ->
val viewModel: LlmAskAudioViewModel = hiltViewModel()
getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_AUDIO)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmAskAudioScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() }, navigateUp = { navController.navigateUp() },
) )
@ -256,6 +291,7 @@ fun navigateToTaskScreen(
when (taskType) { when (taskType) {
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}") TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}") TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}")
TaskType.LLM_PROMPT_LAB -> TaskType.LLM_PROMPT_LAB ->
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}") navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_1 -> {}

View file

@ -1,91 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import android.content.Context
import android.graphics.Bitmap
import android.graphics.Canvas
import android.graphics.drawable.Drawable
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.core.content.ContextCompat
import androidx.core.graphics.createBitmap
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.common.Classification
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageClassification
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
class PreviewChatModel(context: Context) : ChatViewModel(task = TASK_TEST1) {
init {
val model = task.models[1]
addMessage(
model = model,
message =
ChatMessageText(
content =
"Thanks everyone for your enthusiasm on the team lunch, but people who can sign on the cheque is OOO next week \uD83D\uDE02,",
side = ChatSide.USER,
),
)
addMessage(
model = model,
message =
ChatMessageText(content = "Today is Wednesday!", side = ChatSide.AGENT, latencyMs = 1232f),
)
addMessage(
model = model,
message =
ChatMessageClassification(
classifications =
listOf(
Classification(label = "label1", score = 0.3f, color = Color.Red),
Classification(label = "label2", score = 0.7f, color = Color.Blue),
),
latencyMs = 12345f,
),
)
val bitmap =
getBitmapFromVectorDrawable(
context = context,
drawableId = R.drawable.ic_launcher_background,
)!!
addMessage(
model = model,
message =
ChatMessageImage(
bitmap = bitmap,
imageBitMap = bitmap.asImageBitmap(),
side = ChatSide.USER,
),
)
}
private fun getBitmapFromVectorDrawable(context: Context, drawableId: Int): Bitmap? {
val drawable: Drawable =
ContextCompat.getDrawable(context, drawableId) ?: return null // Drawable not found
val bitmap = createBitmap(drawable.intrinsicWidth, drawable.intrinsicHeight)
val canvas = Canvas(bitmap)
drawable.setBounds(0, 0, canvas.width, canvas.height)
drawable.draw(canvas)
return bitmap
}
}

View file

@ -1,57 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
// TODO(migration)
//
// import com.google.ai.edge.gallery.data.AccessTokenData
// import com.google.ai.edge.gallery.data.DataStoreRepository
// import com.google.ai.edge.gallery.data.ImportedModelInfo
// class PreviewDataStoreRepository : DataStoreRepository
class PreviewDataStoreRepository {
// override fun saveTextInputHistory(history: List<String>) {
// }
// override fun readTextInputHistory(): List<String> {
// return listOf()
// }
// override fun saveThemeOverride(theme: String) {
// }
// override fun readThemeOverride(): String {
// return ""
// }
// override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) {
// }
// override fun readAccessTokenData(): AccessTokenData? {
// return null
// }
// override fun clearAccessTokenData() {
// }
// override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
// }
// override fun readImportedModels(): List<ImportedModelInfo> {
// return listOf()
// }
}

View file

@ -1,44 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import com.google.ai.edge.gallery.data.AGWorkInfo
import com.google.ai.edge.gallery.data.DownloadRepository
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatus
import java.util.UUID
class PreviewDownloadRepository : DownloadRepository {
override fun downloadModel(
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) {}
override fun cancelDownloadModel(model: Model) {}
override fun cancelAll(models: List<Model>, onComplete: () -> Unit) {}
override fun observerWorkerProgress(
workerId: UUID,
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) {}
override fun getEnqueuedOrRunningWorkInfos(): List<AGWorkInfo> {
return listOf()
}
}

View file

@ -1,63 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
class PreviewModelManagerViewModel {}
// class PreviewModelManagerViewModel(context: Context) :
// ModelManagerViewModel(
// downloadRepository = PreviewDownloadRepository(),
// // dataStoreRepository = PreviewDataStoreRepository(),
// context = context,
// ) {
// init {
// for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
// task.index = index
// for (model in task.models) {
// model.preProcess()
// }
// }
// val modelDownloadStatus =
// mapOf(
// MODEL_TEST1.name to
// ModelDownloadStatus(
// status = ModelDownloadStatusType.IN_PROGRESS,
// receivedBytes = 1234,
// totalBytes = 3456,
// bytesPerSecond = 2333,
// remainingMs = 324,
// ),
// MODEL_TEST2.name to ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED),
// MODEL_TEST3.name to
// ModelDownloadStatus(
// status = ModelDownloadStatusType.FAILED,
// errorMessage = "Http code 404",
// ),
// MODEL_TEST4.name to ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED),
// )
// val newUiState =
// ModelManagerUiState(
// tasks = ALL_PREVIEW_TASKS,
// modelDownloadStatus = modelDownloadStatus,
// modelInitializationStatus = mapOf(),
// selectedModel = MODEL_TEST2,
// )
// _uiState.update { newUiState }
// }
// }

View file

@ -1,103 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.AccountBox
import androidx.compose.material.icons.rounded.AutoAwesome
import com.google.ai.edge.gallery.data.BooleanSwitchConfig
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.data.ValueType
val TEST_CONFIGS1: List<Config> =
listOf(
LabelConfig(key = ConfigKey.NAME, defaultValue = "Test name"),
NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f,
sliderMax = 5f,
defaultValue = 3f,
valueType = ValueType.INT,
),
BooleanSwitchConfig(key = ConfigKey.USE_GPU, defaultValue = false),
SegmentedButtonConfig(
key = ConfigKey.THEME,
defaultValue = "Auto",
options = listOf("Auto", "Light", "Dark"),
),
)
val MODEL_TEST1: Model =
Model(
name = "deterministic3",
downloadFileName = "deterministric3.json",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/deterministic3.json",
sizeInBytes = 40146048L,
configs = TEST_CONFIGS1,
)
val MODEL_TEST2: Model =
Model(
name = "isnet",
downloadFileName = "isnet.tflite",
url =
"https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/isnet-general-use-int8.tflite",
sizeInBytes = 44366296L,
configs = TEST_CONFIGS1,
)
val MODEL_TEST3: Model =
Model(
name = "yolo",
downloadFileName = "yolo.json",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/yolo.json",
sizeInBytes = 40641364L,
)
val MODEL_TEST4: Model =
Model(
name = "mobilenet v3",
downloadFileName = "mobilenet_v3_large.pt2",
url =
"https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/mobilenet_v3_large.pt2",
sizeInBytes = 277135998L,
)
val TASK_TEST1 =
Task(
type = TaskType.TEST_TASK_1,
icon = Icons.Rounded.AutoAwesome,
models = mutableListOf(MODEL_TEST1, MODEL_TEST2),
description = "This is a test task (1)",
)
val TASK_TEST2 =
Task(
type = TaskType.TEST_TASK_2,
icon = Icons.Rounded.AccountBox,
models = mutableListOf(MODEL_TEST3, MODEL_TEST4),
description = "This is a test task (2)",
)
val ALL_PREVIEW_TASKS: List<Task> = listOf(TASK_TEST1, TASK_TEST2)

View file

@ -120,6 +120,8 @@ data class CustomColors(
val agentBubbleBgColor: Color = Color.Transparent, val agentBubbleBgColor: Color = Color.Transparent,
val linkColor: Color = Color.Transparent, val linkColor: Color = Color.Transparent,
val successColor: Color = Color.Transparent, val successColor: Color = Color.Transparent,
val recordButtonBgColor: Color = Color.Transparent,
val waveFormBgColor: Color = Color.Transparent,
) )
val LocalCustomColors = staticCompositionLocalOf { CustomColors() } val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
@ -145,6 +147,8 @@ val lightCustomColors =
userBubbleBgColor = Color(0xFF32628D), userBubbleBgColor = Color(0xFF32628D),
linkColor = Color(0xFF32628D), linkColor = Color(0xFF32628D),
successColor = Color(0xff3d860b), successColor = Color(0xff3d860b),
recordButtonBgColor = Color(0xFFEE675C),
waveFormBgColor = Color(0xFFaaaaaa),
) )
val darkCustomColors = val darkCustomColors =
@ -168,6 +172,8 @@ val darkCustomColors =
userBubbleBgColor = Color(0xFF1f3760), userBubbleBgColor = Color(0xFF1f3760),
linkColor = Color(0xFF9DCAFC), linkColor = Color(0xFF9DCAFC),
successColor = Color(0xFFA1CE83), successColor = Color(0xFFA1CE83),
recordButtonBgColor = Color(0xFFEE675C),
waveFormBgColor = Color(0xFFaaaaaa),
) )
val MaterialTheme.customColors: CustomColors val MaterialTheme.customColors: CustomColors

View file

@ -55,6 +55,7 @@ message LlmConfig {
float default_topp = 4; float default_topp = 4;
float default_temperature = 5; float default_temperature = 5;
bool support_image = 6; bool support_image = 6;
bool support_audio = 7;
} }
message Settings { message Settings {

View file

@ -0,0 +1,107 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.data
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
@RunWith(JUnit4::class)
class ModelAllowlistTest {
@Test
fun toModel_success() {
val modelName = "test_model"
val modelId = "test_model_id"
val modelFile = "test_model_file"
val description = "test description"
val sizeInBytes = 100L
val version = "20250623"
val topK = 10
val topP = 0.5f
val temperature = 0.1f
val maxTokens = 1000
val accelerators = "gpu,cpu"
val taskTypes = listOf("llm_chat", "ask_image")
val estimatedPeakMemoryInBytes = 300L
val allowedModel =
AllowedModel(
name = modelName,
modelId = modelId,
modelFile = modelFile,
description = description,
sizeInBytes = sizeInBytes,
version = version,
defaultConfig =
DefaultConfig(
topK = topK,
topP = topP,
temperature = temperature,
maxTokens = maxTokens,
accelerators = accelerators,
),
taskTypes = taskTypes,
llmSupportImage = true,
llmSupportAudio = true,
estimatedPeakMemoryInBytes = estimatedPeakMemoryInBytes,
)
val model = allowedModel.toModel()
// Check that basic fields are set correctly.
assertEquals(model.name, modelName)
assertEquals(model.version, version)
assertEquals(model.info, description)
assertEquals(
model.url,
"https://huggingface.co/test_model_id/resolve/main/test_model_file?download=true",
)
assertEquals(model.sizeInBytes, sizeInBytes)
assertEquals(model.estimatedPeakMemoryInBytes, estimatedPeakMemoryInBytes)
assertEquals(model.downloadFileName, modelFile)
assertFalse(model.showBenchmarkButton)
assertFalse(model.showRunAgainButton)
assertTrue(model.llmSupportImage)
assertTrue(model.llmSupportAudio)
// Check that configs are set correctly.
assertEquals(model.configs.size, 5)
// A label for showing max tokens (non-changeable).
assertTrue(model.configs[0] is LabelConfig)
assertEquals((model.configs[0] as LabelConfig).defaultValue, "$maxTokens")
// A slider for topK.
assertTrue(model.configs[1] is NumberSliderConfig)
assertEquals((model.configs[1] as NumberSliderConfig).defaultValue, topK.toFloat())
// A slider for topP.
assertTrue(model.configs[2] is NumberSliderConfig)
assertEquals((model.configs[2] as NumberSliderConfig).defaultValue, topP)
// A slider for temperature.
assertTrue(model.configs[3] is NumberSliderConfig)
assertEquals((model.configs[3] as NumberSliderConfig).defaultValue, temperature)
// A segmented button for accelerators.
assertTrue(model.configs[4] is SegmentedButtonConfig)
assertEquals((model.configs[4] as SegmentedButtonConfig).defaultValue, "GPU")
assertEquals((model.configs[4] as SegmentedButtonConfig).options, listOf("GPU", "CPU"))
}
}

View file

@ -19,4 +19,5 @@ plugins {
alias(libs.plugins.android.application) apply false alias(libs.plugins.android.application) apply false
alias(libs.plugins.kotlin.android) apply false alias(libs.plugins.kotlin.android) apply false
alias(libs.plugins.kotlin.compose) apply false alias(libs.plugins.kotlin.compose) apply false
alias(libs.plugins.hilt.application) apply false
} }

View file

@ -30,6 +30,8 @@ playServicesTfliteGpu= "16.4.0"
cameraX = "1.4.2" cameraX = "1.4.2"
netOpenidAppauth = "0.11.1" netOpenidAppauth = "0.11.1"
splashscreen = "1.2.0-beta02" splashscreen = "1.2.0-beta02"
hilt = "2.56.2"
hiltNavigation = "1.2.0"
[libraries] [libraries]
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
@ -69,6 +71,10 @@ camerax-view = { group = "androidx.camera", name = "camera-view", version.ref =
openid-appauth = { group = "net.openid", name = "appauth", version.ref = "netOpenidAppauth" } openid-appauth = { group = "net.openid", name = "appauth", version.ref = "netOpenidAppauth" }
androidx-splashscreen = { group = "androidx.core", name = "core-splashscreen", version.ref = "splashscreen" } androidx-splashscreen = { group = "androidx.core", name = "core-splashscreen", version.ref = "splashscreen" }
protobuf-javalite = { group = "com.google.protobuf", name = "protobuf-javalite", version.ref = "protobufJavaLite" } protobuf-javalite = { group = "com.google.protobuf", name = "protobuf-javalite", version.ref = "protobufJavaLite" }
hilt-android = { module = "com.google.dagger:hilt-android", version.ref = "hilt" }
hilt-navigation-compose = { module = "androidx.hilt:hilt-navigation-compose", version.ref = "hiltNavigation" }
hilt-android-testing = { module = "com.google.dagger:hilt-android-testing", version.ref = "hilt" }
hilt-android-compiler = { module = "com.google.dagger:hilt-android-compiler", version.ref = "hilt" }
[plugins] [plugins]
android-application = { id = "com.android.application", version.ref = "agp" } android-application = { id = "com.android.application", version.ref = "agp" }
@ -76,3 +82,4 @@ kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" }
kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "serializationPlugin" } kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "serializationPlugin" }
protobuf = {id = "com.google.protobuf", version.ref = "protobuf"} protobuf = {id = "com.google.protobuf", version.ref = "protobuf"}
hilt-application = { id = "com.google.dagger.hilt.android", version.ref = "hilt" }