mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-05 14:10:35 -04:00
Add audio support.
- Add a new task "audio scribe". - Allow users to record audio clips or pick wav files to interact with model. - Add support for importing models with audio capability. - Fix a typo in Settings dialog (Thanks https://github.com/rhnvrm!) PiperOrigin-RevId: 774832681
This commit is contained in:
parent
33c3ee638e
commit
d0989adce1
27 changed files with 1369 additions and 288 deletions
|
@ -29,6 +29,7 @@
|
||||||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
|
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
|
||||||
<uses-permission android:name="android.permission.INTERNET" />
|
<uses-permission android:name="android.permission.INTERNET" />
|
||||||
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
|
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
|
||||||
|
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||||
<uses-permission android:name="android.permission.WAKE_LOCK"/>
|
<uses-permission android:name="android.permission.WAKE_LOCK"/>
|
||||||
|
|
||||||
<uses-feature
|
<uses-feature
|
||||||
|
|
|
@ -25,3 +25,5 @@ interface LatencyProvider {
|
||||||
data class Classification(val label: String, val score: Float, val color: Color)
|
data class Classification(val label: String, val score: Float, val color: Color)
|
||||||
|
|
||||||
data class JsonObjAndTextContent<T>(val jsonObj: T, val textContent: String)
|
data class JsonObjAndTextContent<T>(val jsonObj: T, val textContent: String)
|
||||||
|
|
||||||
|
class AudioClip(val audioData: ByteArray, val sampleRate: Int)
|
||||||
|
|
|
@ -17,12 +17,17 @@
|
||||||
package com.google.ai.edge.gallery.common
|
package com.google.ai.edge.gallery.common
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import android.net.Uri
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
|
import com.google.ai.edge.gallery.data.SAMPLE_RATE
|
||||||
import com.google.gson.Gson
|
import com.google.gson.Gson
|
||||||
import com.google.gson.reflect.TypeToken
|
import com.google.gson.reflect.TypeToken
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.net.HttpURLConnection
|
import java.net.HttpURLConnection
|
||||||
import java.net.URL
|
import java.net.URL
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
import java.nio.ByteOrder
|
||||||
|
import kotlin.math.floor
|
||||||
|
|
||||||
data class LaunchInfo(val ts: Long)
|
data class LaunchInfo(val ts: Long)
|
||||||
|
|
||||||
|
@ -112,3 +117,135 @@ inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
|
||||||
|
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun convertWavToMonoWithMaxSeconds(
|
||||||
|
context: Context,
|
||||||
|
stereoUri: Uri,
|
||||||
|
maxSeconds: Int = 30,
|
||||||
|
): AudioClip? {
|
||||||
|
Log.d(TAG, "Start to convert wav file to mono channel")
|
||||||
|
|
||||||
|
try {
|
||||||
|
val inputStream = context.contentResolver.openInputStream(stereoUri) ?: return null
|
||||||
|
val originalBytes = inputStream.readBytes()
|
||||||
|
inputStream.close()
|
||||||
|
|
||||||
|
// Read WAV header
|
||||||
|
if (originalBytes.size < 44) {
|
||||||
|
// Not a valid WAV file
|
||||||
|
Log.e(TAG, "Not a valid wav file")
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
val headerBuffer = ByteBuffer.wrap(originalBytes, 0, 44).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
val channels = headerBuffer.getShort(22)
|
||||||
|
var sampleRate = headerBuffer.getInt(24)
|
||||||
|
val bitDepth = headerBuffer.getShort(34)
|
||||||
|
Log.d(TAG, "File metadata: channels: $channels, sampleRate: $sampleRate, bitDepth: $bitDepth")
|
||||||
|
|
||||||
|
// Normalize audio to 16-bit.
|
||||||
|
val audioDataBytes = originalBytes.copyOfRange(fromIndex = 44, toIndex = originalBytes.size)
|
||||||
|
var sixteenBitBytes: ByteArray =
|
||||||
|
if (bitDepth.toInt() == 8) {
|
||||||
|
Log.d(TAG, "Converting 8-bit audio to 16-bit.")
|
||||||
|
convert8BitTo16Bit(audioDataBytes)
|
||||||
|
} else {
|
||||||
|
// Assume 16-bit or other format that can be handled directly
|
||||||
|
audioDataBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert byte array to short array for processing
|
||||||
|
val shortBuffer =
|
||||||
|
ByteBuffer.wrap(sixteenBitBytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
|
||||||
|
var pcmSamples = ShortArray(shortBuffer.remaining())
|
||||||
|
shortBuffer.get(pcmSamples)
|
||||||
|
|
||||||
|
// Resample if sample rate is less than 16000 Hz ---
|
||||||
|
if (sampleRate < SAMPLE_RATE) {
|
||||||
|
Log.d(TAG, "Resampling from $sampleRate Hz to $SAMPLE_RATE Hz.")
|
||||||
|
pcmSamples = resample(pcmSamples, sampleRate, SAMPLE_RATE, channels.toInt())
|
||||||
|
sampleRate = SAMPLE_RATE
|
||||||
|
Log.d(TAG, "Resampling complete. New sample count: ${pcmSamples.size}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert stereo to mono if necessary
|
||||||
|
var monoSamples =
|
||||||
|
if (channels.toInt() == 2) {
|
||||||
|
Log.d(TAG, "Converting stereo to mono.")
|
||||||
|
val mono = ShortArray(pcmSamples.size / 2)
|
||||||
|
for (i in mono.indices) {
|
||||||
|
val left = pcmSamples[i * 2]
|
||||||
|
val right = pcmSamples[i * 2 + 1]
|
||||||
|
mono[i] = ((left + right) / 2).toShort()
|
||||||
|
}
|
||||||
|
mono
|
||||||
|
} else {
|
||||||
|
Log.d(TAG, "Audio is already mono. No channel conversion needed.")
|
||||||
|
pcmSamples
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trim the audio to maxSeconds ---
|
||||||
|
val maxSamples = maxSeconds * sampleRate
|
||||||
|
if (monoSamples.size > maxSamples) {
|
||||||
|
Log.d(TAG, "Trimming clip from ${monoSamples.size} samples to $maxSamples samples.")
|
||||||
|
monoSamples = monoSamples.copyOfRange(0, maxSamples)
|
||||||
|
}
|
||||||
|
|
||||||
|
val monoByteBuffer = ByteBuffer.allocate(monoSamples.size * 2).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
monoByteBuffer.asShortBuffer().put(monoSamples)
|
||||||
|
return AudioClip(audioData = monoByteBuffer.array(), sampleRate = sampleRate)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Failed to convert wav to mono", e)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Converts 8-bit unsigned PCM audio data to 16-bit signed PCM. */
|
||||||
|
private fun convert8BitTo16Bit(eightBitData: ByteArray): ByteArray {
|
||||||
|
// The new 16-bit data will be twice the size
|
||||||
|
val sixteenBitData = ByteArray(eightBitData.size * 2)
|
||||||
|
val buffer = ByteBuffer.wrap(sixteenBitData).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
|
||||||
|
for (byte in eightBitData) {
|
||||||
|
// Convert the unsigned 8-bit byte (0-255) to a signed 16-bit short (-32768 to 32767)
|
||||||
|
// 1. Get the unsigned value by masking with 0xFF
|
||||||
|
// 2. Subtract 128 to center the waveform around 0 (range becomes -128 to 127)
|
||||||
|
// 3. Scale by 256 to expand to the 16-bit range
|
||||||
|
val unsignedByte = byte.toInt() and 0xFF
|
||||||
|
val sixteenBitSample = ((unsignedByte - 128) * 256).toShort()
|
||||||
|
buffer.putShort(sixteenBitSample)
|
||||||
|
}
|
||||||
|
return sixteenBitData
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Resamples PCM audio data from an original sample rate to a target sample rate. */
|
||||||
|
private fun resample(
|
||||||
|
inputSamples: ShortArray,
|
||||||
|
originalSampleRate: Int,
|
||||||
|
targetSampleRate: Int,
|
||||||
|
channels: Int,
|
||||||
|
): ShortArray {
|
||||||
|
if (originalSampleRate == targetSampleRate) {
|
||||||
|
return inputSamples
|
||||||
|
}
|
||||||
|
|
||||||
|
val ratio = targetSampleRate.toDouble() / originalSampleRate
|
||||||
|
val outputLength = (inputSamples.size * ratio).toInt()
|
||||||
|
val resampledData = ShortArray(outputLength)
|
||||||
|
|
||||||
|
if (channels == 1) { // Mono
|
||||||
|
for (i in resampledData.indices) {
|
||||||
|
val position = i / ratio
|
||||||
|
val index1 = floor(position).toInt()
|
||||||
|
val index2 = index1 + 1
|
||||||
|
val fraction = position - index1
|
||||||
|
|
||||||
|
val sample1 = if (index1 < inputSamples.size) inputSamples[index1].toDouble() else 0.0
|
||||||
|
val sample2 = if (index2 < inputSamples.size) inputSamples[index2].toDouble() else 0.0
|
||||||
|
|
||||||
|
resampledData[i] = (sample1 * (1 - fraction) + sample2 * fraction).toInt().toShort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resampledData
|
||||||
|
}
|
||||||
|
|
|
@ -50,6 +50,7 @@ enum class ConfigKey(val label: String) {
|
||||||
DEFAULT_TOPP("Default TopP"),
|
DEFAULT_TOPP("Default TopP"),
|
||||||
DEFAULT_TEMPERATURE("Default temperature"),
|
DEFAULT_TEMPERATURE("Default temperature"),
|
||||||
SUPPORT_IMAGE("Support image"),
|
SUPPORT_IMAGE("Support image"),
|
||||||
|
SUPPORT_AUDIO("Support audio"),
|
||||||
MAX_RESULT_COUNT("Max result count"),
|
MAX_RESULT_COUNT("Max result count"),
|
||||||
USE_GPU("Use GPU"),
|
USE_GPU("Use GPU"),
|
||||||
ACCELERATOR("Choose accelerator"),
|
ACCELERATOR("Choose accelerator"),
|
||||||
|
|
|
@ -44,3 +44,12 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
|
||||||
|
|
||||||
// Max number of images allowed in a "ask image" session.
|
// Max number of images allowed in a "ask image" session.
|
||||||
const val MAX_IMAGE_COUNT = 10
|
const val MAX_IMAGE_COUNT = 10
|
||||||
|
|
||||||
|
// Max number of audio clip in an "ask audio" session.
|
||||||
|
const val MAX_AUDIO_CLIP_COUNT = 10
|
||||||
|
|
||||||
|
// Max audio clip duration in seconds.
|
||||||
|
const val MAX_AUDIO_CLIP_DURATION_SEC = 30
|
||||||
|
|
||||||
|
// Audio-recording related consts.
|
||||||
|
const val SAMPLE_RATE = 16000
|
||||||
|
|
|
@ -87,6 +87,9 @@ data class Model(
|
||||||
/** Whether the LLM model supports image input. */
|
/** Whether the LLM model supports image input. */
|
||||||
val llmSupportImage: Boolean = false,
|
val llmSupportImage: Boolean = false,
|
||||||
|
|
||||||
|
/** Whether the LLM model supports audio input. */
|
||||||
|
val llmSupportAudio: Boolean = false,
|
||||||
|
|
||||||
/** Whether the model is imported or not. */
|
/** Whether the model is imported or not. */
|
||||||
val imported: Boolean = false,
|
val imported: Boolean = false,
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,7 @@ data class AllowedModel(
|
||||||
val taskTypes: List<String>,
|
val taskTypes: List<String>,
|
||||||
val disabled: Boolean? = null,
|
val disabled: Boolean? = null,
|
||||||
val llmSupportImage: Boolean? = null,
|
val llmSupportImage: Boolean? = null,
|
||||||
|
val llmSupportAudio: Boolean? = null,
|
||||||
val estimatedPeakMemoryInBytes: Long? = null,
|
val estimatedPeakMemoryInBytes: Long? = null,
|
||||||
) {
|
) {
|
||||||
fun toModel(): Model {
|
fun toModel(): Model {
|
||||||
|
@ -96,6 +97,7 @@ data class AllowedModel(
|
||||||
showRunAgainButton = showRunAgainButton,
|
showRunAgainButton = showRunAgainButton,
|
||||||
learnMoreUrl = "https://huggingface.co/${modelId}",
|
learnMoreUrl = "https://huggingface.co/${modelId}",
|
||||||
llmSupportImage = llmSupportImage == true,
|
llmSupportImage = llmSupportImage == true,
|
||||||
|
llmSupportAudio = llmSupportAudio == true,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,19 +17,22 @@
|
||||||
package com.google.ai.edge.gallery.data
|
package com.google.ai.edge.gallery.data
|
||||||
|
|
||||||
import androidx.annotation.StringRes
|
import androidx.annotation.StringRes
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.outlined.Forum
|
||||||
|
import androidx.compose.material.icons.outlined.Mic
|
||||||
|
import androidx.compose.material.icons.outlined.Mms
|
||||||
|
import androidx.compose.material.icons.outlined.Widgets
|
||||||
import androidx.compose.runtime.MutableState
|
import androidx.compose.runtime.MutableState
|
||||||
import androidx.compose.runtime.mutableLongStateOf
|
import androidx.compose.runtime.mutableLongStateOf
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
import com.google.ai.edge.gallery.R
|
import com.google.ai.edge.gallery.R
|
||||||
import com.google.ai.edge.gallery.ui.icon.Forum
|
|
||||||
import com.google.ai.edge.gallery.ui.icon.Mms
|
|
||||||
import com.google.ai.edge.gallery.ui.icon.Widgets
|
|
||||||
|
|
||||||
/** Type of task. */
|
/** Type of task. */
|
||||||
enum class TaskType(val label: String, val id: String) {
|
enum class TaskType(val label: String, val id: String) {
|
||||||
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
|
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
|
||||||
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
|
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
|
||||||
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
|
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
|
||||||
|
LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
|
||||||
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
||||||
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
|
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
|
||||||
}
|
}
|
||||||
|
@ -71,7 +74,7 @@ data class Task(
|
||||||
val TASK_LLM_CHAT =
|
val TASK_LLM_CHAT =
|
||||||
Task(
|
Task(
|
||||||
type = TaskType.LLM_CHAT,
|
type = TaskType.LLM_CHAT,
|
||||||
icon = Forum,
|
icon = Icons.Outlined.Forum,
|
||||||
models = mutableListOf(),
|
models = mutableListOf(),
|
||||||
description = "Chat with on-device large language models",
|
description = "Chat with on-device large language models",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
@ -83,7 +86,7 @@ val TASK_LLM_CHAT =
|
||||||
val TASK_LLM_PROMPT_LAB =
|
val TASK_LLM_PROMPT_LAB =
|
||||||
Task(
|
Task(
|
||||||
type = TaskType.LLM_PROMPT_LAB,
|
type = TaskType.LLM_PROMPT_LAB,
|
||||||
icon = Widgets,
|
icon = Icons.Outlined.Widgets,
|
||||||
models = mutableListOf(),
|
models = mutableListOf(),
|
||||||
description = "Single turn use cases with on-device large language model",
|
description = "Single turn use cases with on-device large language model",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
@ -95,7 +98,7 @@ val TASK_LLM_PROMPT_LAB =
|
||||||
val TASK_LLM_ASK_IMAGE =
|
val TASK_LLM_ASK_IMAGE =
|
||||||
Task(
|
Task(
|
||||||
type = TaskType.LLM_ASK_IMAGE,
|
type = TaskType.LLM_ASK_IMAGE,
|
||||||
icon = Mms,
|
icon = Icons.Outlined.Mms,
|
||||||
models = mutableListOf(),
|
models = mutableListOf(),
|
||||||
description = "Ask questions about images with on-device large language models",
|
description = "Ask questions about images with on-device large language models",
|
||||||
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
@ -104,8 +107,23 @@ val TASK_LLM_ASK_IMAGE =
|
||||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val TASK_LLM_ASK_AUDIO =
|
||||||
|
Task(
|
||||||
|
type = TaskType.LLM_ASK_AUDIO,
|
||||||
|
icon = Icons.Outlined.Mic,
|
||||||
|
models = mutableListOf(),
|
||||||
|
// TODO(do not submit)
|
||||||
|
description =
|
||||||
|
"Instantly transcribe and/or translate audio clips using on-device large language models",
|
||||||
|
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
|
||||||
|
sourceCodeUrl =
|
||||||
|
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
|
||||||
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
|
||||||
|
)
|
||||||
|
|
||||||
/** All tasks. */
|
/** All tasks. */
|
||||||
val TASKS: List<Task> = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
|
val TASKS: List<Task> =
|
||||||
|
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
|
||||||
|
|
||||||
fun getModelByName(name: String): Model? {
|
fun getModelByName(name: String): Model? {
|
||||||
for (task in TASKS) {
|
for (task in TASKS) {
|
||||||
|
|
|
@ -21,6 +21,7 @@ import androidx.lifecycle.viewmodel.CreationExtras
|
||||||
import androidx.lifecycle.viewmodel.initializer
|
import androidx.lifecycle.viewmodel.initializer
|
||||||
import androidx.lifecycle.viewmodel.viewModelFactory
|
import androidx.lifecycle.viewmodel.viewModelFactory
|
||||||
import com.google.ai.edge.gallery.GalleryApplication
|
import com.google.ai.edge.gallery.GalleryApplication
|
||||||
|
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
|
||||||
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
|
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
|
||||||
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
|
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
|
||||||
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||||
|
@ -49,6 +50,9 @@ object ViewModelProvider {
|
||||||
|
|
||||||
// Initializer for LlmAskImageViewModel.
|
// Initializer for LlmAskImageViewModel.
|
||||||
initializer { LlmAskImageViewModel() }
|
initializer { LlmAskImageViewModel() }
|
||||||
|
|
||||||
|
// Initializer for LlmAskAudioViewModel.
|
||||||
|
initializer { LlmAskAudioViewModel() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,344 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.ai.edge.gallery.ui.common.chat
|
||||||
|
|
||||||
|
import android.media.AudioAttributes
|
||||||
|
import android.media.AudioFormat
|
||||||
|
import android.media.AudioTrack
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.compose.foundation.Canvas
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.height
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.width
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.rounded.PlayArrow
|
||||||
|
import androidx.compose.material.icons.rounded.Stop
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.IconButton
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.DisposableEffect
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.MutableState
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableFloatStateOf
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.geometry.CornerRadius
|
||||||
|
import androidx.compose.ui.geometry.Offset
|
||||||
|
import androidx.compose.ui.geometry.Size
|
||||||
|
import androidx.compose.ui.geometry.toRect
|
||||||
|
import androidx.compose.ui.graphics.BlendMode
|
||||||
|
import androidx.compose.ui.graphics.Color
|
||||||
|
import androidx.compose.ui.graphics.drawscope.drawIntoCanvas
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
|
||||||
|
import com.google.ai.edge.gallery.ui.theme.customColors
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
import java.nio.ByteOrder
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.isActive
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
|
||||||
|
private const val TAG = "AGAudioPlaybackPanel"
|
||||||
|
private const val BAR_SPACE = 2
|
||||||
|
private const val BAR_WIDTH = 2
|
||||||
|
private const val MIN_BAR_COUNT = 16
|
||||||
|
private const val MAX_BAR_COUNT = 48
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Composable that displays an audio playback panel, including play/stop controls, a waveform
|
||||||
|
* visualization, and the duration of the audio clip.
|
||||||
|
*/
|
||||||
|
@Composable
|
||||||
|
fun AudioPlaybackPanel(
|
||||||
|
audioData: ByteArray,
|
||||||
|
sampleRate: Int,
|
||||||
|
isRecording: Boolean,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
onDarkBg: Boolean = false,
|
||||||
|
) {
|
||||||
|
val coroutineScope = rememberCoroutineScope()
|
||||||
|
var isPlaying by remember { mutableStateOf(false) }
|
||||||
|
val audioTrackState = remember { mutableStateOf<AudioTrack?>(null) }
|
||||||
|
val durationInSeconds =
|
||||||
|
remember(audioData) {
|
||||||
|
// PCM 16-bit
|
||||||
|
val bytesPerSample = 2
|
||||||
|
val bytesPerFrame = bytesPerSample * 1 // mono
|
||||||
|
val totalFrames = audioData.size.toDouble() / bytesPerFrame
|
||||||
|
totalFrames / sampleRate
|
||||||
|
}
|
||||||
|
val barCount =
|
||||||
|
remember(durationInSeconds) {
|
||||||
|
val f = durationInSeconds / MAX_AUDIO_CLIP_DURATION_SEC
|
||||||
|
((MAX_BAR_COUNT - MIN_BAR_COUNT) * f + MIN_BAR_COUNT).toInt()
|
||||||
|
}
|
||||||
|
val amplitudeLevels =
|
||||||
|
remember(audioData) { generateAmplitudeLevels(audioData = audioData, barCount = barCount) }
|
||||||
|
var playbackProgress by remember { mutableFloatStateOf(0f) }
|
||||||
|
|
||||||
|
// Reset when a new recording is started.
|
||||||
|
LaunchedEffect(isRecording) {
|
||||||
|
if (isRecording) {
|
||||||
|
val audioTrack = audioTrackState.value
|
||||||
|
audioTrack?.stop()
|
||||||
|
isPlaying = false
|
||||||
|
playbackProgress = 0f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup on Composable Disposal.
|
||||||
|
DisposableEffect(Unit) {
|
||||||
|
onDispose {
|
||||||
|
val audioTrack = audioTrackState.value
|
||||||
|
audioTrack?.stop()
|
||||||
|
audioTrack?.release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) {
|
||||||
|
// Button to play/stop the clip.
|
||||||
|
IconButton(
|
||||||
|
onClick = {
|
||||||
|
coroutineScope.launch {
|
||||||
|
if (!isPlaying) {
|
||||||
|
isPlaying = true
|
||||||
|
playAudio(
|
||||||
|
audioTrackState = audioTrackState,
|
||||||
|
audioData = audioData,
|
||||||
|
sampleRate = sampleRate,
|
||||||
|
onProgress = { playbackProgress = it },
|
||||||
|
onCompletion = {
|
||||||
|
playbackProgress = 0f
|
||||||
|
isPlaying = false
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
stopPlayAudio(audioTrackState = audioTrackState)
|
||||||
|
playbackProgress = 0f
|
||||||
|
isPlaying = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
if (isPlaying) Icons.Rounded.Stop else Icons.Rounded.PlayArrow,
|
||||||
|
contentDescription = "",
|
||||||
|
tint = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visualization
|
||||||
|
AmplitudeBarGraph(
|
||||||
|
amplitudeLevels = amplitudeLevels,
|
||||||
|
progress = playbackProgress,
|
||||||
|
modifier =
|
||||||
|
Modifier.width((barCount * BAR_WIDTH + (barCount - 1) * BAR_SPACE).dp).height(24.dp),
|
||||||
|
onDarkBg = onDarkBg,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Duration
|
||||||
|
Text(
|
||||||
|
"${"%.1f".format(durationInSeconds)}s",
|
||||||
|
style = MaterialTheme.typography.labelLarge,
|
||||||
|
color = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary,
|
||||||
|
modifier = Modifier.padding(start = 12.dp),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun AmplitudeBarGraph(
|
||||||
|
amplitudeLevels: List<Float>,
|
||||||
|
progress: Float,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
onDarkBg: Boolean = false,
|
||||||
|
) {
|
||||||
|
val barColor = MaterialTheme.customColors.waveFormBgColor
|
||||||
|
val progressColor = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary
|
||||||
|
|
||||||
|
Canvas(modifier = modifier) {
|
||||||
|
val barCount = amplitudeLevels.size
|
||||||
|
val barWidth = (size.width - BAR_SPACE.dp.toPx() * (barCount - 1)) / barCount
|
||||||
|
val cornerRadius = CornerRadius(x = barWidth, y = barWidth)
|
||||||
|
|
||||||
|
// Use drawIntoCanvas for advanced blend mode operations
|
||||||
|
drawIntoCanvas { canvas ->
|
||||||
|
|
||||||
|
// 1. Save the current state of the canvas onto a temporary, offscreen layer
|
||||||
|
canvas.saveLayer(size.toRect(), androidx.compose.ui.graphics.Paint())
|
||||||
|
|
||||||
|
// 2. Draw the bars in grey.
|
||||||
|
amplitudeLevels.forEachIndexed { index, level ->
|
||||||
|
val barHeight = (level * size.height).coerceAtLeast(1.5f)
|
||||||
|
val left = index * (barWidth + BAR_SPACE.dp.toPx())
|
||||||
|
drawRoundRect(
|
||||||
|
color = barColor,
|
||||||
|
topLeft = Offset(x = left, y = size.height / 2 - barHeight / 2),
|
||||||
|
size = Size(barWidth, barHeight),
|
||||||
|
cornerRadius = cornerRadius,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Draw the progress rectangle using BlendMode.SrcIn to only draw where the bars already
|
||||||
|
// exists.
|
||||||
|
val progressWidth = size.width * progress
|
||||||
|
drawRect(
|
||||||
|
color = progressColor,
|
||||||
|
topLeft = Offset.Zero,
|
||||||
|
size = Size(progressWidth, size.height),
|
||||||
|
blendMode = BlendMode.SrcIn,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 4. Restore the layer, merging it onto the main canvas
|
||||||
|
canvas.restore()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private suspend fun playAudio(
|
||||||
|
audioTrackState: MutableState<AudioTrack?>,
|
||||||
|
audioData: ByteArray,
|
||||||
|
sampleRate: Int,
|
||||||
|
onProgress: (Float) -> Unit,
|
||||||
|
onCompletion: () -> Unit,
|
||||||
|
) {
|
||||||
|
Log.d(TAG, "Start playing audio...")
|
||||||
|
|
||||||
|
try {
|
||||||
|
withContext(Dispatchers.IO) {
|
||||||
|
var lastProgressUpdateMs = 0L
|
||||||
|
audioTrackState.value?.release()
|
||||||
|
val audioTrack =
|
||||||
|
AudioTrack.Builder()
|
||||||
|
.setAudioAttributes(
|
||||||
|
AudioAttributes.Builder()
|
||||||
|
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
|
||||||
|
.setUsage(AudioAttributes.USAGE_MEDIA)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.setAudioFormat(
|
||||||
|
AudioFormat.Builder()
|
||||||
|
.setEncoding(AudioFormat.ENCODING_PCM_16BIT)
|
||||||
|
.setSampleRate(sampleRate)
|
||||||
|
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.setTransferMode(AudioTrack.MODE_STATIC)
|
||||||
|
.setBufferSizeInBytes(audioData.size)
|
||||||
|
.build()
|
||||||
|
|
||||||
|
val bytesPerFrame = 2 // For PCM 16-bit Mono
|
||||||
|
val totalFrames = audioData.size / bytesPerFrame
|
||||||
|
|
||||||
|
audioTrackState.value = audioTrack
|
||||||
|
audioTrack.write(audioData, 0, audioData.size)
|
||||||
|
audioTrack.play()
|
||||||
|
|
||||||
|
// Coroutine to monitor progress
|
||||||
|
while (isActive && audioTrack.playState == AudioTrack.PLAYSTATE_PLAYING) {
|
||||||
|
val currentFrames = audioTrack.playbackHeadPosition
|
||||||
|
if (currentFrames >= totalFrames) {
|
||||||
|
break // Exit loop when playback is done
|
||||||
|
}
|
||||||
|
val progress = currentFrames.toFloat() / totalFrames
|
||||||
|
val curMs = System.currentTimeMillis()
|
||||||
|
if (curMs - lastProgressUpdateMs > 30) {
|
||||||
|
onProgress(progress)
|
||||||
|
lastProgressUpdateMs = curMs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isActive) {
|
||||||
|
audioTrackState.value?.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
// Ignore
|
||||||
|
} finally {
|
||||||
|
onProgress(1f)
|
||||||
|
onCompletion()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun stopPlayAudio(audioTrackState: MutableState<AudioTrack?>) {
|
||||||
|
Log.d(TAG, "Stopping playing audio...")
|
||||||
|
|
||||||
|
val audioTrack = audioTrackState.value
|
||||||
|
audioTrack?.stop()
|
||||||
|
audioTrack?.release()
|
||||||
|
audioTrackState.value = null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Processes a raw PCM 16-bit audio byte array to generate a list of normalized amplitude levels for
|
||||||
|
* visualization.
|
||||||
|
*/
|
||||||
|
private fun generateAmplitudeLevels(audioData: ByteArray, barCount: Int): List<Float> {
|
||||||
|
if (audioData.isEmpty()) {
|
||||||
|
return List(barCount) { 0f }
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Parse bytes into 16-bit short samples (PCM 16-bit)
|
||||||
|
val shortBuffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
|
||||||
|
val samples = ShortArray(shortBuffer.remaining())
|
||||||
|
shortBuffer.get(samples)
|
||||||
|
|
||||||
|
if (samples.isEmpty()) {
|
||||||
|
return List(barCount) { 0f }
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Determine the size of each chunk
|
||||||
|
val chunkSize = samples.size / barCount
|
||||||
|
val amplitudeLevels = mutableListOf<Float>()
|
||||||
|
|
||||||
|
// 3. Get the max value for each chunk
|
||||||
|
for (i in 0 until barCount) {
|
||||||
|
val chunkStart = i * chunkSize
|
||||||
|
val chunkEnd = (chunkStart + chunkSize).coerceAtMost(samples.size)
|
||||||
|
|
||||||
|
var maxAmplitudeInChunk = 0.0
|
||||||
|
|
||||||
|
for (j in chunkStart until chunkEnd) {
|
||||||
|
val sampleAbs = kotlin.math.abs(samples[j].toDouble())
|
||||||
|
if (sampleAbs > maxAmplitudeInChunk) {
|
||||||
|
maxAmplitudeInChunk = sampleAbs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Normalize the value (0 to 1)
|
||||||
|
// Short.MAX_VALUE is 32767.0, a good reference for max amplitude
|
||||||
|
val normalizedRms = (maxAmplitudeInChunk / Short.MAX_VALUE).toFloat().coerceIn(0f, 1f)
|
||||||
|
amplitudeLevels.add(normalizedRms)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize the resulting levels so that the max value becomes 0.9.
|
||||||
|
val maxVal = amplitudeLevels.max()
|
||||||
|
if (maxVal == 0f) {
|
||||||
|
return amplitudeLevels
|
||||||
|
}
|
||||||
|
val scaleFactor = 0.9f / maxVal
|
||||||
|
return amplitudeLevels.map { it * scaleFactor }
|
||||||
|
}
|
|
@ -0,0 +1,327 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.ai.edge.gallery.ui.common.chat
|
||||||
|
|
||||||
|
import android.annotation.SuppressLint
|
||||||
|
import android.content.Context
|
||||||
|
import android.media.AudioFormat
|
||||||
|
import android.media.AudioRecord
|
||||||
|
import android.media.MediaRecorder
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.height
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.size
|
||||||
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.rounded.ArrowUpward
|
||||||
|
import androidx.compose.material.icons.rounded.Mic
|
||||||
|
import androidx.compose.material.icons.rounded.Stop
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.IconButton
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.DisposableEffect
|
||||||
|
import androidx.compose.runtime.MutableLongState
|
||||||
|
import androidx.compose.runtime.MutableState
|
||||||
|
import androidx.compose.runtime.derivedStateOf
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableIntStateOf
|
||||||
|
import androidx.compose.runtime.mutableLongStateOf
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.graphics.Color
|
||||||
|
import androidx.compose.ui.graphics.graphicsLayer
|
||||||
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
import androidx.compose.ui.res.painterResource
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.google.ai.edge.gallery.R
|
||||||
|
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC
|
||||||
|
import com.google.ai.edge.gallery.data.SAMPLE_RATE
|
||||||
|
import com.google.ai.edge.gallery.ui.theme.customColors
|
||||||
|
import java.io.ByteArrayOutputStream
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
import java.nio.ByteOrder
|
||||||
|
import kotlin.math.abs
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.coroutineScope
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
|
private const val TAG = "AGAudioRecorderPanel"
|
||||||
|
|
||||||
|
private const val CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO
|
||||||
|
private const val AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Composable that provides an audio recording panel. It allows users to record audio clips,
|
||||||
|
* displays the recording duration and a live amplitude visualization, and provides options to play
|
||||||
|
* back the recorded clip or send it.
|
||||||
|
*/
|
||||||
|
@Composable
|
||||||
|
fun AudioRecorderPanel(onSendAudioClip: (ByteArray) -> Unit) {
|
||||||
|
val context = LocalContext.current
|
||||||
|
val coroutineScope = rememberCoroutineScope()
|
||||||
|
|
||||||
|
var isRecording by remember { mutableStateOf(false) }
|
||||||
|
val elapsedMs = remember { mutableLongStateOf(0L) }
|
||||||
|
val audioRecordState = remember { mutableStateOf<AudioRecord?>(null) }
|
||||||
|
val audioStream = remember { ByteArrayOutputStream() }
|
||||||
|
val recordedBytes = remember { mutableStateOf<ByteArray?>(null) }
|
||||||
|
var currentAmplitude by remember { mutableIntStateOf(0) }
|
||||||
|
|
||||||
|
val elapsedSeconds by remember {
|
||||||
|
derivedStateOf { "%.1f".format(elapsedMs.value.toFloat() / 1000f) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup on Composable Disposal.
|
||||||
|
DisposableEffect(Unit) { onDispose { audioRecordState.value?.release() } }
|
||||||
|
|
||||||
|
Column(modifier = Modifier.padding(bottom = 12.dp)) {
|
||||||
|
// Title bar.
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.SpaceBetween,
|
||||||
|
) {
|
||||||
|
// Logo and state.
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.padding(start = 16.dp),
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
painterResource(R.drawable.logo),
|
||||||
|
modifier = Modifier.size(20.dp),
|
||||||
|
contentDescription = "",
|
||||||
|
tint = Color.Unspecified,
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
"Record audio clip (up to $MAX_AUDIO_CLIP_DURATION_SEC seconds)",
|
||||||
|
style = MaterialTheme.typography.labelLarge,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recorded clip.
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.fillMaxWidth().padding(vertical = 12.dp).height(40.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.Center,
|
||||||
|
) {
|
||||||
|
val curRecordedBytes = recordedBytes.value
|
||||||
|
if (curRecordedBytes == null) {
|
||||||
|
// Info message when there is no recorded clip and the recording has not started yet.
|
||||||
|
if (!isRecording) {
|
||||||
|
Text(
|
||||||
|
"Tap the record button to start",
|
||||||
|
style = MaterialTheme.typography.labelLarge,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
// Visualization for clip being recorded.
|
||||||
|
else {
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
) {
|
||||||
|
Box(
|
||||||
|
modifier =
|
||||||
|
Modifier.size(8.dp)
|
||||||
|
.background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
|
||||||
|
)
|
||||||
|
Text("$elapsedSeconds s")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Controls for recorded clip.
|
||||||
|
else {
|
||||||
|
Row {
|
||||||
|
// Clip player.
|
||||||
|
AudioPlaybackPanel(
|
||||||
|
audioData = curRecordedBytes,
|
||||||
|
sampleRate = SAMPLE_RATE,
|
||||||
|
isRecording = isRecording,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Button to send the clip
|
||||||
|
IconButton(onClick = { onSendAudioClip(curRecordedBytes) }) {
|
||||||
|
Icon(Icons.Rounded.ArrowUpward, contentDescription = "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buttons
|
||||||
|
Box(contentAlignment = Alignment.Center, modifier = Modifier.fillMaxWidth().height(40.dp)) {
|
||||||
|
// Visualization of the current amplitude.
|
||||||
|
if (isRecording) {
|
||||||
|
// Normalize the amplitude (0-32767) to a fraction (0.0-1.0)
|
||||||
|
// We use a power scale (exponent < 1) to make the pulse more visible for lower volumes.
|
||||||
|
val normalizedAmplitude = (currentAmplitude.toFloat() / 32767f).pow(0.35f)
|
||||||
|
// Define the min and max size of the circle
|
||||||
|
val minSize = 38.dp
|
||||||
|
val maxSize = 100.dp
|
||||||
|
|
||||||
|
// Map the normalized amplitude to our size range
|
||||||
|
val scale by
|
||||||
|
remember(normalizedAmplitude) {
|
||||||
|
derivedStateOf { (minSize + (maxSize - minSize) * normalizedAmplitude) / minSize }
|
||||||
|
}
|
||||||
|
Box(
|
||||||
|
modifier =
|
||||||
|
Modifier.size(minSize)
|
||||||
|
.graphicsLayer(scaleX = scale, scaleY = scale, clip = false, alpha = 0.3f)
|
||||||
|
.background(MaterialTheme.customColors.recordButtonBgColor, CircleShape)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record/stop button.
|
||||||
|
IconButton(
|
||||||
|
onClick = {
|
||||||
|
coroutineScope.launch {
|
||||||
|
if (!isRecording) {
|
||||||
|
isRecording = true
|
||||||
|
recordedBytes.value = null
|
||||||
|
startRecording(
|
||||||
|
context = context,
|
||||||
|
audioRecordState = audioRecordState,
|
||||||
|
audioStream = audioStream,
|
||||||
|
elapsedMs = elapsedMs,
|
||||||
|
onAmplitudeChanged = { currentAmplitude = it },
|
||||||
|
onMaxDurationReached = {
|
||||||
|
val curRecordedBytes =
|
||||||
|
stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
|
||||||
|
recordedBytes.value = curRecordedBytes
|
||||||
|
isRecording = false
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
val curRecordedBytes =
|
||||||
|
stopRecording(audioRecordState = audioRecordState, audioStream = audioStream)
|
||||||
|
recordedBytes.value = curRecordedBytes
|
||||||
|
isRecording = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
modifier =
|
||||||
|
Modifier.clip(CircleShape).background(MaterialTheme.customColors.recordButtonBgColor),
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
if (isRecording) Icons.Rounded.Stop else Icons.Rounded.Mic,
|
||||||
|
contentDescription = "",
|
||||||
|
tint = Color.White,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Permission is checked in parent composable.
|
||||||
|
@SuppressLint("MissingPermission")
|
||||||
|
private suspend fun startRecording(
|
||||||
|
context: Context,
|
||||||
|
audioRecordState: MutableState<AudioRecord?>,
|
||||||
|
audioStream: ByteArrayOutputStream,
|
||||||
|
elapsedMs: MutableLongState,
|
||||||
|
onAmplitudeChanged: (Int) -> Unit,
|
||||||
|
onMaxDurationReached: () -> Unit,
|
||||||
|
) {
|
||||||
|
Log.d(TAG, "Start recording...")
|
||||||
|
val minBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT)
|
||||||
|
|
||||||
|
audioRecordState.value?.release()
|
||||||
|
val recorder =
|
||||||
|
AudioRecord(
|
||||||
|
MediaRecorder.AudioSource.MIC,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
CHANNEL_CONFIG,
|
||||||
|
AUDIO_FORMAT,
|
||||||
|
minBufferSize,
|
||||||
|
)
|
||||||
|
|
||||||
|
audioRecordState.value = recorder
|
||||||
|
val buffer = ByteArray(minBufferSize)
|
||||||
|
|
||||||
|
// The function will only return when the recording is done (when stopRecording is called).
|
||||||
|
coroutineScope {
|
||||||
|
launch(Dispatchers.IO) {
|
||||||
|
recorder.startRecording()
|
||||||
|
|
||||||
|
val startMs = System.currentTimeMillis()
|
||||||
|
elapsedMs.value = 0L
|
||||||
|
while (audioRecordState.value?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
|
||||||
|
val bytesRead = recorder.read(buffer, 0, buffer.size)
|
||||||
|
if (bytesRead > 0) {
|
||||||
|
val currentAmplitude = calculatePeakAmplitude(buffer = buffer, bytesRead = bytesRead)
|
||||||
|
onAmplitudeChanged(currentAmplitude)
|
||||||
|
audioStream.write(buffer, 0, bytesRead)
|
||||||
|
}
|
||||||
|
elapsedMs.value = System.currentTimeMillis() - startMs
|
||||||
|
if (elapsedMs.value >= MAX_AUDIO_CLIP_DURATION_SEC * 1000) {
|
||||||
|
onMaxDurationReached()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun stopRecording(
|
||||||
|
audioRecordState: MutableState<AudioRecord?>,
|
||||||
|
audioStream: ByteArrayOutputStream,
|
||||||
|
): ByteArray {
|
||||||
|
Log.d(TAG, "Stopping recording...")
|
||||||
|
|
||||||
|
val recorder = audioRecordState.value
|
||||||
|
if (recorder?.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
|
||||||
|
recorder.stop()
|
||||||
|
}
|
||||||
|
recorder?.release()
|
||||||
|
audioRecordState.value = null
|
||||||
|
|
||||||
|
val recordedBytes = audioStream.toByteArray()
|
||||||
|
audioStream.reset()
|
||||||
|
Log.d(TAG, "Stopped. Recorded ${recordedBytes.size} bytes.")
|
||||||
|
|
||||||
|
return recordedBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun calculatePeakAmplitude(buffer: ByteArray, bytesRead: Int): Int {
|
||||||
|
// Wrap the byte array in a ByteBuffer and set the order to little-endian
|
||||||
|
val shortBuffer =
|
||||||
|
ByteBuffer.wrap(buffer, 0, bytesRead).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
|
||||||
|
|
||||||
|
var maxAmplitude = 0
|
||||||
|
// Iterate through the short buffer to find the maximum absolute value
|
||||||
|
while (shortBuffer.hasRemaining()) {
|
||||||
|
val currentSample = abs(shortBuffer.get().toInt())
|
||||||
|
if (currentSample > maxAmplitude) {
|
||||||
|
maxAmplitude = currentSample
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxAmplitude
|
||||||
|
}
|
|
@ -17,18 +17,22 @@
|
||||||
package com.google.ai.edge.gallery.ui.common.chat
|
package com.google.ai.edge.gallery.ui.common.chat
|
||||||
|
|
||||||
import android.graphics.Bitmap
|
import android.graphics.Bitmap
|
||||||
|
import android.util.Log
|
||||||
import androidx.compose.ui.graphics.ImageBitmap
|
import androidx.compose.ui.graphics.ImageBitmap
|
||||||
import androidx.compose.ui.unit.Dp
|
import androidx.compose.ui.unit.Dp
|
||||||
import com.google.ai.edge.gallery.common.Classification
|
import com.google.ai.edge.gallery.common.Classification
|
||||||
import com.google.ai.edge.gallery.data.Model
|
import com.google.ai.edge.gallery.data.Model
|
||||||
import com.google.ai.edge.gallery.data.PromptTemplate
|
import com.google.ai.edge.gallery.data.PromptTemplate
|
||||||
|
|
||||||
|
private const val TAG = "AGChatMessage"
|
||||||
|
|
||||||
enum class ChatMessageType {
|
enum class ChatMessageType {
|
||||||
INFO,
|
INFO,
|
||||||
WARNING,
|
WARNING,
|
||||||
TEXT,
|
TEXT,
|
||||||
IMAGE,
|
IMAGE,
|
||||||
IMAGE_WITH_HISTORY,
|
IMAGE_WITH_HISTORY,
|
||||||
|
AUDIO_CLIP,
|
||||||
LOADING,
|
LOADING,
|
||||||
CLASSIFICATION,
|
CLASSIFICATION,
|
||||||
CONFIG_VALUES_CHANGE,
|
CONFIG_VALUES_CHANGE,
|
||||||
|
@ -121,6 +125,90 @@ class ChatMessageImage(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Chat message for audio clip. */
|
||||||
|
class ChatMessageAudioClip(
|
||||||
|
val audioData: ByteArray,
|
||||||
|
val sampleRate: Int,
|
||||||
|
override val side: ChatSide,
|
||||||
|
override val latencyMs: Float = 0f,
|
||||||
|
) : ChatMessage(type = ChatMessageType.AUDIO_CLIP, side = side, latencyMs = latencyMs) {
|
||||||
|
override fun clone(): ChatMessageAudioClip {
|
||||||
|
return ChatMessageAudioClip(
|
||||||
|
audioData = audioData,
|
||||||
|
sampleRate = sampleRate,
|
||||||
|
side = side,
|
||||||
|
latencyMs = latencyMs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun genByteArrayForWav(): ByteArray {
|
||||||
|
val header = ByteArray(44)
|
||||||
|
|
||||||
|
val pcmDataSize = audioData.size
|
||||||
|
val wavFileSize = pcmDataSize + 44 // 44 bytes for the header
|
||||||
|
val channels = 1 // Mono
|
||||||
|
val bitsPerSample: Short = 16
|
||||||
|
val byteRate = sampleRate * channels * bitsPerSample / 8
|
||||||
|
Log.d(TAG, "Wav metadata: sampleRate: $sampleRate")
|
||||||
|
|
||||||
|
// RIFF/WAVE header
|
||||||
|
header[0] = 'R'.code.toByte()
|
||||||
|
header[1] = 'I'.code.toByte()
|
||||||
|
header[2] = 'F'.code.toByte()
|
||||||
|
header[3] = 'F'.code.toByte()
|
||||||
|
header[4] = (wavFileSize and 0xff).toByte()
|
||||||
|
header[5] = (wavFileSize shr 8 and 0xff).toByte()
|
||||||
|
header[6] = (wavFileSize shr 16 and 0xff).toByte()
|
||||||
|
header[7] = (wavFileSize shr 24 and 0xff).toByte()
|
||||||
|
header[8] = 'W'.code.toByte()
|
||||||
|
header[9] = 'A'.code.toByte()
|
||||||
|
header[10] = 'V'.code.toByte()
|
||||||
|
header[11] = 'E'.code.toByte()
|
||||||
|
header[12] = 'f'.code.toByte()
|
||||||
|
header[13] = 'm'.code.toByte()
|
||||||
|
header[14] = 't'.code.toByte()
|
||||||
|
header[15] = ' '.code.toByte()
|
||||||
|
header[16] = 16
|
||||||
|
header[17] = 0
|
||||||
|
header[18] = 0
|
||||||
|
header[19] = 0 // Sub-chunk size (16 for PCM)
|
||||||
|
header[20] = 1
|
||||||
|
header[21] = 0 // Audio format (1 for PCM)
|
||||||
|
header[22] = channels.toByte()
|
||||||
|
header[23] = 0 // Number of channels
|
||||||
|
header[24] = (sampleRate and 0xff).toByte()
|
||||||
|
header[25] = (sampleRate shr 8 and 0xff).toByte()
|
||||||
|
header[26] = (sampleRate shr 16 and 0xff).toByte()
|
||||||
|
header[27] = (sampleRate shr 24 and 0xff).toByte()
|
||||||
|
header[28] = (byteRate and 0xff).toByte()
|
||||||
|
header[29] = (byteRate shr 8 and 0xff).toByte()
|
||||||
|
header[30] = (byteRate shr 16 and 0xff).toByte()
|
||||||
|
header[31] = (byteRate shr 24 and 0xff).toByte()
|
||||||
|
header[32] = (channels * bitsPerSample / 8).toByte()
|
||||||
|
header[33] = 0 // Block align
|
||||||
|
header[34] = bitsPerSample.toByte()
|
||||||
|
header[35] = (bitsPerSample.toInt() shr 8 and 0xff).toByte() // Bits per sample
|
||||||
|
header[36] = 'd'.code.toByte()
|
||||||
|
header[37] = 'a'.code.toByte()
|
||||||
|
header[38] = 't'.code.toByte()
|
||||||
|
header[39] = 'a'.code.toByte()
|
||||||
|
header[40] = (pcmDataSize and 0xff).toByte()
|
||||||
|
header[41] = (pcmDataSize shr 8 and 0xff).toByte()
|
||||||
|
header[42] = (pcmDataSize shr 16 and 0xff).toByte()
|
||||||
|
header[43] = (pcmDataSize shr 24 and 0xff).toByte()
|
||||||
|
|
||||||
|
return header + audioData
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getDurationInSeconds(): Float {
|
||||||
|
// PCM 16-bit
|
||||||
|
val bytesPerSample = 2
|
||||||
|
val bytesPerFrame = bytesPerSample * 1 // mono
|
||||||
|
val totalFrames = audioData.size.toFloat() / bytesPerFrame
|
||||||
|
return totalFrames / sampleRate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Chat message for images with history. */
|
/** Chat message for images with history. */
|
||||||
class ChatMessageImageWithHistory(
|
class ChatMessageImageWithHistory(
|
||||||
val bitmaps: List<Bitmap>,
|
val bitmaps: List<Bitmap>,
|
||||||
|
|
|
@ -137,6 +137,19 @@ fun ChatPanel(
|
||||||
}
|
}
|
||||||
imageMessageCount
|
imageMessageCount
|
||||||
}
|
}
|
||||||
|
val audioClipMesssageCountToLastconfigChange =
|
||||||
|
remember(messages) {
|
||||||
|
var audioClipMessageCount = 0
|
||||||
|
for (message in messages.reversed()) {
|
||||||
|
if (message is ChatMessageConfigValuesChange) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if (message is ChatMessageAudioClip) {
|
||||||
|
audioClipMessageCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
audioClipMessageCount
|
||||||
|
}
|
||||||
|
|
||||||
var curMessage by remember { mutableStateOf("") } // Correct state
|
var curMessage by remember { mutableStateOf("") } // Correct state
|
||||||
val focusManager = LocalFocusManager.current
|
val focusManager = LocalFocusManager.current
|
||||||
|
@ -342,6 +355,9 @@ fun ChatPanel(
|
||||||
imageHistoryCurIndex = imageHistoryCurIndex,
|
imageHistoryCurIndex = imageHistoryCurIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Audio clip.
|
||||||
|
is ChatMessageAudioClip -> MessageBodyAudioClip(message = message)
|
||||||
|
|
||||||
// Classification result
|
// Classification result
|
||||||
is ChatMessageClassification ->
|
is ChatMessageClassification ->
|
||||||
MessageBodyClassification(
|
MessageBodyClassification(
|
||||||
|
@ -467,6 +483,22 @@ fun ChatPanel(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Show an info message for ask image task to get users started.
|
||||||
|
else if (task.type == TaskType.LLM_ASK_AUDIO && messages.isEmpty()) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(),
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.Center,
|
||||||
|
) {
|
||||||
|
MessageBodyInfo(
|
||||||
|
ChatMessageInfo(
|
||||||
|
content =
|
||||||
|
"To get started, tap the + icon to add your audio clips. You can add up to 10 clips, each up to 30 seconds long."
|
||||||
|
),
|
||||||
|
smallFontSize = false,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Chat input
|
// Chat input
|
||||||
|
@ -482,6 +514,7 @@ fun ChatPanel(
|
||||||
isResettingSession = uiState.isResettingSession,
|
isResettingSession = uiState.isResettingSession,
|
||||||
modelPreparing = uiState.preparing,
|
modelPreparing = uiState.preparing,
|
||||||
imageMessageCount = imageMessageCountToLastConfigChange,
|
imageMessageCount = imageMessageCountToLastConfigChange,
|
||||||
|
audioClipMessageCount = audioClipMesssageCountToLastconfigChange,
|
||||||
modelInitializing =
|
modelInitializing =
|
||||||
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||||
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
textFieldPlaceHolderRes = task.textInputPlaceHolderRes,
|
||||||
|
@ -504,7 +537,10 @@ fun ChatPanel(
|
||||||
onStopButtonClicked = onStopButtonClicked,
|
onStopButtonClicked = onStopButtonClicked,
|
||||||
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
// showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen,
|
||||||
showPromptTemplatesInMenu = false,
|
showPromptTemplatesInMenu = false,
|
||||||
showImagePickerInMenu = selectedModel.llmSupportImage,
|
showImagePickerInMenu =
|
||||||
|
selectedModel.llmSupportImage && task.type === TaskType.LLM_ASK_IMAGE,
|
||||||
|
showAudioItemsInMenu =
|
||||||
|
selectedModel.llmSupportAudio && task.type === TaskType.LLM_ASK_AUDIO,
|
||||||
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
|
showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.ai.edge.gallery.ui.common.chat
|
||||||
|
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun MessageBodyAudioClip(message: ChatMessageAudioClip, modifier: Modifier = Modifier) {
|
||||||
|
AudioPlaybackPanel(
|
||||||
|
audioData = message.audioData,
|
||||||
|
sampleRate = message.sampleRate,
|
||||||
|
isRecording = false,
|
||||||
|
modifier = Modifier.padding(end = 16.dp),
|
||||||
|
onDarkBg = true,
|
||||||
|
)
|
||||||
|
}
|
|
@ -21,6 +21,7 @@ package com.google.ai.edge.gallery.ui.common.chat
|
||||||
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
|
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme
|
||||||
import android.Manifest
|
import android.Manifest
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import android.content.Intent
|
||||||
import android.content.pm.PackageManager
|
import android.content.pm.PackageManager
|
||||||
import android.graphics.Bitmap
|
import android.graphics.Bitmap
|
||||||
import android.graphics.BitmapFactory
|
import android.graphics.BitmapFactory
|
||||||
|
@ -65,9 +66,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.automirrored.rounded.Send
|
import androidx.compose.material.icons.automirrored.rounded.Send
|
||||||
import androidx.compose.material.icons.rounded.Add
|
import androidx.compose.material.icons.rounded.Add
|
||||||
|
import androidx.compose.material.icons.rounded.AudioFile
|
||||||
import androidx.compose.material.icons.rounded.Close
|
import androidx.compose.material.icons.rounded.Close
|
||||||
import androidx.compose.material.icons.rounded.FlipCameraAndroid
|
import androidx.compose.material.icons.rounded.FlipCameraAndroid
|
||||||
import androidx.compose.material.icons.rounded.History
|
import androidx.compose.material.icons.rounded.History
|
||||||
|
import androidx.compose.material.icons.rounded.Mic
|
||||||
import androidx.compose.material.icons.rounded.Photo
|
import androidx.compose.material.icons.rounded.Photo
|
||||||
import androidx.compose.material.icons.rounded.PhotoCamera
|
import androidx.compose.material.icons.rounded.PhotoCamera
|
||||||
import androidx.compose.material.icons.rounded.PostAdd
|
import androidx.compose.material.icons.rounded.PostAdd
|
||||||
|
@ -107,9 +110,14 @@ import androidx.compose.ui.unit.dp
|
||||||
import androidx.compose.ui.viewinterop.AndroidView
|
import androidx.compose.ui.viewinterop.AndroidView
|
||||||
import androidx.core.content.ContextCompat
|
import androidx.core.content.ContextCompat
|
||||||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||||
|
import com.google.ai.edge.gallery.common.AudioClip
|
||||||
|
import com.google.ai.edge.gallery.common.convertWavToMonoWithMaxSeconds
|
||||||
|
import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_COUNT
|
||||||
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
|
import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT
|
||||||
|
import com.google.ai.edge.gallery.data.SAMPLE_RATE
|
||||||
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import java.util.concurrent.Executors
|
import java.util.concurrent.Executors
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
private const val TAG = "AGMessageInputText"
|
private const val TAG = "AGMessageInputText"
|
||||||
|
@ -128,6 +136,7 @@ fun MessageInputText(
|
||||||
isResettingSession: Boolean,
|
isResettingSession: Boolean,
|
||||||
inProgress: Boolean,
|
inProgress: Boolean,
|
||||||
imageMessageCount: Int,
|
imageMessageCount: Int,
|
||||||
|
audioClipMessageCount: Int,
|
||||||
modelInitializing: Boolean,
|
modelInitializing: Boolean,
|
||||||
@StringRes textFieldPlaceHolderRes: Int,
|
@StringRes textFieldPlaceHolderRes: Int,
|
||||||
onValueChanged: (String) -> Unit,
|
onValueChanged: (String) -> Unit,
|
||||||
|
@ -137,6 +146,7 @@ fun MessageInputText(
|
||||||
onStopButtonClicked: () -> Unit = {},
|
onStopButtonClicked: () -> Unit = {},
|
||||||
showPromptTemplatesInMenu: Boolean = false,
|
showPromptTemplatesInMenu: Boolean = false,
|
||||||
showImagePickerInMenu: Boolean = false,
|
showImagePickerInMenu: Boolean = false,
|
||||||
|
showAudioItemsInMenu: Boolean = false,
|
||||||
showStopButtonWhenInProgress: Boolean = false,
|
showStopButtonWhenInProgress: Boolean = false,
|
||||||
) {
|
) {
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
|
@ -146,7 +156,12 @@ fun MessageInputText(
|
||||||
var showTextInputHistorySheet by remember { mutableStateOf(false) }
|
var showTextInputHistorySheet by remember { mutableStateOf(false) }
|
||||||
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
|
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
|
||||||
val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||||
|
var showAudioRecorderBottomSheet by remember { mutableStateOf(false) }
|
||||||
|
val audioRecorderSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||||
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
|
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
|
||||||
|
var pickedAudioClips by remember { mutableStateOf<List<AudioClip>>(listOf()) }
|
||||||
|
var hasFrontCamera by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
val updatePickedImages: (List<Bitmap>) -> Unit = { bitmaps ->
|
val updatePickedImages: (List<Bitmap>) -> Unit = { bitmaps ->
|
||||||
var newPickedImages: MutableList<Bitmap> = mutableListOf()
|
var newPickedImages: MutableList<Bitmap> = mutableListOf()
|
||||||
newPickedImages.addAll(pickedImages)
|
newPickedImages.addAll(pickedImages)
|
||||||
|
@ -156,7 +171,16 @@ fun MessageInputText(
|
||||||
}
|
}
|
||||||
pickedImages = newPickedImages.toList()
|
pickedImages = newPickedImages.toList()
|
||||||
}
|
}
|
||||||
var hasFrontCamera by remember { mutableStateOf(false) }
|
|
||||||
|
val updatePickedAudioClips: (List<AudioClip>) -> Unit = { audioDataList ->
|
||||||
|
var newAudioDataList: MutableList<AudioClip> = mutableListOf()
|
||||||
|
newAudioDataList.addAll(pickedAudioClips)
|
||||||
|
newAudioDataList.addAll(audioDataList)
|
||||||
|
if (newAudioDataList.size > MAX_AUDIO_CLIP_COUNT) {
|
||||||
|
newAudioDataList = newAudioDataList.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
|
||||||
|
}
|
||||||
|
pickedAudioClips = newAudioDataList.toList()
|
||||||
|
}
|
||||||
|
|
||||||
LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) }
|
LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) }
|
||||||
|
|
||||||
|
@ -170,6 +194,16 @@ fun MessageInputText(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Permission request when recording audio clips.
|
||||||
|
val recordAudioClipsPermissionLauncher =
|
||||||
|
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) {
|
||||||
|
permissionGranted ->
|
||||||
|
if (permissionGranted) {
|
||||||
|
showAddContentMenu = false
|
||||||
|
showAudioRecorderBottomSheet = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Registers a photo picker activity launcher in single-select mode.
|
// Registers a photo picker activity launcher in single-select mode.
|
||||||
val pickMedia =
|
val pickMedia =
|
||||||
rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris ->
|
rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris ->
|
||||||
|
@ -184,9 +218,31 @@ fun MessageInputText(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val pickWav =
|
||||||
|
rememberLauncherForActivityResult(
|
||||||
|
contract = ActivityResultContracts.StartActivityForResult()
|
||||||
|
) { result ->
|
||||||
|
if (result.resultCode == android.app.Activity.RESULT_OK) {
|
||||||
|
result.data?.data?.let { uri ->
|
||||||
|
Log.d(TAG, "Picked wav file: $uri")
|
||||||
|
scope.launch(Dispatchers.IO) {
|
||||||
|
convertWavToMonoWithMaxSeconds(context = context, stereoUri = uri)?.let { audioClip ->
|
||||||
|
updatePickedAudioClips(
|
||||||
|
listOf(
|
||||||
|
AudioClip(audioData = audioClip.audioData, sampleRate = audioClip.sampleRate)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Log.d(TAG, "Wav picking cancelled.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Column {
|
Column {
|
||||||
// A preview panel for the selected image.
|
// A preview panel for the selected images and audio clips.
|
||||||
if (pickedImages.isNotEmpty()) {
|
if (pickedImages.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
|
||||||
Row(
|
Row(
|
||||||
modifier =
|
modifier =
|
||||||
Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()),
|
Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()),
|
||||||
|
@ -203,20 +259,30 @@ fun MessageInputText(
|
||||||
.clip(RoundedCornerShape(8.dp))
|
.clip(RoundedCornerShape(8.dp))
|
||||||
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
|
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)),
|
||||||
)
|
)
|
||||||
|
MediaPanelCloseButton { pickedImages = pickedImages.filter { image != it } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for ((index, audioClip) in pickedAudioClips.withIndex()) {
|
||||||
|
Box(contentAlignment = Alignment.TopEnd) {
|
||||||
Box(
|
Box(
|
||||||
modifier =
|
modifier =
|
||||||
Modifier.offset(x = 10.dp, y = (-10).dp)
|
Modifier.shadow(2.dp, shape = RoundedCornerShape(8.dp))
|
||||||
.clip(CircleShape)
|
.clip(RoundedCornerShape(8.dp))
|
||||||
.background(MaterialTheme.colorScheme.surface)
|
.background(MaterialTheme.colorScheme.surface)
|
||||||
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
|
.border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp))
|
||||||
.clickable { pickedImages = pickedImages.filter { image != it } }
|
|
||||||
) {
|
) {
|
||||||
Icon(
|
AudioPlaybackPanel(
|
||||||
Icons.Rounded.Close,
|
audioData = audioClip.audioData,
|
||||||
contentDescription = "",
|
sampleRate = audioClip.sampleRate,
|
||||||
modifier = Modifier.padding(3.dp).size(16.dp),
|
isRecording = false,
|
||||||
|
modifier = Modifier.padding(end = 16.dp),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
MediaPanelCloseButton {
|
||||||
|
pickedAudioClips =
|
||||||
|
pickedAudioClips.filterIndexed { curIndex, curAudioData -> curIndex != index }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -239,10 +305,13 @@ fun MessageInputText(
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
) {
|
) {
|
||||||
val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT
|
val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT
|
||||||
|
val enableRecordAudioClipMenuItems =
|
||||||
|
(audioClipMessageCount + pickedAudioClips.size) < MAX_AUDIO_CLIP_COUNT
|
||||||
DropdownMenu(
|
DropdownMenu(
|
||||||
expanded = showAddContentMenu,
|
expanded = showAddContentMenu,
|
||||||
onDismissRequest = { showAddContentMenu = false },
|
onDismissRequest = { showAddContentMenu = false },
|
||||||
) {
|
) {
|
||||||
|
// Image related menu items.
|
||||||
if (showImagePickerInMenu) {
|
if (showImagePickerInMenu) {
|
||||||
// Take a picture.
|
// Take a picture.
|
||||||
DropdownMenuItem(
|
DropdownMenuItem(
|
||||||
|
@ -295,6 +364,70 @@ fun MessageInputText(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Audio related menu items.
|
||||||
|
if (showAudioItemsInMenu) {
|
||||||
|
DropdownMenuItem(
|
||||||
|
text = {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||||
|
) {
|
||||||
|
Icon(Icons.Rounded.Mic, contentDescription = "")
|
||||||
|
Text("Record audio clip")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
enabled = enableRecordAudioClipMenuItems,
|
||||||
|
onClick = {
|
||||||
|
// Check permission
|
||||||
|
when (PackageManager.PERMISSION_GRANTED) {
|
||||||
|
// Already got permission. Call the lambda.
|
||||||
|
ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO) -> {
|
||||||
|
showAddContentMenu = false
|
||||||
|
showAudioRecorderBottomSheet = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, ask for permission
|
||||||
|
else -> {
|
||||||
|
recordAudioClipsPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
DropdownMenuItem(
|
||||||
|
text = {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(6.dp),
|
||||||
|
) {
|
||||||
|
Icon(Icons.Rounded.AudioFile, contentDescription = "")
|
||||||
|
Text("Pick wav file")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
enabled = enableRecordAudioClipMenuItems,
|
||||||
|
onClick = {
|
||||||
|
showAddContentMenu = false
|
||||||
|
|
||||||
|
// Show file picker.
|
||||||
|
val intent =
|
||||||
|
Intent(Intent.ACTION_GET_CONTENT).apply {
|
||||||
|
addCategory(Intent.CATEGORY_OPENABLE)
|
||||||
|
type = "audio/*"
|
||||||
|
|
||||||
|
// Provide a list of more specific MIME types to filter for.
|
||||||
|
val mimeTypes = arrayOf("audio/wav", "audio/x-wav")
|
||||||
|
putExtra(Intent.EXTRA_MIME_TYPES, mimeTypes)
|
||||||
|
|
||||||
|
// Single select.
|
||||||
|
putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false)
|
||||||
|
.addFlags(Intent.FLAG_GRANT_PERSISTABLE_URI_PERMISSION)
|
||||||
|
.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
|
||||||
|
}
|
||||||
|
pickWav.launch(intent)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// Prompt templates.
|
// Prompt templates.
|
||||||
if (showPromptTemplatesInMenu) {
|
if (showPromptTemplatesInMenu) {
|
||||||
DropdownMenuItem(
|
DropdownMenuItem(
|
||||||
|
@ -369,15 +502,22 @@ fun MessageInputText(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // Send button. Only shown when text is not empty.
|
}
|
||||||
else if (curMessage.isNotEmpty()) {
|
// Send button. Only shown when text is not empty, or there is at least one recorded
|
||||||
|
// audio clip.
|
||||||
|
else if (curMessage.isNotEmpty() || pickedAudioClips.isNotEmpty()) {
|
||||||
IconButton(
|
IconButton(
|
||||||
enabled = !inProgress && !isResettingSession,
|
enabled = !inProgress && !isResettingSession,
|
||||||
onClick = {
|
onClick = {
|
||||||
onSendMessage(
|
onSendMessage(
|
||||||
createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim())
|
createMessagesToSend(
|
||||||
|
pickedImages = pickedImages,
|
||||||
|
audioClips = pickedAudioClips,
|
||||||
|
text = curMessage.trim(),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pickedImages = listOf()
|
pickedImages = listOf()
|
||||||
|
pickedAudioClips = listOf()
|
||||||
},
|
},
|
||||||
colors =
|
colors =
|
||||||
IconButtonDefaults.iconButtonColors(
|
IconButtonDefaults.iconButtonColors(
|
||||||
|
@ -403,8 +543,15 @@ fun MessageInputText(
|
||||||
history = modelManagerUiState.textInputHistory,
|
history = modelManagerUiState.textInputHistory,
|
||||||
onDismissed = { showTextInputHistorySheet = false },
|
onDismissed = { showTextInputHistorySheet = false },
|
||||||
onHistoryItemClicked = { item ->
|
onHistoryItemClicked = { item ->
|
||||||
onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item))
|
onSendMessage(
|
||||||
|
createMessagesToSend(
|
||||||
|
pickedImages = pickedImages,
|
||||||
|
audioClips = pickedAudioClips,
|
||||||
|
text = item,
|
||||||
|
)
|
||||||
|
)
|
||||||
pickedImages = listOf()
|
pickedImages = listOf()
|
||||||
|
pickedAudioClips = listOf()
|
||||||
modelManagerViewModel.promoteTextInputHistoryItem(item)
|
modelManagerViewModel.promoteTextInputHistoryItem(item)
|
||||||
},
|
},
|
||||||
onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) },
|
onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) },
|
||||||
|
@ -582,6 +729,43 @@ fun MessageInputText(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (showAudioRecorderBottomSheet) {
|
||||||
|
ModalBottomSheet(
|
||||||
|
sheetState = audioRecorderSheetState,
|
||||||
|
onDismissRequest = { showAudioRecorderBottomSheet = false },
|
||||||
|
) {
|
||||||
|
AudioRecorderPanel(
|
||||||
|
onSendAudioClip = { audioData ->
|
||||||
|
scope.launch {
|
||||||
|
updatePickedAudioClips(
|
||||||
|
listOf(AudioClip(audioData = audioData, sampleRate = SAMPLE_RATE))
|
||||||
|
)
|
||||||
|
audioRecorderSheetState.hide()
|
||||||
|
showAudioRecorderBottomSheet = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun MediaPanelCloseButton(onClicked: () -> Unit) {
|
||||||
|
Box(
|
||||||
|
modifier =
|
||||||
|
Modifier.offset(x = 10.dp, y = (-10).dp)
|
||||||
|
.clip(CircleShape)
|
||||||
|
.background(MaterialTheme.colorScheme.surface)
|
||||||
|
.border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape)
|
||||||
|
.clickable { onClicked() }
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Rounded.Close,
|
||||||
|
contentDescription = "",
|
||||||
|
modifier = Modifier.padding(3.dp).size(16.dp),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handleImagesSelected(
|
private fun handleImagesSelected(
|
||||||
|
@ -641,20 +825,50 @@ private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun createMessagesToSend(pickedImages: List<Bitmap>, text: String): List<ChatMessage> {
|
private fun createMessagesToSend(
|
||||||
|
pickedImages: List<Bitmap>,
|
||||||
|
audioClips: List<AudioClip>,
|
||||||
|
text: String,
|
||||||
|
): List<ChatMessage> {
|
||||||
var messages: MutableList<ChatMessage> = mutableListOf()
|
var messages: MutableList<ChatMessage> = mutableListOf()
|
||||||
|
|
||||||
|
// Add image messages.
|
||||||
|
var imageMessages: MutableList<ChatMessageImage> = mutableListOf()
|
||||||
if (pickedImages.isNotEmpty()) {
|
if (pickedImages.isNotEmpty()) {
|
||||||
for (image in pickedImages) {
|
for (image in pickedImages) {
|
||||||
messages.add(
|
imageMessages.add(
|
||||||
ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER)
|
ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Cap the number of image messages.
|
// Cap the number of image messages.
|
||||||
if (messages.size > MAX_IMAGE_COUNT) {
|
if (imageMessages.size > MAX_IMAGE_COUNT) {
|
||||||
messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
|
imageMessages = imageMessages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT)
|
||||||
|
}
|
||||||
|
messages.addAll(imageMessages)
|
||||||
|
|
||||||
|
// Add audio messages.
|
||||||
|
var audioMessages: MutableList<ChatMessageAudioClip> = mutableListOf()
|
||||||
|
if (audioClips.isNotEmpty()) {
|
||||||
|
for (audioClip in audioClips) {
|
||||||
|
audioMessages.add(
|
||||||
|
ChatMessageAudioClip(
|
||||||
|
audioData = audioClip.audioData,
|
||||||
|
sampleRate = audioClip.sampleRate,
|
||||||
|
side = ChatSide.USER,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Cap the number of audio messages.
|
||||||
|
if (audioMessages.size > MAX_AUDIO_CLIP_COUNT) {
|
||||||
|
audioMessages = audioMessages.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT)
|
||||||
|
}
|
||||||
|
messages.addAll(audioMessages)
|
||||||
|
|
||||||
|
if (text.isNotEmpty()) {
|
||||||
|
messages.add(ChatMessageText(content = text, side = ChatSide.USER))
|
||||||
}
|
}
|
||||||
messages.add(ChatMessageText(content = text, side = ChatSide.USER))
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,6 +121,7 @@ private val IMPORT_CONFIGS_LLM: List<Config> =
|
||||||
valueType = ValueType.FLOAT,
|
valueType = ValueType.FLOAT,
|
||||||
),
|
),
|
||||||
BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false),
|
BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false),
|
||||||
|
BooleanSwitchConfig(key = ConfigKey.SUPPORT_AUDIO, defaultValue = false),
|
||||||
SegmentedButtonConfig(
|
SegmentedButtonConfig(
|
||||||
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
key = ConfigKey.COMPATIBLE_ACCELERATORS,
|
||||||
defaultValue = Accelerator.CPU.label,
|
defaultValue = Accelerator.CPU.label,
|
||||||
|
@ -230,6 +231,12 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
|
||||||
valueType = ValueType.BOOLEAN,
|
valueType = ValueType.BOOLEAN,
|
||||||
)
|
)
|
||||||
as Boolean
|
as Boolean
|
||||||
|
val supportAudio =
|
||||||
|
convertValueToTargetType(
|
||||||
|
value = values.get(ConfigKey.SUPPORT_AUDIO.label)!!,
|
||||||
|
valueType = ValueType.BOOLEAN,
|
||||||
|
)
|
||||||
|
as Boolean
|
||||||
val importedModel: ImportedModel =
|
val importedModel: ImportedModel =
|
||||||
ImportedModel.newBuilder()
|
ImportedModel.newBuilder()
|
||||||
.setFileName(fileName)
|
.setFileName(fileName)
|
||||||
|
@ -242,6 +249,7 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) -
|
||||||
.setDefaultTopp(defaultTopp)
|
.setDefaultTopp(defaultTopp)
|
||||||
.setDefaultTemperature(defaultTemperature)
|
.setDefaultTemperature(defaultTemperature)
|
||||||
.setSupportImage(supportImage)
|
.setSupportImage(supportImage)
|
||||||
|
.setSupportAudio(supportAudio)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
.build()
|
.build()
|
||||||
|
|
|
@ -173,7 +173,7 @@ fun SettingsDialog(
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||||
)
|
)
|
||||||
Text(
|
Text(
|
||||||
"Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
|
"Expires at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}",
|
||||||
style = MaterialTheme.typography.bodyMedium,
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2025 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package com.google.ai.edge.gallery.ui.icon
|
|
||||||
|
|
||||||
import androidx.compose.ui.graphics.Color
|
|
||||||
import androidx.compose.ui.graphics.SolidColor
|
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
|
||||||
import androidx.compose.ui.graphics.vector.path
|
|
||||||
import androidx.compose.ui.unit.dp
|
|
||||||
|
|
||||||
val Forum: ImageVector
|
|
||||||
get() {
|
|
||||||
if (_Forum != null) return _Forum!!
|
|
||||||
|
|
||||||
_Forum =
|
|
||||||
ImageVector.Builder(
|
|
||||||
name = "Forum",
|
|
||||||
defaultWidth = 24.dp,
|
|
||||||
defaultHeight = 24.dp,
|
|
||||||
viewportWidth = 960f,
|
|
||||||
viewportHeight = 960f,
|
|
||||||
)
|
|
||||||
.apply {
|
|
||||||
path(fill = SolidColor(Color(0xFF000000))) {
|
|
||||||
moveTo(280f, 720f)
|
|
||||||
quadToRelative(-17f, 0f, -28.5f, -11.5f)
|
|
||||||
reflectiveQuadTo(240f, 680f)
|
|
||||||
verticalLineToRelative(-80f)
|
|
||||||
horizontalLineToRelative(520f)
|
|
||||||
verticalLineToRelative(-360f)
|
|
||||||
horizontalLineToRelative(80f)
|
|
||||||
quadToRelative(17f, 0f, 28.5f, 11.5f)
|
|
||||||
reflectiveQuadTo(880f, 280f)
|
|
||||||
verticalLineToRelative(600f)
|
|
||||||
lineTo(720f, 720f)
|
|
||||||
close()
|
|
||||||
moveTo(80f, 680f)
|
|
||||||
verticalLineToRelative(-560f)
|
|
||||||
quadToRelative(0f, -17f, 11.5f, -28.5f)
|
|
||||||
reflectiveQuadTo(120f, 80f)
|
|
||||||
horizontalLineToRelative(520f)
|
|
||||||
quadToRelative(17f, 0f, 28.5f, 11.5f)
|
|
||||||
reflectiveQuadTo(680f, 120f)
|
|
||||||
verticalLineToRelative(360f)
|
|
||||||
quadToRelative(0f, 17f, -11.5f, 28.5f)
|
|
||||||
reflectiveQuadTo(640f, 520f)
|
|
||||||
horizontalLineTo(240f)
|
|
||||||
close()
|
|
||||||
moveToRelative(520f, -240f)
|
|
||||||
verticalLineToRelative(-280f)
|
|
||||||
horizontalLineTo(160f)
|
|
||||||
verticalLineToRelative(280f)
|
|
||||||
close()
|
|
||||||
moveToRelative(-440f, 0f)
|
|
||||||
verticalLineToRelative(-280f)
|
|
||||||
close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.build()
|
|
||||||
|
|
||||||
return _Forum!!
|
|
||||||
}
|
|
||||||
|
|
||||||
private var _Forum: ImageVector? = null
|
|
|
@ -1,73 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2025 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package com.google.ai.edge.gallery.ui.icon
|
|
||||||
|
|
||||||
import androidx.compose.ui.graphics.Color
|
|
||||||
import androidx.compose.ui.graphics.SolidColor
|
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
|
||||||
import androidx.compose.ui.graphics.vector.path
|
|
||||||
import androidx.compose.ui.unit.dp
|
|
||||||
|
|
||||||
val Mms: ImageVector
|
|
||||||
get() {
|
|
||||||
if (_Mms != null) return _Mms!!
|
|
||||||
|
|
||||||
_Mms =
|
|
||||||
ImageVector.Builder(
|
|
||||||
name = "Mms",
|
|
||||||
defaultWidth = 24.dp,
|
|
||||||
defaultHeight = 24.dp,
|
|
||||||
viewportWidth = 960f,
|
|
||||||
viewportHeight = 960f,
|
|
||||||
)
|
|
||||||
.apply {
|
|
||||||
path(fill = SolidColor(Color(0xFF000000))) {
|
|
||||||
moveTo(240f, 560f)
|
|
||||||
horizontalLineToRelative(480f)
|
|
||||||
lineTo(570f, 360f)
|
|
||||||
lineTo(450f, 520f)
|
|
||||||
lineToRelative(-90f, -120f)
|
|
||||||
close()
|
|
||||||
moveTo(80f, 880f)
|
|
||||||
verticalLineToRelative(-720f)
|
|
||||||
quadToRelative(0f, -33f, 23.5f, -56.5f)
|
|
||||||
reflectiveQuadTo(160f, 80f)
|
|
||||||
horizontalLineToRelative(640f)
|
|
||||||
quadToRelative(33f, 0f, 56.5f, 23.5f)
|
|
||||||
reflectiveQuadTo(880f, 160f)
|
|
||||||
verticalLineToRelative(480f)
|
|
||||||
quadToRelative(0f, 33f, -23.5f, 56.5f)
|
|
||||||
reflectiveQuadTo(800f, 720f)
|
|
||||||
horizontalLineTo(240f)
|
|
||||||
close()
|
|
||||||
moveToRelative(126f, -240f)
|
|
||||||
horizontalLineToRelative(594f)
|
|
||||||
verticalLineToRelative(-480f)
|
|
||||||
horizontalLineTo(160f)
|
|
||||||
verticalLineToRelative(525f)
|
|
||||||
close()
|
|
||||||
moveToRelative(-46f, 0f)
|
|
||||||
verticalLineToRelative(-480f)
|
|
||||||
close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.build()
|
|
||||||
|
|
||||||
return _Mms!!
|
|
||||||
}
|
|
||||||
|
|
||||||
private var _Mms: ImageVector? = null
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2025 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package com.google.ai.edge.gallery.ui.icon
|
|
||||||
|
|
||||||
import androidx.compose.ui.graphics.Color
|
|
||||||
import androidx.compose.ui.graphics.SolidColor
|
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
|
||||||
import androidx.compose.ui.graphics.vector.path
|
|
||||||
import androidx.compose.ui.unit.dp
|
|
||||||
|
|
||||||
val Widgets: ImageVector
|
|
||||||
get() {
|
|
||||||
if (_Widgets != null) return _Widgets!!
|
|
||||||
|
|
||||||
_Widgets =
|
|
||||||
ImageVector.Builder(
|
|
||||||
name = "Widgets",
|
|
||||||
defaultWidth = 24.dp,
|
|
||||||
defaultHeight = 24.dp,
|
|
||||||
viewportWidth = 960f,
|
|
||||||
viewportHeight = 960f,
|
|
||||||
)
|
|
||||||
.apply {
|
|
||||||
path(fill = SolidColor(Color(0xFF000000))) {
|
|
||||||
moveTo(666f, 520f)
|
|
||||||
lineTo(440f, 294f)
|
|
||||||
lineToRelative(226f, -226f)
|
|
||||||
lineToRelative(226f, 226f)
|
|
||||||
close()
|
|
||||||
moveToRelative(-546f, -80f)
|
|
||||||
verticalLineToRelative(-320f)
|
|
||||||
horizontalLineToRelative(320f)
|
|
||||||
verticalLineToRelative(320f)
|
|
||||||
close()
|
|
||||||
moveToRelative(400f, 400f)
|
|
||||||
verticalLineToRelative(-320f)
|
|
||||||
horizontalLineToRelative(320f)
|
|
||||||
verticalLineToRelative(320f)
|
|
||||||
close()
|
|
||||||
moveToRelative(-400f, 0f)
|
|
||||||
verticalLineToRelative(-320f)
|
|
||||||
horizontalLineToRelative(320f)
|
|
||||||
verticalLineToRelative(320f)
|
|
||||||
close()
|
|
||||||
moveToRelative(80f, -480f)
|
|
||||||
horizontalLineToRelative(160f)
|
|
||||||
verticalLineToRelative(-160f)
|
|
||||||
horizontalLineTo(200f)
|
|
||||||
close()
|
|
||||||
moveToRelative(467f, 48f)
|
|
||||||
lineToRelative(113f, -113f)
|
|
||||||
lineToRelative(-113f, -113f)
|
|
||||||
lineToRelative(-113f, 113f)
|
|
||||||
close()
|
|
||||||
moveToRelative(-67f, 352f)
|
|
||||||
horizontalLineToRelative(160f)
|
|
||||||
verticalLineToRelative(-160f)
|
|
||||||
horizontalLineTo(600f)
|
|
||||||
close()
|
|
||||||
moveToRelative(-400f, 0f)
|
|
||||||
horizontalLineToRelative(160f)
|
|
||||||
verticalLineToRelative(-160f)
|
|
||||||
horizontalLineTo(200f)
|
|
||||||
close()
|
|
||||||
moveToRelative(400f, -160f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.build()
|
|
||||||
|
|
||||||
return _Widgets!!
|
|
||||||
}
|
|
||||||
|
|
||||||
private var _Widgets: ImageVector? = null
|
|
|
@ -62,13 +62,13 @@ object LlmChatModelHelper {
|
||||||
Accelerator.GPU.label -> LlmInference.Backend.GPU
|
Accelerator.GPU.label -> LlmInference.Backend.GPU
|
||||||
else -> LlmInference.Backend.GPU
|
else -> LlmInference.Backend.GPU
|
||||||
}
|
}
|
||||||
val options =
|
val optionsBuilder =
|
||||||
LlmInference.LlmInferenceOptions.builder()
|
LlmInference.LlmInferenceOptions.builder()
|
||||||
.setModelPath(model.getPath(context = context))
|
.setModelPath(model.getPath(context = context))
|
||||||
.setMaxTokens(maxTokens)
|
.setMaxTokens(maxTokens)
|
||||||
.setPreferredBackend(preferredBackend)
|
.setPreferredBackend(preferredBackend)
|
||||||
.setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0)
|
.setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0)
|
||||||
.build()
|
val options = optionsBuilder.build()
|
||||||
|
|
||||||
// Create an instance of the LLM Inference task and session.
|
// Create an instance of the LLM Inference task and session.
|
||||||
try {
|
try {
|
||||||
|
@ -82,7 +82,9 @@ object LlmChatModelHelper {
|
||||||
.setTopP(topP)
|
.setTopP(topP)
|
||||||
.setTemperature(temperature)
|
.setTemperature(temperature)
|
||||||
.setGraphOptions(
|
.setGraphOptions(
|
||||||
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
GraphOptions.builder()
|
||||||
|
.setEnableVisionModality(model.llmSupportImage)
|
||||||
|
.build()
|
||||||
)
|
)
|
||||||
.build(),
|
.build(),
|
||||||
)
|
)
|
||||||
|
@ -115,7 +117,9 @@ object LlmChatModelHelper {
|
||||||
.setTopP(topP)
|
.setTopP(topP)
|
||||||
.setTemperature(temperature)
|
.setTemperature(temperature)
|
||||||
.setGraphOptions(
|
.setGraphOptions(
|
||||||
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
GraphOptions.builder()
|
||||||
|
.setEnableVisionModality(model.llmSupportImage)
|
||||||
|
.build()
|
||||||
)
|
)
|
||||||
.build(),
|
.build(),
|
||||||
)
|
)
|
||||||
|
@ -159,6 +163,7 @@ object LlmChatModelHelper {
|
||||||
resultListener: ResultListener,
|
resultListener: ResultListener,
|
||||||
cleanUpListener: CleanUpListener,
|
cleanUpListener: CleanUpListener,
|
||||||
images: List<Bitmap> = listOf(),
|
images: List<Bitmap> = listOf(),
|
||||||
|
audioClips: List<ByteArray> = listOf(),
|
||||||
) {
|
) {
|
||||||
val instance = model.instance as LlmModelInstance
|
val instance = model.instance as LlmModelInstance
|
||||||
|
|
||||||
|
@ -172,10 +177,16 @@ object LlmChatModelHelper {
|
||||||
// For a model that supports image modality, we need to add the text query chunk before adding
|
// For a model that supports image modality, we need to add the text query chunk before adding
|
||||||
// image.
|
// image.
|
||||||
val session = instance.session
|
val session = instance.session
|
||||||
session.addQueryChunk(input)
|
if (input.trim().isNotEmpty()) {
|
||||||
|
session.addQueryChunk(input)
|
||||||
|
}
|
||||||
for (image in images) {
|
for (image in images) {
|
||||||
session.addImage(BitmapImageBuilder(image).build())
|
session.addImage(BitmapImageBuilder(image).build())
|
||||||
}
|
}
|
||||||
|
for (audioClip in audioClips) {
|
||||||
|
// Uncomment when audio is supported.
|
||||||
|
// session.addAudio(audioClip)
|
||||||
|
}
|
||||||
val unused = session.generateResponseAsync(resultListener)
|
val unused = session.generateResponseAsync(resultListener)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.lifecycle.viewmodel.compose.viewModel
|
import androidx.lifecycle.viewmodel.compose.viewModel
|
||||||
import com.google.ai.edge.gallery.ui.ViewModelProvider
|
import com.google.ai.edge.gallery.ui.ViewModelProvider
|
||||||
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
|
||||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
|
||||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
|
||||||
import com.google.ai.edge.gallery.ui.common.chat.ChatView
|
import com.google.ai.edge.gallery.ui.common.chat.ChatView
|
||||||
|
@ -36,6 +37,10 @@ object LlmAskImageDestination {
|
||||||
val route = "LlmAskImageRoute"
|
val route = "LlmAskImageRoute"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object LlmAskAudioDestination {
|
||||||
|
val route = "LlmAskAudioRoute"
|
||||||
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun LlmChatScreen(
|
fun LlmChatScreen(
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
@ -66,6 +71,21 @@ fun LlmAskImageScreen(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun LlmAskAudioScreen(
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
navigateUp: () -> Unit,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
viewModel: LlmAskAudioViewModel = viewModel(factory = ViewModelProvider.Factory),
|
||||||
|
) {
|
||||||
|
ChatViewWrapper(
|
||||||
|
viewModel = viewModel,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
navigateUp = navigateUp,
|
||||||
|
modifier = modifier,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ChatViewWrapper(
|
fun ChatViewWrapper(
|
||||||
viewModel: LlmChatViewModel,
|
viewModel: LlmChatViewModel,
|
||||||
|
@ -86,6 +106,7 @@ fun ChatViewWrapper(
|
||||||
|
|
||||||
var text = ""
|
var text = ""
|
||||||
val images: MutableList<Bitmap> = mutableListOf()
|
val images: MutableList<Bitmap> = mutableListOf()
|
||||||
|
val audioMessages: MutableList<ChatMessageAudioClip> = mutableListOf()
|
||||||
var chatMessageText: ChatMessageText? = null
|
var chatMessageText: ChatMessageText? = null
|
||||||
for (message in messages) {
|
for (message in messages) {
|
||||||
if (message is ChatMessageText) {
|
if (message is ChatMessageText) {
|
||||||
|
@ -93,14 +114,17 @@ fun ChatViewWrapper(
|
||||||
text = message.content
|
text = message.content
|
||||||
} else if (message is ChatMessageImage) {
|
} else if (message is ChatMessageImage) {
|
||||||
images.add(message.bitmap)
|
images.add(message.bitmap)
|
||||||
|
} else if (message is ChatMessageAudioClip) {
|
||||||
|
audioMessages.add(message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (text.isNotEmpty() && chatMessageText != null) {
|
if ((text.isNotEmpty() && chatMessageText != null) || audioMessages.isNotEmpty()) {
|
||||||
modelManagerViewModel.addTextInputHistory(text)
|
modelManagerViewModel.addTextInputHistory(text)
|
||||||
viewModel.generateResponse(
|
viewModel.generateResponse(
|
||||||
model = model,
|
model = model,
|
||||||
input = text,
|
input = text,
|
||||||
images = images,
|
images = images,
|
||||||
|
audioMessages = audioMessages,
|
||||||
onError = {
|
onError = {
|
||||||
viewModel.handleError(
|
viewModel.handleError(
|
||||||
context = context,
|
context = context,
|
||||||
|
|
|
@ -22,9 +22,11 @@ import android.util.Log
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.google.ai.edge.gallery.data.ConfigKey
|
import com.google.ai.edge.gallery.data.ConfigKey
|
||||||
import com.google.ai.edge.gallery.data.Model
|
import com.google.ai.edge.gallery.data.Model
|
||||||
|
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.ai.edge.gallery.data.Task
|
import com.google.ai.edge.gallery.data.Task
|
||||||
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
|
||||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
|
||||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
|
||||||
|
@ -52,6 +54,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
model: Model,
|
model: Model,
|
||||||
input: String,
|
input: String,
|
||||||
images: List<Bitmap> = listOf(),
|
images: List<Bitmap> = listOf(),
|
||||||
|
audioMessages: List<ChatMessageAudioClip> = listOf(),
|
||||||
onError: () -> Unit,
|
onError: () -> Unit,
|
||||||
) {
|
) {
|
||||||
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
|
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
|
||||||
|
@ -72,6 +75,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
val instance = model.instance as LlmModelInstance
|
val instance = model.instance as LlmModelInstance
|
||||||
var prefillTokens = instance.session.sizeInTokens(input)
|
var prefillTokens = instance.session.sizeInTokens(input)
|
||||||
prefillTokens += images.size * 257
|
prefillTokens += images.size * 257
|
||||||
|
for (audioMessages in audioMessages) {
|
||||||
|
// 150ms = 1 audio token
|
||||||
|
val duration = audioMessages.getDurationInSeconds()
|
||||||
|
prefillTokens += (duration * 1000f / 150f).toInt()
|
||||||
|
}
|
||||||
|
|
||||||
var firstRun = true
|
var firstRun = true
|
||||||
var timeToFirstToken = 0f
|
var timeToFirstToken = 0f
|
||||||
|
@ -86,6 +94,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
model = model,
|
model = model,
|
||||||
input = input,
|
input = input,
|
||||||
images = images,
|
images = images,
|
||||||
|
audioClips = audioMessages.map { it.genByteArrayForWav() },
|
||||||
resultListener = { partialResult, done ->
|
resultListener = { partialResult, done ->
|
||||||
val curTs = System.currentTimeMillis()
|
val curTs = System.currentTimeMillis()
|
||||||
|
|
||||||
|
@ -214,7 +223,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
context: Context,
|
context: Context,
|
||||||
model: Model,
|
model: Model,
|
||||||
modelManagerViewModel: ModelManagerViewModel,
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
triggeredMessage: ChatMessageText,
|
triggeredMessage: ChatMessageText?,
|
||||||
) {
|
) {
|
||||||
// Clean up.
|
// Clean up.
|
||||||
modelManagerViewModel.cleanupModel(task = task, model = model)
|
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||||
|
@ -236,14 +245,20 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
)
|
)
|
||||||
|
|
||||||
// Add the triggered message back.
|
// Add the triggered message back.
|
||||||
addMessage(model = model, message = triggeredMessage)
|
if (triggeredMessage != null) {
|
||||||
|
addMessage(model = model, message = triggeredMessage)
|
||||||
|
}
|
||||||
|
|
||||||
// Re-initialize the session/engine.
|
// Re-initialize the session/engine.
|
||||||
modelManagerViewModel.initializeModel(context = context, task = task, model = model)
|
modelManagerViewModel.initializeModel(context = context, task = task, model = model)
|
||||||
|
|
||||||
// Re-generate the response automatically.
|
// Re-generate the response automatically.
|
||||||
generateResponse(model = model, input = triggeredMessage.content, onError = {})
|
if (triggeredMessage != null) {
|
||||||
|
generateResponse(model = model, input = triggeredMessage.content, onError = {})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
|
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
|
||||||
|
|
||||||
|
class LlmAskAudioViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_AUDIO)
|
||||||
|
|
|
@ -36,6 +36,7 @@ import com.google.ai.edge.gallery.data.ModelAllowlist
|
||||||
import com.google.ai.edge.gallery.data.ModelDownloadStatus
|
import com.google.ai.edge.gallery.data.ModelDownloadStatus
|
||||||
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
|
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.ai.edge.gallery.data.TASKS
|
import com.google.ai.edge.gallery.data.TASKS
|
||||||
|
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
|
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||||
|
@ -281,15 +282,12 @@ open class ModelManagerViewModel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
when (task.type) {
|
when (task.type) {
|
||||||
TaskType.LLM_CHAT ->
|
TaskType.LLM_CHAT,
|
||||||
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
|
TaskType.LLM_ASK_IMAGE,
|
||||||
|
TaskType.LLM_ASK_AUDIO,
|
||||||
TaskType.LLM_PROMPT_LAB ->
|
TaskType.LLM_PROMPT_LAB ->
|
||||||
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
|
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
|
||||||
|
|
||||||
TaskType.LLM_ASK_IMAGE ->
|
|
||||||
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
|
|
||||||
|
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
}
|
}
|
||||||
|
@ -301,9 +299,11 @@ open class ModelManagerViewModel(
|
||||||
model.cleanUpAfterInit = false
|
model.cleanUpAfterInit = false
|
||||||
Log.d(TAG, "Cleaning up model '${model.name}'...")
|
Log.d(TAG, "Cleaning up model '${model.name}'...")
|
||||||
when (task.type) {
|
when (task.type) {
|
||||||
TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_CHAT,
|
||||||
TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_PROMPT_LAB,
|
||||||
TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_ASK_IMAGE,
|
||||||
|
TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model)
|
||||||
|
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
}
|
}
|
||||||
|
@ -410,14 +410,19 @@ open class ModelManagerViewModel(
|
||||||
// Create model.
|
// Create model.
|
||||||
val model = createModelFromImportedModelInfo(info = info)
|
val model = createModelFromImportedModelInfo(info = info)
|
||||||
|
|
||||||
for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
|
for (task in
|
||||||
|
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) {
|
||||||
// Remove duplicated imported model if existed.
|
// Remove duplicated imported model if existed.
|
||||||
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported }
|
||||||
if (modelIndex >= 0) {
|
if (modelIndex >= 0) {
|
||||||
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
Log.d(TAG, "duplicated imported model found in task. Removing it first")
|
||||||
task.models.removeAt(modelIndex)
|
task.models.removeAt(modelIndex)
|
||||||
}
|
}
|
||||||
if ((task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || task != TASK_LLM_ASK_IMAGE) {
|
if (
|
||||||
|
(task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) ||
|
||||||
|
(task == TASK_LLM_ASK_AUDIO && model.llmSupportAudio) ||
|
||||||
|
(task != TASK_LLM_ASK_IMAGE && task != TASK_LLM_ASK_AUDIO)
|
||||||
|
) {
|
||||||
task.models.add(model)
|
task.models.add(model)
|
||||||
}
|
}
|
||||||
task.updateTrigger.value = System.currentTimeMillis()
|
task.updateTrigger.value = System.currentTimeMillis()
|
||||||
|
@ -657,6 +662,7 @@ open class ModelManagerViewModel(
|
||||||
TASK_LLM_CHAT.models.clear()
|
TASK_LLM_CHAT.models.clear()
|
||||||
TASK_LLM_PROMPT_LAB.models.clear()
|
TASK_LLM_PROMPT_LAB.models.clear()
|
||||||
TASK_LLM_ASK_IMAGE.models.clear()
|
TASK_LLM_ASK_IMAGE.models.clear()
|
||||||
|
TASK_LLM_ASK_AUDIO.models.clear()
|
||||||
for (allowedModel in modelAllowlist.models) {
|
for (allowedModel in modelAllowlist.models) {
|
||||||
if (allowedModel.disabled == true) {
|
if (allowedModel.disabled == true) {
|
||||||
continue
|
continue
|
||||||
|
@ -672,6 +678,9 @@ open class ModelManagerViewModel(
|
||||||
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
|
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) {
|
||||||
TASK_LLM_ASK_IMAGE.models.add(model)
|
TASK_LLM_ASK_IMAGE.models.add(model)
|
||||||
}
|
}
|
||||||
|
if (allowedModel.taskTypes.contains(TASK_LLM_ASK_AUDIO.type.id)) {
|
||||||
|
TASK_LLM_ASK_AUDIO.models.add(model)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-process all tasks.
|
// Pre-process all tasks.
|
||||||
|
@ -760,6 +769,9 @@ open class ModelManagerViewModel(
|
||||||
if (model.llmSupportImage) {
|
if (model.llmSupportImage) {
|
||||||
TASK_LLM_ASK_IMAGE.models.add(model)
|
TASK_LLM_ASK_IMAGE.models.add(model)
|
||||||
}
|
}
|
||||||
|
if (model.llmSupportAudio) {
|
||||||
|
TASK_LLM_ASK_AUDIO.models.add(model)
|
||||||
|
}
|
||||||
|
|
||||||
// Update status.
|
// Update status.
|
||||||
modelDownloadStatus[model.name] =
|
modelDownloadStatus[model.name] =
|
||||||
|
@ -800,6 +812,7 @@ open class ModelManagerViewModel(
|
||||||
accelerators = accelerators,
|
accelerators = accelerators,
|
||||||
)
|
)
|
||||||
val llmSupportImage = info.llmConfig.supportImage
|
val llmSupportImage = info.llmConfig.supportImage
|
||||||
|
val llmSupportAudio = info.llmConfig.supportAudio
|
||||||
val model =
|
val model =
|
||||||
Model(
|
Model(
|
||||||
name = info.fileName,
|
name = info.fileName,
|
||||||
|
@ -811,6 +824,7 @@ open class ModelManagerViewModel(
|
||||||
showRunAgainButton = false,
|
showRunAgainButton = false,
|
||||||
imported = true,
|
imported = true,
|
||||||
llmSupportImage = llmSupportImage,
|
llmSupportImage = llmSupportImage,
|
||||||
|
llmSupportAudio = llmSupportAudio,
|
||||||
)
|
)
|
||||||
model.preProcess()
|
model.preProcess()
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,7 @@ import androidx.navigation.compose.NavHost
|
||||||
import androidx.navigation.compose.composable
|
import androidx.navigation.compose.composable
|
||||||
import androidx.navigation.navArgument
|
import androidx.navigation.navArgument
|
||||||
import com.google.ai.edge.gallery.data.Model
|
import com.google.ai.edge.gallery.data.Model
|
||||||
|
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
|
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
|
||||||
|
@ -55,6 +56,8 @@ import com.google.ai.edge.gallery.data.TaskType
|
||||||
import com.google.ai.edge.gallery.data.getModelByName
|
import com.google.ai.edge.gallery.data.getModelByName
|
||||||
import com.google.ai.edge.gallery.ui.ViewModelProvider
|
import com.google.ai.edge.gallery.ui.ViewModelProvider
|
||||||
import com.google.ai.edge.gallery.ui.home.HomeScreen
|
import com.google.ai.edge.gallery.ui.home.HomeScreen
|
||||||
|
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination
|
||||||
|
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen
|
||||||
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination
|
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination
|
||||||
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen
|
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen
|
||||||
import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination
|
import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination
|
||||||
|
@ -209,7 +212,7 @@ fun GalleryNavHost(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLM image to text.
|
// Ask image.
|
||||||
composable(
|
composable(
|
||||||
route = "${LlmAskImageDestination.route}/{modelName}",
|
route = "${LlmAskImageDestination.route}/{modelName}",
|
||||||
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||||
|
@ -225,6 +228,23 @@ fun GalleryNavHost(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ask audio.
|
||||||
|
composable(
|
||||||
|
route = "${LlmAskAudioDestination.route}/{modelName}",
|
||||||
|
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
|
||||||
|
enterTransition = { slideEnter() },
|
||||||
|
exitTransition = { slideExit() },
|
||||||
|
) {
|
||||||
|
getModelFromNavigationParam(it, TASK_LLM_ASK_AUDIO)?.let { defaultModel ->
|
||||||
|
modelManagerViewModel.selectModel(defaultModel)
|
||||||
|
|
||||||
|
LlmAskAudioScreen(
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
navigateUp = { navController.navigateUp() },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle incoming intents for deep links
|
// Handle incoming intents for deep links
|
||||||
|
@ -256,6 +276,7 @@ fun navigateToTaskScreen(
|
||||||
when (taskType) {
|
when (taskType) {
|
||||||
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}")
|
||||||
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
|
TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}")
|
||||||
|
TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}")
|
||||||
TaskType.LLM_PROMPT_LAB ->
|
TaskType.LLM_PROMPT_LAB ->
|
||||||
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
|
|
|
@ -120,6 +120,8 @@ data class CustomColors(
|
||||||
val agentBubbleBgColor: Color = Color.Transparent,
|
val agentBubbleBgColor: Color = Color.Transparent,
|
||||||
val linkColor: Color = Color.Transparent,
|
val linkColor: Color = Color.Transparent,
|
||||||
val successColor: Color = Color.Transparent,
|
val successColor: Color = Color.Transparent,
|
||||||
|
val recordButtonBgColor: Color = Color.Transparent,
|
||||||
|
val waveFormBgColor: Color = Color.Transparent,
|
||||||
)
|
)
|
||||||
|
|
||||||
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
|
val LocalCustomColors = staticCompositionLocalOf { CustomColors() }
|
||||||
|
@ -145,6 +147,8 @@ val lightCustomColors =
|
||||||
userBubbleBgColor = Color(0xFF32628D),
|
userBubbleBgColor = Color(0xFF32628D),
|
||||||
linkColor = Color(0xFF32628D),
|
linkColor = Color(0xFF32628D),
|
||||||
successColor = Color(0xff3d860b),
|
successColor = Color(0xff3d860b),
|
||||||
|
recordButtonBgColor = Color(0xFFEE675C),
|
||||||
|
waveFormBgColor = Color(0xFFaaaaaa),
|
||||||
)
|
)
|
||||||
|
|
||||||
val darkCustomColors =
|
val darkCustomColors =
|
||||||
|
@ -168,6 +172,8 @@ val darkCustomColors =
|
||||||
userBubbleBgColor = Color(0xFF1f3760),
|
userBubbleBgColor = Color(0xFF1f3760),
|
||||||
linkColor = Color(0xFF9DCAFC),
|
linkColor = Color(0xFF9DCAFC),
|
||||||
successColor = Color(0xFFA1CE83),
|
successColor = Color(0xFFA1CE83),
|
||||||
|
recordButtonBgColor = Color(0xFFEE675C),
|
||||||
|
waveFormBgColor = Color(0xFFaaaaaa),
|
||||||
)
|
)
|
||||||
|
|
||||||
val MaterialTheme.customColors: CustomColors
|
val MaterialTheme.customColors: CustomColors
|
||||||
|
|
|
@ -55,6 +55,7 @@ message LlmConfig {
|
||||||
float default_topp = 4;
|
float default_topp = 4;
|
||||||
float default_temperature = 5;
|
float default_temperature = 5;
|
||||||
bool support_image = 6;
|
bool support_image = 6;
|
||||||
|
bool support_audio = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Settings {
|
message Settings {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue