From d0989adce15d73e6d3221a9d7ed54fdf352c074c Mon Sep 17 00:00:00 2001 From: Google AI Edge Gallery Date: Mon, 23 Jun 2025 10:24:10 -0700 Subject: [PATCH] Add audio support. - Add a new task "audio scribe". - Allow users to record audio clips or pick wav files to interact with model. - Add support for importing models with audio capability. - Fix a typo in Settings dialog (Thanks https://github.com/rhnvrm!) PiperOrigin-RevId: 774832681 --- Android/src/app/src/main/AndroidManifest.xml | 1 + .../google/ai/edge/gallery/common/Types.kt | 2 + .../google/ai/edge/gallery/common/Utils.kt | 137 +++++++ .../com/google/ai/edge/gallery/data/Config.kt | 1 + .../com/google/ai/edge/gallery/data/Consts.kt | 9 + .../com/google/ai/edge/gallery/data/Model.kt | 3 + .../ai/edge/gallery/data/ModelAllowlist.kt | 2 + .../com/google/ai/edge/gallery/data/Tasks.kt | 32 +- .../ai/edge/gallery/ui/ViewModelProvider.kt | 4 + .../ui/common/chat/AudioPlaybackPanel.kt | 344 ++++++++++++++++++ .../ui/common/chat/AudioRecorderPanel.kt | 327 +++++++++++++++++ .../gallery/ui/common/chat/ChatMessage.kt | 88 +++++ .../edge/gallery/ui/common/chat/ChatPanel.kt | 38 +- .../ui/common/chat/MessageBodyAudioClip.kt | 33 ++ .../ui/common/chat/MessageInputText.kt | 254 ++++++++++++- .../edge/gallery/ui/home/ModelImportDialog.kt | 8 + .../ai/edge/gallery/ui/home/SettingsDialog.kt | 2 +- .../google/ai/edge/gallery/ui/icon/Forum.kt | 78 ---- .../com/google/ai/edge/gallery/ui/icon/Mms.kt | 73 ---- .../ai/edge/gallery/ui/icon/Widgets.kt.kt | 87 ----- .../gallery/ui/llmchat/LlmChatModelHelper.kt | 21 +- .../edge/gallery/ui/llmchat/LlmChatScreen.kt | 26 +- .../gallery/ui/llmchat/LlmChatViewModel.kt | 21 +- .../ui/modelmanager/ModelManagerViewModel.kt | 36 +- .../gallery/ui/navigation/GalleryNavGraph.kt | 23 +- .../google/ai/edge/gallery/ui/theme/Theme.kt | 6 + Android/src/app/src/main/proto/settings.proto | 1 + 27 files changed, 1369 insertions(+), 288 deletions(-) create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml index 2f04e88..6bacbd3 100644 --- a/Android/src/app/src/main/AndroidManifest.xml +++ b/Android/src/app/src/main/AndroidManifest.xml @@ -29,6 +29,7 @@ + (val jsonObj: T, val textContent: String) + +class AudioClip(val audioData: ByteArray, val sampleRate: Int) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt index 6d37f9d..1ada15d 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt @@ -17,12 +17,17 @@ package com.google.ai.edge.gallery.common import android.content.Context +import android.net.Uri import android.util.Log +import com.google.ai.edge.gallery.data.SAMPLE_RATE import com.google.gson.Gson import com.google.gson.reflect.TypeToken import java.io.File import java.net.HttpURLConnection import java.net.URL +import java.nio.ByteBuffer +import java.nio.ByteOrder +import kotlin.math.floor data class LaunchInfo(val ts: Long) @@ -112,3 +117,135 @@ inline fun getJsonResponse(url: String): JsonObjAndTextContent? { return null } + +fun convertWavToMonoWithMaxSeconds( + context: Context, + stereoUri: Uri, + maxSeconds: Int = 30, +): AudioClip? { + Log.d(TAG, "Start to convert wav file to mono channel") + + try { + val inputStream = context.contentResolver.openInputStream(stereoUri) ?: return null + val originalBytes = inputStream.readBytes() + inputStream.close() + + // Read WAV header + if (originalBytes.size < 44) { + // Not a valid WAV file + Log.e(TAG, "Not a valid wav file") + return null + } + + val headerBuffer = ByteBuffer.wrap(originalBytes, 0, 44).order(ByteOrder.LITTLE_ENDIAN) + val channels = headerBuffer.getShort(22) + var sampleRate = headerBuffer.getInt(24) + val bitDepth = headerBuffer.getShort(34) + Log.d(TAG, "File metadata: channels: $channels, sampleRate: $sampleRate, bitDepth: $bitDepth") + + // Normalize audio to 16-bit. + val audioDataBytes = originalBytes.copyOfRange(fromIndex = 44, toIndex = originalBytes.size) + var sixteenBitBytes: ByteArray = + if (bitDepth.toInt() == 8) { + Log.d(TAG, "Converting 8-bit audio to 16-bit.") + convert8BitTo16Bit(audioDataBytes) + } else { + // Assume 16-bit or other format that can be handled directly + audioDataBytes + } + + // Convert byte array to short array for processing + val shortBuffer = + ByteBuffer.wrap(sixteenBitBytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer() + var pcmSamples = ShortArray(shortBuffer.remaining()) + shortBuffer.get(pcmSamples) + + // Resample if sample rate is less than 16000 Hz --- + if (sampleRate < SAMPLE_RATE) { + Log.d(TAG, "Resampling from $sampleRate Hz to $SAMPLE_RATE Hz.") + pcmSamples = resample(pcmSamples, sampleRate, SAMPLE_RATE, channels.toInt()) + sampleRate = SAMPLE_RATE + Log.d(TAG, "Resampling complete. New sample count: ${pcmSamples.size}") + } + + // Convert stereo to mono if necessary + var monoSamples = + if (channels.toInt() == 2) { + Log.d(TAG, "Converting stereo to mono.") + val mono = ShortArray(pcmSamples.size / 2) + for (i in mono.indices) { + val left = pcmSamples[i * 2] + val right = pcmSamples[i * 2 + 1] + mono[i] = ((left + right) / 2).toShort() + } + mono + } else { + Log.d(TAG, "Audio is already mono. No channel conversion needed.") + pcmSamples + } + + // Trim the audio to maxSeconds --- + val maxSamples = maxSeconds * sampleRate + if (monoSamples.size > maxSamples) { + Log.d(TAG, "Trimming clip from ${monoSamples.size} samples to $maxSamples samples.") + monoSamples = monoSamples.copyOfRange(0, maxSamples) + } + + val monoByteBuffer = ByteBuffer.allocate(monoSamples.size * 2).order(ByteOrder.LITTLE_ENDIAN) + monoByteBuffer.asShortBuffer().put(monoSamples) + return AudioClip(audioData = monoByteBuffer.array(), sampleRate = sampleRate) + } catch (e: Exception) { + Log.e(TAG, "Failed to convert wav to mono", e) + return null + } +} + +/** Converts 8-bit unsigned PCM audio data to 16-bit signed PCM. */ +private fun convert8BitTo16Bit(eightBitData: ByteArray): ByteArray { + // The new 16-bit data will be twice the size + val sixteenBitData = ByteArray(eightBitData.size * 2) + val buffer = ByteBuffer.wrap(sixteenBitData).order(ByteOrder.LITTLE_ENDIAN) + + for (byte in eightBitData) { + // Convert the unsigned 8-bit byte (0-255) to a signed 16-bit short (-32768 to 32767) + // 1. Get the unsigned value by masking with 0xFF + // 2. Subtract 128 to center the waveform around 0 (range becomes -128 to 127) + // 3. Scale by 256 to expand to the 16-bit range + val unsignedByte = byte.toInt() and 0xFF + val sixteenBitSample = ((unsignedByte - 128) * 256).toShort() + buffer.putShort(sixteenBitSample) + } + return sixteenBitData +} + +/** Resamples PCM audio data from an original sample rate to a target sample rate. */ +private fun resample( + inputSamples: ShortArray, + originalSampleRate: Int, + targetSampleRate: Int, + channels: Int, +): ShortArray { + if (originalSampleRate == targetSampleRate) { + return inputSamples + } + + val ratio = targetSampleRate.toDouble() / originalSampleRate + val outputLength = (inputSamples.size * ratio).toInt() + val resampledData = ShortArray(outputLength) + + if (channels == 1) { // Mono + for (i in resampledData.indices) { + val position = i / ratio + val index1 = floor(position).toInt() + val index2 = index1 + 1 + val fraction = position - index1 + + val sample1 = if (index1 < inputSamples.size) inputSamples[index1].toDouble() else 0.0 + val sample2 = if (index2 < inputSamples.size) inputSamples[index2].toDouble() else 0.0 + + resampledData[i] = (sample1 * (1 - fraction) + sample2 * fraction).toInt().toShort() + } + } + + return resampledData +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt index d961ecc..2b9fdf7 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt @@ -50,6 +50,7 @@ enum class ConfigKey(val label: String) { DEFAULT_TOPP("Default TopP"), DEFAULT_TEMPERATURE("Default temperature"), SUPPORT_IMAGE("Support image"), + SUPPORT_AUDIO("Support audio"), MAX_RESULT_COUNT("Max result count"), USE_GPU("Use GPU"), ACCELERATOR("Choose accelerator"), diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt index a2209af..85fd71e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt @@ -44,3 +44,12 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU) // Max number of images allowed in a "ask image" session. const val MAX_IMAGE_COUNT = 10 + +// Max number of audio clip in an "ask audio" session. +const val MAX_AUDIO_CLIP_COUNT = 10 + +// Max audio clip duration in seconds. +const val MAX_AUDIO_CLIP_DURATION_SEC = 30 + +// Audio-recording related consts. +const val SAMPLE_RATE = 16000 diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt index 580bf2b..02e1368 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt @@ -87,6 +87,9 @@ data class Model( /** Whether the LLM model supports image input. */ val llmSupportImage: Boolean = false, + /** Whether the LLM model supports audio input. */ + val llmSupportAudio: Boolean = false, + /** Whether the model is imported or not. */ val imported: Boolean = false, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt index f336638..5cf11ca 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt @@ -38,6 +38,7 @@ data class AllowedModel( val taskTypes: List, val disabled: Boolean? = null, val llmSupportImage: Boolean? = null, + val llmSupportAudio: Boolean? = null, val estimatedPeakMemoryInBytes: Long? = null, ) { fun toModel(): Model { @@ -96,6 +97,7 @@ data class AllowedModel( showRunAgainButton = showRunAgainButton, learnMoreUrl = "https://huggingface.co/${modelId}", llmSupportImage = llmSupportImage == true, + llmSupportAudio = llmSupportAudio == true, ) } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt index e95feab..c52d2c3 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt @@ -17,19 +17,22 @@ package com.google.ai.edge.gallery.data import androidx.annotation.StringRes +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.outlined.Forum +import androidx.compose.material.icons.outlined.Mic +import androidx.compose.material.icons.outlined.Mms +import androidx.compose.material.icons.outlined.Widgets import androidx.compose.runtime.MutableState import androidx.compose.runtime.mutableLongStateOf import androidx.compose.ui.graphics.vector.ImageVector import com.google.ai.edge.gallery.R -import com.google.ai.edge.gallery.ui.icon.Forum -import com.google.ai.edge.gallery.ui.icon.Mms -import com.google.ai.edge.gallery.ui.icon.Widgets /** Type of task. */ enum class TaskType(val label: String, val id: String) { LLM_CHAT(label = "AI Chat", id = "llm_chat"), LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"), LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"), + LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"), TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_2(label = "Test task 2", id = "test_task_2"), } @@ -71,7 +74,7 @@ data class Task( val TASK_LLM_CHAT = Task( type = TaskType.LLM_CHAT, - icon = Forum, + icon = Icons.Outlined.Forum, models = mutableListOf(), description = "Chat with on-device large language models", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", @@ -83,7 +86,7 @@ val TASK_LLM_CHAT = val TASK_LLM_PROMPT_LAB = Task( type = TaskType.LLM_PROMPT_LAB, - icon = Widgets, + icon = Icons.Outlined.Widgets, models = mutableListOf(), description = "Single turn use cases with on-device large language model", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", @@ -95,7 +98,7 @@ val TASK_LLM_PROMPT_LAB = val TASK_LLM_ASK_IMAGE = Task( type = TaskType.LLM_ASK_IMAGE, - icon = Mms, + icon = Icons.Outlined.Mms, models = mutableListOf(), description = "Ask questions about images with on-device large language models", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", @@ -104,8 +107,23 @@ val TASK_LLM_ASK_IMAGE = textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat, ) +val TASK_LLM_ASK_AUDIO = + Task( + type = TaskType.LLM_ASK_AUDIO, + icon = Icons.Outlined.Mic, + models = mutableListOf(), + // TODO(do not submit) + description = + "Instantly transcribe and/or translate audio clips using on-device large language models", + docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", + sourceCodeUrl = + "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt", + textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat, + ) + /** All tasks. */ -val TASKS: List = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT) +val TASKS: List = + listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT) fun getModelByName(name: String): Model? { for (task in TASKS) { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt index fdded53..6ceb148 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt @@ -21,6 +21,7 @@ import androidx.lifecycle.viewmodel.CreationExtras import androidx.lifecycle.viewmodel.initializer import androidx.lifecycle.viewmodel.viewModelFactory import com.google.ai.edge.gallery.GalleryApplication +import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel @@ -49,6 +50,9 @@ object ViewModelProvider { // Initializer for LlmAskImageViewModel. initializer { LlmAskImageViewModel() } + + // Initializer for LlmAskAudioViewModel. + initializer { LlmAskAudioViewModel() } } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt new file mode 100644 index 0000000..9faebf0 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt @@ -0,0 +1,344 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.ui.common.chat + +import android.media.AudioAttributes +import android.media.AudioFormat +import android.media.AudioTrack +import android.util.Log +import androidx.compose.foundation.Canvas +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.width +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.rounded.PlayArrow +import androidx.compose.material.icons.rounded.Stop +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.DisposableEffect +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.MutableState +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableFloatStateOf +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.geometry.CornerRadius +import androidx.compose.ui.geometry.Offset +import androidx.compose.ui.geometry.Size +import androidx.compose.ui.geometry.toRect +import androidx.compose.ui.graphics.BlendMode +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.drawscope.drawIntoCanvas +import androidx.compose.ui.unit.dp +import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC +import com.google.ai.edge.gallery.ui.theme.customColors +import java.nio.ByteBuffer +import java.nio.ByteOrder +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext + +private const val TAG = "AGAudioPlaybackPanel" +private const val BAR_SPACE = 2 +private const val BAR_WIDTH = 2 +private const val MIN_BAR_COUNT = 16 +private const val MAX_BAR_COUNT = 48 + +/** + * A Composable that displays an audio playback panel, including play/stop controls, a waveform + * visualization, and the duration of the audio clip. + */ +@Composable +fun AudioPlaybackPanel( + audioData: ByteArray, + sampleRate: Int, + isRecording: Boolean, + modifier: Modifier = Modifier, + onDarkBg: Boolean = false, +) { + val coroutineScope = rememberCoroutineScope() + var isPlaying by remember { mutableStateOf(false) } + val audioTrackState = remember { mutableStateOf(null) } + val durationInSeconds = + remember(audioData) { + // PCM 16-bit + val bytesPerSample = 2 + val bytesPerFrame = bytesPerSample * 1 // mono + val totalFrames = audioData.size.toDouble() / bytesPerFrame + totalFrames / sampleRate + } + val barCount = + remember(durationInSeconds) { + val f = durationInSeconds / MAX_AUDIO_CLIP_DURATION_SEC + ((MAX_BAR_COUNT - MIN_BAR_COUNT) * f + MIN_BAR_COUNT).toInt() + } + val amplitudeLevels = + remember(audioData) { generateAmplitudeLevels(audioData = audioData, barCount = barCount) } + var playbackProgress by remember { mutableFloatStateOf(0f) } + + // Reset when a new recording is started. + LaunchedEffect(isRecording) { + if (isRecording) { + val audioTrack = audioTrackState.value + audioTrack?.stop() + isPlaying = false + playbackProgress = 0f + } + } + + // Cleanup on Composable Disposal. + DisposableEffect(Unit) { + onDispose { + val audioTrack = audioTrackState.value + audioTrack?.stop() + audioTrack?.release() + } + } + + Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) { + // Button to play/stop the clip. + IconButton( + onClick = { + coroutineScope.launch { + if (!isPlaying) { + isPlaying = true + playAudio( + audioTrackState = audioTrackState, + audioData = audioData, + sampleRate = sampleRate, + onProgress = { playbackProgress = it }, + onCompletion = { + playbackProgress = 0f + isPlaying = false + }, + ) + } else { + stopPlayAudio(audioTrackState = audioTrackState) + playbackProgress = 0f + isPlaying = false + } + } + } + ) { + Icon( + if (isPlaying) Icons.Rounded.Stop else Icons.Rounded.PlayArrow, + contentDescription = "", + tint = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary, + ) + } + + // Visualization + AmplitudeBarGraph( + amplitudeLevels = amplitudeLevels, + progress = playbackProgress, + modifier = + Modifier.width((barCount * BAR_WIDTH + (barCount - 1) * BAR_SPACE).dp).height(24.dp), + onDarkBg = onDarkBg, + ) + + // Duration + Text( + "${"%.1f".format(durationInSeconds)}s", + style = MaterialTheme.typography.labelLarge, + color = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary, + modifier = Modifier.padding(start = 12.dp), + ) + } +} + +@Composable +private fun AmplitudeBarGraph( + amplitudeLevels: List, + progress: Float, + modifier: Modifier = Modifier, + onDarkBg: Boolean = false, +) { + val barColor = MaterialTheme.customColors.waveFormBgColor + val progressColor = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary + + Canvas(modifier = modifier) { + val barCount = amplitudeLevels.size + val barWidth = (size.width - BAR_SPACE.dp.toPx() * (barCount - 1)) / barCount + val cornerRadius = CornerRadius(x = barWidth, y = barWidth) + + // Use drawIntoCanvas for advanced blend mode operations + drawIntoCanvas { canvas -> + + // 1. Save the current state of the canvas onto a temporary, offscreen layer + canvas.saveLayer(size.toRect(), androidx.compose.ui.graphics.Paint()) + + // 2. Draw the bars in grey. + amplitudeLevels.forEachIndexed { index, level -> + val barHeight = (level * size.height).coerceAtLeast(1.5f) + val left = index * (barWidth + BAR_SPACE.dp.toPx()) + drawRoundRect( + color = barColor, + topLeft = Offset(x = left, y = size.height / 2 - barHeight / 2), + size = Size(barWidth, barHeight), + cornerRadius = cornerRadius, + ) + } + + // 3. Draw the progress rectangle using BlendMode.SrcIn to only draw where the bars already + // exists. + val progressWidth = size.width * progress + drawRect( + color = progressColor, + topLeft = Offset.Zero, + size = Size(progressWidth, size.height), + blendMode = BlendMode.SrcIn, + ) + + // 4. Restore the layer, merging it onto the main canvas + canvas.restore() + } + } +} + +private suspend fun playAudio( + audioTrackState: MutableState, + audioData: ByteArray, + sampleRate: Int, + onProgress: (Float) -> Unit, + onCompletion: () -> Unit, +) { + Log.d(TAG, "Start playing audio...") + + try { + withContext(Dispatchers.IO) { + var lastProgressUpdateMs = 0L + audioTrackState.value?.release() + val audioTrack = + AudioTrack.Builder() + .setAudioAttributes( + AudioAttributes.Builder() + .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) + .setUsage(AudioAttributes.USAGE_MEDIA) + .build() + ) + .setAudioFormat( + AudioFormat.Builder() + .setEncoding(AudioFormat.ENCODING_PCM_16BIT) + .setSampleRate(sampleRate) + .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) + .build() + ) + .setTransferMode(AudioTrack.MODE_STATIC) + .setBufferSizeInBytes(audioData.size) + .build() + + val bytesPerFrame = 2 // For PCM 16-bit Mono + val totalFrames = audioData.size / bytesPerFrame + + audioTrackState.value = audioTrack + audioTrack.write(audioData, 0, audioData.size) + audioTrack.play() + + // Coroutine to monitor progress + while (isActive && audioTrack.playState == AudioTrack.PLAYSTATE_PLAYING) { + val currentFrames = audioTrack.playbackHeadPosition + if (currentFrames >= totalFrames) { + break // Exit loop when playback is done + } + val progress = currentFrames.toFloat() / totalFrames + val curMs = System.currentTimeMillis() + if (curMs - lastProgressUpdateMs > 30) { + onProgress(progress) + lastProgressUpdateMs = curMs + } + } + + if (isActive) { + audioTrackState.value?.stop() + } + } + } catch (e: Exception) { + // Ignore + } finally { + onProgress(1f) + onCompletion() + } +} + +private fun stopPlayAudio(audioTrackState: MutableState) { + Log.d(TAG, "Stopping playing audio...") + + val audioTrack = audioTrackState.value + audioTrack?.stop() + audioTrack?.release() + audioTrackState.value = null +} + +/** + * Processes a raw PCM 16-bit audio byte array to generate a list of normalized amplitude levels for + * visualization. + */ +private fun generateAmplitudeLevels(audioData: ByteArray, barCount: Int): List { + if (audioData.isEmpty()) { + return List(barCount) { 0f } + } + + // 1. Parse bytes into 16-bit short samples (PCM 16-bit) + val shortBuffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer() + val samples = ShortArray(shortBuffer.remaining()) + shortBuffer.get(samples) + + if (samples.isEmpty()) { + return List(barCount) { 0f } + } + + // 2. Determine the size of each chunk + val chunkSize = samples.size / barCount + val amplitudeLevels = mutableListOf() + + // 3. Get the max value for each chunk + for (i in 0 until barCount) { + val chunkStart = i * chunkSize + val chunkEnd = (chunkStart + chunkSize).coerceAtMost(samples.size) + + var maxAmplitudeInChunk = 0.0 + + for (j in chunkStart until chunkEnd) { + val sampleAbs = kotlin.math.abs(samples[j].toDouble()) + if (sampleAbs > maxAmplitudeInChunk) { + maxAmplitudeInChunk = sampleAbs + } + } + + // 4. Normalize the value (0 to 1) + // Short.MAX_VALUE is 32767.0, a good reference for max amplitude + val normalizedRms = (maxAmplitudeInChunk / Short.MAX_VALUE).toFloat().coerceIn(0f, 1f) + amplitudeLevels.add(normalizedRms) + } + + // Normalize the resulting levels so that the max value becomes 0.9. + val maxVal = amplitudeLevels.max() + if (maxVal == 0f) { + return amplitudeLevels + } + val scaleFactor = 0.9f / maxVal + return amplitudeLevels.map { it * scaleFactor } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt new file mode 100644 index 0000000..9b018dc --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt @@ -0,0 +1,327 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.ui.common.chat + +import android.annotation.SuppressLint +import android.content.Context +import android.media.AudioFormat +import android.media.AudioRecord +import android.media.MediaRecorder +import android.util.Log +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.rounded.ArrowUpward +import androidx.compose.material.icons.rounded.Mic +import androidx.compose.material.icons.rounded.Stop +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.DisposableEffect +import androidx.compose.runtime.MutableLongState +import androidx.compose.runtime.MutableState +import androidx.compose.runtime.derivedStateOf +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableIntStateOf +import androidx.compose.runtime.mutableLongStateOf +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.graphicsLayer +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.res.painterResource +import androidx.compose.ui.unit.dp +import com.google.ai.edge.gallery.R +import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC +import com.google.ai.edge.gallery.data.SAMPLE_RATE +import com.google.ai.edge.gallery.ui.theme.customColors +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer +import java.nio.ByteOrder +import kotlin.math.abs +import kotlin.math.pow +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch + +private const val TAG = "AGAudioRecorderPanel" + +private const val CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO +private const val AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT + +/** + * A Composable that provides an audio recording panel. It allows users to record audio clips, + * displays the recording duration and a live amplitude visualization, and provides options to play + * back the recorded clip or send it. + */ +@Composable +fun AudioRecorderPanel(onSendAudioClip: (ByteArray) -> Unit) { + val context = LocalContext.current + val coroutineScope = rememberCoroutineScope() + + var isRecording by remember { mutableStateOf(false) } + val elapsedMs = remember { mutableLongStateOf(0L) } + val audioRecordState = remember { mutableStateOf(null) } + val audioStream = remember { ByteArrayOutputStream() } + val recordedBytes = remember { mutableStateOf(null) } + var currentAmplitude by remember { mutableIntStateOf(0) } + + val elapsedSeconds by remember { + derivedStateOf { "%.1f".format(elapsedMs.value.toFloat() / 1000f) } + } + + // Cleanup on Composable Disposal. + DisposableEffect(Unit) { onDispose { audioRecordState.value?.release() } } + + Column(modifier = Modifier.padding(bottom = 12.dp)) { + // Title bar. + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween, + ) { + // Logo and state. + Row( + modifier = Modifier.padding(start = 16.dp), + horizontalArrangement = Arrangement.spacedBy(12.dp), + ) { + Icon( + painterResource(R.drawable.logo), + modifier = Modifier.size(20.dp), + contentDescription = "", + tint = Color.Unspecified, + ) + Text( + "Record audio clip (up to $MAX_AUDIO_CLIP_DURATION_SEC seconds)", + style = MaterialTheme.typography.labelLarge, + ) + } + } + + // Recorded clip. + Row( + modifier = Modifier.fillMaxWidth().padding(vertical = 12.dp).height(40.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.Center, + ) { + val curRecordedBytes = recordedBytes.value + if (curRecordedBytes == null) { + // Info message when there is no recorded clip and the recording has not started yet. + if (!isRecording) { + Text( + "Tap the record button to start", + style = MaterialTheme.typography.labelLarge, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + } + // Visualization for clip being recorded. + else { + Row( + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalAlignment = Alignment.CenterVertically, + ) { + Box( + modifier = + Modifier.size(8.dp) + .background(MaterialTheme.customColors.recordButtonBgColor, CircleShape) + ) + Text("$elapsedSeconds s") + } + } + } + // Controls for recorded clip. + else { + Row { + // Clip player. + AudioPlaybackPanel( + audioData = curRecordedBytes, + sampleRate = SAMPLE_RATE, + isRecording = isRecording, + ) + + // Button to send the clip + IconButton(onClick = { onSendAudioClip(curRecordedBytes) }) { + Icon(Icons.Rounded.ArrowUpward, contentDescription = "") + } + } + } + } + + // Buttons + Box(contentAlignment = Alignment.Center, modifier = Modifier.fillMaxWidth().height(40.dp)) { + // Visualization of the current amplitude. + if (isRecording) { + // Normalize the amplitude (0-32767) to a fraction (0.0-1.0) + // We use a power scale (exponent < 1) to make the pulse more visible for lower volumes. + val normalizedAmplitude = (currentAmplitude.toFloat() / 32767f).pow(0.35f) + // Define the min and max size of the circle + val minSize = 38.dp + val maxSize = 100.dp + + // Map the normalized amplitude to our size range + val scale by + remember(normalizedAmplitude) { + derivedStateOf { (minSize + (maxSize - minSize) * normalizedAmplitude) / minSize } + } + Box( + modifier = + Modifier.size(minSize) + .graphicsLayer(scaleX = scale, scaleY = scale, clip = false, alpha = 0.3f) + .background(MaterialTheme.customColors.recordButtonBgColor, CircleShape) + ) + } + + // Record/stop button. + IconButton( + onClick = { + coroutineScope.launch { + if (!isRecording) { + isRecording = true + recordedBytes.value = null + startRecording( + context = context, + audioRecordState = audioRecordState, + audioStream = audioStream, + elapsedMs = elapsedMs, + onAmplitudeChanged = { currentAmplitude = it }, + onMaxDurationReached = { + val curRecordedBytes = + stopRecording(audioRecordState = audioRecordState, audioStream = audioStream) + recordedBytes.value = curRecordedBytes + isRecording = false + }, + ) + } else { + val curRecordedBytes = + stopRecording(audioRecordState = audioRecordState, audioStream = audioStream) + recordedBytes.value = curRecordedBytes + isRecording = false + } + } + }, + modifier = + Modifier.clip(CircleShape).background(MaterialTheme.customColors.recordButtonBgColor), + ) { + Icon( + if (isRecording) Icons.Rounded.Stop else Icons.Rounded.Mic, + contentDescription = "", + tint = Color.White, + ) + } + } + } +} + +// Permission is checked in parent composable. +@SuppressLint("MissingPermission") +private suspend fun startRecording( + context: Context, + audioRecordState: MutableState, + audioStream: ByteArrayOutputStream, + elapsedMs: MutableLongState, + onAmplitudeChanged: (Int) -> Unit, + onMaxDurationReached: () -> Unit, +) { + Log.d(TAG, "Start recording...") + val minBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT) + + audioRecordState.value?.release() + val recorder = + AudioRecord( + MediaRecorder.AudioSource.MIC, + SAMPLE_RATE, + CHANNEL_CONFIG, + AUDIO_FORMAT, + minBufferSize, + ) + + audioRecordState.value = recorder + val buffer = ByteArray(minBufferSize) + + // The function will only return when the recording is done (when stopRecording is called). + coroutineScope { + launch(Dispatchers.IO) { + recorder.startRecording() + + val startMs = System.currentTimeMillis() + elapsedMs.value = 0L + while (audioRecordState.value?.recordingState == AudioRecord.RECORDSTATE_RECORDING) { + val bytesRead = recorder.read(buffer, 0, buffer.size) + if (bytesRead > 0) { + val currentAmplitude = calculatePeakAmplitude(buffer = buffer, bytesRead = bytesRead) + onAmplitudeChanged(currentAmplitude) + audioStream.write(buffer, 0, bytesRead) + } + elapsedMs.value = System.currentTimeMillis() - startMs + if (elapsedMs.value >= MAX_AUDIO_CLIP_DURATION_SEC * 1000) { + onMaxDurationReached() + break + } + } + } + } +} + +private fun stopRecording( + audioRecordState: MutableState, + audioStream: ByteArrayOutputStream, +): ByteArray { + Log.d(TAG, "Stopping recording...") + + val recorder = audioRecordState.value + if (recorder?.recordingState == AudioRecord.RECORDSTATE_RECORDING) { + recorder.stop() + } + recorder?.release() + audioRecordState.value = null + + val recordedBytes = audioStream.toByteArray() + audioStream.reset() + Log.d(TAG, "Stopped. Recorded ${recordedBytes.size} bytes.") + + return recordedBytes +} + +private fun calculatePeakAmplitude(buffer: ByteArray, bytesRead: Int): Int { + // Wrap the byte array in a ByteBuffer and set the order to little-endian + val shortBuffer = + ByteBuffer.wrap(buffer, 0, bytesRead).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer() + + var maxAmplitude = 0 + // Iterate through the short buffer to find the maximum absolute value + while (shortBuffer.hasRemaining()) { + val currentSample = abs(shortBuffer.get().toInt()) + if (currentSample > maxAmplitude) { + maxAmplitude = currentSample + } + } + return maxAmplitude +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt index 7eba194..e1f45b0 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt @@ -17,18 +17,22 @@ package com.google.ai.edge.gallery.ui.common.chat import android.graphics.Bitmap +import android.util.Log import androidx.compose.ui.graphics.ImageBitmap import androidx.compose.ui.unit.Dp import com.google.ai.edge.gallery.common.Classification import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.PromptTemplate +private const val TAG = "AGChatMessage" + enum class ChatMessageType { INFO, WARNING, TEXT, IMAGE, IMAGE_WITH_HISTORY, + AUDIO_CLIP, LOADING, CLASSIFICATION, CONFIG_VALUES_CHANGE, @@ -121,6 +125,90 @@ class ChatMessageImage( } } +/** Chat message for audio clip. */ +class ChatMessageAudioClip( + val audioData: ByteArray, + val sampleRate: Int, + override val side: ChatSide, + override val latencyMs: Float = 0f, +) : ChatMessage(type = ChatMessageType.AUDIO_CLIP, side = side, latencyMs = latencyMs) { + override fun clone(): ChatMessageAudioClip { + return ChatMessageAudioClip( + audioData = audioData, + sampleRate = sampleRate, + side = side, + latencyMs = latencyMs, + ) + } + + fun genByteArrayForWav(): ByteArray { + val header = ByteArray(44) + + val pcmDataSize = audioData.size + val wavFileSize = pcmDataSize + 44 // 44 bytes for the header + val channels = 1 // Mono + val bitsPerSample: Short = 16 + val byteRate = sampleRate * channels * bitsPerSample / 8 + Log.d(TAG, "Wav metadata: sampleRate: $sampleRate") + + // RIFF/WAVE header + header[0] = 'R'.code.toByte() + header[1] = 'I'.code.toByte() + header[2] = 'F'.code.toByte() + header[3] = 'F'.code.toByte() + header[4] = (wavFileSize and 0xff).toByte() + header[5] = (wavFileSize shr 8 and 0xff).toByte() + header[6] = (wavFileSize shr 16 and 0xff).toByte() + header[7] = (wavFileSize shr 24 and 0xff).toByte() + header[8] = 'W'.code.toByte() + header[9] = 'A'.code.toByte() + header[10] = 'V'.code.toByte() + header[11] = 'E'.code.toByte() + header[12] = 'f'.code.toByte() + header[13] = 'm'.code.toByte() + header[14] = 't'.code.toByte() + header[15] = ' '.code.toByte() + header[16] = 16 + header[17] = 0 + header[18] = 0 + header[19] = 0 // Sub-chunk size (16 for PCM) + header[20] = 1 + header[21] = 0 // Audio format (1 for PCM) + header[22] = channels.toByte() + header[23] = 0 // Number of channels + header[24] = (sampleRate and 0xff).toByte() + header[25] = (sampleRate shr 8 and 0xff).toByte() + header[26] = (sampleRate shr 16 and 0xff).toByte() + header[27] = (sampleRate shr 24 and 0xff).toByte() + header[28] = (byteRate and 0xff).toByte() + header[29] = (byteRate shr 8 and 0xff).toByte() + header[30] = (byteRate shr 16 and 0xff).toByte() + header[31] = (byteRate shr 24 and 0xff).toByte() + header[32] = (channels * bitsPerSample / 8).toByte() + header[33] = 0 // Block align + header[34] = bitsPerSample.toByte() + header[35] = (bitsPerSample.toInt() shr 8 and 0xff).toByte() // Bits per sample + header[36] = 'd'.code.toByte() + header[37] = 'a'.code.toByte() + header[38] = 't'.code.toByte() + header[39] = 'a'.code.toByte() + header[40] = (pcmDataSize and 0xff).toByte() + header[41] = (pcmDataSize shr 8 and 0xff).toByte() + header[42] = (pcmDataSize shr 16 and 0xff).toByte() + header[43] = (pcmDataSize shr 24 and 0xff).toByte() + + return header + audioData + } + + fun getDurationInSeconds(): Float { + // PCM 16-bit + val bytesPerSample = 2 + val bytesPerFrame = bytesPerSample * 1 // mono + val totalFrames = audioData.size.toFloat() / bytesPerFrame + return totalFrames / sampleRate + } +} + /** Chat message for images with history. */ class ChatMessageImageWithHistory( val bitmaps: List, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt index 2b8b168..80f2cfa 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt @@ -137,6 +137,19 @@ fun ChatPanel( } imageMessageCount } + val audioClipMesssageCountToLastconfigChange = + remember(messages) { + var audioClipMessageCount = 0 + for (message in messages.reversed()) { + if (message is ChatMessageConfigValuesChange) { + break + } + if (message is ChatMessageAudioClip) { + audioClipMessageCount++ + } + } + audioClipMessageCount + } var curMessage by remember { mutableStateOf("") } // Correct state val focusManager = LocalFocusManager.current @@ -342,6 +355,9 @@ fun ChatPanel( imageHistoryCurIndex = imageHistoryCurIndex, ) + // Audio clip. + is ChatMessageAudioClip -> MessageBodyAudioClip(message = message) + // Classification result is ChatMessageClassification -> MessageBodyClassification( @@ -467,6 +483,22 @@ fun ChatPanel( ) } } + // Show an info message for ask image task to get users started. + else if (task.type == TaskType.LLM_ASK_AUDIO && messages.isEmpty()) { + Column( + modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center, + ) { + MessageBodyInfo( + ChatMessageInfo( + content = + "To get started, tap the + icon to add your audio clips. You can add up to 10 clips, each up to 30 seconds long." + ), + smallFontSize = false, + ) + } + } } // Chat input @@ -482,6 +514,7 @@ fun ChatPanel( isResettingSession = uiState.isResettingSession, modelPreparing = uiState.preparing, imageMessageCount = imageMessageCountToLastConfigChange, + audioClipMessageCount = audioClipMesssageCountToLastconfigChange, modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, textFieldPlaceHolderRes = task.textInputPlaceHolderRes, @@ -504,7 +537,10 @@ fun ChatPanel( onStopButtonClicked = onStopButtonClicked, // showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, showPromptTemplatesInMenu = false, - showImagePickerInMenu = selectedModel.llmSupportImage, + showImagePickerInMenu = + selectedModel.llmSupportImage && task.type === TaskType.LLM_ASK_IMAGE, + showAudioItemsInMenu = + selectedModel.llmSupportAudio && task.type === TaskType.LLM_ASK_AUDIO, showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress, ) } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt new file mode 100644 index 0000000..e4e3716 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt @@ -0,0 +1,33 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.ui.common.chat + +import androidx.compose.foundation.layout.padding +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.unit.dp + +@Composable +fun MessageBodyAudioClip(message: ChatMessageAudioClip, modifier: Modifier = Modifier) { + AudioPlaybackPanel( + audioData = message.audioData, + sampleRate = message.sampleRate, + isRecording = false, + modifier = Modifier.padding(end = 16.dp), + onDarkBg = true, + ) +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt index 1e2e056..d958815 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt @@ -21,6 +21,7 @@ package com.google.ai.edge.gallery.ui.common.chat // import com.google.ai.edge.gallery.ui.theme.GalleryTheme import android.Manifest import android.content.Context +import android.content.Intent import android.content.pm.PackageManager import android.graphics.Bitmap import android.graphics.BitmapFactory @@ -65,9 +66,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.rounded.Send import androidx.compose.material.icons.rounded.Add +import androidx.compose.material.icons.rounded.AudioFile import androidx.compose.material.icons.rounded.Close import androidx.compose.material.icons.rounded.FlipCameraAndroid import androidx.compose.material.icons.rounded.History +import androidx.compose.material.icons.rounded.Mic import androidx.compose.material.icons.rounded.Photo import androidx.compose.material.icons.rounded.PhotoCamera import androidx.compose.material.icons.rounded.PostAdd @@ -107,9 +110,14 @@ import androidx.compose.ui.unit.dp import androidx.compose.ui.viewinterop.AndroidView import androidx.core.content.ContextCompat import androidx.lifecycle.compose.LocalLifecycleOwner +import com.google.ai.edge.gallery.common.AudioClip +import com.google.ai.edge.gallery.common.convertWavToMonoWithMaxSeconds +import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_COUNT import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT +import com.google.ai.edge.gallery.data.SAMPLE_RATE import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import java.util.concurrent.Executors +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch private const val TAG = "AGMessageInputText" @@ -128,6 +136,7 @@ fun MessageInputText( isResettingSession: Boolean, inProgress: Boolean, imageMessageCount: Int, + audioClipMessageCount: Int, modelInitializing: Boolean, @StringRes textFieldPlaceHolderRes: Int, onValueChanged: (String) -> Unit, @@ -137,6 +146,7 @@ fun MessageInputText( onStopButtonClicked: () -> Unit = {}, showPromptTemplatesInMenu: Boolean = false, showImagePickerInMenu: Boolean = false, + showAudioItemsInMenu: Boolean = false, showStopButtonWhenInProgress: Boolean = false, ) { val context = LocalContext.current @@ -146,7 +156,12 @@ fun MessageInputText( var showTextInputHistorySheet by remember { mutableStateOf(false) } var showCameraCaptureBottomSheet by remember { mutableStateOf(false) } val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) + var showAudioRecorderBottomSheet by remember { mutableStateOf(false) } + val audioRecorderSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) var pickedImages by remember { mutableStateOf>(listOf()) } + var pickedAudioClips by remember { mutableStateOf>(listOf()) } + var hasFrontCamera by remember { mutableStateOf(false) } + val updatePickedImages: (List) -> Unit = { bitmaps -> var newPickedImages: MutableList = mutableListOf() newPickedImages.addAll(pickedImages) @@ -156,7 +171,16 @@ fun MessageInputText( } pickedImages = newPickedImages.toList() } - var hasFrontCamera by remember { mutableStateOf(false) } + + val updatePickedAudioClips: (List) -> Unit = { audioDataList -> + var newAudioDataList: MutableList = mutableListOf() + newAudioDataList.addAll(pickedAudioClips) + newAudioDataList.addAll(audioDataList) + if (newAudioDataList.size > MAX_AUDIO_CLIP_COUNT) { + newAudioDataList = newAudioDataList.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT) + } + pickedAudioClips = newAudioDataList.toList() + } LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) } @@ -170,6 +194,16 @@ fun MessageInputText( } } + // Permission request when recording audio clips. + val recordAudioClipsPermissionLauncher = + rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) { + permissionGranted -> + if (permissionGranted) { + showAddContentMenu = false + showAudioRecorderBottomSheet = true + } + } + // Registers a photo picker activity launcher in single-select mode. val pickMedia = rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris -> @@ -184,9 +218,31 @@ fun MessageInputText( } } + val pickWav = + rememberLauncherForActivityResult( + contract = ActivityResultContracts.StartActivityForResult() + ) { result -> + if (result.resultCode == android.app.Activity.RESULT_OK) { + result.data?.data?.let { uri -> + Log.d(TAG, "Picked wav file: $uri") + scope.launch(Dispatchers.IO) { + convertWavToMonoWithMaxSeconds(context = context, stereoUri = uri)?.let { audioClip -> + updatePickedAudioClips( + listOf( + AudioClip(audioData = audioClip.audioData, sampleRate = audioClip.sampleRate) + ) + ) + } + } + } + } else { + Log.d(TAG, "Wav picking cancelled.") + } + } + Column { - // A preview panel for the selected image. - if (pickedImages.isNotEmpty()) { + // A preview panel for the selected images and audio clips. + if (pickedImages.isNotEmpty() || pickedAudioClips.isNotEmpty()) { Row( modifier = Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()), @@ -203,20 +259,30 @@ fun MessageInputText( .clip(RoundedCornerShape(8.dp)) .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)), ) + MediaPanelCloseButton { pickedImages = pickedImages.filter { image != it } } + } + } + + for ((index, audioClip) in pickedAudioClips.withIndex()) { + Box(contentAlignment = Alignment.TopEnd) { Box( modifier = - Modifier.offset(x = 10.dp, y = (-10).dp) - .clip(CircleShape) + Modifier.shadow(2.dp, shape = RoundedCornerShape(8.dp)) + .clip(RoundedCornerShape(8.dp)) .background(MaterialTheme.colorScheme.surface) - .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape) - .clickable { pickedImages = pickedImages.filter { image != it } } + .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)) ) { - Icon( - Icons.Rounded.Close, - contentDescription = "", - modifier = Modifier.padding(3.dp).size(16.dp), + AudioPlaybackPanel( + audioData = audioClip.audioData, + sampleRate = audioClip.sampleRate, + isRecording = false, + modifier = Modifier.padding(end = 16.dp), ) } + MediaPanelCloseButton { + pickedAudioClips = + pickedAudioClips.filterIndexed { curIndex, curAudioData -> curIndex != index } + } } } } @@ -239,10 +305,13 @@ fun MessageInputText( verticalAlignment = Alignment.CenterVertically, ) { val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT + val enableRecordAudioClipMenuItems = + (audioClipMessageCount + pickedAudioClips.size) < MAX_AUDIO_CLIP_COUNT DropdownMenu( expanded = showAddContentMenu, onDismissRequest = { showAddContentMenu = false }, ) { + // Image related menu items. if (showImagePickerInMenu) { // Take a picture. DropdownMenuItem( @@ -295,6 +364,70 @@ fun MessageInputText( ) } + // Audio related menu items. + if (showAudioItemsInMenu) { + DropdownMenuItem( + text = { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp), + ) { + Icon(Icons.Rounded.Mic, contentDescription = "") + Text("Record audio clip") + } + }, + enabled = enableRecordAudioClipMenuItems, + onClick = { + // Check permission + when (PackageManager.PERMISSION_GRANTED) { + // Already got permission. Call the lambda. + ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO) -> { + showAddContentMenu = false + showAudioRecorderBottomSheet = true + } + + // Otherwise, ask for permission + else -> { + recordAudioClipsPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO) + } + } + }, + ) + + DropdownMenuItem( + text = { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp), + ) { + Icon(Icons.Rounded.AudioFile, contentDescription = "") + Text("Pick wav file") + } + }, + enabled = enableRecordAudioClipMenuItems, + onClick = { + showAddContentMenu = false + + // Show file picker. + val intent = + Intent(Intent.ACTION_GET_CONTENT).apply { + addCategory(Intent.CATEGORY_OPENABLE) + type = "audio/*" + + // Provide a list of more specific MIME types to filter for. + val mimeTypes = arrayOf("audio/wav", "audio/x-wav") + putExtra(Intent.EXTRA_MIME_TYPES, mimeTypes) + + // Single select. + putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false) + .addFlags(Intent.FLAG_GRANT_PERSISTABLE_URI_PERMISSION) + .addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION) + } + pickWav.launch(intent) + }, + ) + } + // Prompt templates. if (showPromptTemplatesInMenu) { DropdownMenuItem( @@ -369,15 +502,22 @@ fun MessageInputText( ) } } - } // Send button. Only shown when text is not empty. - else if (curMessage.isNotEmpty()) { + } + // Send button. Only shown when text is not empty, or there is at least one recorded + // audio clip. + else if (curMessage.isNotEmpty() || pickedAudioClips.isNotEmpty()) { IconButton( enabled = !inProgress && !isResettingSession, onClick = { onSendMessage( - createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim()) + createMessagesToSend( + pickedImages = pickedImages, + audioClips = pickedAudioClips, + text = curMessage.trim(), + ) ) pickedImages = listOf() + pickedAudioClips = listOf() }, colors = IconButtonDefaults.iconButtonColors( @@ -403,8 +543,15 @@ fun MessageInputText( history = modelManagerUiState.textInputHistory, onDismissed = { showTextInputHistorySheet = false }, onHistoryItemClicked = { item -> - onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item)) + onSendMessage( + createMessagesToSend( + pickedImages = pickedImages, + audioClips = pickedAudioClips, + text = item, + ) + ) pickedImages = listOf() + pickedAudioClips = listOf() modelManagerViewModel.promoteTextInputHistoryItem(item) }, onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) }, @@ -582,6 +729,43 @@ fun MessageInputText( } } } + + if (showAudioRecorderBottomSheet) { + ModalBottomSheet( + sheetState = audioRecorderSheetState, + onDismissRequest = { showAudioRecorderBottomSheet = false }, + ) { + AudioRecorderPanel( + onSendAudioClip = { audioData -> + scope.launch { + updatePickedAudioClips( + listOf(AudioClip(audioData = audioData, sampleRate = SAMPLE_RATE)) + ) + audioRecorderSheetState.hide() + showAudioRecorderBottomSheet = false + } + } + ) + } + } +} + +@Composable +private fun MediaPanelCloseButton(onClicked: () -> Unit) { + Box( + modifier = + Modifier.offset(x = 10.dp, y = (-10).dp) + .clip(CircleShape) + .background(MaterialTheme.colorScheme.surface) + .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape) + .clickable { onClicked() } + ) { + Icon( + Icons.Rounded.Close, + contentDescription = "", + modifier = Modifier.padding(3.dp).size(16.dp), + ) + } } private fun handleImagesSelected( @@ -641,20 +825,50 @@ private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) { ) } -private fun createMessagesToSend(pickedImages: List, text: String): List { +private fun createMessagesToSend( + pickedImages: List, + audioClips: List, + text: String, +): List { var messages: MutableList = mutableListOf() + + // Add image messages. + var imageMessages: MutableList = mutableListOf() if (pickedImages.isNotEmpty()) { for (image in pickedImages) { - messages.add( + imageMessages.add( ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER) ) } } // Cap the number of image messages. - if (messages.size > MAX_IMAGE_COUNT) { - messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT) + if (imageMessages.size > MAX_IMAGE_COUNT) { + imageMessages = imageMessages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT) + } + messages.addAll(imageMessages) + + // Add audio messages. + var audioMessages: MutableList = mutableListOf() + if (audioClips.isNotEmpty()) { + for (audioClip in audioClips) { + audioMessages.add( + ChatMessageAudioClip( + audioData = audioClip.audioData, + sampleRate = audioClip.sampleRate, + side = ChatSide.USER, + ) + ) + } + } + // Cap the number of audio messages. + if (audioMessages.size > MAX_AUDIO_CLIP_COUNT) { + audioMessages = audioMessages.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT) + } + messages.addAll(audioMessages) + + if (text.isNotEmpty()) { + messages.add(ChatMessageText(content = text, side = ChatSide.USER)) } - messages.add(ChatMessageText(content = text, side = ChatSide.USER)) return messages } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt index 1f0cd00..9c985e5 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt @@ -121,6 +121,7 @@ private val IMPORT_CONFIGS_LLM: List = valueType = ValueType.FLOAT, ), BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false), + BooleanSwitchConfig(key = ConfigKey.SUPPORT_AUDIO, defaultValue = false), SegmentedButtonConfig( key = ConfigKey.COMPATIBLE_ACCELERATORS, defaultValue = Accelerator.CPU.label, @@ -230,6 +231,12 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) - valueType = ValueType.BOOLEAN, ) as Boolean + val supportAudio = + convertValueToTargetType( + value = values.get(ConfigKey.SUPPORT_AUDIO.label)!!, + valueType = ValueType.BOOLEAN, + ) + as Boolean val importedModel: ImportedModel = ImportedModel.newBuilder() .setFileName(fileName) @@ -242,6 +249,7 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) - .setDefaultTopp(defaultTopp) .setDefaultTemperature(defaultTemperature) .setSupportImage(supportImage) + .setSupportAudio(supportAudio) .build() ) .build() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt index ee418bb..dc0bba1 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt @@ -173,7 +173,7 @@ fun SettingsDialog( color = MaterialTheme.colorScheme.onSurfaceVariant, ) Text( - "Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}", + "Expires at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}", style = MaterialTheme.typography.bodyMedium, color = MaterialTheme.colorScheme.onSurfaceVariant, ) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt deleted file mode 100644 index 173ca4f..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.icon - -import androidx.compose.ui.graphics.Color -import androidx.compose.ui.graphics.SolidColor -import androidx.compose.ui.graphics.vector.ImageVector -import androidx.compose.ui.graphics.vector.path -import androidx.compose.ui.unit.dp - -val Forum: ImageVector - get() { - if (_Forum != null) return _Forum!! - - _Forum = - ImageVector.Builder( - name = "Forum", - defaultWidth = 24.dp, - defaultHeight = 24.dp, - viewportWidth = 960f, - viewportHeight = 960f, - ) - .apply { - path(fill = SolidColor(Color(0xFF000000))) { - moveTo(280f, 720f) - quadToRelative(-17f, 0f, -28.5f, -11.5f) - reflectiveQuadTo(240f, 680f) - verticalLineToRelative(-80f) - horizontalLineToRelative(520f) - verticalLineToRelative(-360f) - horizontalLineToRelative(80f) - quadToRelative(17f, 0f, 28.5f, 11.5f) - reflectiveQuadTo(880f, 280f) - verticalLineToRelative(600f) - lineTo(720f, 720f) - close() - moveTo(80f, 680f) - verticalLineToRelative(-560f) - quadToRelative(0f, -17f, 11.5f, -28.5f) - reflectiveQuadTo(120f, 80f) - horizontalLineToRelative(520f) - quadToRelative(17f, 0f, 28.5f, 11.5f) - reflectiveQuadTo(680f, 120f) - verticalLineToRelative(360f) - quadToRelative(0f, 17f, -11.5f, 28.5f) - reflectiveQuadTo(640f, 520f) - horizontalLineTo(240f) - close() - moveToRelative(520f, -240f) - verticalLineToRelative(-280f) - horizontalLineTo(160f) - verticalLineToRelative(280f) - close() - moveToRelative(-440f, 0f) - verticalLineToRelative(-280f) - close() - } - } - .build() - - return _Forum!! - } - -private var _Forum: ImageVector? = null diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt deleted file mode 100644 index 56f9990..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.icon - -import androidx.compose.ui.graphics.Color -import androidx.compose.ui.graphics.SolidColor -import androidx.compose.ui.graphics.vector.ImageVector -import androidx.compose.ui.graphics.vector.path -import androidx.compose.ui.unit.dp - -val Mms: ImageVector - get() { - if (_Mms != null) return _Mms!! - - _Mms = - ImageVector.Builder( - name = "Mms", - defaultWidth = 24.dp, - defaultHeight = 24.dp, - viewportWidth = 960f, - viewportHeight = 960f, - ) - .apply { - path(fill = SolidColor(Color(0xFF000000))) { - moveTo(240f, 560f) - horizontalLineToRelative(480f) - lineTo(570f, 360f) - lineTo(450f, 520f) - lineToRelative(-90f, -120f) - close() - moveTo(80f, 880f) - verticalLineToRelative(-720f) - quadToRelative(0f, -33f, 23.5f, -56.5f) - reflectiveQuadTo(160f, 80f) - horizontalLineToRelative(640f) - quadToRelative(33f, 0f, 56.5f, 23.5f) - reflectiveQuadTo(880f, 160f) - verticalLineToRelative(480f) - quadToRelative(0f, 33f, -23.5f, 56.5f) - reflectiveQuadTo(800f, 720f) - horizontalLineTo(240f) - close() - moveToRelative(126f, -240f) - horizontalLineToRelative(594f) - verticalLineToRelative(-480f) - horizontalLineTo(160f) - verticalLineToRelative(525f) - close() - moveToRelative(-46f, 0f) - verticalLineToRelative(-480f) - close() - } - } - .build() - - return _Mms!! - } - -private var _Mms: ImageVector? = null diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt deleted file mode 100644 index c727a7b..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.icon - -import androidx.compose.ui.graphics.Color -import androidx.compose.ui.graphics.SolidColor -import androidx.compose.ui.graphics.vector.ImageVector -import androidx.compose.ui.graphics.vector.path -import androidx.compose.ui.unit.dp - -val Widgets: ImageVector - get() { - if (_Widgets != null) return _Widgets!! - - _Widgets = - ImageVector.Builder( - name = "Widgets", - defaultWidth = 24.dp, - defaultHeight = 24.dp, - viewportWidth = 960f, - viewportHeight = 960f, - ) - .apply { - path(fill = SolidColor(Color(0xFF000000))) { - moveTo(666f, 520f) - lineTo(440f, 294f) - lineToRelative(226f, -226f) - lineToRelative(226f, 226f) - close() - moveToRelative(-546f, -80f) - verticalLineToRelative(-320f) - horizontalLineToRelative(320f) - verticalLineToRelative(320f) - close() - moveToRelative(400f, 400f) - verticalLineToRelative(-320f) - horizontalLineToRelative(320f) - verticalLineToRelative(320f) - close() - moveToRelative(-400f, 0f) - verticalLineToRelative(-320f) - horizontalLineToRelative(320f) - verticalLineToRelative(320f) - close() - moveToRelative(80f, -480f) - horizontalLineToRelative(160f) - verticalLineToRelative(-160f) - horizontalLineTo(200f) - close() - moveToRelative(467f, 48f) - lineToRelative(113f, -113f) - lineToRelative(-113f, -113f) - lineToRelative(-113f, 113f) - close() - moveToRelative(-67f, 352f) - horizontalLineToRelative(160f) - verticalLineToRelative(-160f) - horizontalLineTo(600f) - close() - moveToRelative(-400f, 0f) - horizontalLineToRelative(160f) - verticalLineToRelative(-160f) - horizontalLineTo(200f) - close() - moveToRelative(400f, -160f) - } - } - .build() - - return _Widgets!! - } - -private var _Widgets: ImageVector? = null diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt index 128baa8..a330c35 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -62,13 +62,13 @@ object LlmChatModelHelper { Accelerator.GPU.label -> LlmInference.Backend.GPU else -> LlmInference.Backend.GPU } - val options = + val optionsBuilder = LlmInference.LlmInferenceOptions.builder() .setModelPath(model.getPath(context = context)) .setMaxTokens(maxTokens) .setPreferredBackend(preferredBackend) .setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0) - .build() + val options = optionsBuilder.build() // Create an instance of the LLM Inference task and session. try { @@ -82,7 +82,9 @@ object LlmChatModelHelper { .setTopP(topP) .setTemperature(temperature) .setGraphOptions( - GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() + GraphOptions.builder() + .setEnableVisionModality(model.llmSupportImage) + .build() ) .build(), ) @@ -115,7 +117,9 @@ object LlmChatModelHelper { .setTopP(topP) .setTemperature(temperature) .setGraphOptions( - GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() + GraphOptions.builder() + .setEnableVisionModality(model.llmSupportImage) + .build() ) .build(), ) @@ -159,6 +163,7 @@ object LlmChatModelHelper { resultListener: ResultListener, cleanUpListener: CleanUpListener, images: List = listOf(), + audioClips: List = listOf(), ) { val instance = model.instance as LlmModelInstance @@ -172,10 +177,16 @@ object LlmChatModelHelper { // For a model that supports image modality, we need to add the text query chunk before adding // image. val session = instance.session - session.addQueryChunk(input) + if (input.trim().isNotEmpty()) { + session.addQueryChunk(input) + } for (image in images) { session.addImage(BitmapImageBuilder(image).build()) } + for (audioClip in audioClips) { + // Uncomment when audio is supported. + // session.addAudio(audioClip) + } val unused = session.generateResponseAsync(resultListener) } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt index 0a97f3f..23b5777 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt @@ -22,6 +22,7 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.lifecycle.viewmodel.compose.viewModel import com.google.ai.edge.gallery.ui.ViewModelProvider +import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText import com.google.ai.edge.gallery.ui.common.chat.ChatView @@ -36,6 +37,10 @@ object LlmAskImageDestination { val route = "LlmAskImageRoute" } +object LlmAskAudioDestination { + val route = "LlmAskAudioRoute" +} + @Composable fun LlmChatScreen( modelManagerViewModel: ModelManagerViewModel, @@ -66,6 +71,21 @@ fun LlmAskImageScreen( ) } +@Composable +fun LlmAskAudioScreen( + modelManagerViewModel: ModelManagerViewModel, + navigateUp: () -> Unit, + modifier: Modifier = Modifier, + viewModel: LlmAskAudioViewModel = viewModel(factory = ViewModelProvider.Factory), +) { + ChatViewWrapper( + viewModel = viewModel, + modelManagerViewModel = modelManagerViewModel, + navigateUp = navigateUp, + modifier = modifier, + ) +} + @Composable fun ChatViewWrapper( viewModel: LlmChatViewModel, @@ -86,6 +106,7 @@ fun ChatViewWrapper( var text = "" val images: MutableList = mutableListOf() + val audioMessages: MutableList = mutableListOf() var chatMessageText: ChatMessageText? = null for (message in messages) { if (message is ChatMessageText) { @@ -93,14 +114,17 @@ fun ChatViewWrapper( text = message.content } else if (message is ChatMessageImage) { images.add(message.bitmap) + } else if (message is ChatMessageAudioClip) { + audioMessages.add(message) } } - if (text.isNotEmpty() && chatMessageText != null) { + if ((text.isNotEmpty() && chatMessageText != null) || audioMessages.isNotEmpty()) { modelManagerViewModel.addTextInputHistory(text) viewModel.generateResponse( model = model, input = text, images = images, + audioMessages = audioMessages, onError = { viewModel.handleError( context = context, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt index 7b0a083..d97a9df 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -22,9 +22,11 @@ import android.util.Log import androidx.lifecycle.viewModelScope import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.Model +import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.Task +import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText @@ -52,6 +54,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task model: Model, input: String, images: List = listOf(), + audioMessages: List = listOf(), onError: () -> Unit, ) { val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") @@ -72,6 +75,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task val instance = model.instance as LlmModelInstance var prefillTokens = instance.session.sizeInTokens(input) prefillTokens += images.size * 257 + for (audioMessages in audioMessages) { + // 150ms = 1 audio token + val duration = audioMessages.getDurationInSeconds() + prefillTokens += (duration * 1000f / 150f).toInt() + } var firstRun = true var timeToFirstToken = 0f @@ -86,6 +94,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task model = model, input = input, images = images, + audioClips = audioMessages.map { it.genByteArrayForWav() }, resultListener = { partialResult, done -> val curTs = System.currentTimeMillis() @@ -214,7 +223,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task context: Context, model: Model, modelManagerViewModel: ModelManagerViewModel, - triggeredMessage: ChatMessageText, + triggeredMessage: ChatMessageText?, ) { // Clean up. modelManagerViewModel.cleanupModel(task = task, model = model) @@ -236,14 +245,20 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task ) // Add the triggered message back. - addMessage(model = model, message = triggeredMessage) + if (triggeredMessage != null) { + addMessage(model = model, message = triggeredMessage) + } // Re-initialize the session/engine. modelManagerViewModel.initializeModel(context = context, task = task, model = model) // Re-generate the response automatically. - generateResponse(model = model, input = triggeredMessage.content, onError = {}) + if (triggeredMessage != null) { + generateResponse(model = model, input = triggeredMessage.content, onError = {}) + } } } class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) + +class LlmAskAudioViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_AUDIO) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt index f73b893..6e2fb9e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -36,6 +36,7 @@ import com.google.ai.edge.gallery.data.ModelAllowlist import com.google.ai.edge.gallery.data.ModelDownloadStatus import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.TASKS +import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB @@ -281,15 +282,12 @@ open class ModelManagerViewModel( } } when (task.type) { - TaskType.LLM_CHAT -> - LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone) - + TaskType.LLM_CHAT, + TaskType.LLM_ASK_IMAGE, + TaskType.LLM_ASK_AUDIO, TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone) - TaskType.LLM_ASK_IMAGE -> - LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone) - TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_2 -> {} } @@ -301,9 +299,11 @@ open class ModelManagerViewModel( model.cleanUpAfterInit = false Log.d(TAG, "Cleaning up model '${model.name}'...") when (task.type) { - TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model) - TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model) - TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model) + TaskType.LLM_CHAT, + TaskType.LLM_PROMPT_LAB, + TaskType.LLM_ASK_IMAGE, + TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model) + TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_2 -> {} } @@ -410,14 +410,19 @@ open class ModelManagerViewModel( // Create model. val model = createModelFromImportedModelInfo(info = info) - for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) { + for (task in + listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) { // Remove duplicated imported model if existed. val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } if (modelIndex >= 0) { Log.d(TAG, "duplicated imported model found in task. Removing it first") task.models.removeAt(modelIndex) } - if ((task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || task != TASK_LLM_ASK_IMAGE) { + if ( + (task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || + (task == TASK_LLM_ASK_AUDIO && model.llmSupportAudio) || + (task != TASK_LLM_ASK_IMAGE && task != TASK_LLM_ASK_AUDIO) + ) { task.models.add(model) } task.updateTrigger.value = System.currentTimeMillis() @@ -657,6 +662,7 @@ open class ModelManagerViewModel( TASK_LLM_CHAT.models.clear() TASK_LLM_PROMPT_LAB.models.clear() TASK_LLM_ASK_IMAGE.models.clear() + TASK_LLM_ASK_AUDIO.models.clear() for (allowedModel in modelAllowlist.models) { if (allowedModel.disabled == true) { continue @@ -672,6 +678,9 @@ open class ModelManagerViewModel( if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) { TASK_LLM_ASK_IMAGE.models.add(model) } + if (allowedModel.taskTypes.contains(TASK_LLM_ASK_AUDIO.type.id)) { + TASK_LLM_ASK_AUDIO.models.add(model) + } } // Pre-process all tasks. @@ -760,6 +769,9 @@ open class ModelManagerViewModel( if (model.llmSupportImage) { TASK_LLM_ASK_IMAGE.models.add(model) } + if (model.llmSupportAudio) { + TASK_LLM_ASK_AUDIO.models.add(model) + } // Update status. modelDownloadStatus[model.name] = @@ -800,6 +812,7 @@ open class ModelManagerViewModel( accelerators = accelerators, ) val llmSupportImage = info.llmConfig.supportImage + val llmSupportAudio = info.llmConfig.supportAudio val model = Model( name = info.fileName, @@ -811,6 +824,7 @@ open class ModelManagerViewModel( showRunAgainButton = false, imported = true, llmSupportImage = llmSupportImage, + llmSupportAudio = llmSupportAudio, ) model.preProcess() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt index 5e8672d..138edc7 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt @@ -47,6 +47,7 @@ import androidx.navigation.compose.NavHost import androidx.navigation.compose.composable import androidx.navigation.navArgument import com.google.ai.edge.gallery.data.Model +import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB @@ -55,6 +56,8 @@ import com.google.ai.edge.gallery.data.TaskType import com.google.ai.edge.gallery.data.getModelByName import com.google.ai.edge.gallery.ui.ViewModelProvider import com.google.ai.edge.gallery.ui.home.HomeScreen +import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination +import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination @@ -209,7 +212,7 @@ fun GalleryNavHost( } } - // LLM image to text. + // Ask image. composable( route = "${LlmAskImageDestination.route}/{modelName}", arguments = listOf(navArgument("modelName") { type = NavType.StringType }), @@ -225,6 +228,23 @@ fun GalleryNavHost( ) } } + + // Ask audio. + composable( + route = "${LlmAskAudioDestination.route}/{modelName}", + arguments = listOf(navArgument("modelName") { type = NavType.StringType }), + enterTransition = { slideEnter() }, + exitTransition = { slideExit() }, + ) { + getModelFromNavigationParam(it, TASK_LLM_ASK_AUDIO)?.let { defaultModel -> + modelManagerViewModel.selectModel(defaultModel) + + LlmAskAudioScreen( + modelManagerViewModel = modelManagerViewModel, + navigateUp = { navController.navigateUp() }, + ) + } + } } // Handle incoming intents for deep links @@ -256,6 +276,7 @@ fun navigateToTaskScreen( when (taskType) { TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}") TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}") + TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}") TaskType.LLM_PROMPT_LAB -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}") TaskType.TEST_TASK_1 -> {} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt index 72ad1d5..a1a7ad1 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt @@ -120,6 +120,8 @@ data class CustomColors( val agentBubbleBgColor: Color = Color.Transparent, val linkColor: Color = Color.Transparent, val successColor: Color = Color.Transparent, + val recordButtonBgColor: Color = Color.Transparent, + val waveFormBgColor: Color = Color.Transparent, ) val LocalCustomColors = staticCompositionLocalOf { CustomColors() } @@ -145,6 +147,8 @@ val lightCustomColors = userBubbleBgColor = Color(0xFF32628D), linkColor = Color(0xFF32628D), successColor = Color(0xff3d860b), + recordButtonBgColor = Color(0xFFEE675C), + waveFormBgColor = Color(0xFFaaaaaa), ) val darkCustomColors = @@ -168,6 +172,8 @@ val darkCustomColors = userBubbleBgColor = Color(0xFF1f3760), linkColor = Color(0xFF9DCAFC), successColor = Color(0xFFA1CE83), + recordButtonBgColor = Color(0xFFEE675C), + waveFormBgColor = Color(0xFFaaaaaa), ) val MaterialTheme.customColors: CustomColors diff --git a/Android/src/app/src/main/proto/settings.proto b/Android/src/app/src/main/proto/settings.proto index 5540eaf..3b5f6f3 100644 --- a/Android/src/app/src/main/proto/settings.proto +++ b/Android/src/app/src/main/proto/settings.proto @@ -55,6 +55,7 @@ message LlmConfig { float default_topp = 4; float default_temperature = 5; bool support_image = 6; + bool support_audio = 7; } message Settings {