From 0a492199158dc36aff9da566f6effdab4c7df4f2 Mon Sep 17 00:00:00 2001 From: Google AI Edge Gallery Date: Mon, 16 Jun 2025 13:52:07 -0700 Subject: [PATCH 1/9] Update the memory warning trigger condition. PiperOrigin-RevId: 772158962 --- .../ai/edge/gallery/ui/common/DownloadAndTryButton.kt | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/DownloadAndTryButton.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/DownloadAndTryButton.kt index 8511c6e..039ed6d 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/DownloadAndTryButton.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/DownloadAndTryButton.kt @@ -64,6 +64,7 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.withContext private const val TAG = "AGDownloadAndTryButton" +private const val SYSTEM_RESERVED_MEMORY_IN_BYTES = 3 * (1L shl 30) // TODO: // - replace the download button in chat view page with this one, and add a flag to not "onclick" @@ -290,7 +291,6 @@ fun DownloadAndTryButton( val activityManager = context.getSystemService(android.app.Activity.ACTIVITY_SERVICE) as? ActivityManager val estimatedPeakMemoryInBytes = model.estimatedPeakMemoryInBytes - val isMemoryLow = if (activityManager != null && estimatedPeakMemoryInBytes != null) { val memoryInfo = ActivityManager.MemoryInfo() @@ -302,11 +302,10 @@ fun DownloadAndTryButton( // The device should be able to run the model if `availMem` is larger than the // estimated peak memory. Android also has a mechanism to kill background apps to - // free up memory for the foreground app. We believe that if half of the total - // memory on the device is larger than the estimated peak memory, it can run the - // model fine with this mechanism. For example, a phone with 12GB memory can have - // very few `availMem` but will have no problem running most models. - max(memoryInfo.availMem, memoryInfo.totalMem / 2) < estimatedPeakMemoryInBytes + // free up memory for the foreground app. Reserving 3G for system buffer memory to + // avoid the app being killed by the system. + max(memoryInfo.availMem, memoryInfo.totalMem - SYSTEM_RESERVED_MEMORY_IN_BYTES) < + estimatedPeakMemoryInBytes } else { false } From 74a013c2e2d09bbcc2a2a55b0b0e590be31b6b44 Mon Sep 17 00:00:00 2001 From: Chunlei Niu Date: Tue, 17 Jun 2025 09:36:57 -0700 Subject: [PATCH 2/9] Remove OpenCL native lib from Android manifest, which is already provided in GenAI tasks dependency. PiperOrigin-RevId: 772508137 --- Android/src/app/src/main/AndroidManifest.xml | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml index a87aa98..2f04e88 100644 --- a/Android/src/app/src/main/AndroidManifest.xml +++ b/Android/src/app/src/main/AndroidManifest.xml @@ -70,17 +70,6 @@ - - - - - Date: Wed, 18 Jun 2025 10:55:49 -0700 Subject: [PATCH 3/9] Update bug_report.md --- .github/ISSUE_TEMPLATE/bug_report.md | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index a7d15c7..9e38ca2 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -12,6 +12,10 @@ A clear and concise description of what the bug is. **To Reproduce:** Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error **Expected behavior:** A clear and concise description of what you expected to happen. @@ -19,16 +23,10 @@ A clear and concise description of what you expected to happen. **Screenshots:** If applicable, add screenshots to help explain your problem. -**Desktop (please complete the following information):** - - OS: [e.g. iOS] - - Browser [e.g. chrome, safari] - - Version [e.g. 22] - -**Smartphone (please complete the following information):** - - Device: [e.g. iPhone6] - - OS: [e.g. iOS8.1] - - Browser [e.g. stock browser, safari] - - Version [e.g. 22] +**Device & App Information (Please complete the following):** +- Device: [e.g., Samsung Galaxy S23, Google Pixel 7] +- Android Version: [e.g., Android 12, Android 13] +- App Version: [e.g., 1.0.1, v1.0.2] **Additional context:** Add any other context about the problem here. From d0989adce15d73e6d3221a9d7ed54fdf352c074c Mon Sep 17 00:00:00 2001 From: Google AI Edge Gallery Date: Mon, 23 Jun 2025 10:24:10 -0700 Subject: [PATCH 4/9] Add audio support. - Add a new task "audio scribe". - Allow users to record audio clips or pick wav files to interact with model. - Add support for importing models with audio capability. - Fix a typo in Settings dialog (Thanks https://github.com/rhnvrm!) PiperOrigin-RevId: 774832681 --- Android/src/app/src/main/AndroidManifest.xml | 1 + .../google/ai/edge/gallery/common/Types.kt | 2 + .../google/ai/edge/gallery/common/Utils.kt | 137 +++++++ .../com/google/ai/edge/gallery/data/Config.kt | 1 + .../com/google/ai/edge/gallery/data/Consts.kt | 9 + .../com/google/ai/edge/gallery/data/Model.kt | 3 + .../ai/edge/gallery/data/ModelAllowlist.kt | 2 + .../com/google/ai/edge/gallery/data/Tasks.kt | 32 +- .../ai/edge/gallery/ui/ViewModelProvider.kt | 4 + .../ui/common/chat/AudioPlaybackPanel.kt | 344 ++++++++++++++++++ .../ui/common/chat/AudioRecorderPanel.kt | 327 +++++++++++++++++ .../gallery/ui/common/chat/ChatMessage.kt | 88 +++++ .../edge/gallery/ui/common/chat/ChatPanel.kt | 38 +- .../ui/common/chat/MessageBodyAudioClip.kt | 33 ++ .../ui/common/chat/MessageInputText.kt | 254 ++++++++++++- .../edge/gallery/ui/home/ModelImportDialog.kt | 8 + .../ai/edge/gallery/ui/home/SettingsDialog.kt | 2 +- .../google/ai/edge/gallery/ui/icon/Forum.kt | 78 ---- .../com/google/ai/edge/gallery/ui/icon/Mms.kt | 73 ---- .../ai/edge/gallery/ui/icon/Widgets.kt.kt | 87 ----- .../gallery/ui/llmchat/LlmChatModelHelper.kt | 21 +- .../edge/gallery/ui/llmchat/LlmChatScreen.kt | 26 +- .../gallery/ui/llmchat/LlmChatViewModel.kt | 21 +- .../ui/modelmanager/ModelManagerViewModel.kt | 36 +- .../gallery/ui/navigation/GalleryNavGraph.kt | 23 +- .../google/ai/edge/gallery/ui/theme/Theme.kt | 6 + Android/src/app/src/main/proto/settings.proto | 1 + 27 files changed, 1369 insertions(+), 288 deletions(-) create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml index 2f04e88..6bacbd3 100644 --- a/Android/src/app/src/main/AndroidManifest.xml +++ b/Android/src/app/src/main/AndroidManifest.xml @@ -29,6 +29,7 @@ + (val jsonObj: T, val textContent: String) + +class AudioClip(val audioData: ByteArray, val sampleRate: Int) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt index 6d37f9d..1ada15d 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt @@ -17,12 +17,17 @@ package com.google.ai.edge.gallery.common import android.content.Context +import android.net.Uri import android.util.Log +import com.google.ai.edge.gallery.data.SAMPLE_RATE import com.google.gson.Gson import com.google.gson.reflect.TypeToken import java.io.File import java.net.HttpURLConnection import java.net.URL +import java.nio.ByteBuffer +import java.nio.ByteOrder +import kotlin.math.floor data class LaunchInfo(val ts: Long) @@ -112,3 +117,135 @@ inline fun getJsonResponse(url: String): JsonObjAndTextContent? { return null } + +fun convertWavToMonoWithMaxSeconds( + context: Context, + stereoUri: Uri, + maxSeconds: Int = 30, +): AudioClip? { + Log.d(TAG, "Start to convert wav file to mono channel") + + try { + val inputStream = context.contentResolver.openInputStream(stereoUri) ?: return null + val originalBytes = inputStream.readBytes() + inputStream.close() + + // Read WAV header + if (originalBytes.size < 44) { + // Not a valid WAV file + Log.e(TAG, "Not a valid wav file") + return null + } + + val headerBuffer = ByteBuffer.wrap(originalBytes, 0, 44).order(ByteOrder.LITTLE_ENDIAN) + val channels = headerBuffer.getShort(22) + var sampleRate = headerBuffer.getInt(24) + val bitDepth = headerBuffer.getShort(34) + Log.d(TAG, "File metadata: channels: $channels, sampleRate: $sampleRate, bitDepth: $bitDepth") + + // Normalize audio to 16-bit. + val audioDataBytes = originalBytes.copyOfRange(fromIndex = 44, toIndex = originalBytes.size) + var sixteenBitBytes: ByteArray = + if (bitDepth.toInt() == 8) { + Log.d(TAG, "Converting 8-bit audio to 16-bit.") + convert8BitTo16Bit(audioDataBytes) + } else { + // Assume 16-bit or other format that can be handled directly + audioDataBytes + } + + // Convert byte array to short array for processing + val shortBuffer = + ByteBuffer.wrap(sixteenBitBytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer() + var pcmSamples = ShortArray(shortBuffer.remaining()) + shortBuffer.get(pcmSamples) + + // Resample if sample rate is less than 16000 Hz --- + if (sampleRate < SAMPLE_RATE) { + Log.d(TAG, "Resampling from $sampleRate Hz to $SAMPLE_RATE Hz.") + pcmSamples = resample(pcmSamples, sampleRate, SAMPLE_RATE, channels.toInt()) + sampleRate = SAMPLE_RATE + Log.d(TAG, "Resampling complete. New sample count: ${pcmSamples.size}") + } + + // Convert stereo to mono if necessary + var monoSamples = + if (channels.toInt() == 2) { + Log.d(TAG, "Converting stereo to mono.") + val mono = ShortArray(pcmSamples.size / 2) + for (i in mono.indices) { + val left = pcmSamples[i * 2] + val right = pcmSamples[i * 2 + 1] + mono[i] = ((left + right) / 2).toShort() + } + mono + } else { + Log.d(TAG, "Audio is already mono. No channel conversion needed.") + pcmSamples + } + + // Trim the audio to maxSeconds --- + val maxSamples = maxSeconds * sampleRate + if (monoSamples.size > maxSamples) { + Log.d(TAG, "Trimming clip from ${monoSamples.size} samples to $maxSamples samples.") + monoSamples = monoSamples.copyOfRange(0, maxSamples) + } + + val monoByteBuffer = ByteBuffer.allocate(monoSamples.size * 2).order(ByteOrder.LITTLE_ENDIAN) + monoByteBuffer.asShortBuffer().put(monoSamples) + return AudioClip(audioData = monoByteBuffer.array(), sampleRate = sampleRate) + } catch (e: Exception) { + Log.e(TAG, "Failed to convert wav to mono", e) + return null + } +} + +/** Converts 8-bit unsigned PCM audio data to 16-bit signed PCM. */ +private fun convert8BitTo16Bit(eightBitData: ByteArray): ByteArray { + // The new 16-bit data will be twice the size + val sixteenBitData = ByteArray(eightBitData.size * 2) + val buffer = ByteBuffer.wrap(sixteenBitData).order(ByteOrder.LITTLE_ENDIAN) + + for (byte in eightBitData) { + // Convert the unsigned 8-bit byte (0-255) to a signed 16-bit short (-32768 to 32767) + // 1. Get the unsigned value by masking with 0xFF + // 2. Subtract 128 to center the waveform around 0 (range becomes -128 to 127) + // 3. Scale by 256 to expand to the 16-bit range + val unsignedByte = byte.toInt() and 0xFF + val sixteenBitSample = ((unsignedByte - 128) * 256).toShort() + buffer.putShort(sixteenBitSample) + } + return sixteenBitData +} + +/** Resamples PCM audio data from an original sample rate to a target sample rate. */ +private fun resample( + inputSamples: ShortArray, + originalSampleRate: Int, + targetSampleRate: Int, + channels: Int, +): ShortArray { + if (originalSampleRate == targetSampleRate) { + return inputSamples + } + + val ratio = targetSampleRate.toDouble() / originalSampleRate + val outputLength = (inputSamples.size * ratio).toInt() + val resampledData = ShortArray(outputLength) + + if (channels == 1) { // Mono + for (i in resampledData.indices) { + val position = i / ratio + val index1 = floor(position).toInt() + val index2 = index1 + 1 + val fraction = position - index1 + + val sample1 = if (index1 < inputSamples.size) inputSamples[index1].toDouble() else 0.0 + val sample2 = if (index2 < inputSamples.size) inputSamples[index2].toDouble() else 0.0 + + resampledData[i] = (sample1 * (1 - fraction) + sample2 * fraction).toInt().toShort() + } + } + + return resampledData +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt index d961ecc..2b9fdf7 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt @@ -50,6 +50,7 @@ enum class ConfigKey(val label: String) { DEFAULT_TOPP("Default TopP"), DEFAULT_TEMPERATURE("Default temperature"), SUPPORT_IMAGE("Support image"), + SUPPORT_AUDIO("Support audio"), MAX_RESULT_COUNT("Max result count"), USE_GPU("Use GPU"), ACCELERATOR("Choose accelerator"), diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt index a2209af..85fd71e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt @@ -44,3 +44,12 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU) // Max number of images allowed in a "ask image" session. const val MAX_IMAGE_COUNT = 10 + +// Max number of audio clip in an "ask audio" session. +const val MAX_AUDIO_CLIP_COUNT = 10 + +// Max audio clip duration in seconds. +const val MAX_AUDIO_CLIP_DURATION_SEC = 30 + +// Audio-recording related consts. +const val SAMPLE_RATE = 16000 diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt index 580bf2b..02e1368 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt @@ -87,6 +87,9 @@ data class Model( /** Whether the LLM model supports image input. */ val llmSupportImage: Boolean = false, + /** Whether the LLM model supports audio input. */ + val llmSupportAudio: Boolean = false, + /** Whether the model is imported or not. */ val imported: Boolean = false, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt index f336638..5cf11ca 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt @@ -38,6 +38,7 @@ data class AllowedModel( val taskTypes: List, val disabled: Boolean? = null, val llmSupportImage: Boolean? = null, + val llmSupportAudio: Boolean? = null, val estimatedPeakMemoryInBytes: Long? = null, ) { fun toModel(): Model { @@ -96,6 +97,7 @@ data class AllowedModel( showRunAgainButton = showRunAgainButton, learnMoreUrl = "https://huggingface.co/${modelId}", llmSupportImage = llmSupportImage == true, + llmSupportAudio = llmSupportAudio == true, ) } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt index e95feab..c52d2c3 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt @@ -17,19 +17,22 @@ package com.google.ai.edge.gallery.data import androidx.annotation.StringRes +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.outlined.Forum +import androidx.compose.material.icons.outlined.Mic +import androidx.compose.material.icons.outlined.Mms +import androidx.compose.material.icons.outlined.Widgets import androidx.compose.runtime.MutableState import androidx.compose.runtime.mutableLongStateOf import androidx.compose.ui.graphics.vector.ImageVector import com.google.ai.edge.gallery.R -import com.google.ai.edge.gallery.ui.icon.Forum -import com.google.ai.edge.gallery.ui.icon.Mms -import com.google.ai.edge.gallery.ui.icon.Widgets /** Type of task. */ enum class TaskType(val label: String, val id: String) { LLM_CHAT(label = "AI Chat", id = "llm_chat"), LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"), LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"), + LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"), TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_2(label = "Test task 2", id = "test_task_2"), } @@ -71,7 +74,7 @@ data class Task( val TASK_LLM_CHAT = Task( type = TaskType.LLM_CHAT, - icon = Forum, + icon = Icons.Outlined.Forum, models = mutableListOf(), description = "Chat with on-device large language models", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", @@ -83,7 +86,7 @@ val TASK_LLM_CHAT = val TASK_LLM_PROMPT_LAB = Task( type = TaskType.LLM_PROMPT_LAB, - icon = Widgets, + icon = Icons.Outlined.Widgets, models = mutableListOf(), description = "Single turn use cases with on-device large language model", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", @@ -95,7 +98,7 @@ val TASK_LLM_PROMPT_LAB = val TASK_LLM_ASK_IMAGE = Task( type = TaskType.LLM_ASK_IMAGE, - icon = Mms, + icon = Icons.Outlined.Mms, models = mutableListOf(), description = "Ask questions about images with on-device large language models", docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", @@ -104,8 +107,23 @@ val TASK_LLM_ASK_IMAGE = textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat, ) +val TASK_LLM_ASK_AUDIO = + Task( + type = TaskType.LLM_ASK_AUDIO, + icon = Icons.Outlined.Mic, + models = mutableListOf(), + // TODO(do not submit) + description = + "Instantly transcribe and/or translate audio clips using on-device large language models", + docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", + sourceCodeUrl = + "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt", + textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat, + ) + /** All tasks. */ -val TASKS: List = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT) +val TASKS: List = + listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT) fun getModelByName(name: String): Model? { for (task in TASKS) { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt index fdded53..6ceb148 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt @@ -21,6 +21,7 @@ import androidx.lifecycle.viewmodel.CreationExtras import androidx.lifecycle.viewmodel.initializer import androidx.lifecycle.viewmodel.viewModelFactory import com.google.ai.edge.gallery.GalleryApplication +import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel @@ -49,6 +50,9 @@ object ViewModelProvider { // Initializer for LlmAskImageViewModel. initializer { LlmAskImageViewModel() } + + // Initializer for LlmAskAudioViewModel. + initializer { LlmAskAudioViewModel() } } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt new file mode 100644 index 0000000..9faebf0 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioPlaybackPanel.kt @@ -0,0 +1,344 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.ui.common.chat + +import android.media.AudioAttributes +import android.media.AudioFormat +import android.media.AudioTrack +import android.util.Log +import androidx.compose.foundation.Canvas +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.width +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.rounded.PlayArrow +import androidx.compose.material.icons.rounded.Stop +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.DisposableEffect +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.MutableState +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableFloatStateOf +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.geometry.CornerRadius +import androidx.compose.ui.geometry.Offset +import androidx.compose.ui.geometry.Size +import androidx.compose.ui.geometry.toRect +import androidx.compose.ui.graphics.BlendMode +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.drawscope.drawIntoCanvas +import androidx.compose.ui.unit.dp +import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC +import com.google.ai.edge.gallery.ui.theme.customColors +import java.nio.ByteBuffer +import java.nio.ByteOrder +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext + +private const val TAG = "AGAudioPlaybackPanel" +private const val BAR_SPACE = 2 +private const val BAR_WIDTH = 2 +private const val MIN_BAR_COUNT = 16 +private const val MAX_BAR_COUNT = 48 + +/** + * A Composable that displays an audio playback panel, including play/stop controls, a waveform + * visualization, and the duration of the audio clip. + */ +@Composable +fun AudioPlaybackPanel( + audioData: ByteArray, + sampleRate: Int, + isRecording: Boolean, + modifier: Modifier = Modifier, + onDarkBg: Boolean = false, +) { + val coroutineScope = rememberCoroutineScope() + var isPlaying by remember { mutableStateOf(false) } + val audioTrackState = remember { mutableStateOf(null) } + val durationInSeconds = + remember(audioData) { + // PCM 16-bit + val bytesPerSample = 2 + val bytesPerFrame = bytesPerSample * 1 // mono + val totalFrames = audioData.size.toDouble() / bytesPerFrame + totalFrames / sampleRate + } + val barCount = + remember(durationInSeconds) { + val f = durationInSeconds / MAX_AUDIO_CLIP_DURATION_SEC + ((MAX_BAR_COUNT - MIN_BAR_COUNT) * f + MIN_BAR_COUNT).toInt() + } + val amplitudeLevels = + remember(audioData) { generateAmplitudeLevels(audioData = audioData, barCount = barCount) } + var playbackProgress by remember { mutableFloatStateOf(0f) } + + // Reset when a new recording is started. + LaunchedEffect(isRecording) { + if (isRecording) { + val audioTrack = audioTrackState.value + audioTrack?.stop() + isPlaying = false + playbackProgress = 0f + } + } + + // Cleanup on Composable Disposal. + DisposableEffect(Unit) { + onDispose { + val audioTrack = audioTrackState.value + audioTrack?.stop() + audioTrack?.release() + } + } + + Row(verticalAlignment = Alignment.CenterVertically, modifier = modifier) { + // Button to play/stop the clip. + IconButton( + onClick = { + coroutineScope.launch { + if (!isPlaying) { + isPlaying = true + playAudio( + audioTrackState = audioTrackState, + audioData = audioData, + sampleRate = sampleRate, + onProgress = { playbackProgress = it }, + onCompletion = { + playbackProgress = 0f + isPlaying = false + }, + ) + } else { + stopPlayAudio(audioTrackState = audioTrackState) + playbackProgress = 0f + isPlaying = false + } + } + } + ) { + Icon( + if (isPlaying) Icons.Rounded.Stop else Icons.Rounded.PlayArrow, + contentDescription = "", + tint = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary, + ) + } + + // Visualization + AmplitudeBarGraph( + amplitudeLevels = amplitudeLevels, + progress = playbackProgress, + modifier = + Modifier.width((barCount * BAR_WIDTH + (barCount - 1) * BAR_SPACE).dp).height(24.dp), + onDarkBg = onDarkBg, + ) + + // Duration + Text( + "${"%.1f".format(durationInSeconds)}s", + style = MaterialTheme.typography.labelLarge, + color = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary, + modifier = Modifier.padding(start = 12.dp), + ) + } +} + +@Composable +private fun AmplitudeBarGraph( + amplitudeLevels: List, + progress: Float, + modifier: Modifier = Modifier, + onDarkBg: Boolean = false, +) { + val barColor = MaterialTheme.customColors.waveFormBgColor + val progressColor = if (onDarkBg) Color.White else MaterialTheme.colorScheme.primary + + Canvas(modifier = modifier) { + val barCount = amplitudeLevels.size + val barWidth = (size.width - BAR_SPACE.dp.toPx() * (barCount - 1)) / barCount + val cornerRadius = CornerRadius(x = barWidth, y = barWidth) + + // Use drawIntoCanvas for advanced blend mode operations + drawIntoCanvas { canvas -> + + // 1. Save the current state of the canvas onto a temporary, offscreen layer + canvas.saveLayer(size.toRect(), androidx.compose.ui.graphics.Paint()) + + // 2. Draw the bars in grey. + amplitudeLevels.forEachIndexed { index, level -> + val barHeight = (level * size.height).coerceAtLeast(1.5f) + val left = index * (barWidth + BAR_SPACE.dp.toPx()) + drawRoundRect( + color = barColor, + topLeft = Offset(x = left, y = size.height / 2 - barHeight / 2), + size = Size(barWidth, barHeight), + cornerRadius = cornerRadius, + ) + } + + // 3. Draw the progress rectangle using BlendMode.SrcIn to only draw where the bars already + // exists. + val progressWidth = size.width * progress + drawRect( + color = progressColor, + topLeft = Offset.Zero, + size = Size(progressWidth, size.height), + blendMode = BlendMode.SrcIn, + ) + + // 4. Restore the layer, merging it onto the main canvas + canvas.restore() + } + } +} + +private suspend fun playAudio( + audioTrackState: MutableState, + audioData: ByteArray, + sampleRate: Int, + onProgress: (Float) -> Unit, + onCompletion: () -> Unit, +) { + Log.d(TAG, "Start playing audio...") + + try { + withContext(Dispatchers.IO) { + var lastProgressUpdateMs = 0L + audioTrackState.value?.release() + val audioTrack = + AudioTrack.Builder() + .setAudioAttributes( + AudioAttributes.Builder() + .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) + .setUsage(AudioAttributes.USAGE_MEDIA) + .build() + ) + .setAudioFormat( + AudioFormat.Builder() + .setEncoding(AudioFormat.ENCODING_PCM_16BIT) + .setSampleRate(sampleRate) + .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) + .build() + ) + .setTransferMode(AudioTrack.MODE_STATIC) + .setBufferSizeInBytes(audioData.size) + .build() + + val bytesPerFrame = 2 // For PCM 16-bit Mono + val totalFrames = audioData.size / bytesPerFrame + + audioTrackState.value = audioTrack + audioTrack.write(audioData, 0, audioData.size) + audioTrack.play() + + // Coroutine to monitor progress + while (isActive && audioTrack.playState == AudioTrack.PLAYSTATE_PLAYING) { + val currentFrames = audioTrack.playbackHeadPosition + if (currentFrames >= totalFrames) { + break // Exit loop when playback is done + } + val progress = currentFrames.toFloat() / totalFrames + val curMs = System.currentTimeMillis() + if (curMs - lastProgressUpdateMs > 30) { + onProgress(progress) + lastProgressUpdateMs = curMs + } + } + + if (isActive) { + audioTrackState.value?.stop() + } + } + } catch (e: Exception) { + // Ignore + } finally { + onProgress(1f) + onCompletion() + } +} + +private fun stopPlayAudio(audioTrackState: MutableState) { + Log.d(TAG, "Stopping playing audio...") + + val audioTrack = audioTrackState.value + audioTrack?.stop() + audioTrack?.release() + audioTrackState.value = null +} + +/** + * Processes a raw PCM 16-bit audio byte array to generate a list of normalized amplitude levels for + * visualization. + */ +private fun generateAmplitudeLevels(audioData: ByteArray, barCount: Int): List { + if (audioData.isEmpty()) { + return List(barCount) { 0f } + } + + // 1. Parse bytes into 16-bit short samples (PCM 16-bit) + val shortBuffer = ByteBuffer.wrap(audioData).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer() + val samples = ShortArray(shortBuffer.remaining()) + shortBuffer.get(samples) + + if (samples.isEmpty()) { + return List(barCount) { 0f } + } + + // 2. Determine the size of each chunk + val chunkSize = samples.size / barCount + val amplitudeLevels = mutableListOf() + + // 3. Get the max value for each chunk + for (i in 0 until barCount) { + val chunkStart = i * chunkSize + val chunkEnd = (chunkStart + chunkSize).coerceAtMost(samples.size) + + var maxAmplitudeInChunk = 0.0 + + for (j in chunkStart until chunkEnd) { + val sampleAbs = kotlin.math.abs(samples[j].toDouble()) + if (sampleAbs > maxAmplitudeInChunk) { + maxAmplitudeInChunk = sampleAbs + } + } + + // 4. Normalize the value (0 to 1) + // Short.MAX_VALUE is 32767.0, a good reference for max amplitude + val normalizedRms = (maxAmplitudeInChunk / Short.MAX_VALUE).toFloat().coerceIn(0f, 1f) + amplitudeLevels.add(normalizedRms) + } + + // Normalize the resulting levels so that the max value becomes 0.9. + val maxVal = amplitudeLevels.max() + if (maxVal == 0f) { + return amplitudeLevels + } + val scaleFactor = 0.9f / maxVal + return amplitudeLevels.map { it * scaleFactor } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt new file mode 100644 index 0000000..9b018dc --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/AudioRecorderPanel.kt @@ -0,0 +1,327 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.ui.common.chat + +import android.annotation.SuppressLint +import android.content.Context +import android.media.AudioFormat +import android.media.AudioRecord +import android.media.MediaRecorder +import android.util.Log +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.rounded.ArrowUpward +import androidx.compose.material.icons.rounded.Mic +import androidx.compose.material.icons.rounded.Stop +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.DisposableEffect +import androidx.compose.runtime.MutableLongState +import androidx.compose.runtime.MutableState +import androidx.compose.runtime.derivedStateOf +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableIntStateOf +import androidx.compose.runtime.mutableLongStateOf +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.graphicsLayer +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.res.painterResource +import androidx.compose.ui.unit.dp +import com.google.ai.edge.gallery.R +import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_DURATION_SEC +import com.google.ai.edge.gallery.data.SAMPLE_RATE +import com.google.ai.edge.gallery.ui.theme.customColors +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer +import java.nio.ByteOrder +import kotlin.math.abs +import kotlin.math.pow +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch + +private const val TAG = "AGAudioRecorderPanel" + +private const val CHANNEL_CONFIG = AudioFormat.CHANNEL_IN_MONO +private const val AUDIO_FORMAT = AudioFormat.ENCODING_PCM_16BIT + +/** + * A Composable that provides an audio recording panel. It allows users to record audio clips, + * displays the recording duration and a live amplitude visualization, and provides options to play + * back the recorded clip or send it. + */ +@Composable +fun AudioRecorderPanel(onSendAudioClip: (ByteArray) -> Unit) { + val context = LocalContext.current + val coroutineScope = rememberCoroutineScope() + + var isRecording by remember { mutableStateOf(false) } + val elapsedMs = remember { mutableLongStateOf(0L) } + val audioRecordState = remember { mutableStateOf(null) } + val audioStream = remember { ByteArrayOutputStream() } + val recordedBytes = remember { mutableStateOf(null) } + var currentAmplitude by remember { mutableIntStateOf(0) } + + val elapsedSeconds by remember { + derivedStateOf { "%.1f".format(elapsedMs.value.toFloat() / 1000f) } + } + + // Cleanup on Composable Disposal. + DisposableEffect(Unit) { onDispose { audioRecordState.value?.release() } } + + Column(modifier = Modifier.padding(bottom = 12.dp)) { + // Title bar. + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween, + ) { + // Logo and state. + Row( + modifier = Modifier.padding(start = 16.dp), + horizontalArrangement = Arrangement.spacedBy(12.dp), + ) { + Icon( + painterResource(R.drawable.logo), + modifier = Modifier.size(20.dp), + contentDescription = "", + tint = Color.Unspecified, + ) + Text( + "Record audio clip (up to $MAX_AUDIO_CLIP_DURATION_SEC seconds)", + style = MaterialTheme.typography.labelLarge, + ) + } + } + + // Recorded clip. + Row( + modifier = Modifier.fillMaxWidth().padding(vertical = 12.dp).height(40.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.Center, + ) { + val curRecordedBytes = recordedBytes.value + if (curRecordedBytes == null) { + // Info message when there is no recorded clip and the recording has not started yet. + if (!isRecording) { + Text( + "Tap the record button to start", + style = MaterialTheme.typography.labelLarge, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + } + // Visualization for clip being recorded. + else { + Row( + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalAlignment = Alignment.CenterVertically, + ) { + Box( + modifier = + Modifier.size(8.dp) + .background(MaterialTheme.customColors.recordButtonBgColor, CircleShape) + ) + Text("$elapsedSeconds s") + } + } + } + // Controls for recorded clip. + else { + Row { + // Clip player. + AudioPlaybackPanel( + audioData = curRecordedBytes, + sampleRate = SAMPLE_RATE, + isRecording = isRecording, + ) + + // Button to send the clip + IconButton(onClick = { onSendAudioClip(curRecordedBytes) }) { + Icon(Icons.Rounded.ArrowUpward, contentDescription = "") + } + } + } + } + + // Buttons + Box(contentAlignment = Alignment.Center, modifier = Modifier.fillMaxWidth().height(40.dp)) { + // Visualization of the current amplitude. + if (isRecording) { + // Normalize the amplitude (0-32767) to a fraction (0.0-1.0) + // We use a power scale (exponent < 1) to make the pulse more visible for lower volumes. + val normalizedAmplitude = (currentAmplitude.toFloat() / 32767f).pow(0.35f) + // Define the min and max size of the circle + val minSize = 38.dp + val maxSize = 100.dp + + // Map the normalized amplitude to our size range + val scale by + remember(normalizedAmplitude) { + derivedStateOf { (minSize + (maxSize - minSize) * normalizedAmplitude) / minSize } + } + Box( + modifier = + Modifier.size(minSize) + .graphicsLayer(scaleX = scale, scaleY = scale, clip = false, alpha = 0.3f) + .background(MaterialTheme.customColors.recordButtonBgColor, CircleShape) + ) + } + + // Record/stop button. + IconButton( + onClick = { + coroutineScope.launch { + if (!isRecording) { + isRecording = true + recordedBytes.value = null + startRecording( + context = context, + audioRecordState = audioRecordState, + audioStream = audioStream, + elapsedMs = elapsedMs, + onAmplitudeChanged = { currentAmplitude = it }, + onMaxDurationReached = { + val curRecordedBytes = + stopRecording(audioRecordState = audioRecordState, audioStream = audioStream) + recordedBytes.value = curRecordedBytes + isRecording = false + }, + ) + } else { + val curRecordedBytes = + stopRecording(audioRecordState = audioRecordState, audioStream = audioStream) + recordedBytes.value = curRecordedBytes + isRecording = false + } + } + }, + modifier = + Modifier.clip(CircleShape).background(MaterialTheme.customColors.recordButtonBgColor), + ) { + Icon( + if (isRecording) Icons.Rounded.Stop else Icons.Rounded.Mic, + contentDescription = "", + tint = Color.White, + ) + } + } + } +} + +// Permission is checked in parent composable. +@SuppressLint("MissingPermission") +private suspend fun startRecording( + context: Context, + audioRecordState: MutableState, + audioStream: ByteArrayOutputStream, + elapsedMs: MutableLongState, + onAmplitudeChanged: (Int) -> Unit, + onMaxDurationReached: () -> Unit, +) { + Log.d(TAG, "Start recording...") + val minBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, CHANNEL_CONFIG, AUDIO_FORMAT) + + audioRecordState.value?.release() + val recorder = + AudioRecord( + MediaRecorder.AudioSource.MIC, + SAMPLE_RATE, + CHANNEL_CONFIG, + AUDIO_FORMAT, + minBufferSize, + ) + + audioRecordState.value = recorder + val buffer = ByteArray(minBufferSize) + + // The function will only return when the recording is done (when stopRecording is called). + coroutineScope { + launch(Dispatchers.IO) { + recorder.startRecording() + + val startMs = System.currentTimeMillis() + elapsedMs.value = 0L + while (audioRecordState.value?.recordingState == AudioRecord.RECORDSTATE_RECORDING) { + val bytesRead = recorder.read(buffer, 0, buffer.size) + if (bytesRead > 0) { + val currentAmplitude = calculatePeakAmplitude(buffer = buffer, bytesRead = bytesRead) + onAmplitudeChanged(currentAmplitude) + audioStream.write(buffer, 0, bytesRead) + } + elapsedMs.value = System.currentTimeMillis() - startMs + if (elapsedMs.value >= MAX_AUDIO_CLIP_DURATION_SEC * 1000) { + onMaxDurationReached() + break + } + } + } + } +} + +private fun stopRecording( + audioRecordState: MutableState, + audioStream: ByteArrayOutputStream, +): ByteArray { + Log.d(TAG, "Stopping recording...") + + val recorder = audioRecordState.value + if (recorder?.recordingState == AudioRecord.RECORDSTATE_RECORDING) { + recorder.stop() + } + recorder?.release() + audioRecordState.value = null + + val recordedBytes = audioStream.toByteArray() + audioStream.reset() + Log.d(TAG, "Stopped. Recorded ${recordedBytes.size} bytes.") + + return recordedBytes +} + +private fun calculatePeakAmplitude(buffer: ByteArray, bytesRead: Int): Int { + // Wrap the byte array in a ByteBuffer and set the order to little-endian + val shortBuffer = + ByteBuffer.wrap(buffer, 0, bytesRead).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer() + + var maxAmplitude = 0 + // Iterate through the short buffer to find the maximum absolute value + while (shortBuffer.hasRemaining()) { + val currentSample = abs(shortBuffer.get().toInt()) + if (currentSample > maxAmplitude) { + maxAmplitude = currentSample + } + } + return maxAmplitude +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt index 7eba194..e1f45b0 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt @@ -17,18 +17,22 @@ package com.google.ai.edge.gallery.ui.common.chat import android.graphics.Bitmap +import android.util.Log import androidx.compose.ui.graphics.ImageBitmap import androidx.compose.ui.unit.Dp import com.google.ai.edge.gallery.common.Classification import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.PromptTemplate +private const val TAG = "AGChatMessage" + enum class ChatMessageType { INFO, WARNING, TEXT, IMAGE, IMAGE_WITH_HISTORY, + AUDIO_CLIP, LOADING, CLASSIFICATION, CONFIG_VALUES_CHANGE, @@ -121,6 +125,90 @@ class ChatMessageImage( } } +/** Chat message for audio clip. */ +class ChatMessageAudioClip( + val audioData: ByteArray, + val sampleRate: Int, + override val side: ChatSide, + override val latencyMs: Float = 0f, +) : ChatMessage(type = ChatMessageType.AUDIO_CLIP, side = side, latencyMs = latencyMs) { + override fun clone(): ChatMessageAudioClip { + return ChatMessageAudioClip( + audioData = audioData, + sampleRate = sampleRate, + side = side, + latencyMs = latencyMs, + ) + } + + fun genByteArrayForWav(): ByteArray { + val header = ByteArray(44) + + val pcmDataSize = audioData.size + val wavFileSize = pcmDataSize + 44 // 44 bytes for the header + val channels = 1 // Mono + val bitsPerSample: Short = 16 + val byteRate = sampleRate * channels * bitsPerSample / 8 + Log.d(TAG, "Wav metadata: sampleRate: $sampleRate") + + // RIFF/WAVE header + header[0] = 'R'.code.toByte() + header[1] = 'I'.code.toByte() + header[2] = 'F'.code.toByte() + header[3] = 'F'.code.toByte() + header[4] = (wavFileSize and 0xff).toByte() + header[5] = (wavFileSize shr 8 and 0xff).toByte() + header[6] = (wavFileSize shr 16 and 0xff).toByte() + header[7] = (wavFileSize shr 24 and 0xff).toByte() + header[8] = 'W'.code.toByte() + header[9] = 'A'.code.toByte() + header[10] = 'V'.code.toByte() + header[11] = 'E'.code.toByte() + header[12] = 'f'.code.toByte() + header[13] = 'm'.code.toByte() + header[14] = 't'.code.toByte() + header[15] = ' '.code.toByte() + header[16] = 16 + header[17] = 0 + header[18] = 0 + header[19] = 0 // Sub-chunk size (16 for PCM) + header[20] = 1 + header[21] = 0 // Audio format (1 for PCM) + header[22] = channels.toByte() + header[23] = 0 // Number of channels + header[24] = (sampleRate and 0xff).toByte() + header[25] = (sampleRate shr 8 and 0xff).toByte() + header[26] = (sampleRate shr 16 and 0xff).toByte() + header[27] = (sampleRate shr 24 and 0xff).toByte() + header[28] = (byteRate and 0xff).toByte() + header[29] = (byteRate shr 8 and 0xff).toByte() + header[30] = (byteRate shr 16 and 0xff).toByte() + header[31] = (byteRate shr 24 and 0xff).toByte() + header[32] = (channels * bitsPerSample / 8).toByte() + header[33] = 0 // Block align + header[34] = bitsPerSample.toByte() + header[35] = (bitsPerSample.toInt() shr 8 and 0xff).toByte() // Bits per sample + header[36] = 'd'.code.toByte() + header[37] = 'a'.code.toByte() + header[38] = 't'.code.toByte() + header[39] = 'a'.code.toByte() + header[40] = (pcmDataSize and 0xff).toByte() + header[41] = (pcmDataSize shr 8 and 0xff).toByte() + header[42] = (pcmDataSize shr 16 and 0xff).toByte() + header[43] = (pcmDataSize shr 24 and 0xff).toByte() + + return header + audioData + } + + fun getDurationInSeconds(): Float { + // PCM 16-bit + val bytesPerSample = 2 + val bytesPerFrame = bytesPerSample * 1 // mono + val totalFrames = audioData.size.toFloat() / bytesPerFrame + return totalFrames / sampleRate + } +} + /** Chat message for images with history. */ class ChatMessageImageWithHistory( val bitmaps: List, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt index 2b8b168..80f2cfa 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt @@ -137,6 +137,19 @@ fun ChatPanel( } imageMessageCount } + val audioClipMesssageCountToLastconfigChange = + remember(messages) { + var audioClipMessageCount = 0 + for (message in messages.reversed()) { + if (message is ChatMessageConfigValuesChange) { + break + } + if (message is ChatMessageAudioClip) { + audioClipMessageCount++ + } + } + audioClipMessageCount + } var curMessage by remember { mutableStateOf("") } // Correct state val focusManager = LocalFocusManager.current @@ -342,6 +355,9 @@ fun ChatPanel( imageHistoryCurIndex = imageHistoryCurIndex, ) + // Audio clip. + is ChatMessageAudioClip -> MessageBodyAudioClip(message = message) + // Classification result is ChatMessageClassification -> MessageBodyClassification( @@ -467,6 +483,22 @@ fun ChatPanel( ) } } + // Show an info message for ask image task to get users started. + else if (task.type == TaskType.LLM_ASK_AUDIO && messages.isEmpty()) { + Column( + modifier = Modifier.padding(horizontal = 16.dp).fillMaxSize(), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center, + ) { + MessageBodyInfo( + ChatMessageInfo( + content = + "To get started, tap the + icon to add your audio clips. You can add up to 10 clips, each up to 30 seconds long." + ), + smallFontSize = false, + ) + } + } } // Chat input @@ -482,6 +514,7 @@ fun ChatPanel( isResettingSession = uiState.isResettingSession, modelPreparing = uiState.preparing, imageMessageCount = imageMessageCountToLastConfigChange, + audioClipMessageCount = audioClipMesssageCountToLastconfigChange, modelInitializing = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, textFieldPlaceHolderRes = task.textInputPlaceHolderRes, @@ -504,7 +537,10 @@ fun ChatPanel( onStopButtonClicked = onStopButtonClicked, // showPromptTemplatesInMenu = isLlmTask && notLlmStartScreen, showPromptTemplatesInMenu = false, - showImagePickerInMenu = selectedModel.llmSupportImage, + showImagePickerInMenu = + selectedModel.llmSupportImage && task.type === TaskType.LLM_ASK_IMAGE, + showAudioItemsInMenu = + selectedModel.llmSupportAudio && task.type === TaskType.LLM_ASK_AUDIO, showStopButtonWhenInProgress = showStopButtonInInputWhenInProgress, ) } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt new file mode 100644 index 0000000..e4e3716 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageBodyAudioClip.kt @@ -0,0 +1,33 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.ui.common.chat + +import androidx.compose.foundation.layout.padding +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.unit.dp + +@Composable +fun MessageBodyAudioClip(message: ChatMessageAudioClip, modifier: Modifier = Modifier) { + AudioPlaybackPanel( + audioData = message.audioData, + sampleRate = message.sampleRate, + isRecording = false, + modifier = Modifier.padding(end = 16.dp), + onDarkBg = true, + ) +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt index 1e2e056..d958815 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/MessageInputText.kt @@ -21,6 +21,7 @@ package com.google.ai.edge.gallery.ui.common.chat // import com.google.ai.edge.gallery.ui.theme.GalleryTheme import android.Manifest import android.content.Context +import android.content.Intent import android.content.pm.PackageManager import android.graphics.Bitmap import android.graphics.BitmapFactory @@ -65,9 +66,11 @@ import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.rounded.Send import androidx.compose.material.icons.rounded.Add +import androidx.compose.material.icons.rounded.AudioFile import androidx.compose.material.icons.rounded.Close import androidx.compose.material.icons.rounded.FlipCameraAndroid import androidx.compose.material.icons.rounded.History +import androidx.compose.material.icons.rounded.Mic import androidx.compose.material.icons.rounded.Photo import androidx.compose.material.icons.rounded.PhotoCamera import androidx.compose.material.icons.rounded.PostAdd @@ -107,9 +110,14 @@ import androidx.compose.ui.unit.dp import androidx.compose.ui.viewinterop.AndroidView import androidx.core.content.ContextCompat import androidx.lifecycle.compose.LocalLifecycleOwner +import com.google.ai.edge.gallery.common.AudioClip +import com.google.ai.edge.gallery.common.convertWavToMonoWithMaxSeconds +import com.google.ai.edge.gallery.data.MAX_AUDIO_CLIP_COUNT import com.google.ai.edge.gallery.data.MAX_IMAGE_COUNT +import com.google.ai.edge.gallery.data.SAMPLE_RATE import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import java.util.concurrent.Executors +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch private const val TAG = "AGMessageInputText" @@ -128,6 +136,7 @@ fun MessageInputText( isResettingSession: Boolean, inProgress: Boolean, imageMessageCount: Int, + audioClipMessageCount: Int, modelInitializing: Boolean, @StringRes textFieldPlaceHolderRes: Int, onValueChanged: (String) -> Unit, @@ -137,6 +146,7 @@ fun MessageInputText( onStopButtonClicked: () -> Unit = {}, showPromptTemplatesInMenu: Boolean = false, showImagePickerInMenu: Boolean = false, + showAudioItemsInMenu: Boolean = false, showStopButtonWhenInProgress: Boolean = false, ) { val context = LocalContext.current @@ -146,7 +156,12 @@ fun MessageInputText( var showTextInputHistorySheet by remember { mutableStateOf(false) } var showCameraCaptureBottomSheet by remember { mutableStateOf(false) } val cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) + var showAudioRecorderBottomSheet by remember { mutableStateOf(false) } + val audioRecorderSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) var pickedImages by remember { mutableStateOf>(listOf()) } + var pickedAudioClips by remember { mutableStateOf>(listOf()) } + var hasFrontCamera by remember { mutableStateOf(false) } + val updatePickedImages: (List) -> Unit = { bitmaps -> var newPickedImages: MutableList = mutableListOf() newPickedImages.addAll(pickedImages) @@ -156,7 +171,16 @@ fun MessageInputText( } pickedImages = newPickedImages.toList() } - var hasFrontCamera by remember { mutableStateOf(false) } + + val updatePickedAudioClips: (List) -> Unit = { audioDataList -> + var newAudioDataList: MutableList = mutableListOf() + newAudioDataList.addAll(pickedAudioClips) + newAudioDataList.addAll(audioDataList) + if (newAudioDataList.size > MAX_AUDIO_CLIP_COUNT) { + newAudioDataList = newAudioDataList.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT) + } + pickedAudioClips = newAudioDataList.toList() + } LaunchedEffect(Unit) { checkFrontCamera(context = context, callback = { hasFrontCamera = it }) } @@ -170,6 +194,16 @@ fun MessageInputText( } } + // Permission request when recording audio clips. + val recordAudioClipsPermissionLauncher = + rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) { + permissionGranted -> + if (permissionGranted) { + showAddContentMenu = false + showAudioRecorderBottomSheet = true + } + } + // Registers a photo picker activity launcher in single-select mode. val pickMedia = rememberLauncherForActivityResult(ActivityResultContracts.PickMultipleVisualMedia()) { uris -> @@ -184,9 +218,31 @@ fun MessageInputText( } } + val pickWav = + rememberLauncherForActivityResult( + contract = ActivityResultContracts.StartActivityForResult() + ) { result -> + if (result.resultCode == android.app.Activity.RESULT_OK) { + result.data?.data?.let { uri -> + Log.d(TAG, "Picked wav file: $uri") + scope.launch(Dispatchers.IO) { + convertWavToMonoWithMaxSeconds(context = context, stereoUri = uri)?.let { audioClip -> + updatePickedAudioClips( + listOf( + AudioClip(audioData = audioClip.audioData, sampleRate = audioClip.sampleRate) + ) + ) + } + } + } + } else { + Log.d(TAG, "Wav picking cancelled.") + } + } + Column { - // A preview panel for the selected image. - if (pickedImages.isNotEmpty()) { + // A preview panel for the selected images and audio clips. + if (pickedImages.isNotEmpty() || pickedAudioClips.isNotEmpty()) { Row( modifier = Modifier.offset(x = 16.dp).fillMaxWidth().horizontalScroll(rememberScrollState()), @@ -203,20 +259,30 @@ fun MessageInputText( .clip(RoundedCornerShape(8.dp)) .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)), ) + MediaPanelCloseButton { pickedImages = pickedImages.filter { image != it } } + } + } + + for ((index, audioClip) in pickedAudioClips.withIndex()) { + Box(contentAlignment = Alignment.TopEnd) { Box( modifier = - Modifier.offset(x = 10.dp, y = (-10).dp) - .clip(CircleShape) + Modifier.shadow(2.dp, shape = RoundedCornerShape(8.dp)) + .clip(RoundedCornerShape(8.dp)) .background(MaterialTheme.colorScheme.surface) - .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape) - .clickable { pickedImages = pickedImages.filter { image != it } } + .border(1.dp, MaterialTheme.colorScheme.outline, RoundedCornerShape(8.dp)) ) { - Icon( - Icons.Rounded.Close, - contentDescription = "", - modifier = Modifier.padding(3.dp).size(16.dp), + AudioPlaybackPanel( + audioData = audioClip.audioData, + sampleRate = audioClip.sampleRate, + isRecording = false, + modifier = Modifier.padding(end = 16.dp), ) } + MediaPanelCloseButton { + pickedAudioClips = + pickedAudioClips.filterIndexed { curIndex, curAudioData -> curIndex != index } + } } } } @@ -239,10 +305,13 @@ fun MessageInputText( verticalAlignment = Alignment.CenterVertically, ) { val enableAddImageMenuItems = (imageMessageCount + pickedImages.size) < MAX_IMAGE_COUNT + val enableRecordAudioClipMenuItems = + (audioClipMessageCount + pickedAudioClips.size) < MAX_AUDIO_CLIP_COUNT DropdownMenu( expanded = showAddContentMenu, onDismissRequest = { showAddContentMenu = false }, ) { + // Image related menu items. if (showImagePickerInMenu) { // Take a picture. DropdownMenuItem( @@ -295,6 +364,70 @@ fun MessageInputText( ) } + // Audio related menu items. + if (showAudioItemsInMenu) { + DropdownMenuItem( + text = { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp), + ) { + Icon(Icons.Rounded.Mic, contentDescription = "") + Text("Record audio clip") + } + }, + enabled = enableRecordAudioClipMenuItems, + onClick = { + // Check permission + when (PackageManager.PERMISSION_GRANTED) { + // Already got permission. Call the lambda. + ContextCompat.checkSelfPermission(context, Manifest.permission.RECORD_AUDIO) -> { + showAddContentMenu = false + showAudioRecorderBottomSheet = true + } + + // Otherwise, ask for permission + else -> { + recordAudioClipsPermissionLauncher.launch(Manifest.permission.RECORD_AUDIO) + } + } + }, + ) + + DropdownMenuItem( + text = { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp), + ) { + Icon(Icons.Rounded.AudioFile, contentDescription = "") + Text("Pick wav file") + } + }, + enabled = enableRecordAudioClipMenuItems, + onClick = { + showAddContentMenu = false + + // Show file picker. + val intent = + Intent(Intent.ACTION_GET_CONTENT).apply { + addCategory(Intent.CATEGORY_OPENABLE) + type = "audio/*" + + // Provide a list of more specific MIME types to filter for. + val mimeTypes = arrayOf("audio/wav", "audio/x-wav") + putExtra(Intent.EXTRA_MIME_TYPES, mimeTypes) + + // Single select. + putExtra(Intent.EXTRA_ALLOW_MULTIPLE, false) + .addFlags(Intent.FLAG_GRANT_PERSISTABLE_URI_PERMISSION) + .addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION) + } + pickWav.launch(intent) + }, + ) + } + // Prompt templates. if (showPromptTemplatesInMenu) { DropdownMenuItem( @@ -369,15 +502,22 @@ fun MessageInputText( ) } } - } // Send button. Only shown when text is not empty. - else if (curMessage.isNotEmpty()) { + } + // Send button. Only shown when text is not empty, or there is at least one recorded + // audio clip. + else if (curMessage.isNotEmpty() || pickedAudioClips.isNotEmpty()) { IconButton( enabled = !inProgress && !isResettingSession, onClick = { onSendMessage( - createMessagesToSend(pickedImages = pickedImages, text = curMessage.trim()) + createMessagesToSend( + pickedImages = pickedImages, + audioClips = pickedAudioClips, + text = curMessage.trim(), + ) ) pickedImages = listOf() + pickedAudioClips = listOf() }, colors = IconButtonDefaults.iconButtonColors( @@ -403,8 +543,15 @@ fun MessageInputText( history = modelManagerUiState.textInputHistory, onDismissed = { showTextInputHistorySheet = false }, onHistoryItemClicked = { item -> - onSendMessage(createMessagesToSend(pickedImages = pickedImages, text = item)) + onSendMessage( + createMessagesToSend( + pickedImages = pickedImages, + audioClips = pickedAudioClips, + text = item, + ) + ) pickedImages = listOf() + pickedAudioClips = listOf() modelManagerViewModel.promoteTextInputHistoryItem(item) }, onHistoryItemDeleted = { item -> modelManagerViewModel.deleteTextInputHistory(item) }, @@ -582,6 +729,43 @@ fun MessageInputText( } } } + + if (showAudioRecorderBottomSheet) { + ModalBottomSheet( + sheetState = audioRecorderSheetState, + onDismissRequest = { showAudioRecorderBottomSheet = false }, + ) { + AudioRecorderPanel( + onSendAudioClip = { audioData -> + scope.launch { + updatePickedAudioClips( + listOf(AudioClip(audioData = audioData, sampleRate = SAMPLE_RATE)) + ) + audioRecorderSheetState.hide() + showAudioRecorderBottomSheet = false + } + } + ) + } + } +} + +@Composable +private fun MediaPanelCloseButton(onClicked: () -> Unit) { + Box( + modifier = + Modifier.offset(x = 10.dp, y = (-10).dp) + .clip(CircleShape) + .background(MaterialTheme.colorScheme.surface) + .border((1.5).dp, MaterialTheme.colorScheme.outline, CircleShape) + .clickable { onClicked() } + ) { + Icon( + Icons.Rounded.Close, + contentDescription = "", + modifier = Modifier.padding(3.dp).size(16.dp), + ) + } } private fun handleImagesSelected( @@ -641,20 +825,50 @@ private fun checkFrontCamera(context: Context, callback: (Boolean) -> Unit) { ) } -private fun createMessagesToSend(pickedImages: List, text: String): List { +private fun createMessagesToSend( + pickedImages: List, + audioClips: List, + text: String, +): List { var messages: MutableList = mutableListOf() + + // Add image messages. + var imageMessages: MutableList = mutableListOf() if (pickedImages.isNotEmpty()) { for (image in pickedImages) { - messages.add( + imageMessages.add( ChatMessageImage(bitmap = image, imageBitMap = image.asImageBitmap(), side = ChatSide.USER) ) } } // Cap the number of image messages. - if (messages.size > MAX_IMAGE_COUNT) { - messages = messages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT) + if (imageMessages.size > MAX_IMAGE_COUNT) { + imageMessages = imageMessages.subList(fromIndex = 0, toIndex = MAX_IMAGE_COUNT) + } + messages.addAll(imageMessages) + + // Add audio messages. + var audioMessages: MutableList = mutableListOf() + if (audioClips.isNotEmpty()) { + for (audioClip in audioClips) { + audioMessages.add( + ChatMessageAudioClip( + audioData = audioClip.audioData, + sampleRate = audioClip.sampleRate, + side = ChatSide.USER, + ) + ) + } + } + // Cap the number of audio messages. + if (audioMessages.size > MAX_AUDIO_CLIP_COUNT) { + audioMessages = audioMessages.subList(fromIndex = 0, toIndex = MAX_AUDIO_CLIP_COUNT) + } + messages.addAll(audioMessages) + + if (text.isNotEmpty()) { + messages.add(ChatMessageText(content = text, side = ChatSide.USER)) } - messages.add(ChatMessageText(content = text, side = ChatSide.USER)) return messages } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt index 1f0cd00..9c985e5 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/ModelImportDialog.kt @@ -121,6 +121,7 @@ private val IMPORT_CONFIGS_LLM: List = valueType = ValueType.FLOAT, ), BooleanSwitchConfig(key = ConfigKey.SUPPORT_IMAGE, defaultValue = false), + BooleanSwitchConfig(key = ConfigKey.SUPPORT_AUDIO, defaultValue = false), SegmentedButtonConfig( key = ConfigKey.COMPATIBLE_ACCELERATORS, defaultValue = Accelerator.CPU.label, @@ -230,6 +231,12 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) - valueType = ValueType.BOOLEAN, ) as Boolean + val supportAudio = + convertValueToTargetType( + value = values.get(ConfigKey.SUPPORT_AUDIO.label)!!, + valueType = ValueType.BOOLEAN, + ) + as Boolean val importedModel: ImportedModel = ImportedModel.newBuilder() .setFileName(fileName) @@ -242,6 +249,7 @@ fun ModelImportDialog(uri: Uri, onDismiss: () -> Unit, onDone: (ImportedModel) - .setDefaultTopp(defaultTopp) .setDefaultTemperature(defaultTemperature) .setSupportImage(supportImage) + .setSupportAudio(supportAudio) .build() ) .build() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt index ee418bb..dc0bba1 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt @@ -173,7 +173,7 @@ fun SettingsDialog( color = MaterialTheme.colorScheme.onSurfaceVariant, ) Text( - "Expired at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}", + "Expires at: ${dateFormatter.format(Instant.ofEpochMilli(curHfToken.expiresAtMs))}", style = MaterialTheme.typography.bodyMedium, color = MaterialTheme.colorScheme.onSurfaceVariant, ) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt deleted file mode 100644 index 173ca4f..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Forum.kt +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.icon - -import androidx.compose.ui.graphics.Color -import androidx.compose.ui.graphics.SolidColor -import androidx.compose.ui.graphics.vector.ImageVector -import androidx.compose.ui.graphics.vector.path -import androidx.compose.ui.unit.dp - -val Forum: ImageVector - get() { - if (_Forum != null) return _Forum!! - - _Forum = - ImageVector.Builder( - name = "Forum", - defaultWidth = 24.dp, - defaultHeight = 24.dp, - viewportWidth = 960f, - viewportHeight = 960f, - ) - .apply { - path(fill = SolidColor(Color(0xFF000000))) { - moveTo(280f, 720f) - quadToRelative(-17f, 0f, -28.5f, -11.5f) - reflectiveQuadTo(240f, 680f) - verticalLineToRelative(-80f) - horizontalLineToRelative(520f) - verticalLineToRelative(-360f) - horizontalLineToRelative(80f) - quadToRelative(17f, 0f, 28.5f, 11.5f) - reflectiveQuadTo(880f, 280f) - verticalLineToRelative(600f) - lineTo(720f, 720f) - close() - moveTo(80f, 680f) - verticalLineToRelative(-560f) - quadToRelative(0f, -17f, 11.5f, -28.5f) - reflectiveQuadTo(120f, 80f) - horizontalLineToRelative(520f) - quadToRelative(17f, 0f, 28.5f, 11.5f) - reflectiveQuadTo(680f, 120f) - verticalLineToRelative(360f) - quadToRelative(0f, 17f, -11.5f, 28.5f) - reflectiveQuadTo(640f, 520f) - horizontalLineTo(240f) - close() - moveToRelative(520f, -240f) - verticalLineToRelative(-280f) - horizontalLineTo(160f) - verticalLineToRelative(280f) - close() - moveToRelative(-440f, 0f) - verticalLineToRelative(-280f) - close() - } - } - .build() - - return _Forum!! - } - -private var _Forum: ImageVector? = null diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt deleted file mode 100644 index 56f9990..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Mms.kt +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.icon - -import androidx.compose.ui.graphics.Color -import androidx.compose.ui.graphics.SolidColor -import androidx.compose.ui.graphics.vector.ImageVector -import androidx.compose.ui.graphics.vector.path -import androidx.compose.ui.unit.dp - -val Mms: ImageVector - get() { - if (_Mms != null) return _Mms!! - - _Mms = - ImageVector.Builder( - name = "Mms", - defaultWidth = 24.dp, - defaultHeight = 24.dp, - viewportWidth = 960f, - viewportHeight = 960f, - ) - .apply { - path(fill = SolidColor(Color(0xFF000000))) { - moveTo(240f, 560f) - horizontalLineToRelative(480f) - lineTo(570f, 360f) - lineTo(450f, 520f) - lineToRelative(-90f, -120f) - close() - moveTo(80f, 880f) - verticalLineToRelative(-720f) - quadToRelative(0f, -33f, 23.5f, -56.5f) - reflectiveQuadTo(160f, 80f) - horizontalLineToRelative(640f) - quadToRelative(33f, 0f, 56.5f, 23.5f) - reflectiveQuadTo(880f, 160f) - verticalLineToRelative(480f) - quadToRelative(0f, 33f, -23.5f, 56.5f) - reflectiveQuadTo(800f, 720f) - horizontalLineTo(240f) - close() - moveToRelative(126f, -240f) - horizontalLineToRelative(594f) - verticalLineToRelative(-480f) - horizontalLineTo(160f) - verticalLineToRelative(525f) - close() - moveToRelative(-46f, 0f) - verticalLineToRelative(-480f) - close() - } - } - .build() - - return _Mms!! - } - -private var _Mms: ImageVector? = null diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt deleted file mode 100644 index c727a7b..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/icon/Widgets.kt.kt +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.icon - -import androidx.compose.ui.graphics.Color -import androidx.compose.ui.graphics.SolidColor -import androidx.compose.ui.graphics.vector.ImageVector -import androidx.compose.ui.graphics.vector.path -import androidx.compose.ui.unit.dp - -val Widgets: ImageVector - get() { - if (_Widgets != null) return _Widgets!! - - _Widgets = - ImageVector.Builder( - name = "Widgets", - defaultWidth = 24.dp, - defaultHeight = 24.dp, - viewportWidth = 960f, - viewportHeight = 960f, - ) - .apply { - path(fill = SolidColor(Color(0xFF000000))) { - moveTo(666f, 520f) - lineTo(440f, 294f) - lineToRelative(226f, -226f) - lineToRelative(226f, 226f) - close() - moveToRelative(-546f, -80f) - verticalLineToRelative(-320f) - horizontalLineToRelative(320f) - verticalLineToRelative(320f) - close() - moveToRelative(400f, 400f) - verticalLineToRelative(-320f) - horizontalLineToRelative(320f) - verticalLineToRelative(320f) - close() - moveToRelative(-400f, 0f) - verticalLineToRelative(-320f) - horizontalLineToRelative(320f) - verticalLineToRelative(320f) - close() - moveToRelative(80f, -480f) - horizontalLineToRelative(160f) - verticalLineToRelative(-160f) - horizontalLineTo(200f) - close() - moveToRelative(467f, 48f) - lineToRelative(113f, -113f) - lineToRelative(-113f, -113f) - lineToRelative(-113f, 113f) - close() - moveToRelative(-67f, 352f) - horizontalLineToRelative(160f) - verticalLineToRelative(-160f) - horizontalLineTo(600f) - close() - moveToRelative(-400f, 0f) - horizontalLineToRelative(160f) - verticalLineToRelative(-160f) - horizontalLineTo(200f) - close() - moveToRelative(400f, -160f) - } - } - .build() - - return _Widgets!! - } - -private var _Widgets: ImageVector? = null diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt index 128baa8..a330c35 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -62,13 +62,13 @@ object LlmChatModelHelper { Accelerator.GPU.label -> LlmInference.Backend.GPU else -> LlmInference.Backend.GPU } - val options = + val optionsBuilder = LlmInference.LlmInferenceOptions.builder() .setModelPath(model.getPath(context = context)) .setMaxTokens(maxTokens) .setPreferredBackend(preferredBackend) .setMaxNumImages(if (model.llmSupportImage) MAX_IMAGE_COUNT else 0) - .build() + val options = optionsBuilder.build() // Create an instance of the LLM Inference task and session. try { @@ -82,7 +82,9 @@ object LlmChatModelHelper { .setTopP(topP) .setTemperature(temperature) .setGraphOptions( - GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() + GraphOptions.builder() + .setEnableVisionModality(model.llmSupportImage) + .build() ) .build(), ) @@ -115,7 +117,9 @@ object LlmChatModelHelper { .setTopP(topP) .setTemperature(temperature) .setGraphOptions( - GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() + GraphOptions.builder() + .setEnableVisionModality(model.llmSupportImage) + .build() ) .build(), ) @@ -159,6 +163,7 @@ object LlmChatModelHelper { resultListener: ResultListener, cleanUpListener: CleanUpListener, images: List = listOf(), + audioClips: List = listOf(), ) { val instance = model.instance as LlmModelInstance @@ -172,10 +177,16 @@ object LlmChatModelHelper { // For a model that supports image modality, we need to add the text query chunk before adding // image. val session = instance.session - session.addQueryChunk(input) + if (input.trim().isNotEmpty()) { + session.addQueryChunk(input) + } for (image in images) { session.addImage(BitmapImageBuilder(image).build()) } + for (audioClip in audioClips) { + // Uncomment when audio is supported. + // session.addAudio(audioClip) + } val unused = session.generateResponseAsync(resultListener) } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt index 0a97f3f..23b5777 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt @@ -22,6 +22,7 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.lifecycle.viewmodel.compose.viewModel import com.google.ai.edge.gallery.ui.ViewModelProvider +import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText import com.google.ai.edge.gallery.ui.common.chat.ChatView @@ -36,6 +37,10 @@ object LlmAskImageDestination { val route = "LlmAskImageRoute" } +object LlmAskAudioDestination { + val route = "LlmAskAudioRoute" +} + @Composable fun LlmChatScreen( modelManagerViewModel: ModelManagerViewModel, @@ -66,6 +71,21 @@ fun LlmAskImageScreen( ) } +@Composable +fun LlmAskAudioScreen( + modelManagerViewModel: ModelManagerViewModel, + navigateUp: () -> Unit, + modifier: Modifier = Modifier, + viewModel: LlmAskAudioViewModel = viewModel(factory = ViewModelProvider.Factory), +) { + ChatViewWrapper( + viewModel = viewModel, + modelManagerViewModel = modelManagerViewModel, + navigateUp = navigateUp, + modifier = modifier, + ) +} + @Composable fun ChatViewWrapper( viewModel: LlmChatViewModel, @@ -86,6 +106,7 @@ fun ChatViewWrapper( var text = "" val images: MutableList = mutableListOf() + val audioMessages: MutableList = mutableListOf() var chatMessageText: ChatMessageText? = null for (message in messages) { if (message is ChatMessageText) { @@ -93,14 +114,17 @@ fun ChatViewWrapper( text = message.content } else if (message is ChatMessageImage) { images.add(message.bitmap) + } else if (message is ChatMessageAudioClip) { + audioMessages.add(message) } } - if (text.isNotEmpty() && chatMessageText != null) { + if ((text.isNotEmpty() && chatMessageText != null) || audioMessages.isNotEmpty()) { modelManagerViewModel.addTextInputHistory(text) viewModel.generateResponse( model = model, input = text, images = images, + audioMessages = audioMessages, onError = { viewModel.handleError( context = context, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt index 7b0a083..d97a9df 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -22,9 +22,11 @@ import android.util.Log import androidx.lifecycle.viewModelScope import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.Model +import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.Task +import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText @@ -52,6 +54,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task model: Model, input: String, images: List = listOf(), + audioMessages: List = listOf(), onError: () -> Unit, ) { val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") @@ -72,6 +75,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task val instance = model.instance as LlmModelInstance var prefillTokens = instance.session.sizeInTokens(input) prefillTokens += images.size * 257 + for (audioMessages in audioMessages) { + // 150ms = 1 audio token + val duration = audioMessages.getDurationInSeconds() + prefillTokens += (duration * 1000f / 150f).toInt() + } var firstRun = true var timeToFirstToken = 0f @@ -86,6 +94,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task model = model, input = input, images = images, + audioClips = audioMessages.map { it.genByteArrayForWav() }, resultListener = { partialResult, done -> val curTs = System.currentTimeMillis() @@ -214,7 +223,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task context: Context, model: Model, modelManagerViewModel: ModelManagerViewModel, - triggeredMessage: ChatMessageText, + triggeredMessage: ChatMessageText?, ) { // Clean up. modelManagerViewModel.cleanupModel(task = task, model = model) @@ -236,14 +245,20 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task ) // Add the triggered message back. - addMessage(model = model, message = triggeredMessage) + if (triggeredMessage != null) { + addMessage(model = model, message = triggeredMessage) + } // Re-initialize the session/engine. modelManagerViewModel.initializeModel(context = context, task = task, model = model) // Re-generate the response automatically. - generateResponse(model = model, input = triggeredMessage.content, onError = {}) + if (triggeredMessage != null) { + generateResponse(model = model, input = triggeredMessage.content, onError = {}) + } } } class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) + +class LlmAskAudioViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_AUDIO) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt index f73b893..6e2fb9e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -36,6 +36,7 @@ import com.google.ai.edge.gallery.data.ModelAllowlist import com.google.ai.edge.gallery.data.ModelDownloadStatus import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.TASKS +import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB @@ -281,15 +282,12 @@ open class ModelManagerViewModel( } } when (task.type) { - TaskType.LLM_CHAT -> - LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone) - + TaskType.LLM_CHAT, + TaskType.LLM_ASK_IMAGE, + TaskType.LLM_ASK_AUDIO, TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone) - TaskType.LLM_ASK_IMAGE -> - LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone) - TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_2 -> {} } @@ -301,9 +299,11 @@ open class ModelManagerViewModel( model.cleanUpAfterInit = false Log.d(TAG, "Cleaning up model '${model.name}'...") when (task.type) { - TaskType.LLM_CHAT -> LlmChatModelHelper.cleanUp(model = model) - TaskType.LLM_PROMPT_LAB -> LlmChatModelHelper.cleanUp(model = model) - TaskType.LLM_ASK_IMAGE -> LlmChatModelHelper.cleanUp(model = model) + TaskType.LLM_CHAT, + TaskType.LLM_PROMPT_LAB, + TaskType.LLM_ASK_IMAGE, + TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model) + TaskType.TEST_TASK_1 -> {} TaskType.TEST_TASK_2 -> {} } @@ -410,14 +410,19 @@ open class ModelManagerViewModel( // Create model. val model = createModelFromImportedModelInfo(info = info) - for (task in listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) { + for (task in + listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)) { // Remove duplicated imported model if existed. val modelIndex = task.models.indexOfFirst { info.fileName == it.name && it.imported } if (modelIndex >= 0) { Log.d(TAG, "duplicated imported model found in task. Removing it first") task.models.removeAt(modelIndex) } - if ((task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || task != TASK_LLM_ASK_IMAGE) { + if ( + (task == TASK_LLM_ASK_IMAGE && model.llmSupportImage) || + (task == TASK_LLM_ASK_AUDIO && model.llmSupportAudio) || + (task != TASK_LLM_ASK_IMAGE && task != TASK_LLM_ASK_AUDIO) + ) { task.models.add(model) } task.updateTrigger.value = System.currentTimeMillis() @@ -657,6 +662,7 @@ open class ModelManagerViewModel( TASK_LLM_CHAT.models.clear() TASK_LLM_PROMPT_LAB.models.clear() TASK_LLM_ASK_IMAGE.models.clear() + TASK_LLM_ASK_AUDIO.models.clear() for (allowedModel in modelAllowlist.models) { if (allowedModel.disabled == true) { continue @@ -672,6 +678,9 @@ open class ModelManagerViewModel( if (allowedModel.taskTypes.contains(TASK_LLM_ASK_IMAGE.type.id)) { TASK_LLM_ASK_IMAGE.models.add(model) } + if (allowedModel.taskTypes.contains(TASK_LLM_ASK_AUDIO.type.id)) { + TASK_LLM_ASK_AUDIO.models.add(model) + } } // Pre-process all tasks. @@ -760,6 +769,9 @@ open class ModelManagerViewModel( if (model.llmSupportImage) { TASK_LLM_ASK_IMAGE.models.add(model) } + if (model.llmSupportAudio) { + TASK_LLM_ASK_AUDIO.models.add(model) + } // Update status. modelDownloadStatus[model.name] = @@ -800,6 +812,7 @@ open class ModelManagerViewModel( accelerators = accelerators, ) val llmSupportImage = info.llmConfig.supportImage + val llmSupportAudio = info.llmConfig.supportAudio val model = Model( name = info.fileName, @@ -811,6 +824,7 @@ open class ModelManagerViewModel( showRunAgainButton = false, imported = true, llmSupportImage = llmSupportImage, + llmSupportAudio = llmSupportAudio, ) model.preProcess() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt index 5e8672d..138edc7 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt @@ -47,6 +47,7 @@ import androidx.navigation.compose.NavHost import androidx.navigation.compose.composable import androidx.navigation.navArgument import com.google.ai.edge.gallery.data.Model +import com.google.ai.edge.gallery.data.TASK_LLM_ASK_AUDIO import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB @@ -55,6 +56,8 @@ import com.google.ai.edge.gallery.data.TaskType import com.google.ai.edge.gallery.data.getModelByName import com.google.ai.edge.gallery.ui.ViewModelProvider import com.google.ai.edge.gallery.ui.home.HomeScreen +import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination +import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination @@ -209,7 +212,7 @@ fun GalleryNavHost( } } - // LLM image to text. + // Ask image. composable( route = "${LlmAskImageDestination.route}/{modelName}", arguments = listOf(navArgument("modelName") { type = NavType.StringType }), @@ -225,6 +228,23 @@ fun GalleryNavHost( ) } } + + // Ask audio. + composable( + route = "${LlmAskAudioDestination.route}/{modelName}", + arguments = listOf(navArgument("modelName") { type = NavType.StringType }), + enterTransition = { slideEnter() }, + exitTransition = { slideExit() }, + ) { + getModelFromNavigationParam(it, TASK_LLM_ASK_AUDIO)?.let { defaultModel -> + modelManagerViewModel.selectModel(defaultModel) + + LlmAskAudioScreen( + modelManagerViewModel = modelManagerViewModel, + navigateUp = { navController.navigateUp() }, + ) + } + } } // Handle incoming intents for deep links @@ -256,6 +276,7 @@ fun navigateToTaskScreen( when (taskType) { TaskType.LLM_CHAT -> navController.navigate("${LlmChatDestination.route}/${modelName}") TaskType.LLM_ASK_IMAGE -> navController.navigate("${LlmAskImageDestination.route}/${modelName}") + TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}") TaskType.LLM_PROMPT_LAB -> navController.navigate("${LlmSingleTurnDestination.route}/${modelName}") TaskType.TEST_TASK_1 -> {} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt index 72ad1d5..a1a7ad1 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/theme/Theme.kt @@ -120,6 +120,8 @@ data class CustomColors( val agentBubbleBgColor: Color = Color.Transparent, val linkColor: Color = Color.Transparent, val successColor: Color = Color.Transparent, + val recordButtonBgColor: Color = Color.Transparent, + val waveFormBgColor: Color = Color.Transparent, ) val LocalCustomColors = staticCompositionLocalOf { CustomColors() } @@ -145,6 +147,8 @@ val lightCustomColors = userBubbleBgColor = Color(0xFF32628D), linkColor = Color(0xFF32628D), successColor = Color(0xff3d860b), + recordButtonBgColor = Color(0xFFEE675C), + waveFormBgColor = Color(0xFFaaaaaa), ) val darkCustomColors = @@ -168,6 +172,8 @@ val darkCustomColors = userBubbleBgColor = Color(0xFF1f3760), linkColor = Color(0xFF9DCAFC), successColor = Color(0xFFA1CE83), + recordButtonBgColor = Color(0xFFEE675C), + waveFormBgColor = Color(0xFFaaaaaa), ) val MaterialTheme.customColors: CustomColors diff --git a/Android/src/app/src/main/proto/settings.proto b/Android/src/app/src/main/proto/settings.proto index 5540eaf..3b5f6f3 100644 --- a/Android/src/app/src/main/proto/settings.proto +++ b/Android/src/app/src/main/proto/settings.proto @@ -55,6 +55,7 @@ message LlmConfig { float default_topp = 4; float default_temperature = 5; bool support_image = 6; + bool support_audio = 7; } message Settings { From 2a95e5853b6cd67e952dd1f2ad1c104969810e14 Mon Sep 17 00:00:00 2001 From: Google AI Edge Gallery Date: Tue, 24 Jun 2025 09:23:21 -0700 Subject: [PATCH 5/9] Add a simple local test for allowlisted model. PiperOrigin-RevId: 775265777 --- .../ai/edge/gallery/data/ModelAllowlist.kt | 8 +- .../edge/gallery/data/ModelAllowlistTest.kt | 107 ++++++++++++++++++ 2 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt index 5cf11ca..03263ca 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt @@ -50,10 +50,10 @@ data class AllowedModel( taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id) var configs: List = listOf() if (isLlmModel) { - var defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK - var defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP - var defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE - var defaultMaxToken = defaultConfig.maxTokens ?: 1024 + val defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK + val defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP + val defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE + val defaultMaxToken = defaultConfig.maxTokens ?: 1024 var accelerators: List = DEFAULT_ACCELERATORS if (defaultConfig.accelerators != null) { val items = defaultConfig.accelerators.split(",") diff --git a/Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt b/Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt new file mode 100644 index 0000000..bfc04bb --- /dev/null +++ b/Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt @@ -0,0 +1,107 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.data + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +@RunWith(JUnit4::class) +class ModelAllowlistTest { + @Test + fun toModel_success() { + val modelName = "test_model" + val modelId = "test_model_id" + val modelFile = "test_model_file" + val description = "test description" + val sizeInBytes = 100L + val version = "20250623" + val topK = 10 + val topP = 0.5f + val temperature = 0.1f + val maxTokens = 1000 + val accelerators = "gpu,cpu" + val taskTypes = listOf("llm_chat", "ask_image") + val estimatedPeakMemoryInBytes = 300L + + val allowedModel = + AllowedModel( + name = modelName, + modelId = modelId, + modelFile = modelFile, + description = description, + sizeInBytes = sizeInBytes, + version = version, + defaultConfig = + DefaultConfig( + topK = topK, + topP = topP, + temperature = temperature, + maxTokens = maxTokens, + accelerators = accelerators, + ), + taskTypes = taskTypes, + llmSupportImage = true, + llmSupportAudio = true, + estimatedPeakMemoryInBytes = estimatedPeakMemoryInBytes, + ) + val model = allowedModel.toModel() + + // Check that basic fields are set correctly. + assertEquals(model.name, modelName) + assertEquals(model.version, version) + assertEquals(model.info, description) + assertEquals( + model.url, + "https://huggingface.co/test_model_id/resolve/main/test_model_file?download=true", + ) + assertEquals(model.sizeInBytes, sizeInBytes) + assertEquals(model.estimatedPeakMemoryInBytes, estimatedPeakMemoryInBytes) + assertEquals(model.downloadFileName, modelFile) + assertFalse(model.showBenchmarkButton) + assertFalse(model.showRunAgainButton) + assertTrue(model.llmSupportImage) + assertTrue(model.llmSupportAudio) + + // Check that configs are set correctly. + assertEquals(model.configs.size, 5) + + // A label for showing max tokens (non-changeable). + assertTrue(model.configs[0] is LabelConfig) + assertEquals((model.configs[0] as LabelConfig).defaultValue, "$maxTokens") + + // A slider for topK. + assertTrue(model.configs[1] is NumberSliderConfig) + assertEquals((model.configs[1] as NumberSliderConfig).defaultValue, topK.toFloat()) + + // A slider for topP. + assertTrue(model.configs[2] is NumberSliderConfig) + assertEquals((model.configs[2] as NumberSliderConfig).defaultValue, topP) + + // A slider for temperature. + assertTrue(model.configs[3] is NumberSliderConfig) + assertEquals((model.configs[3] as NumberSliderConfig).defaultValue, temperature) + + // A segmented button for accelerators. + assertTrue(model.configs[4] is SegmentedButtonConfig) + assertEquals((model.configs[4] as SegmentedButtonConfig).defaultValue, "GPU") + assertEquals((model.configs[4] as SegmentedButtonConfig).options, listOf("GPU", "CPU")) + } +} From 3c5302b4fc30ded23939e4aca5d73f7c172d6c1a Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Wed, 25 Jun 2025 00:10:25 -0700 Subject: [PATCH 6/9] Update build_android.yaml --- .github/workflows/build_android.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/build_android.yaml b/.github/workflows/build_android.yaml index 0ecbccb..1515255 100644 --- a/.github/workflows/build_android.yaml +++ b/.github/workflows/build_android.yaml @@ -21,5 +21,9 @@ jobs: steps: - name: Checkout the source code uses: actions/checkout@v3 + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '21' - name: Build run: ./gradlew assembleRelease From db0242fe88111add2e1eaf5d08cf65facd753998 Mon Sep 17 00:00:00 2001 From: Google AI Edge Gallery Date: Wed, 25 Jun 2025 13:54:34 -0700 Subject: [PATCH 7/9] [gallery] improve demo experience to keep the screen always on. PiperOrigin-RevId: 775825781 --- .../src/main/java/com/google/ai/edge/gallery/MainActivity.kt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt index dc172dc..10cbfee 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt @@ -18,6 +18,7 @@ package com.google.ai.edge.gallery import android.os.Build import android.os.Bundle +import android.view.WindowManager import androidx.activity.ComponentActivity import androidx.activity.compose.setContent import androidx.activity.enableEdgeToEdge @@ -39,5 +40,7 @@ class MainActivity : ComponentActivity() { window.isNavigationBarContrastEnforced = false } setContent { GalleryTheme { Surface(modifier = Modifier.fillMaxSize()) { GalleryApp() } } } + // Keep the screen on while the app is running for better demo experience. + window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON) } } From 665c86a6408163623ce00c539a98fb0fce5264a9 Mon Sep 17 00:00:00 2001 From: Google AI Edge Gallery Date: Wed, 25 Jun 2025 16:21:43 -0700 Subject: [PATCH 8/9] Update the max number of audio clips to 1 PiperOrigin-RevId: 775879562 --- .../app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt | 2 +- .../java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt index 85fd71e..2315099 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt @@ -46,7 +46,7 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU) const val MAX_IMAGE_COUNT = 10 // Max number of audio clip in an "ask audio" session. -const val MAX_AUDIO_CLIP_COUNT = 10 +const val MAX_AUDIO_CLIP_COUNT = 1 // Max audio clip duration in seconds. const val MAX_AUDIO_CLIP_DURATION_SEC = 30 diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt index 80f2cfa..06615e2 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt @@ -493,7 +493,7 @@ fun ChatPanel( MessageBodyInfo( ChatMessageInfo( content = - "To get started, tap the + icon to add your audio clips. You can add up to 10 clips, each up to 30 seconds long." + "To get started, tap the + icon to add your audio clip. Limited to 1 clip up to 30 seconds long." ), smallFontSize = false, ) From 323124a628ef7ad05b910624dc1f89804bf67520 Mon Sep 17 00:00:00 2001 From: Google AI Edge Gallery Date: Thu, 26 Jun 2025 18:08:36 -0700 Subject: [PATCH 9/9] Refactor code to migrate manual dependency injection to using Hilt PiperOrigin-RevId: 776356661 --- Android/src/app/build.gradle.kts | 6 + .../ai/edge/gallery/GalleryApplication.kt | 39 ++----- .../google/ai/edge/gallery/MainActivity.kt | 3 + .../ai/edge/gallery/SettingsSerializer.kt | 38 +++++++ .../google/ai/edge/gallery/di/AppModule.kt | 86 +++++++++++++++ .../ai/edge/gallery/ui/ViewModelProvider.kt | 64 ----------- .../gallery/ui/common/chat/ChatViewModel.kt | 2 +- .../edge/gallery/ui/llmchat/LlmChatScreen.kt | 10 +- .../gallery/ui/llmchat/LlmChatViewModel.kt | 19 +++- .../ui/llmsingleturn/LlmSingleTurnScreen.kt | 7 +- .../llmsingleturn/LlmSingleTurnViewModel.kt | 7 +- .../ui/modelmanager/ModelManagerViewModel.kt | 10 +- .../gallery/ui/navigation/GalleryNavGraph.kt | 37 +++++-- .../gallery/ui/preview/PreviewChatModel.kt | 91 ---------------- .../ui/preview/PreviewDataStoreRepository.kt | 57 ---------- .../ui/preview/PreviewDownloadRepository.kt | 44 -------- .../preview/PreviewLlmSingleTurnViewModel.kt | 21 ---- .../preview/PreviewModelManagerViewModel.kt | 63 ----------- .../edge/gallery/ui/preview/PreviewTasks.kt | 103 ------------------ Android/src/build.gradle.kts | 1 + Android/src/gradle/libs.versions.toml | 7 ++ 21 files changed, 209 insertions(+), 506 deletions(-) create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/SettingsSerializer.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewChatModel.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDataStoreRepository.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDownloadRepository.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewLlmSingleTurnViewModel.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewModelManagerViewModel.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewTasks.kt diff --git a/Android/src/app/build.gradle.kts b/Android/src/app/build.gradle.kts index c259a9a..b051876 100644 --- a/Android/src/app/build.gradle.kts +++ b/Android/src/app/build.gradle.kts @@ -20,6 +20,8 @@ plugins { alias(libs.plugins.kotlin.compose) alias(libs.plugins.kotlin.serialization) alias(libs.plugins.protobuf) + alias(libs.plugins.hilt.application) + kotlin("kapt") } android { @@ -91,11 +93,15 @@ dependencies { implementation(libs.openid.appauth) implementation(libs.androidx.splashscreen) implementation(libs.protobuf.javalite) + implementation(libs.hilt.android) + implementation(libs.hilt.navigation.compose) + kapt(libs.hilt.android.compiler) testImplementation(libs.junit) androidTestImplementation(libs.androidx.junit) androidTestImplementation(libs.androidx.espresso.core) androidTestImplementation(platform(libs.androidx.compose.bom)) androidTestImplementation(libs.androidx.ui.test.junit4) + androidTestImplementation(libs.hilt.android.testing) debugImplementation(libs.androidx.ui.tooling) debugImplementation(libs.androidx.ui.test.manifest) } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApplication.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApplication.kt index 4572311..1574f0b 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApplication.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApplication.kt @@ -17,48 +17,23 @@ package com.google.ai.edge.gallery import android.app.Application -import android.content.Context -import androidx.datastore.core.CorruptionException -import androidx.datastore.core.DataStore -import androidx.datastore.core.Serializer -import androidx.datastore.dataStore import com.google.ai.edge.gallery.common.writeLaunchInfo -import com.google.ai.edge.gallery.data.AppContainer -import com.google.ai.edge.gallery.data.DefaultAppContainer -import com.google.ai.edge.gallery.proto.Settings +import com.google.ai.edge.gallery.data.DataStoreRepository import com.google.ai.edge.gallery.ui.theme.ThemeSettings -import com.google.protobuf.InvalidProtocolBufferException -import java.io.InputStream -import java.io.OutputStream - -object SettingsSerializer : Serializer { - override val defaultValue: Settings = Settings.getDefaultInstance() - - override suspend fun readFrom(input: InputStream): Settings { - try { - return Settings.parseFrom(input) - } catch (exception: InvalidProtocolBufferException) { - throw CorruptionException("Cannot read proto.", exception) - } - } - - override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output) -} - -private val Context.dataStore: DataStore by - dataStore(fileName = "settings.pb", serializer = SettingsSerializer) +import dagger.hilt.android.HiltAndroidApp +import javax.inject.Inject +@HiltAndroidApp class GalleryApplication : Application() { - /** AppContainer instance used by the rest of classes to obtain dependencies */ - lateinit var container: AppContainer + + @Inject lateinit var dataStoreRepository: DataStoreRepository override fun onCreate() { super.onCreate() writeLaunchInfo(context = this) - container = DefaultAppContainer(this, dataStore) // Load saved theme. - ThemeSettings.themeOverride.value = container.dataStoreRepository.readTheme() + ThemeSettings.themeOverride.value = dataStoreRepository.readTheme() } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt index 10cbfee..cd3c9ea 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/MainActivity.kt @@ -27,8 +27,11 @@ import androidx.compose.material3.Surface import androidx.compose.ui.Modifier import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen import com.google.ai.edge.gallery.ui.theme.GalleryTheme +import dagger.hilt.android.AndroidEntryPoint +@AndroidEntryPoint class MainActivity : ComponentActivity() { + override fun onCreate(savedInstanceState: Bundle?) { installSplashScreen() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/SettingsSerializer.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/SettingsSerializer.kt new file mode 100644 index 0000000..c7381f1 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/SettingsSerializer.kt @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery + +import androidx.datastore.core.CorruptionException +import androidx.datastore.core.Serializer +import com.google.ai.edge.gallery.proto.Settings +import com.google.protobuf.InvalidProtocolBufferException +import java.io.InputStream +import java.io.OutputStream + +object SettingsSerializer : Serializer { + override val defaultValue: Settings = Settings.getDefaultInstance() + + override suspend fun readFrom(input: InputStream): Settings { + try { + return Settings.parseFrom(input) + } catch (exception: InvalidProtocolBufferException) { + throw CorruptionException("Cannot read proto.", exception) + } + } + + override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output) +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt new file mode 100644 index 0000000..a634822 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt @@ -0,0 +1,86 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.di + +import android.content.Context +import androidx.datastore.core.DataStore +import androidx.datastore.core.DataStoreFactory +import androidx.datastore.core.Serializer +import androidx.datastore.dataStoreFile +import com.google.ai.edge.gallery.AppLifecycleProvider +import com.google.ai.edge.gallery.GalleryLifecycleProvider +import com.google.ai.edge.gallery.SettingsSerializer +import com.google.ai.edge.gallery.data.DataStoreRepository +import com.google.ai.edge.gallery.data.DefaultDataStoreRepository +import com.google.ai.edge.gallery.data.DefaultDownloadRepository +import com.google.ai.edge.gallery.data.DownloadRepository +import com.google.ai.edge.gallery.proto.Settings +import dagger.Module +import dagger.Provides +import dagger.hilt.InstallIn +import dagger.hilt.android.qualifiers.ApplicationContext +import dagger.hilt.components.SingletonComponent +import javax.inject.Singleton + +@Module +@InstallIn(SingletonComponent::class) +internal object AppModule { + + // Provides the SettingsSerializer + @Provides + @Singleton + fun provideSettingsSerializer(): Serializer { + return SettingsSerializer + } + + // Provides DataStore + @Provides + @Singleton + fun provideSettingsDataStore( + @ApplicationContext context: Context, + settingsSerializer: Serializer, + ): DataStore { + return DataStoreFactory.create( + serializer = settingsSerializer, + produceFile = { context.dataStoreFile("settings.pb") }, + ) + } + + // Provides AppLifecycleProvider + @Provides + @Singleton + fun provideAppLifecycleProvider(): AppLifecycleProvider { + return GalleryLifecycleProvider() + } + + // Provides DataStoreRepository + @Provides + @Singleton + fun provideDataStoreRepository(dataStore: DataStore): DataStoreRepository { + return DefaultDataStoreRepository(dataStore) + } + + // Provides DownloadRepository + @Provides + @Singleton + fun provideDownloadRepository( + @ApplicationContext context: Context, + lifecycleProvider: AppLifecycleProvider, + ): DownloadRepository { + return DefaultDownloadRepository(context, lifecycleProvider) + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt deleted file mode 100644 index 6ceb148..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui - -import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory -import androidx.lifecycle.viewmodel.CreationExtras -import androidx.lifecycle.viewmodel.initializer -import androidx.lifecycle.viewmodel.viewModelFactory -import com.google.ai.edge.gallery.GalleryApplication -import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel -import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel -import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel -import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel -import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel - -object ViewModelProvider { - val Factory = viewModelFactory { - // Initializer for ModelManagerViewModel. - initializer { - val downloadRepository = galleryApplication().container.downloadRepository - val dataStoreRepository = galleryApplication().container.dataStoreRepository - val lifecycleProvider = galleryApplication().container.lifecycleProvider - ModelManagerViewModel( - downloadRepository = downloadRepository, - dataStoreRepository = dataStoreRepository, - lifecycleProvider = lifecycleProvider, - context = galleryApplication().container.context, - ) - } - - // Initializer for LlmChatViewModel. - initializer { LlmChatViewModel() } - - // Initializer for LlmSingleTurnViewModel.. - initializer { LlmSingleTurnViewModel() } - - // Initializer for LlmAskImageViewModel. - initializer { LlmAskImageViewModel() } - - // Initializer for LlmAskAudioViewModel. - initializer { LlmAskAudioViewModel() } - } -} - -/** - * Extension function to queries for [Application] object and returns an instance of - * [GalleryApplication]. - */ -fun CreationExtras.galleryApplication(): GalleryApplication = - (this[AndroidViewModelFactory.APPLICATION_KEY] as GalleryApplication) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatViewModel.kt index 95785aa..d2f64e2 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatViewModel.kt @@ -53,7 +53,7 @@ data class ChatUiState( ) /** ViewModel responsible for managing the chat UI state and handling chat-related operations. */ -open class ChatViewModel(val task: Task) : ViewModel() { +abstract class ChatViewModel(val task: Task) : ViewModel() { private val _uiState = MutableStateFlow(createUiState(task = task)) val uiState = _uiState.asStateFlow() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt index 23b5777..e61e2eb 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt @@ -20,8 +20,6 @@ import android.graphics.Bitmap import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext -import androidx.lifecycle.viewmodel.compose.viewModel -import com.google.ai.edge.gallery.ui.ViewModelProvider import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText @@ -46,7 +44,7 @@ fun LlmChatScreen( modelManagerViewModel: ModelManagerViewModel, navigateUp: () -> Unit, modifier: Modifier = Modifier, - viewModel: LlmChatViewModel = viewModel(factory = ViewModelProvider.Factory), + viewModel: LlmChatViewModel, ) { ChatViewWrapper( viewModel = viewModel, @@ -61,7 +59,7 @@ fun LlmAskImageScreen( modelManagerViewModel: ModelManagerViewModel, navigateUp: () -> Unit, modifier: Modifier = Modifier, - viewModel: LlmAskImageViewModel = viewModel(factory = ViewModelProvider.Factory), + viewModel: LlmAskImageViewModel, ) { ChatViewWrapper( viewModel = viewModel, @@ -76,7 +74,7 @@ fun LlmAskAudioScreen( modelManagerViewModel: ModelManagerViewModel, navigateUp: () -> Unit, modifier: Modifier = Modifier, - viewModel: LlmAskAudioViewModel = viewModel(factory = ViewModelProvider.Factory), + viewModel: LlmAskAudioViewModel, ) { ChatViewWrapper( viewModel = viewModel, @@ -88,7 +86,7 @@ fun LlmAskAudioScreen( @Composable fun ChatViewWrapper( - viewModel: LlmChatViewModel, + viewModel: LlmChatViewModelBase, modelManagerViewModel: ModelManagerViewModel, navigateUp: () -> Unit, modifier: Modifier = Modifier, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt index d97a9df..205d4e0 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -36,6 +36,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatSide import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel import com.google.ai.edge.gallery.ui.common.chat.Stat import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel +import dagger.hilt.android.lifecycle.HiltViewModel +import javax.inject.Inject import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.launch @@ -49,7 +51,7 @@ private val STATS = Stat(id = "latency", label = "Latency", unit = "sec"), ) -open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) { +open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTask) { fun generateResponse( model: Model, input: String, @@ -75,9 +77,9 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task val instance = model.instance as LlmModelInstance var prefillTokens = instance.session.sizeInTokens(input) prefillTokens += images.size * 257 - for (audioMessages in audioMessages) { + for (audioMessage in audioMessages) { // 150ms = 1 audio token - val duration = audioMessages.getDurationInSeconds() + val duration = audioMessage.getDurationInSeconds() prefillTokens += (duration * 1000f / 150f).toInt() } @@ -259,6 +261,13 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task } } -class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) +@HiltViewModel +class LlmChatViewModel @Inject constructor() : LlmChatViewModelBase(curTask = TASK_LLM_CHAT) -class LlmAskAudioViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_AUDIO) +@HiltViewModel +class LlmAskImageViewModel @Inject constructor() : + LlmChatViewModelBase(curTask = TASK_LLM_ASK_IMAGE) + +@HiltViewModel +class LlmAskAudioViewModel @Inject constructor() : + LlmChatViewModelBase(curTask = TASK_LLM_ASK_AUDIO) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt index 2759ba4..f62e2f8 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt @@ -43,9 +43,8 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.draw.alpha import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalLayoutDirection -import androidx.lifecycle.viewmodel.compose.viewModel import com.google.ai.edge.gallery.data.ModelDownloadStatusType -import com.google.ai.edge.gallery.ui.ViewModelProvider +import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB import com.google.ai.edge.gallery.ui.common.ErrorDialog import com.google.ai.edge.gallery.ui.common.ModelPageAppBar import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel @@ -67,9 +66,9 @@ fun LlmSingleTurnScreen( modelManagerViewModel: ModelManagerViewModel, navigateUp: () -> Unit, modifier: Modifier = Modifier, - viewModel: LlmSingleTurnViewModel = viewModel(factory = ViewModelProvider.Factory), + viewModel: LlmSingleTurnViewModel, ) { - val task = viewModel.task + val task = TASK_LLM_PROMPT_LAB val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState() val selectedModel = modelManagerUiState.selectedModel diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt index 88414d6..861d5fd 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt @@ -27,6 +27,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.ai.edge.gallery.ui.common.chat.Stat import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance +import dagger.hilt.android.lifecycle.HiltViewModel +import javax.inject.Inject import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.flow.MutableStateFlow @@ -63,8 +65,9 @@ private val STATS = Stat(id = "latency", label = "Latency", unit = "sec"), ) -open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() { - private val _uiState = MutableStateFlow(createUiState(task = task)) +@HiltViewModel +class LlmSingleTurnViewModel @Inject constructor() : ViewModel() { + private val _uiState = MutableStateFlow(createUiState(task = TASK_LLM_PROMPT_LAB)) val uiState = _uiState.asStateFlow() fun generateResponse(model: Model, input: String) { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt index 6e2fb9e..b8b4eee 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -52,9 +52,12 @@ import com.google.ai.edge.gallery.ui.common.AuthConfig import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper import com.google.gson.Gson import com.google.gson.reflect.TypeToken +import dagger.hilt.android.lifecycle.HiltViewModel +import dagger.hilt.android.qualifiers.ApplicationContext import java.io.File import java.net.HttpURLConnection import java.net.URL +import javax.inject.Inject import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.flow.MutableStateFlow @@ -134,11 +137,14 @@ data class PagerScrollState(val page: Int = 0, val offset: Float = 0f) * cleaning up models. It also manages the UI state for model management, including the list of * tasks, models, download statuses, and initialization statuses. */ -open class ModelManagerViewModel( +@HiltViewModel +open class ModelManagerViewModel +@Inject +constructor( private val downloadRepository: DownloadRepository, private val dataStoreRepository: DataStoreRepository, private val lifecycleProvider: AppLifecycleProvider, - context: Context, + @ApplicationContext private val context: Context, ) : ViewModel() { private val externalFilesDir = context.getExternalFilesDir(null) private val inProgressWorkInfos: List = diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt index 138edc7..e4ebd48 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/navigation/GalleryNavGraph.kt @@ -36,10 +36,10 @@ import androidx.compose.runtime.setValue import androidx.compose.ui.Modifier import androidx.compose.ui.unit.IntOffset import androidx.compose.ui.zIndex +import androidx.hilt.navigation.compose.hiltViewModel import androidx.lifecycle.Lifecycle import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.compose.LocalLifecycleOwner -import androidx.lifecycle.viewmodel.compose.viewModel import androidx.navigation.NavBackStackEntry import androidx.navigation.NavHostController import androidx.navigation.NavType @@ -54,16 +54,19 @@ import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.TaskType import com.google.ai.edge.gallery.data.getModelByName -import com.google.ai.edge.gallery.ui.ViewModelProvider import com.google.ai.edge.gallery.ui.home.HomeScreen import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen +import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen +import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination import com.google.ai.edge.gallery.ui.llmchat.LlmChatScreen +import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnDestination import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnScreen +import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManager import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel @@ -107,7 +110,7 @@ private fun AnimatedContentTransitionScope<*>.slideExit(): ExitTransition { fun GalleryNavHost( navController: NavHostController, modifier: Modifier = Modifier, - modelManagerViewModel: ModelManagerViewModel = viewModel(factory = ViewModelProvider.Factory), + modelManagerViewModel: ModelManagerViewModel = hiltViewModel(), ) { val lifecycleOwner = LocalLifecycleOwner.current var showModelManager by remember { mutableStateOf(false) } @@ -184,11 +187,14 @@ fun GalleryNavHost( arguments = listOf(navArgument("modelName") { type = NavType.StringType }), enterTransition = { slideEnter() }, exitTransition = { slideExit() }, - ) { - getModelFromNavigationParam(it, TASK_LLM_CHAT)?.let { defaultModel -> + ) { backStackEntry -> + val viewModel: LlmChatViewModel = hiltViewModel(backStackEntry) + + getModelFromNavigationParam(backStackEntry, TASK_LLM_CHAT)?.let { defaultModel -> modelManagerViewModel.selectModel(defaultModel) LlmChatScreen( + viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, navigateUp = { navController.navigateUp() }, ) @@ -201,11 +207,14 @@ fun GalleryNavHost( arguments = listOf(navArgument("modelName") { type = NavType.StringType }), enterTransition = { slideEnter() }, exitTransition = { slideExit() }, - ) { - getModelFromNavigationParam(it, TASK_LLM_PROMPT_LAB)?.let { defaultModel -> + ) { backStackEntry -> + val viewModel: LlmSingleTurnViewModel = hiltViewModel(backStackEntry) + + getModelFromNavigationParam(backStackEntry, TASK_LLM_PROMPT_LAB)?.let { defaultModel -> modelManagerViewModel.selectModel(defaultModel) LlmSingleTurnScreen( + viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, navigateUp = { navController.navigateUp() }, ) @@ -218,11 +227,14 @@ fun GalleryNavHost( arguments = listOf(navArgument("modelName") { type = NavType.StringType }), enterTransition = { slideEnter() }, exitTransition = { slideExit() }, - ) { - getModelFromNavigationParam(it, TASK_LLM_ASK_IMAGE)?.let { defaultModel -> + ) { backStackEntry -> + val viewModel: LlmAskImageViewModel = hiltViewModel() + + getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_IMAGE)?.let { defaultModel -> modelManagerViewModel.selectModel(defaultModel) LlmAskImageScreen( + viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, navigateUp = { navController.navigateUp() }, ) @@ -235,11 +247,14 @@ fun GalleryNavHost( arguments = listOf(navArgument("modelName") { type = NavType.StringType }), enterTransition = { slideEnter() }, exitTransition = { slideExit() }, - ) { - getModelFromNavigationParam(it, TASK_LLM_ASK_AUDIO)?.let { defaultModel -> + ) { backStackEntry -> + val viewModel: LlmAskAudioViewModel = hiltViewModel() + + getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_AUDIO)?.let { defaultModel -> modelManagerViewModel.selectModel(defaultModel) LlmAskAudioScreen( + viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, navigateUp = { navController.navigateUp() }, ) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewChatModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewChatModel.kt deleted file mode 100644 index bab07cb..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewChatModel.kt +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.preview - -import android.content.Context -import android.graphics.Bitmap -import android.graphics.Canvas -import android.graphics.drawable.Drawable -import androidx.compose.ui.graphics.Color -import androidx.compose.ui.graphics.asImageBitmap -import androidx.core.content.ContextCompat -import androidx.core.graphics.createBitmap -import com.google.ai.edge.gallery.R -import com.google.ai.edge.gallery.common.Classification -import com.google.ai.edge.gallery.ui.common.chat.ChatMessageClassification -import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage -import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText -import com.google.ai.edge.gallery.ui.common.chat.ChatSide -import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel - -class PreviewChatModel(context: Context) : ChatViewModel(task = TASK_TEST1) { - init { - val model = task.models[1] - addMessage( - model = model, - message = - ChatMessageText( - content = - "Thanks everyone for your enthusiasm on the team lunch, but people who can sign on the cheque is OOO next week \uD83D\uDE02,", - side = ChatSide.USER, - ), - ) - addMessage( - model = model, - message = - ChatMessageText(content = "Today is Wednesday!", side = ChatSide.AGENT, latencyMs = 1232f), - ) - addMessage( - model = model, - message = - ChatMessageClassification( - classifications = - listOf( - Classification(label = "label1", score = 0.3f, color = Color.Red), - Classification(label = "label2", score = 0.7f, color = Color.Blue), - ), - latencyMs = 12345f, - ), - ) - val bitmap = - getBitmapFromVectorDrawable( - context = context, - drawableId = R.drawable.ic_launcher_background, - )!! - addMessage( - model = model, - message = - ChatMessageImage( - bitmap = bitmap, - imageBitMap = bitmap.asImageBitmap(), - side = ChatSide.USER, - ), - ) - } - - private fun getBitmapFromVectorDrawable(context: Context, drawableId: Int): Bitmap? { - val drawable: Drawable = - ContextCompat.getDrawable(context, drawableId) ?: return null // Drawable not found - - val bitmap = createBitmap(drawable.intrinsicWidth, drawable.intrinsicHeight) - val canvas = Canvas(bitmap) - drawable.setBounds(0, 0, canvas.width, canvas.height) - drawable.draw(canvas) - - return bitmap - } -} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDataStoreRepository.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDataStoreRepository.kt deleted file mode 100644 index b5e46c9..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDataStoreRepository.kt +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.preview - -// TODO(migration) -// -// import com.google.ai.edge.gallery.data.AccessTokenData -// import com.google.ai.edge.gallery.data.DataStoreRepository -// import com.google.ai.edge.gallery.data.ImportedModelInfo - -// class PreviewDataStoreRepository : DataStoreRepository -class PreviewDataStoreRepository { - // override fun saveTextInputHistory(history: List) { - // } - - // override fun readTextInputHistory(): List { - // return listOf() - // } - - // override fun saveThemeOverride(theme: String) { - // } - - // override fun readThemeOverride(): String { - // return "" - // } - - // override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) { - // } - - // override fun readAccessTokenData(): AccessTokenData? { - // return null - // } - - // override fun clearAccessTokenData() { - // } - - // override fun saveImportedModels(importedModels: List) { - // } - - // override fun readImportedModels(): List { - // return listOf() - // } -} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDownloadRepository.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDownloadRepository.kt deleted file mode 100644 index 88e7dfa..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewDownloadRepository.kt +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.preview - -import com.google.ai.edge.gallery.data.AGWorkInfo -import com.google.ai.edge.gallery.data.DownloadRepository -import com.google.ai.edge.gallery.data.Model -import com.google.ai.edge.gallery.data.ModelDownloadStatus -import java.util.UUID - -class PreviewDownloadRepository : DownloadRepository { - override fun downloadModel( - model: Model, - onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit, - ) {} - - override fun cancelDownloadModel(model: Model) {} - - override fun cancelAll(models: List, onComplete: () -> Unit) {} - - override fun observerWorkerProgress( - workerId: UUID, - model: Model, - onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit, - ) {} - - override fun getEnqueuedOrRunningWorkInfos(): List { - return listOf() - } -} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewLlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewLlmSingleTurnViewModel.kt deleted file mode 100644 index 0c8d9dc..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewLlmSingleTurnViewModel.kt +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.preview - -import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel - -class PreviewLlmSingleTurnViewModel : LlmSingleTurnViewModel(task = TASK_TEST1) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewModelManagerViewModel.kt deleted file mode 100644 index 612cbe6..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewModelManagerViewModel.kt +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.preview - -class PreviewModelManagerViewModel {} - -// class PreviewModelManagerViewModel(context: Context) : -// ModelManagerViewModel( -// downloadRepository = PreviewDownloadRepository(), -// // dataStoreRepository = PreviewDataStoreRepository(), -// context = context, -// ) { - -// init { -// for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) { -// task.index = index -// for (model in task.models) { -// model.preProcess() -// } -// } - -// val modelDownloadStatus = -// mapOf( -// MODEL_TEST1.name to -// ModelDownloadStatus( -// status = ModelDownloadStatusType.IN_PROGRESS, -// receivedBytes = 1234, -// totalBytes = 3456, -// bytesPerSecond = 2333, -// remainingMs = 324, -// ), -// MODEL_TEST2.name to ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED), -// MODEL_TEST3.name to -// ModelDownloadStatus( -// status = ModelDownloadStatusType.FAILED, -// errorMessage = "Http code 404", -// ), -// MODEL_TEST4.name to ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED), -// ) -// val newUiState = -// ModelManagerUiState( -// tasks = ALL_PREVIEW_TASKS, -// modelDownloadStatus = modelDownloadStatus, -// modelInitializationStatus = mapOf(), -// selectedModel = MODEL_TEST2, -// ) -// _uiState.update { newUiState } -// } -// } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewTasks.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewTasks.kt deleted file mode 100644 index 44e2e3d..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/preview/PreviewTasks.kt +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.preview - -import androidx.compose.material.icons.Icons -import androidx.compose.material.icons.rounded.AccountBox -import androidx.compose.material.icons.rounded.AutoAwesome -import com.google.ai.edge.gallery.data.BooleanSwitchConfig -import com.google.ai.edge.gallery.data.Config -import com.google.ai.edge.gallery.data.ConfigKey -import com.google.ai.edge.gallery.data.LabelConfig -import com.google.ai.edge.gallery.data.Model -import com.google.ai.edge.gallery.data.NumberSliderConfig -import com.google.ai.edge.gallery.data.SegmentedButtonConfig -import com.google.ai.edge.gallery.data.Task -import com.google.ai.edge.gallery.data.TaskType -import com.google.ai.edge.gallery.data.ValueType - -val TEST_CONFIGS1: List = - listOf( - LabelConfig(key = ConfigKey.NAME, defaultValue = "Test name"), - NumberSliderConfig( - key = ConfigKey.MAX_RESULT_COUNT, - sliderMin = 1f, - sliderMax = 5f, - defaultValue = 3f, - valueType = ValueType.INT, - ), - BooleanSwitchConfig(key = ConfigKey.USE_GPU, defaultValue = false), - SegmentedButtonConfig( - key = ConfigKey.THEME, - defaultValue = "Auto", - options = listOf("Auto", "Light", "Dark"), - ), - ) - -val MODEL_TEST1: Model = - Model( - name = "deterministic3", - downloadFileName = "deterministric3.json", - url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/deterministic3.json", - sizeInBytes = 40146048L, - configs = TEST_CONFIGS1, - ) - -val MODEL_TEST2: Model = - Model( - name = "isnet", - downloadFileName = "isnet.tflite", - url = - "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/isnet-general-use-int8.tflite", - sizeInBytes = 44366296L, - configs = TEST_CONFIGS1, - ) - -val MODEL_TEST3: Model = - Model( - name = "yolo", - downloadFileName = "yolo.json", - url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/yolo.json", - sizeInBytes = 40641364L, - ) - -val MODEL_TEST4: Model = - Model( - name = "mobilenet v3", - downloadFileName = "mobilenet_v3_large.pt2", - url = - "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/mobilenet_v3_large.pt2", - sizeInBytes = 277135998L, - ) - -val TASK_TEST1 = - Task( - type = TaskType.TEST_TASK_1, - icon = Icons.Rounded.AutoAwesome, - models = mutableListOf(MODEL_TEST1, MODEL_TEST2), - description = "This is a test task (1)", - ) - -val TASK_TEST2 = - Task( - type = TaskType.TEST_TASK_2, - icon = Icons.Rounded.AccountBox, - models = mutableListOf(MODEL_TEST3, MODEL_TEST4), - description = "This is a test task (2)", - ) - -val ALL_PREVIEW_TASKS: List = listOf(TASK_TEST1, TASK_TEST2) diff --git a/Android/src/build.gradle.kts b/Android/src/build.gradle.kts index da93723..fe99361 100644 --- a/Android/src/build.gradle.kts +++ b/Android/src/build.gradle.kts @@ -19,4 +19,5 @@ plugins { alias(libs.plugins.android.application) apply false alias(libs.plugins.kotlin.android) apply false alias(libs.plugins.kotlin.compose) apply false + alias(libs.plugins.hilt.application) apply false } diff --git a/Android/src/gradle/libs.versions.toml b/Android/src/gradle/libs.versions.toml index 704cefe..57b3aa1 100644 --- a/Android/src/gradle/libs.versions.toml +++ b/Android/src/gradle/libs.versions.toml @@ -29,6 +29,8 @@ playServicesTfliteGpu= "16.4.0" cameraX = "1.4.2" netOpenidAppauth = "0.11.1" splashscreen = "1.2.0-beta01" +hilt = "2.56.2" +hiltNavigation = "1.2.0" [libraries] androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } @@ -67,6 +69,10 @@ camerax-view = { group = "androidx.camera", name = "camera-view", version.ref = openid-appauth = { group = "net.openid", name = "appauth", version.ref = "netOpenidAppauth" } androidx-splashscreen = { group = "androidx.core", name = "core-splashscreen", version.ref = "splashscreen" } protobuf-javalite = { group = "com.google.protobuf", name = "protobuf-javalite", version.ref = "protobufJavaLite" } +hilt-android = { module = "com.google.dagger:hilt-android", version.ref = "hilt" } +hilt-navigation-compose = { module = "androidx.hilt:hilt-navigation-compose", version.ref = "hiltNavigation" } +hilt-android-testing = { module = "com.google.dagger:hilt-android-testing", version.ref = "hilt" } +hilt-android-compiler = { module = "com.google.dagger:hilt-android-compiler", version.ref = "hilt" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" } @@ -74,3 +80,4 @@ kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "serializationPlugin" } protobuf = {id = "com.google.protobuf", version.ref = "protobuf"} +hilt-application = { id = "com.google.dagger.hilt.android", version.ref = "hilt" }