Better support for sliding to change model

This commit is contained in:
Jing Jin 2025-04-27 18:59:01 -07:00
parent 3341286efa
commit d0beaab31e
8 changed files with 433 additions and 237 deletions

View file

@ -90,6 +90,8 @@ data class Model(
// The following fields are managed by the app. Don't need to set manually.
var instance: Any? = null,
var initializing: Boolean = false,
// TODO(jingjin): use a "queue" system to manage model init and cleanup.
var cleanUpAfterInit: Boolean = false,
var configValues: Map<String, Any> = mapOf(),
var totalBytes: Long = 0L,
var accessToken: String? = null,

View file

@ -16,26 +16,19 @@
package com.google.aiedge.gallery.ui.common
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.ArrowDropDown
import androidx.compose.material.icons.rounded.Settings
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
@ -44,7 +37,7 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource
@ -54,7 +47,6 @@ import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
@OptIn(ExperimentalMaterial3Api::class)
@ -69,14 +61,15 @@ fun ModelPageAppBar(
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
) {
var showConfigDialog by remember { mutableStateOf(false) }
var showModelPicker by remember { mutableStateOf(false) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val context = LocalContext.current
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
CenterAlignedTopAppBar(title = {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(4.dp)
) {
// Task type.
Row(
verticalAlignment = Alignment.CenterVertically,
@ -95,29 +88,13 @@ fun ModelPageAppBar(
)
}
// Model name.
Row(verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
.clickable {
showModelPicker = true
}
.padding(start = 8.dp, end = 2.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text(
model.name,
style = MaterialTheme.typography.labelSmall,
modifier = Modifier.padding(start = 4.dp),
// Model chips pager.
ModelPickerChipsPager(
task = task,
initialModel = model,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = onModelSelected,
)
Icon(
Icons.Rounded.ArrowDropDown,
modifier = Modifier.size(20.dp),
contentDescription = "",
)
}
}
}, modifier = modifier,
// The back button.
@ -131,15 +108,21 @@ fun ModelPageAppBar(
},
// The config button for the model (if existed).
actions = {
if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
IconButton(onClick = { showConfigDialog = true }) {
val showConfigButton =
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
IconButton(
onClick = {
showConfigDialog = true
},
enabled = showConfigButton,
modifier = Modifier.alpha(if (showConfigButton) 1f else 0f)
) {
Icon(
imageVector = Icons.Rounded.Settings,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
}
}
})
// Config dialog.
@ -193,21 +176,4 @@ fun ModelPageAppBar(
},
)
}
// Model picker.
if (showModelPicker) {
ModalBottomSheet(
onDismissRequest = { showModelPicker = false },
sheetState = sheetState,
) {
ModelPicker(
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = { model ->
showModelPicker = false
onModelSelected(model)
}
)
}
}
}

View file

@ -0,0 +1,178 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.aiedge.gallery.ui.common
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.rounded.ArrowDropDown
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlin.math.absoluteValue
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelPickerChipsPager(
task: Task,
initialModel: Model,
modelManagerViewModel: ModelManagerViewModel,
onModelSelected: (Model) -> Unit,
) {
var showModelPicker by remember { mutableStateOf(false) }
var modelPickerModel by remember { mutableStateOf<Model?>(null) }
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
val scope = rememberCoroutineScope()
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
pageCount = { task.models.size })
// Sync scrolling.
LaunchedEffect(modelManagerViewModel.pagerScrollState) {
modelManagerViewModel.pagerScrollState.collect { state ->
pagerState.scrollToPage(state.page, state.offset)
}
}
HorizontalPager(state = pagerState, userScrollEnabled = false) { pageIndex ->
val model = task.models[pageIndex]
// Calculate the alpha of the current page based on how far they are from the center.
val pageOffset =
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
val curAlpha = 1f - (pageOffset * 1.5f).coerceIn(0f, 1f)
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[model.name]
Box(
modifier = Modifier
.fillMaxWidth()
.graphicsLayer { alpha = curAlpha },
contentAlignment = Alignment.Center
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp)
) {
Row(verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(2.dp),
modifier = Modifier
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
.clickable {
modelPickerModel = model
showModelPicker = true
}
.padding(start = 8.dp, end = 2.dp)
.padding(vertical = 1.dp)) Inner@{
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(18.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
this@Inner.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
) {
// Circular progress indicator.
CircularProgressIndicator(
modifier = Modifier
.size(17.dp)
.alpha(0.5f),
strokeWidth = 2.dp,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
Text(
model.name,
style = MaterialTheme.typography.labelSmall,
modifier = Modifier.padding(start = 4.dp),
)
Icon(
Icons.Rounded.ArrowDropDown,
modifier = Modifier.size(20.dp),
contentDescription = "",
)
}
}
}
}
// Model picker.
val curModelPickerModel = modelPickerModel
if (showModelPicker && curModelPickerModel != null) {
ModalBottomSheet(
onDismissRequest = { showModelPicker = false },
sheetState = sheetState,
) {
ModelPicker(
task = task,
modelManagerViewModel = modelManagerViewModel,
onModelSelected = { selectedModel ->
showModelPicker = false
scope.launch(Dispatchers.Default) {
// Scroll to the selected model.
pagerState.animateScrollToPage(task.models.indexOf(selectedModel))
}
onModelSelected(selectedModel)
}
)
}
}
}

View file

@ -16,11 +16,6 @@
package com.google.aiedge.gallery.ui.common.chat
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.gestures.detectTapGestures
@ -33,7 +28,6 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.ime
import androidx.compose.foundation.layout.imePadding
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
@ -426,16 +420,6 @@ fun ChatPanel(
}
}
// Model initialization in-progress message.
this@Column.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
modifier = Modifier.offset(y = 12.dp)
) {
ModelInitializationStatusChip()
}
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
}

