Better support for sliding to change model

This commit is contained in:
Jing Jin 2025-04-27 18:59:01 -07:00
parent 3341286efa
commit d0beaab31e
8 changed files with 433 additions and 237 deletions

View file

@ -90,6 +90,8 @@ data class Model(
// The following fields are managed by the app. Don't need to set manually. // The following fields are managed by the app. Don't need to set manually.
var instance: Any? = null, var instance: Any? = null,
var initializing: Boolean = false, var initializing: Boolean = false,
// TODO(jingjin): use a "queue" system to manage model init and cleanup.
var cleanUpAfterInit: Boolean = false,
var configValues: Map<String, Any> = mapOf(), var configValues: Map<String, Any> = mapOf(),
var totalBytes: Long = 0L, var totalBytes: Long = 0L,
var accessToken: String? = null, var accessToken: String? = null,

View file

@ -16,26 +16,19 @@
package com.google.aiedge.gallery.ui.common package com.google.aiedge.gallery.ui.common
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.ArrowDropDown
import androidx.compose.material.icons.rounded.Settings import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material3.CenterAlignedTopAppBar import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
@ -44,7 +37,7 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource import androidx.compose.ui.res.vectorResource
@ -54,7 +47,6 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@ -69,14 +61,15 @@ fun ModelPageAppBar(
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> }, onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
) { ) {
var showConfigDialog by remember { mutableStateOf(false) } var showConfigDialog by remember { mutableStateOf(false) }
var showModelPicker by remember { mutableStateOf(false) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val context = LocalContext.current val context = LocalContext.current
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name] val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
CenterAlignedTopAppBar(title = { CenterAlignedTopAppBar(title = {
Column(horizontalAlignment = Alignment.CenterHorizontally) { Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(4.dp)
) {
// Task type. // Task type.
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
@ -95,29 +88,13 @@ fun ModelPageAppBar(
) )
} }
// Model name. // Model chips pager.
Row(verticalAlignment = Alignment.CenterVertically, ModelPickerChipsPager(
horizontalArrangement = Arrangement.spacedBy(2.dp), task = task,
modifier = Modifier initialModel = model,
.clip(CircleShape) modelManagerViewModel = modelManagerViewModel,
.background(MaterialTheme.colorScheme.surfaceContainerHigh) onModelSelected = onModelSelected,
.clickable { )
showModelPicker = true
}
.padding(start = 8.dp, end = 2.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text(
model.name,
style = MaterialTheme.typography.labelSmall,
modifier = Modifier.padding(start = 4.dp),
)
Icon(
Icons.Rounded.ArrowDropDown,
modifier = Modifier.size(20.dp),
contentDescription = "",
)
}
} }
}, modifier = modifier, }, modifier = modifier,
// The back button. // The back button.
@ -131,14 +108,20 @@ fun ModelPageAppBar(
}, },
// The config button for the model (if existed). // The config button for the model (if existed).
actions = { actions = {
if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { val showConfigButton =
IconButton(onClick = { showConfigDialog = true }) { model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
Icon( IconButton(
imageVector = Icons.Rounded.Settings, onClick = {
contentDescription = "", showConfigDialog = true
tint = MaterialTheme.colorScheme.primary },
) enabled = showConfigButton,
} modifier = Modifier.alpha(if (showConfigButton) 1f else 0f)
) {
Icon(
imageVector = Icons.Rounded.Settings,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
} }
}) })
@ -193,21 +176,4 @@ fun ModelPageAppBar(
}, },
) )
} }
// Model picker.
if (showModelPicker) {
ModalBottomSheet(
onDismissRequest = { showModelPicker = false },
sheetState = sheetState,
) {
ModelPicker(
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = { model ->
showModelPicker = false
onModelSelected(model)
}
)
}
}
} }

View file

