Refactor code to migrate manual dependency injection to using Hilt

PiperOrigin-RevId: 776356661
This commit is contained in:
Google AI Edge Gallery 2025-06-26 18:08:36 -07:00 committed by Copybara-Service
parent 665c86a640
commit 323124a628
21 changed files with 209 additions and 506 deletions

View file

@ -20,6 +20,8 @@ plugins {
alias(libs.plugins.kotlin.compose) alias(libs.plugins.kotlin.compose)
alias(libs.plugins.kotlin.serialization) alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.protobuf) alias(libs.plugins.protobuf)
alias(libs.plugins.hilt.application)
kotlin("kapt")
} }
android { android {
@ -91,11 +93,15 @@ dependencies {
implementation(libs.openid.appauth) implementation(libs.openid.appauth)
implementation(libs.androidx.splashscreen) implementation(libs.androidx.splashscreen)
implementation(libs.protobuf.javalite) implementation(libs.protobuf.javalite)
implementation(libs.hilt.android)
implementation(libs.hilt.navigation.compose)
kapt(libs.hilt.android.compiler)
testImplementation(libs.junit) testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit) androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core) androidTestImplementation(libs.androidx.espresso.core)
androidTestImplementation(platform(libs.androidx.compose.bom)) androidTestImplementation(platform(libs.androidx.compose.bom))
androidTestImplementation(libs.androidx.ui.test.junit4) androidTestImplementation(libs.androidx.ui.test.junit4)
androidTestImplementation(libs.hilt.android.testing)
debugImplementation(libs.androidx.ui.tooling) debugImplementation(libs.androidx.ui.tooling)
debugImplementation(libs.androidx.ui.test.manifest) debugImplementation(libs.androidx.ui.test.manifest)
} }

View file

@ -17,48 +17,23 @@
package com.google.ai.edge.gallery package com.google.ai.edge.gallery
import android.app.Application import android.app.Application
import android.content.Context
import androidx.datastore.core.CorruptionException
import androidx.datastore.core.DataStore
import androidx.datastore.core.Serializer
import androidx.datastore.dataStore
import com.google.ai.edge.gallery.common.writeLaunchInfo import com.google.ai.edge.gallery.common.writeLaunchInfo
import com.google.ai.edge.gallery.data.AppContainer import com.google.ai.edge.gallery.data.DataStoreRepository
import com.google.ai.edge.gallery.data.DefaultAppContainer
import com.google.ai.edge.gallery.proto.Settings
import com.google.ai.edge.gallery.ui.theme.ThemeSettings import com.google.ai.edge.gallery.ui.theme.ThemeSettings
import com.google.protobuf.InvalidProtocolBufferException import dagger.hilt.android.HiltAndroidApp
import java.io.InputStream import javax.inject.Inject
import java.io.OutputStream
object SettingsSerializer : Serializer<Settings> {
override val defaultValue: Settings = Settings.getDefaultInstance()
override suspend fun readFrom(input: InputStream): Settings {
try {
return Settings.parseFrom(input)
} catch (exception: InvalidProtocolBufferException) {
throw CorruptionException("Cannot read proto.", exception)
}
}
override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
}
private val Context.dataStore: DataStore<Settings> by
dataStore(fileName = "settings.pb", serializer = SettingsSerializer)
@HiltAndroidApp
class GalleryApplication : Application() { class GalleryApplication : Application() {
/** AppContainer instance used by the rest of classes to obtain dependencies */
lateinit var container: AppContainer @Inject lateinit var dataStoreRepository: DataStoreRepository
override fun onCreate() { override fun onCreate() {
super.onCreate() super.onCreate()
writeLaunchInfo(context = this) writeLaunchInfo(context = this)
container = DefaultAppContainer(this, dataStore)
// Load saved theme. // Load saved theme.
ThemeSettings.themeOverride.value = container.dataStoreRepository.readTheme() ThemeSettings.themeOverride.value = dataStoreRepository.readTheme()
} }
} }

View file

@ -27,8 +27,11 @@ import androidx.compose.material3.Surface
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import com.google.ai.edge.gallery.ui.theme.GalleryTheme import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import dagger.hilt.android.AndroidEntryPoint
@AndroidEntryPoint
class MainActivity : ComponentActivity() { class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) { override fun onCreate(savedInstanceState: Bundle?) {
installSplashScreen() installSplashScreen()

View file

@ -0,0 +1,38 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery
import androidx.datastore.core.CorruptionException
import androidx.datastore.core.Serializer
import com.google.ai.edge.gallery.proto.Settings
import com.google.protobuf.InvalidProtocolBufferException
import java.io.InputStream
import java.io.OutputStream
object SettingsSerializer : Serializer<Settings> {
override val defaultValue: Settings = Settings.getDefaultInstance()
override suspend fun readFrom(input: InputStream): Settings {
try {
return Settings.parseFrom(input)
} catch (exception: InvalidProtocolBufferException) {
throw CorruptionException("Cannot read proto.", exception)
}
}
override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
}

View file

@ -0,0 +1,86 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.di
import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.core.DataStoreFactory
import androidx.datastore.core.Serializer
import androidx.datastore.dataStoreFile
import com.google.ai.edge.gallery.AppLifecycleProvider
import com.google.ai.edge.gallery.GalleryLifecycleProvider
import com.google.ai.edge.gallery.SettingsSerializer
import com.google.ai.edge.gallery.data.DataStoreRepository
import com.google.ai.edge.gallery.data.DefaultDataStoreRepository
import com.google.ai.edge.gallery.data.DefaultDownloadRepository
import com.google.ai.edge.gallery.data.DownloadRepository
import com.google.ai.edge.gallery.proto.Settings
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.android.qualifiers.ApplicationContext
import dagger.hilt.components.SingletonComponent
import javax.inject.Singleton
@Module
@InstallIn(SingletonComponent::class)
internal object AppModule {
// Provides the SettingsSerializer
@Provides
@Singleton
fun provideSettingsSerializer(): Serializer<Settings> {
return SettingsSerializer
}
// Provides DataStore<Settings>
@Provides
@Singleton
fun provideSettingsDataStore(
@ApplicationContext context: Context,
settingsSerializer: Serializer<Settings>,
): DataStore<Settings> {
return DataStoreFactory.create(
serializer = settingsSerializer,
produceFile = { context.dataStoreFile("settings.pb") },
)
}
// Provides AppLifecycleProvider
@Provides
@Singleton
fun provideAppLifecycleProvider(): AppLifecycleProvider {
return GalleryLifecycleProvider()
}
// Provides DataStoreRepository
@Provides
@Singleton
fun provideDataStoreRepository(dataStore: DataStore<Settings>): DataStoreRepository {
return DefaultDataStoreRepository(dataStore)
}
// Provides DownloadRepository
@Provides
@Singleton
fun provideDownloadRepository(
@ApplicationContext context: Context,
lifecycleProvider: AppLifecycleProvider,
): DownloadRepository {
return DefaultDownloadRepository(context, lifecycleProvider)
}
}

View file

@ -1,64 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui
import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory
import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.initializer
import androidx.lifecycle.viewmodel.viewModelFactory
import com.google.ai.edge.gallery.GalleryApplication
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
object ViewModelProvider {
val Factory = viewModelFactory {
// Initializer for ModelManagerViewModel.
initializer {
val downloadRepository = galleryApplication().container.downloadRepository
val dataStoreRepository = galleryApplication().container.dataStoreRepository
val lifecycleProvider = galleryApplication().container.lifecycleProvider
ModelManagerViewModel(
downloadRepository = downloadRepository,
dataStoreRepository = dataStoreRepository,
lifecycleProvider = lifecycleProvider,
context = galleryApplication().container.context,
)
}
// Initializer for LlmChatViewModel.
initializer { LlmChatViewModel() }
// Initializer for LlmSingleTurnViewModel..
initializer { LlmSingleTurnViewModel() }
// Initializer for LlmAskImageViewModel.
initializer { LlmAskImageViewModel() }
// Initializer for LlmAskAudioViewModel.
initializer { LlmAskAudioViewModel() }
}
}
/**
* Extension function to queries for [Application] object and returns an instance of
* [GalleryApplication].
*/
fun CreationExtras.galleryApplication(): GalleryApplication =
(this[AndroidViewModelFactory.APPLICATION_KEY] as GalleryApplication)

View file

@ -53,7 +53,7 @@ data class ChatUiState(
) )
/** ViewModel responsible for managing the chat UI state and handling chat-related operations. */ /** ViewModel responsible for managing the chat UI state and handling chat-related operations. */
open class ChatViewModel(val task: Task) : ViewModel() { abstract class ChatViewModel(val task: Task) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task)) private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow() val uiState = _uiState.asStateFlow()

View file

@ -20,8 +20,6 @@ import android.graphics.Bitmap
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
@ -46,7 +44,7 @@ fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmChatViewModel = viewModel(factory = ViewModelProvider.Factory), viewModel: LlmChatViewModel,
) { ) {
ChatViewWrapper( ChatViewWrapper(
viewModel = viewModel, viewModel = viewModel,
@ -61,7 +59,7 @@ fun LlmAskImageScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmAskImageViewModel = viewModel(factory = ViewModelProvider.Factory), viewModel: LlmAskImageViewModel,
) { ) {
ChatViewWrapper( ChatViewWrapper(
viewModel = viewModel, viewModel = viewModel,
@ -76,7 +74,7 @@ fun LlmAskAudioScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmAskAudioViewModel = viewModel(factory = ViewModelProvider.Factory), viewModel: LlmAskAudioViewModel,
) { ) {
ChatViewWrapper( ChatViewWrapper(
viewModel = viewModel, viewModel = viewModel,
@ -88,7 +86,7 @@ fun LlmAskAudioScreen(
@Composable @Composable
fun ChatViewWrapper( fun ChatViewWrapper(
viewModel: LlmChatViewModel, viewModel: LlmChatViewModelBase,
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,

View file

@ -36,6 +36,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import com.google.ai.edge.gallery.ui.common.chat.Stat import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -49,7 +51,7 @@ private val STATS =
Stat(id = "latency", label = "Latency", unit = "sec"), Stat(id = "latency", label = "Latency", unit = "sec"),
) )
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) { open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTask) {
fun generateResponse( fun generateResponse(
model: Model, model: Model,
input: String, input: String,
@ -75,9 +77,9 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
var prefillTokens = instance.session.sizeInTokens(input) var prefillTokens = instance.session.sizeInTokens(input)
prefillTokens += images.size * 257 prefillTokens += images.size * 257
for (audioMessages in audioMessages) { for (audioMessage in audioMessages) {
// 150ms = 1 audio token // 150ms = 1 audio token
val duration = audioMessages.getDurationInSeconds() val duration = audioMessage.getDurationInSeconds()
prefillTokens += (duration * 1000f / 150f).toInt() prefillTokens += (duration * 1000f / 150f).toInt()
} }
@ -259,6 +261,13 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
} }
} }
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) @HiltViewModel
class LlmChatViewModel @Inject constructor() : LlmChatViewModelBase(curTask = TASK_LLM_CHAT)
class LlmAskAudioViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_AUDIO) @HiltViewModel
class LlmAskImageViewModel @Inject constructor() :
LlmChatViewModelBase(curTask = TASK_LLM_ASK_IMAGE)
@HiltViewModel
class LlmAskAudioViewModel @Inject constructor() :
LlmChatViewModelBase(curTask = TASK_LLM_ASK_AUDIO)

View file

@ -43,9 +43,8 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection import androidx.compose.ui.platform.LocalLayoutDirection
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.data.ModelDownloadStatusType import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.ui.ViewModelProvider import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.ui.common.ErrorDialog import com.google.ai.edge.gallery.ui.common.ErrorDialog
import com.google.ai.edge.gallery.ui.common.ModelPageAppBar import com.google.ai.edge.gallery.ui.common.ModelPageAppBar
import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
@ -67,9 +66,9 @@ fun LlmSingleTurnScreen(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit, navigateUp: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
viewModel: LlmSingleTurnViewModel = viewModel(factory = ViewModelProvider.Factory), viewModel: LlmSingleTurnViewModel,
) { ) {
val task = viewModel.task val task = TASK_LLM_PROMPT_LAB
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel val selectedModel = modelManagerUiState.selectedModel

View file

@ -27,6 +27,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.Stat import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
@ -63,8 +65,9 @@ private val STATS =
Stat(id = "latency", label = "Latency", unit = "sec"), Stat(id = "latency", label = "Latency", unit = "sec"),
) )
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() { @HiltViewModel
private val _uiState = MutableStateFlow(createUiState(task = task)) class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = TASK_LLM_PROMPT_LAB))
val uiState = _uiState.asStateFlow() val uiState = _uiState.asStateFlow()
fun generateResponse(model: Model, input: String) { fun generateResponse(model: Model, input: String) {

View file

@ -52,9 +52,12 @@ import com.google.ai.edge.gallery.ui.common.AuthConfig
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.gson.Gson import com.google.gson.Gson
import com.google.gson.reflect.TypeToken import com.google.gson.reflect.TypeToken
import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext
import java.io.File import java.io.File
import java.net.HttpURLConnection import java.net.HttpURLConnection
import java.net.URL import java.net.URL
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
@ -134,11 +137,14 @@ data class PagerScrollState(val page: Int = 0, val offset: Float = 0f)
* cleaning up models. It also manages the UI state for model management, including the list of * cleaning up models. It also manages the UI state for model management, including the list of
* tasks, models, download statuses, and initialization statuses. * tasks, models, download statuses, and initialization statuses.
*/ */
open class ModelManagerViewModel( @HiltViewModel
open class ModelManagerViewModel
@Inject
constructor(
private val downloadRepository: DownloadRepository, private val downloadRepository: DownloadRepository,
private val dataStoreRepository: DataStoreRepository, private val dataStoreRepository: DataStoreRepository,
private val lifecycleProvider: AppLifecycleProvider, private val lifecycleProvider: AppLifecycleProvider,
context: Context, @ApplicationContext private val context: Context,
) : ViewModel() { ) : ViewModel() {
private val externalFilesDir = context.getExternalFilesDir(null) private val externalFilesDir = context.getExternalFilesDir(null)
private val inProgressWorkInfos: List<AGWorkInfo> = private val inProgressWorkInfos: List<AGWorkInfo> =

View file

@ -36,10 +36,10 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.IntOffset import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.zIndex import androidx.compose.ui.zIndex
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.Lifecycle import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.navigation.NavBackStackEntry import androidx.navigation.NavBackStackEntry
import androidx.navigation.NavHostController import androidx.navigation.NavHostController
import androidx.navigation.NavType import androidx.navigation.NavType
@ -54,16 +54,19 @@ import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.TaskType import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.data.getModelByName import com.google.ai.edge.gallery.data.getModelByName
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.home.HomeScreen import com.google.ai.edge.gallery.ui.home.HomeScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmChatScreen import com.google.ai.edge.gallery.ui.llmchat.LlmChatScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnDestination import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnScreen import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManager import com.google.ai.edge.gallery.ui.modelmanager.ModelManager
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
@ -107,7 +110,7 @@ private fun AnimatedContentTransitionScope<*>.slideExit(): ExitTransition {
fun GalleryNavHost( fun GalleryNavHost(
navController: NavHostController, navController: NavHostController,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
modelManagerViewModel: ModelManagerViewModel = viewModel(factory = ViewModelProvider.Factory), modelManagerViewModel: ModelManagerViewModel = hiltViewModel(),
) { ) {
val lifecycleOwner = LocalLifecycleOwner.current val lifecycleOwner = LocalLifecycleOwner.current
var showModelManager by remember { mutableStateOf(false) } var showModelManager by remember { mutableStateOf(false) }
@ -184,11 +187,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) { backStackEntry ->
getModelFromNavigationParam(it, TASK_LLM_CHAT)?.let { defaultModel -> val viewModel: LlmChatViewModel = hiltViewModel(backStackEntry)
getModelFromNavigationParam(backStackEntry, TASK_LLM_CHAT)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmChatScreen( LlmChatScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() }, navigateUp = { navController.navigateUp() },
) )
@ -201,11 +207,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) { backStackEntry ->
getModelFromNavigationParam(it, TASK_LLM_PROMPT_LAB)?.let { defaultModel -> val viewModel: LlmSingleTurnViewModel = hiltViewModel(backStackEntry)
getModelFromNavigationParam(backStackEntry, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen( LlmSingleTurnScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() }, navigateUp = { navController.navigateUp() },
) )
@ -218,11 +227,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) { backStackEntry ->
getModelFromNavigationParam(it, TASK_LLM_ASK_IMAGE)?.let { defaultModel -> val viewModel: LlmAskImageViewModel = hiltViewModel()
getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmAskImageScreen( LlmAskImageScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() }, navigateUp = { navController.navigateUp() },
) )
@ -235,11 +247,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }), arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() }, enterTransition = { slideEnter() },
exitTransition = { slideExit() }, exitTransition = { slideExit() },
) { ) { backStackEntry ->
getModelFromNavigationParam(it, TASK_LLM_ASK_AUDIO)?.let { defaultModel -> val viewModel: LlmAskAudioViewModel = hiltViewModel()
getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_AUDIO)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel) modelManagerViewModel.selectModel(defaultModel)
LlmAskAudioScreen( LlmAskAudioScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() }, navigateUp = { navController.navigateUp() },
) )

View file

@ -1,91 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import android.content.Context
import android.graphics.Bitmap
import android.graphics.Canvas
import android.graphics.drawable.Drawable
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.core.content.ContextCompat
import androidx.core.graphics.createBitmap
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.common.Classification
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageClassification
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
class PreviewChatModel(context: Context) : ChatViewModel(task = TASK_TEST1) {
init {
val model = task.models[1]
addMessage(
model = model,
message =
ChatMessageText(
content =
"Thanks everyone for your enthusiasm on the team lunch, but people who can sign on the cheque is OOO next week \uD83D\uDE02,",
side = ChatSide.USER,
),
)
addMessage(
model = model,
message =
ChatMessageText(content = "Today is Wednesday!", side = ChatSide.AGENT, latencyMs = 1232f),
)
addMessage(
model = model,
message =
ChatMessageClassification(
classifications =
listOf(
Classification(label = "label1", score = 0.3f, color = Color.Red),
Classification(label = "label2", score = 0.7f, color = Color.Blue),
),
latencyMs = 12345f,
),
)
val bitmap =
getBitmapFromVectorDrawable(
context = context,
drawableId = R.drawable.ic_launcher_background,
)!!
addMessage(
model = model,
message =
ChatMessageImage(
bitmap = bitmap,
imageBitMap = bitmap.asImageBitmap(),
side = ChatSide.USER,
),
)
}
private fun getBitmapFromVectorDrawable(context: Context, drawableId: Int): Bitmap? {
val drawable: Drawable =
ContextCompat.getDrawable(context, drawableId) ?: return null // Drawable not found
val bitmap = createBitmap(drawable.intrinsicWidth, drawable.intrinsicHeight)
val canvas = Canvas(bitmap)
drawable.setBounds(0, 0, canvas.width, canvas.height)
drawable.draw(canvas)
return bitmap
}
}

View file

@ -1,57 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
// TODO(migration)
//
// import com.google.ai.edge.gallery.data.AccessTokenData
// import com.google.ai.edge.gallery.data.DataStoreRepository
// import com.google.ai.edge.gallery.data.ImportedModelInfo
// class PreviewDataStoreRepository : DataStoreRepository
class PreviewDataStoreRepository {
// override fun saveTextInputHistory(history: List<String>) {
// }
// override fun readTextInputHistory(): List<String> {
// return listOf()
// }
// override fun saveThemeOverride(theme: String) {
// }
// override fun readThemeOverride(): String {
// return ""
// }
// override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) {
// }
// override fun readAccessTokenData(): AccessTokenData? {
// return null
// }
// override fun clearAccessTokenData() {
// }
// override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
// }
// override fun readImportedModels(): List<ImportedModelInfo> {
// return listOf()
// }
}

View file

@ -1,44 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import com.google.ai.edge.gallery.data.AGWorkInfo
import com.google.ai.edge.gallery.data.DownloadRepository
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatus
import java.util.UUID
class PreviewDownloadRepository : DownloadRepository {
override fun downloadModel(
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) {}
override fun cancelDownloadModel(model: Model) {}
override fun cancelAll(models: List<Model>, onComplete: () -> Unit) {}
override fun observerWorkerProgress(
workerId: UUID,
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) {}
override fun getEnqueuedOrRunningWorkInfos(): List<AGWorkInfo> {
return listOf()
}
}

View file

@ -1,21 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
class PreviewLlmSingleTurnViewModel : LlmSingleTurnViewModel(task = TASK_TEST1)

View file

@ -1,63 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
class PreviewModelManagerViewModel {}
// class PreviewModelManagerViewModel(context: Context) :
// ModelManagerViewModel(
// downloadRepository = PreviewDownloadRepository(),
// // dataStoreRepository = PreviewDataStoreRepository(),
// context = context,
// ) {
// init {
// for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
// task.index = index
// for (model in task.models) {
// model.preProcess()
// }
// }
// val modelDownloadStatus =
// mapOf(
// MODEL_TEST1.name to
// ModelDownloadStatus(
// status = ModelDownloadStatusType.IN_PROGRESS,
// receivedBytes = 1234,
// totalBytes = 3456,
// bytesPerSecond = 2333,
// remainingMs = 324,
// ),
// MODEL_TEST2.name to ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED),
// MODEL_TEST3.name to
// ModelDownloadStatus(
// status = ModelDownloadStatusType.FAILED,
// errorMessage = "Http code 404",
// ),
// MODEL_TEST4.name to ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED),
// )
// val newUiState =
// ModelManagerUiState(
// tasks = ALL_PREVIEW_TASKS,
// modelDownloadStatus = modelDownloadStatus,
// modelInitializationStatus = mapOf(),
// selectedModel = MODEL_TEST2,
// )
// _uiState.update { newUiState }
// }
// }

View file

@ -1,103 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.AccountBox
import androidx.compose.material.icons.rounded.AutoAwesome
import com.google.ai.edge.gallery.data.BooleanSwitchConfig
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.data.ValueType
val TEST_CONFIGS1: List<Config> =
listOf(
LabelConfig(key = ConfigKey.NAME, defaultValue = "Test name"),
NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f,
sliderMax = 5f,
defaultValue = 3f,
valueType = ValueType.INT,
),
BooleanSwitchConfig(key = ConfigKey.USE_GPU, defaultValue = false),
SegmentedButtonConfig(
key = ConfigKey.THEME,
defaultValue = "Auto",
options = listOf("Auto", "Light", "Dark"),
),
)
val MODEL_TEST1: Model =
Model(
name = "deterministic3",
downloadFileName = "deterministric3.json",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/deterministic3.json",
sizeInBytes = 40146048L,
configs = TEST_CONFIGS1,
)
val MODEL_TEST2: Model =
Model(
name = "isnet",
downloadFileName = "isnet.tflite",
url =
"https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/isnet-general-use-int8.tflite",
sizeInBytes = 44366296L,
configs = TEST_CONFIGS1,
)
val MODEL_TEST3: Model =
Model(
name = "yolo",
downloadFileName = "yolo.json",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/yolo.json",
sizeInBytes = 40641364L,
)
val MODEL_TEST4: Model =
Model(
name = "mobilenet v3",
downloadFileName = "mobilenet_v3_large.pt2",
url =
"https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/mobilenet_v3_large.pt2",
sizeInBytes = 277135998L,
)
val TASK_TEST1 =
Task(
type = TaskType.TEST_TASK_1,
icon = Icons.Rounded.AutoAwesome,
models = mutableListOf(MODEL_TEST1, MODEL_TEST2),
description = "This is a test task (1)",
)
val TASK_TEST2 =
Task(
type = TaskType.TEST_TASK_2,
icon = Icons.Rounded.AccountBox,
models = mutableListOf(MODEL_TEST3, MODEL_TEST4),
description = "This is a test task (2)",
)
val ALL_PREVIEW_TASKS: List<Task> = listOf(TASK_TEST1, TASK_TEST2)

View file

@ -19,4 +19,5 @@ plugins {
alias(libs.plugins.android.application) apply false alias(libs.plugins.android.application) apply false
alias(libs.plugins.kotlin.android) apply false alias(libs.plugins.kotlin.android) apply false
alias(libs.plugins.kotlin.compose) apply false alias(libs.plugins.kotlin.compose) apply false
alias(libs.plugins.hilt.application) apply false
} }

View file

@ -29,6 +29,8 @@ playServicesTfliteGpu= "16.4.0"
cameraX = "1.4.2" cameraX = "1.4.2"
netOpenidAppauth = "0.11.1" netOpenidAppauth = "0.11.1"
splashscreen = "1.2.0-beta01" splashscreen = "1.2.0-beta01"
hilt = "2.56.2"
hiltNavigation = "1.2.0"
[libraries] [libraries]
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
@ -67,6 +69,10 @@ camerax-view = { group = "androidx.camera", name = "camera-view", version.ref =
openid-appauth = { group = "net.openid", name = "appauth", version.ref = "netOpenidAppauth" } openid-appauth = { group = "net.openid", name = "appauth", version.ref = "netOpenidAppauth" }
androidx-splashscreen = { group = "androidx.core", name = "core-splashscreen", version.ref = "splashscreen" } androidx-splashscreen = { group = "androidx.core", name = "core-splashscreen", version.ref = "splashscreen" }
protobuf-javalite = { group = "com.google.protobuf", name = "protobuf-javalite", version.ref = "protobufJavaLite" } protobuf-javalite = { group = "com.google.protobuf", name = "protobuf-javalite", version.ref = "protobufJavaLite" }
hilt-android = { module = "com.google.dagger:hilt-android", version.ref = "hilt" }
hilt-navigation-compose = { module = "androidx.hilt:hilt-navigation-compose", version.ref = "hiltNavigation" }
hilt-android-testing = { module = "com.google.dagger:hilt-android-testing", version.ref = "hilt" }
hilt-android-compiler = { module = "com.google.dagger:hilt-android-compiler", version.ref = "hilt" }
[plugins] [plugins]
android-application = { id = "com.android.application", version.ref = "agp" } android-application = { id = "com.android.application", version.ref = "agp" }
@ -74,3 +80,4 @@ kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" }
kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "serializationPlugin" } kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "serializationPlugin" }
protobuf = {id = "com.google.protobuf", version.ref = "protobuf"} protobuf = {id = "com.google.protobuf", version.ref = "protobuf"}
hilt-application = { id = "com.google.dagger.hilt.android", version.ref = "hilt" }