Add audio support.

- Add a new task "audio scribe".
- Allow users to record audio clips or pick wav files to interact with model.
- Add support for importing models with audio capability.
- Fix a typo in Settings dialog (Thanks https://github.com/rhnvrm!)

PiperOrigin-RevId: 774832681
This commit is contained in:
Google AI Edge Gallery 2025-06-23 10:24:10 -07:00 committed by Copybara-Service
parent 33c3ee638e
commit d0989adce1
27 changed files with 1369 additions and 288 deletions

View file

@ -29,6 +29,7 @@
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<uses-permission android:name="android.permission.WAKE_LOCK"/>
<uses-feature

View file

@ -25,3 +25,5 @@ interface LatencyProvider {
data class Classification(val label: String, val score: Float, val color: Color)
data class JsonObjAndTextContent<T>(val jsonObj: T, val textContent: String)
class AudioClip(val audioData: ByteArray, val sampleRate: Int)

View file

@ -17,12 +17,17 @@
package com.google.ai.edge.gallery.common
import android.content.Context
import android.net.Uri
import android.util.Log
import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.math.floor
data class LaunchInfo(val ts: Long)
@ -112,3 +117,135 @@ inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
return null
}
fun convertWavToMonoWithMaxSeconds(
context: Context,
stereoUri: Uri,
maxSeconds: Int = 30,
): AudioClip? {
Log.d(TAG, "Start to convert wav file to mono channel")
try {
val inputStream = context.contentResolver.openInputStream(stereoUri) ?: return null
val originalBytes = inputStream.readBytes()
inputStream.close()
// Read WAV header
if (originalBytes.size < 44) {
// Not a valid WAV file
Log.e(TAG, "Not a valid wav file")
return null
}
val headerBuffer = ByteBuffer.wrap(originalBytes, 0, 44).order(ByteOrder.LITTLE_ENDIAN)
val channels = headerBuffer.getShort(22)
var sampleRate = headerBuffer.getInt(24)
val bitDepth = headerBuffer.getShort(34)
Log.d(TAG, "File metadata: channels: $channels, sampleRate: $sampleRate, bitDepth: $bitDepth")
// Normalize audio to 16-bit.
val audioDataBytes = originalBytes.copyOfRange(fromIndex = 44, toIndex = originalBytes.size)
var sixteenBitBytes: ByteArray =
if (bitDepth.toInt() == 8) {
Log.d(TAG, "Converting 8-bit audio to 16-bit.")
convert8BitTo16Bit(audioDataBytes)
} else {
// Assume 16-bit or other format that can be handled directly
audioDataBytes
}
// Convert byte array to short array for processing
val shortBuffer =
ByteBuffer.wrap(sixteenBitBytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
var pcmSamples = ShortArray(shortBuffer.remaining())
shortBuffer.get(pcmSamples)
// Resample if sample rate is less than 16000 Hz ---
if (sampleRate < SAMPLE_RATE) {
Log.d(TAG, "Resampling from $sampleRate Hz to $SAMPLE_RATE Hz.")
pcmSamples = resample(pcmSamples, sampleRate, SAMPLE_RATE, channels.toInt())
sampleRate = SAMPLE_RATE
Log.d(TAG, "Resampling complete. New sample count: ${pcmSamples.size}")
}
// Convert stereo to mono if necessary
var monoSamples =
if (channels.toInt() == 2) {
Log.d(TAG, "Converting stereo to mono.")
val mono = ShortArray(pcmSamples.size / 2)
for (i in mono.indices) {
val left = pcmSamples[i * 2]
val right = pcmSamples[i * 2 + 1]
mono[i] = ((left + right) / 2).toShort()
}
mono
} else {
Log.d(TAG, "Audio is already mono. No channel conversion needed.")
pcmSamples
}
// Trim the audio to maxSeconds ---
val maxSamples = maxSeconds * sampleRate
if (monoSamples.size > maxSamples) {
Log.d(TAG, "Trimming clip from ${monoSamples.size} samples to $maxSamples samples.")
monoSamples = monoSamples.copyOfRange(0, maxSamples)
}
val monoByteBuffer = ByteBuffer.allocate(monoSamples.size * 2).order(ByteOrder.LITTLE_ENDIAN)
monoByteBuffer.asShortBuffer().put(monoSamples)
return AudioClip(audioData = monoByteBuffer.array(), sampleRate = sampleRate)
} catch (e: Exception) {
Log.e(TAG, "Failed to convert wav to mono", e)
return null
}
}
/** Converts 8-bit unsigned PCM audio data to 16-bit signed PCM. */
private fun convert8BitTo16Bit(eightBitData: ByteArray): ByteArray {
// The new 16-bit data will be twice the size
val sixteenBitData = ByteArray(eightBitData.size * 2)
val buffer = ByteBuffer.wrap(sixteenBitData).order(ByteOrder.LITTLE_ENDIAN)
for (byte in eightBitData) {
// Convert the unsigned 8-bit byte (0-255) to a signed 16-bit short (-32768 to 32767)
// 1. Get the unsigned value by masking with 0xFF
// 2. Subtract 128 to center the waveform around 0 (range becomes -128 to 127)
// 3. Scale by 256 to expand to the 16-bit range
val unsignedByte = byte.toInt() and 0xFF
val sixteenBitSample = ((unsignedByte - 128) * 256).toShort()
buffer.putShort(sixteenBitSample)
}
return sixteenBitData
}
/** Resamples PCM audio data from an original sample rate to a target sample rate. */
private fun resample(
inputSamples: ShortArray,
originalSampleRate: Int,
targetSampleRate: Int,
channels: Int,
): ShortArray {
if (originalSampleRate == targetSampleRate) {
return inputSamples
}
val ratio = targetSampleRate.toDouble() / originalSampleRate
val outputLength = (inputSamples.size * ratio).toInt()
val resampledData = ShortArray(outputLength)
if (channels == 1) { // Mono
for (i in resampledData.indices) {
val position = i / ratio
val index1 = floor(position).toInt()
val index2 = index1 + 1
val fraction = position - index1
val sample1 = if (index1 < inputSamples.size) inputSamples[index1].toDouble() else 0.0
val sample2 = if (index2 < inputSamples.size) inputSamples[index2].toDouble() else 0.0
resampledData[i] = (sample1 * (1 - fraction) + sample2 * fraction).toInt().toShort()
}
}
return resampledData
}

View file

@ -50,6 +50,7 @@ enum class ConfigKey(val label: String) {
DEFAULT_TOPP("Default TopP"),
DEFAULT_TEMPERATURE("Default temperature"),
SUPPORT_IMAGE("Support image"),
SUPPORT_AUDIO("Support audio"),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Choose accelerator"),

View file

@ -44,3 +44,12 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
// Max number of images allowed in a "ask image" session.
const val MAX_IMAGE_COUNT = 10
// Max number of audio clip in an "ask audio" session.
const val MAX_AUDIO_CLIP_COUNT = 10
// Max audio clip duration in seconds.
const val MAX_AUDIO_CLIP_DURATION_SEC = 30
// Audio-recording related consts.
const val SAMPLE_RATE = 16000

View file

@ -87,6 +87,9 @@ data class Model(
/** Whether the LLM model supports image input. */
val llmSupportImage: Boolean = false,
/** Whether the LLM model supports audio input. */
val llmSupportAudio: Boolean = false,
/** Whether the model is imported or not. */
val imported: Boolean = false,

View file

@ -38,6 +38,7 @@ data class AllowedModel(
val taskTypes: List<String>,
val disabled: Boolean? = null,
val llmSupportImage: Boolean? = null,
val llmSupportAudio: Boolean? = null,
val estimatedPeakMemoryInBytes: Long? = null,
) {
fun toModel(): Model {
@ -96,6 +97,7 @@ data class AllowedModel(
showRunAgainButton = showRunAgainButton,
learnMoreUrl = "https://huggingface.co/${modelId}",
llmSupportImage = llmSupportImage == true,
llmSupportAudio = llmSupportAudio == true,
)
}

View file

@ -17,19 +17,22 @@
package com.google.ai.edge.gallery.data
import androidx.annotation.StringRes
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.Forum
import androidx.compose.material.icons.outlined.Mic
import androidx.compose.material.icons.outlined.Mms
import androidx.compose.material.icons.outlined.Widgets
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.ui.graphics.vector.ImageVector
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.icon.Forum
import com.google.ai.edge.gallery.ui.icon.Mms
import com.google.ai.edge.gallery.ui.icon.Widgets
/** Type of task. */
enum class TaskType(val label: String, val id: String) {
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
}
@ -71,7 +74,7 @@ data class Task(
val TASK_LLM_CHAT =
Task(
type = TaskType.LLM_CHAT,
icon = Forum,
icon = Icons.Outlined.Forum,
models = mutableListOf(),
description = "Chat with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -83,7 +86,7 @@ val TASK_LLM_CHAT =
val TASK_LLM_PROMPT_LAB =
Task(
type = TaskType.LLM_PROMPT_LAB,
icon = Widgets,
icon = Icons.Outlined.Widgets,
models = mutableListOf(),
description = "Single turn use cases with on-device large language model",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -95,7 +98,7 @@ val TASK_LLM_PROMPT_LAB =
val TASK_LLM_ASK_IMAGE =
Task(
type = TaskType.LLM_ASK_IMAGE,
icon = Mms,
icon = Icons.Outlined.Mms,
models = mutableListOf(),
description = "Ask questions about images with on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@ -104,8 +107,23 @@ val TASK_LLM_ASK_IMAGE =
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
val TASK_LLM_ASK_AUDIO =
Task(
type = TaskType.LLM_ASK_AUDIO,
icon = Icons.Outlined.Mic,
models = mutableListOf(),
// TODO(do not submit)
description =
"Instantly transcribe and/or translate audio clips using on-device large language models",
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
sourceCodeUrl =
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
/** All tasks. */
val TASKS: List<Task> = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
val TASKS: List<Task> =
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
fun getModelByName(name: String): Model? {
for (task in TASKS) {

View file

@ -21,6 +21,7 @@ import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.initializer
import androidx.lifecycle.viewmodel.viewModelFactory
import com.google.ai.edge.gallery.GalleryApplication
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
@ -49,6 +50,9 @@ object ViewModelProvider {
// Initializer for LlmAskImageViewModel.
initializer { LlmAskImageViewModel() }
// Initializer for LlmAskAudioViewModel.
initializer { LlmAskAudioViewModel() }
}
}

View file

@ -0,0 +1,344 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.chat
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioTrack
import android.util.Log
import androidx.compose.foundation.Canvas
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.PlayArrow
import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.DisposableEffect
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.geometry.CornerRadius
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.geometry.Size
import androidx.compose.ui.geometry.toRect
import androidx.compose.ui.graphics.BlendMode
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.drawscope.drawIntoCanvas
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
import com.google.ai.edge.gallery.ui.theme.customColors
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
private const val TAG = "AGAudioPlaybackPanel"
private const val BAR_SPACE = 2
private const val BAR_WIDTH = 2
private const val MIN_BAR_COUNT = 16
private const val MAX_BAR_COUNT = 48
/**
* A Composable that displays an audio playback panel, including play/stop controls, a waveform
* visualization, and the duration of the audio clip.
*/
@Composable
fun AudioPlaybackPanel(
audioData: ByteArray,
sampleRate: Int,
isRecording: Boolean,
modifier: Modifier = Modifier,
onDarkBg: Boolean = false,
) {
val coroutineScope = rememberCoroutineScope()
var isPlaying by remember { mutableStateOf(false) }
val audioTrackState = remember { mutableStateOf<AudioTrack?>(null) }
val durationInSeconds =
remember(audioData) {
// PCM 16-bit
val bytesPerSample = 2
val bytesPerFrame = bytesPerSample * 1 // mono
val totalFrames = audioData.size.toDouble() / bytesPerFrame
totalFrames / sampleRate
}
val barCount =
remember(durationInSeconds) {
val f = durationInSeconds / MAX_AUDIO_CLIP_DURATION_SEC
((MAX_BAR_COUNT - MIN_BAR_COUNT) * f + MIN_BAR_COUNT).toInt()
}
val amplitudeLevels =
remember(audioData) { generateAmplitudeLevels(audioData = audioData, barCount = barCount) }
var playbackProgress by remember { mutableFloatStateOf(0f) }
// Reset when a new recording is started.
LaunchedEffect(isRecording) {
if (isRecording) {
val audioTrack = audioTrackState.value
audioTrack?.stop()
isPlaying = false
playbackProgress = 0f
}
}
// Cleanup on Composable Disposal.
DisposableEffect(Unit) {
onDispose {
val audioTrack = audioTrackState.value
audioTrack?.stop()
audioTrack?.release()
}
}
Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) {
// Button to play/stop the clip.
IconButton(
onClick = {
coroutineScope.launch {
if (!isPlaying) {
isPlaying = true
playAudio(
audioTrackState = audioTrackState,
audioData = audioData,
sampleRate = sampleRate,
onProgress = { playbackProgress = it },
onCompletion = {
playbackProgress = 0f
isPlaying = false
},
)
} else {
stopPlayAudio(audioTrackState = audioTrackState)
playbackProgress = 0f
isPlaying = false
}
}
}
) {
Icon(
if (isPlaying) Icons.Rounded.Stop else Icons.Rounded.PlayArrow,
contentDescription = "",
tint = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
)
}
// Visualization
AmplitudeBarGraph(
amplitudeLevels = amplitudeLevels,
progress = playbackProgress,
modifier =
Modifier.width((barCount * BAR_WIDTH + (barCount - 1) * BAR_SPACE).dp).height(24.dp),
onDarkBg = onDarkBg,
)
// Duration
Text(
"${"%.1f".format(durationInSeconds)}s",
style = MaterialTheme.typography.labelLarge,
color = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
modifier = Modifier.padding(start = 12.dp),
)
}
}
@Composable
private fun AmplitudeBarGraph(
amplitudeLevels: List<Float>,
progress: Float,
modifier: Modifier = Modifier,
onDarkBg: Boolean = false,
) {
val barColor = MaterialTheme.customColors.waveFormBgColor
val progressColor = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary
Canvas(modifier = modifier) {
val barCount = amplitudeLevels.size
val barWidth = (size.width - BAR_SPACE.dp.toPx() * (barCount - 1)) / barCount
val cornerRadius = CornerRadius(x = barWidth, y = barWidth)
// Use drawIntoCanvas for advanced blend mode operations
drawIntoCanvas { canvas ->
// 1. Save the current state of the canvas onto a temporary, offscreen layer
canvas.saveLayer(size.toRect(), androidx.compose.ui.graphics.Paint())
// 2. Draw the bars in grey.
amplitudeLevels.forEachIndexed { index, level ->
val barHeight = (level * size.height).coerceAtLeast(1.5f)
val left = index * (barWidth + BAR_SPACE.dp.toPx())
drawRoundRect(
color = barColor,
topLeft = Offset(x = left, y = size.height / 2 - barHeight / 2),
size = Size(barWidth, barHeight),
cornerRadius = cornerRadius,
)
}
// 3. Draw the progress rectangle using BlendMode.SrcIn to only draw where the bars already
// exists.
val progressWidth = size.width * progress
drawRect(
color = progressColor,
topLeft = Offset.Zero,
size = Size(progressWidth, size.height),
blendMode = BlendMode.SrcIn,
)
// 4. Restore the layer, merging it onto the main canvas
canvas.restore()
}
}
}
private suspend fun playAudio(
audioTrackState: MutableState<AudioTrack?>,
audioData: ByteArray,
sampleRate: Int,
onProgress: (Float) -> Unit,
onCompletion: () -> Unit,
) {
Log.d(TAG, "Start playing audio...")
try {
withContext(Dispatchers.IO) {
var lastProgressUpdateMs = 0L
audioTrackState.value?.release()
val audioTrack =
AudioTrack.Builder()
.setAudioAttributes(
AudioAttributes.Builder()
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()
)
.setAudioFormat(
AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
.setSampleRate(sampleRate)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.build()
)
.setTransferMode(AudioTrack.MODE_STATIC)
.setBufferSizeInBytes(audioData.size)
.build()
val bytesPerFrame = 2 // For PCM 16-bit Mono
val totalFrames = audioData.size / bytesPerFrame
audioTrackState.value = audioTrack
audioTrack.write(audioData, 0, audioData.size)
audioTrack.play()
// Coroutine to monitor progress
while (isActive && audioTrack.playState == AudioTrack.PLAYSTATE_PLAYING) {
val currentFrames = audioTrack.playbackHeadPosition
if (currentFrames >= totalFrames) {
break // Exit loop when playback is done
}
val progress = currentFrames.toFloat() / totalFrames
val curMs = System.currentTimeMillis()
if (curMs - lastProgressUpdateMs > 30) {
onProgress(progress)
lastProgressUpdateMs = curMs
}
}
if (isActive) {
audioTrackState.value?.stop()
}
}
} catch (e: Exception) {
// Ignore
} finally {
onProgress(1f)
onCompletion()
}
}
private fun stopPlayAudio(audioTrackState: MutableState<AudioTrack?>) {
Log.d(TAG, "Stopping playing audio...")
val audioTrack = audioTrackState.value
audioTrack?.stop()
audioTrack?.release()
audioTrackState.value = null
}
/**
* Processes a raw PCM 16-bit audio byte array to generate a list of normalized amplitude levels for
* visualization.
*/
private fun generateAmplitudeLevels(audioData: ByteArray, barCount: Int): List<Float> {
if (audioData.isEmpty()) {
return List(barCount) { 0f }
}
// 1. Parse bytes into 16-bit short samples (PCM 16-bit)
val shortBuffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
val samples = ShortArray(shortBuffer.remaining())
shortBuffer.get(samples)
if (samples.isEmpty()) {
return List(barCount) { 0f }
}
// 2. Determine the size of each chunk
val chunkSize = samples.size / barCount
val amplitudeLevels = mutableListOf<Float>()
// 3. Get the max value for each chunk
for (i in 0 until barCount) {
val chunkStart = i * chunkSize
val chunkEnd = (chunkStart + chunkSize).coerceAtMost(samples.size)
var maxAmplitudeInChunk = 0.0
for (j in chunkStart until chunkEnd) {
val sampleAbs = kotlin.math.abs(samples[j].toDouble())
if (sampleAbs > maxAmplitudeInChunk) {
maxAmplitudeInChunk = sampleAbs
}
}
// 4. Normalize the value (0 to 1)
// Short.MAX_VALUE is 32767.0, a good reference for max amplitude
val normalizedRms = (maxAmplitudeInChunk / Short.MAX_VALUE).toFloat().coerceIn(0f, 1f)
amplitudeLevels.add(normalizedRms)
}
// Normalize the resulting levels so that the max value becomes 0.9.
val maxVal = amplitudeLevels.max()
if (maxVal == 0f) {
return amplitudeLevels
}
val scaleFactor = 0.9f / maxVal
return amplitudeLevels.map { it * scaleFactor }
}

View file

@ -0,0 +1,327 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.chat
import android.annotation.SuppressLint
import android.content.Context
import android.media.AudioFormat
import android.media.AudioRecord
import android.media.MediaRecorder
import android.util.Log
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ArrowUpward
import androidx.compose.material.icons.rounded.Mic
import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.DisposableEffect
import androidx.compose.runtime.MutableLongState
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableLongStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.ai.edge.gallery.ui.theme.customColors
import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.math.abs
import kotlin.math.pow
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
private const val TAG = "AGAudioRecorderPanel"
private const val CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO
private const val AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT
/**
* A Composable that provides an audio recording panel. It allows users to record audio clips,
* displays the recording duration and a live amplitude visualization, and provides options to play
* back the recorded clip or send it.
*/
@Composable
fun AudioRecorderPanel(onSendAudioClip: (ByteArray) -> Unit) {
val context = LocalContext.current
val coroutineScope = rememberCoroutineScope()
var isRecording by remember { mutableStateOf(false) }
val elapsedMs = remember { mutableLongStateOf(0L) }
val audioRecordState = remember { mutableStateOf<AudioRecord?>(null) }
val audioStream = remember { ByteArrayOutputStream() }
val recordedBytes = remember { mutableStateOf<ByteArray?>(null) }
var currentAmplitude by remember { mutableIntStateOf(0) }
val elapsedSeconds by remember {
derivedStateOf { "%.1f".format(elapsedMs.value.toFloat() / 1000f) }
}
// Cleanup on Composable Disposal.
DisposableEffect(Unit) { onDispose { audioRecordState.value?.release() } }
Column(modifier = Modifier.padding(bottom = 12.dp)) {
// Title bar.
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
) {
// Logo and state.
Row(
modifier = Modifier.padding(start = 16.dp),
horizontalArrangement = Arrangement.spacedBy(12.dp),
) {
Icon(
painterResource(R.drawable.logo),
modifier = Modifier.size(20.dp),
contentDescription = "",
tint = Color.Unspecified,
)
Text(
"Record audio clip (up to $MAX_AUDIO_CLIP_DURATION_SEC seconds)",
style = MaterialTheme.typography.labelLarge,
)
}
}
// Recorded clip.
Row(
modifier = Modifier.fillMaxWidth().padding(vertical = 12.dp).height(40.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.Center,
) {
val curRecordedBytes = recordedBytes.value
if (curRecordedBytes == null) {
// Info message when there is no recorded clip and the recording has not started yet.
if (!isRecording) {
Text(
"Tap the record button to start",
style = MaterialTheme.typography.labelLarge,
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
}
// Visualization for clip being recorded.
else {
Row(
horizontalArrangement = Arrangement.spacedBy(12.dp),
verticalAlignment = Alignment.CenterVertically,
) {
Box(
modifier =
Modifier.size(8.dp)
.background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
)
Text("$elapsedSeconds s")
}
}
}
// Controls for recorded clip.
else {
Row {
// Clip player.
AudioPlaybackPanel(
audioData = curRecordedBytes,
sampleRate = SAMPLE_RATE,
isRecording = isRecording,
)
// Button to send the clip
IconButton(onClick = { onSendAudioClip(curRecordedBytes) }) {
Icon(Icons.Rounded.ArrowUpward, contentDescription = "")
}
}
}
}
// Buttons
Box(contentAlignment = Alignment.Center, modifier = Modifier.fillMaxWidth().height(40.dp)) {
// Visualization of the current amplitude.
if (isRecording) {
// Normalize the amplitude (0-32767) to a fraction (0.0-1.0)
// We use a power scale (exponent < 1) to make the pulse more visible for lower volumes.
val normalizedAmplitude = (currentAmplitude.toFloat() / 32767f).pow(0.35f)
// Define the min and max size of the circle
val minSize = 38.dp
val maxSize = 100.dp
// Map the normalized amplitude to our size range
val scale by
remember(normalizedAmplitude) {
derivedStateOf { (minSize + (maxSize - minSize) * normalizedAmplitude) / minSize }
}
Box(
modifier =
Modifier.size(minSize)
.graphicsLayer(scaleX = scale, scaleY = scale, clip = false, alpha = 0.3f)
.background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
)
}
// Record/stop button.
IconButton(
onClick = {
coroutineScope.launch {
if (!isRecording) {
isRecording = true
recordedBytes.value = null
startRecording(
context = context,
audioRecordState = audioRecordState,
audioStream = audioStream,
elapsedMs = elapsedMs,
onAmplitudeChanged = { currentAmplitude = it },
onMaxDurationReached = {
val curRecordedBytes =
stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
recordedBytes.value = curRecordedBytes
isRecording = false
},
)
} else {
val curRecordedBytes =
stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
recordedBytes.value = curRecordedBytes
isRecording = false
}
}
},
modifier =
Modifier.clip(CircleShape).background(MaterialTheme.customColors.recordButtonBgColor),
) {
Icon(
if (isRecording) Icons.Rounded.Stop else Icons.Rounded.Mic,
contentDescription = "",
tint = Color.White,
)
}
}
}
}
// Permission is checked in parent composable.
@SuppressLint("MissingPermission")
private suspend fun startRecording(
context: Context,
audioRecordState: MutableState<AudioRecord?>,
audioStream: ByteArrayOutputStream,
elapsedMs: MutableLongState,
onAmplitudeChanged: (Int) -> Unit,
onMaxDurationReached: () -> Unit,
) {
Log.d(TAG, "Start recording...")
val minBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT)
audioRecordState.value?.release()
val recorder =
AudioRecord(
MediaRecorder.AudioSource.MIC,
SAMPLE_RATE,
CHANNEL_CONFIG,
AUDIO_FORMAT,
minBufferSize,
)
audioRecordState.value = recorder
val buffer = ByteArray(minBufferSize)
// The function will only return when the recording is done (when stopRecording is called).
coroutineScope {
launch(Dispatchers.IO) {
recorder.startRecording()
val startMs = System.currentTimeMillis()
elapsedMs.value = 0L
while (audioRecordState.value?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
val bytesRead = recorder.read(buffer, 0, buffer.size)
if (bytesRead > 0) {
val currentAmplitude = calculatePeakAmplitude(buffer = buffer, bytesRead = bytesRead)
onAmplitudeChanged(currentAmplitude)
audioStream.write(buffer, 0, bytesRead)
}
elapsedMs.value = System.currentTimeMillis() - startMs
if (elapsedMs.value >= MAX_AUDIO_CLIP_DURATION_SEC * 1000) {
onMaxDurationReached()
break
}
}
}
}
}
private fun stopRecording(
audioRecordState: MutableState<AudioRecord?>,
audioStream: ByteArrayOutputStream,
): ByteArray {
Log.d(TAG, "Stopping recording...")
val recorder = audioRecordState.value
if (recorder?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
recorder.stop()
}
recorder?.release()
audioRecordState.value = null
val recordedBytes = audioStream.toByteArray()
audioStream.reset()
Log.d(TAG, "Stopped. Recorded ${recordedBytes.size} bytes.")
return recordedBytes
}
private fun calculatePeakAmplitude(buffer: ByteArray, bytesRead: Int): Int {
// Wrap the byte array in a ByteBuffer and set the order to little-endian
val shortBuffer =
ByteBuffer.wrap(buffer, 0, bytesRead).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
var maxAmplitude = 0
// Iterate through the short buffer to find the maximum absolute value
while (shortBuffer.hasRemaining()) {
val currentSample = abs(shortBuffer.get().toInt())
if (currentSample > maxAmplitude) {
maxAmplitude = currentSample
}
}
return maxAmplitude
}

View file

@ -17,18 +17,22 @@
package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap
import android.util.Log
import androidx.compose.ui.graphics.ImageBitmap
import androidx.compose.ui.unit.Dp
import com.google.ai.edge.gallery.common.Classification
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.PromptTemplate
private const val TAG = "AGChatMessage"
enum class ChatMessageType {
INFO,
WARNING,
TEXT,
IMAGE,
IMAGE_WITH_HISTORY,
AUDIO_CLIP,
LOADING,
CLASSIFICATION,
CONFIG_VALUES_CHANGE,
@ -121,6 +125,90 @@ class ChatMessageImage(
}
}
/** Chat message for audio clip. */
class ChatMessageAudioClip(
val audioData: ByteArray,
val sampleRate: Int,
override val side: ChatSide,
override val latencyMs: Float = 0f,
) : ChatMessage(type = ChatMessageType.AUDIO_CLIP, side = side, latencyMs = latencyMs) {
override fun clone(): ChatMessageAudioClip {
return ChatMessageAudioClip(
audioData = audioData,
sampleRate = sampleRate,
side = side,
latencyMs = latencyMs,
)
}
fun genByteArrayForWav(): ByteArray {
val header = ByteArray(44)
val pcmDataSize = audioData.size
val wavFileSize = pcmDataSize + 44 // 44 bytes for the header
val channels = 1 // Mono
val bitsPerSample: Short = 16
val byteRate = sampleRate * channels * bitsPerSample / 8
Log.d(TAG, "Wav metadata: sampleRate: $sampleRate")
// RIFF/WAVE header
header[0] = 'R'.code.toByte()
header[1] = 'I'.code.toByte()
header[2] = 'F'.code.toByte()
header[3] = 'F'.code.toByte()
header[4] = (wavFileSize and 0xff).toByte()
header[5] = (wavFileSize shr 8 and 0xff).toByte()
header[6] = (wavFileSize shr 16 and 0xff).toByte()
header[7] = (wavFileSize shr 24 and 0xff).toByte()
header[8] = 'W'.code.toByte()
header[9] = 'A'.code.toByte()
header[10] = 'V'.code.toByte()
header[11] = 'E'.code.toByte()
header[12] = 'f'.code.toByte()
header[13] = 'm'.code.toByte()
header[14] = 't'.code.toByte()
header[15] = ' '.code.toByte()
header[16] = 16
header[17] = 0
header[18] = 0
header[19] = 0 // Sub-chunk size (16 for PCM)
header[20] = 1
header[21] = 0 // Audio format (1 for PCM)
header[22] = channels.toByte()
header[23] = 0 // Number of channels
header[24] = (sampleRate and 0xff).toByte()
header[25] = (sampleRate shr 8 and 0xff).toByte()
header[26] = (sampleRate shr 16 and 0xff).toByte()
header[27] = (sampleRate shr 24 and 0xff).toByte()
header[28] = (byteRate and 0xff).toByte()
header[29] = (byteRate shr 8 and 0xff).toByte()
header[30] = (byteRate shr 16 and 0xff).toByte()
header[31] = (byteRate shr 24 and 0xff).toByte()
header[32] = (channels * bitsPerSample / 8).toByte()
header[33] = 0 // Block align
header[34] = bitsPerSample.toByte()
header[35] = (bitsPerSample.toInt() shr 8 and 0xff).toByte() // Bits per sample
header[36] = 'd'.code.toByte()
header[37] = 'a'.code.toByte()
header[38] = 't'.code.toByte()
header[39] = 'a'.code.toByte()
header[40] = (pcmDataSize and 0xff).toByte()
header[41] = (pcmDataSize shr 8 and 0xff).toByte()
header[42] = (pcmDataSize shr 16 and 0xff).toByte()
header[43] = (pcmDataSize shr 24 and 0xff).toByte()
return header + audioData
}
fun getDurationInSeconds(): Float {
// PCM 16-bit
val bytesPerSample = 2
val bytesPerFrame = bytesPerSample * 1 // mono
val totalFrames = audioData.size.toFloat() / bytesPerFrame
return totalFrames / sampleRate
}
}
/** Chat message for images with history. */
class ChatMessageImageWithHistory(
val bitmaps: List<Bitmap>,

View file

@ -137,6 +137,19 @@ fun ChatPanel(
}
imageMessageCount
}
val audioClipMesssageCountToLastconfigChange =
remember(messages) {
var audioClipMessageCount = 0
for (message in messages.reversed()) {
if (message is ChatMessageConfigValuesChange) {
break
}
if (message is ChatMessageAudioClip) {
audioClipMessageCount++
}
}
audioClipMessageCount
}
var curMessage by remember { mutableStateOf("") } // Correct state
val focusManager = LocalFocusManager.current
@ -342,6 +355,9 @@ fun ChatPanel(
imageHistoryCurIndex = imageHistoryCurIndex,
)
// Audio clip.
is ChatMessageAudioClip -> MessageBodyAudioClip(message = message)
// Classification result
is ChatMessageClassification ->
MessageBodyClassification(
@ -467,6 +483,22 @@ fun ChatPanel(
)
}
}
// Show an info message for ask image task to get users started.
else if (task.type == TaskType.LLM_ASK_AUDIO && messages.isEmpty()) {
Column(
modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center,
) {
MessageBodyInfo(
ChatMessageInfo(
content =
"To get started, tap the + icon to add your audio clips. You can add up to 10 clips, each up to 30 seconds long."
),
smallFontSize = false,
)
}
}
}
// Chat input
@ -482,6 +514,7 @@ fun ChatPanel(
isResettingSession = uiState.isResettingSession,
modelPreparing = uiState.preparing,
imageMessageCount = imageMessageCountToLastConfigChange,
audioClipMessageCount = audioClipMesssageCountToLastconfigChange,
modelInitializing =
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
@ -504,7 +537,10 @@ fun ChatPanel(
onStopButtonClicked = onStopButtonClicked,
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
showPromptTemplatesInMenu = false,
showImagePickerInMenu = selectedModel.llmSupportImage,
showImagePickerInMenu =
selectedModel.llmSupportImage && task.type === TaskType.LLM_ASK_IMAGE,
showAudioItemsInMenu =
selectedModel.llmSupportAudio && task.type === TaskType.LLM_ASK_AUDIO,
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
)
}

View file

@ -0,0 +1,33 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.common.chat
import androidx.compose.foundation.layout.padding
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
@Composable
fun MessageBodyAudioClip(message: ChatMessageAudioClip, modifier: Modifier = Modifier) {
AudioPlaybackPanel(
audioData = message.audioData,
sampleRate = message.sampleRate,
isRecording = false,
modifier = Modifier.padding(end = 16.dp),
onDarkBg = true,
)
}

View file

@ -21,6 +21,7 @@ package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import android.Manifest
import android.content.Context
import android.content.Intent
import android.content.pm.PackageManager
import android.graphics.Bitmap
import android.graphics.BitmapFactory
@ -65,9 +66,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.Send
import androidx.compose.material.icons.rounded.Add
import androidx.compose.material.icons.rounded.AudioFile
import androidx.compose.material.icons.rounded.Close
import androidx.compose.material.icons.rounded.FlipCameraAndroid
import androidx.compose.material.icons.rounded.History
import androidx.compose.material.icons.rounded.Mic
import androidx.compose.material.icons.rounded.Photo
import androidx.compose.material.icons.rounded.PhotoCamera
import androidx.compose.material.icons.rounded.PostAdd
@ -107,9 +110,14 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.viewinterop.AndroidView
import androidx.core.content.ContextCompat
import androidx.lifecycle.compose.LocalLifecycleOwner
import com.google.ai.edge.gallery.common.AudioClip
import com.google.ai.edge.gallery.common.convertWavToMonoWithMaxSeconds
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_COUNT
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import java.util.concurrent.Executors
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
private const val TAG = "AGMessageInputText"
@ -128,6 +136,7 @@ fun MessageInputText(
isResettingSession: Boolean,
inProgress: Boolean,
imageMessageCount: Int,
audioClipMessageCount: Int,
modelInitializing: Boolean,
@StringRes textFieldPlaceHolderRes: Int,
onValueChanged: (String) -> Unit,
@ -137,6 +146,7 @@ fun MessageInputText(
onStopButtonClicked: () -> Unit = {},
showPromptTemplatesInMenu: Boolean = false,
showImagePickerInMenu: Boolean = false,
showAudioItemsInMenu: Boolean = false,
showStopButtonWhenInProgress: Boolean = false,
) {
val context = LocalContext.current
@ -146,7 +156,12 @@ fun MessageInputText(
var showTextInputHistorySheet by remember { mutableStateOf(false) }
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var showAudioRecorderBottomSheet by remember { mutableStateOf(false) }
val audioRecorderSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
var pickedAudioClips by remember { mutableStateOf<List<AudioClip>>(listOf()) }
var hasFrontCamera by remember { mutableStateOf(false) }
val updatePickedImages: (List<Bitmap>) -> Unit = { bitmaps ->
var newPickedImages: MutableList<Bitmap> = mutableListOf()
newPickedImages.addAll(pickedImages)
@ -156,7 +171,16 @@ fun MessageInputText(
}
pickedImages = newPickedImages.toList()
}
var hasFrontCamera by remember { mutableStateOf(false) }
val updatePickedAudioClips: (List<AudioClip>) -> Unit = { audioDataList ->
var newAudioDataList: MutableList<AudioClip> = mutableListOf()
newAudioDataList.addAll(pickedAudioClips)
newAudioDataList.addAll(audioDataList)
if (newAudioDataList.size > MAX_AUDIO_CLIP_COUNT) {
newAudioDataList = newAudioDataList.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
}
pickedAudioClips = newAudioDataList.toList()
}
LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) }
@ -170,6 +194,16 @@ fun MessageInputText(
}
}
// Permission request when recording audio clips.
val recordAudioClipsPermissionLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
permissionGranted ->
if (permissionGranted) {
showAddContentMenu = false
showAudioRecorderBottomSheet = true
}
}
// Registers a photo picker activity launcher in single-select mode.
val pickMedia =
rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris ->
@ -184,9 +218,31 @@ fun MessageInputText(
}
}
val pickWav =
rememberLauncherForActivityResult(
contract = ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == android.app.Activity.RESULT_OK) {
result.data?.data?.let { uri ->
Log.d(TAG, "Picked wav file: $uri")
scope.launch(Dispatchers.IO) {
convertWavToMonoWithMaxSeconds(context = context, stereoUri = uri)?.let { audioClip ->
updatePickedAudioClips(
listOf(
AudioClip(audioData = audioClip.audioData, sampleRate = audioClip.sampleRate)
)
)
}
}
}
} else {
Log.d(TAG, "Wav picking cancelled.")
}
}
Column {
// A preview panel for the selected image.
if (pickedImages.isNotEmpty()) {
// A preview panel for the selected images and audio clips.
if (pickedImages.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
Row(
modifier =
Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()),
@ -203,20 +259,30 @@ fun MessageInputText(
.clip(RoundedCornerShape(8.dp))
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
)
MediaPanelCloseButton { pickedImages = pickedImages.filter { image != it } }
}
}
for ((index, audioClip) in pickedAudioClips.withIndex()) {
Box(contentAlignment = Alignment.TopEnd) {
Box(
modifier =
Modifier.offset(x = 10.dp, y = (-10).dp)
.clip(CircleShape)
Modifier.shadow(2.dp, shape = RoundedCornerShape(8.dp))
.clip(RoundedCornerShape(8.dp))
.background(MaterialTheme.colorScheme.surface)
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
.clickable { pickedImages = pickedImages.filter { image != it } }
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp))
) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
modifier = Modifier.padding(3.dp).size(16.dp),
AudioPlaybackPanel(
audioData = audioClip.audioData,
sampleRate = audioClip.sampleRate,
isRecording = false,
modifier = Modifier.padding(end = 16.dp),
)
}
MediaPanelCloseButton {
pickedAudioClips =
pickedAudioClips.filterIndexed { curIndex, curAudioData -> curIndex != index }
}
}
}
}
@ -239,10 +305,13 @@ fun MessageInputText(
verticalAlignment = Alignment.CenterVertically,
) {
val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT
val enableRecordAudioClipMenuItems =
(audioClipMessageCount + pickedAudioClips.size) < MAX_AUDIO_CLIP_COUNT
DropdownMenu(
expanded = showAddContentMenu,
onDismissRequest = { showAddContentMenu = false },
) {
// Image related menu items.
if (showImagePickerInMenu) {
// Take a picture.
DropdownMenuItem(
@ -295,6 +364,70 @@ fun MessageInputText(
)
}
// Audio related menu items.
if (showAudioItemsInMenu) {
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(Icons.Rounded.Mic, contentDescription = "")
Text("Record audio clip")
}
},
enabled = enableRecordAudioClipMenuItems,
onClick = {
// Check permission
when (PackageManager.PERMISSION_GRANTED) {
// Already got permission. Call the lambda.
ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO) -> {
showAddContentMenu = false
showAudioRecorderBottomSheet = true
}
// Otherwise, ask for permission
else -> {
recordAudioClipsPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO)
}
}
},
)
DropdownMenuItem(
text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(6.dp),
) {
Icon(Icons.Rounded.AudioFile, contentDescription = "")
Text("Pick wav file")
}
},
enabled = enableRecordAudioClipMenuItems,
onClick = {
showAddContentMenu = false
// Show file picker.
val intent =
Intent(Intent.ACTION_GET_CONTENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "audio/*"
// Provide a list of more specific MIME types to filter for.
val mimeTypes = arrayOf("audio/wav", "audio/x-wav")
putExtra(Intent.EXTRA_MIME_TYPES, mimeTypes)
// Single select.
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
.addFlags(Intent.FLAG_GRANT_PERSISTABLE_URI_PERMISSION)
.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
}
pickWav.launch(intent)
},
)
}
// Prompt templates.
if (showPromptTemplatesInMenu) {
DropdownMenuItem(
@ -369,15 +502,22 @@ fun MessageInputText(
)
}
}
} // Send button. Only shown when text is not empty.
else if (curMessage.isNotEmpty()) {
}
// Send button. Only shown when text is not empty, or there is at least one recorded
// audio clip.
else if (curMessage.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
IconButton(
enabled = !inProgress && !isResettingSession,
onClick = {
onSendMessage(
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
createMessagesToSend(
pickedImages = pickedImages,
audioClips = pickedAudioClips,
text = curMessage.trim(),
)
)
pickedImages = listOf()
pickedAudioClips = listOf()
},
colors =
IconButtonDefaults.iconButtonColors(
@ -403,8 +543,15 @@ fun MessageInputText(
history = modelManagerUiState.textInputHistory,
onDismissed = { showTextInputHistorySheet = false },
onHistoryItemClicked = { item ->
onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
onSendMessage(
createMessagesToSend(
pickedImages = pickedImages,
audioClips = pickedAudioClips,
text = item,
)
)
pickedImages = listOf()
pickedAudioClips = listOf()
modelManagerViewModel.promoteTextInputHistoryItem(item)
},
onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) },
@ -582,6 +729,43 @@ fun MessageInputText(
}
}
}
if (showAudioRecorderBottomSheet) {
ModalBottomSheet(
sheetState = audioRecorderSheetState,
onDismissRequest = { showAudioRecorderBottomSheet = false },
) {
AudioRecorderPanel(
onSendAudioClip = { audioData ->
scope.launch {
updatePickedAudioClips(
listOf(AudioClip(audioData = audioData, sampleRate = SAMPLE_RATE))
)
audioRecorderSheetState.hide()
showAudioRecorderBottomSheet = false
}
}
)
}
}
}
@Composable
private fun MediaPanelCloseButton(onClicked: () -> Unit) {
Box(
modifier =
Modifier.offset(x = 10.dp, y = (-10).dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surface)
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
.clickable { onClicked() }
) {
Icon(
Icons.Rounded.Close,
contentDescription = "",
modifier = Modifier.padding(3.dp).size(16.dp),
)
}
}
private fun handleImagesSelected(
@ -641,20 +825,50 @@ private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) {
)
}
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> {
private fun createMessagesToSend(
pickedImages: List<Bitmap>,
audioClips: List<AudioClip>,
text: String,
): List<ChatMessage> {
var messages: MutableList<ChatMessage> = mutableListOf()
// Add image messages.
var imageMessages: MutableList<ChatMessageImage> = mutableListOf()
if (pickedImages.isNotEmpty()) {
for (image in pickedImages) {
messages.add(
imageMessages.add(
ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER)
)
}
}
// Cap the number of image messages.
if (messages.size > MAX_IMAGE_COUNT) {
messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
if (imageMessages.size > MAX_IMAGE_COUNT) {
imageMessages = imageMessages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
}
messages.addAll(imageMessages)
// Add audio messages.
var audioMessages: MutableList<ChatMessageAudioClip> = mutableListOf()
if (audioClips.isNotEmpty()) {
for (audioClip in audioClips) {
audioMessages.add(
ChatMessageAudioClip(
audioData = audioClip.audioData,
sampleRate = audioClip.sampleRate,
side = ChatSide.USER,
)
)
}
}
// Cap the number of audio messages.
if (audioMessages.size > MAX_AUDIO_CLIP_COUNT) {
audioMessages = audioMessages.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
}
messages.addAll(audioMessages)
if (text.isNotEmpty()) {
messages.add(ChatMessageText(content = text, side = ChatSide.USER))
}
messages.add(ChatMessageText(content = text, side = ChatSide.USER))
return messages
}

View file

@ -121,6 +121,7 @@ private val IMPORT_CONFIGS_LLM: List<Config> =
valueType = ValueType.FLOAT,
),
BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false),
BooleanSwitchConfig(key = ConfigKey.SUPPORT_AUDIO, defaultValue = false),
SegmentedButtonConfig(
key = ConfigKey.COMPATIBLE_ACCELERATORS,
defaultValue = Accelerator.CPU.label,
@ -230,6 +231,12 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
valueType = ValueType.BOOLEAN,
)
as Boolean
val supportAudio =
convertValueToTargetType(
value = values.get(ConfigKey.SUPPORT_AUDIO.label)!!,
valueType = ValueType.BOOLEAN,
)
as Boolean
val importedModel: ImportedModel =
ImportedModel.newBuilder()
.setFileName(fileName)
@ -242,6 +249,7 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
.setDefaultTopp(defaultTopp)
.setDefaultTemperature(defaultTemperature)
.setSupportImage(supportImage)
.setSupportAudio(supportAudio)
.build()
)
.build()

