mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-14 02:12:02 -04:00
Better support for sliding to change model
This commit is contained in:
parent
3341286efa
commit
d0beaab31e
8 changed files with 433 additions and 237 deletions
|
@ -90,6 +90,8 @@ data class Model(
|
|||
// The following fields are managed by the app. Don't need to set manually.
|
||||
var instance: Any? = null,
|
||||
var initializing: Boolean = false,
|
||||
// TODO(jingjin): use a "queue" system to manage model init and cleanup.
|
||||
var cleanUpAfterInit: Boolean = false,
|
||||
var configValues: Map<String, Any> = mapOf(),
|
||||
var totalBytes: Long = 0L,
|
||||
var accessToken: String? = null,
|
||||
|
|
|
@ -16,26 +16,19 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.common
|
||||
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
|
||||
import androidx.compose.material.icons.rounded.ArrowDropDown
|
||||
import androidx.compose.material.icons.rounded.Settings
|
||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.rememberModalBottomSheetState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
|
@ -44,7 +37,7 @@ import androidx.compose.runtime.remember
|
|||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.vectorResource
|
||||
|
@ -54,7 +47,6 @@ import com.google.aiedge.gallery.data.Model
|
|||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
||||
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
|
@ -69,14 +61,15 @@ fun ModelPageAppBar(
|
|||
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
|
||||
) {
|
||||
var showConfigDialog by remember { mutableStateOf(false) }
|
||||
var showModelPicker by remember { mutableStateOf(false) }
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||
val context = LocalContext.current
|
||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
||||
|
||||
CenterAlignedTopAppBar(title = {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
Column(
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
// Task type.
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
|
@ -95,29 +88,13 @@ fun ModelPageAppBar(
|
|||
)
|
||||
}
|
||||
|
||||
// Model name.
|
||||
Row(verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(2.dp),
|
||||
modifier = Modifier
|
||||
.clip(CircleShape)
|
||||
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
|
||||
.clickable {
|
||||
showModelPicker = true
|
||||
}
|
||||
.padding(start = 8.dp, end = 2.dp)) {
|
||||
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
||||
Text(
|
||||
model.name,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
modifier = Modifier.padding(start = 4.dp),
|
||||
)
|
||||
Icon(
|
||||
Icons.Rounded.ArrowDropDown,
|
||||
modifier = Modifier.size(20.dp),
|
||||
contentDescription = "",
|
||||
)
|
||||
}
|
||||
|
||||
// Model chips pager.
|
||||
ModelPickerChipsPager(
|
||||
task = task,
|
||||
initialModel = model,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onModelSelected = onModelSelected,
|
||||
)
|
||||
}
|
||||
}, modifier = modifier,
|
||||
// The back button.
|
||||
|
@ -131,14 +108,20 @@ fun ModelPageAppBar(
|
|||
},
|
||||
// The config button for the model (if existed).
|
||||
actions = {
|
||||
if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
IconButton(onClick = { showConfigDialog = true }) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.Settings,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
val showConfigButton =
|
||||
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
|
||||
IconButton(
|
||||
onClick = {
|
||||
showConfigDialog = true
|
||||
},
|
||||
enabled = showConfigButton,
|
||||
modifier = Modifier.alpha(if (showConfigButton) 1f else 0f)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.Settings,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -193,21 +176,4 @@ fun ModelPageAppBar(
|
|||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Model picker.
|
||||
if (showModelPicker) {
|
||||
ModalBottomSheet(
|
||||
onDismissRequest = { showModelPicker = false },
|
||||
sheetState = sheetState,
|
||||
) {
|
||||
ModelPicker(
|
||||
task = task,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onModelSelected = { model ->
|
||||
showModelPicker = false
|
||||
onModelSelected(model)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,178 @@
|
|||
/*
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.aiedge.gallery.ui.common
|
||||
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import androidx.compose.animation.fadeIn
|
||||
import androidx.compose.animation.fadeOut
|
||||
import androidx.compose.animation.scaleIn
|
||||
import androidx.compose.animation.scaleOut
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.pager.HorizontalPager
|
||||
import androidx.compose.foundation.pager.rememberPagerState
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.rounded.ArrowDropDown
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.rememberModalBottomSheetState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.graphicsLayer
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlin.math.absoluteValue
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ModelPickerChipsPager(
|
||||
task: Task,
|
||||
initialModel: Model,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
onModelSelected: (Model) -> Unit,
|
||||
) {
|
||||
var showModelPicker by remember { mutableStateOf(false) }
|
||||
var modelPickerModel by remember { mutableStateOf<Model?>(null) }
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||
val scope = rememberCoroutineScope()
|
||||
|
||||
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
|
||||
pageCount = { task.models.size })
|
||||
|
||||
// Sync scrolling.
|
||||
LaunchedEffect(modelManagerViewModel.pagerScrollState) {
|
||||
modelManagerViewModel.pagerScrollState.collect { state ->
|
||||
pagerState.scrollToPage(state.page, state.offset)
|
||||
}
|
||||
}
|
||||
|
||||
HorizontalPager(state = pagerState, userScrollEnabled = false) { pageIndex ->
|
||||
val model = task.models[pageIndex]
|
||||
|
||||
// Calculate the alpha of the current page based on how far they are from the center.
|
||||
val pageOffset =
|
||||
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
|
||||
val curAlpha = 1f - (pageOffset * 1.5f).coerceIn(0f, 1f)
|
||||
|
||||
val modelInitializationStatus =
|
||||
modelManagerUiState.modelInitializationStatus[model.name]
|
||||
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.graphicsLayer { alpha = curAlpha },
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(2.dp)
|
||||
) {
|
||||
Row(verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(2.dp),
|
||||
modifier = Modifier
|
||||
.clip(CircleShape)
|
||||
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
|
||||
.clickable {
|
||||
modelPickerModel = model
|
||||
showModelPicker = true
|
||||
}
|
||||
.padding(start = 8.dp, end = 2.dp)
|
||||
.padding(vertical = 1.dp)) Inner@{
|
||||
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(18.dp)) {
|
||||
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
||||
this@Inner.AnimatedVisibility(
|
||||
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||
enter = scaleIn() + fadeIn(),
|
||||
exit = scaleOut() + fadeOut(),
|
||||
) {
|
||||
// Circular progress indicator.
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier
|
||||
.size(17.dp)
|
||||
.alpha(0.5f),
|
||||
strokeWidth = 2.dp,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
Text(
|
||||
model.name,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
modifier = Modifier.padding(start = 4.dp),
|
||||
)
|
||||
Icon(
|
||||
Icons.Rounded.ArrowDropDown,
|
||||
modifier = Modifier.size(20.dp),
|
||||
contentDescription = "",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Model picker.
|
||||
val curModelPickerModel = modelPickerModel
|
||||
if (showModelPicker && curModelPickerModel != null) {
|
||||
ModalBottomSheet(
|
||||
onDismissRequest = { showModelPicker = false },
|
||||
sheetState = sheetState,
|
||||
) {
|
||||
ModelPicker(
|
||||
task = task,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
onModelSelected = { selectedModel ->
|
||||
showModelPicker = false
|
||||
|
||||
scope.launch(Dispatchers.Default) {
|
||||
// Scroll to the selected model.
|
||||
pagerState.animateScrollToPage(task.models.indexOf(selectedModel))
|
||||
}
|
||||
|
||||
onModelSelected(selectedModel)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,11 +16,6 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.common.chat
|
||||
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import androidx.compose.animation.fadeIn
|
||||
import androidx.compose.animation.fadeOut
|
||||
import androidx.compose.animation.scaleIn
|
||||
import androidx.compose.animation.scaleOut
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.gestures.detectTapGestures
|
||||
|
@ -33,7 +28,6 @@ import androidx.compose.foundation.layout.fillMaxSize
|
|||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.ime
|
||||
import androidx.compose.foundation.layout.imePadding
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.layout.width
|
||||
|
@ -426,16 +420,6 @@ fun ChatPanel(
|
|||
}
|
||||
}
|
||||
|
||||
// Model initialization in-progress message.
|
||||
this@Column.AnimatedVisibility(
|
||||
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||
enter = scaleIn() + fadeIn(),
|
||||
exit = scaleOut() + fadeOut(),
|
||||
modifier = Modifier.offset(y = 12.dp)
|
||||
) {
|
||||
ModelInitializationStatusChip()
|
||||
}
|
||||
|
||||
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
|
||||
}
|
||||
|
||||
|
|
|
@ -25,14 +25,17 @@ import androidx.compose.foundation.layout.fillMaxSize
|
|||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.pager.HorizontalPager
|
||||
import androidx.compose.foundation.pager.rememberPagerState
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.rememberCoroutineScope
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.snapshotFlow
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.graphicsLayer
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
|
@ -42,6 +45,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
|||
import com.google.aiedge.gallery.data.Task
|
||||
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||
|
@ -84,8 +88,10 @@ fun ChatView(
|
|||
pageCount = { task.models.size })
|
||||
val context = LocalContext.current
|
||||
val scope = rememberCoroutineScope()
|
||||
var navigatingUp by remember { mutableStateOf(false) }
|
||||
|
||||
val handleNavigateUp = {
|
||||
navigatingUp = true
|
||||
navigateUp()
|
||||
|
||||
// clean up all models.
|
||||
|
@ -99,9 +105,11 @@ fun ChatView(
|
|||
// Initialize model when model/download state changes.
|
||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
||||
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
|
||||
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||
if (!navigatingUp) {
|
||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
|
||||
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,6 +126,25 @@ fun ChatView(
|
|||
modelManagerViewModel.selectModel(curSelectedModel)
|
||||
}
|
||||
|
||||
LaunchedEffect(pagerState) {
|
||||
// Collect from the a snapshotFlow reading the currentPage
|
||||
snapshotFlow { pagerState.currentPage }.collect { page ->
|
||||
Log.d(TAG, "Page changed to $page")
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger scroll sync.
|
||||
LaunchedEffect(pagerState) {
|
||||
snapshotFlow {
|
||||
PagerScrollState(
|
||||
page = pagerState.currentPage,
|
||||
offset = pagerState.currentPageOffsetFraction
|
||||
)
|
||||
}.collect { scrollState ->
|
||||
modelManagerViewModel.pagerScrollState.value = scrollState
|
||||
}
|
||||
}
|
||||
|
||||
// Handle system's edge swipe.
|
||||
BackHandler {
|
||||
handleNavigateUp()
|
||||
|
|
|
@ -18,17 +18,11 @@ package com.google.aiedge.gallery.ui.llmsingleturn
|
|||
|
||||
import android.util.Log
|
||||
import androidx.activity.compose.BackHandler
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import androidx.compose.animation.fadeIn
|
||||
import androidx.compose.animation.fadeOut
|
||||
import androidx.compose.animation.scaleIn
|
||||
import androidx.compose.animation.scaleOut
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.calculateStartPadding
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Scaffold
|
||||
|
@ -42,6 +36,7 @@ import androidx.compose.runtime.rememberCoroutineScope
|
|||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalLayoutDirection
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
|
@ -50,8 +45,6 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
|||
import com.google.aiedge.gallery.ui.ViewModelProvider
|
||||
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
||||
import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
|
||||
import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
|
||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
|
@ -116,9 +109,6 @@ fun LlmSingleTurnScreen(
|
|||
}
|
||||
}
|
||||
|
||||
val modelInitializationStatus =
|
||||
modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
||||
|
||||
Scaffold(modifier = modifier, topBar = {
|
||||
ModelPageAppBar(
|
||||
task = task,
|
||||
|
@ -151,49 +141,42 @@ fun LlmSingleTurnScreen(
|
|||
)
|
||||
|
||||
// Main UI after model is downloaded.
|
||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||
Box(
|
||||
contentAlignment = Alignment.BottomCenter,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
VerticalSplitView(modifier = Modifier.fillMaxSize(),
|
||||
topView = {
|
||||
PromptTemplatesPanel(
|
||||
val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
|
||||
Box(
|
||||
contentAlignment = Alignment.BottomCenter,
|
||||
modifier = Modifier
|
||||
.weight(1f)
|
||||
// Just hide the UI without removing it from the screen so that the scroll syncing
|
||||
// from ResponsePanel still works.
|
||||
.alpha(if (modelDownloaded) 1.0f else 0.0f)
|
||||
) {
|
||||
VerticalSplitView(modifier = Modifier.fillMaxSize(),
|
||||
topView = {
|
||||
PromptTemplatesPanel(
|
||||
model = selectedModel,
|
||||
viewModel = viewModel,
|
||||
onSend = { fullPrompt ->
|
||||
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
|
||||
}, modifier = Modifier.fillMaxSize()
|
||||
)
|
||||
},
|
||||
bottomView = {
|
||||
Box(
|
||||
contentAlignment = Alignment.BottomCenter,
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||
) {
|
||||
ResponsePanel(
|
||||
model = selectedModel,
|
||||
viewModel = viewModel,
|
||||
onSend = { fullPrompt ->
|
||||
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
|
||||
}, modifier = Modifier.fillMaxSize()
|
||||
)
|
||||
},
|
||||
bottomView = {
|
||||
Box(
|
||||
contentAlignment = Alignment.BottomCenter,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.background(MaterialTheme.customColors.agentBubbleBgColor)
|
||||
) {
|
||||
ResponsePanel(
|
||||
model = selectedModel,
|
||||
viewModel = viewModel,
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(bottom = innerPadding.calculateBottomPadding())
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
// Model initialization in-progress message.
|
||||
this@Column.AnimatedVisibility(
|
||||
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||
enter = scaleIn() + fadeIn(),
|
||||
exit = scaleOut() + fadeOut(),
|
||||
modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding())
|
||||
) {
|
||||
ModelInitializationStatusChip()
|
||||
}
|
||||
|
||||
}
|
||||
.padding(bottom = innerPadding.calculateBottomPadding())
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||
|
||||
import android.util.Log
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
|
@ -24,6 +25,8 @@ import androidx.compose.foundation.layout.fillMaxSize
|
|||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.pager.HorizontalPager
|
||||
import androidx.compose.foundation.pager.rememberPagerState
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material.icons.Icons
|
||||
|
@ -45,162 +48,195 @@ import androidx.compose.runtime.getValue
|
|||
import androidx.compose.runtime.mutableIntStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.snapshotFlow
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalClipboardManager
|
||||
import androidx.compose.ui.text.AnnotatedString
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.aiedge.gallery.data.Model
|
||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState
|
||||
|
||||
private val OPTIONS = listOf("Response", "Benchmark")
|
||||
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
|
||||
private const val TAG = "AGResponsePanel"
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ResponsePanel(
|
||||
model: Model,
|
||||
viewModel: LlmSingleTurnViewModel,
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
modifier: Modifier = Modifier,
|
||||
) {
|
||||
val task = TASK_LLM_SINGLE_TURN
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
val inProgress = uiState.inProgress
|
||||
val initializing = uiState.initializing
|
||||
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||
val response = uiState.responsesByModel[model.name]?.get(selectedPromptTemplateType.label) ?: ""
|
||||
val benchmark = uiState.benchmarkByModel[model.name]?.get(selectedPromptTemplateType.label)
|
||||
val responseScrollState = rememberScrollState()
|
||||
var selectedOptionIndex by remember { mutableIntStateOf(0) }
|
||||
val clipboardManager = LocalClipboardManager.current
|
||||
|
||||
// Scroll to bottom when response changes.
|
||||
LaunchedEffect(response) {
|
||||
if (inProgress) {
|
||||
responseScrollState.animateScrollTo(responseScrollState.maxValue)
|
||||
}
|
||||
}
|
||||
val pagerState = rememberPagerState(
|
||||
initialPage = task.models.indexOf(model),
|
||||
pageCount = { task.models.size })
|
||||
|
||||
// Select the "response" tab when prompt template changes.
|
||||
LaunchedEffect(selectedPromptTemplateType) {
|
||||
selectedOptionIndex = 0
|
||||
}
|
||||
|
||||
if (initializing) {
|
||||
Box(
|
||||
contentAlignment = Alignment.TopStart,
|
||||
modifier = modifier
|
||||
.fillMaxSize()
|
||||
.padding(horizontal = 16.dp)
|
||||
) {
|
||||
MessageBodyLoading()
|
||||
// Update selected model and clean up previous model when page is settled on a model page.
|
||||
LaunchedEffect(pagerState.settledPage) {
|
||||
val curSelectedModel = task.models[pagerState.settledPage]
|
||||
Log.d(
|
||||
TAG,
|
||||
"Pager settled on model '${curSelectedModel.name}' from '${model.name}'. Updating selected model."
|
||||
)
|
||||
if (curSelectedModel.name != model.name) {
|
||||
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||
}
|
||||
} else {
|
||||
// Message when response is empty.
|
||||
if (response.isEmpty()) {
|
||||
Row(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
horizontalArrangement = Arrangement.Center,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Text(
|
||||
"Response will appear here",
|
||||
modifier = Modifier.alpha(0.5f),
|
||||
style = MaterialTheme.typography.labelMedium,
|
||||
)
|
||||
modelManagerViewModel.selectModel(curSelectedModel)
|
||||
}
|
||||
|
||||
// Trigger scroll sync.
|
||||
LaunchedEffect(pagerState) {
|
||||
snapshotFlow {
|
||||
PagerScrollState(
|
||||
page = pagerState.currentPage,
|
||||
offset = pagerState.currentPageOffsetFraction
|
||||
)
|
||||
}.collect { scrollState ->
|
||||
modelManagerViewModel.pagerScrollState.value = scrollState
|
||||
}
|
||||
}
|
||||
|
||||
// Scroll pager when selected model changes.
|
||||
LaunchedEffect(modelManagerUiState.selectedModel) {
|
||||
pagerState.animateScrollToPage(task.models.indexOf(model))
|
||||
}
|
||||
|
||||
HorizontalPager(state = pagerState) { pageIndex ->
|
||||
val curPageModel = task.models[pageIndex]
|
||||
|
||||
val response =
|
||||
uiState.responsesByModel[curPageModel.name]?.get(selectedPromptTemplateType.label) ?: ""
|
||||
val benchmark =
|
||||
uiState.benchmarkByModel[curPageModel.name]?.get(selectedPromptTemplateType.label)
|
||||
|
||||
// Scroll to bottom when response changes.
|
||||
LaunchedEffect(response) {
|
||||
if (inProgress) {
|
||||
responseScrollState.animateScrollTo(responseScrollState.maxValue)
|
||||
}
|
||||
}
|
||||
// Response markdown.
|
||||
else {
|
||||
Column(
|
||||
|
||||
if (initializing) {
|
||||
Box(
|
||||
contentAlignment = Alignment.TopStart,
|
||||
modifier = modifier
|
||||
.fillMaxSize()
|
||||
.padding(horizontal = 16.dp)
|
||||
.padding(bottom = 4.dp)
|
||||
) {
|
||||
// Response/benchmark switch.
|
||||
Row(modifier = Modifier.fillMaxWidth()) {
|
||||
PrimaryTabRow(
|
||||
selectedTabIndex = selectedOptionIndex,
|
||||
containerColor = Color.Transparent,
|
||||
) {
|
||||
OPTIONS.forEachIndexed { index, title ->
|
||||
Tab(selected = selectedOptionIndex == index, onClick = {
|
||||
selectedOptionIndex = index
|
||||
}, text = {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
Icon(
|
||||
ICONS[index],
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
.size(16.dp)
|
||||
.alpha(0.7f)
|
||||
)
|
||||
Text(text = title)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
MessageBodyLoading()
|
||||
}
|
||||
} else {
|
||||
// Message when response is empty.
|
||||
if (response.isEmpty()) {
|
||||
Row(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
horizontalArrangement = Arrangement.Center,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Text(
|
||||
"Response will appear here",
|
||||
modifier = Modifier.alpha(0.5f),
|
||||
style = MaterialTheme.typography.labelMedium,
|
||||
)
|
||||
}
|
||||
if (selectedOptionIndex == 0) {
|
||||
Box(
|
||||
contentAlignment = Alignment.BottomEnd,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.verticalScroll(responseScrollState)
|
||||
}
|
||||
// Response markdown.
|
||||
else {
|
||||
Column(
|
||||
modifier = modifier
|
||||
.padding(horizontal = 16.dp)
|
||||
.padding(bottom = 4.dp)
|
||||
) {
|
||||
// Response/benchmark switch.
|
||||
Row(modifier = Modifier.fillMaxWidth()) {
|
||||
PrimaryTabRow(
|
||||
selectedTabIndex = selectedOptionIndex,
|
||||
containerColor = Color.Transparent,
|
||||
) {
|
||||
MarkdownText(
|
||||
text = response,
|
||||
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp)
|
||||
)
|
||||
}
|
||||
// Copy button.
|
||||
IconButton(
|
||||
onClick = {
|
||||
val clipData = AnnotatedString(response)
|
||||
clipboardManager.setText(clipData)
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||
contentColor = MaterialTheme.colorScheme.primary,
|
||||
),
|
||||
) {
|
||||
Icon(
|
||||
Icons.Outlined.ContentCopy,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(20.dp),
|
||||
)
|
||||
OPTIONS.forEachIndexed { index, title ->
|
||||
Tab(selected = selectedOptionIndex == index, onClick = {
|
||||
selectedOptionIndex = index
|
||||
}, text = {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
Icon(
|
||||
ICONS[index],
|
||||
contentDescription = "",
|
||||
modifier = Modifier
|
||||
.size(16.dp)
|
||||
.alpha(0.7f)
|
||||
)
|
||||
Text(text = title)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (selectedOptionIndex == 1) {
|
||||
if (benchmark != null) {
|
||||
MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth())
|
||||
if (selectedOptionIndex == 0) {
|
||||
Box(
|
||||
contentAlignment = Alignment.BottomEnd,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.verticalScroll(responseScrollState)
|
||||
) {
|
||||
MarkdownText(
|
||||
text = response,
|
||||
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp)
|
||||
)
|
||||
}
|
||||
// Copy button.
|
||||
IconButton(
|
||||
onClick = {
|
||||
val clipData = AnnotatedString(response)
|
||||
clipboardManager.setText(clipData)
|
||||
},
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
|
||||
contentColor = MaterialTheme.colorScheme.primary,
|
||||
),
|
||||
) {
|
||||
Icon(
|
||||
Icons.Outlined.ContentCopy,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(20.dp),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else if (selectedOptionIndex == 1) {
|
||||
if (benchmark != null) {
|
||||
MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ResponsePanelPreview() {
|
||||
GalleryTheme {
|
||||
ResponsePanel(
|
||||
model = TASK_LLM_SINGLE_TURN.models[0],
|
||||
viewModel = LlmSingleTurnViewModel(),
|
||||
modifier = Modifier.fillMaxSize()
|
||||
)
|
||||
}
|
||||
}
|
|
@ -133,6 +133,11 @@ data class ModelManagerUiState(
|
|||
val textInputHistory: List<String> = listOf(),
|
||||
)
|
||||
|
||||
data class PagerScrollState(
|
||||
val page: Int = 0,
|
||||
val offset: Float = 0f,
|
||||
)
|
||||
|
||||
/**
|
||||
* ViewModel responsible for managing models, their download status, and initialization.
|
||||
*
|
||||
|
@ -150,9 +155,12 @@ open class ModelManagerViewModel(
|
|||
downloadRepository.getEnqueuedOrRunningWorkInfos()
|
||||
protected val _uiState = MutableStateFlow(createUiState())
|
||||
val uiState = _uiState.asStateFlow()
|
||||
|
||||
val authService = AuthorizationService(context)
|
||||
var curAccessToken: String = ""
|
||||
|
||||
var pagerScrollState: MutableStateFlow<PagerScrollState> = MutableStateFlow(PagerScrollState())
|
||||
|
||||
init {
|
||||
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
|
||||
|
||||
|
@ -269,6 +277,7 @@ open class ModelManagerViewModel(
|
|||
|
||||
// Skip if initialization is in progress.
|
||||
if (model.initializing) {
|
||||
model.cleanUpAfterInit = false
|
||||
Log.d(TAG, "Model '${model.name}' is being initialized. Skipping.")
|
||||
return@launch
|
||||
}
|
||||
|
@ -299,6 +308,10 @@ open class ModelManagerViewModel(
|
|||
model = model,
|
||||
status = ModelInitializationStatusType.INITIALIZED,
|
||||
)
|
||||
if (model.cleanUpAfterInit) {
|
||||
Log.d(TAG, "Model '${model.name}' needs cleaning up after init.")
|
||||
cleanupModel(task = task, model = model)
|
||||
}
|
||||
} else if (error.isNotEmpty()) {
|
||||
Log.d(TAG, "Model '${model.name}' failed to initialize")
|
||||
updateModelInitializationStatus(
|
||||
|
@ -345,6 +358,7 @@ open class ModelManagerViewModel(
|
|||
|
||||
fun cleanupModel(task: Task, model: Model) {
|
||||
if (model.instance != null) {
|
||||
model.cleanUpAfterInit = false
|
||||
Log.d(TAG, "Cleaning up model '${model.name}'...")
|
||||
when (task.type) {
|
||||
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
||||
|
@ -360,6 +374,12 @@ open class ModelManagerViewModel(
|
|||
updateModelInitializationStatus(
|
||||
model = model, status = ModelInitializationStatusType.NOT_INITIALIZED
|
||||
)
|
||||
} else {
|
||||
// When model is being initialized and we are trying to clean it up at same time, we mark it
|
||||
// to clean up and it will be cleaned up after initialization is done.
|
||||
if (model.initializing) {
|
||||
model.cleanUpAfterInit = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue