mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-14 02:12:02 -04:00
Better support for sliding to change model
This commit is contained in:
parent
3341286efa
commit
d0beaab31e
8 changed files with 433 additions and 237 deletions
|
@ -90,6 +90,8 @@ data class Model(
|
||||||
// The following fields are managed by the app. Don't need to set manually.
|
// The following fields are managed by the app. Don't need to set manually.
|
||||||
var instance: Any? = null,
|
var instance: Any? = null,
|
||||||
var initializing: Boolean = false,
|
var initializing: Boolean = false,
|
||||||
|
// TODO(jingjin): use a "queue" system to manage model init and cleanup.
|
||||||
|
var cleanUpAfterInit: Boolean = false,
|
||||||
var configValues: Map<String, Any> = mapOf(),
|
var configValues: Map<String, Any> = mapOf(),
|
||||||
var totalBytes: Long = 0L,
|
var totalBytes: Long = 0L,
|
||||||
var accessToken: String? = null,
|
var accessToken: String? = null,
|
||||||
|
|
|
@ -16,26 +16,19 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.common
|
package com.google.aiedge.gallery.ui.common
|
||||||
|
|
||||||
import androidx.compose.foundation.background
|
|
||||||
import androidx.compose.foundation.clickable
|
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
import androidx.compose.foundation.layout.padding
|
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.foundation.shape.CircleShape
|
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
|
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
|
||||||
import androidx.compose.material.icons.rounded.ArrowDropDown
|
|
||||||
import androidx.compose.material.icons.rounded.Settings
|
import androidx.compose.material.icons.rounded.Settings
|
||||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
import androidx.compose.material3.Icon
|
import androidx.compose.material3.Icon
|
||||||
import androidx.compose.material3.IconButton
|
import androidx.compose.material3.IconButton
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.ModalBottomSheet
|
|
||||||
import androidx.compose.material3.Text
|
import androidx.compose.material3.Text
|
||||||
import androidx.compose.material3.rememberModalBottomSheetState
|
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.collectAsState
|
import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
|
@ -44,7 +37,7 @@ import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.clip
|
import androidx.compose.ui.draw.alpha
|
||||||
import androidx.compose.ui.graphics.vector.ImageVector
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.res.vectorResource
|
import androidx.compose.ui.res.vectorResource
|
||||||
|
@ -54,7 +47,6 @@ import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
import com.google.aiedge.gallery.ui.common.chat.ConfigDialog
|
||||||
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
|
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
|
@ -69,14 +61,15 @@ fun ModelPageAppBar(
|
||||||
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
|
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
|
||||||
) {
|
) {
|
||||||
var showConfigDialog by remember { mutableStateOf(false) }
|
var showConfigDialog by remember { mutableStateOf(false) }
|
||||||
var showModelPicker by remember { mutableStateOf(false) }
|
|
||||||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[model.name]
|
||||||
|
|
||||||
CenterAlignedTopAppBar(title = {
|
CenterAlignedTopAppBar(title = {
|
||||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
Column(
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.spacedBy(4.dp)
|
||||||
|
) {
|
||||||
// Task type.
|
// Task type.
|
||||||
Row(
|
Row(
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
@ -95,29 +88,13 @@ fun ModelPageAppBar(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model name.
|
// Model chips pager.
|
||||||
Row(verticalAlignment = Alignment.CenterVertically,
|
ModelPickerChipsPager(
|
||||||
horizontalArrangement = Arrangement.spacedBy(2.dp),
|
task = task,
|
||||||
modifier = Modifier
|
initialModel = model,
|
||||||
.clip(CircleShape)
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
|
onModelSelected = onModelSelected,
|
||||||
.clickable {
|
|
||||||
showModelPicker = true
|
|
||||||
}
|
|
||||||
.padding(start = 8.dp, end = 2.dp)) {
|
|
||||||
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
|
||||||
Text(
|
|
||||||
model.name,
|
|
||||||
style = MaterialTheme.typography.labelSmall,
|
|
||||||
modifier = Modifier.padding(start = 4.dp),
|
|
||||||
)
|
)
|
||||||
Icon(
|
|
||||||
Icons.Rounded.ArrowDropDown,
|
|
||||||
modifier = Modifier.size(20.dp),
|
|
||||||
contentDescription = "",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}, modifier = modifier,
|
}, modifier = modifier,
|
||||||
// The back button.
|
// The back button.
|
||||||
|
@ -131,15 +108,21 @@ fun ModelPageAppBar(
|
||||||
},
|
},
|
||||||
// The config button for the model (if existed).
|
// The config button for the model (if existed).
|
||||||
actions = {
|
actions = {
|
||||||
if (model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
val showConfigButton =
|
||||||
IconButton(onClick = { showConfigDialog = true }) {
|
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
|
||||||
|
IconButton(
|
||||||
|
onClick = {
|
||||||
|
showConfigDialog = true
|
||||||
|
},
|
||||||
|
enabled = showConfigButton,
|
||||||
|
modifier = Modifier.alpha(if (showConfigButton) 1f else 0f)
|
||||||
|
) {
|
||||||
Icon(
|
Icon(
|
||||||
imageVector = Icons.Rounded.Settings,
|
imageVector = Icons.Rounded.Settings,
|
||||||
contentDescription = "",
|
contentDescription = "",
|
||||||
tint = MaterialTheme.colorScheme.primary
|
tint = MaterialTheme.colorScheme.primary
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// Config dialog.
|
// Config dialog.
|
||||||
|
@ -193,21 +176,4 @@ fun ModelPageAppBar(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model picker.
|
|
||||||
if (showModelPicker) {
|
|
||||||
ModalBottomSheet(
|
|
||||||
onDismissRequest = { showModelPicker = false },
|
|
||||||
sheetState = sheetState,
|
|
||||||
) {
|
|
||||||
ModelPicker(
|
|
||||||
task = task,
|
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
|
||||||
onModelSelected = { model ->
|
|
||||||
showModelPicker = false
|
|
||||||
onModelSelected(model)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -0,0 +1,178 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.google.aiedge.gallery.ui.common
|
||||||
|
|
||||||
|
import androidx.compose.animation.AnimatedVisibility
|
||||||
|
import androidx.compose.animation.fadeIn
|
||||||
|
import androidx.compose.animation.fadeOut
|
||||||
|
import androidx.compose.animation.scaleIn
|
||||||
|
import androidx.compose.animation.scaleOut
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Box
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.foundation.layout.size
|
||||||
|
import androidx.compose.foundation.pager.HorizontalPager
|
||||||
|
import androidx.compose.foundation.pager.rememberPagerState
|
||||||
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.rounded.ArrowDropDown
|
||||||
|
import androidx.compose.material3.CircularProgressIndicator
|
||||||
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
|
import androidx.compose.material3.Icon
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.ModalBottomSheet
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.material3.rememberModalBottomSheetState
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.alpha
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.graphics.graphicsLayer
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import com.google.aiedge.gallery.data.Model
|
||||||
|
import com.google.aiedge.gallery.data.Task
|
||||||
|
import com.google.aiedge.gallery.ui.common.modelitem.StatusIcon
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import kotlin.math.absoluteValue
|
||||||
|
|
||||||
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
|
@Composable
|
||||||
|
fun ModelPickerChipsPager(
|
||||||
|
task: Task,
|
||||||
|
initialModel: Model,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
|
onModelSelected: (Model) -> Unit,
|
||||||
|
) {
|
||||||
|
var showModelPicker by remember { mutableStateOf(false) }
|
||||||
|
var modelPickerModel by remember { mutableStateOf<Model?>(null) }
|
||||||
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
|
val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||||
|
val scope = rememberCoroutineScope()
|
||||||
|
|
||||||
|
val pagerState = rememberPagerState(initialPage = task.models.indexOf(initialModel),
|
||||||
|
pageCount = { task.models.size })
|
||||||
|
|
||||||
|
// Sync scrolling.
|
||||||
|
LaunchedEffect(modelManagerViewModel.pagerScrollState) {
|
||||||
|
modelManagerViewModel.pagerScrollState.collect { state ->
|
||||||
|
pagerState.scrollToPage(state.page, state.offset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HorizontalPager(state = pagerState, userScrollEnabled = false) { pageIndex ->
|
||||||
|
val model = task.models[pageIndex]
|
||||||
|
|
||||||
|
// Calculate the alpha of the current page based on how far they are from the center.
|
||||||
|
val pageOffset =
|
||||||
|
((pagerState.currentPage - pageIndex) + pagerState.currentPageOffsetFraction).absoluteValue
|
||||||
|
val curAlpha = 1f - (pageOffset * 1.5f).coerceIn(0f, 1f)
|
||||||
|
|
||||||
|
val modelInitializationStatus =
|
||||||
|
modelManagerUiState.modelInitializationStatus[model.name]
|
||||||
|
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.graphicsLayer { alpha = curAlpha },
|
||||||
|
contentAlignment = Alignment.Center
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(2.dp)
|
||||||
|
) {
|
||||||
|
Row(verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(2.dp),
|
||||||
|
modifier = Modifier
|
||||||
|
.clip(CircleShape)
|
||||||
|
.background(MaterialTheme.colorScheme.surfaceContainerHigh)
|
||||||
|
.clickable {
|
||||||
|
modelPickerModel = model
|
||||||
|
showModelPicker = true
|
||||||
|
}
|
||||||
|
.padding(start = 8.dp, end = 2.dp)
|
||||||
|
.padding(vertical = 1.dp)) Inner@{
|
||||||
|
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(18.dp)) {
|
||||||
|
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
||||||
|
this@Inner.AnimatedVisibility(
|
||||||
|
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||||
|
enter = scaleIn() + fadeIn(),
|
||||||
|
exit = scaleOut() + fadeOut(),
|
||||||
|
) {
|
||||||
|
// Circular progress indicator.
|
||||||
|
CircularProgressIndicator(
|
||||||
|
modifier = Modifier
|
||||||
|
.size(17.dp)
|
||||||
|
.alpha(0.5f),
|
||||||
|
strokeWidth = 2.dp,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Text(
|
||||||
|
model.name,
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
modifier = Modifier.padding(start = 4.dp),
|
||||||
|
)
|
||||||
|
Icon(
|
||||||
|
Icons.Rounded.ArrowDropDown,
|
||||||
|
modifier = Modifier.size(20.dp),
|
||||||
|
contentDescription = "",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model picker.
|
||||||
|
val curModelPickerModel = modelPickerModel
|
||||||
|
if (showModelPicker && curModelPickerModel != null) {
|
||||||
|
ModalBottomSheet(
|
||||||
|
onDismissRequest = { showModelPicker = false },
|
||||||
|
sheetState = sheetState,
|
||||||
|
) {
|
||||||
|
ModelPicker(
|
||||||
|
task = task,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
|
onModelSelected = { selectedModel ->
|
||||||
|
showModelPicker = false
|
||||||
|
|
||||||
|
scope.launch(Dispatchers.Default) {
|
||||||
|
// Scroll to the selected model.
|
||||||
|
pagerState.animateScrollToPage(task.models.indexOf(selectedModel))
|
||||||
|
}
|
||||||
|
|
||||||
|
onModelSelected(selectedModel)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,11 +16,6 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.common.chat
|
package com.google.aiedge.gallery.ui.common.chat
|
||||||
|
|
||||||
import androidx.compose.animation.AnimatedVisibility
|
|
||||||
import androidx.compose.animation.fadeIn
|
|
||||||
import androidx.compose.animation.fadeOut
|
|
||||||
import androidx.compose.animation.scaleIn
|
|
||||||
import androidx.compose.animation.scaleOut
|
|
||||||
import androidx.compose.foundation.background
|
import androidx.compose.foundation.background
|
||||||
import androidx.compose.foundation.clickable
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.gestures.detectTapGestures
|
import androidx.compose.foundation.gestures.detectTapGestures
|
||||||
|
@ -33,7 +28,6 @@ import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.ime
|
import androidx.compose.foundation.layout.ime
|
||||||
import androidx.compose.foundation.layout.imePadding
|
import androidx.compose.foundation.layout.imePadding
|
||||||
import androidx.compose.foundation.layout.offset
|
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
import androidx.compose.foundation.layout.width
|
import androidx.compose.foundation.layout.width
|
||||||
|
@ -426,16 +420,6 @@ fun ChatPanel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model initialization in-progress message.
|
|
||||||
this@Column.AnimatedVisibility(
|
|
||||||
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
|
||||||
enter = scaleIn() + fadeIn(),
|
|
||||||
exit = scaleOut() + fadeOut(),
|
|
||||||
modifier = Modifier.offset(y = 12.dp)
|
|
||||||
) {
|
|
||||||
ModelInitializationStatusChip()
|
|
||||||
}
|
|
||||||
|
|
||||||
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
|
SnackbarHost(hostState = snackbarHostState, modifier = Modifier.padding(vertical = 4.dp))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,14 +25,17 @@ import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.pager.HorizontalPager
|
import androidx.compose.foundation.pager.HorizontalPager
|
||||||
import androidx.compose.foundation.pager.rememberPagerState
|
import androidx.compose.foundation.pager.rememberPagerState
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.Scaffold
|
import androidx.compose.material3.Scaffold
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.LaunchedEffect
|
import androidx.compose.runtime.LaunchedEffect
|
||||||
import androidx.compose.runtime.collectAsState
|
import androidx.compose.runtime.collectAsState
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.runtime.mutableStateOf
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.rememberCoroutineScope
|
import androidx.compose.runtime.rememberCoroutineScope
|
||||||
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.runtime.snapshotFlow
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.graphics.graphicsLayer
|
import androidx.compose.ui.graphics.graphicsLayer
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
@ -42,6 +45,7 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.data.Task
|
import com.google.aiedge.gallery.data.Task
|
||||||
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
import com.google.aiedge.gallery.ui.preview.PreviewChatModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
import com.google.aiedge.gallery.ui.preview.TASK_TEST1
|
||||||
|
@ -84,8 +88,10 @@ fun ChatView(
|
||||||
pageCount = { task.models.size })
|
pageCount = { task.models.size })
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
val scope = rememberCoroutineScope()
|
val scope = rememberCoroutineScope()
|
||||||
|
var navigatingUp by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
val handleNavigateUp = {
|
val handleNavigateUp = {
|
||||||
|
navigatingUp = true
|
||||||
navigateUp()
|
navigateUp()
|
||||||
|
|
||||||
// clean up all models.
|
// clean up all models.
|
||||||
|
@ -99,11 +105,13 @@ fun ChatView(
|
||||||
// Initialize model when model/download state changes.
|
// Initialize model when model/download state changes.
|
||||||
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name]
|
||||||
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
LaunchedEffect(curDownloadStatus, selectedModel.name) {
|
||||||
|
if (!navigatingUp) {
|
||||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
||||||
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
|
Log.d(TAG, "Initializing model '${selectedModel.name}' from ChatView launched effect")
|
||||||
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
modelManagerViewModel.initializeModel(context, task = task, model = selectedModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Update selected model and clean up previous model when page is settled on a model page.
|
// Update selected model and clean up previous model when page is settled on a model page.
|
||||||
LaunchedEffect(pagerState.settledPage) {
|
LaunchedEffect(pagerState.settledPage) {
|
||||||
|
@ -118,6 +126,25 @@ fun ChatView(
|
||||||
modelManagerViewModel.selectModel(curSelectedModel)
|
modelManagerViewModel.selectModel(curSelectedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LaunchedEffect(pagerState) {
|
||||||
|
// Collect from the a snapshotFlow reading the currentPage
|
||||||
|
snapshotFlow { pagerState.currentPage }.collect { page ->
|
||||||
|
Log.d(TAG, "Page changed to $page")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger scroll sync.
|
||||||
|
LaunchedEffect(pagerState) {
|
||||||
|
snapshotFlow {
|
||||||
|
PagerScrollState(
|
||||||
|
page = pagerState.currentPage,
|
||||||
|
offset = pagerState.currentPageOffsetFraction
|
||||||
|
)
|
||||||
|
}.collect { scrollState ->
|
||||||
|
modelManagerViewModel.pagerScrollState.value = scrollState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle system's edge swipe.
|
// Handle system's edge swipe.
|
||||||
BackHandler {
|
BackHandler {
|
||||||
handleNavigateUp()
|
handleNavigateUp()
|
||||||
|
|
|
@ -18,17 +18,11 @@ package com.google.aiedge.gallery.ui.llmsingleturn
|
||||||
|
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.activity.compose.BackHandler
|
import androidx.activity.compose.BackHandler
|
||||||
import androidx.compose.animation.AnimatedVisibility
|
|
||||||
import androidx.compose.animation.fadeIn
|
|
||||||
import androidx.compose.animation.fadeOut
|
|
||||||
import androidx.compose.animation.scaleIn
|
|
||||||
import androidx.compose.animation.scaleOut
|
|
||||||
import androidx.compose.foundation.background
|
import androidx.compose.foundation.background
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.calculateStartPadding
|
import androidx.compose.foundation.layout.calculateStartPadding
|
||||||
import androidx.compose.foundation.layout.fillMaxSize
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.foundation.layout.offset
|
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.material3.MaterialTheme
|
import androidx.compose.material3.MaterialTheme
|
||||||
import androidx.compose.material3.Scaffold
|
import androidx.compose.material3.Scaffold
|
||||||
|
@ -42,6 +36,7 @@ import androidx.compose.runtime.rememberCoroutineScope
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.alpha
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.platform.LocalLayoutDirection
|
import androidx.compose.ui.platform.LocalLayoutDirection
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
@ -50,8 +45,6 @@ import com.google.aiedge.gallery.data.ModelDownloadStatusType
|
||||||
import com.google.aiedge.gallery.ui.ViewModelProvider
|
import com.google.aiedge.gallery.ui.ViewModelProvider
|
||||||
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
import com.google.aiedge.gallery.ui.common.ModelPageAppBar
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
|
import com.google.aiedge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel
|
||||||
import com.google.aiedge.gallery.ui.common.chat.ModelInitializationStatusChip
|
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelInitializationStatusType
|
|
||||||
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewLlmSingleTurnViewModel
|
||||||
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
import com.google.aiedge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||||
|
@ -116,9 +109,6 @@ fun LlmSingleTurnScreen(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val modelInitializationStatus =
|
|
||||||
modelManagerUiState.modelInitializationStatus[selectedModel.name]
|
|
||||||
|
|
||||||
Scaffold(modifier = modifier, topBar = {
|
Scaffold(modifier = modifier, topBar = {
|
||||||
ModelPageAppBar(
|
ModelPageAppBar(
|
||||||
task = task,
|
task = task,
|
||||||
|
@ -151,10 +141,14 @@ fun LlmSingleTurnScreen(
|
||||||
)
|
)
|
||||||
|
|
||||||
// Main UI after model is downloaded.
|
// Main UI after model is downloaded.
|
||||||
if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) {
|
val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
|
||||||
Box(
|
Box(
|
||||||
contentAlignment = Alignment.BottomCenter,
|
contentAlignment = Alignment.BottomCenter,
|
||||||
modifier = Modifier.weight(1f)
|
modifier = Modifier
|
||||||
|
.weight(1f)
|
||||||
|
// Just hide the UI without removing it from the screen so that the scroll syncing
|
||||||
|
// from ResponsePanel still works.
|
||||||
|
.alpha(if (modelDownloaded) 1.0f else 0.0f)
|
||||||
) {
|
) {
|
||||||
VerticalSplitView(modifier = Modifier.fillMaxSize(),
|
VerticalSplitView(modifier = Modifier.fillMaxSize(),
|
||||||
topView = {
|
topView = {
|
||||||
|
@ -176,24 +170,13 @@ fun LlmSingleTurnScreen(
|
||||||
ResponsePanel(
|
ResponsePanel(
|
||||||
model = selectedModel,
|
model = selectedModel,
|
||||||
viewModel = viewModel,
|
viewModel = viewModel,
|
||||||
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxSize()
|
.fillMaxSize()
|
||||||
.padding(bottom = innerPadding.calculateBottomPadding())
|
.padding(bottom = innerPadding.calculateBottomPadding())
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Model initialization in-progress message.
|
|
||||||
this@Column.AnimatedVisibility(
|
|
||||||
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
|
||||||
enter = scaleIn() + fadeIn(),
|
|
||||||
exit = scaleOut() + fadeOut(),
|
|
||||||
modifier = Modifier.offset(y = -innerPadding.calculateBottomPadding())
|
|
||||||
) {
|
|
||||||
ModelInitializationStatusChip()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package com.google.aiedge.gallery.ui.llmsingleturn
|
package com.google.aiedge.gallery.ui.llmsingleturn
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
|
@ -24,6 +25,8 @@ import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.foundation.layout.fillMaxWidth
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.layout.size
|
import androidx.compose.foundation.layout.size
|
||||||
|
import androidx.compose.foundation.pager.HorizontalPager
|
||||||
|
import androidx.compose.foundation.pager.rememberPagerState
|
||||||
import androidx.compose.foundation.rememberScrollState
|
import androidx.compose.foundation.rememberScrollState
|
||||||
import androidx.compose.foundation.verticalScroll
|
import androidx.compose.foundation.verticalScroll
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
|
@ -45,40 +48,89 @@ import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableIntStateOf
|
import androidx.compose.runtime.mutableIntStateOf
|
||||||
import androidx.compose.runtime.remember
|
import androidx.compose.runtime.remember
|
||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
|
import androidx.compose.runtime.snapshotFlow
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.alpha
|
import androidx.compose.ui.draw.alpha
|
||||||
import androidx.compose.ui.graphics.Color
|
import androidx.compose.ui.graphics.Color
|
||||||
import androidx.compose.ui.platform.LocalClipboardManager
|
import androidx.compose.ui.platform.LocalClipboardManager
|
||||||
import androidx.compose.ui.text.AnnotatedString
|
import androidx.compose.ui.text.AnnotatedString
|
||||||
import androidx.compose.ui.tooling.preview.Preview
|
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import com.google.aiedge.gallery.data.Model
|
import com.google.aiedge.gallery.data.Model
|
||||||
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
import com.google.aiedge.gallery.data.TASK_LLM_SINGLE_TURN
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
import com.google.aiedge.gallery.ui.common.chat.MarkdownText
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
import com.google.aiedge.gallery.ui.common.chat.MessageBodyBenchmarkLlm
|
||||||
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
import com.google.aiedge.gallery.ui.common.chat.MessageBodyLoading
|
||||||
import com.google.aiedge.gallery.ui.theme.GalleryTheme
|
import com.google.aiedge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
import com.google.aiedge.gallery.ui.modelmanager.PagerScrollState
|
||||||
|
|
||||||
private val OPTIONS = listOf("Response", "Benchmark")
|
private val OPTIONS = listOf("Response", "Benchmark")
|
||||||
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
|
private val ICONS = listOf(Icons.Outlined.AutoAwesome, Icons.Outlined.Timer)
|
||||||
|
private const val TAG = "AGResponsePanel"
|
||||||
|
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun ResponsePanel(
|
fun ResponsePanel(
|
||||||
model: Model,
|
model: Model,
|
||||||
viewModel: LlmSingleTurnViewModel,
|
viewModel: LlmSingleTurnViewModel,
|
||||||
|
modelManagerViewModel: ModelManagerViewModel,
|
||||||
modifier: Modifier = Modifier,
|
modifier: Modifier = Modifier,
|
||||||
) {
|
) {
|
||||||
|
val task = TASK_LLM_SINGLE_TURN
|
||||||
val uiState by viewModel.uiState.collectAsState()
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
|
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||||
val inProgress = uiState.inProgress
|
val inProgress = uiState.inProgress
|
||||||
val initializing = uiState.initializing
|
val initializing = uiState.initializing
|
||||||
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
val selectedPromptTemplateType = uiState.selectedPromptTemplateType
|
||||||
val response = uiState.responsesByModel[model.name]?.get(selectedPromptTemplateType.label) ?: ""
|
|
||||||
val benchmark = uiState.benchmarkByModel[model.name]?.get(selectedPromptTemplateType.label)
|
|
||||||
val responseScrollState = rememberScrollState()
|
val responseScrollState = rememberScrollState()
|
||||||
var selectedOptionIndex by remember { mutableIntStateOf(0) }
|
var selectedOptionIndex by remember { mutableIntStateOf(0) }
|
||||||
val clipboardManager = LocalClipboardManager.current
|
val clipboardManager = LocalClipboardManager.current
|
||||||
|
val pagerState = rememberPagerState(
|
||||||
|
initialPage = task.models.indexOf(model),
|
||||||
|
pageCount = { task.models.size })
|
||||||
|
|
||||||
|
// Select the "response" tab when prompt template changes.
|
||||||
|
LaunchedEffect(selectedPromptTemplateType) {
|
||||||
|
selectedOptionIndex = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update selected model and clean up previous model when page is settled on a model page.
|
||||||
|
LaunchedEffect(pagerState.settledPage) {
|
||||||
|
val curSelectedModel = task.models[pagerState.settledPage]
|
||||||
|
Log.d(
|
||||||
|
TAG,
|
||||||
|
"Pager settled on model '${curSelectedModel.name}' from '${model.name}'. Updating selected model."
|
||||||
|
)
|
||||||
|
if (curSelectedModel.name != model.name) {
|
||||||
|
modelManagerViewModel.cleanupModel(task = task, model = model)
|
||||||
|
}
|
||||||
|
modelManagerViewModel.selectModel(curSelectedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger scroll sync.
|
||||||
|
LaunchedEffect(pagerState) {
|
||||||
|
snapshotFlow {
|
||||||
|
PagerScrollState(
|
||||||
|
page = pagerState.currentPage,
|
||||||
|
offset = pagerState.currentPageOffsetFraction
|
||||||
|
)
|
||||||
|
}.collect { scrollState ->
|
||||||
|
modelManagerViewModel.pagerScrollState.value = scrollState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scroll pager when selected model changes.
|
||||||
|
LaunchedEffect(modelManagerUiState.selectedModel) {
|
||||||
|
pagerState.animateScrollToPage(task.models.indexOf(model))
|
||||||
|
}
|
||||||
|
|
||||||
|
HorizontalPager(state = pagerState) { pageIndex ->
|
||||||
|
val curPageModel = task.models[pageIndex]
|
||||||
|
|
||||||
|
val response =
|
||||||
|
uiState.responsesByModel[curPageModel.name]?.get(selectedPromptTemplateType.label) ?: ""
|
||||||
|
val benchmark =
|
||||||
|
uiState.benchmarkByModel[curPageModel.name]?.get(selectedPromptTemplateType.label)
|
||||||
|
|
||||||
// Scroll to bottom when response changes.
|
// Scroll to bottom when response changes.
|
||||||
LaunchedEffect(response) {
|
LaunchedEffect(response) {
|
||||||
|
@ -87,11 +139,6 @@ fun ResponsePanel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select the "response" tab when prompt template changes.
|
|
||||||
LaunchedEffect(selectedPromptTemplateType) {
|
|
||||||
selectedOptionIndex = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if (initializing) {
|
if (initializing) {
|
||||||
Box(
|
Box(
|
||||||
contentAlignment = Alignment.TopStart,
|
contentAlignment = Alignment.TopStart,
|
||||||
|
@ -191,16 +238,5 @@ fun ResponsePanel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Preview(showBackground = true)
|
|
||||||
@Composable
|
|
||||||
fun ResponsePanelPreview() {
|
|
||||||
GalleryTheme {
|
|
||||||
ResponsePanel(
|
|
||||||
model = TASK_LLM_SINGLE_TURN.models[0],
|
|
||||||
viewModel = LlmSingleTurnViewModel(),
|
|
||||||
modifier = Modifier.fillMaxSize()
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -133,6 +133,11 @@ data class ModelManagerUiState(
|
||||||
val textInputHistory: List<String> = listOf(),
|
val textInputHistory: List<String> = listOf(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
data class PagerScrollState(
|
||||||
|
val page: Int = 0,
|
||||||
|
val offset: Float = 0f,
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ViewModel responsible for managing models, their download status, and initialization.
|
* ViewModel responsible for managing models, their download status, and initialization.
|
||||||
*
|
*
|
||||||
|
@ -150,9 +155,12 @@ open class ModelManagerViewModel(
|
||||||
downloadRepository.getEnqueuedOrRunningWorkInfos()
|
downloadRepository.getEnqueuedOrRunningWorkInfos()
|
||||||
protected val _uiState = MutableStateFlow(createUiState())
|
protected val _uiState = MutableStateFlow(createUiState())
|
||||||
val uiState = _uiState.asStateFlow()
|
val uiState = _uiState.asStateFlow()
|
||||||
|
|
||||||
val authService = AuthorizationService(context)
|
val authService = AuthorizationService(context)
|
||||||
var curAccessToken: String = ""
|
var curAccessToken: String = ""
|
||||||
|
|
||||||
|
var pagerScrollState: MutableStateFlow<PagerScrollState> = MutableStateFlow(PagerScrollState())
|
||||||
|
|
||||||
init {
|
init {
|
||||||
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
|
Log.d(TAG, "In-progress worker infos: $inProgressWorkInfos")
|
||||||
|
|
||||||
|
@ -269,6 +277,7 @@ open class ModelManagerViewModel(
|
||||||
|
|
||||||
// Skip if initialization is in progress.
|
// Skip if initialization is in progress.
|
||||||
if (model.initializing) {
|
if (model.initializing) {
|
||||||
|
model.cleanUpAfterInit = false
|
||||||
Log.d(TAG, "Model '${model.name}' is being initialized. Skipping.")
|
Log.d(TAG, "Model '${model.name}' is being initialized. Skipping.")
|
||||||
return@launch
|
return@launch
|
||||||
}
|
}
|
||||||
|
@ -299,6 +308,10 @@ open class ModelManagerViewModel(
|
||||||
model = model,
|
model = model,
|
||||||
status = ModelInitializationStatusType.INITIALIZED,
|
status = ModelInitializationStatusType.INITIALIZED,
|
||||||
)
|
)
|
||||||
|
if (model.cleanUpAfterInit) {
|
||||||
|
Log.d(TAG, "Model '${model.name}' needs cleaning up after init.")
|
||||||
|
cleanupModel(task = task, model = model)
|
||||||
|
}
|
||||||
} else if (error.isNotEmpty()) {
|
} else if (error.isNotEmpty()) {
|
||||||
Log.d(TAG, "Model '${model.name}' failed to initialize")
|
Log.d(TAG, "Model '${model.name}' failed to initialize")
|
||||||
updateModelInitializationStatus(
|
updateModelInitializationStatus(
|
||||||
|
@ -345,6 +358,7 @@ open class ModelManagerViewModel(
|
||||||
|
|
||||||
fun cleanupModel(task: Task, model: Model) {
|
fun cleanupModel(task: Task, model: Model) {
|
||||||
if (model.instance != null) {
|
if (model.instance != null) {
|
||||||
|
model.cleanUpAfterInit = false
|
||||||
Log.d(TAG, "Cleaning up model '${model.name}'...")
|
Log.d(TAG, "Cleaning up model '${model.name}'...")
|
||||||
when (task.type) {
|
when (task.type) {
|
||||||
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
TaskType.TEXT_CLASSIFICATION -> TextClassificationModelHelper.cleanUp(model = model)
|
||||||
|
@ -360,6 +374,12 @@ open class ModelManagerViewModel(
|
||||||
updateModelInitializationStatus(
|
updateModelInitializationStatus(
|
||||||
model = model, status = ModelInitializationStatusType.NOT_INITIALIZED
|
model = model, status = ModelInitializationStatusType.NOT_INITIALIZED
|
||||||
)
|
)
|
||||||
|
} else {
|
||||||
|
// When model is being initialized and we are trying to clean it up at same time, we mark it
|
||||||
|
// to clean up and it will be cleaned up after initialization is done.
|
||||||
|
if (model.initializing) {
|
||||||
|
model.cleanUpAfterInit = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue