diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt index 78a897c..7b99103 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/data/Model.kt @@ -90,6 +90,8 @@ data class Model( // The following fields are managed by the app. Don't need to set manually. var instance: Any? = null, var initializing: Boolean = false, + // TODO(jingjin): use a "queue" system to manage model init and cleanup. + var cleanUpAfterInit: Boolean = false, var configValues: Map = mapOf(), var totalBytes: Long = 0L, var accessToken: String? = null, diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt index 6832e59..1d940db 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPageAppBar.kt @@ -16,26 +16,19 @@ package com.google.aiedge.gallery.ui.common -import androidx.compose.foundation.background -import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row -import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size -import androidx.compose.foundation.shape.CircleShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.rounded.ArrowBack -import androidx.compose.material.icons.rounded.ArrowDropDown import androidx.compose.material.icons.rounded.Settings import androidx.compose.material3.CenterAlignedTopAppBar import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.IconButton import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.ModalBottomSheet import androidx.compose.material3.Text -import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.runtime.Composable import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue @@ -44,7 +37,7 @@ import androidx.compose.runtime.remember import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier -import androidx.compose.ui.draw.clip +import androidx.compose.ui.draw.alpha import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.res.vectorResource @@ -54,7 +47,6 @@ import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.chat.ConfigDialog -import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel @OptIn(ExperimentalMaterial3Api::class) @@ -69,14 +61,15 @@ fun ModelPageAppBar( onConfigChanged: (oldConfigValues: Map, newConfigValues: Map) -> Unit = { _, _ -> }, ) { var showConfigDialog by remember { mutableStateOf(false) } - var showModelPicker by remember { mutableStateOf(false) } val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() - val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) val context = LocalContext.current val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name] CenterAlignedTopAppBar(title = { - Column(horizontalAlignment = Alignment.CenterHorizontally) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(4.dp) + ) { // Task type. Row( verticalAlignment = Alignment.CenterVertically, @@ -95,29 +88,13 @@ fun ModelPageAppBar( ) } - // Model name. - Row(verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.spacedBy(2.dp), - modifier = Modifier - .clip(CircleShape) - .background(MaterialTheme.colorScheme.surfaceContainerHigh) - .clickable { - showModelPicker = true - } - .padding(start = 8.dp, end = 2.dp)) { - StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name]) - Text( - model.name, - style = MaterialTheme.typography.labelSmall, - modifier = Modifier.padding(start = 4.dp), - ) - Icon( - Icons.Rounded.ArrowDropDown, - modifier = Modifier.size(20.dp), - contentDescription = "", - ) - } - + // Model chips pager. + ModelPickerChipsPager( + task = task, + initialModel = model, + modelManagerViewModel = modelManagerViewModel, + onModelSelected = onModelSelected, + ) } }, modifier = modifier, // The back button. @@ -131,14 +108,20 @@ fun ModelPageAppBar( }, // The config button for the model (if existed). actions = { - if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { - IconButton(onClick = { showConfigDialog = true }) { - Icon( - imageVector = Icons.Rounded.Settings, - contentDescription = "", - tint = MaterialTheme.colorScheme.primary - ) - } + val showConfigButton = + model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED + IconButton( + onClick = { + showConfigDialog = true + }, + enabled = showConfigButton, + modifier = Modifier.alpha(if (showConfigButton) 1f else 0f) + ) { + Icon( + imageVector = Icons.Rounded.Settings, + contentDescription = "", + tint = MaterialTheme.colorScheme.primary + ) } }) @@ -193,21 +176,4 @@ fun ModelPageAppBar( }, ) } - - // Model picker. - if (showModelPicker) { - ModalBottomSheet( - onDismissRequest = { showModelPicker = false }, - sheetState = sheetState, - ) { - ModelPicker( - task = task, - modelManagerViewModel = modelManagerViewModel, - onModelSelected = { model -> - showModelPicker = false - onModelSelected(model) - } - ) - } - } } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt new file mode 100644 index 0000000..8c82b58 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/ModelPickerChipsPager.kt @@ -0,0 +1,178 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.aiedge.gallery.ui.common + +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.fadeIn +import androidx.compose.animation.fadeOut +import androidx.compose.animation.scaleIn +import androidx.compose.animation.scaleOut +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.pager.HorizontalPager +import androidx.compose.foundation.pager.rememberPagerState +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.rounded.ArrowDropDown +import androidx.compose.material3.CircularProgressIndicator +import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.ModalBottomSheet +import androidx.compose.material3.Text +import androidx.compose.material3.rememberModalBottomSheetState +import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.alpha +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.graphicsLayer +import androidx.compose.ui.unit.dp +import com.google.aiedge.gallery.data.Model +import com.google.aiedge.gallery.data.Task +import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon +import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlin.math.absoluteValue + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun ModelPickerChipsPager( + task: Task, + initialModel: Model, + modelManagerViewModel: ModelManagerViewModel, + onModelSelected: (Model) -> Unit, +) { + var showModelPicker by remember { mutableStateOf(false) } + var modelPickerModel by remember { mutableStateOf(null) } + val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() + val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) + val scope = rememberCoroutineScope() + + val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel), + pageCount = { task.models.size }) + + // Sync scrolling. + LaunchedEffect(modelManagerViewModel.pagerScrollState) { + modelManagerViewModel.pagerScrollState.collect { state -> + pagerState.scrollToPage(state.page, state.offset) + } + } + + HorizontalPager(state = pagerState, userScrollEnabled = false) { pageIndex -> + val model = task.models[pageIndex] + + // Calculate the alpha of the current page based on how far they are from the center. + val pageOffset = + ((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue + val curAlpha = 1f - (pageOffset * 1.5f).coerceIn(0f, 1f) + + val modelInitializationStatus = + modelManagerUiState.modelInitializationStatus[model.name] + + Box( + modifier = Modifier + .fillMaxWidth() + .graphicsLayer { alpha = curAlpha }, + contentAlignment = Alignment.Center + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(2.dp) + ) { + Row(verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(2.dp), + modifier = Modifier + .clip(CircleShape) + .background(MaterialTheme.colorScheme.surfaceContainerHigh) + .clickable { + modelPickerModel = model + showModelPicker = true + } + .padding(start = 8.dp, end = 2.dp) + .padding(vertical = 1.dp)) Inner@{ + Box(contentAlignment = Alignment.Center, modifier = Modifier.size(18.dp)) { + StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name]) + this@Inner.AnimatedVisibility( + visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, + enter = scaleIn() + fadeIn(), + exit = scaleOut() + fadeOut(), + ) { + // Circular progress indicator. + CircularProgressIndicator( + modifier = Modifier + .size(17.dp) + .alpha(0.5f), + strokeWidth = 2.dp, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + Text( + model.name, + style = MaterialTheme.typography.labelSmall, + modifier = Modifier.padding(start = 4.dp), + ) + Icon( + Icons.Rounded.ArrowDropDown, + modifier = Modifier.size(20.dp), + contentDescription = "", + ) + } + } + } + } + + // Model picker. + val curModelPickerModel = modelPickerModel + if (showModelPicker && curModelPickerModel != null) { + ModalBottomSheet( + onDismissRequest = { showModelPicker = false }, + sheetState = sheetState, + ) { + ModelPicker( + task = task, + modelManagerViewModel = modelManagerViewModel, + onModelSelected = { selectedModel -> + showModelPicker = false + + scope.launch(Dispatchers.Default) { + // Scroll to the selected model. + pagerState.animateScrollToPage(task.models.indexOf(selectedModel)) + } + + onModelSelected(selectedModel) + } + ) + } + } +} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt index a573590..8e66bf2 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt @@ -16,11 +16,6 @@ package com.google.aiedge.gallery.ui.common.chat -import androidx.compose.animation.AnimatedVisibility -import androidx.compose.animation.fadeIn -import androidx.compose.animation.fadeOut -import androidx.compose.animation.scaleIn -import androidx.compose.animation.scaleOut import androidx.compose.foundation.background import androidx.compose.foundation.clickable import androidx.compose.foundation.gestures.detectTapGestures @@ -33,7 +28,6 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.ime import androidx.compose.foundation.layout.imePadding -import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.width @@ -426,16 +420,6 @@ fun ChatPanel( } } - // Model initialization in-progress message. - this@Column.AnimatedVisibility( - visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, - enter = scaleIn() + fadeIn(), - exit = scaleOut() + fadeOut(), - modifier = Modifier.offset(y = 12.dp) - ) { - ModelInitializationStatusChip() - } - SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp)) } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt index 44d37b7..60992de 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatView.kt @@ -25,14 +25,17 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.padding import androidx.compose.foundation.pager.HorizontalPager import androidx.compose.foundation.pager.rememberPagerState -import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Scaffold import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue +import androidx.compose.runtime.snapshotFlow import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.platform.LocalContext @@ -42,6 +45,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.data.Task import com.google.aiedge.gallery.ui.common.ModelPageAppBar import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel +import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState import com.google.aiedge.gallery.ui.preview.PreviewChatModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.aiedge.gallery.ui.preview.TASK_TEST1 @@ -84,8 +88,10 @@ fun ChatView( pageCount = { task.models.size }) val context = LocalContext.current val scope = rememberCoroutineScope() + var navigatingUp by remember { mutableStateOf(false) } val handleNavigateUp = { + navigatingUp = true navigateUp() // clean up all models. @@ -99,9 +105,11 @@ fun ChatView( // Initialize model when model/download state changes. val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name] LaunchedEffect(curDownloadStatus, selectedModel.name) { - if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { - Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect") - modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) + if (!navigatingUp) { + if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { + Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect") + modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) + } } } @@ -118,6 +126,25 @@ fun ChatView( modelManagerViewModel.selectModel(curSelectedModel) } + LaunchedEffect(pagerState) { + // Collect from the a snapshotFlow reading the currentPage + snapshotFlow { pagerState.currentPage }.collect { page -> + Log.d(TAG, "Page changed to $page") + } + } + + // Trigger scroll sync. + LaunchedEffect(pagerState) { + snapshotFlow { + PagerScrollState( + page = pagerState.currentPage, + offset = pagerState.currentPageOffsetFraction + ) + }.collect { scrollState -> + modelManagerViewModel.pagerScrollState.value = scrollState + } + } + // Handle system's edge swipe. BackHandler { handleNavigateUp() diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt index 0d8a2c7..f34b1f4 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt @@ -18,17 +18,11 @@ package com.google.aiedge.gallery.ui.llmsingleturn import android.util.Log import androidx.activity.compose.BackHandler -import androidx.compose.animation.AnimatedVisibility -import androidx.compose.animation.fadeIn -import androidx.compose.animation.fadeOut -import androidx.compose.animation.scaleIn -import androidx.compose.animation.scaleOut import androidx.compose.foundation.background import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.calculateStartPadding import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.padding import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Scaffold @@ -42,6 +36,7 @@ import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.alpha import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalLayoutDirection import androidx.compose.ui.tooling.preview.Preview @@ -50,8 +45,6 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType import com.google.aiedge.gallery.ui.ViewModelProvider import com.google.aiedge.gallery.ui.common.ModelPageAppBar import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel -import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip -import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel @@ -116,9 +109,6 @@ fun LlmSingleTurnScreen( } } - val modelInitializationStatus = - modelManagerUiState.modelInitializationStatus[selectedModel.name] - Scaffold(modifier = modifier, topBar = { ModelPageAppBar( task = task, @@ -151,49 +141,42 @@ fun LlmSingleTurnScreen( ) // Main UI after model is downloaded. - if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { - Box( - contentAlignment = Alignment.BottomCenter, - modifier = Modifier.weight(1f) - ) { - VerticalSplitView(modifier = Modifier.fillMaxSize(), - topView = { - PromptTemplatesPanel( + val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED + Box( + contentAlignment = Alignment.BottomCenter, + modifier = Modifier + .weight(1f) + // Just hide the UI without removing it from the screen so that the scroll syncing + // from ResponsePanel still works. + .alpha(if (modelDownloaded) 1.0f else 0.0f) + ) { + VerticalSplitView(modifier = Modifier.fillMaxSize(), + topView = { + PromptTemplatesPanel( + model = selectedModel, + viewModel = viewModel, + onSend = { fullPrompt -> + viewModel.generateResponse(model = selectedModel, input = fullPrompt) + }, modifier = Modifier.fillMaxSize() + ) + }, + bottomView = { + Box( + contentAlignment = Alignment.BottomCenter, + modifier = Modifier + .fillMaxSize() + .background(MaterialTheme.customColors.agentBubbleBgColor) + ) { + ResponsePanel( model = selectedModel, viewModel = viewModel, - onSend = { fullPrompt -> - viewModel.generateResponse(model = selectedModel, input = fullPrompt) - }, modifier = Modifier.fillMaxSize() - ) - }, - bottomView = { - Box( - contentAlignment = Alignment.BottomCenter, + modelManagerViewModel = modelManagerViewModel, modifier = Modifier .fillMaxSize() - .background(MaterialTheme.customColors.agentBubbleBgColor) - ) { - ResponsePanel( - model = selectedModel, - viewModel = viewModel, - modifier = Modifier - .fillMaxSize() - .padding(bottom = innerPadding.calculateBottomPadding()) - ) - } - }) - - // Model initialization in-progress message. - this@Column.AnimatedVisibility( - visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, - enter = scaleIn() + fadeIn(), - exit = scaleOut() + fadeOut(), - modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding()) - ) { - ModelInitializationStatusChip() - } - - } + .padding(bottom = innerPadding.calculateBottomPadding()) + ) + } + }) } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/ResponsePanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/ResponsePanel.kt index 4a81d64..ebc79ac 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/ResponsePanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/llmsingleturn/ResponsePanel.kt @@ -16,6 +16,7 @@ package com.google.aiedge.gallery.ui.llmsingleturn +import android.util.Log import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -24,6 +25,8 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size +import androidx.compose.foundation.pager.HorizontalPager +import androidx.compose.foundation.pager.rememberPagerState import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.verticalScroll import androidx.compose.material.icons.Icons @@ -45,162 +48,195 @@ import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.setValue +import androidx.compose.runtime.snapshotFlow import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.draw.alpha import androidx.compose.ui.graphics.Color import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.text.AnnotatedString -import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import com.google.aiedge.gallery.data.Model import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN import com.google.aiedge.gallery.ui.common.chat.MarkdownText import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading -import com.google.aiedge.gallery.ui.theme.GalleryTheme +import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel +import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState private val OPTIONS = listOf("Response", "Benchmark") private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer) +private const val TAG = "AGResponsePanel" @OptIn(ExperimentalMaterial3Api::class) @Composable fun ResponsePanel( model: Model, viewModel: LlmSingleTurnViewModel, + modelManagerViewModel: ModelManagerViewModel, modifier: Modifier = Modifier, ) { + val task = TASK_LLM_SINGLE_TURN val uiState by viewModel.uiState.collectAsState() + val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val inProgress = uiState.inProgress val initializing = uiState.initializing val selectedPromptTemplateType = uiState.selectedPromptTemplateType - val response = uiState.responsesByModel[model.name]?.get(selectedPromptTemplateType.label) ?: "" - val benchmark = uiState.benchmarkByModel[model.name]?.get(selectedPromptTemplateType.label) val responseScrollState = rememberScrollState() var selectedOptionIndex by remember { mutableIntStateOf(0) } val clipboardManager = LocalClipboardManager.current - - // Scroll to bottom when response changes. - LaunchedEffect(response) { - if (inProgress) { - responseScrollState.animateScrollTo(responseScrollState.maxValue) - } - } + val pagerState = rememberPagerState( + initialPage = task.models.indexOf(model), + pageCount = { task.models.size }) // Select the "response" tab when prompt template changes. LaunchedEffect(selectedPromptTemplateType) { selectedOptionIndex = 0 } - if (initializing) { - Box( - contentAlignment = Alignment.TopStart, - modifier = modifier - .fillMaxSize() - .padding(horizontal = 16.dp) - ) { - MessageBodyLoading() + // Update selected model and clean up previous model when page is settled on a model page. + LaunchedEffect(pagerState.settledPage) { + val curSelectedModel = task.models[pagerState.settledPage] + Log.d( + TAG, + "Pager settled on model '${curSelectedModel.name}' from '${model.name}'. Updating selected model." + ) + if (curSelectedModel.name != model.name) { + modelManagerViewModel.cleanupModel(task = task, model = model) } - } else { - // Message when response is empty. - if (response.isEmpty()) { - Row( - modifier = Modifier.fillMaxSize(), - horizontalArrangement = Arrangement.Center, - verticalAlignment = Alignment.CenterVertically - ) { - Text( - "Response will appear here", - modifier = Modifier.alpha(0.5f), - style = MaterialTheme.typography.labelMedium, - ) + modelManagerViewModel.selectModel(curSelectedModel) + } + + // Trigger scroll sync. + LaunchedEffect(pagerState) { + snapshotFlow { + PagerScrollState( + page = pagerState.currentPage, + offset = pagerState.currentPageOffsetFraction + ) + }.collect { scrollState -> + modelManagerViewModel.pagerScrollState.value = scrollState + } + } + + // Scroll pager when selected model changes. + LaunchedEffect(modelManagerUiState.selectedModel) { + pagerState.animateScrollToPage(task.models.indexOf(model)) + } + + HorizontalPager(state = pagerState) { pageIndex -> + val curPageModel = task.models[pageIndex] + + val response = + uiState.responsesByModel[curPageModel.name]?.get(selectedPromptTemplateType.label) ?: "" + val benchmark = + uiState.benchmarkByModel[curPageModel.name]?.get(selectedPromptTemplateType.label) + + // Scroll to bottom when response changes. + LaunchedEffect(response) { + if (inProgress) { + responseScrollState.animateScrollTo(responseScrollState.maxValue) } } - // Response markdown. - else { - Column( + + if (initializing) { + Box( + contentAlignment = Alignment.TopStart, modifier = modifier + .fillMaxSize() .padding(horizontal = 16.dp) - .padding(bottom = 4.dp) ) { - // Response/benchmark switch. - Row(modifier = Modifier.fillMaxWidth()) { - PrimaryTabRow( - selectedTabIndex = selectedOptionIndex, - containerColor = Color.Transparent, - ) { - OPTIONS.forEachIndexed { index, title -> - Tab(selected = selectedOptionIndex == index, onClick = { - selectedOptionIndex = index - }, text = { - Row( - verticalAlignment = Alignment.CenterVertically, - horizontalArrangement = Arrangement.spacedBy(4.dp) - ) { - Icon( - ICONS[index], - contentDescription = "", - modifier = Modifier - .size(16.dp) - .alpha(0.7f) - ) - Text(text = title) - } - }) - } - } + MessageBodyLoading() + } + } else { + // Message when response is empty. + if (response.isEmpty()) { + Row( + modifier = Modifier.fillMaxSize(), + horizontalArrangement = Arrangement.Center, + verticalAlignment = Alignment.CenterVertically + ) { + Text( + "Response will appear here", + modifier = Modifier.alpha(0.5f), + style = MaterialTheme.typography.labelMedium, + ) } - if (selectedOptionIndex == 0) { - Box( - contentAlignment = Alignment.BottomEnd, - modifier = Modifier.weight(1f) - ) { - Column( - modifier = Modifier - .fillMaxSize() - .verticalScroll(responseScrollState) + } + // Response markdown. + else { + Column( + modifier = modifier + .padding(horizontal = 16.dp) + .padding(bottom = 4.dp) + ) { + // Response/benchmark switch. + Row(modifier = Modifier.fillMaxWidth()) { + PrimaryTabRow( + selectedTabIndex = selectedOptionIndex, + containerColor = Color.Transparent, ) { - MarkdownText( - text = response, - modifier = Modifier.padding(top = 8.dp, bottom = 40.dp) - ) - } - // Copy button. - IconButton( - onClick = { - val clipData = AnnotatedString(response) - clipboardManager.setText(clipData) - }, - colors = IconButtonDefaults.iconButtonColors( - containerColor = MaterialTheme.colorScheme.surfaceContainerHighest, - contentColor = MaterialTheme.colorScheme.primary, - ), - ) { - Icon( - Icons.Outlined.ContentCopy, - contentDescription = "", - modifier = Modifier.size(20.dp), - ) + OPTIONS.forEachIndexed { index, title -> + Tab(selected = selectedOptionIndex == index, onClick = { + selectedOptionIndex = index + }, text = { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(4.dp) + ) { + Icon( + ICONS[index], + contentDescription = "", + modifier = Modifier + .size(16.dp) + .alpha(0.7f) + ) + Text(text = title) + } + }) + } } } - } else if (selectedOptionIndex == 1) { - if (benchmark != null) { - MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth()) + if (selectedOptionIndex == 0) { + Box( + contentAlignment = Alignment.BottomEnd, + modifier = Modifier.weight(1f) + ) { + Column( + modifier = Modifier + .fillMaxSize() + .verticalScroll(responseScrollState) + ) { + MarkdownText( + text = response, + modifier = Modifier.padding(top = 8.dp, bottom = 40.dp) + ) + } + // Copy button. + IconButton( + onClick = { + val clipData = AnnotatedString(response) + clipboardManager.setText(clipData) + }, + colors = IconButtonDefaults.iconButtonColors( + containerColor = MaterialTheme.colorScheme.surfaceContainerHighest, + contentColor = MaterialTheme.colorScheme.primary, + ), + ) { + Icon( + Icons.Outlined.ContentCopy, + contentDescription = "", + modifier = Modifier.size(20.dp), + ) + } + } + } else if (selectedOptionIndex == 1) { + if (benchmark != null) { + MessageBodyBenchmarkLlm(message = benchmark, modifier = Modifier.fillMaxWidth()) + } } } } } } } - -@Preview(showBackground = true) -@Composable -fun ResponsePanelPreview() { - GalleryTheme { - ResponsePanel( - model = TASK_LLM_SINGLE_TURN.models[0], - viewModel = LlmSingleTurnViewModel(), - modifier = Modifier.fillMaxSize() - ) - } -} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt index f637ba5..79bf5d9 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -133,6 +133,11 @@ data class ModelManagerUiState( val textInputHistory: List = listOf(), ) +data class PagerScrollState( + val page: Int = 0, + val offset: Float = 0f, +) + /** * ViewModel responsible for managing models, their download status, and initialization. * @@ -150,9 +155,12 @@ open class ModelManagerViewModel( downloadRepository.getEnqueuedOrRunningWorkInfos() protected val _uiState = MutableStateFlow(createUiState()) val uiState = _uiState.asStateFlow() + val authService = AuthorizationService(context) var curAccessToken: String = "" + var pagerScrollState: MutableStateFlow = MutableStateFlow(PagerScrollState()) + init { Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos") @@ -269,6 +277,7 @@ open class ModelManagerViewModel( // Skip if initialization is in progress. if (model.initializing) { + model.cleanUpAfterInit = false Log.d(TAG, "Model '${model.name}' is being initialized. Skipping.") return@launch } @@ -299,6 +308,10 @@ open class ModelManagerViewModel( model = model, status = ModelInitializationStatusType.INITIALIZED, ) + if (model.cleanUpAfterInit) { + Log.d(TAG, "Model '${model.name}' needs cleaning up after init.") + cleanupModel(task = task, model = model) + } } else if (error.isNotEmpty()) { Log.d(TAG, "Model '${model.name}' failed to initialize") updateModelInitializationStatus( @@ -345,6 +358,7 @@ open class ModelManagerViewModel( fun cleanupModel(task: Task, model: Model) { if (model.instance != null) { + model.cleanUpAfterInit = false Log.d(TAG, "Cleaning up model '${model.name}'...") when (task.type) { TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model) @@ -360,6 +374,12 @@ open class ModelManagerViewModel( updateModelInitializationStatus( model = model, status = ModelInitializationStatusType.NOT_INITIALIZED ) + } else { + // When model is being initialized and we are trying to clean it up at same time, we mark it + // to clean up and it will be cleaned up after initialization is done. + if (model.initializing) { + model.cleanUpAfterInit = true + } } }