@ -0,0 +1,178 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ArrowDropDown
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlin.math.absoluteValue
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelPickerChipsPager(
task: Task,
initialModel: Model,
modelManagerViewModel: ModelManagerViewModel,
onModelSelected: (Model) -> Unit,
) {
var showModelPicker by remember { mutableStateOf(false) }
var modelPickerModel by remember { mutableStateOf<Model?>(null) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val scope = rememberCoroutineScope()
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
pageCount = { task.models.size })
// Sync scrolling.
LaunchedEffect(modelManagerViewModel.pagerScrollState) {
modelManagerViewModel.pagerScrollState.collect { state ->
pagerState.scrollToPage(state.page, state.offset)
}
}
HorizontalPager(state = pagerState, userScrollEnabled = false) { pageIndex ->
val model = task.models[pageIndex]
// Calculate the alpha of the current page based on how far they are from the center.
val pageOffset =
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
val curAlpha = 1f - (pageOffset * 1.5f).coerceIn(0f, 1f)
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[model.name]
Box(
modifier = Modifier
.fillMaxWidth()
.graphicsLayer { alpha = curAlpha },
contentAlignment = Alignment.Center
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp)
) {
Row(verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
.clickable {
modelPickerModel = model
showModelPicker = true
}
.padding(start = 8.dp, end = 2.dp)
.padding(vertical = 1.dp)) Inner@{
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(18.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
this@Inner.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
) {
// Circular progress indicator.
CircularProgressIndicator(
modifier = Modifier
.size(17.dp)
.alpha(0.5f),
strokeWidth = 2.dp,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
Text(
model.name,
style = MaterialTheme.typography.labelSmall,
modifier = Modifier.padding(start = 4.dp),
)
Icon(
Icons.Rounded.ArrowDropDown,
modifier = Modifier.size(20.dp),
contentDescription = "",
)
}
}
}
}
// Model picker.
val curModelPickerModel = modelPickerModel
if (showModelPicker && curModelPickerModel != null) {
ModalBottomSheet(
onDismissRequest = { showModelPicker = false },
sheetState = sheetState,
) {
ModelPicker(
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = { selectedModel ->
showModelPicker = false
scope.launch(Dispatchers.Default) {
// Scroll to the selected model.
pagerState.animateScrollToPage(task.models.indexOf(selectedModel))
}
onModelSelected(selectedModel)
}
)
}
}
}

View file

@ -16,11 +16,6 @@
package com.google.aiedge.gallery.ui.common.chat package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.gestures.detectTapGestures import androidx.compose.foundation.gestures.detectTapGestures
@ -33,7 +28,6 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.ime import androidx.compose.foundation.layout.ime
import androidx.compose.foundation.layout.imePadding import androidx.compose.foundation.layout.imePadding
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.width
@ -426,16 +420,6 @@ fun ChatPanel(
} }
} }
// Model initialization in-progress message.
this@Column.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
modifier = Modifier.offset(y = 12.dp)
) {
ModelInitializationStatusChip()
}
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp)) SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
} }

View file

