Refactor code to migrate manual dependency injection to using Hilt

PiperOrigin-RevId: 776356661
This commit is contained in:
Google AI Edge Gallery 2025-06-26 18:08:36 -07:00 committed by Copybara-Service
parent 665c86a640
commit 323124a628
21 changed files with 209 additions and 506 deletions

View file

@ -20,6 +20,8 @@ plugins {
alias(libs.plugins.kotlin.compose)
alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.protobuf)
alias(libs.plugins.hilt.application)
kotlin("kapt")
}
android {
@ -91,11 +93,15 @@ dependencies {
implementation(libs.openid.appauth)
implementation(libs.androidx.splashscreen)
implementation(libs.protobuf.javalite)
implementation(libs.hilt.android)
implementation(libs.hilt.navigation.compose)
kapt(libs.hilt.android.compiler)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
androidTestImplementation(platform(libs.androidx.compose.bom))
androidTestImplementation(libs.androidx.ui.test.junit4)
androidTestImplementation(libs.hilt.android.testing)
debugImplementation(libs.androidx.ui.tooling)
debugImplementation(libs.androidx.ui.test.manifest)
}

View file

@ -17,48 +17,23 @@
package com.google.ai.edge.gallery
import android.app.Application
import android.content.Context
import androidx.datastore.core.CorruptionException
import androidx.datastore.core.DataStore
import androidx.datastore.core.Serializer
import androidx.datastore.dataStore
import com.google.ai.edge.gallery.common.writeLaunchInfo
import com.google.ai.edge.gallery.data.AppContainer
import com.google.ai.edge.gallery.data.DefaultAppContainer
import com.google.ai.edge.gallery.proto.Settings
import com.google.ai.edge.gallery.data.DataStoreRepository
import com.google.ai.edge.gallery.ui.theme.ThemeSettings
import com.google.protobuf.InvalidProtocolBufferException
import java.io.InputStream
import java.io.OutputStream
object SettingsSerializer : Serializer<Settings> {
override val defaultValue: Settings = Settings.getDefaultInstance()
override suspend fun readFrom(input: InputStream): Settings {
try {
return Settings.parseFrom(input)
} catch (exception: InvalidProtocolBufferException) {
throw CorruptionException("Cannot read proto.", exception)
}
}
override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
}
private val Context.dataStore: DataStore<Settings> by
dataStore(fileName = "settings.pb", serializer = SettingsSerializer)
import dagger.hilt.android.HiltAndroidApp
import javax.inject.Inject
@HiltAndroidApp
class GalleryApplication : Application() {
/** AppContainer instance used by the rest of classes to obtain dependencies */
lateinit var container: AppContainer
@Inject lateinit var dataStoreRepository: DataStoreRepository
override fun onCreate() {
super.onCreate()
writeLaunchInfo(context = this)
container = DefaultAppContainer(this, dataStore)
// Load saved theme.
ThemeSettings.themeOverride.value = container.dataStoreRepository.readTheme()
ThemeSettings.themeOverride.value = dataStoreRepository.readTheme()
}
}

View file

@ -27,8 +27,11 @@ import androidx.compose.material3.Surface
import androidx.compose.ui.Modifier
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import dagger.hilt.android.AndroidEntryPoint
@AndroidEntryPoint
class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
installSplashScreen()

View file

@ -0,0 +1,38 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery
import androidx.datastore.core.CorruptionException
import androidx.datastore.core.Serializer
import com.google.ai.edge.gallery.proto.Settings
import com.google.protobuf.InvalidProtocolBufferException
import java.io.InputStream
import java.io.OutputStream
object SettingsSerializer : Serializer<Settings> {
override val defaultValue: Settings = Settings.getDefaultInstance()
override suspend fun readFrom(input: InputStream): Settings {
try {
return Settings.parseFrom(input)
} catch (exception: InvalidProtocolBufferException) {
throw CorruptionException("Cannot read proto.", exception)
}
}
override suspend fun writeTo(t: Settings, output: OutputStream) = t.writeTo(output)
}

View file

@ -0,0 +1,86 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.di
import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.core.DataStoreFactory
import androidx.datastore.core.Serializer
import androidx.datastore.dataStoreFile
import com.google.ai.edge.gallery.AppLifecycleProvider
import com.google.ai.edge.gallery.GalleryLifecycleProvider
import com.google.ai.edge.gallery.SettingsSerializer
import com.google.ai.edge.gallery.data.DataStoreRepository
import com.google.ai.edge.gallery.data.DefaultDataStoreRepository
import com.google.ai.edge.gallery.data.DefaultDownloadRepository
import com.google.ai.edge.gallery.data.DownloadRepository
import com.google.ai.edge.gallery.proto.Settings
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.android.qualifiers.ApplicationContext
import dagger.hilt.components.SingletonComponent
import javax.inject.Singleton
@Module
@InstallIn(SingletonComponent::class)
internal object AppModule {
// Provides the SettingsSerializer
@Provides
@Singleton
fun provideSettingsSerializer(): Serializer<Settings> {
return SettingsSerializer
}
// Provides DataStore<Settings>
@Provides
@Singleton
fun provideSettingsDataStore(
@ApplicationContext context: Context,
settingsSerializer: Serializer<Settings>,
): DataStore<Settings> {
return DataStoreFactory.create(
serializer = settingsSerializer,
produceFile = { context.dataStoreFile("settings.pb") },
)
}
// Provides AppLifecycleProvider
@Provides
@Singleton
fun provideAppLifecycleProvider(): AppLifecycleProvider {
return GalleryLifecycleProvider()
}
// Provides DataStoreRepository
@Provides
@Singleton
fun provideDataStoreRepository(dataStore: DataStore<Settings>): DataStoreRepository {
return DefaultDataStoreRepository(dataStore)
}
// Provides DownloadRepository
@Provides
@Singleton
fun provideDownloadRepository(
@ApplicationContext context: Context,
lifecycleProvider: AppLifecycleProvider,
): DownloadRepository {
return DefaultDownloadRepository(context, lifecycleProvider)
}
}

View file

@ -1,64 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui
import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory
import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.initializer
import androidx.lifecycle.viewmodel.viewModelFactory
import com.google.ai.edge.gallery.GalleryApplication
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
object ViewModelProvider {
val Factory = viewModelFactory {
// Initializer for ModelManagerViewModel.
initializer {
val downloadRepository = galleryApplication().container.downloadRepository
val dataStoreRepository = galleryApplication().container.dataStoreRepository
val lifecycleProvider = galleryApplication().container.lifecycleProvider
ModelManagerViewModel(
downloadRepository = downloadRepository,
dataStoreRepository = dataStoreRepository,
lifecycleProvider = lifecycleProvider,
context = galleryApplication().container.context,
)
}
// Initializer for LlmChatViewModel.
initializer { LlmChatViewModel() }
// Initializer for LlmSingleTurnViewModel..
initializer { LlmSingleTurnViewModel() }
// Initializer for LlmAskImageViewModel.
initializer { LlmAskImageViewModel() }
// Initializer for LlmAskAudioViewModel.
initializer { LlmAskAudioViewModel() }
}
}
/**
* Extension function to queries for [Application] object and returns an instance of
* [GalleryApplication].
*/
fun CreationExtras.galleryApplication(): GalleryApplication =
(this[AndroidViewModelFactory.APPLICATION_KEY] as GalleryApplication)

View file

@ -53,7 +53,7 @@ data class ChatUiState(
)
/** ViewModel responsible for managing the chat UI state and handling chat-related operations. */
open class ChatViewModel(val task: Task) : ViewModel() {
abstract class ChatViewModel(val task: Task) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
val uiState = _uiState.asStateFlow()

View file

@ -20,8 +20,6 @@ import android.graphics.Bitmap
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
@ -46,7 +44,7 @@ fun LlmChatScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmChatViewModel = viewModel(factory = ViewModelProvider.Factory),
viewModel: LlmChatViewModel,
) {
ChatViewWrapper(
viewModel = viewModel,
@ -61,7 +59,7 @@ fun LlmAskImageScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmAskImageViewModel = viewModel(factory = ViewModelProvider.Factory),
viewModel: LlmAskImageViewModel,
) {
ChatViewWrapper(
viewModel = viewModel,
@ -76,7 +74,7 @@ fun LlmAskAudioScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmAskAudioViewModel = viewModel(factory = ViewModelProvider.Factory),
viewModel: LlmAskAudioViewModel,
) {
ChatViewWrapper(
viewModel = viewModel,
@ -88,7 +86,7 @@ fun LlmAskAudioScreen(
@Composable
fun ChatViewWrapper(
viewModel: LlmChatViewModel,
viewModel: LlmChatViewModelBase,
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,

View file

@ -36,6 +36,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
@ -49,7 +51,7 @@ private val STATS =
Stat(id = "latency", label = "Latency", unit = "sec"),
)
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTask) {
fun generateResponse(
model: Model,
input: String,
@ -75,9 +77,9 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
val instance = model.instance as LlmModelInstance
var prefillTokens = instance.session.sizeInTokens(input)
prefillTokens += images.size * 257
for (audioMessages in audioMessages) {
for (audioMessage in audioMessages) {
// 150ms = 1 audio token
val duration = audioMessages.getDurationInSeconds()
val duration = audioMessage.getDurationInSeconds()
prefillTokens += (duration * 1000f / 150f).toInt()
}
@ -259,6 +261,13 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
}
}
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
@HiltViewModel
class LlmChatViewModel @Inject constructor() : LlmChatViewModelBase(curTask = TASK_LLM_CHAT)
class LlmAskAudioViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_AUDIO)
@HiltViewModel
class LlmAskImageViewModel @Inject constructor() :
LlmChatViewModelBase(curTask = TASK_LLM_ASK_IMAGE)
@HiltViewModel
class LlmAskAudioViewModel @Inject constructor() :
LlmChatViewModelBase(curTask = TASK_LLM_ASK_AUDIO)

View file

@ -43,9 +43,8 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.ui.common.ErrorDialog
import com.google.ai.edge.gallery.ui.common.ModelPageAppBar
import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
@ -67,9 +66,9 @@ fun LlmSingleTurnScreen(
modelManagerViewModel: ModelManagerViewModel,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
viewModel: LlmSingleTurnViewModel = viewModel(factory = ViewModelProvider.Factory),
viewModel: LlmSingleTurnViewModel,
) {
val task = viewModel.task
val task = TASK_LLM_PROMPT_LAB
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val uiState by viewModel.uiState.collectAsState()
val selectedModel = modelManagerUiState.selectedModel

View file

@ -27,6 +27,8 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.Stat
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
@ -63,8 +65,9 @@ private val STATS =
Stat(id = "latency", label = "Latency", unit = "sec"),
)
open class LlmSingleTurnViewModel(val task: Task = TASK_LLM_PROMPT_LAB) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = task))
@HiltViewModel
class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = TASK_LLM_PROMPT_LAB))
val uiState = _uiState.asStateFlow()
fun generateResponse(model: Model, input: String) {

View file

@ -52,9 +52,12 @@ import com.google.ai.edge.gallery.ui.common.AuthConfig
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext
import java.io.File
import java.net.HttpURLConnection
import java.net.URL
import javax.inject.Inject
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
@ -134,11 +137,14 @@ data class PagerScrollState(val page: Int = 0, val offset: Float = 0f)
* cleaning up models. It also manages the UI state for model management, including the list of
* tasks, models, download statuses, and initialization statuses.
*/
open class ModelManagerViewModel(
@HiltViewModel
open class ModelManagerViewModel
@Inject
constructor(
private val downloadRepository: DownloadRepository,
private val dataStoreRepository: DataStoreRepository,
private val lifecycleProvider: AppLifecycleProvider,
context: Context,
@ApplicationContext private val context: Context,
) : ViewModel() {
private val externalFilesDir = context.getExternalFilesDir(null)
private val inProgressWorkInfos: List<AGWorkInfo> =

View file

@ -36,10 +36,10 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.zIndex
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.compose.LocalLifecycleOwner
import androidx.lifecycle.viewmodel.compose.viewModel
import androidx.navigation.NavBackStackEntry
import androidx.navigation.NavHostController
import androidx.navigation.NavType
@ -54,16 +54,19 @@ import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.data.getModelByName
import com.google.ai.edge.gallery.ui.ViewModelProvider
import com.google.ai.edge.gallery.ui.home.HomeScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatDestination
import com.google.ai.edge.gallery.ui.llmchat.LlmChatScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnDestination
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManager
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
@ -107,7 +110,7 @@ private fun AnimatedContentTransitionScope<*>.slideExit(): ExitTransition {
fun GalleryNavHost(
navController: NavHostController,
modifier: Modifier = Modifier,
modelManagerViewModel: ModelManagerViewModel = viewModel(factory = ViewModelProvider.Factory),
modelManagerViewModel: ModelManagerViewModel = hiltViewModel(),
) {
val lifecycleOwner = LocalLifecycleOwner.current
var showModelManager by remember { mutableStateOf(false) }
@ -184,11 +187,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_CHAT)?.let { defaultModel ->
) { backStackEntry ->
val viewModel: LlmChatViewModel = hiltViewModel(backStackEntry)
getModelFromNavigationParam(backStackEntry, TASK_LLM_CHAT)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmChatScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
@ -201,11 +207,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
) { backStackEntry ->
val viewModel: LlmSingleTurnViewModel = hiltViewModel(backStackEntry)
getModelFromNavigationParam(backStackEntry, TASK_LLM_PROMPT_LAB)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmSingleTurnScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
@ -218,11 +227,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
) { backStackEntry ->
val viewModel: LlmAskImageViewModel = hiltViewModel()
getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_IMAGE)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmAskImageScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)
@ -235,11 +247,14 @@ fun GalleryNavHost(
arguments = listOf(navArgument("modelName") { type = NavType.StringType }),
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
getModelFromNavigationParam(it, TASK_LLM_ASK_AUDIO)?.let { defaultModel ->
) { backStackEntry ->
val viewModel: LlmAskAudioViewModel = hiltViewModel()
getModelFromNavigationParam(backStackEntry, TASK_LLM_ASK_AUDIO)?.let { defaultModel ->
modelManagerViewModel.selectModel(defaultModel)
LlmAskAudioScreen(
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
navigateUp = { navController.navigateUp() },
)

View file

@ -1,91 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import android.content.Context
import android.graphics.Bitmap
import android.graphics.Canvas
import android.graphics.drawable.Drawable
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.core.content.ContextCompat
import androidx.core.graphics.createBitmap
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.common.Classification
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageClassification
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
class PreviewChatModel(context: Context) : ChatViewModel(task = TASK_TEST1) {
init {
val model = task.models[1]
addMessage(
model = model,
message =
ChatMessageText(
content =
"Thanks everyone for your enthusiasm on the team lunch, but people who can sign on the cheque is OOO next week \uD83D\uDE02,",
side = ChatSide.USER,
),
)
addMessage(
model = model,
message =
ChatMessageText(content = "Today is Wednesday!", side = ChatSide.AGENT, latencyMs = 1232f),
)
addMessage(
model = model,
message =
ChatMessageClassification(
classifications =
listOf(
Classification(label = "label1", score = 0.3f, color = Color.Red),
Classification(label = "label2", score = 0.7f, color = Color.Blue),
),
latencyMs = 12345f,
),
)
val bitmap =
getBitmapFromVectorDrawable(
context = context,
drawableId = R.drawable.ic_launcher_background,
)!!
addMessage(
model = model,
message =
ChatMessageImage(
bitmap = bitmap,
imageBitMap = bitmap.asImageBitmap(),
side = ChatSide.USER,
),
)
}
private fun getBitmapFromVectorDrawable(context: Context, drawableId: Int): Bitmap? {
val drawable: Drawable =
ContextCompat.getDrawable(context, drawableId) ?: return null // Drawable not found
val bitmap = createBitmap(drawable.intrinsicWidth, drawable.intrinsicHeight)
val canvas = Canvas(bitmap)
drawable.setBounds(0, 0, canvas.width, canvas.height)
drawable.draw(canvas)
return bitmap
}
}

View file

@ -1,57 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
// TODO(migration)
//
// import com.google.ai.edge.gallery.data.AccessTokenData
// import com.google.ai.edge.gallery.data.DataStoreRepository
// import com.google.ai.edge.gallery.data.ImportedModelInfo
// class PreviewDataStoreRepository : DataStoreRepository
class PreviewDataStoreRepository {
// override fun saveTextInputHistory(history: List<String>) {
// }
// override fun readTextInputHistory(): List<String> {
// return listOf()
// }
// override fun saveThemeOverride(theme: String) {
// }
// override fun readThemeOverride(): String {
// return ""
// }
// override fun saveAccessTokenData(accessToken: String, refreshToken: String, expiresAt: Long) {
// }
// override fun readAccessTokenData(): AccessTokenData? {
// return null
// }
// override fun clearAccessTokenData() {
// }
// override fun saveImportedModels(importedModels: List<ImportedModelInfo>) {
// }
// override fun readImportedModels(): List<ImportedModelInfo> {
// return listOf()
// }
}

View file

@ -1,44 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import com.google.ai.edge.gallery.data.AGWorkInfo
import com.google.ai.edge.gallery.data.DownloadRepository
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.ModelDownloadStatus
import java.util.UUID
class PreviewDownloadRepository : DownloadRepository {
override fun downloadModel(
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) {}
override fun cancelDownloadModel(model: Model) {}
override fun cancelAll(models: List<Model>, onComplete: () -> Unit) {}
override fun observerWorkerProgress(
workerId: UUID,
model: Model,
onStatusUpdated: (model: Model, status: ModelDownloadStatus) -> Unit,
) {}
override fun getEnqueuedOrRunningWorkInfos(): List<AGWorkInfo> {
return listOf()
}
}

View file

@ -1,21 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
class PreviewLlmSingleTurnViewModel : LlmSingleTurnViewModel(task = TASK_TEST1)

View file

@ -1,63 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
class PreviewModelManagerViewModel {}
// class PreviewModelManagerViewModel(context: Context) :
// ModelManagerViewModel(
// downloadRepository = PreviewDownloadRepository(),
// // dataStoreRepository = PreviewDataStoreRepository(),
// context = context,
// ) {
// init {
// for ((index, task) in ALL_PREVIEW_TASKS.withIndex()) {
// task.index = index
// for (model in task.models) {
// model.preProcess()
// }
// }
// val modelDownloadStatus =
// mapOf(
// MODEL_TEST1.name to
// ModelDownloadStatus(
// status = ModelDownloadStatusType.IN_PROGRESS,
// receivedBytes = 1234,
// totalBytes = 3456,
// bytesPerSecond = 2333,
// remainingMs = 324,
// ),
// MODEL_TEST2.name to ModelDownloadStatus(status = ModelDownloadStatusType.SUCCEEDED),
// MODEL_TEST3.name to
// ModelDownloadStatus(
// status = ModelDownloadStatusType.FAILED,
// errorMessage = "Http code 404",
// ),
// MODEL_TEST4.name to ModelDownloadStatus(status = ModelDownloadStatusType.NOT_DOWNLOADED),
// )
// val newUiState =
// ModelManagerUiState(
// tasks = ALL_PREVIEW_TASKS,
// modelDownloadStatus = modelDownloadStatus,
// modelInitializationStatus = mapOf(),
// selectedModel = MODEL_TEST2,
// )
// _uiState.update { newUiState }
// }
// }

View file

@ -1,103 +0,0 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.ai.edge.gallery.ui.preview
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.AccountBox
import androidx.compose.material.icons.rounded.AutoAwesome
import com.google.ai.edge.gallery.data.BooleanSwitchConfig
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.TaskType
import com.google.ai.edge.gallery.data.ValueType
val TEST_CONFIGS1: List<Config> =
listOf(
LabelConfig(key = ConfigKey.NAME, defaultValue = "Test name"),
NumberSliderConfig(
key = ConfigKey.MAX_RESULT_COUNT,
sliderMin = 1f,
sliderMax = 5f,
defaultValue = 3f,
valueType = ValueType.INT,
),
BooleanSwitchConfig(key = ConfigKey.USE_GPU, defaultValue = false),
SegmentedButtonConfig(
key = ConfigKey.THEME,
defaultValue = "Auto",
options = listOf("Auto", "Light", "Dark"),
),
)
val MODEL_TEST1: Model =
Model(
name = "deterministic3",
downloadFileName = "deterministric3.json",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/deterministic3.json",
sizeInBytes = 40146048L,
configs = TEST_CONFIGS1,
)
val MODEL_TEST2: Model =
Model(
name = "isnet",
downloadFileName = "isnet.tflite",
url =
"https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/isnet-general-use-int8.tflite",
sizeInBytes = 44366296L,
configs = TEST_CONFIGS1,
)
val MODEL_TEST3: Model =
Model(
name = "yolo",
downloadFileName = "yolo.json",
url = "https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/yolo.json",
sizeInBytes = 40641364L,
)
val MODEL_TEST4: Model =
Model(
name = "mobilenet v3",
downloadFileName = "mobilenet_v3_large.pt2",
url =
"https://storage.googleapis.com/tfweb/model-graph-vis-v2-test-models/mobilenet_v3_large.pt2",
sizeInBytes = 277135998L,
)
val TASK_TEST1 =
Task(
type = TaskType.TEST_TASK_1,
icon = Icons.Rounded.AutoAwesome,
models = mutableListOf(MODEL_TEST1, MODEL_TEST2),
description = "This is a test task (1)",
)
val TASK_TEST2 =
Task(
type = TaskType.TEST_TASK_2,
icon = Icons.Rounded.AccountBox,
models = mutableListOf(MODEL_TEST3, MODEL_TEST4),
description = "This is a test task (2)",
)
val ALL_PREVIEW_TASKS: List<Task> = listOf(TASK_TEST1, TASK_TEST2)

View file

@ -19,4 +19,5 @@ plugins {
alias(libs.plugins.android.application) apply false
alias(libs.plugins.kotlin.android) apply false
alias(libs.plugins.kotlin.compose) apply false
alias(libs.plugins.hilt.application) apply false
}

View file

@ -29,6 +29,8 @@ playServicesTfliteGpu= "16.4.0"
cameraX = "1.4.2"
netOpenidAppauth = "0.11.1"
splashscreen = "1.2.0-beta01"
hilt = "2.56.2"
hiltNavigation = "1.2.0"
[libraries]
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
@ -67,6 +69,10 @@ camerax-view = { group = "androidx.camera", name = "camera-view", version.ref =
openid-appauth = { group = "net.openid", name = "appauth", version.ref = "netOpenidAppauth" }
androidx-splashscreen = { group = "androidx.core", name = "core-splashscreen", version.ref = "splashscreen" }
protobuf-javalite = { group = "com.google.protobuf", name = "protobuf-javalite", version.ref = "protobufJavaLite" }
hilt-android = { module = "com.google.dagger:hilt-android", version.ref = "hilt" }
hilt-navigation-compose = { module = "androidx.hilt:hilt-navigation-compose", version.ref = "hiltNavigation" }
hilt-android-testing = { module = "com.google.dagger:hilt-android-testing", version.ref = "hilt" }
hilt-android-compiler = { module = "com.google.dagger:hilt-android-compiler", version.ref = "hilt" }
[plugins]
android-application = { id = "com.android.application", version.ref = "agp" }
@ -74,3 +80,4 @@ kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" }
kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "serializationPlugin" }
protobuf = {id = "com.google.protobuf", version.ref = "protobuf"}
hilt-application = { id = "com.google.dagger.hilt.android", version.ref = "hilt" }