mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-15 02:36:43 -04:00
- Improve UX
- Fix a bug related to LLM inference engine cleanup.
This commit is contained in:
parent
b1aa6511dd
commit
f9cab2f06d
21 changed files with 216 additions and 78 deletions
|
@ -31,7 +31,7 @@ android {
|
|||
minSdk = 26
|
||||
targetSdk = 35
|
||||
versionCode = 1
|
||||
versionName = "1.0.0"
|
||||
versionName = "1.0.1"
|
||||
|
||||
// Needed for HuggingFace auth workflows.
|
||||
manifestPlaceholders["appAuthRedirectScheme"] = "com.google.ai.edge.gallery.oauth"
|
||||
|
|
|
@ -89,7 +89,7 @@ class NumberSliderConfig(
|
|||
type = ConfigEditorType.NUMBER_SLIDER,
|
||||
key = key,
|
||||
defaultValue = defaultValue,
|
||||
valueType = valueType
|
||||
valueType = valueType,
|
||||
)
|
||||
|
||||
/**
|
||||
|
|
|
@ -209,7 +209,7 @@ enum class ConfigKey(val label: String) {
|
|||
SUPPORT_IMAGE("Support image"),
|
||||
MAX_RESULT_COUNT("Max result count"),
|
||||
USE_GPU("Use GPU"),
|
||||
ACCELERATOR("Accelerator"),
|
||||
ACCELERATOR("Choose accelerator"),
|
||||
COMPATIBLE_ACCELERATORS("Compatible accelerators"),
|
||||
WARM_UP_ITERATIONS("Warm up iterations"),
|
||||
BENCHMARK_ITERATIONS("Benchmark iterations"),
|
||||
|
|
|
@ -16,12 +16,14 @@
|
|||
|
||||
package com.google.ai.edge.gallery.ui.common
|
||||
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.offset
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.rounded.ArrowBack
|
||||
import androidx.compose.material.icons.rounded.MapsUgc
|
||||
|
@ -42,7 +44,7 @@ import androidx.compose.runtime.setValue
|
|||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.scale
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.vectorResource
|
||||
|
@ -68,7 +70,7 @@ fun ModelPageAppBar(
|
|||
modifier: Modifier = Modifier,
|
||||
isResettingSession: Boolean = false,
|
||||
onResetSessionClicked: (Model) -> Unit = {},
|
||||
showResetSessionButton: Boolean = false,
|
||||
canShowResetSessionButton: Boolean = false,
|
||||
onConfigChanged: (oldConfigValues: Map<String, Any>, newConfigValues: Map<String, Any>) -> Unit = { _, _ -> },
|
||||
) {
|
||||
var showConfigDialog by remember { mutableStateOf(false) }
|
||||
|
@ -96,7 +98,7 @@ fun ModelPageAppBar(
|
|||
)
|
||||
Text(
|
||||
task.type.label,
|
||||
style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.SemiBold),
|
||||
style = MaterialTheme.typography.titleMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||
color = getTaskIconColor(task = task)
|
||||
)
|
||||
}
|
||||
|
@ -121,11 +123,12 @@ fun ModelPageAppBar(
|
|||
},
|
||||
// The config button for the model (if existed).
|
||||
actions = {
|
||||
val showConfigButton =
|
||||
model.configs.isNotEmpty() && curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
|
||||
val downloadSucceeded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED
|
||||
val showConfigButton = model.configs.isNotEmpty() && downloadSucceeded
|
||||
val showResetSessionButton = canShowResetSessionButton && downloadSucceeded
|
||||
Box(modifier = Modifier.size(42.dp), contentAlignment = Alignment.Center) {
|
||||
var configButtonOffset = 0.dp
|
||||
if (showConfigButton && showResetSessionButton) {
|
||||
if (showConfigButton && canShowResetSessionButton) {
|
||||
configButtonOffset = (-40).dp
|
||||
}
|
||||
val isModelInitializing =
|
||||
|
@ -138,14 +141,14 @@ fun ModelPageAppBar(
|
|||
},
|
||||
enabled = enableConfigButton,
|
||||
modifier = Modifier
|
||||
.scale(0.75f)
|
||||
.offset(x = configButtonOffset)
|
||||
.alpha(if (!enableConfigButton) 0.5f else 1f)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.Tune,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
tint = MaterialTheme.colorScheme.primary,
|
||||
modifier = Modifier.size(20.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -164,14 +167,23 @@ fun ModelPageAppBar(
|
|||
},
|
||||
enabled = enableResetButton,
|
||||
modifier = Modifier
|
||||
.scale(0.75f)
|
||||
.alpha(if (!enableResetButton) 0.5f else 1f)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.MapsUgc,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.size(32.dp)
|
||||
.clip(CircleShape)
|
||||
.background(MaterialTheme.colorScheme.surfaceContainer),
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Rounded.MapsUgc,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.primary,
|
||||
modifier = Modifier
|
||||
.size(20.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package com.google.ai.edge.gallery.ui.common
|
||||
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
|
@ -26,6 +27,7 @@ import androidx.compose.foundation.layout.padding
|
|||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.layout.width
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.CheckCircle
|
||||
import androidx.compose.material.icons.outlined.CheckCircle
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
|
@ -35,6 +37,7 @@ import androidx.compose.runtime.collectAsState
|
|||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.vector.ImageVector
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.vectorResource
|
||||
|
@ -83,6 +86,7 @@ fun ModelPicker(
|
|||
|
||||
// Model list.
|
||||
for (model in task.models) {
|
||||
val selected = model.name == modelManagerUiState.selectedModel.name
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.SpaceBetween,
|
||||
|
@ -91,12 +95,16 @@ fun ModelPicker(
|
|||
.clickable {
|
||||
onModelSelected(model)
|
||||
}
|
||||
.background(if (selected) MaterialTheme.colorScheme.surfaceContainer else Color.Transparent)
|
||||
.padding(horizontal = 16.dp, vertical = 8.dp),
|
||||
) {
|
||||
Spacer(modifier = Modifier.width(24.dp))
|
||||
Column(modifier = Modifier.weight(1f)) {
|
||||
Text(model.name, style = MaterialTheme.typography.bodyMedium)
|
||||
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
|
||||
Row(
|
||||
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
||||
Text(
|
||||
model.sizeInBytes.humanReadableSize(),
|
||||
|
@ -105,9 +113,9 @@ fun ModelPicker(
|
|||
)
|
||||
}
|
||||
}
|
||||
if (model.name == modelManagerUiState.selectedModel.name) {
|
||||
if (selected) {
|
||||
Icon(
|
||||
Icons.Outlined.CheckCircle,
|
||||
Icons.Filled.CheckCircle,
|
||||
modifier = Modifier.size(16.dp),
|
||||
contentDescription = ""
|
||||
)
|
||||
|
|
|
@ -130,8 +130,8 @@ fun ModelPickerChipsPager(
|
|||
showModelPicker = true
|
||||
}
|
||||
.padding(start = 8.dp, end = 2.dp)
|
||||
.padding(vertical = 1.dp)) Inner@{
|
||||
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(18.dp)) {
|
||||
.padding(vertical = 4.dp)) Inner@{
|
||||
Box(contentAlignment = Alignment.Center, modifier = Modifier.size(21.dp)) {
|
||||
StatusIcon(downloadStatus = modelManagerUiState.modelDownloadStatus[model.name])
|
||||
this@Inner.AnimatedVisibility(
|
||||
visible = modelInitializationStatus?.status == ModelInitializationStatusType.INITIALIZING,
|
||||
|
@ -141,7 +141,7 @@ fun ModelPickerChipsPager(
|
|||
// Circular progress indicator.
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier
|
||||
.size(17.dp)
|
||||
.size(24.dp)
|
||||
.alpha(0.5f),
|
||||
strokeWidth = 2.dp,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
|
@ -150,7 +150,7 @@ fun ModelPickerChipsPager(
|
|||
}
|
||||
Text(
|
||||
model.name,
|
||||
style = MaterialTheme.typography.labelMedium,
|
||||
style = MaterialTheme.typography.labelLarge,
|
||||
modifier = Modifier
|
||||
.padding(start = 4.dp)
|
||||
.widthIn(0.dp, screenWidthDp - 250.dp),
|
||||
|
|
|
@ -80,6 +80,7 @@ import androidx.compose.ui.text.AnnotatedString
|
|||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.ai.edge.gallery.R
|
||||
import com.google.ai.edge.gallery.data.ConfigKey
|
||||
import com.google.ai.edge.gallery.data.Model
|
||||
import com.google.ai.edge.gallery.data.Task
|
||||
import com.google.ai.edge.gallery.data.TaskType
|
||||
|
@ -267,7 +268,7 @@ fun ChatPanel(
|
|||
// Sender row.
|
||||
MessageSender(
|
||||
message = message,
|
||||
agentNameRes = task.agentNameRes,
|
||||
agentName = stringResource(task.agentNameRes),
|
||||
imageHistoryCurIndex = imageHistoryCurIndex.intValue
|
||||
)
|
||||
|
||||
|
|
|
@ -174,7 +174,7 @@ fun ChatView(
|
|||
task = task,
|
||||
model = selectedModel,
|
||||
modelManagerViewModel = modelManagerViewModel,
|
||||
showResetSessionButton = true,
|
||||
canShowResetSessionButton = true,
|
||||
isResettingSession = uiState.isResettingSession,
|
||||
inProgress = uiState.inProgress,
|
||||
modelPreparing = uiState.preparing,
|
||||
|
|
|
@ -44,12 +44,12 @@ fun MarkdownText(
|
|||
smallFontSize: Boolean = false
|
||||
) {
|
||||
val fontSize =
|
||||
if (smallFontSize) MaterialTheme.typography.bodySmall.fontSize else MaterialTheme.typography.bodyMedium.fontSize
|
||||
if (smallFontSize) MaterialTheme.typography.bodyMedium.fontSize else MaterialTheme.typography.bodyLarge.fontSize
|
||||
CompositionLocalProvider {
|
||||
ProvideTextStyle(
|
||||
value = TextStyle(
|
||||
fontSize = fontSize,
|
||||
lineHeight = fontSize * 1.4,
|
||||
lineHeight = fontSize * 1.3,
|
||||
)
|
||||
) {
|
||||
RichText(
|
||||
|
|
|
@ -38,7 +38,7 @@ fun MessageBodyText(message: ChatMessageText) {
|
|||
if (message.side == ChatSide.USER) {
|
||||
Text(
|
||||
message.content,
|
||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.Medium),
|
||||
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.Medium),
|
||||
color = Color.White,
|
||||
modifier = Modifier.padding(12.dp)
|
||||
)
|
||||
|
|
|
@ -23,10 +23,18 @@ import android.graphics.Bitmap
|
|||
import android.graphics.BitmapFactory
|
||||
import android.graphics.Matrix
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.PickVisualMediaRequest
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.annotation.StringRes
|
||||
import androidx.camera.core.CameraControl
|
||||
import androidx.camera.core.CameraSelector
|
||||
import androidx.camera.core.ImageCapture
|
||||
import androidx.camera.lifecycle.ProcessCameraProvider
|
||||
import androidx.camera.lifecycle.awaitInstance
|
||||
import androidx.camera.view.LifecycleCameraController
|
||||
import androidx.camera.view.PreviewView
|
||||
import androidx.compose.foundation.Image
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.border
|
||||
|
@ -36,6 +44,7 @@ import androidx.compose.foundation.layout.Box
|
|||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.offset
|
||||
|
@ -53,16 +62,21 @@ import androidx.compose.material.icons.rounded.Photo
|
|||
import androidx.compose.material.icons.rounded.PhotoCamera
|
||||
import androidx.compose.material.icons.rounded.PostAdd
|
||||
import androidx.compose.material.icons.rounded.Stop
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.DropdownMenu
|
||||
import androidx.compose.material3.DropdownMenuItem
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.IconButtonDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.ModalBottomSheet
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextField
|
||||
import androidx.compose.material3.TextFieldDefaults
|
||||
import androidx.compose.material3.rememberModalBottomSheetState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
|
@ -73,18 +87,23 @@ import androidx.compose.ui.Modifier
|
|||
import androidx.compose.ui.draw.alpha
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.draw.shadow
|
||||
import androidx.compose.ui.focus.focusModifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.asImageBitmap
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.viewinterop.AndroidView
|
||||
import androidx.core.content.ContextCompat
|
||||
import androidx.lifecycle.LifecycleOwner
|
||||
import androidx.lifecycle.compose.LocalLifecycleOwner
|
||||
import com.google.ai.edge.gallery.R
|
||||
import com.google.ai.edge.gallery.ui.common.createTempPictureUri
|
||||
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||
import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel
|
||||
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
/**
|
||||
* Composable function to display a text input field for composing chat messages.
|
||||
|
@ -92,6 +111,7 @@ import com.google.ai.edge.gallery.ui.theme.GalleryTheme
|
|||
* This function renders a row containing a text field for message input and a send button.
|
||||
* It handles message composition, input validation, and sending messages.
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun MessageInputText(
|
||||
modelManagerViewModel: ModelManagerViewModel,
|
||||
|
@ -114,6 +134,8 @@ fun MessageInputText(
|
|||
val modelManagerUiState by modelManagerViewModel.uiState.collectAsState()
|
||||
var showAddContentMenu by remember { mutableStateOf(false) }
|
||||
var showTextInputHistorySheet by remember { mutableStateOf(false) }
|
||||
var showCameraCaptureBottomSheet by remember { mutableStateOf(false) }
|
||||
var cameraCaptureSheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
|
||||
var tempPhotoUri by remember { mutableStateOf(value = Uri.EMPTY) }
|
||||
var pickedImages by remember { mutableStateOf<List<Bitmap>>(listOf()) }
|
||||
val updatePickedImages: (Bitmap) -> Unit = { bitmap ->
|
||||
|
@ -145,6 +167,7 @@ fun MessageInputText(
|
|||
if (permissionGranted) {
|
||||
showAddContentMenu = false
|
||||
tempPhotoUri = context.createTempPictureUri()
|
||||
// showCameraCaptureBottomSheet = true
|
||||
cameraLauncher.launch(tempPhotoUri)
|
||||
}
|
||||
}
|
||||
|
@ -241,6 +264,7 @@ fun MessageInputText(
|
|||
) -> {
|
||||
showAddContentMenu = false
|
||||
tempPhotoUri = context.createTempPictureUri()
|
||||
// showCameraCaptureBottomSheet = true
|
||||
cameraLauncher.launch(tempPhotoUri)
|
||||
}
|
||||
|
||||
|
@ -313,7 +337,7 @@ fun MessageInputText(
|
|||
disabledIndicatorColor = Color.Transparent,
|
||||
disabledContainerColor = Color.Transparent,
|
||||
),
|
||||
textStyle = MaterialTheme.typography.bodyMedium,
|
||||
textStyle = MaterialTheme.typography.bodyLarge,
|
||||
modifier = Modifier
|
||||
.weight(1f)
|
||||
.padding(start = 36.dp),
|
||||
|
@ -379,6 +403,90 @@ fun MessageInputText(
|
|||
modelManagerViewModel.clearTextInputHistory()
|
||||
})
|
||||
}
|
||||
|
||||
if (showCameraCaptureBottomSheet) {
|
||||
ModalBottomSheet(
|
||||
sheetState = cameraCaptureSheetState,
|
||||
onDismissRequest = { showCameraCaptureBottomSheet = false }) {
|
||||
val lifecycleOwner = LocalLifecycleOwner.current
|
||||
val previewUseCase = remember { androidx.camera.core.Preview.Builder().build() }
|
||||
val imageCaptureUseCase = remember { ImageCapture.Builder().build() }
|
||||
var cameraProvider by remember { mutableStateOf<ProcessCameraProvider?>(null) }
|
||||
var cameraControl by remember { mutableStateOf<CameraControl?>(null) }
|
||||
val localContext = LocalContext.current
|
||||
|
||||
val executor = remember { Executors.newSingleThreadExecutor() }
|
||||
val capturedImageUri = remember { mutableStateOf<Uri?>(null) }
|
||||
|
||||
fun rebindCameraProvider() {
|
||||
cameraProvider?.let { cameraProvider ->
|
||||
val cameraSelector = CameraSelector.Builder()
|
||||
.requireLensFacing(CameraSelector.LENS_FACING_FRONT)
|
||||
.build()
|
||||
cameraProvider.unbindAll()
|
||||
val camera = cameraProvider.bindToLifecycle(
|
||||
lifecycleOwner = lifecycleOwner,
|
||||
cameraSelector = cameraSelector,
|
||||
previewUseCase,
|
||||
imageCaptureUseCase
|
||||
)
|
||||
cameraControl = camera.cameraControl
|
||||
}
|
||||
}
|
||||
|
||||
LaunchedEffect(Unit) {
|
||||
cameraProvider = ProcessCameraProvider.awaitInstance(localContext)
|
||||
rebindCameraProvider()
|
||||
}
|
||||
|
||||
// val cameraController = remember {
|
||||
// LifecycleCameraController(context).apply {
|
||||
// bindToLifecycle(lifecycleOwner)
|
||||
// }
|
||||
// }
|
||||
|
||||
Box(modifier = Modifier.fillMaxSize()) {
|
||||
// PreviewView for the camera feed.
|
||||
AndroidView(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
factory = { ctx ->
|
||||
PreviewView(context).also {
|
||||
previewUseCase.surfaceProvider = it.surfaceProvider
|
||||
rebindCameraProvider()
|
||||
}
|
||||
// PreviewView(ctx).apply {
|
||||
// scaleType = PreviewView.ScaleType.FILL_START
|
||||
// implementationMode = PreviewView.ImplementationMode.COMPATIBLE
|
||||
// controller = cameraController // Attach the lifecycle-aware camera controller.
|
||||
// }
|
||||
},
|
||||
// onRelease = {
|
||||
// // Called when the PreviewView is removed from the composable hierarchy
|
||||
// cameraController.unbind() // Unbinds the camera to free up resources
|
||||
// }
|
||||
)
|
||||
// Button that triggers the image capture process
|
||||
IconButton(
|
||||
colors = IconButtonDefaults.iconButtonColors(
|
||||
containerColor = MaterialTheme.colorScheme.tertiaryContainer,
|
||||
),
|
||||
modifier = Modifier
|
||||
.align(Alignment.BottomCenter)
|
||||
.padding(bottom = 32.dp)
|
||||
.size(64.dp),
|
||||
onClick = {
|
||||
},
|
||||
) {
|
||||
Icon(
|
||||
Icons.Rounded.PhotoCamera,
|
||||
contentDescription = "",
|
||||
tint = MaterialTheme.colorScheme.onTertiaryContainer,
|
||||
modifier = Modifier.size(36.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun handleImageSelected(
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package com.google.ai.edge.gallery.ui.common.chat
|
||||
|
||||
import android.graphics.Bitmap
|
||||
import androidx.annotation.StringRes
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
|
@ -38,7 +37,6 @@ import androidx.compose.ui.unit.dp
|
|||
import com.google.ai.edge.gallery.R
|
||||
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.ai.edge.gallery.ui.theme.bodySmallNarrow
|
||||
import com.google.ai.edge.gallery.ui.theme.bodySmallSemiBold
|
||||
|
||||
data class MessageLayoutConfig(
|
||||
val horizontalArrangement: Arrangement.Horizontal,
|
||||
|
@ -56,7 +54,9 @@ data class MessageLayoutConfig(
|
|||
*/
|
||||
@Composable
|
||||
fun MessageSender(
|
||||
message: ChatMessage, @StringRes agentNameRes: Int, imageHistoryCurIndex: Int = 0
|
||||
message: ChatMessage,
|
||||
agentName: String = "",
|
||||
imageHistoryCurIndex: Int = 0
|
||||
) {
|
||||
// No user label for system messages.
|
||||
if (message.side == ChatSide.SYSTEM) {
|
||||
|
@ -64,7 +64,7 @@ fun MessageSender(
|
|||
}
|
||||
|
||||
val (horizontalArrangement, modifier, userLabel, rightSideLabel) = getMessageLayoutConfig(
|
||||
message = message, agentNameRes = agentNameRes, imageHistoryCurIndex = imageHistoryCurIndex
|
||||
message = message, agentName = agentName, imageHistoryCurIndex = imageHistoryCurIndex
|
||||
)
|
||||
|
||||
Row(
|
||||
|
@ -76,7 +76,7 @@ fun MessageSender(
|
|||
// Sender label.
|
||||
Text(
|
||||
userLabel,
|
||||
style = bodySmallSemiBold,
|
||||
style = MaterialTheme.typography.titleSmall,
|
||||
)
|
||||
|
||||
when (message) {
|
||||
|
@ -152,7 +152,7 @@ fun MessageSender(
|
|||
@Composable
|
||||
private fun getMessageLayoutConfig(
|
||||
message: ChatMessage,
|
||||
@StringRes agentNameRes: Int,
|
||||
agentName: String,
|
||||
imageHistoryCurIndex: Int,
|
||||
): MessageLayoutConfig {
|
||||
var userLabel = stringResource(R.string.chat_you)
|
||||
|
@ -161,7 +161,7 @@ private fun getMessageLayoutConfig(
|
|||
var modifier = Modifier.padding(bottom = 2.dp)
|
||||
|
||||
if (message.side == ChatSide.AGENT) {
|
||||
userLabel = stringResource(agentNameRes)
|
||||
userLabel = agentName
|
||||
}
|
||||
|
||||
when (message) {
|
||||
|
@ -207,12 +207,12 @@ fun MessageSenderPreview() {
|
|||
// Agent message.
|
||||
MessageSender(
|
||||
message = ChatMessageText(content = "hello world", side = ChatSide.AGENT),
|
||||
agentNameRes = R.string.chat_generic_agent_name
|
||||
agentName = stringResource(R.string.chat_generic_agent_name)
|
||||
)
|
||||
// User message.
|
||||
MessageSender(
|
||||
message = ChatMessageText(content = "hello world", side = ChatSide.USER),
|
||||
agentNameRes = R.string.chat_generic_agent_name
|
||||
agentName = stringResource(R.string.chat_generic_agent_name)
|
||||
)
|
||||
// Benchmark during warmup.
|
||||
MessageSender(
|
||||
|
@ -225,7 +225,8 @@ fun MessageSenderPreview() {
|
|||
warmupTotal = 50,
|
||||
iterationCurrent = 0,
|
||||
iterationTotal = 200
|
||||
), agentNameRes = R.string.chat_generic_agent_name
|
||||
),
|
||||
agentName = stringResource(R.string.chat_generic_agent_name)
|
||||
)
|
||||
// Benchmark during running.
|
||||
MessageSender(
|
||||
|
@ -238,7 +239,8 @@ fun MessageSenderPreview() {
|
|||
warmupTotal = 50,
|
||||
iterationCurrent = 123,
|
||||
iterationTotal = 200
|
||||
), agentNameRes = R.string.chat_generic_agent_name
|
||||
),
|
||||
agentName = stringResource(R.string.chat_generic_agent_name)
|
||||
)
|
||||
// Image generation during running.
|
||||
MessageSender(
|
||||
|
@ -248,7 +250,7 @@ fun MessageSenderPreview() {
|
|||
totalIterations = 10,
|
||||
ChatSide.AGENT
|
||||
),
|
||||
agentNameRes = R.string.chat_generic_agent_name,
|
||||
agentName = stringResource(R.string.chat_generic_agent_name),
|
||||
imageHistoryCurIndex = 4,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -226,7 +226,7 @@ fun ModelDownloadingAnimation(
|
|||
Text(
|
||||
sizeLabel,
|
||||
color = MaterialTheme.colorScheme.secondary,
|
||||
style = labelSmallNarrow.copy(fontSize = 9.sp, lineHeight = 10.sp),
|
||||
style = MaterialTheme.typography.labelMedium,
|
||||
textAlign = TextAlign.Center,
|
||||
overflow = TextOverflow.Visible,
|
||||
modifier = Modifier
|
||||
|
@ -266,7 +266,7 @@ fun ModelDownloadingAnimation(
|
|||
"Feel free to switch apps or lock your device.\n"
|
||||
+ "The download will continue in the background.\n"
|
||||
+ "We'll send a notification when it's done.",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
textAlign = TextAlign.Center
|
||||
)
|
||||
}
|
||||
|
|
|
@ -38,6 +38,8 @@ import com.google.ai.edge.gallery.data.ModelDownloadStatusType
|
|||
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
|
||||
import com.google.ai.edge.gallery.ui.theme.customColors
|
||||
|
||||
private val SIZE = 18.dp
|
||||
|
||||
/**
|
||||
* Composable function to display an icon representing the download status of a model.
|
||||
*/
|
||||
|
@ -53,7 +55,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
|
|||
Icons.AutoMirrored.Outlined.HelpOutline,
|
||||
tint = Color(0xFFCCCCCC),
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(14.dp)
|
||||
modifier = Modifier.size(SIZE)
|
||||
)
|
||||
|
||||
ModelDownloadStatusType.SUCCEEDED -> {
|
||||
|
@ -61,7 +63,7 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
|
|||
Icons.Filled.DownloadForOffline,
|
||||
tint = MaterialTheme.customColors.successColor,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(14.dp)
|
||||
modifier = Modifier.size(SIZE)
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -69,13 +71,13 @@ fun StatusIcon(downloadStatus: ModelDownloadStatus?, modifier: Modifier = Modifi
|
|||
Icons.Rounded.Error,
|
||||
tint = Color(0xFFAA0000),
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(14.dp)
|
||||
modifier = Modifier.size(SIZE)
|
||||
)
|
||||
|
||||
ModelDownloadStatusType.IN_PROGRESS -> Icon(
|
||||
Icons.Rounded.Downloading,
|
||||
contentDescription = "",
|
||||
modifier = Modifier.size(14.dp)
|
||||
modifier = Modifier.size(SIZE)
|
||||
)
|
||||
|
||||
else -> {}
|
||||
|
|
|
@ -387,7 +387,7 @@ private fun TaskList(
|
|||
Text(
|
||||
introText,
|
||||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.SemiBold),
|
||||
modifier = Modifier.padding(bottom = 20.dp)
|
||||
)
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package com.google.ai.edge.gallery.ui.llmchat
|
|||
import com.google.ai.edge.gallery.data.Accelerator
|
||||
import com.google.ai.edge.gallery.data.Config
|
||||
import com.google.ai.edge.gallery.data.ConfigKey
|
||||
import com.google.ai.edge.gallery.data.LabelConfig
|
||||
import com.google.ai.edge.gallery.data.NumberSliderConfig
|
||||
import com.google.ai.edge.gallery.data.SegmentedButtonConfig
|
||||
import com.google.ai.edge.gallery.data.ValueType
|
||||
|
@ -37,12 +38,9 @@ fun createLlmChatConfigs(
|
|||
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
|
||||
): List<Config> {
|
||||
return listOf(
|
||||
NumberSliderConfig(
|
||||
LabelConfig(
|
||||
key = ConfigKey.MAX_TOKENS,
|
||||
sliderMin = 100f,
|
||||
sliderMax = 1024f,
|
||||
defaultValue = defaultMaxToken.toFloat(),
|
||||
valueType = ValueType.INT
|
||||
defaultValue = "$defaultMaxToken",
|
||||
),
|
||||
NumberSliderConfig(
|
||||
key = ConfigKey.TOPK,
|
||||
|
|
|
@ -84,27 +84,31 @@ object LlmChatModelHelper {
|
|||
}
|
||||
|
||||
fun resetSession(model: Model) {
|
||||
Log.d(TAG, "Resetting session for model '${model.name}'")
|
||||
try {
|
||||
Log.d(TAG, "Resetting session for model '${model.name}'")
|
||||
|
||||
val instance = model.instance as LlmModelInstance? ?: return
|
||||
val session = instance.session
|
||||
session.close()
|
||||
val instance = model.instance as LlmModelInstance? ?: return
|
||||
val session = instance.session
|
||||
session.close()
|
||||
|
||||
val inference = instance.engine
|
||||
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
|
||||
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
|
||||
val temperature =
|
||||
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
|
||||
val newSession = LlmInferenceSession.createFromOptions(
|
||||
inference,
|
||||
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||
.setTemperature(temperature)
|
||||
.setGraphOptions(
|
||||
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
||||
).build()
|
||||
)
|
||||
instance.session = newSession
|
||||
Log.d(TAG, "Resetting done")
|
||||
val inference = instance.engine
|
||||
val topK = model.getIntConfigValue(key = ConfigKey.TOPK, defaultValue = DEFAULT_TOPK)
|
||||
val topP = model.getFloatConfigValue(key = ConfigKey.TOPP, defaultValue = DEFAULT_TOPP)
|
||||
val temperature =
|
||||
model.getFloatConfigValue(key = ConfigKey.TEMPERATURE, defaultValue = DEFAULT_TEMPERATURE)
|
||||
val newSession = LlmInferenceSession.createFromOptions(
|
||||
inference,
|
||||
LlmInferenceSession.LlmInferenceSessionOptions.builder().setTopK(topK).setTopP(topP)
|
||||
.setTemperature(temperature)
|
||||
.setGraphOptions(
|
||||
GraphOptions.builder().setEnableVisionModality(model.llmSupportImage).build()
|
||||
).build()
|
||||
)
|
||||
instance.session = newSession
|
||||
Log.d(TAG, "Resetting done")
|
||||
} catch (e: Exception) {
|
||||
Log.d(TAG, "Failed to reset session", e)
|
||||
}
|
||||
}
|
||||
|
||||
fun cleanUp(model: Model) {
|
||||
|
@ -114,7 +118,7 @@ object LlmChatModelHelper {
|
|||
|
||||
val instance = model.instance as LlmModelInstance
|
||||
try {
|
||||
instance.session.close()
|
||||
// This will also close the session. Do not call session.close manually.
|
||||
instance.engine.close()
|
||||
} catch (e: Exception) {
|
||||
// ignore
|
||||
|
|
|
@ -241,7 +241,7 @@ fun PromptTemplatesPanel(
|
|||
disabledIndicatorColor = Color.Transparent,
|
||||
disabledContainerColor = Color.Transparent,
|
||||
),
|
||||
textStyle = MaterialTheme.typography.bodyMedium,
|
||||
textStyle = MaterialTheme.typography.bodyLarge,
|
||||
placeholder = { Text("Enter content") },
|
||||
modifier = Modifier
|
||||
.padding(bottom = 40.dp)
|
||||
|
|
|
@ -102,7 +102,7 @@ fun ModelList(
|
|||
Text(
|
||||
task.description,
|
||||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.bodyMedium.copy(fontWeight = FontWeight.SemiBold),
|
||||
style = MaterialTheme.typography.bodyLarge.copy(fontWeight = FontWeight.SemiBold),
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ fun ClickableLink(
|
|||
Text(
|
||||
text = annotatedText,
|
||||
textAlign = TextAlign.Center,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
modifier = Modifier
|
||||
.padding(start = 6.dp)
|
||||
.clickable {
|
||||
|
|
|
@ -79,6 +79,9 @@ val bodySmallNarrow =
|
|||
val bodySmallSemiBold =
|
||||
baseline.bodySmall.copy(fontFamily = nunitoFontFamily, fontWeight = FontWeight.SemiBold)
|
||||
|
||||
val bodyMediumSemiBold =
|
||||
baseline.bodyMedium.copy(fontFamily = nunitoFontFamily, fontWeight = FontWeight.SemiBold)
|
||||
|
||||
val bodySmallMediumNarrow =
|
||||
baseline.bodySmall.copy(fontFamily = nunitoFontFamily, letterSpacing = 0.0.sp, fontSize = 14.sp)
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ pluginManagement {
|
|||
dependencyResolutionManagement {
|
||||
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
|
||||
repositories {
|
||||
mavenLocal()
|
||||
// mavenLocal()
|
||||
google()
|
||||
mavenCentral()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue