diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml
index 2f04e88..6bacbd3 100644
--- a/Android/src/app/src/main/AndroidManifest.xml
+++ b/Android/src/app/src/main/AndroidManifest.xml
@@ -29,6 +29,7 @@
+
(val jsonObj: T, val textContent: String)
+
+class AudioClip(val audioData: ByteArray, val sampleRate: Int)
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt
index 6d37f9d..1ada15d 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt
@@ -17,12 +17,17 @@
package com.google.ai.edge.gallery.common
import android.content.Context
+import android.net.Uri
import android.util.Log
+import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import kotlin.math.floor
data class LaunchInfo(val ts: Long)
@@ -112,3 +117,135 @@ inline fun getJsonResponse(url: String): JsonObjAndTextContent? {
return null
}
+
+fun convertWavToMonoWithMaxSeconds(
+ context: Context,
+ stereoUri: Uri,
+ maxSeconds: Int = 30,
+): AudioClip? {
+ Log.d(TAG, "Start to convert wav file to mono channel")
+
+ try {
+ val inputStream = context.contentResolver.openInputStream(stereoUri) ?: return null
+ val originalBytes = inputStream.readBytes()
+ inputStream.close()
+
+ // Read WAV header
+ if (originalBytes.size < 44) {
+ // Not a valid WAV file
+ Log.e(TAG, "Not a valid wav file")
+ return null
+ }
+
+ val headerBuffer = ByteBuffer.wrap(originalBytes, 0, 44).order(ByteOrder.LITTLE_ENDIAN)
+ val channels = headerBuffer.getShort(22)
+ var sampleRate = headerBuffer.getInt(24)
+ val bitDepth = headerBuffer.getShort(34)
+ Log.d(TAG, "File metadata: channels: $channels, sampleRate: $sampleRate, bitDepth: $bitDepth")
+
+ // Normalize audio to 16-bit.
+ val audioDataBytes = originalBytes.copyOfRange(fromIndex = 44, toIndex = originalBytes.size)
+ var sixteenBitBytes: ByteArray =
+ if (bitDepth.toInt() == 8) {
+ Log.d(TAG, "Converting 8-bit audio to 16-bit.")
+ convert8BitTo16Bit(audioDataBytes)
+ } else {
+ // Assume 16-bit or other format that can be handled directly
+ audioDataBytes
+ }
+
+ // Convert byte array to short array for processing
+ val shortBuffer =
+ ByteBuffer.wrap(sixteenBitBytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
+ var pcmSamples = ShortArray(shortBuffer.remaining())
+ shortBuffer.get(pcmSamples)
+
+ // Resample if sample rate is less than 16000 Hz ---
+ if (sampleRate < SAMPLE_RATE) {
+ Log.d(TAG, "Resampling from $sampleRate Hz to $SAMPLE_RATE Hz.")
+ pcmSamples = resample(pcmSamples, sampleRate, SAMPLE_RATE, channels.toInt())
+ sampleRate = SAMPLE_RATE
+ Log.d(TAG, "Resampling complete. New sample count: ${pcmSamples.size}")
+ }
+
+ // Convert stereo to mono if necessary
+ var monoSamples =
+ if (channels.toInt() == 2) {
+ Log.d(TAG, "Converting stereo to mono.")
+ val mono = ShortArray(pcmSamples.size / 2)
+ for (i in mono.indices) {
+ val left = pcmSamples[i * 2]
+ val right = pcmSamples[i * 2 + 1]
+ mono[i] = ((left + right) / 2).toShort()
+ }
+ mono
+ } else {
+ Log.d(TAG, "Audio is already mono. No channel conversion needed.")
+ pcmSamples
+ }
+
+ // Trim the audio to maxSeconds ---
+ val maxSamples = maxSeconds * sampleRate
+ if (monoSamples.size > maxSamples) {
+ Log.d(TAG, "Trimming clip from ${monoSamples.size} samples to $maxSamples samples.")
+ monoSamples = monoSamples.copyOfRange(0, maxSamples)
+ }
+
+ val monoByteBuffer = ByteBuffer.allocate(monoSamples.size * 2).order(ByteOrder.LITTLE_ENDIAN)
+ monoByteBuffer.asShortBuffer().put(monoSamples)
+ return AudioClip(audioData = monoByteBuffer.array(), sampleRate = sampleRate)
+ } catch (e: Exception) {
+ Log.e(TAG, "Failed to convert wav to mono", e)
+ return null
+ }
+}
+
+/** Converts 8-bit unsigned PCM audio data to 16-bit signed PCM. */
+private fun convert8BitTo16Bit(eightBitData: ByteArray): ByteArray {
+ // The new 16-bit data will be twice the size
+ val sixteenBitData = ByteArray(eightBitData.size * 2)
+ val buffer = ByteBuffer.wrap(sixteenBitData).order(ByteOrder.LITTLE_ENDIAN)
+
+ for (byte in eightBitData) {
+ // Convert the unsigned 8-bit byte (0-255) to a signed 16-bit short (-32768 to 32767)
+ // 1. Get the unsigned value by masking with 0xFF
+ // 2. Subtract 128 to center the waveform around 0 (range becomes -128 to 127)
+ // 3. Scale by 256 to expand to the 16-bit range
+ val unsignedByte = byte.toInt() and 0xFF
+ val sixteenBitSample = ((unsignedByte - 128) * 256).toShort()
+ buffer.putShort(sixteenBitSample)
+ }
+ return sixteenBitData
+}
+
+/** Resamples PCM audio data from an original sample rate to a target sample rate. */
+private fun resample(
+ inputSamples: ShortArray,
+ originalSampleRate: Int,
+ targetSampleRate: Int,
+ channels: Int,
+): ShortArray {
+ if (originalSampleRate == targetSampleRate) {
+ return inputSamples
+ }
+
+ val ratio = targetSampleRate.toDouble() / originalSampleRate
+ val outputLength = (inputSamples.size * ratio).toInt()
+ val resampledData = ShortArray(outputLength)
+
+ if (channels == 1) { // Mono
+ for (i in resampledData.indices) {
+ val position = i / ratio
+ val index1 = floor(position).toInt()
+ val index2 = index1 + 1
+ val fraction = position - index1
+
+ val sample1 = if (index1 < inputSamples.size) inputSamples[index1].toDouble() else 0.0
+ val sample2 = if (index2 < inputSamples.size) inputSamples[index2].toDouble() else 0.0
+
+ resampledData[i] = (sample1 * (1 - fraction) + sample2 * fraction).toInt().toShort()
+ }
+ }
+
+ return resampledData
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt
index d961ecc..2b9fdf7 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt
@@ -50,6 +50,7 @@ enum class ConfigKey(val label: String) {
DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"),
SUPPORT_IMAGE("Support image"),
+ SUPPORT_AUDIO("Support audio"),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Choose accelerator"),
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt
index a2209af..85fd71e 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt
@@ -44,3 +44,12 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
// Max number of images allowed in a "ask image" session.
const val MAX_IMAGE_COUNT = 10
+
+// Max number of audio clip in an "ask audio" session.
+const val MAX_AUDIO_CLIP_COUNT = 10
+
+// Max audio clip duration in seconds.
+const val MAX_AUDIO_CLIP_DURATION_SEC = 30
+
+// Audio-recording related consts.
+const val SAMPLE_RATE = 16000
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt
index 580bf2b..02e1368 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt
@@ -87,6 +87,9 @@ data class Model(
/** Whether the LLM model supports image input. */
val llmSupportImage: Boolean = false,
+ /** Whether the LLM model supports audio input. */
+ val llmSupportAudio: Boolean = false,
+
/** Whether the model is imported or not. */
val imported: Boolean = false,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt
index f336638..5cf11ca 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt
@@ -38,6 +38,7 @@ data class AllowedModel(
val taskTypes: List,
val disabled: Boolean? = null,
val llmSupportImage: Boolean? = null,
+ val llmSupportAudio: Boolean? = null,
val estimatedPeakMemoryInBytes: Long? = null,
) {
fun toModel(): Model {
@@ -96,6 +97,7 @@ data class AllowedModel(
showRunAgainButton = showRunAgainButton,
learnMoreUrl = "https://huggingface.co/${modelId}",
llmSupportImage = llmSupportImage == true,
+ llmSupportAudio = llmSupportAudio == true,
)
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt
index e95feab..c52d2c3 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt
@@ -17,19 +17,22 @@
package com.google.ai.edge.gallery.data
import androidx.annotation.StringRes
+import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.outlined.Forum
+import androidx.compose.material.icons.outlined.Mic
+import androidx.compose.material.icons.outlined.Mms
+import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.ui.graphics.vector.ImageVector
import com.google.ai.edge.gallery.R
-import com.google.ai.edge.gallery.ui.icon.Forum
-import com.google.ai.edge.gallery.ui.icon.Mms
-import com.google.ai.edge.gallery.ui.icon.Widgets
/** Type of task. */
enum class TaskType(val label: String, val id: String) {
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
+ LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
}
@@ -71,7 +74,7 @@ data class Task(
val TASK_LLM_CHAT =
Task(
type = TaskType.LLM_CHAT,
- icon = Forum,
+ icon = Icons.Outlined.Forum,
models = mutableListOf(),
description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@@ -83,7 +86,7 @@ val TASK_LLM_CHAT =
val TASK_LLM_PROMPT_LAB =
Task(
type = TaskType.LLM_PROMPT_LAB,
- icon = Widgets,
+ icon = Icons.Outlined.Widgets,
models = mutableListOf(),
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@@ -95,7 +98,7 @@ val TASK_LLM_PROMPT_LAB =
val TASK_LLM_ASK_IMAGE =
Task(
type = TaskType.LLM_ASK_IMAGE,
- icon = Mms,
+ icon = Icons.Outlined.Mms,
models = mutableListOf(),
description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@@ -104,8 +107,23 @@ val TASK_LLM_ASK_IMAGE =
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
+val TASK_LLM_ASK_AUDIO =
+ Task(
+ type = TaskType.LLM_ASK_AUDIO,
+ icon = Icons.Outlined.Mic,
+ models = mutableListOf(),
+ // TODO(do not submit)
+ description =
+ "Instantly transcribe and/or translate audio clips using on-device large language models",
+ docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
+ sourceCodeUrl =
+ "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
+ textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
+ )
+
/** All tasks. */
-val TASKS: List = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
+val TASKS: List =
+ listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
fun getModelByName(name: String): Model? {
for (task in TASKS) {
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt
index fdded53..6ceb148 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt
@@ -21,6 +21,7 @@ import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.initializer
import androidx.lifecycle.viewmodel.viewModelFactory
import com.google.ai.edge.gallery.GalleryApplication
+import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
@@ -49,6 +50,9 @@ object ViewModelProvider {
// Initializer for LlmAskImageViewModel.
initializer { LlmAskImageViewModel() }
+
+ // Initializer for LlmAskAudioViewModel.
+ initializer { LlmAskAudioViewModel() }
}
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt
new file mode 100644
index 0000000..9faebf0
--- /dev/null
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt
@@ -0,0 +1,344 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.edge.gallery.ui.common.chat
+
+import android.media.AudioAttributes
+import android.media.AudioFormat
+import android.media.AudioTrack
+import android.util.Log
+import androidx.compose.foundation.Canvas
+import androidx.compose.foundation.layout.Row
+import androidx.compose.foundation.layout.height
+import androidx.compose.foundation.layout.padding
+import androidx.compose.foundation.layout.width
+import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.rounded.PlayArrow
+import androidx.compose.material.icons.rounded.Stop
+import androidx.compose.material3.Icon
+import androidx.compose.material3.IconButton
+import androidx.compose.material3.MaterialTheme
+import androidx.compose.material3.Text
+import androidx.compose.runtime.Composable
+import androidx.compose.runtime.DisposableEffect
+import androidx.compose.runtime.LaunchedEffect
+import androidx.compose.runtime.MutableState
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableFloatStateOf
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.remember
+import androidx.compose.runtime.rememberCoroutineScope
+import androidx.compose.runtime.setValue
+import androidx.compose.ui.Alignment
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.geometry.CornerRadius
+import androidx.compose.ui.geometry.Offset
+import androidx.compose.ui.geometry.Size
+import androidx.compose.ui.geometry.toRect
+import androidx.compose.ui.graphics.BlendMode
+import androidx.compose.ui.graphics.Color
+import androidx.compose.ui.graphics.drawscope.drawIntoCanvas
+import androidx.compose.ui.unit.dp
+import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
+import com.google.ai.edge.gallery.ui.theme.customColors
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.isActive
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.withContext
+
+private const val TAG = "AGAudioPlaybackPanel"
+private const val BAR_SPACE = 2
+private const val BAR_WIDTH = 2
+private const val MIN_BAR_COUNT = 16
+private const val MAX_BAR_COUNT = 48
+
+/**
+ * A Composable that displays an audio playback panel, including play/stop controls, a waveform
+ * visualization, and the duration of the audio clip.
+ */
+@Composable
+fun AudioPlaybackPanel(
+ audioData: ByteArray,
+ sampleRate: Int,
+ isRecording: Boolean,
+ modifier: Modifier = Modifier,
+ onDarkBg: Boolean = false,
+) {
+ val coroutineScope = rememberCoroutineScope()
+ var isPlaying by remember { mutableStateOf(false) }
+ val audioTrackState = remember { mutableStateOf(null) }
+ val durationInSeconds =
+ remember(audioData) {
+ // PCM 16-bit
+ val bytesPerSample = 2
+ val bytesPerFrame = bytesPerSample * 1 // mono
+ val totalFrames = audioData.size.toDouble() / bytesPerFrame
+ totalFrames / sampleRate
+ }
+ val barCount =
+ remember(durationInSeconds) {
+ val f = durationInSeconds / MAX_AUDIO_CLIP_DURATION_SEC
+ ((MAX_BAR_COUNT - MIN_BAR_COUNT) * f + MIN_BAR_COUNT).toInt()
+ }
+ val amplitudeLevels =
+ remember(audioData) { generateAmplitudeLevels(audioData = audioData, barCount = barCount) }
+ var playbackProgress by remember { mutableFloatStateOf(0f) }
+
+ // Reset when a new recording is started.
+ LaunchedEffect(isRecording) {
+ if (isRecording) {
+ val audioTrack = audioTrackState.value
+ audioTrack?.stop()
+ isPlaying = false
+ playbackProgress = 0f
+ }
+ }
+
+ // Cleanup on Composable Disposal.
+ DisposableEffect(Unit) {
+ onDispose {
+ val audioTrack = audioTrackState.value
+ audioTrack?.stop()
+ audioTrack?.release()
+ }
+ }
+
+ Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) {
+ // Button to play/stop the clip.
+ IconButton(
+ onClick = {
+ coroutineScope.launch {
+ if (!isPlaying) {
+ isPlaying = true
+ playAudio(
+ audioTrackState = audioTrackState,
+ audioData = audioData,
+ sampleRate = sampleRate,
+ onProgress = { playbackProgress = it },
+ onCompletion = {
+ playbackProgress = 0f
+ isPlaying = false
+ },
+ )
+ } else {
+ stopPlayAudio(audioTrackState = audioTrackState)
+ playbackProgress = 0f
+ isPlaying = false
+ }
+ }
+ }
+ ) {
+ Icon(
+ if (isPlaying) Icons.Rounded.Stop else Icons.Rounded.PlayArrow,
+ contentDescription = "",
+ tint = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
+ )
+ }
+
+ // Visualization
+ AmplitudeBarGraph(
+ amplitudeLevels = amplitudeLevels,
+ progress = playbackProgress,
+ modifier =
+ Modifier.width((barCount * BAR_WIDTH + (barCount - 1) * BAR_SPACE).dp).height(24.dp),
+ onDarkBg = onDarkBg,
+ )
+
+ // Duration
+ Text(
+ "${"%.1f".format(durationInSeconds)}s",
+ style = MaterialTheme.typography.labelLarge,
+ color = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
+ modifier = Modifier.padding(start = 12.dp),
+ )
+ }
+}
+
+@Composable
+private fun AmplitudeBarGraph(
+ amplitudeLevels: List,
+ progress: Float,
+ modifier: Modifier = Modifier,
+ onDarkBg: Boolean = false,
+) {
+ val barColor = MaterialTheme.customColors.waveFormBgColor
+ val progressColor = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary
+
+ Canvas(modifier = modifier) {
+ val barCount = amplitudeLevels.size
+ val barWidth = (size.width - BAR_SPACE.dp.toPx() * (barCount - 1)) / barCount
+ val cornerRadius = CornerRadius(x = barWidth, y = barWidth)
+
+ // Use drawIntoCanvas for advanced blend mode operations
+ drawIntoCanvas { canvas ->
+
+ // 1. Save the current state of the canvas onto a temporary, offscreen layer
+ canvas.saveLayer(size.toRect(), androidx.compose.ui.graphics.Paint())
+
+ // 2. Draw the bars in grey.
+ amplitudeLevels.forEachIndexed { index, level ->
+ val barHeight = (level * size.height).coerceAtLeast(1.5f)
+ val left = index * (barWidth + BAR_SPACE.dp.toPx())
+ drawRoundRect(
+ color = barColor,
+ topLeft = Offset(x = left, y = size.height / 2 - barHeight / 2),
+ size = Size(barWidth, barHeight),
+ cornerRadius = cornerRadius,
+ )
+ }
+
+ // 3. Draw the progress rectangle using BlendMode.SrcIn to only draw where the bars already
+ // exists.
+ val progressWidth = size.width * progress
+ drawRect(
+ color = progressColor,
+ topLeft = Offset.Zero,
+ size = Size(progressWidth, size.height),
+ blendMode = BlendMode.SrcIn,
+ )
+
+ // 4. Restore the layer, merging it onto the main canvas
+ canvas.restore()
+ }
+ }
+}
+
+private suspend fun playAudio(
+ audioTrackState: MutableState,
+ audioData: ByteArray,
+ sampleRate: Int,
+ onProgress: (Float) -> Unit,
+ onCompletion: () -> Unit,
+) {
+ Log.d(TAG, "Start playing audio...")
+
+ try {
+ withContext(Dispatchers.IO) {
+ var lastProgressUpdateMs = 0L
+ audioTrackState.value?.release()
+ val audioTrack =
+ AudioTrack.Builder()
+ .setAudioAttributes(
+ AudioAttributes.Builder()
+ .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
+ .setUsage(AudioAttributes.USAGE_MEDIA)
+ .build()
+ )
+ .setAudioFormat(
+ AudioFormat.Builder()
+ .setEncoding(AudioFormat.ENCODING_PCM_16BIT)
+ .setSampleRate(sampleRate)
+ .setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
+ .build()
+ )
+ .setTransferMode(AudioTrack.MODE_STATIC)
+ .setBufferSizeInBytes(audioData.size)
+ .build()
+
+ val bytesPerFrame = 2 // For PCM 16-bit Mono
+ val totalFrames = audioData.size / bytesPerFrame
+
+ audioTrackState.value = audioTrack
+ audioTrack.write(audioData, 0, audioData.size)
+ audioTrack.play()
+
+ // Coroutine to monitor progress
+ while (isActive && audioTrack.playState == AudioTrack.PLAYSTATE_PLAYING) {
+ val currentFrames = audioTrack.playbackHeadPosition
+ if (currentFrames >= totalFrames) {
+ break // Exit loop when playback is done
+ }
+ val progress = currentFrames.toFloat() / totalFrames
+ val curMs = System.currentTimeMillis()
+ if (curMs - lastProgressUpdateMs > 30) {
+ onProgress(progress)
+ lastProgressUpdateMs = curMs
+ }
+ }
+
+ if (isActive) {
+ audioTrackState.value?.stop()
+ }
+ }
+ } catch (e: Exception) {
+ // Ignore
+ } finally {
+ onProgress(1f)
+ onCompletion()
+ }
+}
+
+private fun stopPlayAudio(audioTrackState: MutableState) {
+ Log.d(TAG, "Stopping playing audio...")
+
+ val audioTrack = audioTrackState.value
+ audioTrack?.stop()
+ audioTrack?.release()
+ audioTrackState.value = null
+}
+
+/**
+ * Processes a raw PCM 16-bit audio byte array to generate a list of normalized amplitude levels for
+ * visualization.
+ */
+private fun generateAmplitudeLevels(audioData: ByteArray, barCount: Int): List {
+ if (audioData.isEmpty()) {
+ return List(barCount) { 0f }
+ }
+
+ // 1. Parse bytes into 16-bit short samples (PCM 16-bit)
+ val shortBuffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
+ val samples = ShortArray(shortBuffer.remaining())
+ shortBuffer.get(samples)
+
+ if (samples.isEmpty()) {
+ return List(barCount) { 0f }
+ }
+
+ // 2. Determine the size of each chunk
+ val chunkSize = samples.size / barCount
+ val amplitudeLevels = mutableListOf()
+
+ // 3. Get the max value for each chunk
+ for (i in 0 until barCount) {
+ val chunkStart = i * chunkSize
+ val chunkEnd = (chunkStart + chunkSize).coerceAtMost(samples.size)
+
+ var maxAmplitudeInChunk = 0.0
+
+ for (j in chunkStart until chunkEnd) {
+ val sampleAbs = kotlin.math.abs(samples[j].toDouble())
+ if (sampleAbs > maxAmplitudeInChunk) {
+ maxAmplitudeInChunk = sampleAbs
+ }
+ }
+
+ // 4. Normalize the value (0 to 1)
+ // Short.MAX_VALUE is 32767.0, a good reference for max amplitude
+ val normalizedRms = (maxAmplitudeInChunk / Short.MAX_VALUE).toFloat().coerceIn(0f, 1f)
+ amplitudeLevels.add(normalizedRms)
+ }
+
+ // Normalize the resulting levels so that the max value becomes 0.9.
+ val maxVal = amplitudeLevels.max()
+ if (maxVal == 0f) {
+ return amplitudeLevels
+ }
+ val scaleFactor = 0.9f / maxVal
+ return amplitudeLevels.map { it * scaleFactor }
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt
new file mode 100644
index 0000000..9b018dc
--- /dev/null
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt
@@ -0,0 +1,327 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.edge.gallery.ui.common.chat
+
+import android.annotation.SuppressLint
+import android.content.Context
+import android.media.AudioFormat
+import android.media.AudioRecord
+import android.media.MediaRecorder
+import android.util.Log
+import androidx.compose.foundation.background
+import androidx.compose.foundation.layout.Arrangement
+import androidx.compose.foundation.layout.Box
+import androidx.compose.foundation.layout.Column
+import androidx.compose.foundation.layout.Row
+import androidx.compose.foundation.layout.fillMaxWidth
+import androidx.compose.foundation.layout.height
+import androidx.compose.foundation.layout.padding
+import androidx.compose.foundation.layout.size
+import androidx.compose.foundation.shape.CircleShape
+import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.rounded.ArrowUpward
+import androidx.compose.material.icons.rounded.Mic
+import androidx.compose.material.icons.rounded.Stop
+import androidx.compose.material3.Icon
+import androidx.compose.material3.IconButton
+import androidx.compose.material3.MaterialTheme
+import androidx.compose.material3.Text
+import androidx.compose.runtime.Composable
+import androidx.compose.runtime.DisposableEffect
+import androidx.compose.runtime.MutableLongState
+import androidx.compose.runtime.MutableState
+import androidx.compose.runtime.derivedStateOf
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableIntStateOf
+import androidx.compose.runtime.mutableLongStateOf
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.remember
+import androidx.compose.runtime.rememberCoroutineScope
+import androidx.compose.runtime.setValue
+import androidx.compose.ui.Alignment
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.draw.clip
+import androidx.compose.ui.graphics.Color
+import androidx.compose.ui.graphics.graphicsLayer
+import androidx.compose.ui.platform.LocalContext
+import androidx.compose.ui.res.painterResource
+import androidx.compose.ui.unit.dp
+import com.google.ai.edge.gallery.R
+import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
+import com.google.ai.edge.gallery.data.SAMPLE_RATE
+import com.google.ai.edge.gallery.ui.theme.customColors
+import java.io.ByteArrayOutputStream
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import kotlin.math.abs
+import kotlin.math.pow
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.coroutineScope
+import kotlinx.coroutines.launch
+
+private const val TAG = "AGAudioRecorderPanel"
+
+private const val CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO
+private const val AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT
+
+/**
+ * A Composable that provides an audio recording panel. It allows users to record audio clips,
+ * displays the recording duration and a live amplitude visualization, and provides options to play
+ * back the recorded clip or send it.
+ */
+@Composable
+fun AudioRecorderPanel(onSendAudioClip: (ByteArray) -> Unit) {
+ val context = LocalContext.current
+ val coroutineScope = rememberCoroutineScope()
+
+ var isRecording by remember { mutableStateOf(false) }
+ val elapsedMs = remember { mutableLongStateOf(0L) }
+ val audioRecordState = remember { mutableStateOf(null) }
+ val audioStream = remember { ByteArrayOutputStream() }
+ val recordedBytes = remember { mutableStateOf(null) }
+ var currentAmplitude by remember { mutableIntStateOf(0) }
+
+ val elapsedSeconds by remember {
+ derivedStateOf { "%.1f".format(elapsedMs.value.toFloat() / 1000f) }
+ }
+
+ // Cleanup on Composable Disposal.
+ DisposableEffect(Unit) { onDispose { audioRecordState.value?.release() } }
+
+ Column(modifier = Modifier.padding(bottom = 12.dp)) {
+ // Title bar.
+ Row(
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.SpaceBetween,
+ ) {
+ // Logo and state.
+ Row(
+ modifier = Modifier.padding(start = 16.dp),
+ horizontalArrangement = Arrangement.spacedBy(12.dp),
+ ) {
+ Icon(
+ painterResource(R.drawable.logo),
+ modifier = Modifier.size(20.dp),
+ contentDescription = "",
+ tint = Color.Unspecified,
+ )
+ Text(
+ "Record audio clip (up to $MAX_AUDIO_CLIP_DURATION_SEC seconds)",
+ style = MaterialTheme.typography.labelLarge,
+ )
+ }
+ }
+
+ // Recorded clip.
+ Row(
+ modifier = Modifier.fillMaxWidth().padding(vertical = 12.dp).height(40.dp),
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.Center,
+ ) {
+ val curRecordedBytes = recordedBytes.value
+ if (curRecordedBytes == null) {
+ // Info message when there is no recorded clip and the recording has not started yet.
+ if (!isRecording) {
+ Text(
+ "Tap the record button to start",
+ style = MaterialTheme.typography.labelLarge,
+ color = MaterialTheme.colorScheme.onSurfaceVariant,
+ )
+ }
+ // Visualization for clip being recorded.
+ else {
+ Row(
+ horizontalArrangement = Arrangement.spacedBy(12.dp),
+ verticalAlignment = Alignment.CenterVertically,
+ ) {
+ Box(
+ modifier =
+ Modifier.size(8.dp)
+ .background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
+ )
+ Text("$elapsedSeconds s")
+ }
+ }
+ }
+ // Controls for recorded clip.
+ else {
+ Row {
+ // Clip player.
+ AudioPlaybackPanel(
+ audioData = curRecordedBytes,
+ sampleRate = SAMPLE_RATE,
+ isRecording = isRecording,
+ )
+
+ // Button to send the clip
+ IconButton(onClick = { onSendAudioClip(curRecordedBytes) }) {
+ Icon(Icons.Rounded.ArrowUpward, contentDescription = "")
+ }
+ }
+ }
+ }
+
+ // Buttons
+ Box(contentAlignment = Alignment.Center, modifier = Modifier.fillMaxWidth().height(40.dp)) {
+ // Visualization of the current amplitude.
+ if (isRecording) {
+ // Normalize the amplitude (0-32767) to a fraction (0.0-1.0)
+ // We use a power scale (exponent < 1) to make the pulse more visible for lower volumes.
+ val normalizedAmplitude = (currentAmplitude.toFloat() / 32767f).pow(0.35f)
+ // Define the min and max size of the circle
+ val minSize = 38.dp
+ val maxSize = 100.dp
+
+ // Map the normalized amplitude to our size range
+ val scale by
+ remember(normalizedAmplitude) {
+ derivedStateOf { (minSize + (maxSize - minSize) * normalizedAmplitude) / minSize }
+ }
+ Box(
+ modifier =
+ Modifier.size(minSize)
+ .graphicsLayer(scaleX = scale, scaleY = scale, clip = false, alpha = 0.3f)
+ .background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
+ )
+ }
+
+ // Record/stop button.
+ IconButton(
+ onClick = {
+ coroutineScope.launch {
+ if (!isRecording) {
+ isRecording = true
+ recordedBytes.value = null
+ startRecording(
+ context = context,
+ audioRecordState = audioRecordState,
+ audioStream = audioStream,
+ elapsedMs = elapsedMs,
+ onAmplitudeChanged = { currentAmplitude = it },
+ onMaxDurationReached = {
+ val curRecordedBytes =
+ stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
+ recordedBytes.value = curRecordedBytes
+ isRecording = false
+ },
+ )
+ } else {
+ val curRecordedBytes =
+ stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
+ recordedBytes.value = curRecordedBytes
+ isRecording = false
+ }
+ }
+ },
+ modifier =
+ Modifier.clip(CircleShape).background(MaterialTheme.customColors.recordButtonBgColor),
+ ) {
+ Icon(
+ if (isRecording) Icons.Rounded.Stop else Icons.Rounded.Mic,
+ contentDescription = "",
+ tint = Color.White,
+ )
+ }
+ }
+ }
+}
+
+// Permission is checked in parent composable.
+@SuppressLint("MissingPermission")
+private suspend fun startRecording(
+ context: Context,
+ audioRecordState: MutableState,
+ audioStream: ByteArrayOutputStream,
+ elapsedMs: MutableLongState,
+ onAmplitudeChanged: (Int) -> Unit,
+ onMaxDurationReached: () -> Unit,
+) {
+ Log.d(TAG, "Start recording...")
+ val minBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT)
+
+ audioRecordState.value?.release()
+ val recorder =
+ AudioRecord(
+ MediaRecorder.AudioSource.MIC,
+ SAMPLE_RATE,
+ CHANNEL_CONFIG,
+ AUDIO_FORMAT,
+ minBufferSize,
+ )
+
+ audioRecordState.value = recorder
+ val buffer = ByteArray(minBufferSize)
+
+ // The function will only return when the recording is done (when stopRecording is called).
+ coroutineScope {
+ launch(Dispatchers.IO) {
+ recorder.startRecording()
+
+ val startMs = System.currentTimeMillis()
+ elapsedMs.value = 0L
+ while (audioRecordState.value?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
+ val bytesRead = recorder.read(buffer, 0, buffer.size)
+ if (bytesRead > 0) {
+ val currentAmplitude = calculatePeakAmplitude(buffer = buffer, bytesRead = bytesRead)
+ onAmplitudeChanged(currentAmplitude)
+ audioStream.write(buffer, 0, bytesRead)
+ }
+ elapsedMs.value = System.currentTimeMillis() - startMs
+ if (elapsedMs.value >= MAX_AUDIO_CLIP_DURATION_SEC * 1000) {
+ onMaxDurationReached()
+ break
+ }
+ }
+ }
+ }
+}
+
+private fun stopRecording(
+ audioRecordState: MutableState,
+ audioStream: ByteArrayOutputStream,
+): ByteArray {
+ Log.d(TAG, "Stopping recording...")
+
+ val recorder = audioRecordState.value
+ if (recorder?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
+ recorder.stop()
+ }
+ recorder?.release()
+ audioRecordState.value = null
+
+ val recordedBytes = audioStream.toByteArray()
+ audioStream.reset()
+ Log.d(TAG, "Stopped. Recorded ${recordedBytes.size} bytes.")
+
+ return recordedBytes
+}
+
+private fun calculatePeakAmplitude(buffer: ByteArray, bytesRead: Int): Int {
+ // Wrap the byte array in a ByteBuffer and set the order to little-endian
+ val shortBuffer =
+ ByteBuffer.wrap(buffer, 0, bytesRead).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
+
+ var maxAmplitude = 0
+ // Iterate through the short buffer to find the maximum absolute value
+ while (shortBuffer.hasRemaining()) {
+ val currentSample = abs(shortBuffer.get().toInt())
+ if (currentSample > maxAmplitude) {
+ maxAmplitude = currentSample
+ }
+ }
+ return maxAmplitude
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt
index 7eba194..e1f45b0 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt
@@ -17,18 +17,22 @@
package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap
+import android.util.Log
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.unit.Dp
import com.google.ai.edge.gallery.common.Classification
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.PromptTemplate
+private const val TAG = "AGChatMessage"
+
enum class ChatMessageType {
INFO,
WARNING,
TEXT,
IMAGE,
IMAGE_WITH_HISTORY,
+ AUDIO_CLIP,
LOADING,
CLASSIFICATION,
CONFIG_VALUES_CHANGE,
@@ -121,6 +125,90 @@ class ChatMessageImage(
}
}
+/** Chat message for audio clip. */
+class ChatMessageAudioClip(
+ val audioData: ByteArray,
+ val sampleRate: Int,
+ override val side: ChatSide,
+ override val latencyMs: Float = 0f,
+) : ChatMessage(type = ChatMessageType.AUDIO_CLIP, side = side, latencyMs = latencyMs) {
+ override fun clone(): ChatMessageAudioClip {
+ return ChatMessageAudioClip(
+ audioData = audioData,
+ sampleRate = sampleRate,
+ side = side,
+ latencyMs = latencyMs,
+ )
+ }
+
+ fun genByteArrayForWav(): ByteArray {
+ val header = ByteArray(44)
+
+ val pcmDataSize = audioData.size
+ val wavFileSize = pcmDataSize + 44 // 44 bytes for the header
+ val channels = 1 // Mono
+ val bitsPerSample: Short = 16
+ val byteRate = sampleRate * channels * bitsPerSample / 8
+ Log.d(TAG, "Wav metadata: sampleRate: $sampleRate")
+
+ // RIFF/WAVE header
+ header[0] = 'R'.code.toByte()
+ header[1] = 'I'.code.toByte()
+ header[2] = 'F'.code.toByte()
+ header[3] = 'F'.code.toByte()
+ header[4] = (wavFileSize and 0xff).toByte()
+ header[5] = (wavFileSize shr 8 and 0xff).toByte()
+ header[6] = (wavFileSize shr 16 and 0xff).toByte()
+ header[7] = (wavFileSize shr 24 and 0xff).toByte()
+ header[8] = 'W'.code.toByte()
+ header[9] = 'A'.code.toByte()
+ header[10] = 'V'.code.toByte()
+ header[11] = 'E'.code.toByte()
+ header[12] = 'f'.code.toByte()
+ header[13] = 'm'.code.toByte()
+ header[14] = 't'.code.toByte()
+ header[15] = ' '.code.toByte()
+ header[16] = 16
+ header[17] = 0
+ header[18] = 0
+ header[19] = 0 // Sub-chunk size (16 for PCM)
+ header[20] = 1
+ header[21] = 0 // Audio format (1 for PCM)
+ header[22] = channels.toByte()
+ header[23] = 0 // Number of channels
+ header[24] = (sampleRate and 0xff).toByte()
+ header[25] = (sampleRate shr 8 and 0xff).toByte()
+ header[26] = (sampleRate shr 16 and 0xff).toByte()
+ header[27] = (sampleRate shr 24 and 0xff).toByte()
+ header[28] = (byteRate and 0xff).toByte()
+ header[29] = (byteRate shr 8 and 0xff).toByte()
+ header[30] = (byteRate shr 16 and 0xff).toByte()
+ header[31] = (byteRate shr 24 and 0xff).toByte()
+ header[32] = (channels * bitsPerSample / 8).toByte()
+ header[33] = 0 // Block align
+ header[34] = bitsPerSample.toByte()
+ header[35] = (bitsPerSample.toInt() shr 8 and 0xff).toByte() // Bits per sample
+ header[36] = 'd'.code.toByte()
+ header[37] = 'a'.code.toByte()
+ header[38] = 't'.code.toByte()
+ header[39] = 'a'.code.toByte()
+ header[40] = (pcmDataSize and 0xff).toByte()
+ header[41] = (pcmDataSize shr 8 and 0xff).toByte()
+ header[42] = (pcmDataSize shr 16 and 0xff).toByte()
+ header[43] = (pcmDataSize shr 24 and 0xff).toByte()
+
+ return header + audioData
+ }
+
+ fun getDurationInSeconds(): Float {
+ // PCM 16-bit
+ val bytesPerSample = 2
+ val bytesPerFrame = bytesPerSample * 1 // mono
+ val totalFrames = audioData.size.toFloat() / bytesPerFrame
+ return totalFrames / sampleRate
+ }
+}
+
/** Chat message for images with history. */
class ChatMessageImageWithHistory(
val bitmaps: List,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt
index 2b8b168..80f2cfa 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt
@@ -137,6 +137,19 @@ fun ChatPanel(
}
imageMessageCount
}
+ val audioClipMesssageCountToLastconfigChange =
+ remember(messages) {
+ var audioClipMessageCount = 0
+ for (message in messages.reversed()) {
+ if (message is ChatMessageConfigValuesChange) {
+ break
+ }
+ if (message is ChatMessageAudioClip) {
+ audioClipMessageCount++
+ }
+ }
+ audioClipMessageCount
+ }
var curMessage by remember { mutableStateOf("") } // Correct state
val focusManager = LocalFocusManager.current
@@ -342,6 +355,9 @@ fun ChatPanel(
imageHistoryCurIndex = imageHistoryCurIndex,
)
+ // Audio clip.
+ is ChatMessageAudioClip -> MessageBodyAudioClip(message = message)
+
// Classification result
is ChatMessageClassification ->
MessageBodyClassification(
@@ -467,6 +483,22 @@ fun ChatPanel(
)
}
}
+ // Show an info message for ask image task to get users started.
+ else if (task.type == TaskType.LLM_ASK_AUDIO && messages.isEmpty()) {
+ Column(
+ modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(),
+ horizontalAlignment = Alignment.CenterHorizontally,
+ verticalArrangement = Arrangement.Center,
+ ) {
+ MessageBodyInfo(
+ ChatMessageInfo(
+ content =
+ "To get started, tap the + icon to add your audio clips. You can add up to 10 clips, each up to 30 seconds long."
+ ),
+ smallFontSize = false,
+ )
+ }
+ }
}
// Chat input
@@ -482,6 +514,7 @@ fun ChatPanel(
isResettingSession = uiState.isResettingSession,
modelPreparing = uiState.preparing,
imageMessageCount = imageMessageCountToLastConfigChange,
+ audioClipMessageCount = audioClipMesssageCountToLastconfigChange,
modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
@@ -504,7 +537,10 @@ fun ChatPanel(
onStopButtonClicked = onStopButtonClicked,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false,
- showImagePickerInMenu = selectedModel.llmSupportImage,
+ showImagePickerInMenu =
+ selectedModel.llmSupportImage && task.type === TaskType.LLM_ASK_IMAGE,
+ showAudioItemsInMenu =
+ selectedModel.llmSupportAudio && task.type === TaskType.LLM_ASK_AUDIO,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
)
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt
new file mode 100644
index 0000000..e4e3716
--- /dev/null
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.edge.gallery.ui.common.chat
+
+import androidx.compose.foundation.layout.padding
+import androidx.compose.runtime.Composable
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.unit.dp
+
+@Composable
+fun MessageBodyAudioClip(message: ChatMessageAudioClip, modifier: Modifier = Modifier) {
+ AudioPlaybackPanel(
+ audioData = message.audioData,
+ sampleRate = message.sampleRate,
+ isRecording = false,
+ modifier = Modifier.padding(end = 16.dp),
+ onDarkBg = true,
+ )
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt
index 1e2e056..d958815 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt
@@ -21,6 +21,7 @@ package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.Manifest
import android.content.Context
+import android.content.Intent
import android.content.pm.PackageManager
import android.graphics.Bitmap
import android.graphics.BitmapFactory
@@ -65,9 +66,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.Send
import androidx.compose.material.icons.rounded.Add
+import androidx.compose.material.icons.rounded.AudioFile
import androidx.compose.material.icons.rounded.Close
import androidx.compose.material.icons.rounded.FlipCameraAndroid
import androidx.compose.material.icons.rounded.History
+import androidx.compose.material.icons.rounded.Mic
import androidx.compose.material.icons.rounded.Photo
import androidx.compose.material.icons.rounded.PhotoCamera
import androidx.compose.material.icons.rounded.PostAdd
@@ -107,9 +110,14 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.viewinterop.AndroidView
import androidx.core.content.ContextCompat
import androidx.lifecycle.compose.LocalLifecycleOwner
+import com.google.ai.edge.gallery.common.AudioClip
+import com.google.ai.edge.gallery.common.convertWavToMonoWithMaxSeconds
+import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_COUNT
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
+import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import java.util.concurrent.Executors
+import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
private const val TAG = "AGMessageInputText"
@@ -128,6 +136,7 @@ fun MessageInputText(
isResettingSession: Boolean,
inProgress: Boolean,
imageMessageCount: Int,
+ audioClipMessageCount: Int,
modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit,
@@ -137,6 +146,7 @@ fun MessageInputText(
onStopButtonClicked: () -> Unit = {},
showPromptTemplatesInMenu: Boolean = false,
showImagePickerInMenu: Boolean = false,
+ showAudioItemsInMenu: Boolean = false,
showStopButtonWhenInProgress: Boolean = false,
) {
val context = LocalContext.current
@@ -146,7 +156,12 @@ fun MessageInputText(
var showTextInputHistorySheet by remember { mutableStateOf(false) }
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
+ var showAudioRecorderBottomSheet by remember { mutableStateOf(false) }
+ val audioRecorderSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var pickedImages by remember { mutableStateOf>(listOf()) }
+ var pickedAudioClips by remember { mutableStateOf>(listOf()) }
+ var hasFrontCamera by remember { mutableStateOf(false) }
+
val updatePickedImages: (List) -> Unit = { bitmaps ->
var newPickedImages: MutableList = mutableListOf()
newPickedImages.addAll(pickedImages)
@@ -156,7 +171,16 @@ fun MessageInputText(
}
pickedImages = newPickedImages.toList()
}
- var hasFrontCamera by remember { mutableStateOf(false) }
+
+ val updatePickedAudioClips: (List) -> Unit = { audioDataList ->
+ var newAudioDataList: MutableList = mutableListOf()
+ newAudioDataList.addAll(pickedAudioClips)
+ newAudioDataList.addAll(audioDataList)
+ if (newAudioDataList.size > MAX_AUDIO_CLIP_COUNT) {
+ newAudioDataList = newAudioDataList.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
+ }
+ pickedAudioClips = newAudioDataList.toList()
+ }
LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) }
@@ -170,6 +194,16 @@ fun MessageInputText(
}
}
+ // Permission request when recording audio clips.
+ val recordAudioClipsPermissionLauncher =
+ rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
+ permissionGranted ->
+ if (permissionGranted) {
+ showAddContentMenu = false
+ showAudioRecorderBottomSheet = true
+ }
+ }
+
// Registers a photo picker activity launcher in single-select mode.
val pickMedia =
rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris ->
@@ -184,9 +218,31 @@ fun MessageInputText(
}
}
+ val pickWav =
+ rememberLauncherForActivityResult(
+ contract = ActivityResultContracts.StartActivityForResult()
+ ) { result ->
+ if (result.resultCode == android.app.Activity.RESULT_OK) {
+ result.data?.data?.let { uri ->
+ Log.d(TAG, "Picked wav file: $uri")
+ scope.launch(Dispatchers.IO) {
+ convertWavToMonoWithMaxSeconds(context = context, stereoUri = uri)?.let { audioClip ->
+ updatePickedAudioClips(
+ listOf(
+ AudioClip(audioData = audioClip.audioData, sampleRate = audioClip.sampleRate)
+ )
+ )
+ }
+ }
+ }
+ } else {
+ Log.d(TAG, "Wav picking cancelled.")
+ }
+ }
+
Column {
- // A preview panel for the selected image.
- if (pickedImages.isNotEmpty()) {
+ // A preview panel for the selected images and audio clips.
+ if (pickedImages.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
Row(
modifier =
Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()),
@@ -203,20 +259,30 @@ fun MessageInputText(
.clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
)
+ MediaPanelCloseButton { pickedImages = pickedImages.filter { image != it } }
+ }
+ }
+
+ for ((index, audioClip) in pickedAudioClips.withIndex()) {
+ Box(contentAlignment = Alignment.TopEnd) {
Box(
modifier =
- Modifier.offset(x = 10.dp, y = (-10).dp)
- .clip(CircleShape)
+ Modifier.shadow(2.dp, shape = RoundedCornerShape(8.dp))
+ .clip(RoundedCornerShape(8.dp))
.background(MaterialTheme.colorScheme.surface)
- .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
- .clickable { pickedImages = pickedImages.filter { image != it } }
+ .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp))
) {
- Icon(
- Icons.Rounded.Close,
- contentDescription = "",
- modifier = Modifier.padding(3.dp).size(16.dp),
+ AudioPlaybackPanel(
+ audioData = audioClip.audioData,
+ sampleRate = audioClip.sampleRate,
+ isRecording = false,
+ modifier = Modifier.padding(end = 16.dp),
)
}
+ MediaPanelCloseButton {
+ pickedAudioClips =
+ pickedAudioClips.filterIndexed { curIndex, curAudioData -> curIndex != index }
+ }
}
}
}
@@ -239,10 +305,13 @@ fun MessageInputText(
verticalAlignment = Alignment.CenterVertically,
) {
val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT
+ val enableRecordAudioClipMenuItems =
+ (audioClipMessageCount + pickedAudioClips.size) < MAX_AUDIO_CLIP_COUNT
DropdownMenu(
expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false },
) {
+ // Image related menu items.
if (showImagePickerInMenu) {
// Take a picture.
DropdownMenuItem(
@@ -295,6 +364,70 @@ fun MessageInputText(
)
}
+ // Audio related menu items.
+ if (showAudioItemsInMenu) {
+ DropdownMenuItem(
+ text = {
+ Row(
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.spacedBy(6.dp),
+ ) {
+ Icon(Icons.Rounded.Mic, contentDescription = "")
+ Text("Record audio clip")
+ }
+ },
+ enabled = enableRecordAudioClipMenuItems,
+ onClick = {
+ // Check permission
+ when (PackageManager.PERMISSION_GRANTED) {
+ // Already got permission. Call the lambda.
+ ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO) -> {
+ showAddContentMenu = false
+ showAudioRecorderBottomSheet = true
+ }
+
+ // Otherwise, ask for permission
+ else -> {
+ recordAudioClipsPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO)
+ }
+ }
+ },
+ )
+
+ DropdownMenuItem(
+ text = {
+ Row(
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.spacedBy(6.dp),
+ ) {
+ Icon(Icons.Rounded.AudioFile, contentDescription = "")
+ Text("Pick wav file")
+ }
+ },
+ enabled = enableRecordAudioClipMenuItems,
+ onClick = {
+ showAddContentMenu = false
+
+ // Show file picker.
+ val intent =
+ Intent(Intent.ACTION_GET_CONTENT).apply {
+ addCategory(Intent.CATEGORY_OPENABLE)
+ type = "audio/*"
+
+ // Provide a list of more specific MIME types to filter for.
+ val mimeTypes = arrayOf("audio/wav", "audio/x-wav")
+ putExtra(Intent.EXTRA_MIME_TYPES, mimeTypes)
+
+ // Single select.
+ putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
+ .addFlags(Intent.FLAG_GRANT_PERSISTABLE_URI_PERMISSION)
+ .addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
+ }
+ pickWav.launch(intent)
+ },
+ )
+ }
+
// Prompt templates.
if (showPromptTemplatesInMenu) {
DropdownMenuItem(
@@ -369,15 +502,22 @@ fun MessageInputText(
)
}
}
- } // Send button. Only shown when text is not empty.
- else if (curMessage.isNotEmpty()) {
+ }
+ // Send button. Only shown when text is not empty, or there is at least one recorded
+ // audio clip.
+ else if (curMessage.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
IconButton(
enabled = !inProgress && !isResettingSession,
onClick = {
onSendMessage(
- createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
+ createMessagesToSend(
+ pickedImages = pickedImages,
+ audioClips = pickedAudioClips,
+ text = curMessage.trim(),
+ )
)
pickedImages = listOf()
+ pickedAudioClips = listOf()
},
colors =
IconButtonDefaults.iconButtonColors(
@@ -403,8 +543,15 @@ fun MessageInputText(
history = modelManagerUiState.textInputHistory,
onDismissed = { showTextInputHistorySheet = false },
onHistoryItemClicked = { item ->
- onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
+ onSendMessage(
+ createMessagesToSend(
+ pickedImages = pickedImages,
+ audioClips = pickedAudioClips,
+ text = item,
+ )
+ )
pickedImages = listOf()
+ pickedAudioClips = listOf()
modelManagerViewModel.promoteTextInputHistoryItem(item)
},
onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) },
@@ -582,6 +729,43 @@ fun MessageInputText(
}
}
}
+
+ if (showAudioRecorderBottomSheet) {
+ ModalBottomSheet(
+ sheetState = audioRecorderSheetState,
+ onDismissRequest = { showAudioRecorderBottomSheet = false },
+ ) {
+ AudioRecorderPanel(
+ onSendAudioClip = { audioData ->
+ scope.launch {
+ updatePickedAudioClips(
+ listOf(AudioClip(audioData = audioData, sampleRate = SAMPLE_RATE))
+ )
+ audioRecorderSheetState.hide()
+ showAudioRecorderBottomSheet = false
+ }
+ }
+ )
+ }
+ }
+}
+
+@Composable
+private fun MediaPanelCloseButton(onClicked: () -> Unit) {
+ Box(
+ modifier =
+ Modifier.offset(x = 10.dp, y = (-10).dp)
+ .clip(CircleShape)
+ .background(MaterialTheme.colorScheme.surface)
+ .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
+ .clickable { onClicked() }
+ ) {
+ Icon(
+ Icons.Rounded.Close,
+ contentDescription = "",
+ modifier = Modifier.padding(3.dp).size(16.dp),
+ )
+ }
}
private fun handleImagesSelected(
@@ -641,20 +825,50 @@ private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) {
)
}
-private fun createMessagesToSend(pickedImages: List, text: String): List {
+private fun createMessagesToSend(
+ pickedImages: List,
+ audioClips: List,
+ text: String,
+): List {
var messages: MutableList = mutableListOf()
+
+ // Add image messages.
+ var imageMessages: MutableList = mutableListOf()
if (pickedImages.isNotEmpty()) {
for (image in pickedImages) {
- messages.add(
+ imageMessages.add(
ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER)
)
}
}
// Cap the number of image messages.
- if (messages.size > MAX_IMAGE_COUNT) {
- messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
+ if (imageMessages.size > MAX_IMAGE_COUNT) {
+ imageMessages = imageMessages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
+ }
+ messages.addAll(imageMessages)
+
+ // Add audio messages.
+ var audioMessages: MutableList = mutableListOf()
+ if (audioClips.isNotEmpty()) {
+ for (audioClip in audioClips) {
+ audioMessages.add(
+ ChatMessageAudioClip(
+ audioData = audioClip.audioData,
+ sampleRate = audioClip.sampleRate,
+ side = ChatSide.USER,
+ )
+ )
+ }
+ }
+ // Cap the number of audio messages.
+ if (audioMessages.size > MAX_AUDIO_CLIP_COUNT) {
+ audioMessages = audioMessages.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
+ }
+ messages.addAll(audioMessages)
+
+ if (text.isNotEmpty()) {
+ messages.add(ChatMessageText(content = text, side = ChatSide.USER))
}
- messages.add(ChatMessageText(content = text, side = ChatSide.USER))
return messages
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt
index 1f0cd00..9c985e5 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt
@@ -121,6 +121,7 @@ private val IMPORT_CONFIGS_LLM: List =
valueType = ValueType.FLOAT,
),
BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false),
+ BooleanSwitchConfig(key = ConfigKey.SUPPORT_AUDIO, defaultValue = false),
SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,
@@ -230,6 +231,12 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
valueType = ValueType.BOOLEAN,
)
as Boolean
+ val supportAudio =
+ convertValueToTargetType(
+ value = values.get(ConfigKey.SUPPORT_AUDIO.label)!!,
+ valueType = ValueType.BOOLEAN,
+ )
+ as Boolean
val importedModel: ImportedModel =
ImportedModel.newBuilder()
.setFileName(fileName)
@@ -242,6 +249,7 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
.setDefaultTopp(defaultTopp)
.setDefaultTemperature(defaultTemperature)
.setSupportImage(supportImage)
+ .setSupportAudio(supportAudio)
.build()
)
.build()
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt
index ee418bb..dc0bba1 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt
@@ -173,7 +173,7 @@ fun SettingsDialog(
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
Text(
- "Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
+ "Expires at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt
deleted file mode 100644
index 173ca4f..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.icon
-
-import androidx.compose.ui.graphics.Color
-import androidx.compose.ui.graphics.SolidColor
-import androidx.compose.ui.graphics.vector.ImageVector
-import androidx.compose.ui.graphics.vector.path
-import androidx.compose.ui.unit.dp
-
-val Forum: ImageVector
- get() {
- if (_Forum != null) return _Forum!!
-
- _Forum =
- ImageVector.Builder(
- name = "Forum",
- defaultWidth = 24.dp,
- defaultHeight = 24.dp,
- viewportWidth = 960f,
- viewportHeight = 960f,
- )
- .apply {
- path(fill = SolidColor(Color(0xFF000000))) {
- moveTo(280f, 720f)
- quadToRelative(-17f, 0f, -28.5f, -11.5f)
- reflectiveQuadTo(240f, 680f)
- verticalLineToRelative(-80f)
- horizontalLineToRelative(520f)
- verticalLineToRelative(-360f)
- horizontalLineToRelative(80f)
- quadToRelative(17f, 0f, 28.5f, 11.5f)
- reflectiveQuadTo(880f, 280f)
- verticalLineToRelative(600f)
- lineTo(720f, 720f)
- close()
- moveTo(80f, 680f)
- verticalLineToRelative(-560f)
- quadToRelative(0f, -17f, 11.5f, -28.5f)
- reflectiveQuadTo(120f, 80f)
- horizontalLineToRelative(520f)
- quadToRelative(17f, 0f, 28.5f, 11.5f)
- reflectiveQuadTo(680f, 120f)
- verticalLineToRelative(360f)
- quadToRelative(0f, 17f, -11.5f, 28.5f)
- reflectiveQuadTo(640f, 520f)
- horizontalLineTo(240f)
- close()
- moveToRelative(520f, -240f)
- verticalLineToRelative(-280f)
- horizontalLineTo(160f)
- verticalLineToRelative(280f)
- close()
- moveToRelative(-440f, 0f)
- verticalLineToRelative(-280f)
- close()
- }
- }
- .build()
-
- return _Forum!!
- }
-
-private var _Forum: ImageVector? = null
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt
deleted file mode 100644
index 56f9990..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.icon
-
-import androidx.compose.ui.graphics.Color
-import androidx.compose.ui.graphics.SolidColor
-import androidx.compose.ui.graphics.vector.ImageVector
-import androidx.compose.ui.graphics.vector.path
-import androidx.compose.ui.unit.dp
-
-val Mms: ImageVector
- get() {
- if (_Mms != null) return _Mms!!
-
- _Mms =
- ImageVector.Builder(
- name = "Mms",
- defaultWidth = 24.dp,
- defaultHeight = 24.dp,
- viewportWidth = 960f,
- viewportHeight = 960f,
- )
- .apply {
- path(fill = SolidColor(Color(0xFF000000))) {
- moveTo(240f, 560f)
- horizontalLineToRelative(480f)
- lineTo(570f, 360f)
- lineTo(450f, 520f)
- lineToRelative(-90f, -120f)
- close()
- moveTo(80f, 880f)
- verticalLineToRelative(-720f)
- quadToRelative(0f, -33f, 23.5f, -56.5f)
- reflectiveQuadTo(160f, 80f)
- horizontalLineToRelative(640f)
- quadToRelative(33f, 0f, 56.5f, 23.5f)
- reflectiveQuadTo(880f, 160f)
- verticalLineToRelative(480f)
- quadToRelative(0f, 33f, -23.5f, 56.5f)
- reflectiveQuadTo(800f, 720f)
- horizontalLineTo(240f)
- close()
- moveToRelative(126f, -240f)
- horizontalLineToRelative(594f)
- verticalLineToRelative(-480f)
- horizontalLineTo(160f)
- verticalLineToRelative(525f)
- close()
- moveToRelative(-46f, 0f)
- verticalLineToRelative(-480f)
- close()
- }
- }
- .build()
-
- return _Mms!!
- }
-
-private var _Mms: ImageVector? = null
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt
deleted file mode 100644
index c727a7b..0000000
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright 2025 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.google.ai.edge.gallery.ui.icon
-
-import androidx.compose.ui.graphics.Color
-import androidx.compose.ui.graphics.SolidColor
-import androidx.compose.ui.graphics.vector.ImageVector
-import androidx.compose.ui.graphics.vector.path
-import androidx.compose.ui.unit.dp
-
-val Widgets: ImageVector
- get() {
- if (_Widgets != null) return _Widgets!!
-
- _Widgets =
- ImageVector.Builder(
- name = "Widgets",
- defaultWidth = 24.dp,
- defaultHeight = 24.dp,
- viewportWidth = 960f,
- viewportHeight = 960f,
- )
- .apply {
- path(fill = SolidColor(Color(0xFF000000))) {
- moveTo(666f, 520f)
- lineTo(440f, 294f)
- lineToRelative(226f, -226f)
- lineToRelative(226f, 226f)
- close()
- moveToRelative(-546f, -80f)
- verticalLineToRelative(-320f)
- horizontalLineToRelative(320f)
- verticalLineToRelative(320f)
- close()
- moveToRelative(400f, 400f)
- verticalLineToRelative(-320f)
- horizontalLineToRelative(320f)
- verticalLineToRelative(320f)
- close()
- moveToRelative(-400f, 0f)
- verticalLineToRelative(-320f)
- horizontalLineToRelative(320f)
- verticalLineToRelative(320f)
- close()
- moveToRelative(80f, -480f)
- horizontalLineToRelative(160f)
- verticalLineToRelative(-160f)
- horizontalLineTo(200f)
- close()
- moveToRelative(467f, 48f)
- lineToRelative(113f, -113f)
- lineToRelative(-113f, -113f)
- lineToRelative(-113f, 113f)
- close()
- moveToRelative(-67f, 352f)
- horizontalLineToRelative(160f)
- verticalLineToRelative(-160f)
- horizontalLineTo(600f)
- close()
- moveToRelative(-400f, 0f)
- horizontalLineToRelative(160f)
- verticalLineToRelative(-160f)
- horizontalLineTo(200f)
- close()
- moveToRelative(400f, -160f)
- }
- }
- .build()
-
- return _Widgets!!
- }
-
-private var _Widgets: ImageVector? = null
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt
index 128baa8..a330c35 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt
@@ -62,13 +62,13 @@ object LlmChatModelHelper {
Accelerator.GPU.label -> LlmInference.Backend.GPU
else -> LlmInference.Backend.GPU
}
- val options =
+ val optionsBuilder =
LlmInference.LlmInferenceOptions.builder()
.setModelPath(model.getPath(context = context))
.setMaxTokens(maxTokens)
.setPreferredBackend(preferredBackend)
.setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0)
- .build()
+ val options = optionsBuilder.build()
// Create an instance of the LLM Inference task and session.
try {
@@ -82,7 +82,9 @@ object LlmChatModelHelper {
.setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
- GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
+ GraphOptions.builder()
+ .setEnableVisionModality(model.llmSupportImage)
+ .build()
)
.build(),
)
@@ -115,7 +117,9 @@ object LlmChatModelHelper {
.setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
- GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
+ GraphOptions.builder()
+ .setEnableVisionModality(model.llmSupportImage)
+ .build()
)
.build(),
)
@@ -159,6 +163,7 @@ object LlmChatModelHelper {
resultListener: ResultListener,
cleanUpListener: CleanUpListener,
images: List = listOf(),
+ audioClips: List = listOf(),
) {
val instance = model.instance as LlmModelInstance
@@ -172,10 +177,16 @@ object LlmChatModelHelper {
// For a model that supports image modality, we need to add the text query chunk before adding
// image.
val session = instance.session
- session.addQueryChunk(input)
+ if (input.trim().isNotEmpty()) {
+ session.addQueryChunk(input)
+ }
for (image in images) {
session.addImage(BitmapImageBuilder(image).build())
}
+ for (audioClip in audioClips) {
+ // Uncomment when audio is supported.
+ // session.addAudio(audioClip)
+ }
val unused = session.generateResponseAsync(resultListener)
}
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt
index 0a97f3f..23b5777 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt
@@ -22,6 +22,7 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.ui.ViewModelProvider
+import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatView
@@ -36,6 +37,10 @@ object LlmAskImageDestination {
val route = "LlmAskImageRoute"
}
+object LlmAskAudioDestination {
+ val route = "LlmAskAudioRoute"
+}
+
@Composable
fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel,
@@ -66,6 +71,21 @@ fun LlmAskImageScreen(
)
}
+@Composable
+fun LlmAskAudioScreen(
+ modelManagerViewModel: ModelManagerViewModel,
+ navigateUp: () -> Unit,
+ modifier: Modifier = Modifier,
+ viewModel: LlmAskAudioViewModel = viewModel(factory = ViewModelProvider.Factory),
+) {
+ ChatViewWrapper(
+ viewModel = viewModel,
+ modelManagerViewModel = modelManagerViewModel,
+ navigateUp = navigateUp,
+ modifier = modifier,
+ )
+}
+
@Composable
fun ChatViewWrapper(
viewModel: LlmChatViewModel,
@@ -86,6 +106,7 @@ fun ChatViewWrapper(
var text = ""
val images: MutableList = mutableListOf()
+ val audioMessages: MutableList = mutableListOf()
var chatMessageText: ChatMessageText? = null
for (message in messages) {
if (message is ChatMessageText) {
@@ -93,14 +114,17 @@ fun ChatViewWrapper(
text = message.content
} else if (message is ChatMessageImage) {
images.add(message.bitmap)
+ } else if (message is ChatMessageAudioClip) {
+ audioMessages.add(message)
}
}
- if (text.isNotEmpty() && chatMessageText != null) {
+ if ((text.isNotEmpty() && chatMessageText != null) || audioMessages.isNotEmpty()) {
modelManagerViewModel.addTextInputHistory(text)
viewModel.generateResponse(
model = model,
input = text,
images = images,
+ audioMessages = audioMessages,
onError = {
viewModel.handleError(
context = context,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt
index 7b0a083..d97a9df 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt
@@ -22,9 +22,11 @@ import android.util.Log
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
+import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.Task
+import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
@@ -52,6 +54,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
model: Model,
input: String,
images: List = listOf(),
+ audioMessages: List = listOf(),
onError: () -> Unit,
) {
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
@@ -72,6 +75,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
val instance = model.instance as LlmModelInstance
var prefillTokens = instance.session.sizeInTokens(input)
prefillTokens += images.size * 257
+ for (audioMessages in audioMessages) {
+ // 150ms = 1 audio token
+ val duration = audioMessages.getDurationInSeconds()
+ prefillTokens += (duration * 1000f / 150f).toInt()
+ }
var firstRun = true
var timeToFirstToken = 0f
@@ -86,6 +94,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
model = model,
input = input,
images = images,
+ audioClips = audioMessages.map { it.genByteArrayForWav() },
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
@@ -214,7 +223,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
context: Context,
model: Model,
modelManagerViewModel: ModelManagerViewModel,
- triggeredMessage: ChatMessageText,
+ triggeredMessage: ChatMessageText?,
) {
// Clean up.
modelManagerViewModel.cleanupModel(task = task, model = model)
@@ -236,14 +245,20 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
)
// Add the triggered message back.
- addMessage(model = model, message = triggeredMessage)
+ if (triggeredMessage != null) {
+ addMessage(model = model, message = triggeredMessage)
+ }
// Re-initialize the session/engine.
modelManagerViewModel.initializeModel(context = context, task = task, model = model)
// Re-generate the response automatically.
- generateResponse(model = model, input = triggeredMessage.content, onError = {})
+ if (triggeredMessage != null) {
+ generateResponse(model = model, input = triggeredMessage.content, onError = {})
+ }
}
}
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
+
+class LlmAskAudioViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_AUDIO)
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt
index f73b893..6e2fb9e 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt
@@ -36,6 +36,7 @@ import com.google.ai.edge.gallery.data.ModelAllowlist
import com.google.ai.edge.gallery.data.ModelDownloadStatus
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.data.TASKS
+import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
@@ -281,15 +282,12 @@ open class ModelManagerViewModel(
}
}
when (task.type) {
- TaskType.LLM_CHAT ->
- LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
-
+ TaskType.LLM_CHAT,
+ TaskType.LLM_ASK_IMAGE,
+ TaskType.LLM_ASK_AUDIO,
TaskType.LLM_PROMPT_LAB ->
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
- TaskType.LLM_ASK_IMAGE ->
- LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
-
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
@@ -301,9 +299,11 @@ open class ModelManagerViewModel(
model.cleanUpAfterInit = false
Log.d(TAG, "Cleaning up model '${model.name}'...")
when (task.type) {
- TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
- TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model)
- TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model)
+ TaskType.LLM_CHAT,
+ TaskType.LLM_PROMPT_LAB,
+ TaskType.LLM_ASK_IMAGE,
+ TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model)
+
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
@@ -410,14 +410,19 @@ open class ModelManagerViewModel(
// Create model.
val model = createModelFromImportedModelInfo(info = info)
- for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
+ for (task in
+ listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
// Remove duplicated imported model if existed.
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex)
}
- if ((task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || task != TASK_LLM_ASK_IMAGE) {
+ if (
+ (task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) ||
+ (task == TASK_LLM_ASK_AUDIO && model.llmSupportAudio) ||
+ (task != TASK_LLM_ASK_IMAGE && task != TASK_LLM_ASK_AUDIO)
+ ) {
task.models.add(model)
}
task.updateTrigger.value = System.currentTimeMillis()
@@ -657,6 +662,7 @@ open class ModelManagerViewModel(
TASK_LLM_CHAT.models.clear()
TASK_LLM_PROMPT_LAB.models.clear()
TASK_LLM_ASK_IMAGE.models.clear()
+ TASK_LLM_ASK_AUDIO.models.clear()
for (allowedModel in modelAllowlist.models) {
if (allowedModel.disabled == true) {
continue
@@ -672,6 +678,9 @@ open class ModelManagerViewModel(
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
TASK_LLM_ASK_IMAGE.models.add(model)
}
+ if (allowedModel.taskTypes.contains(TASK_LLM_ASK_AUDIO.type.id)) {
+ TASK_LLM_ASK_AUDIO.models.add(model)
+ }
}
// Pre-process all tasks.
@@ -760,6 +769,9 @@ open class ModelManagerViewModel(
if (model.llmSupportImage) {
TASK_LLM_ASK_IMAGE.models.add(model)
}
+ if (model.llmSupportAudio) {
+ TASK_LLM_ASK_AUDIO.models.add(model)
+ }
// Update status.
modelDownloadStatus[model.name] =
@@ -800,6 +812,7 @@ open class ModelManagerViewModel(
accelerators = accelerators,
)
val llmSupportImage = info.llmConfig.supportImage
+ val llmSupportAudio = info.llmConfig.supportAudio
val model =
Model(
name = info.fileName,
@@ -811,6 +824,7 @@ open class ModelManagerViewModel(
showRunAgainButton = false,
imported = true,
llmSupportImage = llmSupportImage,
+ llmSupportAudio = llmSupportAudio,
)
model.preProcess()
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt
index 5e8672d..138edc7 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt
@@ -47,6 +47,7 @@ import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.navArgument
import com.google.ai.edge.gallery.data.Model
+import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
@@ -55,6 +56,8 @@ import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.data.getModelByName
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.home.HomeScreen
+import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination
+import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination
@@ -209,7 +212,7 @@ fun GalleryNavHost(
}
}
- // LLM image to text.
+ // Ask image.
composable(
route = "${LlmAskImageDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
@@ -225,6 +228,23 @@ fun GalleryNavHost(
)
}
}
+
+ // Ask audio.
+ composable(
+ route = "${LlmAskAudioDestination.route}/{modelName}",
+ arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
+ enterTransition = { slideEnter() },
+ exitTransition = { slideExit() },
+ ) {
+ getModelFromNavigationParam(it, TASK_LLM_ASK_AUDIO)?.let { defaultModel ->
+ modelManagerViewModel.selectModel(defaultModel)
+
+ LlmAskAudioScreen(
+ modelManagerViewModel = modelManagerViewModel,
+ navigateUp = { navController.navigateUp() },
+ )
+ }
+ }
}
// Handle incoming intents for deep links
@@ -256,6 +276,7 @@ fun navigateToTaskScreen(
when (taskType) {
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
+ TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}")
TaskType.LLM_PROMPT_LAB ->
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt
index 72ad1d5..a1a7ad1 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt
@@ -120,6 +120,8 @@ data class CustomColors(
val agentBubbleBgColor: Color = Color.Transparent,
val linkColor: Color = Color.Transparent,
val successColor: Color = Color.Transparent,
+ val recordButtonBgColor: Color = Color.Transparent,
+ val waveFormBgColor: Color = Color.Transparent,
)
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
@@ -145,6 +147,8 @@ val lightCustomColors =
userBubbleBgColor = Color(0xFF32628D),
linkColor = Color(0xFF32628D),
successColor = Color(0xff3d860b),
+ recordButtonBgColor = Color(0xFFEE675C),
+ waveFormBgColor = Color(0xFFaaaaaa),
)
val darkCustomColors =
@@ -168,6 +172,8 @@ val darkCustomColors =
userBubbleBgColor = Color(0xFF1f3760),
linkColor = Color(0xFF9DCAFC),
successColor = Color(0xFFA1CE83),
+ recordButtonBgColor = Color(0xFFEE675C),
+ waveFormBgColor = Color(0xFFaaaaaa),
)
val MaterialTheme.customColors: CustomColors
diff --git a/Android/src/app/src/main/proto/settings.proto b/Android/src/app/src/main/proto/settings.proto
index 5540eaf..3b5f6f3 100644
--- a/Android/src/app/src/main/proto/settings.proto
+++ b/Android/src/app/src/main/proto/settings.proto
@@ -55,6 +55,7 @@ message LlmConfig {
float default_topp = 4;
float default_temperature = 5;
bool support_image = 6;
+ bool support_audio = 7;
}
message Settings {