View file

@ -173,7 +173,7 @@ fun SettingsDialog(
color = MaterialTheme.colorScheme.onSurfaceVariant,
)
Text(
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
"Expires at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant,
)

View file

@ -1,78 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Forum: ImageVector
get() {
if (_Forum != null) return _Forum!!
_Forum =
ImageVector.Builder(
name = "Forum",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(280f, 720f)
quadToRelative(-17f, 0f, -28.5f, -11.5f)
reflectiveQuadTo(240f, 680f)
verticalLineToRelative(-80f)
horizontalLineToRelative(520f)
verticalLineToRelative(-360f)
horizontalLineToRelative(80f)
quadToRelative(17f, 0f, 28.5f, 11.5f)
reflectiveQuadTo(880f, 280f)
verticalLineToRelative(600f)
lineTo(720f, 720f)
close()
moveTo(80f, 680f)
verticalLineToRelative(-560f)
quadToRelative(0f, -17f, 11.5f, -28.5f)
reflectiveQuadTo(120f, 80f)
horizontalLineToRelative(520f)
quadToRelative(17f, 0f, 28.5f, 11.5f)
reflectiveQuadTo(680f, 120f)
verticalLineToRelative(360f)
quadToRelative(0f, 17f, -11.5f, 28.5f)
reflectiveQuadTo(640f, 520f)
horizontalLineTo(240f)
close()
moveToRelative(520f, -240f)
verticalLineToRelative(-280f)
horizontalLineTo(160f)
verticalLineToRelative(280f)
close()
moveToRelative(-440f, 0f)
verticalLineToRelative(-280f)
close()
}
}
.build()
return _Forum!!
}
private var _Forum: ImageVector? = null

