diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index a7d15c7..9e38ca2 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -12,6 +12,10 @@ A clear and concise description of what the bug is.
**To Reproduce:**
Steps to reproduce the behavior:
+1. Go to '...'
+2. Click on '....'
+3. Scroll down to '....'
+4. See error
**Expected behavior:**
A clear and concise description of what you expected to happen.
@@ -19,16 +23,10 @@ A clear and concise description of what you expected to happen.
**Screenshots:**
If applicable, add screenshots to help explain your problem.
-**Desktop (please complete the following information):**
- - OS: [e.g. iOS]
- - Browser [e.g. chrome, safari]
- - Version [e.g. 22]
-
-**Smartphone (please complete the following information):**
- - Device: [e.g. iPhone6]
- - OS: [e.g. iOS8.1]
- - Browser [e.g. stock browser, safari]
- - Version [e.g. 22]
+**Device & App Information (Please complete the following):**
+- Device: [e.g., Samsung Galaxy S23, Google Pixel 7]
+- Android Version: [e.g., Android 12, Android 13]
+- App Version: [e.g., 1.0.1, v1.0.2]
**Additional context:**
Add any other context about the problem here.
diff --git a/.github/workflows/build_android.yaml b/.github/workflows/build_android.yaml
index 0ecbccb..1515255 100644
--- a/.github/workflows/build_android.yaml
+++ b/.github/workflows/build_android.yaml
@@ -21,5 +21,9 @@ jobs:
steps:
- name: Checkout the source code
uses: actions/checkout@v3
+ - uses: actions/setup-java@v4
+ with:
+ distribution: 'temurin'
+ java-version: '21'
- name: Build
run: ./gradlew assembleRelease
diff --git a/Android/src/app/build.gradle.kts b/Android/src/app/build.gradle.kts
index 756a3da..a8eab50 100644
--- a/Android/src/app/build.gradle.kts
+++ b/Android/src/app/build.gradle.kts
@@ -20,6 +20,8 @@ plugins {
alias(libs.plugins.kotlin.compose)
alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.protobuf)
+ alias(libs.plugins.hilt.application)
+ kotlin("kapt")
}
android {
@@ -91,11 +93,15 @@ dependencies {
implementation(libs.openid.appauth)
implementation(libs.androidx.splashscreen)
implementation(libs.protobuf.javalite)
+ implementation(libs.hilt.android)
+ implementation(libs.hilt.navigation.compose)
+ kapt(libs.hilt.android.compiler)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
androidTestImplementation(platform(libs.androidx.compose.bom))
androidTestImplementation(libs.androidx.ui.test.junit4)
+ androidTestImplementation(libs.hilt.android.testing)
debugImplementation(libs.androidx.ui.tooling)
debugImplementation(libs.androidx.ui.test.manifest)
implementation(libs.material)
diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml
index a87aa98..6bacbd3 100644
--- a/Android/src/app/src/main/AndroidManifest.xml
+++ b/Android/src/app/src/main/AndroidManifest.xml
@@ -29,6 +29,7 @@
+
-
-
-
-
-
{
- override val defaultValue: Settings = Settings.getDefaultInstance()
-
- override suspend fun readFrom(input: InputStream): Settings {
- try {
- return Settings.parseFrom(input)
- } catch (exception: InvalidProtocolBufferException) {
- throw CorruptionException("Cannot read proto.", exception)
- }
- }
-
- override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
-}
-
-private val Context.dataStore: DataStore by
- dataStore(fileName = "settings.pb", serializer = SettingsSerializer)
+import dagger.hilt.android.HiltAndroidApp
+import javax.inject.Inject
+@HiltAndroidApp
class GalleryApplication : Application() {
- /** AppContainer instance used by the rest of classes to obtain dependencies */
- lateinit var container: AppContainer
+
+ @Inject lateinit var dataStoreRepository: DataStoreRepository
override fun onCreate() {
super.onCreate()
writeLaunchInfo(context = this)
- container = DefaultAppContainer(this, dataStore)
// Load saved theme.
- ThemeSettings.themeOverride.value = container.dataStoreRepository.readTheme()
+ ThemeSettings.themeOverride.value = dataStoreRepository.readTheme()
}
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt
index dc172dc..cd3c9ea 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt
@@ -18,6 +18,7 @@ package com.google.ai.edge.gallery
import android.os.Build
import android.os.Bundle
+import android.view.WindowManager
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge
@@ -26,8 +27,11 @@ import androidx.compose.material3.Surface
import androidx.compose.ui.Modifier
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
+import dagger.hilt.android.AndroidEntryPoint
+@AndroidEntryPoint
class MainActivity : ComponentActivity() {
+
override fun onCreate(savedInstanceState: Bundle?) {
installSplashScreen()
@@ -39,5 +43,7 @@ class MainActivity : ComponentActivity() {
window.isNavigationBarContrastEnforced = false
}
setContent { GalleryTheme { Surface(modifier = Modifier.fillMaxSize()) { GalleryApp() } } }
+ // Keep the screen on while the app is running for better demo experience.
+ window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
}
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/SettingsSerializer.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/SettingsSerializer.kt
new file mode 100644
index 0000000..c7381f1
--- /dev/null
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/SettingsSerializer.kt
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.edge.gallery
+
+import androidx.datastore.core.CorruptionException
+import androidx.datastore.core.Serializer
+import com.google.ai.edge.gallery.proto.Settings
+import com.google.protobuf.InvalidProtocolBufferException
+import java.io.InputStream
+import java.io.OutputStream
+
+object SettingsSerializer : Serializer {
+ override val defaultValue: Settings = Settings.getDefaultInstance()
+
+ override suspend fun readFrom(input: InputStream): Settings {
+ try {
+ return Settings.parseFrom(input)
+ } catch (exception: InvalidProtocolBufferException) {
+ throw CorruptionException("Cannot read proto.", exception)
+ }
+ }
+
+ override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Types.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Types.kt
index f9e9a9f..1c7a824 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Types.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Types.kt
@@ -25,3 +25,5 @@ interface LatencyProvider {
data class Classification(val label: String, val score: Float, val color: Color)
data class JsonObjAndTextContent(val jsonObj: T, val textContent: String)
+
+class AudioClip(val audioData: ByteArray, val sampleRate: Int)
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt
index 6d37f9d..1ada15d 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt
@@ -17,12 +17,17 @@
package com.google.ai.edge.gallery.common
import android.content.Context
+import android.net.Uri
import android.util.Log
+import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import kotlin.math.floor
data class LaunchInfo(val ts: Long)
@@ -112,3 +117,135 @@ inline fun getJsonResponse(url: String): JsonObjAndTextContent? {
return null
}
+
+fun convertWavToMonoWithMaxSeconds(
+ context: Context,
+ stereoUri: Uri,
+ maxSeconds: Int = 30,
+): AudioClip? {
+ Log.d(TAG, "Start to convert wav file to mono channel")
+
+ try {
+ val inputStream = context.contentResolver.openInputStream(stereoUri) ?: return null
+ val originalBytes = inputStream.readBytes()
+ inputStream.close()
+
+ // Read WAV header
+ if (originalBytes.size < 44) {
+ // Not a valid WAV file
+ Log.e(TAG, "Not a valid wav file")
+ return null
+ }
+
+ val headerBuffer = ByteBuffer.wrap(originalBytes, 0, 44).order(ByteOrder.LITTLE_ENDIAN)
+ val channels = headerBuffer.getShort(22)
+ var sampleRate = headerBuffer.getInt(24)
+ val bitDepth = headerBuffer.getShort(34)
+ Log.d(TAG, "File metadata: channels: $channels, sampleRate: $sampleRate, bitDepth: $bitDepth")
+
+ // Normalize audio to 16-bit.
+ val audioDataBytes = originalBytes.copyOfRange(fromIndex = 44, toIndex = originalBytes.size)
+ var sixteenBitBytes: ByteArray =
+ if (bitDepth.toInt() == 8) {
+ Log.d(TAG, "Converting 8-bit audio to 16-bit.")
+ convert8BitTo16Bit(audioDataBytes)
+ } else {
+ // Assume 16-bit or other format that can be handled directly
+ audioDataBytes
+ }
+
+ // Convert byte array to short array for processing
+ val shortBuffer =
+ ByteBuffer.wrap(sixteenBitBytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
+ var pcmSamples = ShortArray(shortBuffer.remaining())
+ shortBuffer.get(pcmSamples)
+
+ // Resample if sample rate is less than 16000 Hz ---
+ if (sampleRate < SAMPLE_RATE) {
+ Log.d(TAG, "Resampling from $sampleRate Hz to $SAMPLE_RATE Hz.")
+ pcmSamples = resample(pcmSamples, sampleRate, SAMPLE_RATE, channels.toInt())
+ sampleRate = SAMPLE_RATE
+ Log.d(TAG, "Resampling complete. New sample count: ${pcmSamples.size}")
+ }
+
+ // Convert stereo to mono if necessary
+ var monoSamples =
+ if (channels.toInt() == 2) {
+ Log.d(TAG, "Converting stereo to mono.")
+ val mono = ShortArray(pcmSamples.size / 2)
+ for (i in mono.indices) {
+ val left = pcmSamples[i * 2]
+ val right = pcmSamples[i * 2 + 1]
+ mono[i] = ((left + right) / 2).toShort()
+ }
+ mono
+ } else {
+ Log.d(TAG, "Audio is already mono. No channel conversion needed.")
+ pcmSamples
+ }
+
+ // Trim the audio to maxSeconds ---
+ val maxSamples = maxSeconds * sampleRate
+ if (monoSamples.size > maxSamples) {
+ Log.d(TAG, "Trimming clip from ${monoSamples.size} samples to $maxSamples samples.")
+ monoSamples = monoSamples.copyOfRange(0, maxSamples)
+ }
+
+ val monoByteBuffer = ByteBuffer.allocate(monoSamples.size * 2).order(ByteOrder.LITTLE_ENDIAN)
+ monoByteBuffer.asShortBuffer().put(monoSamples)
+ return AudioClip(audioData = monoByteBuffer.array(), sampleRate = sampleRate)
+ } catch (e: Exception) {
+ Log.e(TAG, "Failed to convert wav to mono", e)
+ return null
+ }
+}
+
+/** Converts 8-bit unsigned PCM audio data to 16-bit signed PCM. */
+private fun convert8BitTo16Bit(eightBitData: ByteArray): ByteArray {
+ // The new 16-bit data will be twice the size
+ val sixteenBitData = ByteArray(eightBitData.size * 2)
+ val buffer = ByteBuffer.wrap(sixteenBitData).order(ByteOrder.LITTLE_ENDIAN)
+
+ for (byte in eightBitData) {
+ // Convert the unsigned 8-bit byte (0-255) to a signed 16-bit short (-32768 to 32767)
+ // 1. Get the unsigned value by masking with 0xFF
+ // 2. Subtract 128 to center the waveform around 0 (range becomes -128 to 127)
+ // 3. Scale by 256 to expand to the 16-bit range
+ val unsignedByte = byte.toInt() and 0xFF
+ val sixteenBitSample = ((unsignedByte - 128) * 256).toShort()
+ buffer.putShort(sixteenBitSample)
+ }
+ return sixteenBitData
+}
+
+/** Resamples PCM audio data from an original sample rate to a target sample rate. */
+private fun resample(
+ inputSamples: ShortArray,
+ originalSampleRate: Int,
+ targetSampleRate: Int,
+ channels: Int,
+): ShortArray {
+ if (originalSampleRate == targetSampleRate) {
+ return inputSamples
+ }
+
+ val ratio = targetSampleRate.toDouble() / originalSampleRate
+ val outputLength = (inputSamples.size * ratio).toInt()
+ val resampledData = ShortArray(outputLength)
+
+ if (channels == 1) { // Mono
+ for (i in resampledData.indices) {
+ val position = i / ratio
+ val index1 = floor(position).toInt()
+ val index2 = index1 + 1
+ val fraction = position - index1
+
+ val sample1 = if (index1 < inputSamples.size) inputSamples[index1].toDouble() else 0.0
+ val sample2 = if (index2 < inputSamples.size) inputSamples[index2].toDouble() else 0.0
+
+ resampledData[i] = (sample1 * (1 - fraction) + sample2 * fraction).toInt().toShort()
+ }
+ }
+
+ return resampledData
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt
index d961ecc..2b9fdf7 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt
@@ -50,6 +50,7 @@ enum class ConfigKey(val label: String) {
DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"),
SUPPORT_IMAGE("Support image"),
+ SUPPORT_AUDIO("Support audio"),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Choose accelerator"),
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt
index a2209af..2315099 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt
@@ -44,3 +44,12 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
// Max number of images allowed in a "ask image" session.
const val MAX_IMAGE_COUNT = 10
+
+// Max number of audio clip in an "ask audio" session.
+const val MAX_AUDIO_CLIP_COUNT = 1
+
+// Max audio clip duration in seconds.
+const val MAX_AUDIO_CLIP_DURATION_SEC = 30
+
+// Audio-recording related consts.
+const val SAMPLE_RATE = 16000
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt
index 580bf2b..02e1368 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt
@@ -87,6 +87,9 @@ data class Model(
/** Whether the LLM model supports image input. */
val llmSupportImage: Boolean = false,
+ /** Whether the LLM model supports audio input. */
+ val llmSupportAudio: Boolean = false,
+
/** Whether the model is imported or not. */
val imported: Boolean = false,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt
index f336638..03263ca 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt
@@ -38,6 +38,7 @@ data class AllowedModel(
val taskTypes: List,
val disabled: Boolean? = null,
val llmSupportImage: Boolean? = null,
+ val llmSupportAudio: Boolean? = null,
val estimatedPeakMemoryInBytes: Long? = null,
) {
fun toModel(): Model {
@@ -49,10 +50,10 @@ data class AllowedModel(
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)
var configs: List = listOf()
if (isLlmModel) {
- var defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK
- var defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP
- var defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE
- var defaultMaxToken = defaultConfig.maxTokens ?: 1024
+ val defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK
+ val defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP
+ val defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE
+ val defaultMaxToken = defaultConfig.maxTokens ?: 1024
var accelerators: List = DEFAULT_ACCELERATORS
if (defaultConfig.accelerators != null) {
val items = defaultConfig.accelerators.split(",")
@@ -96,6 +97,7 @@ data class AllowedModel(
showRunAgainButton = showRunAgainButton,
learnMoreUrl = "https://huggingface.co/${modelId}",
llmSupportImage = llmSupportImage == true,
+ llmSupportAudio = llmSupportAudio == true,
)
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt
index e95feab..c52d2c3 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt
@@ -17,19 +17,22 @@
package com.google.ai.edge.gallery.data
import androidx.annotation.StringRes
+import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.outlined.Forum
+import androidx.compose.material.icons.outlined.Mic
+import androidx.compose.material.icons.outlined.Mms
+import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.ui.graphics.vector.ImageVector
import com.google.ai.edge.gallery.R
-import com.google.ai.edge.gallery.ui.icon.Forum
-import com.google.ai.edge.gallery.ui.icon.Mms
-import com.google.ai.edge.gallery.ui.icon.Widgets
/** Type of task. */
enum class TaskType(val label: String, val id: String) {
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
+ LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
}
@@ -71,7 +74,7 @@ data class Task(
val TASK_LLM_CHAT =
Task(
type = TaskType.LLM_CHAT,
- icon = Forum,
+ icon = Icons.Outlined.Forum,
models = mutableListOf(),
description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@@ -83,7 +86,7 @@ val TASK_LLM_CHAT =
val TASK_LLM_PROMPT_LAB =
Task(
type = TaskType.LLM_PROMPT_LAB,
- icon = Widgets,
+ icon = Icons.Outlined.Widgets,
models = mutableListOf(),
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@@ -95,7 +98,7 @@ val TASK_LLM_PROMPT_LAB =
val TASK_LLM_ASK_IMAGE =
Task(
type = TaskType.LLM_ASK_IMAGE,
- icon = Mms,
+ icon = Icons.Outlined.Mms,
models = mutableListOf(),
description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@@ -104,8 +107,23 @@ val TASK_LLM_ASK_IMAGE =
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
+val TASK_LLM_ASK_AUDIO =
+ Task(
+ type = TaskType.LLM_ASK_AUDIO,
+ icon = Icons.Outlined.Mic,
+ models = mutableListOf(),
+ // TODO(do not submit)
+ description =
+ "Instantly transcribe and/or translate audio clips using on-device large language models",
+ docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
+ sourceCodeUrl =
+ "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
+ textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
+ )
+
/** All tasks. */
-val TASKS: List = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
+val TASKS: List =
+ listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
fun getModelByName(name: String): Model? {
for (task in TASKS) {
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt
new file mode 100644
index 0000000..a634822
--- /dev/null
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt
@@ -0,0 +1,86 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.edge.gallery.di
+
+import android.content.Context
+import androidx.datastore.core.DataStore
+import androidx.datastore.core.DataStoreFactory
+import androidx.datastore.core.Serializer
+import androidx.datastore.dataStoreFile
+import com.google.ai.edge.gallery.AppLifecycleProvider
+import com.google.ai.edge.gallery.GalleryLifecycleProvider
+import com.google.ai.edge.gallery.SettingsSerializer
+import com.google.ai.edge.gallery.data.DataStoreRepository
+import com.google.ai.edge.gallery.data.DefaultDataStoreRepository
+import com.google.ai.edge.gallery.data.DefaultDownloadRepository
+import com.google.ai.edge.gallery.data.DownloadRepository
+import com.google.ai.edge.gallery.proto.Settings
+import dagger.Module
+import dagger.Provides
+import dagger.hilt.InstallIn
+import dagger.hilt.android.qualifiers.ApplicationContext
+import dagger.hilt.components.SingletonComponent
+import javax.inject.Singleton
+
+@Module
+@InstallIn(SingletonComponent::class)
+internal object AppModule {
+
+ // Provides the SettingsSerializer
+ @Provides
+ @Singleton
+ fun provideSettingsSerializer(): Serializer {
+ return SettingsSerializer
+ }
+
+ // Provides DataStore
+ @Provides
+ @Singleton
+ fun provideSettingsDataStore(
+ @ApplicationContext context: Context,
+ settingsSerializer: Serializer,
+ ): DataStore {
+ return DataStoreFactory.create(
+ serializer = settingsSerializer,
+ produceFile = { context.dataStoreFile("settings.pb") },
+ )
+ }
+
+ // Provides AppLifecycleProvider
+ @Provides
+ @Singleton
+ fun provideAppLifecycleProvider(): AppLifecycleProvider {
+ return GalleryLifecycleProvider()
+ }
+
+ // Provides DataStoreRepository
+ @Provides
+ @Singleton
+ fun provideDataStoreRepository(dataStore: DataStore): DataStoreRepository {
+ return DefaultDataStoreRepository(dataStore)
+ }
+
+ // Provides DownloadRepository
+ @Provides
+ @Singleton
+ fun provideDownloadRepository(
+ @ApplicationContext context: Context,
+ lifecycleProvider: AppLifecycleProvider,
+ ): DownloadRepository {
+ return DefaultDownloadRepository(context, lifecycleProvider)
+ }
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt
deleted file mode 100644
index fdded53..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui
-
-import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory
-import androidx.lifecycle.viewmodel.CreationExtras
-import androidx.lifecycle.viewmodel.initializer
-import androidx.lifecycle.viewmodel.viewModelFactory
-import com.google.ai.edge.gallery.GalleryApplication
-import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
-import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
-import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
-import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
-
-object ViewModelProvider {
- val Factory = viewModelFactory {
- // Initializer for ModelManagerViewModel.
- initializer {
- val downloadRepository = galleryApplication().container.downloadRepository
- val dataStoreRepository = galleryApplication().container.dataStoreRepository
- val lifecycleProvider = galleryApplication().container.lifecycleProvider
- ModelManagerViewModel(
- downloadRepository = downloadRepository,
- dataStoreRepository = dataStoreRepository,
- lifecycleProvider = lifecycleProvider,
- context = galleryApplication().container.context,
- )
- }
-
- // Initializer for LlmChatViewModel.
- initializer { LlmChatViewModel() }
-
- // Initializer for LlmSingleTurnViewModel..
- initializer { LlmSingleTurnViewModel() }
-
- // Initializer for LlmAskImageViewModel.
- initializer { LlmAskImageViewModel() }
- }
-}
-
-/**
- * Extension function to queries for [Application] object and returns an instance of
- * [GalleryApplication].
- */
-fun CreationExtras.galleryApplication(): GalleryApplication =
- (this[AndroidViewModelFactory.APPLICATION_KEY] as GalleryApplication)
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/DownloadAndTryButton.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/DownloadAndTryButton.kt
index 8511c6e..039ed6d 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/DownloadAndTryButton.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/DownloadAndTryButton.kt
@@ -64,6 +64,7 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
private const val TAG = "AGDownloadAndTryButton"
+private const val SYSTEM_RESERVED_MEMORY_IN_BYTES = 3 * (1L shl 30)
// TODO:
// - replace the download button in chat view page with this one, and add a flag to not "onclick"
@@ -290,7 +291,6 @@ fun DownloadAndTryButton(
val activityManager =
context.getSystemService(android.app.Activity.ACTIVITY_SERVICE) as? ActivityManager
val estimatedPeakMemoryInBytes = model.estimatedPeakMemoryInBytes
-
val isMemoryLow =
if (activityManager != null && estimatedPeakMemoryInBytes != null) {
val memoryInfo = ActivityManager.MemoryInfo()
@@ -302,11 +302,10 @@ fun DownloadAndTryButton(
// The device should be able to run the model if `availMem` is larger than the
// estimated peak memory. Android also has a mechanism to kill background apps to
- // free up memory for the foreground app. We believe that if half of the total
- // memory on the device is larger than the estimated peak memory, it can run the
- // model fine with this mechanism. For example, a phone with 12GB memory can have
- // very few `availMem` but will have no problem running most models.
- max(memoryInfo.availMem, memoryInfo.totalMem / 2) < estimatedPeakMemoryInBytes
+ // free up memory for the foreground app. Reserving 3G for system buffer memory to
+ // avoid the app being killed by the system.
+ max(memoryInfo.availMem, memoryInfo.totalMem - SYSTEM_RESERVED_MEMORY_IN_BYTES) <
+ estimatedPeakMemoryInBytes
} else {
false
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt
new file mode 100644
index 0000000..9faebf0
--- /dev/null
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt
@@ -0,0 +1,344 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.edge.gallery.ui.common.chat
+
+import android.media.AudioAttributes
+import android.media.AudioFormat
+import android.media.AudioTrack
+import android.util.Log
+import androidx.compose.foundation.Canvas
+import androidx.compose.foundation.layout.Row
+import androidx.compose.foundation.layout.height
+import androidx.compose.foundation.layout.padding
+import androidx.compose.foundation.layout.width
+import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.rounded.PlayArrow
+import androidx.compose.material.icons.rounded.Stop
+import androidx.compose.material3.Icon
+import androidx.compose.material3.IconButton
+import androidx.compose.material3.MaterialTheme
+import androidx.compose.material3.Text
+import androidx.compose.runtime.Composable
+import androidx.compose.runtime.DisposableEffect
+import androidx.compose.runtime.LaunchedEffect
+import androidx.compose.runtime.MutableState
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableFloatStateOf
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.remember
+import androidx.compose.runtime.rememberCoroutineScope
+import androidx.compose.runtime.setValue
+import androidx.compose.ui.Alignment
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.geometry.CornerRadius
+import androidx.compose.ui.geometry.Offset
+import androidx.compose.ui.geometry.Size
+import androidx.compose.ui.geometry.toRect
+import androidx.compose.ui.graphics.BlendMode
+import androidx.compose.ui.graphics.Color
+import androidx.compose.ui.graphics.drawscope.drawIntoCanvas
+import androidx.compose.ui.unit.dp
+import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
+import com.google.ai.edge.gallery.ui.theme.customColors
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.isActive
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.withContext
+
+private const val TAG = "AGAudioPlaybackPanel"
+private const val BAR_SPACE = 2
+private const val BAR_WIDTH = 2
+private const val MIN_BAR_COUNT = 16
+private const val MAX_BAR_COUNT = 48
+
+/**
+ * A Composable that displays an audio playback panel, including play/stop controls, a waveform
+ * visualization, and the duration of the audio clip.
+ */
+@Composable
+fun AudioPlaybackPanel(
+ audioData: ByteArray,
+ sampleRate: Int,
+ isRecording: Boolean,
+ modifier: Modifier = Modifier,
+ onDarkBg: Boolean = false,
+) {
+ val coroutineScope = rememberCoroutineScope()
+ var isPlaying by remember { mutableStateOf(false) }
+ val audioTrackState = remember { mutableStateOf(null) }
+ val durationInSeconds =
+ remember(audioData) {
+ // PCM 16-bit
+ val bytesPerSample = 2
+ val bytesPerFrame = bytesPerSample * 1 // mono
+ val totalFrames = audioData.size.toDouble() / bytesPerFrame
+ totalFrames / sampleRate
+ }
+ val barCount =
+ remember(durationInSeconds) {
+ val f = durationInSeconds / MAX_AUDIO_CLIP_DURATION_SEC
+ ((MAX_BAR_COUNT - MIN_BAR_COUNT) * f + MIN_BAR_COUNT).toInt()
+ }
+ val amplitudeLevels =
+ remember(audioData) { generateAmplitudeLevels(audioData = audioData, barCount = barCount) }
+ var playbackProgress by remember { mutableFloatStateOf(0f) }
+
+ // Reset when a new recording is started.
+ LaunchedEffect(isRecording) {
+ if (isRecording) {
+ val audioTrack = audioTrackState.value
+ audioTrack?.stop()
+ isPlaying = false
+ playbackProgress = 0f
+ }
+ }
+
+ // Cleanup on Composable Disposal.
+ DisposableEffect(Unit) {
+ onDispose {
+ val audioTrack = audioTrackState.value
+ audioTrack?.stop()
+ audioTrack?.release()
+ }
+ }
+
+ Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) {
+ // Button to play/stop the clip.
+ IconButton(
+ onClick = {
+ coroutineScope.launch {
+ if (!isPlaying) {
+ isPlaying = true
+ playAudio(
+ audioTrackState = audioTrackState,
+ audioData = audioData,
+ sampleRate = sampleRate,
+ onProgress = { playbackProgress = it },
+ onCompletion = {
+ playbackProgress = 0f
+ isPlaying = false
+ },
+ )
+ } else {
+ stopPlayAudio(audioTrackState = audioTrackState)
+ playbackProgress = 0f
+ isPlaying = false
+ }
+ }
+ }
+ ) {
+ Icon(
+ if (isPlaying) Icons.Rounded.Stop else Icons.Rounded.PlayArrow,
+ contentDescription = "",
+ tint = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
+ )
+ }
+
+ // Visualization
+ AmplitudeBarGraph(
+ amplitudeLevels = amplitudeLevels,
+ progress = playbackProgress,
+ modifier =
+ Modifier.width((barCount * BAR_WIDTH + (barCount - 1) * BAR_SPACE).dp).height(24.dp),
+ onDarkBg = onDarkBg,
+ )
+
+ // Duration
+ Text(
+ "${"%.1f".format(durationInSeconds)}s",
+ style = MaterialTheme.typography.labelLarge,
+ color = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
+ modifier = Modifier.padding(start = 12.dp),
+ )
+ }
+}
+
+@Composable
+private fun AmplitudeBarGraph(
+ amplitudeLevels: List,
+ progress: Float,
+ modifier: Modifier = Modifier,
+ onDarkBg: Boolean = false,
+) {
+ val barColor = MaterialTheme.customColors.waveFormBgColor
+ val progressColor = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary
+
+ Canvas(modifier = modifier) {
+ val barCount = amplitudeLevels.size
+ val barWidth = (size.width - BAR_SPACE.dp.toPx() * (barCount - 1)) / barCount
+ val cornerRadius = CornerRadius(x = barWidth, y = barWidth)
+
+ // Use drawIntoCanvas for advanced blend mode operations
+ drawIntoCanvas { canvas ->
+
+ // 1. Save the current state of the canvas onto a temporary, offscreen layer
+ canvas.saveLayer(size.toRect(), androidx.compose.ui.graphics.Paint())
+
+ // 2. Draw the bars in grey.
+ amplitudeLevels.forEachIndexed { index, level ->
+ val barHeight = (level * size.height).coerceAtLeast(1.5f)
+ val left = index * (barWidth + BAR_SPACE.dp.toPx())
+ drawRoundRect(
+ color = barColor,
+ topLeft = Offset(x = left, y = size.height / 2 - barHeight / 2),
+ size = Size(barWidth, barHeight),
+ cornerRadius = cornerRadius,
+ )
+ }
+
+ // 3. Draw the progress rectangle using BlendMode.SrcIn to only draw where the bars already
+ // exists.
+ val progressWidth = size.width * progress
+ drawRect(
+ color = progressColor,
+ topLeft = Offset.Zero,
+ size = Size(progressWidth, size.height),
+ blendMode = BlendMode.SrcIn,
+ )
+
+ // 4. Restore the layer, merging it onto the main canvas
+ canvas.restore()
+ }
+ }
+}
+
+private suspend fun playAudio(
+ audioTrackState: MutableState,
+ audioData: ByteArray,
+ sampleRate: Int,
+ onProgress: (Float) -> Unit,
+ onCompletion: () -> Unit,
+) {
+ Log.d(TAG, "Start playing audio...")
+
+ try {
+ withContext(Dispatchers.IO) {
+ var lastProgressUpdateMs = 0L
+ audioTrackState.value?.release()
+ val audioTrack =
+ AudioTrack.Builder()
+ .setAudioAttributes(
+ AudioAttributes.Builder()
+ .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
+ .setUsage(AudioAttributes.USAGE_MEDIA)
+ .build()
+ )
+ .setAudioFormat(
+ AudioFormat.Builder()
+ .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
+ .setSampleRate(sampleRate)
+ .setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
+ .build()
+ )
+ .setTransferMode(AudioTrack.MODE_STATIC)
+ .setBufferSizeInBytes(audioData.size)
+ .build()
+
+ val bytesPerFrame = 2 // For PCM 16-bit Mono
+ val totalFrames = audioData.size / bytesPerFrame
+
+ audioTrackState.value = audioTrack
+ audioTrack.write(audioData, 0, audioData.size)
+ audioTrack.play()
+
+ // Coroutine to monitor progress
+ while (isActive && audioTrack.playState == AudioTrack.PLAYSTATE_PLAYING) {
+ val currentFrames = audioTrack.playbackHeadPosition
+ if (currentFrames >= totalFrames) {
+ break // Exit loop when playback is done
+ }
+ val progress = currentFrames.toFloat() / totalFrames
+ val curMs = System.currentTimeMillis()
+ if (curMs - lastProgressUpdateMs > 30) {
+ onProgress(progress)
+ lastProgressUpdateMs = curMs
+ }
+ }
+
+ if (isActive) {
+ audioTrackState.value?.stop()
+ }
+ }
+ } catch (e: Exception) {
+ // Ignore
+ } finally {
+ onProgress(1f)
+ onCompletion()
+ }
+}
+
+private fun stopPlayAudio(audioTrackState: MutableState) {
+ Log.d(TAG, "Stopping playing audio...")
+
+ val audioTrack = audioTrackState.value
+ audioTrack?.stop()
+ audioTrack?.release()
+ audioTrackState.value = null
+}
+
+/**
+ * Processes a raw PCM 16-bit audio byte array to generate a list of normalized amplitude levels for
+ * visualization.
+ */
+private fun generateAmplitudeLevels(audioData: ByteArray, barCount: Int): List {
+ if (audioData.isEmpty()) {
+ return List(barCount) { 0f }
+ }
+
+ // 1. Parse bytes into 16-bit short samples (PCM 16-bit)
+ val shortBuffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
+ val samples = ShortArray(shortBuffer.remaining())
+ shortBuffer.get(samples)
+
+ if (samples.isEmpty()) {
+ return List(barCount) { 0f }
+ }
+
+ // 2. Determine the size of each chunk
+ val chunkSize = samples.size / barCount
+ val amplitudeLevels = mutableListOf()
+
+ // 3. Get the max value for each chunk
+ for (i in 0 until barCount) {
+ val chunkStart = i * chunkSize
+ val chunkEnd = (chunkStart + chunkSize).coerceAtMost(samples.size)
+
+ var maxAmplitudeInChunk = 0.0
+
+ for (j in chunkStart until chunkEnd) {
+ val sampleAbs = kotlin.math.abs(samples[j].toDouble())
+ if (sampleAbs > maxAmplitudeInChunk) {
+ maxAmplitudeInChunk = sampleAbs
+ }
+ }
+
+ // 4. Normalize the value (0 to 1)
+ // Short.MAX_VALUE is 32767.0, a good reference for max amplitude
+ val normalizedRms = (maxAmplitudeInChunk / Short.MAX_VALUE).toFloat().coerceIn(0f, 1f)
+ amplitudeLevels.add(normalizedRms)
+ }
+
+ // Normalize the resulting levels so that the max value becomes 0.9.
+ val maxVal = amplitudeLevels.max()
+ if (maxVal == 0f) {
+ return amplitudeLevels
+ }
+ val scaleFactor = 0.9f / maxVal
+ return amplitudeLevels.map { it * scaleFactor }
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt
new file mode 100644
index 0000000..9b018dc
--- /dev/null
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt
@@ -0,0 +1,327 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.edge.gallery.ui.common.chat
+
+import android.annotation.SuppressLint
+import android.content.Context
+import android.media.AudioFormat
+import android.media.AudioRecord
+import android.media.MediaRecorder
+import android.util.Log
+import androidx.compose.foundation.background
+import androidx.compose.foundation.layout.Arrangement
+import androidx.compose.foundation.layout.Box
+import androidx.compose.foundation.layout.Column
+import androidx.compose.foundation.layout.Row
+import androidx.compose.foundation.layout.fillMaxWidth
+import androidx.compose.foundation.layout.height
+import androidx.compose.foundation.layout.padding
+import androidx.compose.foundation.layout.size
+import androidx.compose.foundation.shape.CircleShape
+import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.rounded.ArrowUpward
+import androidx.compose.material.icons.rounded.Mic
+import androidx.compose.material.icons.rounded.Stop
+import androidx.compose.material3.Icon
+import androidx.compose.material3.IconButton
+import androidx.compose.material3.MaterialTheme
+import androidx.compose.material3.Text
+import androidx.compose.runtime.Composable
+import androidx.compose.runtime.DisposableEffect
+import androidx.compose.runtime.MutableLongState
+import androidx.compose.runtime.MutableState
+import androidx.compose.runtime.derivedStateOf
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableIntStateOf
+import androidx.compose.runtime.mutableLongStateOf
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.remember
+import androidx.compose.runtime.rememberCoroutineScope
+import androidx.compose.runtime.setValue
+import androidx.compose.ui.Alignment
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.draw.clip
+import androidx.compose.ui.graphics.Color
+import androidx.compose.ui.graphics.graphicsLayer
+import androidx.compose.ui.platform.LocalContext
+import androidx.compose.ui.res.painterResource
+import androidx.compose.ui.unit.dp
+import com.google.ai.edge.gallery.R
+import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
+import com.google.ai.edge.gallery.data.SAMPLE_RATE
+import com.google.ai.edge.gallery.ui.theme.customColors
+import java.io.ByteArrayOutputStream
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import kotlin.math.abs
+import kotlin.math.pow
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.coroutineScope
+import kotlinx.coroutines.launch
+
+private const val TAG = "AGAudioRecorderPanel"
+
+private const val CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO
+private const val AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT
+
+/**
+ * A Composable that provides an audio recording panel. It allows users to record audio clips,
+ * displays the recording duration and a live amplitude visualization, and provides options to play
+ * back the recorded clip or send it.
+ */
+@Composable
+fun AudioRecorderPanel(onSendAudioClip: (ByteArray) -> Unit) {
+ val context = LocalContext.current
+ val coroutineScope = rememberCoroutineScope()
+
+ var isRecording by remember { mutableStateOf(false) }
+ val elapsedMs = remember { mutableLongStateOf(0L) }
+ val audioRecordState = remember { mutableStateOf(null) }
+ val audioStream = remember { ByteArrayOutputStream() }
+ val recordedBytes = remember { mutableStateOf(null) }
+ var currentAmplitude by remember { mutableIntStateOf(0) }
+
+ val elapsedSeconds by remember {
+ derivedStateOf { "%.1f".format(elapsedMs.value.toFloat() / 1000f) }
+ }
+
+ // Cleanup on Composable Disposal.
+ DisposableEffect(Unit) { onDispose { audioRecordState.value?.release() } }
+
+ Column(modifier = Modifier.padding(bottom = 12.dp)) {
+ // Title bar.
+ Row(
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.SpaceBetween,
+ ) {
+ // Logo and state.
+ Row(
+ modifier = Modifier.padding(start = 16.dp),
+ horizontalArrangement = Arrangement.spacedBy(12.dp),
+ ) {
+ Icon(
+ painterResource(R.drawable.logo),
+ modifier = Modifier.size(20.dp),
+ contentDescription = "",
+ tint = Color.Unspecified,
+ )
+ Text(
+ "Record audio clip (up to $MAX_AUDIO_CLIP_DURATION_SEC seconds)",
+ style = MaterialTheme.typography.labelLarge,
+ )
+ }
+ }
+
+ // Recorded clip.
+ Row(
+ modifier = Modifier.fillMaxWidth().padding(vertical = 12.dp).height(40.dp),
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.Center,
+ ) {
+ val curRecordedBytes = recordedBytes.value
+ if (curRecordedBytes == null) {
+ // Info message when there is no recorded clip and the recording has not started yet.
+ if (!isRecording) {
+ Text(
+ "Tap the record button to start",
+ style = MaterialTheme.typography.labelLarge,
+ color = MaterialTheme.colorScheme.onSurfaceVariant,
+ )
+ }
+ // Visualization for clip being recorded.
+ else {
+ Row(
+ horizontalArrangement = Arrangement.spacedBy(12.dp),
+ verticalAlignment = Alignment.CenterVertically,
+ ) {
+ Box(
+ modifier =
+ Modifier.size(8.dp)
+ .background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
+ )
+ Text("$elapsedSeconds s")
+ }
+ }
+ }
+ // Controls for recorded clip.
+ else {
+ Row {
+ // Clip player.
+ AudioPlaybackPanel(
+ audioData = curRecordedBytes,
+ sampleRate = SAMPLE_RATE,
+ isRecording = isRecording,
+ )
+
+ // Button to send the clip
+ IconButton(onClick = { onSendAudioClip(curRecordedBytes) }) {
+ Icon(Icons.Rounded.ArrowUpward, contentDescription = "")
+ }
+ }
+ }
+ }
+
+ // Buttons
+ Box(contentAlignment = Alignment.Center, modifier = Modifier.fillMaxWidth().height(40.dp)) {
+ // Visualization of the current amplitude.
+ if (isRecording) {
+ // Normalize the amplitude (0-32767) to a fraction (0.0-1.0)
+ // We use a power scale (exponent < 1) to make the pulse more visible for lower volumes.
+ val normalizedAmplitude = (currentAmplitude.toFloat() / 32767f).pow(0.35f)
+ // Define the min and max size of the circle
+ val minSize = 38.dp
+ val maxSize = 100.dp
+
+ // Map the normalized amplitude to our size range
+ val scale by
+ remember(normalizedAmplitude) {
+ derivedStateOf { (minSize + (maxSize - minSize) * normalizedAmplitude) / minSize }
+ }
+ Box(
+ modifier =
+ Modifier.size(minSize)
+ .graphicsLayer(scaleX = scale, scaleY = scale, clip = false, alpha = 0.3f)
+ .background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
+ )
+ }
+
+ // Record/stop button.
+ IconButton(
+ onClick = {
+ coroutineScope.launch {
+ if (!isRecording) {
+ isRecording = true
+ recordedBytes.value = null
+ startRecording(
+ context = context,
+ audioRecordState = audioRecordState,
+ audioStream = audioStream,
+ elapsedMs = elapsedMs,
+ onAmplitudeChanged = { currentAmplitude = it },
+ onMaxDurationReached = {
+ val curRecordedBytes =
+ stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
+ recordedBytes.value = curRecordedBytes
+ isRecording = false
+ },
+ )
+ } else {
+ val curRecordedBytes =
+ stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
+ recordedBytes.value = curRecordedBytes
+ isRecording = false
+ }
+ }
+ },
+ modifier =
+ Modifier.clip(CircleShape).background(MaterialTheme.customColors.recordButtonBgColor),
+ ) {
+ Icon(
+ if (isRecording) Icons.Rounded.Stop else Icons.Rounded.Mic,
+ contentDescription = "",
+ tint = Color.White,
+ )
+ }
+ }
+ }
+}
+
+// Permission is checked in parent composable.
+@SuppressLint("MissingPermission")
+private suspend fun startRecording(
+ context: Context,
+ audioRecordState: MutableState,
+ audioStream: ByteArrayOutputStream,
+ elapsedMs: MutableLongState,
+ onAmplitudeChanged: (Int) -> Unit,
+ onMaxDurationReached: () -> Unit,
+) {
+ Log.d(TAG, "Start recording...")
+ val minBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT)
+
+ audioRecordState.value?.release()
+ val recorder =
+ AudioRecord(
+ MediaRecorder.AudioSource.MIC,
+ SAMPLE_RATE,
+ CHANNEL_CONFIG,
+ AUDIO_FORMAT,
+ minBufferSize,
+ )
+
+ audioRecordState.value = recorder
+ val buffer = ByteArray(minBufferSize)
+
+ // The function will only return when the recording is done (when stopRecording is called).
+ coroutineScope {
+ launch(Dispatchers.IO) {
+ recorder.startRecording()
+
+ val startMs = System.currentTimeMillis()
+ elapsedMs.value = 0L
+ while (audioRecordState.value?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
+ val bytesRead = recorder.read(buffer, 0, buffer.size)
+ if (bytesRead > 0) {
+ val currentAmplitude = calculatePeakAmplitude(buffer = buffer, bytesRead = bytesRead)
+ onAmplitudeChanged(currentAmplitude)
+ audioStream.write(buffer, 0, bytesRead)
+ }
+ elapsedMs.value = System.currentTimeMillis() - startMs
+ if (elapsedMs.value >= MAX_AUDIO_CLIP_DURATION_SEC * 1000) {
+ onMaxDurationReached()
+ break
+ }
+ }
+ }
+ }
+}
+
+private fun stopRecording(
+ audioRecordState: MutableState,
+ audioStream: ByteArrayOutputStream,
+): ByteArray {
+ Log.d(TAG, "Stopping recording...")
+
+ val recorder = audioRecordState.value
+ if (recorder?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
+ recorder.stop()
+ }
+ recorder?.release()
+ audioRecordState.value = null
+
+ val recordedBytes = audioStream.toByteArray()
+ audioStream.reset()
+ Log.d(TAG, "Stopped. Recorded ${recordedBytes.size} bytes.")
+
+ return recordedBytes
+}
+
+private fun calculatePeakAmplitude(buffer: ByteArray, bytesRead: Int): Int {
+ // Wrap the byte array in a ByteBuffer and set the order to little-endian
+ val shortBuffer =
+ ByteBuffer.wrap(buffer, 0, bytesRead).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
+
+ var maxAmplitude = 0
+ // Iterate through the short buffer to find the maximum absolute value
+ while (shortBuffer.hasRemaining()) {
+ val currentSample = abs(shortBuffer.get().toInt())
+ if (currentSample > maxAmplitude) {
+ maxAmplitude = currentSample
+ }
+ }
+ return maxAmplitude
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt
index 7eba194..e1f45b0 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt
@@ -17,18 +17,22 @@
package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap
+import android.util.Log
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.unit.Dp
import com.google.ai.edge.gallery.common.Classification
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.PromptTemplate
+private const val TAG = "AGChatMessage"
+
enum class ChatMessageType {
INFO,
WARNING,
TEXT,
IMAGE,
IMAGE_WITH_HISTORY,
+ AUDIO_CLIP,
LOADING,
CLASSIFICATION,
CONFIG_VALUES_CHANGE,
@@ -121,6 +125,90 @@ class ChatMessageImage(
}
}
+/** Chat message for audio clip. */
+class ChatMessageAudioClip(
+ val audioData: ByteArray,
+ val sampleRate: Int,
+ override val side: ChatSide,
+ override val latencyMs: Float = 0f,
+) : ChatMessage(type = ChatMessageType.AUDIO_CLIP, side = side, latencyMs = latencyMs) {
+ override fun clone(): ChatMessageAudioClip {
+ return ChatMessageAudioClip(
+ audioData = audioData,
+ sampleRate = sampleRate,
+ side = side,
+ latencyMs = latencyMs,
+ )
+ }
+
+ fun genByteArrayForWav(): ByteArray {
+ val header = ByteArray(44)
+
+ val pcmDataSize = audioData.size
+ val wavFileSize = pcmDataSize + 44 // 44 bytes for the header
+ val channels = 1 // Mono
+ val bitsPerSample: Short = 16
+ val byteRate = sampleRate * channels * bitsPerSample / 8
+ Log.d(TAG, "Wav metadata: sampleRate: $sampleRate")
+
+ // RIFF/WAVE header
+ header[0] = 'R'.code.toByte()
+ header[1] = 'I'.code.toByte()
+ header[2] = 'F'.code.toByte()
+ header[3] = 'F'.code.toByte()
+ header[4] = (wavFileSize and 0xff).toByte()
+ header[5] = (wavFileSize shr 8 and 0xff).toByte()
+ header[6] = (wavFileSize shr 16 and 0xff).toByte()
+ header[7] = (wavFileSize shr 24 and 0xff).toByte()
+ header[8] = 'W'.code.toByte()
+ header[9] = 'A'.code.toByte()
+ header[10] = 'V'.code.toByte()
+ header[11] = 'E'.code.toByte()
+ header[12] = 'f'.code.toByte()
+ header[13] = 'm'.code.toByte()
+ header[14] = 't'.code.toByte()
+ header[15] = ' '.code.toByte()
+ header[16] = 16
+ header[17] = 0
+ header[18] = 0
+ header[19] = 0 // Sub-chunk size (16 for PCM)
+ header[20] = 1
+ header[21] = 0 // Audio format (1 for PCM)
+ header[22] = channels.toByte()
+ header[23] = 0 // Number of channels
+ header[24] = (sampleRate and 0xff).toByte()
+ header[25] = (sampleRate shr 8 and 0xff).toByte()
+ header[26] = (sampleRate shr 16 and 0xff).toByte()
+ header[27] = (sampleRate shr 24 and 0xff).toByte()
+ header[28] = (byteRate and 0xff).toByte()
+ header[29] = (byteRate shr 8 and 0xff).toByte()
+ header[30] = (byteRate shr 16 and 0xff).toByte()
+ header[31] = (byteRate shr 24 and 0xff).toByte()
+ header[32] = (channels * bitsPerSample / 8).toByte()
+ header[33] = 0 // Block align
+ header[34] = bitsPerSample.toByte()
+ header[35] = (bitsPerSample.toInt() shr 8 and 0xff).toByte() // Bits per sample
+ header[36] = 'd'.code.toByte()
+ header[37] = 'a'.code.toByte()
+ header[38] = 't'.code.toByte()
+ header[39] = 'a'.code.toByte()
+ header[40] = (pcmDataSize and 0xff).toByte()
+ header[41] = (pcmDataSize shr 8 and 0xff).toByte()
+ header[42] = (pcmDataSize shr 16 and 0xff).toByte()
+ header[43] = (pcmDataSize shr 24 and 0xff).toByte()
+
+ return header + audioData
+ }
+
+ fun getDurationInSeconds(): Float {
+ // PCM 16-bit
+ val bytesPerSample = 2
+ val bytesPerFrame = bytesPerSample * 1 // mono
+ val totalFrames = audioData.size.toFloat() / bytesPerFrame
+ return totalFrames / sampleRate
+ }
+}
+
/** Chat message for images with history. */
class ChatMessageImageWithHistory(
val bitmaps: List,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt
index 2b8b168..06615e2 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt
@@ -137,6 +137,19 @@ fun ChatPanel(
}
imageMessageCount
}
+ val audioClipMesssageCountToLastconfigChange =
+ remember(messages) {
+ var audioClipMessageCount = 0
+ for (message in messages.reversed()) {
+ if (message is ChatMessageConfigValuesChange) {
+ break
+ }
+ if (message is ChatMessageAudioClip) {
+ audioClipMessageCount++
+ }
+ }
+ audioClipMessageCount
+ }
var curMessage by remember { mutableStateOf("") } // Correct state
val focusManager = LocalFocusManager.current
@@ -342,6 +355,9 @@ fun ChatPanel(
imageHistoryCurIndex = imageHistoryCurIndex,
)
+ // Audio clip.
+ is ChatMessageAudioClip -> MessageBodyAudioClip(message = message)
+
// Classification result
is ChatMessageClassification ->
MessageBodyClassification(
@@ -467,6 +483,22 @@ fun ChatPanel(
)
}
}
+ // Show an info message for ask image task to get users started.
+ else if (task.type == TaskType.LLM_ASK_AUDIO && messages.isEmpty()) {
+ Column(
+ modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(),
+ horizontalAlignment = Alignment.CenterHorizontally,
+ verticalArrangement = Arrangement.Center,
+ ) {
+ MessageBodyInfo(
+ ChatMessageInfo(
+ content =
+ "To get started, tap the + icon to add your audio clip. Limited to 1 clip up to 30 seconds long."
+ ),
+ smallFontSize = false,
+ )
+ }
+ }
}
// Chat input
@@ -482,6 +514,7 @@ fun ChatPanel(
isResettingSession = uiState.isResettingSession,
modelPreparing = uiState.preparing,
imageMessageCount = imageMessageCountToLastConfigChange,
+ audioClipMessageCount = audioClipMesssageCountToLastconfigChange,
modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
@@ -504,7 +537,10 @@ fun ChatPanel(
onStopButtonClicked = onStopButtonClicked,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false,
- showImagePickerInMenu = selectedModel.llmSupportImage,
+ showImagePickerInMenu =
+ selectedModel.llmSupportImage && task.type === TaskType.LLM_ASK_IMAGE,
+ showAudioItemsInMenu =
+ selectedModel.llmSupportAudio && task.type === TaskType.LLM_ASK_AUDIO,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
)
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatViewModel.kt
index 95785aa..d2f64e2 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatViewModel.kt
@@ -53,7 +53,7 @@ data class ChatUiState(
)
/** ViewModel responsible for managing the chat UI state and handling chat-related operations. */
-open class ChatViewModel(val task: Task) : ViewModel() {
+abstract class ChatViewModel(val task: Task) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow()
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewLlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt
similarity index 54%
rename from Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewLlmSingleTurnViewModel.kt
rename to Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt
index 0c8d9dc..e4e3716 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewLlmSingleTurnViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt
@@ -14,8 +14,20 @@
* limitations under the License.
*/
-package com.google.ai.edge.gallery.ui.preview
+package com.google.ai.edge.gallery.ui.common.chat
-import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
+import androidx.compose.foundation.layout.padding
+import androidx.compose.runtime.Composable
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.unit.dp
-class PreviewLlmSingleTurnViewModel : LlmSingleTurnViewModel(task = TASK_TEST1)
+@Composable
+fun MessageBodyAudioClip(message: ChatMessageAudioClip, modifier: Modifier = Modifier) {
+ AudioPlaybackPanel(
+ audioData = message.audioData,
+ sampleRate = message.sampleRate,
+ isRecording = false,
+ modifier = Modifier.padding(end = 16.dp),
+ onDarkBg = true,
+ )
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt
index 1e2e056..d958815 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt
@@ -21,6 +21,7 @@ package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.Manifest
import android.content.Context
+import android.content.Intent
import android.content.pm.PackageManager
import android.graphics.Bitmap
import android.graphics.BitmapFactory
@@ -65,9 +66,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.Send
import androidx.compose.material.icons.rounded.Add
+import androidx.compose.material.icons.rounded.AudioFile
import androidx.compose.material.icons.rounded.Close
import androidx.compose.material.icons.rounded.FlipCameraAndroid
import androidx.compose.material.icons.rounded.History
+import androidx.compose.material.icons.rounded.Mic
import androidx.compose.material.icons.rounded.Photo
import androidx.compose.material.icons.rounded.PhotoCamera
import androidx.compose.material.icons.rounded.PostAdd
@@ -107,9 +110,14 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.viewinterop.AndroidView
import androidx.core.content.ContextCompat
import androidx.lifecycle.compose.LocalLifecycleOwner
+import com.google.ai.edge.gallery.common.AudioClip
+import com.google.ai.edge.gallery.common.convertWavToMonoWithMaxSeconds
+import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_COUNT
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
+import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import java.util.concurrent.Executors
+import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
private const val TAG = "AGMessageInputText"
@@ -128,6 +136,7 @@ fun MessageInputText(
isResettingSession: Boolean,
inProgress: Boolean,
imageMessageCount: Int,
+ audioClipMessageCount: Int,
modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit,
@@ -137,6 +146,7 @@ fun MessageInputText(
onStopButtonClicked: () -> Unit = {},
showPromptTemplatesInMenu: Boolean = false,
showImagePickerInMenu: Boolean = false,
+ showAudioItemsInMenu: Boolean = false,
showStopButtonWhenInProgress: Boolean = false,
) {
val context = LocalContext.current
@@ -146,7 +156,12 @@ fun MessageInputText(
var showTextInputHistorySheet by remember { mutableStateOf(false) }
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
+ var showAudioRecorderBottomSheet by remember { mutableStateOf(false) }
+ val audioRecorderSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var pickedImages by remember { mutableStateOf>(listOf()) }
+ var pickedAudioClips by remember { mutableStateOf>(listOf()) }
+ var hasFrontCamera by remember { mutableStateOf(false) }
+
val updatePickedImages: (List) -> Unit = { bitmaps ->
var newPickedImages: MutableList = mutableListOf()
newPickedImages.addAll(pickedImages)
@@ -156,7 +171,16 @@ fun MessageInputText(
}
pickedImages = newPickedImages.toList()
}
- var hasFrontCamera by remember { mutableStateOf(false) }
+
+ val updatePickedAudioClips: (List) -> Unit = { audioDataList ->
+ var newAudioDataList: MutableList = mutableListOf()
+ newAudioDataList.addAll(pickedAudioClips)
+ newAudioDataList.addAll(audioDataList)
+ if (newAudioDataList.size > MAX_AUDIO_CLIP_COUNT) {
+ newAudioDataList = newAudioDataList.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
+ }
+ pickedAudioClips = newAudioDataList.toList()
+ }
LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) }
@@ -170,6 +194,16 @@ fun MessageInputText(
}
}
+ // Permission request when recording audio clips.
+ val recordAudioClipsPermissionLauncher =
+ rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
+ permissionGranted ->
+ if (permissionGranted) {
+ showAddContentMenu = false
+ showAudioRecorderBottomSheet = true
+ }
+ }
+
// Registers a photo picker activity launcher in single-select mode.
val pickMedia =
rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris ->
@@ -184,9 +218,31 @@ fun MessageInputText(
}
}
+ val pickWav =
+ rememberLauncherForActivityResult(
+ contract = ActivityResultContracts.StartActivityForResult()
+ ) { result ->
+ if (result.resultCode == android.app.Activity.RESULT_OK) {
+ result.data?.data?.let { uri ->
+ Log.d(TAG, "Picked wav file: $uri")
+ scope.launch(Dispatchers.IO) {
+ convertWavToMonoWithMaxSeconds(context = context, stereoUri = uri)?.let { audioClip ->
+ updatePickedAudioClips(
+ listOf(
+ AudioClip(audioData = audioClip.audioData, sampleRate = audioClip.sampleRate)
+ )
+ )
+ }
+ }
+ }
+ } else {
+ Log.d(TAG, "Wav picking cancelled.")
+ }
+ }
+
Column {
- // A preview panel for the selected image.
- if (pickedImages.isNotEmpty()) {
+ // A preview panel for the selected images and audio clips.
+ if (pickedImages.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
Row(
modifier =
Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()),
@@ -203,20 +259,30 @@ fun MessageInputText(
.clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
)
+ MediaPanelCloseButton { pickedImages = pickedImages.filter { image != it } }
+ }
+ }
+
+ for ((index, audioClip) in pickedAudioClips.withIndex()) {
+ Box(contentAlignment = Alignment.TopEnd) {
Box(
modifier =
- Modifier.offset(x = 10.dp, y = (-10).dp)
- .clip(CircleShape)
+ Modifier.shadow(2.dp, shape = RoundedCornerShape(8.dp))
+ .clip(RoundedCornerShape(8.dp))
.background(MaterialTheme.colorScheme.surface)
- .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
- .clickable { pickedImages = pickedImages.filter { image != it } }
+ .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp))
) {
- Icon(
- Icons.Rounded.Close,
- contentDescription = "",
- modifier = Modifier.padding(3.dp).size(16.dp),
+ AudioPlaybackPanel(
+ audioData = audioClip.audioData,
+ sampleRate = audioClip.sampleRate,
+ isRecording = false,
+ modifier = Modifier.padding(end = 16.dp),
)
}
+ MediaPanelCloseButton {
+ pickedAudioClips =
+ pickedAudioClips.filterIndexed { curIndex, curAudioData -> curIndex != index }
+ }
}
}
}
@@ -239,10 +305,13 @@ fun MessageInputText(
verticalAlignment = Alignment.CenterVertically,
) {
val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT
+ val enableRecordAudioClipMenuItems =
+ (audioClipMessageCount + pickedAudioClips.size) < MAX_AUDIO_CLIP_COUNT
DropdownMenu(
expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false },
) {
+ // Image related menu items.
if (showImagePickerInMenu) {
// Take a picture.
DropdownMenuItem(
@@ -295,6 +364,70 @@ fun MessageInputText(
)
}
+ // Audio related menu items.
+ if (showAudioItemsInMenu) {
+ DropdownMenuItem(
+ text = {
+ Row(
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.spacedBy(6.dp),
+ ) {
+ Icon(Icons.Rounded.Mic, contentDescription = "")
+ Text("Record audio clip")
+ }
+ },
+ enabled = enableRecordAudioClipMenuItems,
+ onClick = {
+ // Check permission
+ when (PackageManager.PERMISSION_GRANTED) {
+ // Already got permission. Call the lambda.
+ ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO) -> {
+ showAddContentMenu = false
+ showAudioRecorderBottomSheet = true
+ }
+
+ // Otherwise, ask for permission
+ else -> {
+ recordAudioClipsPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO)
+ }
+ }
+ },
+ )
+
+ DropdownMenuItem(
+ text = {
+ Row(
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.spacedBy(6.dp),
+ ) {
+ Icon(Icons.Rounded.AudioFile, contentDescription = "")
+ Text("Pick wav file")
+ }
+ },
+ enabled = enableRecordAudioClipMenuItems,
+ onClick = {
+ showAddContentMenu = false
+
+ // Show file picker.
+ val intent =
+ Intent(Intent.ACTION_GET_CONTENT).apply {
+ addCategory(Intent.CATEGORY_OPENABLE)
+ type = "audio/*"
+
+ // Provide a list of more specific MIME types to filter for.
+ val mimeTypes = arrayOf("audio/wav", "audio/x-wav")
+ putExtra(Intent.EXTRA_MIME_TYPES, mimeTypes)
+
+ // Single select.
+ putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
+ .addFlags(Intent.FLAG_GRANT_PERSISTABLE_URI_PERMISSION)
+ .addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
+ }
+ pickWav.launch(intent)
+ },
+ )
+ }
+
// Prompt templates.
if (showPromptTemplatesInMenu) {
DropdownMenuItem(
@@ -369,15 +502,22 @@ fun MessageInputText(
)
}
}
- } // Send button. Only shown when text is not empty.
- else if (curMessage.isNotEmpty()) {
+ }
+ // Send button. Only shown when text is not empty, or there is at least one recorded
+ // audio clip.
+ else if (curMessage.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
IconButton(
enabled = !inProgress && !isResettingSession,
onClick = {
onSendMessage(
- createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
+ createMessagesToSend(
+ pickedImages = pickedImages,
+ audioClips = pickedAudioClips,
+ text = curMessage.trim(),
+ )
)
pickedImages = listOf()
+ pickedAudioClips = listOf()
},
colors =
IconButtonDefaults.iconButtonColors(
@@ -403,8 +543,15 @@ fun MessageInputText(
history = modelManagerUiState.textInputHistory,
onDismissed = { showTextInputHistorySheet = false },
onHistoryItemClicked = { item ->
- onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
+ onSendMessage(
+ createMessagesToSend(
+ pickedImages = pickedImages,
+ audioClips = pickedAudioClips,
+ text = item,
+ )
+ )
pickedImages = listOf()
+ pickedAudioClips = listOf()
modelManagerViewModel.promoteTextInputHistoryItem(item)
},
onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) },
@@ -582,6 +729,43 @@ fun MessageInputText(
}
}
}
+
+ if (showAudioRecorderBottomSheet) {
+ ModalBottomSheet(
+ sheetState = audioRecorderSheetState,
+ onDismissRequest = { showAudioRecorderBottomSheet = false },
+ ) {
+ AudioRecorderPanel(
+ onSendAudioClip = { audioData ->
+ scope.launch {
+ updatePickedAudioClips(
+ listOf(AudioClip(audioData = audioData, sampleRate = SAMPLE_RATE))
+ )
+ audioRecorderSheetState.hide()
+ showAudioRecorderBottomSheet = false
+ }
+ }
+ )
+ }
+ }
+}
+
+@Composable
+private fun MediaPanelCloseButton(onClicked: () -> Unit) {
+ Box(
+ modifier =
+ Modifier.offset(x = 10.dp, y = (-10).dp)
+ .clip(CircleShape)
+ .background(MaterialTheme.colorScheme.surface)
+ .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
+ .clickable { onClicked() }
+ ) {
+ Icon(
+ Icons.Rounded.Close,
+ contentDescription = "",
+ modifier = Modifier.padding(3.dp).size(16.dp),
+ )
+ }
}
private fun handleImagesSelected(
@@ -641,20 +825,50 @@ private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) {
)
}
-private fun createMessagesToSend(pickedImages: List, text: String): List {
+private fun createMessagesToSend(
+ pickedImages: List,
+ audioClips: List,
+ text: String,
+): List {
var messages: MutableList = mutableListOf()
+
+ // Add image messages.
+ var imageMessages: MutableList = mutableListOf()
if (pickedImages.isNotEmpty()) {
for (image in pickedImages) {
- messages.add(
+ imageMessages.add(
ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER)
)
}
}
// Cap the number of image messages.
- if (messages.size > MAX_IMAGE_COUNT) {
- messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
+ if (imageMessages.size > MAX_IMAGE_COUNT) {
+ imageMessages = imageMessages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
+ }
+ messages.addAll(imageMessages)
+
+ // Add audio messages.
+ var audioMessages: MutableList = mutableListOf()
+ if (audioClips.isNotEmpty()) {
+ for (audioClip in audioClips) {
+ audioMessages.add(
+ ChatMessageAudioClip(
+ audioData = audioClip.audioData,
+ sampleRate = audioClip.sampleRate,
+ side = ChatSide.USER,
+ )
+ )
+ }
+ }
+ // Cap the number of audio messages.
+ if (audioMessages.size > MAX_AUDIO_CLIP_COUNT) {
+ audioMessages = audioMessages.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
+ }
+ messages.addAll(audioMessages)
+
+ if (text.isNotEmpty()) {
+ messages.add(ChatMessageText(content = text, side = ChatSide.USER))
}
- messages.add(ChatMessageText(content = text, side = ChatSide.USER))
return messages
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt
index 1f0cd00..9c985e5 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt
@@ -121,6 +121,7 @@ private val IMPORT_CONFIGS_LLM: List =
valueType = ValueType.FLOAT,
),
BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false),
+ BooleanSwitchConfig(key = ConfigKey.SUPPORT_AUDIO, defaultValue = false),
SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,
@@ -230,6 +231,12 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
valueType = ValueType.BOOLEAN,
)
as Boolean
+ val supportAudio =
+ convertValueToTargetType(
+ value = values.get(ConfigKey.SUPPORT_AUDIO.label)!!,
+ valueType = ValueType.BOOLEAN,
+ )
+ as Boolean
val importedModel: ImportedModel =
ImportedModel.newBuilder()
.setFileName(fileName)
@@ -242,6 +249,7 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
.setDefaultTopp(defaultTopp)
.setDefaultTemperature(defaultTemperature)
.setSupportImage(supportImage)
+ .setSupportAudio(supportAudio)
.build()
)
.build()
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt
index ee418bb..dc0bba1 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt
@@ -173,7 +173,7 @@ fun SettingsDialog(
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
Text(
- "Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
+ "Expires at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt
deleted file mode 100644
index 173ca4f..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.icon
-
-import androidx.compose.ui.graphics.Color
-import androidx.compose.ui.graphics.SolidColor
-import androidx.compose.ui.graphics.vector.ImageVector
-import androidx.compose.ui.graphics.vector.path
-import androidx.compose.ui.unit.dp
-
-val Forum: ImageVector
- get() {
- if (_Forum != null) return _Forum!!
-
- _Forum =
- ImageVector.Builder(
- name = "Forum",
- defaultWidth = 24.dp,
- defaultHeight = 24.dp,
- viewportWidth = 960f,
- viewportHeight = 960f,
- )
- .apply {
- path(fill = SolidColor(Color(0xFF000000))) {
- moveTo(280f, 720f)
- quadToRelative(-17f, 0f, -28.5f, -11.5f)
- reflectiveQuadTo(240f, 680f)
- verticalLineToRelative(-80f)
- horizontalLineToRelative(520f)
- verticalLineToRelative(-360f)
- horizontalLineToRelative(80f)
- quadToRelative(17f, 0f, 28.5f, 11.5f)
- reflectiveQuadTo(880f, 280f)
- verticalLineToRelative(600f)
- lineTo(720f, 720f)
- close()
- moveTo(80f, 680f)
- verticalLineToRelative(-560f)
- quadToRelative(0f, -17f, 11.5f, -28.5f)
- reflectiveQuadTo(120f, 80f)
- horizontalLineToRelative(520f)
- quadToRelative(17f, 0f, 28.5f, 11.5f)
- reflectiveQuadTo(680f, 120f)
- verticalLineToRelative(360f)
- quadToRelative(0f, 17f, -11.5f, 28.5f)
- reflectiveQuadTo(640f, 520f)
- horizontalLineTo(240f)
- close()
- moveToRelative(520f, -240f)
- verticalLineToRelative(-280f)
- horizontalLineTo(160f)
- verticalLineToRelative(280f)
- close()
- moveToRelative(-440f, 0f)
- verticalLineToRelative(-280f)
- close()
- }
- }
- .build()
-
- return _Forum!!
- }
-
-private var _Forum: ImageVector? = null
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt
deleted file mode 100644
index 56f9990..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.icon
-
-import androidx.compose.ui.graphics.Color
-import androidx.compose.ui.graphics.SolidColor
-import androidx.compose.ui.graphics.vector.ImageVector
-import androidx.compose.ui.graphics.vector.path
-import androidx.compose.ui.unit.dp
-
-val Mms: ImageVector
- get() {
- if (_Mms != null) return _Mms!!
-
- _Mms =
- ImageVector.Builder(
- name = "Mms",
- defaultWidth = 24.dp,
- defaultHeight = 24.dp,
- viewportWidth = 960f,
- viewportHeight = 960f,
- )
- .apply {
- path(fill = SolidColor(Color(0xFF000000))) {
- moveTo(240f, 560f)
- horizontalLineToRelative(480f)
- lineTo(570f, 360f)
- lineTo(450f, 520f)
- lineToRelative(-90f, -120f)
- close()
- moveTo(80f, 880f)
- verticalLineToRelative(-720f)
- quadToRelative(0f, -33f, 23.5f, -56.5f)
- reflectiveQuadTo(160f, 80f)
- horizontalLineToRelative(640f)
- quadToRelative(33f, 0f, 56.5f, 23.5f)
- reflectiveQuadTo(880f, 160f)
- verticalLineToRelative(480f)
- quadToRelative(0f, 33f, -23.5f, 56.5f)
- reflectiveQuadTo(800f, 720f)
- horizontalLineTo(240f)
- close()
- moveToRelative(126f, -240f)
- horizontalLineToRelative(594f)
- verticalLineToRelative(-480f)
- horizontalLineTo(160f)
- verticalLineToRelative(525f)
- close()
- moveToRelative(-46f, 0f)
- verticalLineToRelative(-480f)
- close()
- }
- }
- .build()
-
- return _Mms!!
- }
-
-private var _Mms: ImageVector? = null
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt
deleted file mode 100644
index c727a7b..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.icon
-
-import androidx.compose.ui.graphics.Color
-import androidx.compose.ui.graphics.SolidColor
-import androidx.compose.ui.graphics.vector.ImageVector
-import androidx.compose.ui.graphics.vector.path
-import androidx.compose.ui.unit.dp
-
-val Widgets: ImageVector
- get() {
- if (_Widgets != null) return _Widgets!!
-
- _Widgets =
- ImageVector.Builder(
- name = "Widgets",
- defaultWidth = 24.dp,
- defaultHeight = 24.dp,
- viewportWidth = 960f,
- viewportHeight = 960f,
- )
- .apply {
- path(fill = SolidColor(Color(0xFF000000))) {
- moveTo(666f, 520f)
- lineTo(440f, 294f)
- lineToRelative(226f, -226f)
- lineToRelative(226f, 226f)
- close()
- moveToRelative(-546f, -80f)
- verticalLineToRelative(-320f)
- horizontalLineToRelative(320f)
- verticalLineToRelative(320f)
- close()
- moveToRelative(400f, 400f)
- verticalLineToRelative(-320f)
- horizontalLineToRelative(320f)
- verticalLineToRelative(320f)
- close()
- moveToRelative(-400f, 0f)
- verticalLineToRelative(-320f)
- horizontalLineToRelative(320f)
- verticalLineToRelative(320f)
- close()
- moveToRelative(80f, -480f)
- horizontalLineToRelative(160f)
- verticalLineToRelative(-160f)
- horizontalLineTo(200f)
- close()
- moveToRelative(467f, 48f)
- lineToRelative(113f, -113f)
- lineToRelative(-113f, -113f)
- lineToRelative(-113f, 113f)
- close()
- moveToRelative(-67f, 352f)
- horizontalLineToRelative(160f)
- verticalLineToRelative(-160f)
- horizontalLineTo(600f)
- close()
- moveToRelative(-400f, 0f)
- horizontalLineToRelative(160f)
- verticalLineToRelative(-160f)
- horizontalLineTo(200f)
- close()
- moveToRelative(400f, -160f)
- }
- }
- .build()
-
- return _Widgets!!
- }
-
-private var _Widgets: ImageVector? = null
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt
index 128baa8..a330c35 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt
@@ -62,13 +62,13 @@ object LlmChatModelHelper {
Accelerator.GPU.label -> LlmInference.Backend.GPU
else -> LlmInference.Backend.GPU
}
- val options =
+ val optionsBuilder =
LlmInference.LlmInferenceOptions.builder()
.setModelPath(model.getPath(context = context))
.setMaxTokens(maxTokens)
.setPreferredBackend(preferredBackend)
.setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0)
- .build()
+ val options = optionsBuilder.build()
// Create an instance of the LLM Inference task and session.
try {
@@ -82,7 +82,9 @@ object LlmChatModelHelper {
.setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
- GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
+ GraphOptions.builder()
+ .setEnableVisionModality(model.llmSupportImage)
+ .build()
)
.build(),
)
@@ -115,7 +117,9 @@ object LlmChatModelHelper {
.setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
- GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
+ GraphOptions.builder()
+ .setEnableVisionModality(model.llmSupportImage)
+ .build()
)
.build(),
)
@@ -159,6 +163,7 @@ object LlmChatModelHelper {
resultListener: ResultListener,
cleanUpListener: CleanUpListener,
images: List = listOf(),
+ audioClips: List = listOf(),
) {
val instance = model.instance as LlmModelInstance
@@ -172,10 +177,16 @@ object LlmChatModelHelper {
// For a model that supports image modality, we need to add the text query chunk before adding
// image.
val session = instance.session
- session.addQueryChunk(input)
+ if (input.trim().isNotEmpty()) {
+ session.addQueryChunk(input)
+ }
for (image in images) {
session.addImage(BitmapImageBuilder(image).build())
}
+ for (audioClip in audioClips) {
+ // Uncomment when audio is supported.
+ // session.addAudio(audioClip)
+ }
val unused = session.generateResponseAsync(resultListener)
}
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt
index 0a97f3f..e61e2eb 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt
@@ -20,8 +20,7 @@ import android.graphics.Bitmap
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
-import androidx.lifecycle.viewmodel.compose.viewModel
-import com.google.ai.edge.gallery.ui.ViewModelProvider
+import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatView
@@ -36,12 +35,16 @@ object LlmAskImageDestination {
val route = "LlmAskImageRoute"
}
+object LlmAskAudioDestination {
+ val route = "LlmAskAudioRoute"
+}
+
@Composable
fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
- viewModel: LlmChatViewModel = viewModel(factory = ViewModelProvider.Factory),
+ viewModel: LlmChatViewModel,
) {
ChatViewWrapper(
viewModel = viewModel,
@@ -56,7 +59,22 @@ fun LlmAskImageScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
- viewModel: LlmAskImageViewModel = viewModel(factory = ViewModelProvider.Factory),
+ viewModel: LlmAskImageViewModel,
+) {
+ ChatViewWrapper(
+ viewModel = viewModel,
+ modelManagerViewModel = modelManagerViewModel,
+ navigateUp = navigateUp,
+ modifier = modifier,
+ )
+}
+
+@Composable
+fun LlmAskAudioScreen(
+ modelManagerViewModel: ModelManagerViewModel,
+ navigateUp: () -> Unit,
+ modifier: Modifier = Modifier,
+ viewModel: LlmAskAudioViewModel,
) {
ChatViewWrapper(
viewModel = viewModel,
@@ -68,7 +86,7 @@ fun LlmAskImageScreen(
@Composable
fun ChatViewWrapper(
- viewModel: LlmChatViewModel,
+ viewModel: LlmChatViewModelBase,
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
@@ -86,6 +104,7 @@ fun ChatViewWrapper(
var text = ""
val images: MutableList = mutableListOf()
+ val audioMessages: MutableList = mutableListOf()
var chatMessageText: ChatMessageText? = null
for (message in messages) {
if (message is ChatMessageText) {
@@ -93,14 +112,17 @@ fun ChatViewWrapper(
text = message.content
} else if (message is ChatMessageImage) {
images.add(message.bitmap)
+ } else if (message is ChatMessageAudioClip) {
+ audioMessages.add(message)
}
}
- if (text.isNotEmpty() && chatMessageText != null) {
+ if ((text.isNotEmpty() && chatMessageText != null) || audioMessages.isNotEmpty()) {
modelManagerViewModel.addTextInputHistory(text)
viewModel.generateResponse(
model = model,
input = text,
images = images,
+ audioMessages = audioMessages,
onError = {
viewModel.handleError(
context = context,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt
index 7b0a083..205d4e0 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt
@@ -22,9 +22,11 @@ import android.util.Log
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
+import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.Task
+import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
@@ -34,6 +36,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
+import dagger.hilt.android.lifecycle.HiltViewModel
+import javax.inject.Inject
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
@@ -47,11 +51,12 @@ private val STATS =
Stat(id = "latency", label = "Latency", unit = "sec"),
)
-open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
+open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTask) {
fun generateResponse(
model: Model,
input: String,
images: List = listOf(),
+ audioMessages: List = listOf(),
onError: () -> Unit,
) {
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
@@ -72,6 +77,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
val instance = model.instance as LlmModelInstance
var prefillTokens = instance.session.sizeInTokens(input)
prefillTokens += images.size * 257
+ for (audioMessage in audioMessages) {
+ // 150ms = 1 audio token
+ val duration = audioMessage.getDurationInSeconds()
+ prefillTokens += (duration * 1000f / 150f).toInt()
+ }
var firstRun = true
var timeToFirstToken = 0f
@@ -86,6 +96,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
model = model,
input = input,
images = images,
+ audioClips = audioMessages.map { it.genByteArrayForWav() },
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
@@ -214,7 +225,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
context: Context,
model: Model,
modelManagerViewModel: ModelManagerViewModel,
- triggeredMessage: ChatMessageText,
+ triggeredMessage: ChatMessageText?,
) {
// Clean up.
modelManagerViewModel.cleanupModel(task = task, model = model)
@@ -236,14 +247,27 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
)
// Add the triggered message back.
- addMessage(model = model, message = triggeredMessage)
+ if (triggeredMessage != null) {
+ addMessage(model = model, message = triggeredMessage)
+ }
// Re-initialize the session/engine.
modelManagerViewModel.initializeModel(context = context, task = task, model = model)
// Re-generate the response automatically.
- generateResponse(model = model, input = triggeredMessage.content, onError = {})
+ if (triggeredMessage != null) {
+ generateResponse(model = model, input = triggeredMessage.content, onError = {})
+ }
}
}
-class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
+@HiltViewModel
+class LlmChatViewModel @Inject constructor() : LlmChatViewModelBase(curTask = TASK_LLM_CHAT)
+
+@HiltViewModel
+class LlmAskImageViewModel @Inject constructor() :
+ LlmChatViewModelBase(curTask = TASK_LLM_ASK_IMAGE)
+
+@HiltViewModel
+class LlmAskAudioViewModel @Inject constructor() :
+ LlmChatViewModelBase(curTask = TASK_LLM_ASK_AUDIO)
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt
index 2759ba4..f62e2f8 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt
@@ -43,9 +43,8 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection
-import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
-import com.google.ai.edge.gallery.ui.ViewModelProvider
+import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.ui.common.ErrorDialog
import com.google.ai.edge.gallery.ui.common.ModelPageAppBar
import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
@@ -67,9 +66,9 @@ fun LlmSingleTurnScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
- viewModel: LlmSingleTurnViewModel = viewModel(factory = ViewModelProvider.Factory),
+ viewModel: LlmSingleTurnViewModel,
) {
- val task = viewModel.task
+ val task = TASK_LLM_PROMPT_LAB
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val uiState by viewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt
index 88414d6..861d5fd 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt
@@ -27,6 +27,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance
+import dagger.hilt.android.lifecycle.HiltViewModel
+import javax.inject.Inject
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
@@ -63,8 +65,9 @@ private val STATS =
Stat(id = "latency", label = "Latency", unit = "sec"),
)
-open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() {
- private val _uiState = MutableStateFlow(createUiState(task = task))
+@HiltViewModel
+class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
+ private val _uiState = MutableStateFlow(createUiState(task = TASK_LLM_PROMPT_LAB))
val uiState = _uiState.asStateFlow()
fun generateResponse(model: Model, input: String) {
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt
index f73b893..b8b4eee 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt
@@ -36,6 +36,7 @@ import com.google.ai.edge.gallery.data.ModelAllowlist
import com.google.ai.edge.gallery.data.ModelDownloadStatus
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.data.TASKS
+import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
@@ -51,9 +52,12 @@ import com.google.ai.edge.gallery.ui.common.AuthConfig
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
+import dagger.hilt.android.lifecycle.HiltViewModel
+import dagger.hilt.android.qualifiers.ApplicationContext
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
+import javax.inject.Inject
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
@@ -133,11 +137,14 @@ data class PagerScrollState(val page: Int = 0, val offset: Float = 0f)
* cleaning up models. It also manages the UI state for model management, including the list of
* tasks, models, download statuses, and initialization statuses.
*/
-open class ModelManagerViewModel(
+@HiltViewModel
+open class ModelManagerViewModel
+@Inject
+constructor(
private val downloadRepository: DownloadRepository,
private val dataStoreRepository: DataStoreRepository,
private val lifecycleProvider: AppLifecycleProvider,
- context: Context,
+ @ApplicationContext private val context: Context,
) : ViewModel() {
private val externalFilesDir = context.getExternalFilesDir(null)
private val inProgressWorkInfos: List =
@@ -281,15 +288,12 @@ open class ModelManagerViewModel(
}
}
when (task.type) {
- TaskType.LLM_CHAT ->
- LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
-
+ TaskType.LLM_CHAT,
+ TaskType.LLM_ASK_IMAGE,
+ TaskType.LLM_ASK_AUDIO,
TaskType.LLM_PROMPT_LAB ->
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
- TaskType.LLM_ASK_IMAGE ->
- LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
-
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
@@ -301,9 +305,11 @@ open class ModelManagerViewModel(
model.cleanUpAfterInit = false
Log.d(TAG, "Cleaning up model '${model.name}'...")
when (task.type) {
- TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
- TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model)
- TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model)
+ TaskType.LLM_CHAT,
+ TaskType.LLM_PROMPT_LAB,
+ TaskType.LLM_ASK_IMAGE,
+ TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model)
+
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
@@ -410,14 +416,19 @@ open class ModelManagerViewModel(
// Create model.
val model = createModelFromImportedModelInfo(info = info)
- for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
+ for (task in
+ listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
// Remove duplicated imported model if existed.
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex)
}
- if ((task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || task != TASK_LLM_ASK_IMAGE) {
+ if (
+ (task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) ||
+ (task == TASK_LLM_ASK_AUDIO && model.llmSupportAudio) ||
+ (task != TASK_LLM_ASK_IMAGE && task != TASK_LLM_ASK_AUDIO)
+ ) {
task.models.add(model)
}
task.updateTrigger.value = System.currentTimeMillis()
@@ -657,6 +668,7 @@ open class ModelManagerViewModel(
TASK_LLM_CHAT.models.clear()
TASK_LLM_PROMPT_LAB.models.clear()
TASK_LLM_ASK_IMAGE.models.clear()
+ TASK_LLM_ASK_AUDIO.models.clear()
for (allowedModel in modelAllowlist.models) {
if (allowedModel.disabled == true) {
continue
@@ -672,6 +684,9 @@ open class ModelManagerViewModel(
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
TASK_LLM_ASK_IMAGE.models.add(model)
}
+ if (allowedModel.taskTypes.contains(TASK_LLM_ASK_AUDIO.type.id)) {
+ TASK_LLM_ASK_AUDIO.models.add(model)
+ }
}
// Pre-process all tasks.
@@ -760,6 +775,9 @@ open class ModelManagerViewModel(
if (model.llmSupportImage) {
TASK_LLM_ASK_IMAGE.models.add(model)
}
+ if (model.llmSupportAudio) {
+ TASK_LLM_ASK_AUDIO.models.add(model)
+ }
// Update status.
modelDownloadStatus[model.name] =
@@ -800,6 +818,7 @@ open class ModelManagerViewModel(
accelerators = accelerators,
)
val llmSupportImage = info.llmConfig.supportImage
+ val llmSupportAudio = info.llmConfig.supportAudio
val model =
Model(
name = info.fileName,
@@ -811,6 +830,7 @@ open class ModelManagerViewModel(
showRunAgainButton = false,
imported = true,
llmSupportImage = llmSupportImage,
+ llmSupportAudio = llmSupportAudio,
)
model.preProcess()
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt
index 5e8672d..e4ebd48 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt
@@ -36,10 +36,10 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.zIndex
+import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.compose.LocalLifecycleOwner
-import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.navigation.NavBackStackEntry
import androidx.navigation.NavHostController
import androidx.navigation.NavType
@@ -47,20 +47,26 @@ import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.navArgument
import com.google.ai.edge.gallery.data.Model
+import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.data.getModelByName
-import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.home.HomeScreen
+import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination
+import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen
+import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen
+import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmChatScreen
+import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
+import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManager
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
@@ -104,7 +110,7 @@ private fun AnimatedContentTransitionScope<*>.slideExit(): ExitTransition {
fun GalleryNavHost(
navController: NavHostController,
modifier: Modifier = Modifier,
- modelManagerViewModel: ModelManagerViewModel = viewModel(factory = ViewModelProvider.Factory),
+ modelManagerViewModel: ModelManagerViewModel = hiltViewModel(),
) {
val lifecycleOwner = LocalLifecycleOwner.current
var showModelManager by remember { mutableStateOf(false) }
@@ -181,11 +187,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
- ) {
- getModelFromNavigationParam(it, TASK_LLM_CHAT)?.let { defaultModel ->
+ ) { backStackEntry ->
+ val viewModel: LlmChatViewModel = hiltViewModel(backStackEntry)
+
+ getModelFromNavigationParam(backStackEntry, TASK_LLM_CHAT)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmChatScreen(
+ viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
@@ -198,28 +207,54 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
- ) {
- getModelFromNavigationParam(it, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
+ ) { backStackEntry ->
+ val viewModel: LlmSingleTurnViewModel = hiltViewModel(backStackEntry)
+
+ getModelFromNavigationParam(backStackEntry, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen(
+ viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
- // LLM image to text.
+ // Ask image.
composable(
route = "${LlmAskImageDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
- ) {
- getModelFromNavigationParam(it, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
+ ) { backStackEntry ->
+ val viewModel: LlmAskImageViewModel = hiltViewModel()
+
+ getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmAskImageScreen(
+ viewModel = viewModel,
+ modelManagerViewModel = modelManagerViewModel,
+ navigateUp = { navController.navigateUp() },
+ )
+ }
+ }
+
+ // Ask audio.
+ composable(
+ route = "${LlmAskAudioDestination.route}/{modelName}",
+ arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
+ enterTransition = { slideEnter() },
+ exitTransition = { slideExit() },
+ ) { backStackEntry ->
+ val viewModel: LlmAskAudioViewModel = hiltViewModel()
+
+ getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_AUDIO)?.let { defaultModel ->
+ modelManagerViewModel.selectModel(defaultModel)
+
+ LlmAskAudioScreen(
+ viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
@@ -256,6 +291,7 @@ fun navigateToTaskScreen(
when (taskType) {
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
+ TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}")
TaskType.LLM_PROMPT_LAB ->
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewChatModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewChatModel.kt
deleted file mode 100644
index bab07cb..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewChatModel.kt
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.preview
-
-import android.content.Context
-import android.graphics.Bitmap
-import android.graphics.Canvas
-import android.graphics.drawable.Drawable
-import androidx.compose.ui.graphics.Color
-import androidx.compose.ui.graphics.asImageBitmap
-import androidx.core.content.ContextCompat
-import androidx.core.graphics.createBitmap
-import com.google.ai.edge.gallery.R
-import com.google.ai.edge.gallery.common.Classification
-import com.google.ai.edge.gallery.ui.common.chat.ChatMessageClassification
-import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
-import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
-import com.google.ai.edge.gallery.ui.common.chat.ChatSide
-import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
-
-class PreviewChatModel(context: Context) : ChatViewModel(task = TASK_TEST1) {
- init {
- val model = task.models[1]
- addMessage(
- model = model,
- message =
- ChatMessageText(
- content =
- "Thanks everyone for your enthusiasm on the team lunch, but people who can sign on the cheque is OOO next week \uD83D\uDE02,",
- side = ChatSide.USER,
- ),
- )
- addMessage(
- model = model,
- message =
- ChatMessageText(content = "Today is Wednesday!", side = ChatSide.AGENT, latencyMs = 1232f),
- )
- addMessage(
- model = model,
- message =
- ChatMessageClassification(
- classifications =
- listOf(
- Classification(label = "label1", score = 0.3f, color = Color.Red),
- Classification(label = "label2", score = 0.7f, color = Color.Blue),
- ),
- latencyMs = 12345f,
- ),
- )
- val bitmap =
- getBitmapFromVectorDrawable(
- context = context,
- drawableId = R.drawable.ic_launcher_background,
- )!!
- addMessage(
- model = model,
- message =
- ChatMessageImage(
- bitmap = bitmap,
- imageBitMap = bitmap.asImageBitmap(),
- side = ChatSide.USER,
- ),
- )
- }
-
- private fun getBitmapFromVectorDrawable(context: Context, drawableId: Int): Bitmap? {
- val drawable: Drawable =
- ContextCompat.getDrawable(context, drawableId) ?: return null // Drawable not found
-
- val bitmap = createBitmap(drawable.intrinsicWidth, drawable.intrinsicHeight)
- val canvas = Canvas(bitmap)
- drawable.setBounds(0, 0, canvas.width, canvas.height)
- drawable.draw(canvas)
-
- return bitmap
- }
-}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDataStoreRepository.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDataStoreRepository.kt
deleted file mode 100644
index b5e46c9..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDataStoreRepository.kt
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.preview
-
-// TODO(migration)
-//
-// import com.google.ai.edge.gallery.data.AccessTokenData
-// import com.google.ai.edge.gallery.data.DataStoreRepository
-// import com.google.ai.edge.gallery.data.ImportedModelInfo
-
-// class PreviewDataStoreRepository : DataStoreRepository
-class PreviewDataStoreRepository {
- // override fun saveTextInputHistory(history: List) {
- // }
-
- // override fun readTextInputHistory(): List {
- // return listOf()
- // }
-
- // override fun saveThemeOverride(theme: String) {
- // }
-
- // override fun readThemeOverride(): String {
- // return ""
- // }
-
- // override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) {
- // }
-
- // override fun readAccessTokenData(): AccessTokenData? {
- // return null
- // }
-
- // override fun clearAccessTokenData() {
- // }
-
- // override fun saveImportedModels(importedModels: List) {
- // }
-
- // override fun readImportedModels(): List {
- // return listOf()
- // }
-}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDownloadRepository.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDownloadRepository.kt
deleted file mode 100644
index 88e7dfa..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDownloadRepository.kt
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.preview
-
-import com.google.ai.edge.gallery.data.AGWorkInfo
-import com.google.ai.edge.gallery.data.DownloadRepository
-import com.google.ai.edge.gallery.data.Model
-import com.google.ai.edge.gallery.data.ModelDownloadStatus
-import java.util.UUID
-
-class PreviewDownloadRepository : DownloadRepository {
- override fun downloadModel(
- model: Model,
- onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
- ) {}
-
- override fun cancelDownloadModel(model: Model) {}
-
- override fun cancelAll(models: List, onComplete: () -> Unit) {}
-
- override fun observerWorkerProgress(
- workerId: UUID,
- model: Model,
- onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
- ) {}
-
- override fun getEnqueuedOrRunningWorkInfos(): List {
- return listOf()
- }
-}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewModelManagerViewModel.kt
deleted file mode 100644
index 612cbe6..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewModelManagerViewModel.kt
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.preview
-
-class PreviewModelManagerViewModel {}
-
-// class PreviewModelManagerViewModel(context: Context) :
-// ModelManagerViewModel(
-// downloadRepository = PreviewDownloadRepository(),
-// // dataStoreRepository = PreviewDataStoreRepository(),
-// context = context,
-// ) {
-
-// init {
-// for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
-// task.index = index
-// for (model in task.models) {
-// model.preProcess()
-// }
-// }
-
-// val modelDownloadStatus =
-// mapOf(
-// MODEL_TEST1.name to
-// ModelDownloadStatus(
-// status = ModelDownloadStatusType.IN_PROGRESS,
-// receivedBytes = 1234,
-// totalBytes = 3456,
-// bytesPerSecond = 2333,
-// remainingMs = 324,
-// ),
-// MODEL_TEST2.name to ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED),
-// MODEL_TEST3.name to
-// ModelDownloadStatus(
-// status = ModelDownloadStatusType.FAILED,
-// errorMessage = "Http code 404",
-// ),
-// MODEL_TEST4.name to ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED),
-// )
-// val newUiState =
-// ModelManagerUiState(
-// tasks = ALL_PREVIEW_TASKS,
-// modelDownloadStatus = modelDownloadStatus,
-// modelInitializationStatus = mapOf(),
-// selectedModel = MODEL_TEST2,
-// )
-// _uiState.update { newUiState }
-// }
-// }
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewTasks.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewTasks.kt
deleted file mode 100644
index 44e2e3d..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewTasks.kt
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.preview
-
-import androidx.compose.material.icons.Icons
-import androidx.compose.material.icons.rounded.AccountBox
-import androidx.compose.material.icons.rounded.AutoAwesome
-import com.google.ai.edge.gallery.data.BooleanSwitchConfig
-import com.google.ai.edge.gallery.data.Config
-import com.google.ai.edge.gallery.data.ConfigKey
-import com.google.ai.edge.gallery.data.LabelConfig
-import com.google.ai.edge.gallery.data.Model
-import com.google.ai.edge.gallery.data.NumberSliderConfig
-import com.google.ai.edge.gallery.data.SegmentedButtonConfig
-import com.google.ai.edge.gallery.data.Task
-import com.google.ai.edge.gallery.data.TaskType
-import com.google.ai.edge.gallery.data.ValueType
-
-val TEST_CONFIGS1: List =
- listOf(
- LabelConfig(key = ConfigKey.NAME, defaultValue = "Test name"),
- NumberSliderConfig(
- key = ConfigKey.MAX_RESULT_COUNT,
- sliderMin = 1f,
- sliderMax = 5f,
- defaultValue = 3f,
- valueType = ValueType.INT,
- ),
- BooleanSwitchConfig(key = ConfigKey.USE_GPU, defaultValue = false),
- SegmentedButtonConfig(
- key = ConfigKey.THEME,
- defaultValue = "Auto",
- options = listOf("Auto", "Light", "Dark"),
- ),
- )
-
-val MODEL_TEST1: Model =
- Model(
- name = "deterministic3",
- downloadFileName = "deterministric3.json",
- url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/deterministic3.json",
- sizeInBytes = 40146048L,
- configs = TEST_CONFIGS1,
- )
-
-val MODEL_TEST2: Model =
- Model(
- name = "isnet",
- downloadFileName = "isnet.tflite",
- url =
- "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/isnet-general-use-int8.tflite",
- sizeInBytes = 44366296L,
- configs = TEST_CONFIGS1,
- )
-
-val MODEL_TEST3: Model =
- Model(
- name = "yolo",
- downloadFileName = "yolo.json",
- url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/yolo.json",
- sizeInBytes = 40641364L,
- )
-
-val MODEL_TEST4: Model =
- Model(
- name = "mobilenet v3",
- downloadFileName = "mobilenet_v3_large.pt2",
- url =
- "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/mobilenet_v3_large.pt2",
- sizeInBytes = 277135998L,
- )
-
-val TASK_TEST1 =
- Task(
- type = TaskType.TEST_TASK_1,
- icon = Icons.Rounded.AutoAwesome,
- models = mutableListOf(MODEL_TEST1, MODEL_TEST2),
- description = "This is a test task (1)",
- )
-
-val TASK_TEST2 =
- Task(
- type = TaskType.TEST_TASK_2,
- icon = Icons.Rounded.AccountBox,
- models = mutableListOf(MODEL_TEST3, MODEL_TEST4),
- description = "This is a test task (2)",
- )
-
-val ALL_PREVIEW_TASKS: List = listOf(TASK_TEST1, TASK_TEST2)
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt
index 72ad1d5..a1a7ad1 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt
@@ -120,6 +120,8 @@ data class CustomColors(
val agentBubbleBgColor: Color = Color.Transparent,
val linkColor: Color = Color.Transparent,
val successColor: Color = Color.Transparent,
+ val recordButtonBgColor: Color = Color.Transparent,
+ val waveFormBgColor: Color = Color.Transparent,
)
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
@@ -145,6 +147,8 @@ val lightCustomColors =
userBubbleBgColor = Color(0xFF32628D),
linkColor = Color(0xFF32628D),
successColor = Color(0xff3d860b),
+ recordButtonBgColor = Color(0xFFEE675C),
+ waveFormBgColor = Color(0xFFaaaaaa),
)
val darkCustomColors =
@@ -168,6 +172,8 @@ val darkCustomColors =
userBubbleBgColor = Color(0xFF1f3760),
linkColor = Color(0xFF9DCAFC),
successColor = Color(0xFFA1CE83),
+ recordButtonBgColor = Color(0xFFEE675C),
+ waveFormBgColor = Color(0xFFaaaaaa),
)
val MaterialTheme.customColors: CustomColors
diff --git a/Android/src/app/src/main/proto/settings.proto b/Android/src/app/src/main/proto/settings.proto
index 5540eaf..3b5f6f3 100644
--- a/Android/src/app/src/main/proto/settings.proto
+++ b/Android/src/app/src/main/proto/settings.proto
@@ -55,6 +55,7 @@ message LlmConfig {
float default_topp = 4;
float default_temperature = 5;
bool support_image = 6;
+ bool support_audio = 7;
}
message Settings {
diff --git a/Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt b/Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt
new file mode 100644
index 0000000..bfc04bb
--- /dev/null
+++ b/Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt
@@ -0,0 +1,107 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.edge.gallery.data
+
+import org.junit.Assert.assertEquals
+import org.junit.Assert.assertFalse
+import org.junit.Assert.assertTrue
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+@RunWith(JUnit4::class)
+class ModelAllowlistTest {
+ @Test
+ fun toModel_success() {
+ val modelName = "test_model"
+ val modelId = "test_model_id"
+ val modelFile = "test_model_file"
+ val description = "test description"
+ val sizeInBytes = 100L
+ val version = "20250623"
+ val topK = 10
+ val topP = 0.5f
+ val temperature = 0.1f
+ val maxTokens = 1000
+ val accelerators = "gpu,cpu"
+ val taskTypes = listOf("llm_chat", "ask_image")
+ val estimatedPeakMemoryInBytes = 300L
+
+ val allowedModel =
+ AllowedModel(
+ name = modelName,
+ modelId = modelId,
+ modelFile = modelFile,
+ description = description,
+ sizeInBytes = sizeInBytes,
+ version = version,
+ defaultConfig =
+ DefaultConfig(
+ topK = topK,
+ topP = topP,
+ temperature = temperature,
+ maxTokens = maxTokens,
+ accelerators = accelerators,
+ ),
+ taskTypes = taskTypes,
+ llmSupportImage = true,
+ llmSupportAudio = true,
+ estimatedPeakMemoryInBytes = estimatedPeakMemoryInBytes,
+ )
+ val model = allowedModel.toModel()
+
+ // Check that basic fields are set correctly.
+ assertEquals(model.name, modelName)
+ assertEquals(model.version, version)
+ assertEquals(model.info, description)
+ assertEquals(
+ model.url,
+ "https://huggingface.co/test_model_id/resolve/main/test_model_file?download=true",
+ )
+ assertEquals(model.sizeInBytes, sizeInBytes)
+ assertEquals(model.estimatedPeakMemoryInBytes, estimatedPeakMemoryInBytes)
+ assertEquals(model.downloadFileName, modelFile)
+ assertFalse(model.showBenchmarkButton)
+ assertFalse(model.showRunAgainButton)
+ assertTrue(model.llmSupportImage)
+ assertTrue(model.llmSupportAudio)
+
+ // Check that configs are set correctly.
+ assertEquals(model.configs.size, 5)
+
+ // A label for showing max tokens (non-changeable).
+ assertTrue(model.configs[0] is LabelConfig)
+ assertEquals((model.configs[0] as LabelConfig).defaultValue, "$maxTokens")
+
+ // A slider for topK.
+ assertTrue(model.configs[1] is NumberSliderConfig)
+ assertEquals((model.configs[1] as NumberSliderConfig).defaultValue, topK.toFloat())
+
+ // A slider for topP.
+ assertTrue(model.configs[2] is NumberSliderConfig)
+ assertEquals((model.configs[2] as NumberSliderConfig).defaultValue, topP)
+
+ // A slider for temperature.
+ assertTrue(model.configs[3] is NumberSliderConfig)
+ assertEquals((model.configs[3] as NumberSliderConfig).defaultValue, temperature)
+
+ // A segmented button for accelerators.
+ assertTrue(model.configs[4] is SegmentedButtonConfig)
+ assertEquals((model.configs[4] as SegmentedButtonConfig).defaultValue, "GPU")
+ assertEquals((model.configs[4] as SegmentedButtonConfig).options, listOf("GPU", "CPU"))
+ }
+}
diff --git a/Android/src/build.gradle.kts b/Android/src/build.gradle.kts
index da93723..fe99361 100644
--- a/Android/src/build.gradle.kts
+++ b/Android/src/build.gradle.kts
@@ -19,4 +19,5 @@ plugins {
alias(libs.plugins.android.application) apply false
alias(libs.plugins.kotlin.android) apply false
alias(libs.plugins.kotlin.compose) apply false
+ alias(libs.plugins.hilt.application) apply false
}
diff --git a/Android/src/gradle/libs.versions.toml b/Android/src/gradle/libs.versions.toml
index 22bdb23..21d4559 100644
--- a/Android/src/gradle/libs.versions.toml
+++ b/Android/src/gradle/libs.versions.toml
@@ -30,6 +30,8 @@ playServicesTfliteGpu= "16.4.0"
cameraX = "1.4.2"
netOpenidAppauth = "0.11.1"
splashscreen = "1.2.0-beta02"
+hilt = "2.56.2"
+hiltNavigation = "1.2.0"
[libraries]
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
@@ -69,6 +71,10 @@ camerax-view = { group = "androidx.camera", name = "camera-view", version.ref =
openid-appauth = { group = "net.openid", name = "appauth", version.ref = "netOpenidAppauth" }
androidx-splashscreen = { group = "androidx.core", name = "core-splashscreen", version.ref = "splashscreen" }
protobuf-javalite = { group = "com.google.protobuf", name = "protobuf-javalite", version.ref = "protobufJavaLite" }
+hilt-android = { module = "com.google.dagger:hilt-android", version.ref = "hilt" }
+hilt-navigation-compose = { module = "androidx.hilt:hilt-navigation-compose", version.ref = "hiltNavigation" }
+hilt-android-testing = { module = "com.google.dagger:hilt-android-testing", version.ref = "hilt" }
+hilt-android-compiler = { module = "com.google.dagger:hilt-android-compiler", version.ref = "hilt" }
[plugins]
android-application = { id = "com.android.application", version.ref = "agp" }
@@ -76,3 +82,4 @@ kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" }
kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "serializationPlugin" }
protobuf = {id = "com.google.protobuf", version.ref = "protobuf"}
+hilt-application = { id = "com.google.dagger.hilt.android", version.ref = "hilt" }