@ -25,14 +25,17 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.pager.HorizontalPager import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshotFlow
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
@ -42,6 +45,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.ModelPageAppBar import com.google.aiedge.gallery.ui.common.ModelPageAppBar
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState
import com.google.aiedge.gallery.ui.preview.PreviewChatModel import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1 import com.google.aiedge.gallery.ui.preview.TASK_TEST1
@ -84,8 +88,10 @@ fun ChatView(
pageCount = { task.models.size }) pageCount = { task.models.size })
val context = LocalContext.current val context = LocalContext.current
val scope = rememberCoroutineScope() val scope = rememberCoroutineScope()
var navigatingUp by remember { mutableStateOf(false) }
val handleNavigateUp = { val handleNavigateUp = {
navigatingUp = true
navigateUp() navigateUp()
// clean up all models. // clean up all models.
@ -99,9 +105,11 @@ fun ChatView(
// Initialize model when model/download state changes. // Initialize model when model/download state changes.
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name] val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
LaunchedEffect(curDownloadStatus, selectedModel.name) { LaunchedEffect(curDownloadStatus, selectedModel.name) {
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { if (!navigatingUp) {
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect") if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
}
} }
} }
@ -118,6 +126,25 @@ fun ChatView(
modelManagerViewModel.selectModel(curSelectedModel) modelManagerViewModel.selectModel(curSelectedModel)
} }
LaunchedEffect(pagerState) {
// Collect from the a snapshotFlow reading the currentPage
snapshotFlow { pagerState.currentPage }.collect { page ->
Log.d(TAG, "Page changed to $page")
}
}
// Trigger scroll sync.
LaunchedEffect(pagerState) {
snapshotFlow {
PagerScrollState(
page = pagerState.currentPage,
offset = pagerState.currentPageOffsetFraction
)
}.collect { scrollState ->
modelManagerViewModel.pagerScrollState.value = scrollState
}
}
// Handle system's edge swipe. // Handle system's edge swipe.
BackHandler { BackHandler {
handleNavigateUp() handleNavigateUp()

View file

@ -18,17 +18,11 @@ package com.google.aiedge.gallery.ui.llmsingleturn
import android.util.Log import android.util.Log
import androidx.activity.compose.BackHandler import androidx.activity.compose.BackHandler
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.calculateStartPadding import androidx.compose.foundation.layout.calculateStartPadding
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
@ -42,6 +36,7 @@ import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection import androidx.compose.ui.platform.LocalLayoutDirection
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
@ -50,8 +45,6 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.ViewModelProvider import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.ModelPageAppBar import com.google.aiedge.gallery.ui.common.ModelPageAppBar
import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
@ -116,9 +109,6 @@ fun LlmSingleTurnScreen(
} }
} }
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[selectedModel.name]
Scaffold(modifier = modifier, topBar = { Scaffold(modifier = modifier, topBar = {
ModelPageAppBar( ModelPageAppBar(
task = task, task = task,
@ -151,49 +141,42 @@ fun LlmSingleTurnScreen(
) )
// Main UI after model is downloaded. // Main UI after model is downloaded.
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
Box( Box(
contentAlignment = Alignment.BottomCenter, contentAlignment = Alignment.BottomCenter,
modifier = Modifier.weight(1f) modifier = Modifier
) { .weight(1f)
VerticalSplitView(modifier = Modifier.fillMaxSize(), // Just hide the UI without removing it from the screen so that the scroll syncing
topView = { // from ResponsePanel still works.
PromptTemplatesPanel( .alpha(if (modelDownloaded) 1.0f else 0.0f)
) {
VerticalSplitView(modifier = Modifier.fillMaxSize(),
topView = {
PromptTemplatesPanel(
model = selectedModel,
viewModel = viewModel,
onSend = { fullPrompt ->
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
}, modifier = Modifier.fillMaxSize()
)
},
bottomView = {
Box(
contentAlignment = Alignment.BottomCenter,
modifier = Modifier
.fillMaxSize()
.background(MaterialTheme.customColors.agentBubbleBgColor)
) {
ResponsePanel(
model = selectedModel, model = selectedModel,
viewModel = viewModel, viewModel = viewModel,
onSend = { fullPrompt -> modelManagerViewModel = modelManagerViewModel,
viewModel.generateResponse(model = selectedModel, input = fullPrompt)
}, modifier = Modifier.fillMaxSize()
)
},
bottomView = {
Box(
contentAlignment = Alignment.BottomCenter,
modifier = Modifier modifier = Modifier
.fillMaxSize() .fillMaxSize()
.background(MaterialTheme.customColors.agentBubbleBgColor) .padding(bottom = innerPadding.calculateBottomPadding())
) { )
ResponsePanel( }
model = selectedModel, })
viewModel = viewModel,
modifier = Modifier
.fillMaxSize()
.padding(bottom = innerPadding.calculateBottomPadding())
)
}
})
// Model initialization in-progress message.
this@Column.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding())
) {
ModelInitializationStatusChip()
}
}
} }
} }
} }

View file

@ -16,6 +16,7 @@
package com.google.aiedge.gallery.ui.llmsingleturn package com.google.aiedge.gallery.ui.llmsingleturn
import android.util.Log
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
@ -24,6 +25,8 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
@ -45,162 +48,195 @@ import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshotFlow
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.ui.common.chat.MarkdownText import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
import com.google.aiedge.gallery.ui.theme.GalleryTheme import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState
private val OPTIONS = listOf("Response", "Benchmark") private val OPTIONS = listOf("Response", "Benchmark")
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer) private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
private const val TAG = "AGResponsePanel"
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun ResponsePanel( fun ResponsePanel(
model: Model, model: Model,
viewModel: LlmSingleTurnViewModel, viewModel: LlmSingleTurnViewModel,
modelManagerViewModel: ModelManagerViewModel,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
) { ) {
val task = TASK_LLM_SINGLE_TURN
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val inProgress = uiState.inProgress val inProgress = uiState.inProgress
val initializing = uiState.initializing val initializing = uiState.initializing
val selectedPromptTemplateType = uiState.selectedPromptTemplateType val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val response = uiState.responsesByModel[model.name]?.get(selectedPromptTemplateType.label) ?: ""
val benchmark = uiState.benchmarkByModel[model.name]?.get(selectedPromptTemplateType.label)
val responseScrollState = rememberScrollState() val responseScrollState = rememberScrollState()
var selectedOptionIndex by remember { mutableIntStateOf(0) } var selectedOptionIndex by remember { mutableIntStateOf(0) }
val clipboardManager = LocalClipboardManager.current val clipboardManager = LocalClipboardManager.current
val pagerState = rememberPagerState(
// Scroll to bottom when response changes. initialPage = task.models.indexOf(model),
LaunchedEffect(response) { pageCount = { task.models.size })
if (inProgress) {
responseScrollState.animateScrollTo(responseScrollState.maxValue)
}
}
// Select the "response" tab when prompt template changes. // Select the "response" tab when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) { LaunchedEffect(selectedPromptTemplateType) {
selectedOptionIndex = 0 selectedOptionIndex = 0
} }
if (initializing) { // Update selected model and clean up previous model when page is settled on a model page.
Box( LaunchedEffect(pagerState.settledPage) {
contentAlignment = Alignment.TopStart, val curSelectedModel = task.models[pagerState.settledPage]
modifier = modifier Log.d(
.fillMaxSize() TAG,
.padding(horizontal = 16.dp) "Pager settled on model '${curSelectedModel.name}' from '${model.name}'. Updating selected model."
) {
MessageBodyLoading()
}
} else {
// Message when response is empty.
if (response.isEmpty()) {
Row(
modifier = Modifier.fillMaxSize(),
horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically
) {
Text(
"Response will appear here",
modifier = Modifier.alpha(0.5f),
style = MaterialTheme.typography.labelMedium,
)
}
}
// Response markdown.
else {
Column(
modifier = modifier
.padding(horizontal = 16.dp)
.padding(bottom = 4.dp)
) {
// Response/benchmark switch.
Row(modifier = Modifier.fillMaxWidth()) {
PrimaryTabRow(
selectedTabIndex = selectedOptionIndex,
containerColor = Color.Transparent,
) {
OPTIONS.forEachIndexed { index, title ->
Tab(selected = selectedOptionIndex == index, onClick = {
selectedOptionIndex = index
}, text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp)
) {
Icon(
ICONS[index],
contentDescription = "",
modifier = Modifier
.size(16.dp)
.alpha(0.7f)
)
Text(text = title)
}
})
}
}
}
if (selectedOptionIndex == 0) {
Box(
contentAlignment = Alignment.BottomEnd,
modifier = Modifier.weight(1f)
) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(responseScrollState)
) {
MarkdownText(
text = response,
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp)
)
}
// Copy button.
IconButton(
onClick = {
val clipData = AnnotatedString(response)
clipboardManager.setText(clipData)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
contentColor = MaterialTheme.colorScheme.primary,
),
) {
Icon(
Icons.Outlined.ContentCopy,
contentDescription = "",
modifier = Modifier.size(20.dp),
)
}
}
} else if (selectedOptionIndex == 1) {
if (benchmark != null) {
MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth())
}
}
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ResponsePanelPreview() {
GalleryTheme {
ResponsePanel(
model = TASK_LLM_SINGLE_TURN.models[0],
viewModel = LlmSingleTurnViewModel(),
modifier = Modifier.fillMaxSize()
) )
if (curSelectedModel.name != model.name) {
modelManagerViewModel.cleanupModel(task = task, model = model)
}
modelManagerViewModel.selectModel(curSelectedModel)
}
// Trigger scroll sync.
LaunchedEffect(pagerState) {
snapshotFlow {
PagerScrollState(
page = pagerState.currentPage,
offset = pagerState.currentPageOffsetFraction
)
}.collect { scrollState ->
modelManagerViewModel.pagerScrollState.value = scrollState
}
}
// Scroll pager when selected model changes.
LaunchedEffect(modelManagerUiState.selectedModel) {
pagerState.animateScrollToPage(task.models.indexOf(model))
}
HorizontalPager(state = pagerState) { pageIndex ->
val curPageModel = task.models[pageIndex]
val response =
uiState.responsesByModel[curPageModel.name]?.get(selectedPromptTemplateType.label) ?: ""
val benchmark =
uiState.benchmarkByModel[curPageModel.name]?.get(selectedPromptTemplateType.label)
// Scroll to bottom when response changes.
LaunchedEffect(response) {
if (inProgress) {
responseScrollState.animateScrollTo(responseScrollState.maxValue)
}
}
if (initializing) {
Box(
contentAlignment = Alignment.TopStart,
modifier = modifier
.fillMaxSize()
.padding(horizontal = 16.dp)
) {
MessageBodyLoading()
}
} else {
// Message when response is empty.
if (response.isEmpty()) {
Row(
modifier = Modifier.fillMaxSize(),
horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically
) {
Text(
"Response will appear here",
modifier = Modifier.alpha(0.5f),
style = MaterialTheme.typography.labelMedium,
)
}
}
// Response markdown.
else {
Column(
modifier = modifier
.padding(horizontal = 16.dp)
.padding(bottom = 4.dp)
) {
// Response/benchmark switch.
Row(modifier = Modifier.fillMaxWidth()) {
PrimaryTabRow(
selectedTabIndex = selectedOptionIndex,
containerColor = Color.Transparent,
) {
OPTIONS.forEachIndexed { index, title ->
Tab(selected = selectedOptionIndex == index, onClick = {
selectedOptionIndex = index
}, text = {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp)
) {
Icon(
ICONS[index],
contentDescription = "",
modifier = Modifier
.size(16.dp)
.alpha(0.7f)
)
Text(text = title)
}
})
}
}
}
if (selectedOptionIndex == 0) {
Box(
contentAlignment = Alignment.BottomEnd,
modifier = Modifier.weight(1f)
) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(responseScrollState)
) {
MarkdownText(
text = response,
modifier = Modifier.padding(top = 8.dp, bottom = 40.dp)
)
}
// Copy button.
IconButton(
onClick = {
val clipData = AnnotatedString(response)
clipboardManager.setText(clipData)
},
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.surfaceContainerHighest,
contentColor = MaterialTheme.colorScheme.primary,
),
) {
Icon(
Icons.Outlined.ContentCopy,
contentDescription = "",
modifier = Modifier.size(20.dp),
)
}
}
} else if (selectedOptionIndex == 1) {
if (benchmark != null) {
MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth())
}
}
}
}
}
} }
} }

View file

@ -133,6 +133,11 @@ data class ModelManagerUiState(
val textInputHistory: List<String> = listOf(), val textInputHistory: List<String> = listOf(),
) )
data class PagerScrollState(
val page: Int = 0,
val offset: Float = 0f,
)
/** /**
* ViewModel responsible for managing models, their download status, and initialization. * ViewModel responsible for managing models, their download status, and initialization.
* *
@ -150,9 +155,12 @@ open class ModelManagerViewModel(
downloadRepository.getEnqueuedOrRunningWorkInfos() downloadRepository.getEnqueuedOrRunningWorkInfos()
protected val _uiState = MutableStateFlow(createUiState()) protected val _uiState = MutableStateFlow(createUiState())
val uiState = _uiState.asStateFlow() val uiState = _uiState.asStateFlow()
val authService = AuthorizationService(context) val authService = AuthorizationService(context)
var curAccessToken: String = "" var curAccessToken: String = ""
var pagerScrollState: MutableStateFlow<PagerScrollState> = MutableStateFlow(PagerScrollState())
init { init {
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos") Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
@ -269,6 +277,7 @@ open class ModelManagerViewModel(
// Skip if initialization is in progress. // Skip if initialization is in progress.
if (model.initializing) { if (model.initializing) {
model.cleanUpAfterInit = false
Log.d(TAG, "Model '${model.name}' is being initialized. Skipping.") Log.d(TAG, "Model '${model.name}' is being initialized. Skipping.")
return@launch return@launch
} }
@ -299,6 +308,10 @@ open class ModelManagerViewModel(
model = model, model = model,
status = ModelInitializationStatusType.INITIALIZED, status = ModelInitializationStatusType.INITIALIZED,
) )
if (model.cleanUpAfterInit) {
Log.d(TAG, "Model '${model.name}' needs cleaning up after init.")
cleanupModel(task = task, model = model)
}
} else if (error.isNotEmpty()) { } else if (error.isNotEmpty()) {
Log.d(TAG, "Model '${model.name}' failed to initialize") Log.d(TAG, "Model '${model.name}' failed to initialize")
updateModelInitializationStatus( updateModelInitializationStatus(
@ -345,6 +358,7 @@ open class ModelManagerViewModel(
fun cleanupModel(task: Task, model: Model) { fun cleanupModel(task: Task, model: Model) {
if (model.instance != null) { if (model.instance != null) {
model.cleanUpAfterInit = false
Log.d(TAG, "Cleaning up model '${model.name}'...") Log.d(TAG, "Cleaning up model '${model.name}'...")
when (task.type) { when (task.type) {
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model) TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
@ -360,6 +374,12 @@ open class ModelManagerViewModel(
updateModelInitializationStatus( updateModelInitializationStatus(
model = model, status = ModelInitializationStatusType.NOT_INITIALIZED model = model, status = ModelInitializationStatusType.NOT_INITIALIZED
) )
} else {
// When model is being initialized and we are trying to clean it up at same time, we mark it
// to clean up and it will be cleaned up after initialization is done.
if (model.initializing) {
model.cleanUpAfterInit = true
}
} }
} }