View file

@ -25,14 +25,17 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshotFlow
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.platform.LocalContext
@ -42,6 +45,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.data.Task
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
@ -84,8 +88,10 @@ fun ChatView(
pageCount = { task.models.size })
val context = LocalContext.current
val scope = rememberCoroutineScope()
var navigatingUp by remember { mutableStateOf(false) }
val handleNavigateUp = {
navigatingUp = true
navigateUp()
// clean up all models.
@ -99,11 +105,13 @@ fun ChatView(
// Initialize model when model/download state changes.
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
LaunchedEffect(curDownloadStatus, selectedModel.name) {
if (!navigatingUp) {
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
}
}
}
// Update selected model and clean up previous model when page is settled on a model page.
LaunchedEffect(pagerState.settledPage) {
@ -118,6 +126,25 @@ fun ChatView(
modelManagerViewModel.selectModel(curSelectedModel)
}
LaunchedEffect(pagerState) {
// Collect from the a snapshotFlow reading the currentPage
snapshotFlow { pagerState.currentPage }.collect { page ->
Log.d(TAG, "Page changed to $page")
}
}
// Trigger scroll sync.
LaunchedEffect(pagerState) {
snapshotFlow {
PagerScrollState(
page = pagerState.currentPage,
offset = pagerState.currentPageOffsetFraction
)
}.collect { scrollState ->
modelManagerViewModel.pagerScrollState.value = scrollState
}
}
// Handle system's edge swipe.
BackHandler {
handleNavigateUp()

View file

@ -18,17 +18,11 @@ package com.google.aiedge.gallery.ui.llmsingleturn
import android.util.Log
import androidx.activity.compose.BackHandler
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.scaleIn
import androidx.compose.animation.scaleOut
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.calculateStartPadding
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
@ -42,6 +36,7 @@ import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLayoutDirection
import androidx.compose.ui.tooling.preview.Preview
@ -50,8 +45,6 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType
import com.google.aiedge.gallery.ui.ViewModelProvider
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
@ -116,9 +109,6 @@ fun LlmSingleTurnScreen(
}
}
val modelInitializationStatus =
modelManagerUiState.modelInitializationStatus[selectedModel.name]
Scaffold(modifier = modifier, topBar = {
ModelPageAppBar(
task = task,
@ -151,10 +141,14 @@ fun LlmSingleTurnScreen(
)
// Main UI after model is downloaded.
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
Box(
contentAlignment = Alignment.BottomCenter,
modifier = Modifier.weight(1f)
modifier = Modifier
.weight(1f)
// Just hide the UI without removing it from the screen so that the scroll syncing
// from ResponsePanel still works.
.alpha(if (modelDownloaded) 1.0f else 0.0f)
) {
VerticalSplitView(modifier = Modifier.fillMaxSize(),
topView = {
@ -176,24 +170,13 @@ fun LlmSingleTurnScreen(
ResponsePanel(
model = selectedModel,
viewModel = viewModel,
modelManagerViewModel = modelManagerViewModel,
modifier = Modifier
.fillMaxSize()
.padding(bottom = innerPadding.calculateBottomPadding())
)
}
})
// Model initialization in-progress message.
this@Column.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
enter = scaleIn() + fadeIn(),
exit = scaleOut() + fadeOut(),
modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding())
) {
ModelInitializationStatusChip()
}
}
}
}
}

View file

@ -16,6 +16,7 @@
package com.google.aiedge.gallery.ui.llmsingleturn
import android.util.Log
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
@ -24,6 +25,8 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.pager.HorizontalPager
import androidx.compose.foundation.pager.rememberPagerState
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
@ -45,40 +48,89 @@ import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshotFlow
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.aiedge.gallery.data.Model
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
import com.google.aiedge.gallery.ui.theme.GalleryTheme
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState
private val OPTIONS = listOf("Response", "Benchmark")
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
private const val TAG = "AGResponsePanel"
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ResponsePanel(
model: Model,
viewModel: LlmSingleTurnViewModel,
modelManagerViewModel: ModelManagerViewModel,
modifier: Modifier = Modifier,
) {
val task = TASK_LLM_SINGLE_TURN
val uiState by viewModel.uiState.collectAsState()
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
val inProgress = uiState.inProgress
val initializing = uiState.initializing
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
val response = uiState.responsesByModel[model.name]?.get(selectedPromptTemplateType.label) ?: ""
val benchmark = uiState.benchmarkByModel[model.name]?.get(selectedPromptTemplateType.label)
val responseScrollState = rememberScrollState()
var selectedOptionIndex by remember { mutableIntStateOf(0) }
val clipboardManager = LocalClipboardManager.current
val pagerState = rememberPagerState(
initialPage = task.models.indexOf(model),
pageCount = { task.models.size })
// Select the "response" tab when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) {
selectedOptionIndex = 0
}
// Update selected model and clean up previous model when page is settled on a model page.
LaunchedEffect(pagerState.settledPage) {
val curSelectedModel = task.models[pagerState.settledPage]
Log.d(
TAG,
"Pager settled on model '${curSelectedModel.name}' from '${model.name}'. Updating selected model."
)
if (curSelectedModel.name != model.name) {
modelManagerViewModel.cleanupModel(task = task, model = model)
}
modelManagerViewModel.selectModel(curSelectedModel)
}
// Trigger scroll sync.
LaunchedEffect(pagerState) {
snapshotFlow {
PagerScrollState(
page = pagerState.currentPage,
offset = pagerState.currentPageOffsetFraction
)
}.collect { scrollState ->
modelManagerViewModel.pagerScrollState.value = scrollState
}
}
// Scroll pager when selected model changes.
LaunchedEffect(modelManagerUiState.selectedModel) {
pagerState.animateScrollToPage(task.models.indexOf(model))
}
HorizontalPager(state = pagerState) { pageIndex ->
val curPageModel = task.models[pageIndex]
val response =
uiState.responsesByModel[curPageModel.name]?.get(selectedPromptTemplateType.label) ?: ""
val benchmark =
uiState.benchmarkByModel[curPageModel.name]?.get(selectedPromptTemplateType.label)
// Scroll to bottom when response changes.
LaunchedEffect(response) {
@ -87,11 +139,6 @@ fun ResponsePanel(
}
}
// Select the "response" tab when prompt template changes.
LaunchedEffect(selectedPromptTemplateType) {
selectedOptionIndex = 0
}
if (initializing) {
Box(
contentAlignment = Alignment.TopStart,
@ -191,16 +238,5 @@ fun ResponsePanel(
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ResponsePanelPreview() {
GalleryTheme {
ResponsePanel(
model = TASK_LLM_SINGLE_TURN.models[0],
viewModel = LlmSingleTurnViewModel(),
modifier = Modifier.fillMaxSize()
)
}
}

View file

@ -133,6 +133,11 @@ data class ModelManagerUiState(
val textInputHistory: List<String> = listOf(),
)
data class PagerScrollState(
val page: Int = 0,
val offset: Float = 0f,
)
/**
* ViewModel responsible for managing models, their download status, and initialization.
*
@ -150,9 +155,12 @@ open class ModelManagerViewModel(
downloadRepository.getEnqueuedOrRunningWorkInfos()
protected val _uiState = MutableStateFlow(createUiState())
val uiState = _uiState.asStateFlow()
val authService = AuthorizationService(context)
var curAccessToken: String = ""
var pagerScrollState: MutableStateFlow<PagerScrollState> = MutableStateFlow(PagerScrollState())
init {
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
@ -269,6 +277,7 @@ open class ModelManagerViewModel(
// Skip if initialization is in progress.
if (model.initializing) {
model.cleanUpAfterInit = false
Log.d(TAG, "Model '${model.name}' is being initialized. Skipping.")
return@launch
}
@ -299,6 +308,10 @@ open class ModelManagerViewModel(
model = model,
status = ModelInitializationStatusType.INITIALIZED,
)
if (model.cleanUpAfterInit) {
Log.d(TAG, "Model '${model.name}' needs cleaning up after init.")
cleanupModel(task = task, model = model)
}
} else if (error.isNotEmpty()) {
Log.d(TAG, "Model '${model.name}' failed to initialize")
updateModelInitializationStatus(
@ -345,6 +358,7 @@ open class ModelManagerViewModel(
fun cleanupModel(task: Task, model: Model) {
if (model.instance != null) {
model.cleanUpAfterInit = false
Log.d(TAG, "Cleaning up model '${model.name}'...")
when (task.type) {
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
@ -360,6 +374,12 @@ open class ModelManagerViewModel(
updateModelInitializationStatus(
model = model, status = ModelInitializationStatusType.NOT_INITIALIZED
)
} else {
// When model is being initialized and we are trying to clean it up at same time, we mark it
// to clean up and it will be cleaned up after initialization is done.
if (model.initializing) {
model.cleanUpAfterInit = true
}
}
}