mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-05 14:10:35 -04:00
Add audio support.
- Add a new task "audio scribe". - Allow users to record audio clips or pick wav files to interact with model. - Add support for importing models with audio capability. - Fix a typo in Settings dialog (Thanks https://github.com/rhnvrm!) PiperOrigin-RevId: 774832681
This commit is contained in:
parent
33c3ee638e
commit
d0989adce1
27 changed files with 1369 additions and 288 deletions
|
@ -29,6 +29,7 @@
|
|||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
|
||||
<uses-permission android:name="android.permission.INTERNET" />
|
||||
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
<uses-permission android:name="android.permission.WAKE_LOCK"/>
|
||||
|
||||
<uses-feature
|
||||
|
|
|
@ -25,3 +25,5 @@ interface LatencyProvider {
|
|||
data class Classification(val label: String, val score: Float, val color: Color)
|
||||
|
||||
data class JsonObjAndTextContent<T>(val jsonObj: T, val textContent: String)
|
||||
|
||||
class AudioClip(val audioData: ByteArray, val sampleRate: Int)
|
||||
|
|
|
@ -17,12 +17,17 @@
|
|||
package com.google.ai.edge.gallery.common
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import com.google.ai.edge.gallery.data.SAMPLE_RATE
|
||||
import com.google.gson.Gson
|
||||
import com.google.gson.reflect.TypeToken
|
||||
import java.io.File
|
||||
import java.net.HttpURLConnection
|
||||
import java.net.URL
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import kotlin.math.floor
|
||||
|
||||
data class LaunchInfo(val ts: Long)
|
||||
|
||||
|
@ -112,3 +117,135 @@ inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
|
|||
|
||||
return null
|
||||
}
|
||||
|
||||
fun convertWavToMonoWithMaxSeconds(
|
||||
context: Context,
|
||||
stereoUri: Uri,
|
||||
maxSeconds: Int = 30,
|
||||
): AudioClip? {
|
||||
Log.d(TAG, "Start to convert wav file to mono channel")
|
||||
|
||||
try {
|
||||
val inputStream = context.contentResolver.openInputStream(stereoUri) ?: return null
|
||||
val originalBytes = inputStream.readBytes()
|
||||
inputStream.close()
|
||||
|
||||
// Read WAV header
|
||||
if (originalBytes.size < 44) {
|
||||
// Not a valid WAV file
|
||||
Log.e(TAG, "Not a valid wav file")
|
||||
return null
|
||||
}
|
||||
|
||||
val headerBuffer = ByteBuffer.wrap(originalBytes, 0, 44).order(ByteOrder.LITTLE_ENDIAN)
|
||||
val channels = headerBuffer.getShort(22)
|
||||
var sampleRate = headerBuffer.getInt(24)
|
||||
val bitDepth = headerBuffer.getShort(34)
|
||||
Log.d(TAG, "File metadata: channels: $channels, sampleRate: $sampleRate, bitDepth: $bitDepth")
|
||||
|
||||
// Normalize audio to 16-bit.
|
||||
val audioDataBytes = originalBytes.copyOfRange(fromIndex = 44, toIndex = originalBytes.size)
|
||||
var sixteenBitBytes: ByteArray =
|
||||
if (bitDepth.toInt() == 8) {
|
||||
Log.d(TAG, "Converting 8-bit audio to 16-bit.")
|
||||
convert8BitTo16Bit(audioDataBytes)
|
||||
} else {
|
||||
// Assume 16-bit or other format that can be handled directly
|
||||
audioDataBytes
|
||||
}
|
||||
|
||||
// Convert byte array to short array for processing
|
||||
val shortBuffer =
|
||||
ByteBuffer.wrap(sixteenBitBytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
|
||||
var pcmSamples = ShortArray(shortBuffer.remaining())
|
||||
shortBuffer.get(pcmSamples)
|
||||
|
||||
// Resample if sample rate is less than 16000 Hz ---
|
||||
if (sampleRate < SAMPLE_RATE) {
|
||||
Log.d(TAG, "Resampling from $sampleRate Hz to $SAMPLE_RATE Hz.")
|
||||
pcmSamples = resample(pcmSamples, sampleRate, SAMPLE_RATE, channels.toInt())
|
||||
sampleRate = SAMPLE_RATE
|
||||
Log.d(TAG, "Resampling complete. New sample count: ${pcmSamples.size}")
|
||||
}
|
||||
|
||||
// Convert stereo to mono if necessary
|
||||
var monoSamples =
|
||||
if (channels.toInt() == 2) {
|
||||
Log.d(TAG, "Converting stereo to mono.")
|
||||
val mono = ShortArray(pcmSamples.size / 2)
|
||||
for (i in mono.indices) {
|
||||
val left = pcmSamples[i * 2]
|
||||
val right = pcmSamples[i * 2 + 1]
|
||||
mono[i] = ((left + right) / 2).toShort()
|
||||
}
|
||||
mono
|
||||
} else {
|
||||
Log.d(TAG, "Audio is already mono. No channel conversion needed.")
|
||||
pcmSamples
|
||||
}
|
||||
|
||||
// Trim the audio to maxSeconds ---
|
||||
val maxSamples = maxSeconds * sampleRate
|
||||
if (monoSamples.size > maxSamples) {
|
||||
Log.d(TAG, "Trimming clip from ${monoSamples.size} samples to $maxSamples samples.")
|
||||
monoSamples = monoSamples.copyOfRange(0, maxSamples)
|
||||
}
|
||||
|
||||
val monoByteBuffer = ByteBuffer.allocate(monoSamples.size * 2).order(ByteOrder.LITTLE_ENDIAN)
|
||||
monoByteBuffer.asShortBuffer().put(monoSamples)
|
||||
return AudioClip(audioData = monoByteBuffer.array(), sampleRate = sampleRate)
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to convert wav to mono", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/** Converts 8-bit unsigned PCM audio data to 16-bit signed PCM. */
|
||||
private fun convert8BitTo16Bit(eightBitData: ByteArray): ByteArray {
|
||||
// The new 16-bit data will be twice the size
|
||||
val sixteenBitData = ByteArray(eightBitData.size * 2)
|
||||
val buffer = ByteBuffer.wrap(sixteenBitData).order(ByteOrder.LITTLE_ENDIAN)
|
||||
|
||||
for (byte in eightBitData) {
|
||||
// Convert the unsigned 8-bit byte (0-255) to a signed 16-bit short (-32768 to 32767)
|
||||
// 1. Get the unsigned value by masking with 0xFF
|
||||
// 2. Subtract 128 to center the waveform around 0 (range becomes -128 to 127)
|
||||
// 3. Scale by 256 to expand to the 16-bit range
|
||||
val unsignedByte = byte.toInt() and 0xFF
|
||||
val sixteenBitSample = ((unsignedByte - 128) * 256).toShort()
|
||||
buffer.putShort(sixteenBitSample)
|
||||
}
|
||||
return sixteenBitData
|
||||
}
|
||||
|
||||
/** Resamples PCM audio data from an original sample rate to a target sample rate. */
|
||||
private fun resample(
|
||||
inputSamples: ShortArray,
|
||||
originalSampleRate: Int,
|
||||
targetSampleRate: Int,
|
||||
channels: Int,
|
||||
): ShortArray {
|
||||
if (originalSampleRate == targetSampleRate) {
|
||||
return inputSamples
|
||||
}
|
||||
|
||||
val ratio = targetSampleRate.toDouble() / originalSampleRate
|
||||
val outputLength = (inputSamples.size * ratio).toInt()
|
||||
val resampledData = ShortArray(outputLength)
|
||||
|
||||
if (channels == 1) { // Mono
|
||||
for (i in resampledData.indices) {
|
||||
val position = i / ratio
|
||||
val index1 = floor(position).toInt()
|
||||
val index2 = index1 + 1
|
||||
val fraction = position - index1
|
||||
|
||||
val sample1 = if (index1 < inputSamples.size) inputSamples[index1].toDouble() else 0.0
|
||||
val sample2 = if (index2 < inputSamples.size) inputSamples[index2].toDouble() else 0.0
|
||||
|
||||
resampledData[i] = (sample1 * (1 - fraction) + sample2 * fraction).toInt().toShort()
|
||||
}
|
||||
}
|
||||
|
||||
return resampledData
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ enum class ConfigKey(val label: String) {
|
|||
DEFAULT_TOPP("Default TopP"),
|
||||
DEFAULT_TEMPERATURE("Default temperature"),
|
||||
SUPPORT_IMAGE("Support image"),
|
||||
SUPPORT_AUDIO("Support audio"),
|
||||
MAX_RESULT_COUNT("Max result count"),
|
||||
USE_GPU("Use GPU"),
|
||||
ACCELERATOR("Choose accelerator"),
|
||||
|
|
|
@ -44,3 +44,12 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
|
|||
|
||||
// Max number of images allowed in a "ask image" session.
|
||||
const val MAX_IMAGE_COUNT = 10
|
||||
|
||||
// Max number of audio clip in an "ask audio" session.
|
||||
const val MAX_AUDIO_CLIP_COUNT = 10
|
||||
|
||||
// Max audio clip duration in seconds.
|
||||
const val MAX_AUDIO_CLIP_DURATION_SEC = 30
|
||||
|
||||
// Audio-recording related consts.
|
||||
const val SAMPLE_RATE = 16000
|
||||
|
|
|
@ -87,6 +87,9 @@ data class Model(
|
|||
/** Whether the LLM model supports image input. */
|
||||
val llmSupportImage: Boolean = false,
|
||||
|
||||
/** Whether the LLM model supports audio input. */
|
||||
val llmSupportAudio: Boolean = false,
|
||||
|
||||
/** Whether the model is imported or not. */
|
||||
val imported: Boolean = false,
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ data class AllowedModel(
|
|||
val taskTypes: List<String>,
|
||||
val disabled: Boolean? = null,
|
||||
val llmSupportImage: Boolean? = null,
|
||||
val llmSupportAudio: Boolean? = null,
|
||||
val estimatedPeakMemoryInBytes: Long? = null,
|
||||
) {
|
||||
fun toModel(): Model {
|
||||
|
@ -96,6 +97,7 @@ data class AllowedModel(
|
|||
showRunAgainButton = showRunAgainButton,
|
||||
learnMoreUrl = "https://huggingface.co/${modelId}",
|
||||
llmSupportImage = llmSupportImage == true,
|
||||
llmSupportAudio = llmSupportAudio == true,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,19 +17,22 @@
|
|||
package com.google.ai.edge.gallery.data
|
||||
|
||||
import androidx.annotation.StringRes
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.outlined.Forum
|
||||
import androidx.compose.material.icons.outlined.Mic
|
||||
import androidx.compose.material.icons.outlined.Mms
|
||||
import androidx.compose.material.icons.outlined.Widgets
|
||||
import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.mutableLongStateOf
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import com.google.ai.edge.gallery.R
|
||||
import com.google.ai.edge.gallery.ui.icon.Forum
|
||||
import com.google.ai.edge.gallery.ui.icon.Mms
|
||||
import com.google.ai.edge.gallery.ui.icon.Widgets
|
||||
|
||||
/** Type of task. */
|
||||
enum class TaskType(val label: String, val id: String) {
|
||||
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
|
||||
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
|
||||
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
|
||||
LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
|
||||
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
||||
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
|
||||
}
|
||||
|
@ -71,7 +74,7 @@ data class Task(
|
|||
val TASK_LLM_CHAT =
|
||||
Task(
|
||||
type = TaskType.LLM_CHAT,
|
||||
icon = Forum,
|
||||
icon = Icons.Outlined.Forum,
|
||||
models = mutableListOf(),
|
||||
description = "Chat with on-device large language models",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
|
@ -83,7 +86,7 @@ val TASK_LLM_CHAT =
|
|||
val TASK_LLM_PROMPT_LAB =
|
||||
Task(
|
||||
type = TaskType.LLM_PROMPT_LAB,
|
||||
icon = Widgets,
|
||||
icon = Icons.Outlined.Widgets,
|
||||
models = mutableListOf(),
|
||||
description = "Single turn use cases with on-device large language model",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
|
@ -95,7 +98,7 @@ val TASK_LLM_PROMPT_LAB =
|
|||
val TASK_LLM_ASK_IMAGE =
|
||||
Task(
|
||||
type = TaskType.LLM_ASK_IMAGE,
|
||||
icon = Mms,
|
||||
icon = Icons.Outlined.Mms,
|
||||
models = mutableListOf(),
|
||||
description = "Ask questions about images with on-device large language models",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
|
@ -104,8 +107,23 @@ val TASK_LLM_ASK_IMAGE =
|
|||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
|
||||
)
|
||||
|
||||
val TASK_LLM_ASK_AUDIO =
|
||||
Task(
|
||||
type = TaskType.LLM_ASK_AUDIO,
|
||||
icon = Icons.Outlined.Mic,
|
||||
models = mutableListOf(),
|
||||
// TODO(do not submit)
|
||||
description =
|
||||
"Instantly transcribe and/or translate audio clips using on-device large language models",
|
||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||
sourceCodeUrl =
|
||||
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
|
||||
)
|
||||
|
||||
/** All tasks. */
|
||||
val TASKS: List<Task> = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
|
||||
val TASKS: List<Task> =
|
||||
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
|
||||
|
||||
fun getModelByName(name: String): Model? {
|
||||
for (task in TASKS) {
|
||||
|
|
|
@ -21,6 +21,7 @@ import androidx.lifecycle.viewmodel.CreationExtras
|
|||
import androidx.lifecycle.viewmodel.initializer
|
||||
import androidx.lifecycle.viewmodel.viewModelFactory
|
||||
import com.google.ai.edge.gallery.GalleryApplication
|
||||
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
|
||||
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
|
||||
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
|
||||
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||
|
@ -49,6 +50,9 @@ object ViewModelProvider {
|
|||
|
||||
// Initializer for LlmAskImageViewModel.
|
||||
initializer { LlmAskImageViewModel() }
|
||||
|
||||
// Initializer for LlmAskAudioViewModel.
|
||||
initializer { LlmAskAudioViewModel() }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,344 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.ai.edge.gallery.ui.common.chat
|
||||
|
||||
import android.media.AudioAttributes
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioTrack
|
||||
import android.util.Log
|
||||
import androidx.compose.foundation.Canvas
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.PlayArrow
|
||||
import androidx.compose.material.icons.rounded.Stop
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.DisposableEffect
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableFloatStateOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.geometry.CornerRadius
|
||||
import androidx.compose.ui.geometry.Offset
|
||||
import androidx.compose.ui.geometry.Size
|
||||
import androidx.compose.ui.geometry.toRect
|
||||
import androidx.compose.ui.graphics.BlendMode
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.drawscope.drawIntoCanvas
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
|
||||
import com.google.ai.edge.gallery.ui.theme.customColors
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
private const val TAG = "AGAudioPlaybackPanel"
|
||||
private const val BAR_SPACE = 2
|
||||
private const val BAR_WIDTH = 2
|
||||
private const val MIN_BAR_COUNT = 16
|
||||
private const val MAX_BAR_COUNT = 48
|
||||
|
||||
/**
|
||||
* A Composable that displays an audio playback panel, including play/stop controls, a waveform
|
||||
* visualization, and the duration of the audio clip.
|
||||
*/
|
||||
@Composable
|
||||
fun AudioPlaybackPanel(
|
||||
audioData: ByteArray,
|
||||
sampleRate: Int,
|
||||
isRecording: Boolean,
|
||||
modifier: Modifier = Modifier,
|
||||
onDarkBg: Boolean = false,
|
||||
) {
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
var isPlaying by remember { mutableStateOf(false) }
|
||||
val audioTrackState = remember { mutableStateOf<AudioTrack?>(null) }
|
||||
val durationInSeconds =
|
||||
remember(audioData) {
|
||||
// PCM 16-bit
|
||||
val bytesPerSample = 2
|
||||
val bytesPerFrame = bytesPerSample * 1 // mono
|
||||
val totalFrames = audioData.size.toDouble() / bytesPerFrame
|
||||
totalFrames / sampleRate
|
||||
}
|
||||
val barCount =
|
||||
remember(durationInSeconds) {
|
||||
val f = durationInSeconds / MAX_AUDIO_CLIP_DURATION_SEC
|
||||
((MAX_BAR_COUNT - MIN_BAR_COUNT) * f + MIN_BAR_COUNT).toInt()
|
||||
}
|
||||
val amplitudeLevels =
|
||||
remember(audioData) { generateAmplitudeLevels(audioData = audioData, barCount = barCount) }
|
||||
var playbackProgress by remember { mutableFloatStateOf(0f) }
|
||||
|
||||
// Reset when a new recording is started.
|
||||
LaunchedEffect(isRecording) {
|
||||
if (isRecording) {
|
||||
val audioTrack = audioTrackState.value
|
||||
audioTrack?.stop()
|
||||
isPlaying = false
|
||||
playbackProgress = 0f
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup on Composable Disposal.
|
||||
DisposableEffect(Unit) {
|
||||
onDispose {
|
||||
val audioTrack = audioTrackState.value
|
||||
audioTrack?.stop()
|
||||
audioTrack?.release()
|
||||
}
|
||||
}
|
||||
|
||||
Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) {
|
||||
// Button to play/stop the clip.
|
||||
IconButton(
|
||||
onClick = {
|
||||
coroutineScope.launch {
|
||||
if (!isPlaying) {
|
||||
isPlaying = true
|
||||
playAudio(
|
||||
audioTrackState = audioTrackState,
|
||||
audioData = audioData,
|
||||
sampleRate = sampleRate,
|
||||
onProgress = { playbackProgress = it },
|
||||
onCompletion = {
|
||||
playbackProgress = 0f
|
||||
isPlaying = false
|
||||
},
|
||||
)
|
||||
} else {
|
||||
stopPlayAudio(audioTrackState = audioTrackState)
|
||||
playbackProgress = 0f
|
||||
isPlaying = false
|
||||
}
|
||||
}
|
||||
}
|
||||
) {
|
||||
Icon(
|
||||
if (isPlaying) Icons.Rounded.Stop else Icons.Rounded.PlayArrow,
|
||||
contentDescription = "",
|
||||
tint = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
|
||||
)
|
||||
}
|
||||
|
||||
// Visualization
|
||||
AmplitudeBarGraph(
|
||||
amplitudeLevels = amplitudeLevels,
|
||||
progress = playbackProgress,
|
||||
modifier =
|
||||
Modifier.width((barCount * BAR_WIDTH + (barCount - 1) * BAR_SPACE).dp).height(24.dp),
|
||||
onDarkBg = onDarkBg,
|
||||
)
|
||||
|
||||
// Duration
|
||||
Text(
|
||||
"${"%.1f".format(durationInSeconds)}s",
|
||||
style = MaterialTheme.typography.labelLarge,
|
||||
color = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
|
||||
modifier = Modifier.padding(start = 12.dp),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun AmplitudeBarGraph(
|
||||
amplitudeLevels: List<Float>,
|
||||
progress: Float,
|
||||
modifier: Modifier = Modifier,
|
||||
onDarkBg: Boolean = false,
|
||||
) {
|
||||
val barColor = MaterialTheme.customColors.waveFormBgColor
|
||||
val progressColor = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary
|
||||
|
||||
Canvas(modifier = modifier) {
|
||||
val barCount = amplitudeLevels.size
|
||||
val barWidth = (size.width - BAR_SPACE.dp.toPx() * (barCount - 1)) / barCount
|
||||
val cornerRadius = CornerRadius(x = barWidth, y = barWidth)
|
||||
|
||||
// Use drawIntoCanvas for advanced blend mode operations
|
||||
drawIntoCanvas { canvas ->
|
||||
|
||||
// 1. Save the current state of the canvas onto a temporary, offscreen layer
|
||||
canvas.saveLayer(size.toRect(), androidx.compose.ui.graphics.Paint())
|
||||
|
||||
// 2. Draw the bars in grey.
|
||||
amplitudeLevels.forEachIndexed { index, level ->
|
||||
val barHeight = (level * size.height).coerceAtLeast(1.5f)
|
||||
val left = index * (barWidth + BAR_SPACE.dp.toPx())
|
||||
drawRoundRect(
|
||||
color = barColor,
|
||||
topLeft = Offset(x = left, y = size.height / 2 - barHeight / 2),
|
||||
size = Size(barWidth, barHeight),
|
||||
cornerRadius = cornerRadius,
|
||||
)
|
||||
}
|
||||
|
||||
// 3. Draw the progress rectangle using BlendMode.SrcIn to only draw where the bars already
|
||||
// exists.
|
||||
val progressWidth = size.width * progress
|
||||
drawRect(
|
||||
color = progressColor,
|
||||
topLeft = Offset.Zero,
|
||||
size = Size(progressWidth, size.height),
|
||||
blendMode = BlendMode.SrcIn,
|
||||
)
|
||||
|
||||
// 4. Restore the layer, merging it onto the main canvas
|
||||
canvas.restore()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun playAudio(
|
||||
audioTrackState: MutableState<AudioTrack?>,
|
||||
audioData: ByteArray,
|
||||
sampleRate: Int,
|
||||
onProgress: (Float) -> Unit,
|
||||
onCompletion: () -> Unit,
|
||||
) {
|
||||
Log.d(TAG, "Start playing audio...")
|
||||
|
||||
try {
|
||||
withContext(Dispatchers.IO) {
|
||||
var lastProgressUpdateMs = 0L
|
||||
audioTrackState.value?.release()
|
||||
val audioTrack =
|
||||
AudioTrack.Builder()
|
||||
.setAudioAttributes(
|
||||
AudioAttributes.Builder()
|
||||
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||||
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||||
.build()
|
||||
)
|
||||
.setAudioFormat(
|
||||
AudioFormat.Builder()
|
||||
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
|
||||
.setSampleRate(sampleRate)
|
||||
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||||
.build()
|
||||
)
|
||||
.setTransferMode(AudioTrack.MODE_STATIC)
|
||||
.setBufferSizeInBytes(audioData.size)
|
||||
.build()
|
||||
|
||||
val bytesPerFrame = 2 // For PCM 16-bit Mono
|
||||
val totalFrames = audioData.size / bytesPerFrame
|
||||
|
||||
audioTrackState.value = audioTrack
|
||||
audioTrack.write(audioData, 0, audioData.size)
|
||||
audioTrack.play()
|
||||
|
||||
// Coroutine to monitor progress
|
||||
while (isActive && audioTrack.playState == AudioTrack.PLAYSTATE_PLAYING) {
|
||||
val currentFrames = audioTrack.playbackHeadPosition
|
||||
if (currentFrames >= totalFrames) {
|
||||
break // Exit loop when playback is done
|
||||
}
|
||||
val progress = currentFrames.toFloat() / totalFrames
|
||||
val curMs = System.currentTimeMillis()
|
||||
if (curMs - lastProgressUpdateMs > 30) {
|
||||
onProgress(progress)
|
||||
lastProgressUpdateMs = curMs
|
||||
}
|
||||
}
|
||||
|
||||
if (isActive) {
|
||||
audioTrackState.value?.stop()
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
// Ignore
|
||||
} finally {
|
||||
onProgress(1f)
|
||||
onCompletion()
|
||||
}
|
||||
}
|
||||
|
||||
private fun stopPlayAudio(audioTrackState: MutableState<AudioTrack?>) {
|
||||
Log.d(TAG, "Stopping playing audio...")
|
||||
|
||||
val audioTrack = audioTrackState.value
|
||||
audioTrack?.stop()
|
||||
audioTrack?.release()
|
||||
audioTrackState.value = null
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes a raw PCM 16-bit audio byte array to generate a list of normalized amplitude levels for
|
||||
* visualization.
|
||||
*/
|
||||
private fun generateAmplitudeLevels(audioData: ByteArray, barCount: Int): List<Float> {
|
||||
if (audioData.isEmpty()) {
|
||||
return List(barCount) { 0f }
|
||||
}
|
||||
|
||||
// 1. Parse bytes into 16-bit short samples (PCM 16-bit)
|
||||
val shortBuffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
|
||||
val samples = ShortArray(shortBuffer.remaining())
|
||||
shortBuffer.get(samples)
|
||||
|
||||
if (samples.isEmpty()) {
|
||||
return List(barCount) { 0f }
|
||||
}
|
||||
|
||||
// 2. Determine the size of each chunk
|
||||
val chunkSize = samples.size / barCount
|
||||
val amplitudeLevels = mutableListOf<Float>()
|
||||
|
||||
// 3. Get the max value for each chunk
|
||||
for (i in 0 until barCount) {
|
||||
val chunkStart = i * chunkSize
|
||||
val chunkEnd = (chunkStart + chunkSize).coerceAtMost(samples.size)
|
||||
|
||||
var maxAmplitudeInChunk = 0.0
|
||||
|
||||
for (j in chunkStart until chunkEnd) {
|
||||
val sampleAbs = kotlin.math.abs(samples[j].toDouble())
|
||||
if (sampleAbs > maxAmplitudeInChunk) {
|
||||
maxAmplitudeInChunk = sampleAbs
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Normalize the value (0 to 1)
|
||||
// Short.MAX_VALUE is 32767.0, a good reference for max amplitude
|
||||
val normalizedRms = (maxAmplitudeInChunk / Short.MAX_VALUE).toFloat().coerceIn(0f, 1f)
|
||||
amplitudeLevels.add(normalizedRms)
|
||||
}
|
||||
|
||||
// Normalize the resulting levels so that the max value becomes 0.9.
|
||||
val maxVal = amplitudeLevels.max()
|
||||
if (maxVal == 0f) {
|
||||
return amplitudeLevels
|
||||
}
|
||||
val scaleFactor = 0.9f / maxVal
|
||||
return amplitudeLevels.map { it * scaleFactor }
|
||||
}
|
|
@ -0,0 +1,327 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.ai.edge.gallery.ui.common.chat
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.content.Context
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import android.media.MediaRecorder
|
||||
import android.util.Log
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.ArrowUpward
|
||||
import androidx.compose.material.icons.rounded.Mic
|
||||
import androidx.compose.material.icons.rounded.Stop
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.DisposableEffect
|
||||
import androidx.compose.runtime.MutableLongState
|
||||
import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.derivedStateOf
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableIntStateOf
|
||||
import androidx.compose.runtime.mutableLongStateOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.graphicsLayer
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.ai.edge.gallery.R
|
||||
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
|
||||
import com.google.ai.edge.gallery.data.SAMPLE_RATE
|
||||
import com.google.ai.edge.gallery.ui.theme.customColors
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.pow
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
private const val TAG = "AGAudioRecorderPanel"
|
||||
|
||||
private const val CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO
|
||||
private const val AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT
|
||||
|
||||
/**
|
||||
* A Composable that provides an audio recording panel. It allows users to record audio clips,
|
||||
* displays the recording duration and a live amplitude visualization, and provides options to play
|
||||
* back the recorded clip or send it.
|
||||
*/
|
||||
@Composable
|
||||
fun AudioRecorderPanel(onSendAudioClip: (ByteArray) -> Unit) {
|
||||
val context = LocalContext.current
|
||||
val coroutineScope = rememberCoroutineScope()
|
||||
|
||||
var isRecording by remember { mutableStateOf(false) }
|
||||
val elapsedMs = remember { mutableLongStateOf(0L) }
|
||||
val audioRecordState = remember { mutableStateOf<AudioRecord?>(null) }
|
||||
val audioStream = remember { ByteArrayOutputStream() }
|
||||
val recordedBytes = remember { mutableStateOf<ByteArray?>(null) }
|
||||
var currentAmplitude by remember { mutableIntStateOf(0) }
|
||||
|
||||
val elapsedSeconds by remember {
|
||||
derivedStateOf { "%.1f".format(elapsedMs.value.toFloat() / 1000f) }
|
||||
}
|
||||
|
||||
// Cleanup on Composable Disposal.
|
||||
DisposableEffect(Unit) { onDispose { audioRecordState.value?.release() } }
|
||||
|
||||
Column(modifier = Modifier.padding(bottom = 12.dp)) {
|
||||
// Title bar.
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.SpaceBetween,
|
||||
) {
|
||||
// Logo and state.
|
||||
Row(
|
||||
modifier = Modifier.padding(start = 16.dp),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||
) {
|
||||
Icon(
|
||||
painterResource(R.drawable.logo),
|
||||
modifier = Modifier.size(20.dp),
|
||||
contentDescription = "",
|
||||
tint = Color.Unspecified,
|
||||
)
|
||||
Text(
|
||||
"Record audio clip (up to $MAX_AUDIO_CLIP_DURATION_SEC seconds)",
|
||||
style = MaterialTheme.typography.labelLarge,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Recorded clip.
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth().padding(vertical = 12.dp).height(40.dp),
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.Center,
|
||||
) {
|
||||
val curRecordedBytes = recordedBytes.value
|
||||
if (curRecordedBytes == null) {
|
||||
// Info message when there is no recorded clip and the recording has not started yet.
|
||||
if (!isRecording) {
|
||||
Text(
|
||||
"Tap the record button to start",
|
||||
style = MaterialTheme.typography.labelLarge,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||
)
|
||||
}
|
||||
// Visualization for clip being recorded.
|
||||
else {
|
||||
Row(
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
) {
|
||||
Box(
|
||||
modifier =
|
||||
Modifier.size(8.dp)
|
||||
.background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
|
||||
)
|
||||
Text("$elapsedSeconds s")
|
||||
}
|
||||
}
|
||||
}
|
||||
// Controls for recorded clip.
|
||||
else {
|
||||
Row {
|
||||
// Clip player.
|
||||
AudioPlaybackPanel(
|
||||
audioData = curRecordedBytes,
|
||||
sampleRate = SAMPLE_RATE,
|
||||
isRecording = isRecording,
|
||||
)
|
||||
|
||||
// Button to send the clip
|
||||
IconButton(onClick = { onSendAudioClip(curRecordedBytes) }) {
|
||||
Icon(Icons.Rounded.ArrowUpward, contentDescription = "")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Buttons
|
||||
Box(contentAlignment = Alignment.Center, modifier = Modifier.fillMaxWidth().height(40.dp)) {
|
||||
// Visualization of the current amplitude.
|
||||
if (isRecording) {
|
||||
// Normalize the amplitude (0-32767) to a fraction (0.0-1.0)
|
||||
// We use a power scale (exponent < 1) to make the pulse more visible for lower volumes.
|
||||
val normalizedAmplitude = (currentAmplitude.toFloat() / 32767f).pow(0.35f)
|
||||
// Define the min and max size of the circle
|
||||
val minSize = 38.dp
|
||||
val maxSize = 100.dp
|
||||
|
||||
// Map the normalized amplitude to our size range
|
||||
val scale by
|
||||
remember(normalizedAmplitude) {
|
||||
derivedStateOf { (minSize + (maxSize - minSize) * normalizedAmplitude) / minSize }
|
||||
}
|
||||
Box(
|
||||
modifier =
|
||||
Modifier.size(minSize)
|
||||
.graphicsLayer(scaleX = scale, scaleY = scale, clip = false, alpha = 0.3f)
|
||||
.background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
|
||||
)
|
||||
}
|
||||
|
||||
// Record/stop button.
|
||||
IconButton(
|
||||
onClick = {
|
||||
coroutineScope.launch {
|
||||
if (!isRecording) {
|
||||
isRecording = true
|
||||
recordedBytes.value = null
|
||||
startRecording(
|
||||
context = context,
|
||||
audioRecordState = audioRecordState,
|
||||
audioStream = audioStream,
|
||||
elapsedMs = elapsedMs,
|
||||
onAmplitudeChanged = { currentAmplitude = it },
|
||||
onMaxDurationReached = {
|
||||
val curRecordedBytes =
|
||||
stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
|
||||
recordedBytes.value = curRecordedBytes
|
||||
isRecording = false
|
||||
},
|
||||
)
|
||||
} else {
|
||||
val curRecordedBytes =
|
||||
stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
|
||||
recordedBytes.value = curRecordedBytes
|
||||
isRecording = false
|
||||
}
|
||||
}
|
||||
},
|
||||
modifier =
|
||||
Modifier.clip(CircleShape).background(MaterialTheme.customColors.recordButtonBgColor),
|
||||
) {
|
||||
Icon(
|
||||
if (isRecording) Icons.Rounded.Stop else Icons.Rounded.Mic,
|
||||
contentDescription = "",
|
||||
tint = Color.White,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Permission is checked in parent composable.
|
||||
@SuppressLint("MissingPermission")
|
||||
private suspend fun startRecording(
|
||||
context: Context,
|
||||
audioRecordState: MutableState<AudioRecord?>,
|
||||
audioStream: ByteArrayOutputStream,
|
||||
elapsedMs: MutableLongState,
|
||||
onAmplitudeChanged: (Int) -> Unit,
|
||||
onMaxDurationReached: () -> Unit,
|
||||
) {
|
||||
Log.d(TAG, "Start recording...")
|
||||
val minBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT)
|
||||
|
||||
audioRecordState.value?.release()
|
||||
val recorder =
|
||||
AudioRecord(
|
||||
MediaRecorder.AudioSource.MIC,
|
||||
SAMPLE_RATE,
|
||||
CHANNEL_CONFIG,
|
||||
AUDIO_FORMAT,
|
||||
minBufferSize,
|
||||
)
|
||||
|
||||
audioRecordState.value = recorder
|
||||
val buffer = ByteArray(minBufferSize)
|
||||
|
||||
// The function will only return when the recording is done (when stopRecording is called).
|
||||
coroutineScope {
|
||||
launch(Dispatchers.IO) {
|
||||
recorder.startRecording()
|
||||
|
||||
val startMs = System.currentTimeMillis()
|
||||
elapsedMs.value = 0L
|
||||
while (audioRecordState.value?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
|
||||
val bytesRead = recorder.read(buffer, 0, buffer.size)
|
||||
if (bytesRead > 0) {
|
||||
val currentAmplitude = calculatePeakAmplitude(buffer = buffer, bytesRead = bytesRead)
|
||||
onAmplitudeChanged(currentAmplitude)
|
||||
audioStream.write(buffer, 0, bytesRead)
|
||||
}
|
||||
elapsedMs.value = System.currentTimeMillis() - startMs
|
||||
if (elapsedMs.value >= MAX_AUDIO_CLIP_DURATION_SEC * 1000) {
|
||||
onMaxDurationReached()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun stopRecording(
|
||||
audioRecordState: MutableState<AudioRecord?>,
|
||||
audioStream: ByteArrayOutputStream,
|
||||
): ByteArray {
|
||||
Log.d(TAG, "Stopping recording...")
|
||||
|
||||
val recorder = audioRecordState.value
|
||||
if (recorder?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
|
||||
recorder.stop()
|
||||
}
|
||||
recorder?.release()
|
||||
audioRecordState.value = null
|
||||
|
||||
val recordedBytes = audioStream.toByteArray()
|
||||
audioStream.reset()
|
||||
Log.d(TAG, "Stopped. Recorded ${recordedBytes.size} bytes.")
|
||||
|
||||
return recordedBytes
|
||||
}
|
||||
|
||||
private fun calculatePeakAmplitude(buffer: ByteArray, bytesRead: Int): Int {
|
||||
// Wrap the byte array in a ByteBuffer and set the order to little-endian
|
||||
val shortBuffer =
|
||||
ByteBuffer.wrap(buffer, 0, bytesRead).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
|
||||
|
||||
var maxAmplitude = 0
|
||||
// Iterate through the short buffer to find the maximum absolute value
|
||||
while (shortBuffer.hasRemaining()) {
|
||||
val currentSample = abs(shortBuffer.get().toInt())
|
||||
if (currentSample > maxAmplitude) {
|
||||
maxAmplitude = currentSample
|
||||
}
|
||||
}
|
||||
return maxAmplitude
|
||||
}
|
|
@ -17,18 +17,22 @@
|
|||
package com.google.ai.edge.gallery.ui.common.chat
|
||||
|
||||
import android.graphics.Bitmap
|
||||
import android.util.Log
|
||||
import androidx.compose.ui.graphics.ImageBitmap
|
||||
import androidx.compose.ui.unit.Dp
|
||||
import com.google.ai.edge.gallery.common.Classification
|
||||
import com.google.ai.edge.gallery.data.Model
|
||||
import com.google.ai.edge.gallery.data.PromptTemplate
|
||||
|
||||
private const val TAG = "AGChatMessage"
|
||||
|
||||
enum class ChatMessageType {
|
||||
INFO,
|
||||
WARNING,
|
||||
TEXT,
|
||||
IMAGE,
|
||||
IMAGE_WITH_HISTORY,
|
||||
AUDIO_CLIP,
|
||||
LOADING,
|
||||
CLASSIFICATION,
|
||||
CONFIG_VALUES_CHANGE,
|
||||
|
@ -121,6 +125,90 @@ class ChatMessageImage(
|
|||
}
|
||||
}
|
||||
|
||||
/** Chat message for audio clip. */
|
||||
class ChatMessageAudioClip(
|
||||
val audioData: ByteArray,
|
||||
val sampleRate: Int,
|
||||
override val side: ChatSide,
|
||||
override val latencyMs: Float = 0f,
|
||||
) : ChatMessage(type = ChatMessageType.AUDIO_CLIP, side = side, latencyMs = latencyMs) {
|
||||
override fun clone(): ChatMessageAudioClip {
|
||||
return ChatMessageAudioClip(
|
||||
audioData = audioData,
|
||||
sampleRate = sampleRate,
|
||||
side = side,
|
||||
latencyMs = latencyMs,
|
||||
)
|
||||
}
|
||||
|
||||
fun genByteArrayForWav(): ByteArray {
|
||||
val header = ByteArray(44)
|
||||
|
||||
val pcmDataSize = audioData.size
|
||||
val wavFileSize = pcmDataSize + 44 // 44 bytes for the header
|
||||
val channels = 1 // Mono
|
||||
val bitsPerSample: Short = 16
|
||||
val byteRate = sampleRate * channels * bitsPerSample / 8
|
||||
Log.d(TAG, "Wav metadata: sampleRate: $sampleRate")
|
||||
|
||||
// RIFF/WAVE header
|
||||
header[0] = 'R'.code.toByte()
|
||||
header[1] = 'I'.code.toByte()
|
||||
header[2] = 'F'.code.toByte()
|
||||
header[3] = 'F'.code.toByte()
|
||||
header[4] = (wavFileSize and 0xff).toByte()
|
||||
header[5] = (wavFileSize shr 8 and 0xff).toByte()
|
||||
header[6] = (wavFileSize shr 16 and 0xff).toByte()
|
||||
header[7] = (wavFileSize shr 24 and 0xff).toByte()
|
||||
header[8] = 'W'.code.toByte()
|
||||
header[9] = 'A'.code.toByte()
|
||||
header[10] = 'V'.code.toByte()
|
||||
header[11] = 'E'.code.toByte()
|
||||
header[12] = 'f'.code.toByte()
|
||||
header[13] = 'm'.code.toByte()
|
||||
header[14] = 't'.code.toByte()
|
||||
header[15] = ' '.code.toByte()
|
||||
header[16] = 16
|
||||
header[17] = 0
|
||||
header[18] = 0
|
||||
header[19] = 0 // Sub-chunk size (16 for PCM)
|
||||
header[20] = 1
|
||||
header[21] = 0 // Audio format (1 for PCM)
|
||||
header[22] = channels.toByte()
|
||||
header[23] = 0 // Number of channels
|
||||
header[24] = (sampleRate and 0xff).toByte()
|
||||
header[25] = (sampleRate shr 8 and 0xff).toByte()
|
||||
header[26] = (sampleRate shr 16 and 0xff).toByte()
|
||||
header[27] = (sampleRate shr 24 and 0xff).toByte()
|
||||
header[28] = (byteRate and 0xff).toByte()
|
||||
header[29] = (byteRate shr 8 and 0xff).toByte()
|
||||
header[30] = (byteRate shr 16 and 0xff).toByte()
|
||||
header[31] = (byteRate shr 24 and 0xff).toByte()
|
||||
header[32] = (channels * bitsPerSample / 8).toByte()
|
||||
header[33] = 0 // Block align
|
||||
header[34] = bitsPerSample.toByte()
|
||||
header[35] = (bitsPerSample.toInt() shr 8 and 0xff).toByte() // Bits per sample
|
||||
header[36] = 'd'.code.toByte()
|
||||
header[37] = 'a'.code.toByte()
|
||||
header[38] = 't'.code.toByte()
|
||||
header[39] = 'a'.code.toByte()
|
||||
header[40] = (pcmDataSize and 0xff).toByte()
|
||||
header[41] = (pcmDataSize shr 8 and 0xff).toByte()
|
||||
header[42] = (pcmDataSize shr 16 and 0xff).toByte()
|
||||
header[43] = (pcmDataSize shr 24 and 0xff).toByte()
|
||||
|
||||
return header + audioData
|
||||
}
|
||||
|
||||
fun getDurationInSeconds(): Float {
|
||||
// PCM 16-bit
|
||||
val bytesPerSample = 2
|
||||
val bytesPerFrame = bytesPerSample * 1 // mono
|
||||
val totalFrames = audioData.size.toFloat() / bytesPerFrame
|
||||
return totalFrames / sampleRate
|
||||
}
|
||||
}
|
||||
|
||||
/** Chat message for images with history. */
|
||||
class ChatMessageImageWithHistory(
|
||||
val bitmaps: List<Bitmap>,
|
||||
|
|
|
@ -137,6 +137,19 @@ fun ChatPanel(
|
|||
}
|
||||
imageMessageCount
|
||||
}
|
||||
val audioClipMesssageCountToLastconfigChange =
|
||||
remember(messages) {
|
||||
var audioClipMessageCount = 0
|
||||
for (message in messages.reversed()) {
|
||||
if (message is ChatMessageConfigValuesChange) {
|
||||
break
|
||||
}
|
||||
if (message is ChatMessageAudioClip) {
|
||||
audioClipMessageCount++
|
||||
}
|
||||
}
|
||||
audioClipMessageCount
|
||||
}
|
||||
|
||||
var curMessage by remember { mutableStateOf("") } // Correct state
|
||||
val focusManager = LocalFocusManager.current
|
||||
|
@ -342,6 +355,9 @@ fun ChatPanel(
|
|||
imageHistoryCurIndex = imageHistoryCurIndex,
|
||||
)
|
||||
|
||||
// Audio clip.
|
||||
is ChatMessageAudioClip -> MessageBodyAudioClip(message = message)
|
||||
|
||||
// Classification result
|
||||
is ChatMessageClassification ->
|
||||
MessageBodyClassification(
|
||||
|
@ -467,6 +483,22 @@ fun ChatPanel(
|
|||
)
|
||||
}
|
||||
}
|
||||
// Show an info message for ask image task to get users started.
|
||||
else if (task.type == TaskType.LLM_ASK_AUDIO && messages.isEmpty()) {
|
||||
Column(
|
||||
modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center,
|
||||
) {
|
||||
MessageBodyInfo(
|
||||
ChatMessageInfo(
|
||||
content =
|
||||
"To get started, tap the + icon to add your audio clips. You can add up to 10 clips, each up to 30 seconds long."
|
||||
),
|
||||
smallFontSize = false,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Chat input
|
||||
|
@ -482,6 +514,7 @@ fun ChatPanel(
|
|||
isResettingSession = uiState.isResettingSession,
|
||||
modelPreparing = uiState.preparing,
|
||||
imageMessageCount = imageMessageCountToLastConfigChange,
|
||||
audioClipMessageCount = audioClipMesssageCountToLastconfigChange,
|
||||
modelInitializing =
|
||||
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
||||
|
@ -504,7 +537,10 @@ fun ChatPanel(
|
|||
onStopButtonClicked = onStopButtonClicked,
|
||||
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
||||
showPromptTemplatesInMenu = false,
|
||||
showImagePickerInMenu = selectedModel.llmSupportImage,
|
||||
showImagePickerInMenu =
|
||||
selectedModel.llmSupportImage && task.type === TaskType.LLM_ASK_IMAGE,
|
||||
showAudioItemsInMenu =
|
||||
selectedModel.llmSupportAudio && task.type === TaskType.LLM_ASK_AUDIO,
|
||||
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.ai.edge.gallery.ui.common.chat
|
||||
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.unit.dp
|
||||
|
||||
@Composable
|
||||
fun MessageBodyAudioClip(message: ChatMessageAudioClip, modifier: Modifier = Modifier) {
|
||||
AudioPlaybackPanel(
|
||||
audioData = message.audioData,
|
||||
sampleRate = message.sampleRate,
|
||||
isRecording = false,
|
||||
modifier = Modifier.padding(end = 16.dp),
|
||||
onDarkBg = true,
|
||||
)
|
||||
}
|
|
@ -21,6 +21,7 @@ package com.google.ai.edge.gallery.ui.common.chat
|
|||
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
|
||||
import android.Manifest
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import android.content.pm.PackageManager
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.BitmapFactory
|
||||
|
@ -65,9 +66,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
|
|||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.rounded.Send
|
||||
import androidx.compose.material.icons.rounded.Add
|
||||
import androidx.compose.material.icons.rounded.AudioFile
|
||||
import androidx.compose.material.icons.rounded.Close
|
||||
import androidx.compose.material.icons.rounded.FlipCameraAndroid
|
||||
import androidx.compose.material.icons.rounded.History
|
||||
import androidx.compose.material.icons.rounded.Mic
|
||||
import androidx.compose.material.icons.rounded.Photo
|
||||
import androidx.compose.material.icons.rounded.PhotoCamera
|
||||
import androidx.compose.material.icons.rounded.PostAdd
|
||||
|
@ -107,9 +110,14 @@ import androidx.compose.ui.unit.dp
|
|||
import androidx.compose.ui.viewinterop.AndroidView
|
||||
import androidx.core.content.ContextCompat
|
||||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||
import com.google.ai.edge.gallery.common.AudioClip
|
||||
import com.google.ai.edge.gallery.common.convertWavToMonoWithMaxSeconds
|
||||
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_COUNT
|
||||
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
|
||||
import com.google.ai.edge.gallery.data.SAMPLE_RATE
|
||||
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import java.util.concurrent.Executors
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
private const val TAG = "AGMessageInputText"
|
||||
|
@ -128,6 +136,7 @@ fun MessageInputText(
|
|||
isResettingSession: Boolean,
|
||||
inProgress: Boolean,
|
||||
imageMessageCount: Int,
|
||||
audioClipMessageCount: Int,
|
||||
modelInitializing: Boolean,
|
||||
@StringRes textFieldPlaceHolderRes: Int,
|
||||
onValueChanged: (String) -> Unit,
|
||||
|
@ -137,6 +146,7 @@ fun MessageInputText(
|
|||
onStopButtonClicked: () -> Unit = {},
|
||||
showPromptTemplatesInMenu: Boolean = false,
|
||||
showImagePickerInMenu: Boolean = false,
|
||||
showAudioItemsInMenu: Boolean = false,
|
||||
showStopButtonWhenInProgress: Boolean = false,
|
||||
) {
|
||||
val context = LocalContext.current
|
||||
|
@ -146,7 +156,12 @@ fun MessageInputText(
|
|||
var showTextInputHistorySheet by remember { mutableStateOf(false) }
|
||||
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
|
||||
val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||
var showAudioRecorderBottomSheet by remember { mutableStateOf(false) }
|
||||
val audioRecorderSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
|
||||
var pickedAudioClips by remember { mutableStateOf<List<AudioClip>>(listOf()) }
|
||||
var hasFrontCamera by remember { mutableStateOf(false) }
|
||||
|
||||
val updatePickedImages: (List<Bitmap>) -> Unit = { bitmaps ->
|
||||
var newPickedImages: MutableList<Bitmap> = mutableListOf()
|
||||
newPickedImages.addAll(pickedImages)
|
||||
|
@ -156,7 +171,16 @@ fun MessageInputText(
|
|||
}
|
||||
pickedImages = newPickedImages.toList()
|
||||
}
|
||||
var hasFrontCamera by remember { mutableStateOf(false) }
|
||||
|
||||
val updatePickedAudioClips: (List<AudioClip>) -> Unit = { audioDataList ->
|
||||
var newAudioDataList: MutableList<AudioClip> = mutableListOf()
|
||||
newAudioDataList.addAll(pickedAudioClips)
|
||||
newAudioDataList.addAll(audioDataList)
|
||||
if (newAudioDataList.size > MAX_AUDIO_CLIP_COUNT) {
|
||||
newAudioDataList = newAudioDataList.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
|
||||
}
|
||||
pickedAudioClips = newAudioDataList.toList()
|
||||
}
|
||||
|
||||
LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) }
|
||||
|
||||
|
@ -170,6 +194,16 @@ fun MessageInputText(
|
|||
}
|
||||
}
|
||||
|
||||
// Permission request when recording audio clips.
|
||||
val recordAudioClipsPermissionLauncher =
|
||||
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
|
||||
permissionGranted ->
|
||||
if (permissionGranted) {
|
||||
showAddContentMenu = false
|
||||
showAudioRecorderBottomSheet = true
|
||||
}
|
||||
}
|
||||
|
||||
// Registers a photo picker activity launcher in single-select mode.
|
||||
val pickMedia =
|
||||
rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris ->
|
||||
|
@ -184,9 +218,31 @@ fun MessageInputText(
|
|||
}
|
||||
}
|
||||
|
||||
val pickWav =
|
||||
rememberLauncherForActivityResult(
|
||||
contract = ActivityResultContracts.StartActivityForResult()
|
||||
) { result ->
|
||||
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
||||
result.data?.data?.let { uri ->
|
||||
Log.d(TAG, "Picked wav file: $uri")
|
||||
scope.launch(Dispatchers.IO) {
|
||||
convertWavToMonoWithMaxSeconds(context = context, stereoUri = uri)?.let { audioClip ->
|
||||
updatePickedAudioClips(
|
||||
listOf(
|
||||
AudioClip(audioData = audioClip.audioData, sampleRate = audioClip.sampleRate)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Log.d(TAG, "Wav picking cancelled.")
|
||||
}
|
||||
}
|
||||
|
||||
Column {
|
||||
// A preview panel for the selected image.
|
||||
if (pickedImages.isNotEmpty()) {
|
||||
// A preview panel for the selected images and audio clips.
|
||||
if (pickedImages.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
|
||||
Row(
|
||||
modifier =
|
||||
Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()),
|
||||
|
@ -203,20 +259,30 @@ fun MessageInputText(
|
|||
.clip(RoundedCornerShape(8.dp))
|
||||
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
|
||||
)
|
||||
MediaPanelCloseButton { pickedImages = pickedImages.filter { image != it } }
|
||||
}
|
||||
}
|
||||
|
||||
for ((index, audioClip) in pickedAudioClips.withIndex()) {
|
||||
Box(contentAlignment = Alignment.TopEnd) {
|
||||
Box(
|
||||
modifier =
|
||||
Modifier.offset(x = 10.dp, y = (-10).dp)
|
||||
.clip(CircleShape)
|
||||
Modifier.shadow(2.dp, shape = RoundedCornerShape(8.dp))
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.background(MaterialTheme.colorScheme.surface)
|
||||
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
|
||||
.clickable { pickedImages = pickedImages.filter { image != it } }
|
||||
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp))
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.Close,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.padding(3.dp).size(16.dp),
|
||||
AudioPlaybackPanel(
|
||||
audioData = audioClip.audioData,
|
||||
sampleRate = audioClip.sampleRate,
|
||||
isRecording = false,
|
||||
modifier = Modifier.padding(end = 16.dp),
|
||||
)
|
||||
}
|
||||
MediaPanelCloseButton {
|
||||
pickedAudioClips =
|
||||
pickedAudioClips.filterIndexed { curIndex, curAudioData -> curIndex != index }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -239,10 +305,13 @@ fun MessageInputText(
|
|||
verticalAlignment = Alignment.CenterVertically,
|
||||
) {
|
||||
val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT
|
||||
val enableRecordAudioClipMenuItems =
|
||||
(audioClipMessageCount + pickedAudioClips.size) < MAX_AUDIO_CLIP_COUNT
|
||||
DropdownMenu(
|
||||
expanded = showAddContentMenu,
|
||||
onDismissRequest = { showAddContentMenu = false },
|
||||
) {
|
||||
// Image related menu items.
|
||||
if (showImagePickerInMenu) {
|
||||
// Take a picture.
|
||||
DropdownMenuItem(
|
||||
|
@ -295,6 +364,70 @@ fun MessageInputText(
|
|||
)
|
||||
}
|
||||
|
||||
// Audio related menu items.
|
||||
if (showAudioItemsInMenu) {
|
||||
DropdownMenuItem(
|
||||
text = {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||
) {
|
||||
Icon(Icons.Rounded.Mic, contentDescription = "")
|
||||
Text("Record audio clip")
|
||||
}
|
||||
},
|
||||
enabled = enableRecordAudioClipMenuItems,
|
||||
onClick = {
|
||||
// Check permission
|
||||
when (PackageManager.PERMISSION_GRANTED) {
|
||||
// Already got permission. Call the lambda.
|
||||
ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO) -> {
|
||||
showAddContentMenu = false
|
||||
showAudioRecorderBottomSheet = true
|
||||
}
|
||||
|
||||
// Otherwise, ask for permission
|
||||
else -> {
|
||||
recordAudioClipsPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
DropdownMenuItem(
|
||||
text = {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||
) {
|
||||
Icon(Icons.Rounded.AudioFile, contentDescription = "")
|
||||
Text("Pick wav file")
|
||||
}
|
||||
},
|
||||
enabled = enableRecordAudioClipMenuItems,
|
||||
onClick = {
|
||||
showAddContentMenu = false
|
||||
|
||||
// Show file picker.
|
||||
val intent =
|
||||
Intent(Intent.ACTION_GET_CONTENT).apply {
|
||||
addCategory(Intent.CATEGORY_OPENABLE)
|
||||
type = "audio/*"
|
||||
|
||||
// Provide a list of more specific MIME types to filter for.
|
||||
val mimeTypes = arrayOf("audio/wav", "audio/x-wav")
|
||||
putExtra(Intent.EXTRA_MIME_TYPES, mimeTypes)
|
||||
|
||||
// Single select.
|
||||
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
||||
.addFlags(Intent.FLAG_GRANT_PERSISTABLE_URI_PERMISSION)
|
||||
.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
|
||||
}
|
||||
pickWav.launch(intent)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Prompt templates.
|
||||
if (showPromptTemplatesInMenu) {
|
||||
DropdownMenuItem(
|
||||
|
@ -369,15 +502,22 @@ fun MessageInputText(
|
|||
)
|
||||
}
|
||||
}
|
||||
} // Send button. Only shown when text is not empty.
|
||||
else if (curMessage.isNotEmpty()) {
|
||||
}
|
||||
// Send button. Only shown when text is not empty, or there is at least one recorded
|
||||
// audio clip.
|
||||
else if (curMessage.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
|
||||
IconButton(
|
||||
enabled = !inProgress && !isResettingSession,
|
||||
onClick = {
|
||||
onSendMessage(
|
||||
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
|
||||
createMessagesToSend(
|
||||
pickedImages = pickedImages,
|
||||
audioClips = pickedAudioClips,
|
||||
text = curMessage.trim(),
|
||||
)
|
||||
)
|
||||
pickedImages = listOf()
|
||||
pickedAudioClips = listOf()
|
||||
},
|
||||
colors =
|
||||
IconButtonDefaults.iconButtonColors(
|
||||
|
@ -403,8 +543,15 @@ fun MessageInputText(
|
|||
history = modelManagerUiState.textInputHistory,
|
||||
onDismissed = { showTextInputHistorySheet = false },
|
||||
onHistoryItemClicked = { item ->
|
||||
onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
|
||||
onSendMessage(
|
||||
createMessagesToSend(
|
||||
pickedImages = pickedImages,
|
||||
audioClips = pickedAudioClips,
|
||||
text = item,
|
||||
)
|
||||
)
|
||||
pickedImages = listOf()
|
||||
pickedAudioClips = listOf()
|
||||
modelManagerViewModel.promoteTextInputHistoryItem(item)
|
||||
},
|
||||
onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) },
|
||||
|
@ -582,6 +729,43 @@ fun MessageInputText(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (showAudioRecorderBottomSheet) {
|
||||
ModalBottomSheet(
|
||||
sheetState = audioRecorderSheetState,
|
||||
onDismissRequest = { showAudioRecorderBottomSheet = false },
|
||||
) {
|
||||
AudioRecorderPanel(
|
||||
onSendAudioClip = { audioData ->
|
||||
scope.launch {
|
||||
updatePickedAudioClips(
|
||||
listOf(AudioClip(audioData = audioData, sampleRate = SAMPLE_RATE))
|
||||
)
|
||||
audioRecorderSheetState.hide()
|
||||
showAudioRecorderBottomSheet = false
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun MediaPanelCloseButton(onClicked: () -> Unit) {
|
||||
Box(
|
||||
modifier =
|
||||
Modifier.offset(x = 10.dp, y = (-10).dp)
|
||||
.clip(CircleShape)
|
||||
.background(MaterialTheme.colorScheme.surface)
|
||||
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
|
||||
.clickable { onClicked() }
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.Close,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.padding(3.dp).size(16.dp),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun handleImagesSelected(
|
||||
|
@ -641,20 +825,50 @@ private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) {
|
|||
)
|
||||
}
|
||||
|
||||
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> {
|
||||
private fun createMessagesToSend(
|
||||
pickedImages: List<Bitmap>,
|
||||
audioClips: List<AudioClip>,
|
||||
text: String,
|
||||
): List<ChatMessage> {
|
||||
var messages: MutableList<ChatMessage> = mutableListOf()
|
||||
|
||||
// Add image messages.
|
||||
var imageMessages: MutableList<ChatMessageImage> = mutableListOf()
|
||||
if (pickedImages.isNotEmpty()) {
|
||||
for (image in pickedImages) {
|
||||
messages.add(
|
||||
imageMessages.add(
|
||||
ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER)
|
||||
)
|
||||
}
|
||||
}
|
||||
// Cap the number of image messages.
|
||||
if (messages.size > MAX_IMAGE_COUNT) {
|
||||
messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
|
||||
if (imageMessages.size > MAX_IMAGE_COUNT) {
|
||||
imageMessages = imageMessages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
|
||||
}
|
||||
messages.addAll(imageMessages)
|
||||
|
||||
// Add audio messages.
|
||||
var audioMessages: MutableList<ChatMessageAudioClip> = mutableListOf()
|
||||
if (audioClips.isNotEmpty()) {
|
||||
for (audioClip in audioClips) {
|
||||
audioMessages.add(
|
||||
ChatMessageAudioClip(
|
||||
audioData = audioClip.audioData,
|
||||
sampleRate = audioClip.sampleRate,
|
||||
side = ChatSide.USER,
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
// Cap the number of audio messages.
|
||||
if (audioMessages.size > MAX_AUDIO_CLIP_COUNT) {
|
||||
audioMessages = audioMessages.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
|
||||
}
|
||||
messages.addAll(audioMessages)
|
||||
|
||||
if (text.isNotEmpty()) {
|
||||
messages.add(ChatMessageText(content = text, side = ChatSide.USER))
|
||||
}
|
||||
messages.add(ChatMessageText(content = text, side = ChatSide.USER))
|
||||
|
||||
return messages
|
||||
}
|
||||
|
|
|
@ -121,6 +121,7 @@ private val IMPORT_CONFIGS_LLM: List<Config> =
|
|||
valueType = ValueType.FLOAT,
|
||||
),
|
||||
BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false),
|
||||
BooleanSwitchConfig(key = ConfigKey.SUPPORT_AUDIO, defaultValue = false),
|
||||
SegmentedButtonConfig(
|
||||
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
||||
defaultValue = Accelerator.CPU.label,
|
||||
|
@ -230,6 +231,12 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
|
|||
valueType = ValueType.BOOLEAN,
|
||||
)
|
||||
as Boolean
|
||||
val supportAudio =
|
||||
convertValueToTargetType(
|
||||
value = values.get(ConfigKey.SUPPORT_AUDIO.label)!!,
|
||||
valueType = ValueType.BOOLEAN,
|
||||
)
|
||||
as Boolean
|
||||
val importedModel: ImportedModel =
|
||||
ImportedModel.newBuilder()
|
||||
.setFileName(fileName)
|
||||
|
@ -242,6 +249,7 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
|
|||
.setDefaultTopp(defaultTopp)
|
||||
.setDefaultTemperature(defaultTemperature)
|
||||
.setSupportImage(supportImage)
|
||||
.setSupportAudio(supportAudio)
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
|
|
|
@ -173,7 +173,7 @@ fun SettingsDialog(
|
|||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||
)
|
||||
Text(
|
||||
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
|
||||
"Expires at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||
)
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.ai.edge.gallery.ui.icon
|
||||
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.SolidColor
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.graphics.vector.path
|
||||
import androidx.compose.ui.unit.dp
|
||||
|
||||
val Forum: ImageVector
|
||||
get() {
|
||||
if (_Forum != null) return _Forum!!
|
||||
|
||||
_Forum =
|
||||
ImageVector.Builder(
|
||||
name = "Forum",
|
||||
defaultWidth = 24.dp,
|
||||
defaultHeight = 24.dp,
|
||||
viewportWidth = 960f,
|
||||
viewportHeight = 960f,
|
||||
)
|
||||
.apply {
|
||||
path(fill = SolidColor(Color(0xFF000000))) {
|
||||
moveTo(280f, 720f)
|
||||
quadToRelative(-17f, 0f, -28.5f, -11.5f)
|
||||
reflectiveQuadTo(240f, 680f)
|
||||
verticalLineToRelative(-80f)
|
||||
horizontalLineToRelative(520f)
|
||||
verticalLineToRelative(-360f)
|
||||
horizontalLineToRelative(80f)
|
||||
quadToRelative(17f, 0f, 28.5f, 11.5f)
|
||||
reflectiveQuadTo(880f, 280f)
|
||||
verticalLineToRelative(600f)
|
||||
lineTo(720f, 720f)
|
||||
close()
|
||||
moveTo(80f, 680f)
|
||||
verticalLineToRelative(-560f)
|
||||
quadToRelative(0f, -17f, 11.5f, -28.5f)
|
||||
reflectiveQuadTo(120f, 80f)
|
||||
horizontalLineToRelative(520f)
|
||||
quadToRelative(17f, 0f, 28.5f, 11.5f)
|
||||
reflectiveQuadTo(680f, 120f)
|
||||
verticalLineToRelative(360f)
|
||||
quadToRelative(0f, 17f, -11.5f, 28.5f)
|
||||
reflectiveQuadTo(640f, 520f)
|
||||
horizontalLineTo(240f)
|
||||
close()
|
||||
moveToRelative(520f, -240f)
|
||||
verticalLineToRelative(-280f)
|
||||
horizontalLineTo(160f)
|
||||
verticalLineToRelative(280f)
|
||||
close()
|
||||
moveToRelative(-440f, 0f)
|
||||
verticalLineToRelative(-280f)
|
||||
close()
|
||||
}
|
||||
}
|
||||
.build()
|
||||
|
||||
return _Forum!!
|
||||
}
|
||||
|
||||
private var _Forum: ImageVector? = null
|
|
@ -1,73 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.ai.edge.gallery.ui.icon
|
||||
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.SolidColor
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.graphics.vector.path
|
||||
import androidx.compose.ui.unit.dp
|
||||
|
||||
val Mms: ImageVector
|
||||
get() {
|
||||
if (_Mms != null) return _Mms!!
|
||||
|
||||
_Mms =
|
||||
ImageVector.Builder(
|
||||
name = "Mms",
|
||||
defaultWidth = 24.dp,
|
||||
defaultHeight = 24.dp,
|
||||
viewportWidth = 960f,
|
||||
viewportHeight = 960f,
|
||||
)
|
||||
.apply {
|
||||
path(fill = SolidColor(Color(0xFF000000))) {
|
||||
moveTo(240f, 560f)
|
||||
horizontalLineToRelative(480f)
|
||||
lineTo(570f, 360f)
|
||||
lineTo(450f, 520f)
|
||||
lineToRelative(-90f, -120f)
|
||||
close()
|
||||
moveTo(80f, 880f)
|
||||
verticalLineToRelative(-720f)
|
||||
quadToRelative(0f, -33f, 23.5f, -56.5f)
|
||||
reflectiveQuadTo(160f, 80f)
|
||||
horizontalLineToRelative(640f)
|
||||
quadToRelative(33f, 0f, 56.5f, 23.5f)
|
||||
reflectiveQuadTo(880f, 160f)
|
||||
verticalLineToRelative(480f)
|
||||
quadToRelative(0f, 33f, -23.5f, 56.5f)
|
||||
reflectiveQuadTo(800f, 720f)
|
||||
horizontalLineTo(240f)
|
||||
close()
|
||||
moveToRelative(126f, -240f)
|
||||
horizontalLineToRelative(594f)
|
||||
verticalLineToRelative(-480f)
|
||||
horizontalLineTo(160f)
|
||||
verticalLineToRelative(525f)
|
||||
close()
|
||||
moveToRelative(-46f, 0f)
|
||||
verticalLineToRelative(-480f)
|
||||
close()
|
||||
}
|
||||
}
|
||||
.build()
|
||||
|
||||
return _Mms!!
|
||||
}
|
||||
|
||||
private var _Mms: ImageVector? = null
|
|
@ -1,87 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.ai.edge.gallery.ui.icon
|
||||
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.SolidColor
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.graphics.vector.path
|
||||
import androidx.compose.ui.unit.dp
|
||||
|
||||
val Widgets: ImageVector
|
||||
get() {
|
||||
if (_Widgets != null) return _Widgets!!
|
||||
|
||||
_Widgets =
|
||||
ImageVector.Builder(
|
||||
name = "Widgets",
|
||||
defaultWidth = 24.dp,
|
||||
defaultHeight = 24.dp,
|
||||
viewportWidth = 960f,
|
||||
viewportHeight = 960f,
|
||||
)
|
||||
.apply {
|
||||
path(fill = SolidColor(Color(0xFF000000))) {
|
||||
moveTo(666f, 520f)
|
||||
lineTo(440f, 294f)
|
||||
lineToRelative(226f, -226f)
|
||||
lineToRelative(226f, 226f)
|
||||
close()
|
||||
moveToRelative(-546f, -80f)
|
||||
verticalLineToRelative(-320f)
|
||||
horizontalLineToRelative(320f)
|
||||
verticalLineToRelative(320f)
|
||||
close()
|
||||
moveToRelative(400f, 400f)
|
||||
verticalLineToRelative(-320f)
|
||||
horizontalLineToRelative(320f)
|
||||
verticalLineToRelative(320f)
|
||||
close()
|
||||
moveToRelative(-400f, 0f)
|
||||
verticalLineToRelative(-320f)
|
||||
horizontalLineToRelative(320f)
|
||||
verticalLineToRelative(320f)
|
||||
close()
|
||||
moveToRelative(80f, -480f)
|
||||
horizontalLineToRelative(160f)
|
||||
verticalLineToRelative(-160f)
|
||||
horizontalLineTo(200f)
|
||||
close()
|
||||
moveToRelative(467f, 48f)
|
||||
lineToRelative(113f, -113f)
|
||||
lineToRelative(-113f, -113f)
|
||||
lineToRelative(-113f, 113f)
|
||||
close()
|
||||
moveToRelative(-67f, 352f)
|
||||
horizontalLineToRelative(160f)
|
||||
verticalLineToRelative(-160f)
|
||||
horizontalLineTo(600f)
|
||||
close()
|
||||
moveToRelative(-400f, 0f)
|
||||
horizontalLineToRelative(160f)
|
||||
verticalLineToRelative(-160f)
|
||||
horizontalLineTo(200f)
|
||||
close()
|
||||
moveToRelative(400f, -160f)
|
||||
}
|
||||
}
|
||||
.build()
|
||||
|
||||
return _Widgets!!
|
||||
}
|
||||
|
||||
private var _Widgets: ImageVector? = null
|
|
@ -62,13 +62,13 @@ object LlmChatModelHelper {
|
|||
Accelerator.GPU.label -> LlmInference.Backend.GPU
|
||||
else -> LlmInference.Backend.GPU
|
||||
}
|
||||
val options =
|
||||
val optionsBuilder =
|
||||
LlmInference.LlmInferenceOptions.builder()
|
||||
.setModelPath(model.getPath(context = context))
|
||||
.setMaxTokens(maxTokens)
|
||||
.setPreferredBackend(preferredBackend)
|
||||
.setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0)
|
||||
.build()
|
||||
val options = optionsBuilder.build()
|
||||
|
||||
// Create an instance of the LLM Inference task and session.
|
||||
try {
|
||||
|
@ -82,7 +82,9 @@ object LlmChatModelHelper {
|
|||
.setTopP(topP)
|
||||
.setTemperature(temperature)
|
||||
.setGraphOptions(
|
||||
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
||||
GraphOptions.builder()
|
||||
.setEnableVisionModality(model.llmSupportImage)
|
||||
.build()
|
||||
)
|
||||
.build(),
|
||||
)
|
||||
|
@ -115,7 +117,9 @@ object LlmChatModelHelper {
|
|||
.setTopP(topP)
|
||||
.setTemperature(temperature)
|
||||
.setGraphOptions(
|
||||
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
||||
GraphOptions.builder()
|
||||
.setEnableVisionModality(model.llmSupportImage)
|
||||
.build()
|
||||
)
|
||||
.build(),
|
||||
)
|
||||
|
@ -159,6 +163,7 @@ object LlmChatModelHelper {
|
|||
resultListener: ResultListener,
|
||||
cleanUpListener: CleanUpListener,
|
||||
images: List<Bitmap> = listOf(),
|
||||
audioClips: List<ByteArray> = listOf(),
|
||||
) {
|
||||
val instance = model.instance as LlmModelInstance
|
||||
|
||||
|
@ -172,10 +177,16 @@ object LlmChatModelHelper {
|
|||
// For a model that supports image modality, we need to add the text query chunk before adding
|
||||
// image.
|
||||
val session = instance.session
|
||||
session.addQueryChunk(input)
|
||||
if (input.trim().isNotEmpty()) {
|
||||
session.addQueryChunk(input)
|
||||
}
|
||||
for (image in images) {
|
||||
session.addImage(BitmapImageBuilder(image).build())
|
||||
}
|
||||
for (audioClip in audioClips) {
|
||||
// Uncomment when audio is supported.
|
||||
// session.addAudio(audioClip)
|
||||
}
|
||||
val unused = session.generateResponseAsync(resultListener)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||
import com.google.ai.edge.gallery.ui.ViewModelProvider
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatView
|
||||
|
@ -36,6 +37,10 @@ object LlmAskImageDestination {
|
|||
val route = "LlmAskImageRoute"
|
||||
}
|
||||
|
||||
object LlmAskAudioDestination {
|
||||
val route = "LlmAskAudioRoute"
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun LlmChatScreen(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
|
@ -66,6 +71,21 @@ fun LlmAskImageScreen(
|
|||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun LlmAskAudioScreen(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
navigateUp: () -> Unit,
|
||||
modifier: Modifier = Modifier,
|
||||
viewModel: LlmAskAudioViewModel = viewModel(factory = ViewModelProvider.Factory),
|
||||
) {
|
||||
ChatViewWrapper(
|
||||
viewModel = viewModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
navigateUp = navigateUp,
|
||||
modifier = modifier,
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ChatViewWrapper(
|
||||
viewModel: LlmChatViewModel,
|
||||
|
@ -86,6 +106,7 @@ fun ChatViewWrapper(
|
|||
|
||||
var text = ""
|
||||
val images: MutableList<Bitmap> = mutableListOf()
|
||||
val audioMessages: MutableList<ChatMessageAudioClip> = mutableListOf()
|
||||
var chatMessageText: ChatMessageText? = null
|
||||
for (message in messages) {
|
||||
if (message is ChatMessageText) {
|
||||
|
@ -93,14 +114,17 @@ fun ChatViewWrapper(
|
|||
text = message.content
|
||||
} else if (message is ChatMessageImage) {
|
||||
images.add(message.bitmap)
|
||||
} else if (message is ChatMessageAudioClip) {
|
||||
audioMessages.add(message)
|
||||
}
|
||||
}
|
||||
if (text.isNotEmpty() && chatMessageText != null) {
|
||||
if ((text.isNotEmpty() && chatMessageText != null) || audioMessages.isNotEmpty()) {
|
||||
modelManagerViewModel.addTextInputHistory(text)
|
||||
viewModel.generateResponse(
|
||||
model = model,
|
||||
input = text,
|
||||
images = images,
|
||||
audioMessages = audioMessages,
|
||||
onError = {
|
||||
viewModel.handleError(
|
||||
context = context,
|
||||
|
|
|
@ -22,9 +22,11 @@ import android.util.Log
|
|||
import androidx.lifecycle.viewModelScope
|
||||
import com.google.ai.edge.gallery.data.ConfigKey
|
||||
import com.google.ai.edge.gallery.data.Model
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.ai.edge.gallery.data.Task
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
|
||||
|
@ -52,6 +54,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
model: Model,
|
||||
input: String,
|
||||
images: List<Bitmap> = listOf(),
|
||||
audioMessages: List<ChatMessageAudioClip> = listOf(),
|
||||
onError: () -> Unit,
|
||||
) {
|
||||
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
|
||||
|
@ -72,6 +75,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
val instance = model.instance as LlmModelInstance
|
||||
var prefillTokens = instance.session.sizeInTokens(input)
|
||||
prefillTokens += images.size * 257
|
||||
for (audioMessages in audioMessages) {
|
||||
// 150ms = 1 audio token
|
||||
val duration = audioMessages.getDurationInSeconds()
|
||||
prefillTokens += (duration * 1000f / 150f).toInt()
|
||||
}
|
||||
|
||||
var firstRun = true
|
||||
var timeToFirstToken = 0f
|
||||
|
@ -86,6 +94,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
model = model,
|
||||
input = input,
|
||||
images = images,
|
||||
audioClips = audioMessages.map { it.genByteArrayForWav() },
|
||||
resultListener = { partialResult, done ->
|
||||
val curTs = System.currentTimeMillis()
|
||||
|
||||
|
@ -214,7 +223,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
context: Context,
|
||||
model: Model,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
triggeredMessage: ChatMessageText,
|
||||
triggeredMessage: ChatMessageText?,
|
||||
) {
|
||||
// Clean up.
|
||||
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||
|
@ -236,14 +245,20 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
)
|
||||
|
||||
// Add the triggered message back.
|
||||
addMessage(model = model, message = triggeredMessage)
|
||||
if (triggeredMessage != null) {
|
||||
addMessage(model = model, message = triggeredMessage)
|
||||
}
|
||||
|
||||
// Re-initialize the session/engine.
|
||||
modelManagerViewModel.initializeModel(context = context, task = task, model = model)
|
||||
|
||||
// Re-generate the response automatically.
|
||||
generateResponse(model = model, input = triggeredMessage.content, onError = {})
|
||||
if (triggeredMessage != null) {
|
||||
generateResponse(model = model, input = triggeredMessage.content, onError = {})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
|
||||
|
||||
class LlmAskAudioViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_AUDIO)
|
||||
|
|
|
@ -36,6 +36,7 @@ import com.google.ai.edge.gallery.data.ModelAllowlist
|
|||
import com.google.ai.edge.gallery.data.ModelDownloadStatus
|
||||
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.ai.edge.gallery.data.TASKS
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||
|
@ -281,15 +282,12 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
}
|
||||
when (task.type) {
|
||||
TaskType.LLM_CHAT ->
|
||||
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
|
||||
|
||||
TaskType.LLM_CHAT,
|
||||
TaskType.LLM_ASK_IMAGE,
|
||||
TaskType.LLM_ASK_AUDIO,
|
||||
TaskType.LLM_PROMPT_LAB ->
|
||||
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
|
||||
|
||||
TaskType.LLM_ASK_IMAGE ->
|
||||
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
|
||||
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
TaskType.TEST_TASK_2 -> {}
|
||||
}
|
||||
|
@ -301,9 +299,11 @@ open class ModelManagerViewModel(
|
|||
model.cleanUpAfterInit = false
|
||||
Log.d(TAG, "Cleaning up model '${model.name}'...")
|
||||
when (task.type) {
|
||||
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model)
|
||||
TaskType.LLM_CHAT,
|
||||
TaskType.LLM_PROMPT_LAB,
|
||||
TaskType.LLM_ASK_IMAGE,
|
||||
TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model)
|
||||
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
TaskType.TEST_TASK_2 -> {}
|
||||
}
|
||||
|
@ -410,14 +410,19 @@ open class ModelManagerViewModel(
|
|||
// Create model.
|
||||
val model = createModelFromImportedModelInfo(info = info)
|
||||
|
||||
for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
|
||||
for (task in
|
||||
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
|
||||
// Remove duplicated imported model if existed.
|
||||
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
||||
if (modelIndex >= 0) {
|
||||
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
||||
task.models.removeAt(modelIndex)
|
||||
}
|
||||
if ((task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || task != TASK_LLM_ASK_IMAGE) {
|
||||
if (
|
||||
(task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) ||
|
||||
(task == TASK_LLM_ASK_AUDIO && model.llmSupportAudio) ||
|
||||
(task != TASK_LLM_ASK_IMAGE && task != TASK_LLM_ASK_AUDIO)
|
||||
) {
|
||||
task.models.add(model)
|
||||
}
|
||||
task.updateTrigger.value = System.currentTimeMillis()
|
||||
|
@ -657,6 +662,7 @@ open class ModelManagerViewModel(
|
|||
TASK_LLM_CHAT.models.clear()
|
||||
TASK_LLM_PROMPT_LAB.models.clear()
|
||||
TASK_LLM_ASK_IMAGE.models.clear()
|
||||
TASK_LLM_ASK_AUDIO.models.clear()
|
||||
for (allowedModel in modelAllowlist.models) {
|
||||
if (allowedModel.disabled == true) {
|
||||
continue
|
||||
|
@ -672,6 +678,9 @@ open class ModelManagerViewModel(
|
|||
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
|
||||
TASK_LLM_ASK_IMAGE.models.add(model)
|
||||
}
|
||||
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_AUDIO.type.id)) {
|
||||
TASK_LLM_ASK_AUDIO.models.add(model)
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-process all tasks.
|
||||
|
@ -760,6 +769,9 @@ open class ModelManagerViewModel(
|
|||
if (model.llmSupportImage) {
|
||||
TASK_LLM_ASK_IMAGE.models.add(model)
|
||||
}
|
||||
if (model.llmSupportAudio) {
|
||||
TASK_LLM_ASK_AUDIO.models.add(model)
|
||||
}
|
||||
|
||||
// Update status.
|
||||
modelDownloadStatus[model.name] =
|
||||
|
@ -800,6 +812,7 @@ open class ModelManagerViewModel(
|
|||
accelerators = accelerators,
|
||||
)
|
||||
val llmSupportImage = info.llmConfig.supportImage
|
||||
val llmSupportAudio = info.llmConfig.supportAudio
|
||||
val model =
|
||||
Model(
|
||||
name = info.fileName,
|
||||
|
@ -811,6 +824,7 @@ open class ModelManagerViewModel(
|
|||
showRunAgainButton = false,
|
||||
imported = true,
|
||||
llmSupportImage = llmSupportImage,
|
||||
llmSupportAudio = llmSupportAudio,
|
||||
)
|
||||
model.preProcess()
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ import androidx.navigation.compose.NavHost
|
|||
import androidx.navigation.compose.composable
|
||||
import androidx.navigation.navArgument
|
||||
import com.google.ai.edge.gallery.data.Model
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||
|
@ -55,6 +56,8 @@ import com.google.ai.edge.gallery.data.TaskType
|
|||
import com.google.ai.edge.gallery.data.getModelByName
|
||||
import com.google.ai.edge.gallery.ui.ViewModelProvider
|
||||
import com.google.ai.edge.gallery.ui.home.HomeScreen
|
||||
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination
|
||||
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen
|
||||
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination
|
||||
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen
|
||||
import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination
|
||||
|
@ -209,7 +212,7 @@ fun GalleryNavHost(
|
|||
}
|
||||
}
|
||||
|
||||
// LLM image to text.
|
||||
// Ask image.
|
||||
composable(
|
||||
route = "${LlmAskImageDestination.route}/{modelName}",
|
||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||
|
@ -225,6 +228,23 @@ fun GalleryNavHost(
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Ask audio.
|
||||
composable(
|
||||
route = "${LlmAskAudioDestination.route}/{modelName}",
|
||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||
enterTransition = { slideEnter() },
|
||||
exitTransition = { slideExit() },
|
||||
) {
|
||||
getModelFromNavigationParam(it, TASK_LLM_ASK_AUDIO)?.let { defaultModel ->
|
||||
modelManagerViewModel.selectModel(defaultModel)
|
||||
|
||||
LlmAskAudioScreen(
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
navigateUp = { navController.navigateUp() },
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle incoming intents for deep links
|
||||
|
@ -256,6 +276,7 @@ fun navigateToTaskScreen(
|
|||
when (taskType) {
|
||||
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
||||
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
|
||||
TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}")
|
||||
TaskType.LLM_PROMPT_LAB ->
|
||||
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||
TaskType.TEST_TASK_1 -> {}
|
||||
|
|
|
@ -120,6 +120,8 @@ data class CustomColors(
|
|||
val agentBubbleBgColor: Color = Color.Transparent,
|
||||
val linkColor: Color = Color.Transparent,
|
||||
val successColor: Color = Color.Transparent,
|
||||
val recordButtonBgColor: Color = Color.Transparent,
|
||||
val waveFormBgColor: Color = Color.Transparent,
|
||||
)
|
||||
|
||||
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
|
||||
|
@ -145,6 +147,8 @@ val lightCustomColors =
|
|||
userBubbleBgColor = Color(0xFF32628D),
|
||||
linkColor = Color(0xFF32628D),
|
||||
successColor = Color(0xff3d860b),
|
||||
recordButtonBgColor = Color(0xFFEE675C),
|
||||
waveFormBgColor = Color(0xFFaaaaaa),
|
||||
)
|
||||
|
||||
val darkCustomColors =
|
||||
|
@ -168,6 +172,8 @@ val darkCustomColors =
|
|||
userBubbleBgColor = Color(0xFF1f3760),
|
||||
linkColor = Color(0xFF9DCAFC),
|
||||
successColor = Color(0xFFA1CE83),
|
||||
recordButtonBgColor = Color(0xFFEE675C),
|
||||
waveFormBgColor = Color(0xFFaaaaaa),
|
||||
)
|
||||
|
||||
val MaterialTheme.customColors: CustomColors
|
||||
|
|
|
@ -55,6 +55,7 @@ message LlmConfig {
|
|||
float default_topp = 4;
|
||||
float default_temperature = 5;
|
||||
bool support_image = 6;
|
||||
bool support_audio = 7;
|
||||
}
|
||||
|
||||
message Settings {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue