- Improve UX

- Fix a bug related to LLM inference engine cleanup.
This commit is contained in:
Jing Jin 2025-05-21 12:33:35 -07:00
parent b1aa6511dd
commit f9cab2f06d
21 changed files with 216 additions and 78 deletions

View file

@ -31,7 +31,7 @@ android {
minSdk = 26
targetSdk = 35
versionCode = 1
versionName = "1.0.0"
versionName = "1.0.1"
// Needed for HuggingFace auth workflows.
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.ai.edge.gallery.oauth"

View file

@ -89,7 +89,7 @@ class NumberSliderConfig(
type = ConfigEditorType.NUMBER_SLIDER,
key = key,
defaultValue = defaultValue,
valueType = valueType
valueType = valueType,
)
/**

View file

@ -209,7 +209,7 @@ enum class ConfigKey(val label: String) {
SUPPORT_IMAGE("Support image"),
MAX_RESULT_COUNT("Max result count"),
USE_GPU("Use GPU"),
ACCELERATOR("Accelerator"),
ACCELERATOR("Choose accelerator"),
COMPATIBLE_ACCELERATORS("Compatible accelerators"),
WARM_UP_ITERATIONS("Warm up iterations"),
BENCHMARK_ITERATIONS("Benchmark iterations"),

View file

@ -16,12 +16,14 @@
package com.google.ai.edge.gallery.ui.common
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
import androidx.compose.material.icons.rounded.MapsUgc
@ -42,7 +44,7 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.scale
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource
@ -68,7 +70,7 @@ fun ModelPageAppBar(
modifier: Modifier = Modifier,
isResettingSession: Boolean = false,
onResetSessionClicked: (Model) -> Unit = {},
showResetSessionButton: Boolean = false,
canShowResetSessionButton: Boolean = false,
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
) {
var showConfigDialog by remember { mutableStateOf(false) }
@ -96,7 +98,7 @@ fun ModelPageAppBar(
)
Text(
task.type.label,
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.SemiBold),
color = getTaskIconColor(task = task)
)
}
@ -121,11 +123,12 @@ fun ModelPageAppBar(
},
// The config button for the model (if existed).
actions = {
val showConfigButton =
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
val downloadSucceeded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
val showConfigButton = model.configs.isNotEmpty() && downloadSucceeded
val showResetSessionButton = canShowResetSessionButton && downloadSucceeded
Box(modifier = Modifier.size(42.dp), contentAlignment = Alignment.Center) {
var configButtonOffset = 0.dp
if (showConfigButton && showResetSessionButton) {
if (showConfigButton && canShowResetSessionButton) {
configButtonOffset = (-40).dp
}
val isModelInitializing =
@ -138,14 +141,14 @@ fun ModelPageAppBar(
},
enabled = enableConfigButton,
modifier = Modifier
.scale(0.75f)
.offset(x = configButtonOffset)
.alpha(if (!enableConfigButton) 0.5f else 1f)
) {
Icon(
imageVector = Icons.Rounded.Tune,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
tint = MaterialTheme.colorScheme.primary,
modifier = Modifier.size(20.dp)
)
}
}
@ -164,14 +167,23 @@ fun ModelPageAppBar(
},
enabled = enableResetButton,
modifier = Modifier
.scale(0.75f)
.alpha(if (!enableResetButton) 0.5f else 1f)
) {
Icon(
imageVector = Icons.Rounded.MapsUgc,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary
)
Box(
modifier = Modifier
.size(32.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.surfaceContainer),
contentAlignment = Alignment.Center
) {
Icon(
imageVector = Icons.Rounded.MapsUgc,
contentDescription = "",
tint = MaterialTheme.colorScheme.primary,
modifier = Modifier
.size(20.dp)
)
}
}
}
}

View file

@ -16,6 +16,7 @@
package com.google.ai.edge.gallery.ui.common
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
@ -26,6 +27,7 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.layout.width
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.CheckCircle
import androidx.compose.material.icons.outlined.CheckCircle
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
@ -35,6 +37,7 @@ import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.vectorResource
@ -83,6 +86,7 @@ fun ModelPicker(
// Model list.
for (model in task.models) {
val selected = model.name == modelManagerUiState.selectedModel.name
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween,
@ -91,12 +95,16 @@ fun ModelPicker(
.clickable {
onModelSelected(model)
}
.background(if (selected) MaterialTheme.colorScheme.surfaceContainer else Color.Transparent)
.padding(horizontal = 16.dp, vertical = 8.dp),
) {
Spacer(modifier = Modifier.width(24.dp))
Column(modifier = Modifier.weight(1f)) {
Text(model.name, style = MaterialTheme.typography.bodyMedium)
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
Row(
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically
) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
Text(
model.sizeInBytes.humanReadableSize(),
@ -105,9 +113,9 @@ fun ModelPicker(
)
}
}
if (model.name == modelManagerUiState.selectedModel.name) {
if (selected) {
Icon(
Icons.Outlined.CheckCircle,
Icons.Filled.CheckCircle,
modifier = Modifier.size(16.dp),
contentDescription = ""
)

View file

@ -130,8 +130,8 @@ fun ModelPickerChipsPager(
showModelPicker = true
}
.padding(start = 8.dp, end = 2.dp)
.padding(vertical = 1.dp)) Inner@{
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(18.dp)) {
.padding(vertical = 4.dp)) Inner@{
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(21.dp)) {
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
this@Inner.AnimatedVisibility(
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
@ -141,7 +141,7 @@ fun ModelPickerChipsPager(
// Circular progress indicator.
CircularProgressIndicator(
modifier = Modifier
.size(17.dp)
.size(24.dp)
.alpha(0.5f),
strokeWidth = 2.dp,
color = MaterialTheme.colorScheme.onSurfaceVariant
@ -150,7 +150,7 @@ fun ModelPickerChipsPager(
}
Text(
model.name,
style = MaterialTheme.typography.labelMedium,
style = MaterialTheme.typography.labelLarge,
modifier = Modifier
.padding(start = 4.dp)
.widthIn(0.dp, screenWidthDp - 250.dp),

View file

@ -80,6 +80,7 @@ import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.TaskType
@ -267,7 +268,7 @@ fun ChatPanel(
// Sender row.
MessageSender(
message = message,
agentNameRes = task.agentNameRes,
agentName = stringResource(task.agentNameRes),
imageHistoryCurIndex = imageHistoryCurIndex.intValue
)

View file

@ -174,7 +174,7 @@ fun ChatView(
task = task,
model = selectedModel,
modelManagerViewModel = modelManagerViewModel,
showResetSessionButton = true,
canShowResetSessionButton = true,
isResettingSession = uiState.isResettingSession,
inProgress = uiState.inProgress,
modelPreparing = uiState.preparing,

View file

@ -44,12 +44,12 @@ fun MarkdownText(
smallFontSize: Boolean = false
) {
val fontSize =
if (smallFontSize) MaterialTheme.typography.bodySmall.fontSize else MaterialTheme.typography.bodyMedium.fontSize
if (smallFontSize) MaterialTheme.typography.bodyMedium.fontSize else MaterialTheme.typography.bodyLarge.fontSize
CompositionLocalProvider {
ProvideTextStyle(
value = TextStyle(
fontSize = fontSize,
lineHeight = fontSize * 1.4,
lineHeight = fontSize * 1.3,
)
) {
RichText(

View file

@ -38,7 +38,7 @@ fun MessageBodyText(message: ChatMessageText) {
if (message.side == ChatSide.USER) {
Text(
message.content,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.Medium),
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.Medium),
color = Color.White,
modifier = Modifier.padding(12.dp)
)

View file

@ -23,10 +23,18 @@ import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Matrix
import android.net.Uri
import android.util.Log
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.PickVisualMediaRequest
import androidx.activity.result.contract.ActivityResultContracts
import androidx.annotation.StringRes
import androidx.camera.core.CameraControl
import androidx.camera.core.CameraSelector
import androidx.camera.core.ImageCapture
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.camera.lifecycle.awaitInstance
import androidx.camera.view.LifecycleCameraController
import androidx.camera.view.PreviewView
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.border
@ -36,6 +44,7 @@ import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.offset
@ -53,16 +62,21 @@ import androidx.compose.material.icons.rounded.Photo
import androidx.compose.material.icons.rounded.PhotoCamera
import androidx.compose.material.icons.rounded.PostAdd
import androidx.compose.material.icons.rounded.Stop
import androidx.compose.material3.Button
import androidx.compose.material3.DropdownMenu
import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.TextField
import androidx.compose.material3.TextFieldDefaults
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
@ -73,18 +87,23 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.shadow
import androidx.compose.ui.focus.focusModifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.viewinterop.AndroidView
import androidx.core.content.ContextCompat
import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.compose.LocalLifecycleOwner
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.common.createTempPictureUri
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import java.util.concurrent.Executors
/**
* Composable function to display a text input field for composing chat messages.
@ -92,6 +111,7 @@ import com.google.ai.edge.gallery.ui.theme.GalleryTheme
* This function renders a row containing a text field for message input and a send button.
* It handles message composition, input validation, and sending messages.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun MessageInputText(
modelManagerViewModel: ModelManagerViewModel,
@ -114,6 +134,8 @@ fun MessageInputText(
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
var showAddContentMenu by remember { mutableStateOf(false) }
var showTextInputHistorySheet by remember { mutableStateOf(false) }
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
var cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
val updatePickedImages: (Bitmap) -> Unit = { bitmap ->
@ -145,6 +167,7 @@ fun MessageInputText(
if (permissionGranted) {
showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri()
// showCameraCaptureBottomSheet = true
cameraLauncher.launch(tempPhotoUri)
}
}
@ -241,6 +264,7 @@ fun MessageInputText(
) -> {
showAddContentMenu = false
tempPhotoUri = context.createTempPictureUri()
// showCameraCaptureBottomSheet = true
cameraLauncher.launch(tempPhotoUri)
}
@ -313,7 +337,7 @@ fun MessageInputText(
disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyMedium,
textStyle = MaterialTheme.typography.bodyLarge,
modifier = Modifier
.weight(1f)
.padding(start = 36.dp),
@ -379,6 +403,90 @@ fun MessageInputText(
modelManagerViewModel.clearTextInputHistory()
})
}
if (showCameraCaptureBottomSheet) {
ModalBottomSheet(
sheetState = cameraCaptureSheetState,
onDismissRequest = { showCameraCaptureBottomSheet = false }) {
val lifecycleOwner = LocalLifecycleOwner.current
val previewUseCase = remember { androidx.camera.core.Preview.Builder().build() }
val imageCaptureUseCase = remember { ImageCapture.Builder().build() }
var cameraProvider by remember { mutableStateOf<ProcessCameraProvider?>(null) }
var cameraControl by remember { mutableStateOf<CameraControl?>(null) }
val localContext = LocalContext.current
val executor = remember { Executors.newSingleThreadExecutor() }
val capturedImageUri = remember { mutableStateOf<Uri?>(null) }
fun rebindCameraProvider() {
cameraProvider?.let { cameraProvider ->
val cameraSelector = CameraSelector.Builder()
.requireLensFacing(CameraSelector.LENS_FACING_FRONT)
.build()
cameraProvider.unbindAll()
val camera = cameraProvider.bindToLifecycle(
lifecycleOwner = lifecycleOwner,
cameraSelector = cameraSelector,
previewUseCase,
imageCaptureUseCase
)
cameraControl = camera.cameraControl
}
}
LaunchedEffect(Unit) {
cameraProvider = ProcessCameraProvider.awaitInstance(localContext)
rebindCameraProvider()
}
// val cameraController = remember {
// LifecycleCameraController(context).apply {
// bindToLifecycle(lifecycleOwner)
// }
// }
Box(modifier = Modifier.fillMaxSize()) {
// PreviewView for the camera feed.
AndroidView(
modifier = Modifier.fillMaxSize(),
factory = { ctx ->
PreviewView(context).also {
previewUseCase.surfaceProvider = it.surfaceProvider
rebindCameraProvider()
}
// PreviewView(ctx).apply {
// scaleType = PreviewView.ScaleType.FILL_START
// implementationMode = PreviewView.ImplementationMode.COMPATIBLE
// controller = cameraController // Attach the lifecycle-aware camera controller.
// }
},
// onRelease = {
// // Called when the PreviewView is removed from the composable hierarchy
// cameraController.unbind() // Unbinds the camera to free up resources
// }
)
// Button that triggers the image capture process
IconButton(
colors = IconButtonDefaults.iconButtonColors(
containerColor = MaterialTheme.colorScheme.tertiaryContainer,
),
modifier = Modifier
.align(Alignment.BottomCenter)
.padding(bottom = 32.dp)
.size(64.dp),
onClick = {
},
) {
Icon(
Icons.Rounded.PhotoCamera,
contentDescription = "",
tint = MaterialTheme.colorScheme.onTertiaryContainer,
modifier = Modifier.size(36.dp)
)
}
}
}
}
}
private fun handleImageSelected(

View file

@ -17,7 +17,6 @@
package com.google.ai.edge.gallery.ui.common.chat
import android.graphics.Bitmap
import androidx.annotation.StringRes
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
@ -38,7 +37,6 @@ import androidx.compose.ui.unit.dp
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow
import com.google.ai.edge.gallery.ui.theme.bodySmallSemiBold
data class MessageLayoutConfig(
val horizontalArrangement: Arrangement.Horizontal,
@ -56,7 +54,9 @@ data class MessageLayoutConfig(
*/
@Composable
fun MessageSender(
message: ChatMessage, @StringRes agentNameRes: Int, imageHistoryCurIndex: Int = 0
message: ChatMessage,
agentName: String = "",
imageHistoryCurIndex: Int = 0
) {
// No user label for system messages.
if (message.side == ChatSide.SYSTEM) {
@ -64,7 +64,7 @@ fun MessageSender(
}
val (horizontalArrangement, modifier, userLabel, rightSideLabel) = getMessageLayoutConfig(
message = message, agentNameRes = agentNameRes, imageHistoryCurIndex = imageHistoryCurIndex
message = message, agentName = agentName, imageHistoryCurIndex = imageHistoryCurIndex
)
Row(
@ -76,7 +76,7 @@ fun MessageSender(
// Sender label.
Text(
userLabel,
style = bodySmallSemiBold,
style = MaterialTheme.typography.titleSmall,
)
when (message) {
@ -152,7 +152,7 @@ fun MessageSender(
@Composable
private fun getMessageLayoutConfig(
message: ChatMessage,
@StringRes agentNameRes: Int,
agentName: String,
imageHistoryCurIndex: Int,
): MessageLayoutConfig {
var userLabel = stringResource(R.string.chat_you)
@ -161,7 +161,7 @@ private fun getMessageLayoutConfig(
var modifier = Modifier.padding(bottom = 2.dp)
if (message.side == ChatSide.AGENT) {
userLabel = stringResource(agentNameRes)
userLabel = agentName
}
when (message) {
@ -207,12 +207,12 @@ fun MessageSenderPreview() {
// Agent message.
MessageSender(
message = ChatMessageText(content = "hello world", side = ChatSide.AGENT),
agentNameRes = R.string.chat_generic_agent_name
agentName = stringResource(R.string.chat_generic_agent_name)
)
// User message.
MessageSender(
message = ChatMessageText(content = "hello world", side = ChatSide.USER),
agentNameRes = R.string.chat_generic_agent_name
agentName = stringResource(R.string.chat_generic_agent_name)
)
// Benchmark during warmup.
MessageSender(
@ -225,7 +225,8 @@ fun MessageSenderPreview() {
warmupTotal = 50,
iterationCurrent = 0,
iterationTotal = 200
), agentNameRes = R.string.chat_generic_agent_name
),
agentName = stringResource(R.string.chat_generic_agent_name)
)
// Benchmark during running.
MessageSender(
@ -238,7 +239,8 @@ fun MessageSenderPreview() {
warmupTotal = 50,
iterationCurrent = 123,
iterationTotal = 200
), agentNameRes = R.string.chat_generic_agent_name
),
agentName = stringResource(R.string.chat_generic_agent_name)
)
// Image generation during running.
MessageSender(
@ -248,7 +250,7 @@ fun MessageSenderPreview() {
totalIterations = 10,
ChatSide.AGENT
),
agentNameRes = R.string.chat_generic_agent_name,
agentName = stringResource(R.string.chat_generic_agent_name),
imageHistoryCurIndex = 4,
)
}

View file

@ -226,7 +226,7 @@ fun ModelDownloadingAnimation(
Text(
sizeLabel,
color = MaterialTheme.colorScheme.secondary,
style = labelSmallNarrow.copy(fontSize = 9.sp, lineHeight = 10.sp),
style = MaterialTheme.typography.labelMedium,
textAlign = TextAlign.Center,
overflow = TextOverflow.Visible,
modifier = Modifier
@ -266,7 +266,7 @@ fun ModelDownloadingAnimation(
"Feel free to switch apps or lock your device.\n"
+ "The download will continue in the background.\n"
+ "We'll send a notification when it's done.",
style = MaterialTheme.typography.bodyMedium,
style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Center
)
}

View file

@ -38,6 +38,8 @@ import com.google.ai.edge.gallery.data.ModelDownloadStatusType
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.ai.edge.gallery.ui.theme.customColors
private val SIZE = 18.dp
/**
* Composable function to display an icon representing the download status of a model.
*/
@ -53,7 +55,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
Icons.AutoMirrored.Outlined.HelpOutline,
tint = Color(0xFFCCCCCC),
contentDescription = "",
modifier = Modifier.size(14.dp)
modifier = Modifier.size(SIZE)
)
ModelDownloadStatusType.SUCCEEDED -> {
@ -61,7 +63,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
Icons.Filled.DownloadForOffline,
tint = MaterialTheme.customColors.successColor,
contentDescription = "",
modifier = Modifier.size(14.dp)
modifier = Modifier.size(SIZE)
)
}
@ -69,13 +71,13 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
Icons.Rounded.Error,
tint = Color(0xFFAA0000),
contentDescription = "",
modifier = Modifier.size(14.dp)
modifier = Modifier.size(SIZE)
)
ModelDownloadStatusType.IN_PROGRESS -> Icon(
Icons.Rounded.Downloading,
contentDescription = "",
modifier = Modifier.size(14.dp)
modifier = Modifier.size(SIZE)
)
else -> {}

View file

@ -387,7 +387,7 @@ private fun TaskList(
Text(
introText,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.padding(bottom = 20.dp)
)
}

View file

@ -19,6 +19,7 @@ package com.google.ai.edge.gallery.ui.llmchat
import com.google.ai.edge.gallery.data.Accelerator
import com.google.ai.edge.gallery.data.Config
import com.google.ai.edge.gallery.data.ConfigKey
import com.google.ai.edge.gallery.data.LabelConfig
import com.google.ai.edge.gallery.data.NumberSliderConfig
import com.google.ai.edge.gallery.data.SegmentedButtonConfig
import com.google.ai.edge.gallery.data.ValueType
@ -37,12 +38,9 @@ fun createLlmChatConfigs(
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
): List<Config> {
return listOf(
NumberSliderConfig(
LabelConfig(
key = ConfigKey.MAX_TOKENS,
sliderMin = 100f,
sliderMax = 1024f,
defaultValue = defaultMaxToken.toFloat(),
valueType = ValueType.INT
defaultValue = "$defaultMaxToken",
),
NumberSliderConfig(
key = ConfigKey.TOPK,

View file

@ -84,27 +84,31 @@ object LlmChatModelHelper {
}
fun resetSession(model: Model) {
Log.d(TAG, "Resetting session for model '${model.name}'")
try {
Log.d(TAG, "Resetting session for model '${model.name}'")
val instance = model.instance as LlmModelInstance? ?: return
val session = instance.session
session.close()
val instance = model.instance as LlmModelInstance? ?: return
val session = instance.session
session.close()
val inference = instance.engine
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
val newSession = LlmInferenceSession.createFromOptions(
inference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
).build()
)
instance.session = newSession
Log.d(TAG, "Resetting done")
val inference = instance.engine
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
val temperature =
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
val newSession = LlmInferenceSession.createFromOptions(
inference,
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
.setTemperature(temperature)
.setGraphOptions(
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
).build()
)
instance.session = newSession
Log.d(TAG, "Resetting done")
} catch (e: Exception) {
Log.d(TAG, "Failed to reset session", e)
}
}
fun cleanUp(model: Model) {
@ -114,7 +118,7 @@ object LlmChatModelHelper {
val instance = model.instance as LlmModelInstance
try {
instance.session.close()
// This will also close the session. Do not call session.close manually.
instance.engine.close()
} catch (e: Exception) {
// ignore

View file

@ -241,7 +241,7 @@ fun PromptTemplatesPanel(
disabledIndicatorColor = Color.Transparent,
disabledContainerColor = Color.Transparent,
),
textStyle = MaterialTheme.typography.bodyMedium,
textStyle = MaterialTheme.typography.bodyLarge,
placeholder = { Text("Enter content") },
modifier = Modifier
.padding(bottom = 40.dp)

View file

@ -102,7 +102,7 @@ fun ModelList(
Text(
task.description,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.SemiBold),
modifier = Modifier.fillMaxWidth()
)
}
@ -200,7 +200,7 @@ fun ClickableLink(
Text(
text = annotatedText,
textAlign = TextAlign.Center,
style = MaterialTheme.typography.bodySmall,
style = MaterialTheme.typography.bodyLarge,
modifier = Modifier
.padding(start = 6.dp)
.clickable {

View file

@ -79,6 +79,9 @@ val bodySmallNarrow =
val bodySmallSemiBold =
baseline.bodySmall.copy(fontFamily = nunitoFontFamily, fontWeight = FontWeight.SemiBold)
val bodyMediumSemiBold =
baseline.bodyMedium.copy(fontFamily = nunitoFontFamily, fontWeight = FontWeight.SemiBold)
val bodySmallMediumNarrow =
baseline.bodySmall.copy(fontFamily = nunitoFontFamily, letterSpacing = 0.0.sp, fontSize = 14.sp)

View file

@ -30,7 +30,7 @@ pluginManagement {
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
mavenLocal()
// mavenLocal()
google()
mavenCentral()
}