View file

@ -1,73 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Mms: ImageVector
get() {
if (_Mms != null) return _Mms!!
_Mms =
ImageVector.Builder(
name = "Mms",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(240f, 560f)
horizontalLineToRelative(480f)
lineTo(570f, 360f)
lineTo(450f, 520f)
lineToRelative(-90f, -120f)
close()
moveTo(80f, 880f)
verticalLineToRelative(-720f)
quadToRelative(0f, -33f, 23.5f, -56.5f)
reflectiveQuadTo(160f, 80f)
horizontalLineToRelative(640f)
quadToRelative(33f, 0f, 56.5f, 23.5f)
reflectiveQuadTo(880f, 160f)
verticalLineToRelative(480f)
quadToRelative(0f, 33f, -23.5f, 56.5f)
reflectiveQuadTo(800f, 720f)
horizontalLineTo(240f)
close()
moveToRelative(126f, -240f)
horizontalLineToRelative(594f)
verticalLineToRelative(-480f)
horizontalLineTo(160f)
verticalLineToRelative(525f)
close()
moveToRelative(-46f, 0f)
verticalLineToRelative(-480f)
close()
}
}
.build()
return _Mms!!
}
private var _Mms: ImageVector? = null

View file

@ -1,87 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.icon
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.graphics.vector.path
import androidx.compose.ui.unit.dp
val Widgets: ImageVector
get() {
if (_Widgets != null) return _Widgets!!
_Widgets =
ImageVector.Builder(
name = "Widgets",
defaultWidth = 24.dp,
defaultHeight = 24.dp,
viewportWidth = 960f,
viewportHeight = 960f,
)
.apply {
path(fill = SolidColor(Color(0xFF000000))) {
moveTo(666f, 520f)
lineTo(440f, 294f)
lineToRelative(226f, -226f)
lineToRelative(226f, 226f)
close()
moveToRelative(-546f, -80f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(400f, 400f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(-400f, 0f)
verticalLineToRelative(-320f)
horizontalLineToRelative(320f)
verticalLineToRelative(320f)
close()
moveToRelative(80f, -480f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(200f)
close()
moveToRelative(467f, 48f)
lineToRelative(113f, -113f)
lineToRelative(-113f, -113f)
lineToRelative(-113f, 113f)
close()
moveToRelative(-67f, 352f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(600f)
close()
moveToRelative(-400f, 0f)
horizontalLineToRelative(160f)
verticalLineToRelative(-160f)
horizontalLineTo(200f)
close()
moveToRelative(400f, -160f)
}
}
.build()
return _Widgets!!
}
private var _Widgets: ImageVector? = null

View file

@ -62,13 +62,13 @@ object LlmChatModelHelper {
Accelerator.GPU.label -> LlmInference.Backend.GPU
else -> LlmInference.Backend.GPU
}
val options =
val optionsBuilder =
LlmInference.LlmInferenceOptions.builder()
.setModelPath(model.getPath(context = context))
.setMaxTokens(maxTokens)
.setPreferredBackend(preferredBackend)
.setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0)
.build()
val options = optionsBuilder.build()
// Create an instance of the LLM Inference task and session.
try {
@ -82,7 +82,9 @@ object LlmChatModelHelper {
.setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
GraphOptions.builder()
.setEnableVisionModality(model.llmSupportImage)
.build()
)
.build(),
)
@ -115,7 +117,9 @@ object LlmChatModelHelper {
.setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
GraphOptions.builder()
.setEnableVisionModality(model.llmSupportImage)
.build()
)
.build(),
)
@ -159,6 +163,7 @@ object LlmChatModelHelper {
resultListener: ResultListener,
cleanUpListener: CleanUpListener,
images: List<Bitmap> = listOf(),
audioClips: List<ByteArray> = listOf(),
) {
val instance = model.instance as LlmModelInstance
@ -172,10 +177,16 @@ object LlmChatModelHelper {
// For a model that supports image modality, we need to add the text query chunk before adding
// image.
val session = instance.session
session.addQueryChunk(input)
if (input.trim().isNotEmpty()) {
session.addQueryChunk(input)
}
for (image in images) {
session.addImage(BitmapImageBuilder(image).build())
}
for (audioClip in audioClips) {
// Uncomment when audio is supported.
// session.addAudio(audioClip)
}
val unused = session.generateResponseAsync(resultListener)
}
}

View file

@ -22,6 +22,7 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatView
@ -36,6 +37,10 @@ object LlmAskImageDestination {
val route = "LlmAskImageRoute"
}
object LlmAskAudioDestination {
val route = "LlmAskAudioRoute"
}
@Composable
fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel,
@ -66,6 +71,21 @@ fun LlmAskImageScreen(
)
}
@Composable
fun LlmAskAudioScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmAskAudioViewModel = viewModel(factory = ViewModelProvider.Factory),
) {
ChatViewWrapper(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = navigateUp,
modifier = modifier,
)
}
@Composable
fun ChatViewWrapper(
viewModel: LlmChatViewModel,
@ -86,6 +106,7 @@ fun ChatViewWrapper(
var text = ""
val images: MutableList<Bitmap> = mutableListOf()
val audioMessages: MutableList<ChatMessageAudioClip> = mutableListOf()
var chatMessageText: ChatMessageText? = null
for (message in messages) {
if (message is ChatMessageText) {
@ -93,14 +114,17 @@ fun ChatViewWrapper(
text = message.content
} else if (message is ChatMessageImage) {
images.add(message.bitmap)
} else if (message is ChatMessageAudioClip) {
audioMessages.add(message)
}
}
if (text.isNotEmpty() && chatMessageText != null) {
if ((text.isNotEmpty() && chatMessageText != null) || audioMessages.isNotEmpty()) {
modelManagerViewModel.addTextInputHistory(text)
viewModel.generateResponse(
model = model,
input = text,
images = images,
audioMessages = audioMessages,
onError = {
viewModel.handleError(
context = context,

View file

@ -22,9 +22,11 @@ import android.util.Log
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
@ -52,6 +54,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
model: Model,
input: String,
images: List<Bitmap> = listOf(),
audioMessages: List<ChatMessageAudioClip> = listOf(),
onError: () -> Unit,
) {
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
@ -72,6 +75,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
val instance = model.instance as LlmModelInstance
var prefillTokens = instance.session.sizeInTokens(input)
prefillTokens += images.size * 257
for (audioMessages in audioMessages) {
// 150ms = 1 audio token
val duration = audioMessages.getDurationInSeconds()
prefillTokens += (duration * 1000f / 150f).toInt()
}
var firstRun = true
var timeToFirstToken = 0f
@ -86,6 +94,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
model = model,
input = input,
images = images,
audioClips = audioMessages.map { it.genByteArrayForWav() },
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
@ -214,7 +223,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
context: Context,
model: Model,
modelManagerViewModel: ModelManagerViewModel,
triggeredMessage: ChatMessageText,
triggeredMessage: ChatMessageText?,
) {
// Clean up.
modelManagerViewModel.cleanupModel(task = task, model = model)
@ -236,14 +245,20 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
)
// Add the triggered message back.
addMessage(model = model, message = triggeredMessage)
if (triggeredMessage != null) {
addMessage(model = model, message = triggeredMessage)
}
// Re-initialize the session/engine.
modelManagerViewModel.initializeModel(context = context, task = task, model = model)
// Re-generate the response automatically.
generateResponse(model = model, input = triggeredMessage.content, onError = {})
if (triggeredMessage != null) {
generateResponse(model = model, input = triggeredMessage.content, onError = {})
}
}
}
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
class LlmAskAudioViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_AUDIO)

View file

@ -36,6 +36,7 @@ import com.google.ai.edge.gallery.data.ModelAllowlist
import com.google.ai.edge.gallery.data.ModelDownloadStatus
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.data.TASKS
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
@ -281,15 +282,12 @@ open class ModelManagerViewModel(
}
}
when (task.type) {
TaskType.LLM_CHAT ->
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
TaskType.LLM_CHAT,
TaskType.LLM_ASK_IMAGE,
TaskType.LLM_ASK_AUDIO,
TaskType.LLM_PROMPT_LAB ->
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
TaskType.LLM_ASK_IMAGE ->
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
@ -301,9 +299,11 @@ open class ModelManagerViewModel(
model.cleanUpAfterInit = false
Log.d(TAG, "Cleaning up model '${model.name}'...")
when (task.type) {
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_CHAT,
TaskType.LLM_PROMPT_LAB,
TaskType.LLM_ASK_IMAGE,
TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model)
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
@ -410,14 +410,19 @@ open class ModelManagerViewModel(
// Create model.
val model = createModelFromImportedModelInfo(info = info)
for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
for (task in
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
// Remove duplicated imported model if existed.
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
if (modelIndex >= 0) {
Log.d(TAG, "duplicated imported model found in task. Removing it first")
task.models.removeAt(modelIndex)
}
if ((task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || task != TASK_LLM_ASK_IMAGE) {
if (
(task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) ||
(task == TASK_LLM_ASK_AUDIO && model.llmSupportAudio) ||
(task != TASK_LLM_ASK_IMAGE && task != TASK_LLM_ASK_AUDIO)
) {
task.models.add(model)
}
task.updateTrigger.value = System.currentTimeMillis()
@ -657,6 +662,7 @@ open class ModelManagerViewModel(
TASK_LLM_CHAT.models.clear()
TASK_LLM_PROMPT_LAB.models.clear()
TASK_LLM_ASK_IMAGE.models.clear()
TASK_LLM_ASK_AUDIO.models.clear()
for (allowedModel in modelAllowlist.models) {
if (allowedModel.disabled == true) {
continue
@ -672,6 +678,9 @@ open class ModelManagerViewModel(
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
TASK_LLM_ASK_IMAGE.models.add(model)
}
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_AUDIO.type.id)) {
TASK_LLM_ASK_AUDIO.models.add(model)
}
}
// Pre-process all tasks.
@ -760,6 +769,9 @@ open class ModelManagerViewModel(
if (model.llmSupportImage) {
TASK_LLM_ASK_IMAGE.models.add(model)
}
if (model.llmSupportAudio) {
TASK_LLM_ASK_AUDIO.models.add(model)
}
// Update status.
modelDownloadStatus[model.name] =
@ -800,6 +812,7 @@ open class ModelManagerViewModel(
accelerators = accelerators,
)
val llmSupportImage = info.llmConfig.supportImage
val llmSupportAudio = info.llmConfig.supportAudio
val model =
Model(
name = info.fileName,
@ -811,6 +824,7 @@ open class ModelManagerViewModel(
showRunAgainButton = false,
imported = true,
llmSupportImage = llmSupportImage,
llmSupportAudio = llmSupportAudio,
)
model.preProcess()

View file

@ -47,6 +47,7 @@ import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.navArgument
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
@ -55,6 +56,8 @@ import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.data.getModelByName
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.home.HomeScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination
@ -209,7 +212,7 @@ fun GalleryNavHost(
}
}
// LLM image to text.
// Ask image.
composable(
route = "${LlmAskImageDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
@ -225,6 +228,23 @@ fun GalleryNavHost(
)
}
}
// Ask audio.
composable(
route = "${LlmAskAudioDestination.route}/{modelName}",
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_ASK_AUDIO)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmAskAudioScreen(
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
}
}
}
// Handle incoming intents for deep links
@ -256,6 +276,7 @@ fun navigateToTaskScreen(
when (taskType) {
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}")
TaskType.LLM_PROMPT_LAB ->
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.TEST_TASK_1 -> {}

View file

@ -120,6 +120,8 @@ data class CustomColors(
val agentBubbleBgColor: Color = Color.Transparent,
val linkColor: Color = Color.Transparent,
val successColor: Color = Color.Transparent,
val recordButtonBgColor: Color = Color.Transparent,
val waveFormBgColor: Color = Color.Transparent,
)
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
@ -145,6 +147,8 @@ val lightCustomColors =
userBubbleBgColor = Color(0xFF32628D),
linkColor = Color(0xFF32628D),
successColor = Color(0xff3d860b),
recordButtonBgColor = Color(0xFFEE675C),
waveFormBgColor = Color(0xFFaaaaaa),
)
val darkCustomColors =
@ -168,6 +172,8 @@ val darkCustomColors =
userBubbleBgColor = Color(0xFF1f3760),
linkColor = Color(0xFF9DCAFC),
successColor = Color(0xFFA1CE83),
recordButtonBgColor = Color(0xFFEE675C),
waveFormBgColor = Color(0xFFaaaaaa),
)
val MaterialTheme.customColors: CustomColors

View file

@ -55,6 +55,7 @@ message LlmConfig {
float default_topp = 4;
float default_temperature = 5;
bool support_image = 6;
bool support_audio = 7;
}
message Settings {