From e7dda4b4ad154916a81b21ded18ff451360cd2c8 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Sun, 18 May 2025 17:26:31 -0700 Subject: [PATCH] Save the loaded model allowlist to a local file so that it can be read when the allowlist cannot be loaded from internet. Also improve the image clipping transitioning from full image back to the image thumbnail in chat ui. --- .../google/aiedge/gallery/ui/common/Utils.kt | 19 +++---- .../gallery/ui/common/chat/ChatPanel.kt | 11 ++-- .../gallery/ui/home/NewReleaseNotification.kt | 4 +- .../ui/modelmanager/ModelManagerViewModel.kt | 54 +++++++++++++++++-- 4 files changed, 70 insertions(+), 18 deletions(-) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt index 665a073..50a2d27 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt @@ -66,6 +66,10 @@ interface LatencyProvider { private const val START_THINKING = "***Thinking...***" private const val DONE_THINKING = "***Done thinking***" +data class JsonObjAndTextContent( + val jsonObj: T, val textContent: String, +) + /** Format the bytes into a human-readable format. */ fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String { val bytes = this @@ -473,16 +477,14 @@ fun processTasks() { fun processLlmResponse(response: String): String { // Add "thinking" and "done thinking" around the thinking content. - var newContent = response - .replace("", "$START_THINKING\n") - .replace("", "\n$DONE_THINKING") + var newContent = + response.replace("", "$START_THINKING\n").replace("", "\n$DONE_THINKING") // Remove empty thinking content. val endThinkingIndex = newContent.indexOf(DONE_THINKING) if (endThinkingIndex >= 0) { val thinkingContent = - newContent.substring(0, endThinkingIndex + DONE_THINKING.length) - .replace(START_THINKING, "") + newContent.substring(0, endThinkingIndex + DONE_THINKING.length).replace(START_THINKING, "") .replace(DONE_THINKING, "") if (thinkingContent.isBlank()) { newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length) @@ -495,7 +497,7 @@ fun processLlmResponse(response: String): String { } @OptIn(ExperimentalSerializationApi::class) -inline fun getJsonResponse(url: String): T? { +inline fun getJsonResponse(url: String): JsonObjAndTextContent? { try { val connection = URL(url).openConnection() as HttpURLConnection connection.requestMethod = "GET" @@ -514,14 +516,13 @@ inline fun getJsonResponse(url: String): T? { allowTrailingComma = true } val jsonObj = json.decodeFromString(response) - return jsonObj + return JsonObjAndTextContent(jsonObj = jsonObj, textContent = response) } else { Log.e("AGUtils", "HTTP error: $responseCode") } } catch (e: Exception) { Log.e( - "AGUtils", - "Error when getting json response: ${e.message}" + "AGUtils", "Error when getting json response: ${e.message}" ) e.printStackTrace() } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt index 5ce0f4b..3ae0043 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt @@ -342,7 +342,12 @@ fun ChatPanel( } .sharedElement( sharedContentState = rememberSharedContentState(key = "selected_image"), - animatedVisibilityScope = this@AnimatedContent + animatedVisibilityScope = this@AnimatedContent, + clipInOverlayDuringTransition = OverlayClip( + MessageBubbleShape( + radius = bubbleBorderRadius + ) + ) ), ) } @@ -378,7 +383,7 @@ fun ChatPanel( ) { LatencyText(message = message) // A button to show stats for the LLM message. - if ((task.type == TaskType.LLM_CHAT || task.type == TaskType.LLM_ASK_IMAGE) && message is ChatMessageText + if (task.type.id.startsWith("llm_") && message is ChatMessageText // This means we only want to show the action button when the message is done // generating, at which point the latency will be set. && message.latencyMs >= 0 @@ -565,7 +570,7 @@ fun ChatPanel( ) .sharedElement( sharedContentState = rememberSharedContentState(key = "selected_image"), - animatedVisibilityScope = this@AnimatedContent + animatedVisibilityScope = this@AnimatedContent, ), contentScale = ContentScale.Fit, ) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt index 231a92b..05978dc 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt @@ -51,12 +51,12 @@ fun NewReleaseNotification() { val info = getJsonResponse("https://api.github.com/repos/$REPO/releases/latest") if (info != null) { val curRelease = BuildConfig.VERSION_NAME - val newRelease = info.tag_name + val newRelease = info.jsonObj.tag_name val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease) Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer") if (isNewer) { newReleaseVersion = newRelease - newReleaseUrl = info.html_url + newReleaseUrl = info.jsonObj.html_url } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt index ef8b636..50d5a57 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -60,6 +60,8 @@ import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch import kotlinx.coroutines.withContext +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.Json import net.openid.appauth.AuthorizationException import net.openid.appauth.AuthorizationRequest import net.openid.appauth.AuthorizationResponse @@ -73,6 +75,7 @@ private const val TAG = "AGModelManagerViewModel" private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50 private const val MODEL_ALLOWLIST_URL = "https://raw.githubusercontent.com/google-ai-edge/gallery/refs/heads/main/model_allowlist.json" +private const val MODEL_ALLOWLIST_FILENAME = "model_allowlist.json" data class ModelInitializationStatus( val status: ModelInitializationStatusType, var error: String = "" @@ -146,6 +149,7 @@ data class PagerScrollState( * and cleaning up models. It also manages the UI state for model management, including the * list of tasks, models, download statuses, and initialization statuses. */ +@OptIn(ExperimentalSerializationApi::class) open class ModelManagerViewModel( private val downloadRepository: DownloadRepository, private val dataStoreRepository: DataStoreRepository, @@ -227,8 +231,7 @@ open class ModelManagerViewModel( dataStoreRepository.saveImportedModels(importedModels = importedModels) } val newUiState = uiState.value.copy( - modelDownloadStatus = curModelDownloadStatus, - tasks = uiState.value.tasks.toList() + modelDownloadStatus = curModelDownloadStatus, tasks = uiState.value.tasks.toList() ) _uiState.update { newUiState } } @@ -660,8 +663,17 @@ open class ModelManagerViewModel( viewModelScope.launch(Dispatchers.IO) { try { // Load model allowlist json. - val modelAllowlist: ModelAllowlist? = - getJsonResponse(url = MODEL_ALLOWLIST_URL) + Log.d(TAG, "Loading model allowlist from internet...") + val data = getJsonResponse(url = MODEL_ALLOWLIST_URL) + var modelAllowlist: ModelAllowlist? = data?.jsonObj + + if (modelAllowlist == null) { + Log.d(TAG, "Failed to load model allowlist from internet. Trying to load it from disk") + modelAllowlist = readModelAllowlistFromDisk() + } else { + Log.d(TAG, "Done: loading model allowlist from internet") + saveModelAllowlistToDisk(modelAllowlistContent = data?.textContent ?: "{}") + } if (modelAllowlist == null) { _uiState.update { uiState.value.copy(loadingModelAllowlistError = "Failed to load model list") } @@ -707,6 +719,40 @@ open class ModelManagerViewModel( } } + private fun saveModelAllowlistToDisk(modelAllowlistContent: String) { + try { + Log.d(TAG, "Saving model allowlist to disk...") + val file = File(externalFilesDir, MODEL_ALLOWLIST_FILENAME) + file.writeText(modelAllowlistContent) + Log.d(TAG, "Done: saving model allowlist to disk.") + } catch (e: Exception) { + Log.e(TAG, "failed to write model allowlist to disk", e) + } + } + + private fun readModelAllowlistFromDisk(): ModelAllowlist? { + try { + Log.d(TAG, "Reading model allowlist from disk...") + val file = File(externalFilesDir, MODEL_ALLOWLIST_FILENAME) + if (file.exists()) { + val content = file.readText() + Log.d(TAG, "Model allowlist content from local file: $content") + val json = Json { + // Handle potential extra fields + ignoreUnknownKeys = true + allowComments = true + allowTrailingComma = true + } + return json.decodeFromString(content) + } + } catch (e: Exception) { + Log.e(TAG, "failed to read model allowlist from disk", e) + return null + } + + return null + } + private fun isModelPartiallyDownloaded(model: Model): Boolean { return inProgressWorkInfos.find { it.modelName == model.name } != null }