- Improve UX

- Fix a bug related to LLM inference engine cleanup.
This commit is contained in:
Jing Jin 2025-05-21 12:33:35 -07:00
parent b1aa6511dd
commit f9cab2f06d
21 changed files with 216 additions and 78 deletions

View file

@ -31,7 +31,7 @@ android {
minSdk = 26 minSdk = 26
targetSdk = 35 targetSdk = 35
versionCode = 1 versionCode = 1
versionName = "1.0.0" versionName = "1.0.1"
// Needed for HuggingFace auth workflows. // Needed for HuggingFace auth workflows.
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.ai.edge.gallery.oauth" manifestPlaceholders["appAuthRedirectScheme"] = "com.google.ai.edge.gallery.oauth"

View file

@ -89,7 +89,7 @@ class NumberSliderConfig(
type = ConfigEditorType.NUMBER_SLIDER, type = ConfigEditorType.NUMBER_SLIDER,
key = key, key = key,
defaultValue = defaultValue, defaultValue = defaultValue,
valueType = valueType valueType = valueType,
) )
/** /**

View file

@ -209,7 +209,7 @@ enum class ConfigKey(val label: String) {
SUPPORT_IMAGE("Support image"), SUPPORT_IMAGE("Support image"),
MAX_RESULT_COUNT("Max result count"), MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"), USE_GPU("Use GPU"),
ACCELERATOR("Accelerator"), ACCELERATOR("Choose accelerator"),
COMPATIBLE_ACCELERATORS("Compatible accelerators"), COMPATIBLE_ACCELERATORS("Compatible accelerators"),
WARM_UP_ITERATIONS("Warm up iterations"), WARM_UP_ITERATIONS("Warm up iterations"),
BENCHMARK_ITERATIONS("Benchmark iterations"), BENCHMARK_ITERATIONS("Benchmark iterations"),

View file

@ -16,12 +16,14 @@
package com.google.ai.edge.gallery.ui.common package com.google.ai.edge.gallery.ui.common
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.MapsUgc import androidx.compose.material.icons.rounded.MapsUgc
@ -42,7 +44,7 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.scale import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource import androidx.compose.ui.res.vectorResource
@ -68,7 +70,7 @@ fun ModelPageAppBar(
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
isResettingSession: Boolean = false, isResettingSession: Boolean = false,
onResetSessionClicked: (Model) -> Unit = {}, onResetSessionClicked: (Model) -> Unit = {},
showResetSessionButton: Boolean = false, canShowResetSessionButton: Boolean = false,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> }, onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
) { ) {
var showConfigDialog by remember { mutableStateOf(false) } var showConfigDialog by remember { mutableStateOf(false) }
@ -96,7 +98,7 @@ fun ModelPageAppBar(
) )
Text( Text(
task.type.label, task.type.label,
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold), style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.SemiBold),
color = getTaskIconColor(task = task) color = getTaskIconColor(task = task)
) )
} }
@ -121,11 +123,12 @@ fun ModelPageAppBar(
}, },
// The config button for the model (if existed). // The config button for the model (if existed).
actions = { actions = {
val showConfigButton = val downloadSucceeded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED val showConfigButton = model.configs.isNotEmpty() && downloadSucceeded
val showResetSessionButton = canShowResetSessionButton && downloadSucceeded
Box(modifier = Modifier.size(42.dp), contentAlignment = Alignment.Center) { Box(modifier = Modifier.size(42.dp), contentAlignment = Alignment.Center) {
var configButtonOffset = 0.dp var configButtonOffset = 0.dp
if (showConfigButton && showResetSessionButton) { if (showConfigButton && canShowResetSessionButton) {
configButtonOffset = (-40).dp configButtonOffset = (-40).dp
} }
val isModelInitializing = val isModelInitializing =
@ -138,14 +141,14 @@ fun ModelPageAppBar(
}, },
enabled = enableConfigButton, enabled = enableConfigButton,
modifier = Modifier modifier = Modifier
.scale(0.75f)
.offset(x = configButtonOffset) .offset(x = configButtonOffset)
.alpha(if (!enableConfigButton) 0.5f else 1f) .alpha(if (!enableConfigButton) 0.5f else 1f)
) { ) {
Icon( Icon(
imageVector = Icons.Rounded.Tune, imageVector = Icons.Rounded.Tune,
contentDescription = "", contentDescription = "",
tint = MaterialTheme.colorScheme.primary tint = MaterialTheme.colorScheme.primary,
modifier = Modifier.size(20.dp)
) )
} }
} }
@ -164,14 +167,23 @@ fun ModelPageAppBar(
}, },
enabled = enableResetButton, enabled = enableResetButton,
modifier = Modifier modifier = Modifier
.scale(0.75f)
.alpha(if (!enableResetButton) 0.5f else 1f) .alpha(if (!enableResetButton) 0.5f else 1f)
) { ) {
Icon( Box(
imageVector = Icons.Rounded.MapsUgc, modifier = Modifier
contentDescription = "", .size(32.dp)
tint = MaterialTheme.colorScheme.primary .clip(CircleShape)
) .background(MaterialTheme.colorScheme.surfaceContainer),
contentAlignment = Alignment.Center
) {
Icon(
imageVector = Icons.Rounded.MapsUgc,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary,
modifier = Modifier
.size(20.dp)
)
}
} }
} }
} }

View file

@ -16,6 +16,7 @@
package com.google.ai.edge.gallery.ui.common package com.google.ai.edge.gallery.ui.common
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
@ -26,6 +27,7 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width import androidx.compose.foundation.layout.width
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.CheckCircle
import androidx.compose.material.icons.outlined.CheckCircle import androidx.compose.material.icons.outlined.CheckCircle
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
@ -35,6 +37,7 @@ import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource import androidx.compose.ui.res.vectorResource
@ -83,6 +86,7 @@ fun ModelPicker(
// Model list. // Model list.
for (model in task.models) { for (model in task.models) {
val selected = model.name == modelManagerUiState.selectedModel.name
Row( Row(
verticalAlignment = Alignment.CenterVertically, verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween, horizontalArrangement = Arrangement.SpaceBetween,
@ -91,12 +95,16 @@ fun ModelPicker(
.clickable { .clickable {
onModelSelected(model) onModelSelected(model)
} }
.background(if (selected) MaterialTheme.colorScheme.surfaceContainer else Color.Transparent)
.padding(horizontal = 16.dp, vertical = 8.dp), .padding(horizontal = 16.dp, vertical = 8.dp),
) { ) {
Spacer(modifier = Modifier.width(24.dp)) Spacer(modifier = Modifier.width(24.dp))
Column(modifier = Modifier.weight(1f)) { Column(modifier = Modifier.weight(1f)) {
Text(model.name, style = MaterialTheme.typography.bodyMedium) Text(model.name, style = MaterialTheme.typography.bodyMedium)
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) { Row(
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically
) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name]) StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text( Text(
model.sizeInBytes.humanReadableSize(), model.sizeInBytes.humanReadableSize(),
@ -105,9 +113,9 @@ fun ModelPicker(
) )
} }
} }
if (model.name == modelManagerUiState.selectedModel.name) { if (selected) {
Icon( Icon(
Icons.Outlined.CheckCircle, Icons.Filled.CheckCircle,
modifier = Modifier.size(16.dp), modifier = Modifier.size(16.dp),
contentDescription = "" contentDescription = ""
) )

View file

@ -130,8 +130,8 @@ fun ModelPickerChipsPager(
showModelPicker = true showModelPicker = true
} }
.padding(start = 8.dp, end = 2.dp) .padding(start = 8.dp, end = 2.dp)
.padding(vertical = 1.dp)) Inner@{ .padding(vertical = 4.dp)) Inner@{
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(18.dp)) { Box(contentAlignment = Alignment.Center, modifier = Modifier.size(21.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name]) StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
this@Inner.AnimatedVisibility( this@Inner.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING, visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
@ -141,7 +141,7 @@ fun ModelPickerChipsPager(
// Circular progress indicator. // Circular progress indicator.
CircularProgressIndicator( CircularProgressIndicator(
modifier = Modifier modifier = Modifier
.size(17.dp) .size(24.dp)
.alpha(0.5f), .alpha(0.5f),
strokeWidth = 2.dp, strokeWidth = 2.dp,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant
@ -150,7 +150,7 @@ fun ModelPickerChipsPager(
} }
Text( Text(
model.name, model.name,
style = MaterialTheme.typography.labelMedium, style = MaterialTheme.typography.labelLarge,
modifier = Modifier modifier = Modifier
.padding(start = 4.dp) .padding(start = 4.dp)
.widthIn(0.dp, screenWidthDp - 250.dp), .widthIn(0.dp, screenWidthDp - 250.dp),

View file

@ -80,6 +80,7 @@ import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.TaskType import com.google.ai.edge.gallery.data.TaskType
@ -267,7 +268,7 @@ fun ChatPanel(
// Sender row. // Sender row.
MessageSender( MessageSender(
message = message, message = message,
agentNameRes = task.agentNameRes, agentName = stringResource(task.agentNameRes),
imageHistoryCurIndex = imageHistoryCurIndex.intValue imageHistoryCurIndex = imageHistoryCurIndex.intValue
) )

View file

@ -174,7 +174,7 @@ fun ChatView(
task = task, task = task,
model = selectedModel, model = selectedModel,
modelManagerViewModel = modelManagerViewModel, modelManagerViewModel = modelManagerViewModel,
showResetSessionButton = true, canShowResetSessionButton = true,
isResettingSession = uiState.isResettingSession, isResettingSession = uiState.isResettingSession,
inProgress = uiState.inProgress, inProgress = uiState.inProgress,
modelPreparing = uiState.preparing, modelPreparing = uiState.preparing,

View file

@ -44,12 +44,12 @@ fun MarkdownText(
smallFontSize: Boolean = false smallFontSize: Boolean = false
) { ) {
val fontSize = val fontSize =
if (smallFontSize) MaterialTheme.typography.bodySmall.fontSize else MaterialTheme.typography.bodyMedium.fontSize if (smallFontSize) MaterialTheme.typography.bodyMedium.fontSize else MaterialTheme.typography.bodyLarge.fontSize
CompositionLocalProvider { CompositionLocalProvider {
ProvideTextStyle( ProvideTextStyle(
value = TextStyle( value = TextStyle(
fontSize = fontSize, fontSize = fontSize,
lineHeight = fontSize * 1.4, lineHeight = fontSize * 1.3,
) )
) { ) {
RichText( RichText(

View file

@ -38,7 +38,7 @@ fun MessageBodyText(message: ChatMessageText) {
if (message.side == ChatSide.USER) { if (message.side == ChatSide.USER) {
Text( Text(
message.content, message.content,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.Medium), style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.Medium),
color = Color.White, color = Color.White,
modifier = Modifier.padding(12.dp) modifier = Modifier.padding(12.dp)
) )

View file

@ -23,10 +23,18 @@ import android.graphics.Bitmap
import android.graphics.BitmapFactory import android.graphics.BitmapFactory
import android.graphics.Matrix import android.graphics.Matrix
import android.net.Uri import android.net.Uri
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.PickVisualMediaRequest import androidx.activity.result.PickVisualMediaRequest
import androidx.activity.result.contract.ActivityResultContracts import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.StringRes import androidx.annotation.StringRes
import androidx.camera.core.CameraControl
import androidx.camera.core.CameraSelector
import androidx.camera.core.ImageCapture
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.camera.lifecycle.awaitInstance
import androidx.camera.view.LifecycleCameraController
import androidx.camera.view.PreviewView
import androidx.compose.foundation.Image import androidx.compose.foundation.Image
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.border import androidx.compose.foundation.border
@ -36,6 +44,7 @@ import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset import androidx.compose.foundation.layout.offset
@ -53,16 +62,21 @@ import androidx.compose.material.icons.rounded.Photo
import androidx.compose.material.icons.rounded.PhotoCamera import androidx.compose.material.icons.rounded.PhotoCamera
import androidx.compose.material.icons.rounded.PostAdd import androidx.compose.material.icons.rounded.PostAdd
import androidx.compose.material.icons.rounded.Stop import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material3.Button
import androidx.compose.material3.DropdownMenu import androidx.compose.material3.DropdownMenu
import androidx.compose.material3.DropdownMenuItem import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextField import androidx.compose.material3.TextField
import androidx.compose.material3.TextFieldDefaults import androidx.compose.material3.TextFieldDefaults
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
@ -73,18 +87,23 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.shadow import androidx.compose.ui.draw.shadow
import androidx.compose.ui.focus.focusModifier
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.viewinterop.AndroidView
import androidx.core.content.ContextCompat import androidx.core.content.ContextCompat
import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.compose.LocalLifecycleOwner
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.common.createTempPictureUri import com.google.ai.edge.gallery.ui.common.createTempPictureUri
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.GalleryTheme import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import java.util.concurrent.Executors
/** /**
* Composable function to display a text input field for composing chat messages. * Composable function to display a text input field for composing chat messages.
@ -92,6 +111,7 @@ import com.google.ai.edge.gallery.ui.theme.GalleryTheme
* This function renders a row containing a text field for message input and a send button. * This function renders a row containing a text field for message input and a send button.
* It handles message composition, input validation, and sending messages. * It handles message composition, input validation, and sending messages.
*/ */
@OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun MessageInputText( fun MessageInputText(
modelManagerViewModel: ModelManagerViewModel, modelManagerViewModel: ModelManagerViewModel,
@ -114,6 +134,8 @@ fun MessageInputText(
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
var showAddContentMenu by remember { mutableStateOf(false) } var showAddContentMenu by remember { mutableStateOf(false) }
var showTextInputHistorySheet by remember { mutableStateOf(false) } var showTextInputHistorySheet by remember { mutableStateOf(false) }
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
var cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) } var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) } var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
val updatePickedImages: (Bitmap) -> Unit = { bitmap -> val updatePickedImages: (Bitmap) -> Unit = { bitmap ->
@ -145,6 +167,7 @@ fun MessageInputText(
if (permissionGranted) { if (permissionGranted) {
showAddContentMenu = false showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri() tempPhotoUri = context.createTempPictureUri()
// showCameraCaptureBottomSheet = true
cameraLauncher.launch(tempPhotoUri) cameraLauncher.launch(tempPhotoUri)
} }
} }
@ -241,6 +264,7 @@ fun MessageInputText(
) -> { ) -> {
showAddContentMenu = false showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri() tempPhotoUri = context.createTempPictureUri()
// showCameraCaptureBottomSheet = true
cameraLauncher.launch(tempPhotoUri) cameraLauncher.launch(tempPhotoUri)
} }
@ -313,7 +337,7 @@ fun MessageInputText(
disabledIndicatorColor = Color.Transparent, disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent, disabledContainerColor = Color.Transparent,
), ),
textStyle = MaterialTheme.typography.bodyMedium, textStyle = MaterialTheme.typography.bodyLarge,
modifier = Modifier modifier = Modifier
.weight(1f) .weight(1f)
.padding(start = 36.dp), .padding(start = 36.dp),
@ -379,6 +403,90 @@ fun MessageInputText(
modelManagerViewModel.clearTextInputHistory() modelManagerViewModel.clearTextInputHistory()
}) })
} }
if (showCameraCaptureBottomSheet) {
ModalBottomSheet(
sheetState = cameraCaptureSheetState,
onDismissRequest = { showCameraCaptureBottomSheet = false }) {
val lifecycleOwner = LocalLifecycleOwner.current
val previewUseCase = remember { androidx.camera.core.Preview.Builder().build() }
val imageCaptureUseCase = remember { ImageCapture.Builder().build() }
var cameraProvider by remember { mutableStateOf<ProcessCameraProvider?>(null) }
var cameraControl by remember { mutableStateOf<CameraControl?>(null) }
val localContext = LocalContext.current
val executor = remember { Executors.newSingleThreadExecutor() }
val capturedImageUri = remember { mutableStateOf<Uri?>(null) }
fun rebindCameraProvider() {
cameraProvider?.let { cameraProvider ->
val cameraSelector = CameraSelector.Builder()
.requireLensFacing(CameraSelector.LENS_FACING_FRONT)
.build()
cameraProvider.unbindAll()
val camera = cameraProvider.bindToLifecycle(
lifecycleOwner = lifecycleOwner,
cameraSelector = cameraSelector,
previewUseCase,
imageCaptureUseCase
)
cameraControl = camera.cameraControl
}
}
LaunchedEffect(Unit) {
cameraProvider = ProcessCameraProvider.awaitInstance(localContext)
rebindCameraProvider()
}
// val cameraController = remember {
// LifecycleCameraController(context).apply {
// bindToLifecycle(lifecycleOwner)
// }
// }
Box(modifier = Modifier.fillMaxSize()) {
// PreviewView for the camera feed.
AndroidView(
modifier = Modifier.fillMaxSize(),
factory = { ctx ->
PreviewView(context).also {
previewUseCase.surfaceProvider = it.surfaceProvider
rebindCameraProvider()
}
// PreviewView(ctx).apply {
// scaleType = PreviewView.ScaleType.FILL_START
// implementationMode = PreviewView.ImplementationMode.COMPATIBLE
// controller = cameraController // Attach the lifecycle-aware camera controller.
// }
},
// onRelease = {
// // Called when the PreviewView is removed from the composable hierarchy
// cameraController.unbind() // Unbinds the camera to free up resources
// }
)
// Button that triggers the image capture process
IconButton(
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.tertiaryContainer,
),
modifier = Modifier
.align(Alignment.BottomCenter)
.padding(bottom = 32.dp)
.size(64.dp),
onClick = {
},
) {
Icon(
Icons.Rounded.PhotoCamera,
contentDescription = "",
tint = MaterialTheme.colorScheme.onTertiaryContainer,
modifier = Modifier.size(36.dp)
)
}
}
}
}
} }
private fun handleImageSelected( private fun handleImageSelected(

View file

@ -17,7 +17,6 @@
package com.google.ai.edge.gallery.ui.common.chat package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap import android.graphics.Bitmap
import androidx.annotation.StringRes
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
@ -38,7 +37,6 @@ import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.theme.GalleryTheme import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow
import com.google.ai.edge.gallery.ui.theme.bodySmallSemiBold
data class MessageLayoutConfig( data class MessageLayoutConfig(
val horizontalArrangement: Arrangement.Horizontal, val horizontalArrangement: Arrangement.Horizontal,
@ -56,7 +54,9 @@ data class MessageLayoutConfig(
*/ */
@Composable @Composable
fun MessageSender( fun MessageSender(
message: ChatMessage, @StringRes agentNameRes: Int, imageHistoryCurIndex: Int = 0 message: ChatMessage,
agentName: String = "",
imageHistoryCurIndex: Int = 0
) { ) {
// No user label for system messages. // No user label for system messages.
if (message.side == ChatSide.SYSTEM) { if (message.side == ChatSide.SYSTEM) {
@ -64,7 +64,7 @@ fun MessageSender(
} }
val (horizontalArrangement, modifier, userLabel, rightSideLabel) = getMessageLayoutConfig( val (horizontalArrangement, modifier, userLabel, rightSideLabel) = getMessageLayoutConfig(
message = message, agentNameRes = agentNameRes, imageHistoryCurIndex = imageHistoryCurIndex message = message, agentName = agentName, imageHistoryCurIndex = imageHistoryCurIndex
) )
Row( Row(
@ -76,7 +76,7 @@ fun MessageSender(
// Sender label. // Sender label.
Text( Text(
userLabel, userLabel,
style = bodySmallSemiBold, style = MaterialTheme.typography.titleSmall,
) )
when (message) { when (message) {
@ -152,7 +152,7 @@ fun MessageSender(
@Composable @Composable
private fun getMessageLayoutConfig( private fun getMessageLayoutConfig(
message: ChatMessage, message: ChatMessage,
@StringRes agentNameRes: Int, agentName: String,
imageHistoryCurIndex: Int, imageHistoryCurIndex: Int,
): MessageLayoutConfig { ): MessageLayoutConfig {
var userLabel = stringResource(R.string.chat_you) var userLabel = stringResource(R.string.chat_you)
@ -161,7 +161,7 @@ private fun getMessageLayoutConfig(
var modifier = Modifier.padding(bottom = 2.dp) var modifier = Modifier.padding(bottom = 2.dp)
if (message.side == ChatSide.AGENT) { if (message.side == ChatSide.AGENT) {
userLabel = stringResource(agentNameRes) userLabel = agentName
} }
when (message) { when (message) {
@ -207,12 +207,12 @@ fun MessageSenderPreview() {
// Agent message. // Agent message.
MessageSender( MessageSender(
message = ChatMessageText(content = "hello world", side = ChatSide.AGENT), message = ChatMessageText(content = "hello world", side = ChatSide.AGENT),
agentNameRes = R.string.chat_generic_agent_name agentName = stringResource(R.string.chat_generic_agent_name)
) )
// User message. // User message.
MessageSender( MessageSender(
message = ChatMessageText(content = "hello world", side = ChatSide.USER), message = ChatMessageText(content = "hello world", side = ChatSide.USER),
agentNameRes = R.string.chat_generic_agent_name agentName = stringResource(R.string.chat_generic_agent_name)
) )
// Benchmark during warmup. // Benchmark during warmup.
MessageSender( MessageSender(
@ -225,7 +225,8 @@ fun MessageSenderPreview() {
warmupTotal = 50, warmupTotal = 50,
iterationCurrent = 0, iterationCurrent = 0,
iterationTotal = 200 iterationTotal = 200
), agentNameRes = R.string.chat_generic_agent_name ),
agentName = stringResource(R.string.chat_generic_agent_name)
) )
// Benchmark during running. // Benchmark during running.
MessageSender( MessageSender(
@ -238,7 +239,8 @@ fun MessageSenderPreview() {
warmupTotal = 50, warmupTotal = 50,
iterationCurrent = 123, iterationCurrent = 123,
iterationTotal = 200 iterationTotal = 200
), agentNameRes = R.string.chat_generic_agent_name ),
agentName = stringResource(R.string.chat_generic_agent_name)
) )
// Image generation during running. // Image generation during running.
MessageSender( MessageSender(
@ -248,7 +250,7 @@ fun MessageSenderPreview() {
totalIterations = 10, totalIterations = 10,
ChatSide.AGENT ChatSide.AGENT
), ),
agentNameRes = R.string.chat_generic_agent_name, agentName = stringResource(R.string.chat_generic_agent_name),
imageHistoryCurIndex = 4, imageHistoryCurIndex = 4,
) )
} }

View file

@ -226,7 +226,7 @@ fun ModelDownloadingAnimation(
Text( Text(
sizeLabel, sizeLabel,
color = MaterialTheme.colorScheme.secondary, color = MaterialTheme.colorScheme.secondary,
style = labelSmallNarrow.copy(fontSize = 9.sp, lineHeight = 10.sp), style = MaterialTheme.typography.labelMedium,
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
overflow = TextOverflow.Visible, overflow = TextOverflow.Visible,
modifier = Modifier modifier = Modifier
@ -266,7 +266,7 @@ fun ModelDownloadingAnimation(
"Feel free to switch apps or lock your device.\n" "Feel free to switch apps or lock your device.\n"
+ "The download will continue in the background.\n" + "The download will continue in the background.\n"
+ "We'll send a notification when it's done.", + "We'll send a notification when it's done.",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Center textAlign = TextAlign.Center
) )
} }

View file

@ -38,6 +38,8 @@ import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.ui.theme.GalleryTheme import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors import com.google.ai.edge.gallery.ui.theme.customColors
private val SIZE = 18.dp
/** /**
* Composable function to display an icon representing the download status of a model. * Composable function to display an icon representing the download status of a model.
*/ */
@ -53,7 +55,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
Icons.AutoMirrored.Outlined.HelpOutline, Icons.AutoMirrored.Outlined.HelpOutline,
tint = Color(0xFFCCCCCC), tint = Color(0xFFCCCCCC),
contentDescription = "", contentDescription = "",
modifier = Modifier.size(14.dp) modifier = Modifier.size(SIZE)
) )
ModelDownloadStatusType.SUCCEEDED -> { ModelDownloadStatusType.SUCCEEDED -> {
@ -61,7 +63,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
Icons.Filled.DownloadForOffline, Icons.Filled.DownloadForOffline,
tint = MaterialTheme.customColors.successColor, tint = MaterialTheme.customColors.successColor,
contentDescription = "", contentDescription = "",
modifier = Modifier.size(14.dp) modifier = Modifier.size(SIZE)
) )
} }
@ -69,13 +71,13 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
Icons.Rounded.Error, Icons.Rounded.Error,
tint = Color(0xFFAA0000), tint = Color(0xFFAA0000),
contentDescription = "", contentDescription = "",
modifier = Modifier.size(14.dp) modifier = Modifier.size(SIZE)
) )
ModelDownloadStatusType.IN_PROGRESS -> Icon( ModelDownloadStatusType.IN_PROGRESS -> Icon(
Icons.Rounded.Downloading, Icons.Rounded.Downloading,
contentDescription = "", contentDescription = "",
modifier = Modifier.size(14.dp) modifier = Modifier.size(SIZE)
) )
else -> {} else -> {}

View file

@ -387,7 +387,7 @@ private fun TaskList(
Text( Text(
introText, introText,
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.padding(bottom = 20.dp) modifier = Modifier.padding(bottom = 20.dp)
) )
} }

View file

@ -19,6 +19,7 @@ package com.google.ai.edge.gallery.ui.llmchat
import com.google.ai.edge.gallery.data.Accelerator import com.google.ai.edge.gallery.data.Accelerator
import com.google.ai.edge.gallery.data.Config import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.NumberSliderConfig import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.ValueType import com.google.ai.edge.gallery.data.ValueType
@ -37,12 +38,9 @@ fun createLlmChatConfigs(
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS, accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
): List<Config> { ): List<Config> {
return listOf( return listOf(
NumberSliderConfig( LabelConfig(
key = ConfigKey.MAX_TOKENS, key = ConfigKey.MAX_TOKENS,
sliderMin = 100f, defaultValue = "$defaultMaxToken",
sliderMax = 1024f,
defaultValue = defaultMaxToken.toFloat(),
valueType = ValueType.INT
), ),
NumberSliderConfig( NumberSliderConfig(
key = ConfigKey.TOPK, key = ConfigKey.TOPK,

View file

@ -84,27 +84,31 @@ object LlmChatModelHelper {
} }
fun resetSession(model: Model) { fun resetSession(model: Model) {
Log.d(TAG, "Resetting session for model '${model.name}'") try {
Log.d(TAG, "Resetting session for model '${model.name}'")
val instance = model.instance as LlmModelInstance? ?: return val instance = model.instance as LlmModelInstance? ?: return
val session = instance.session val session = instance.session
session.close() session.close()
val inference = instance.engine val inference = instance.engine
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK) val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP) val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature = val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE) model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
val newSession = LlmInferenceSession.createFromOptions( val newSession = LlmInferenceSession.createFromOptions(
inference, inference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP) LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature) .setTemperature(temperature)
.setGraphOptions( .setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build() GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
).build() ).build()
) )
instance.session = newSession instance.session = newSession
Log.d(TAG, "Resetting done") Log.d(TAG, "Resetting done")
} catch (e: Exception) {
Log.d(TAG, "Failed to reset session", e)
}
} }
fun cleanUp(model: Model) { fun cleanUp(model: Model) {
@ -114,7 +118,7 @@ object LlmChatModelHelper {
val instance = model.instance as LlmModelInstance val instance = model.instance as LlmModelInstance
try { try {
instance.session.close() // This will also close the session. Do not call session.close manually.
instance.engine.close() instance.engine.close()
} catch (e: Exception) { } catch (e: Exception) {
// ignore // ignore

View file

@ -241,7 +241,7 @@ fun PromptTemplatesPanel(
disabledIndicatorColor = Color.Transparent, disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent, disabledContainerColor = Color.Transparent,
), ),
textStyle = MaterialTheme.typography.bodyMedium, textStyle = MaterialTheme.typography.bodyLarge,
placeholder = { Text("Enter content") }, placeholder = { Text("Enter content") },
modifier = Modifier modifier = Modifier
.padding(bottom = 40.dp) .padding(bottom = 40.dp)

View file

@ -102,7 +102,7 @@ fun ModelList(
Text( Text(
task.description, task.description,
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold), style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.fillMaxWidth() modifier = Modifier.fillMaxWidth()
) )
} }
@ -200,7 +200,7 @@ fun ClickableLink(
Text( Text(
text = annotatedText, text = annotatedText,
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodySmall, style = MaterialTheme.typography.bodyLarge,
modifier = Modifier modifier = Modifier
.padding(start = 6.dp) .padding(start = 6.dp)
.clickable { .clickable {

View file

@ -79,6 +79,9 @@ val bodySmallNarrow =
val bodySmallSemiBold = val bodySmallSemiBold =
baseline.bodySmall.copy(fontFamily = nunitoFontFamily, fontWeight = FontWeight.SemiBold) baseline.bodySmall.copy(fontFamily = nunitoFontFamily, fontWeight = FontWeight.SemiBold)
val bodyMediumSemiBold =
baseline.bodyMedium.copy(fontFamily = nunitoFontFamily, fontWeight = FontWeight.SemiBold)
val bodySmallMediumNarrow = val bodySmallMediumNarrow =
baseline.bodySmall.copy(fontFamily = nunitoFontFamily, letterSpacing = 0.0.sp, fontSize = 14.sp) baseline.bodySmall.copy(fontFamily = nunitoFontFamily, letterSpacing = 0.0.sp, fontSize = 14.sp)

View file

@ -30,7 +30,7 @@ pluginManagement {
dependencyResolutionManagement { dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories { repositories {
mavenLocal() // mavenLocal()
google() google()
mavenCentral() mavenCentral()
} }