diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt index 665a073..50a2d27 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/Utils.kt @@ -66,6 +66,10 @@ interface LatencyProvider { private const val START_THINKING = "***Thinking...***" private const val DONE_THINKING = "***Done thinking***" +data class JsonObjAndTextContent( + val jsonObj: T, val textContent: String, +) + /** Format the bytes into a human-readable format. */ fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String { val bytes = this @@ -473,16 +477,14 @@ fun processTasks() { fun processLlmResponse(response: String): String { // Add "thinking" and "done thinking" around the thinking content. - var newContent = response - .replace("", "$START_THINKING\n") - .replace("", "\n$DONE_THINKING") + var newContent = + response.replace("", "$START_THINKING\n").replace("", "\n$DONE_THINKING") // Remove empty thinking content. val endThinkingIndex = newContent.indexOf(DONE_THINKING) if (endThinkingIndex >= 0) { val thinkingContent = - newContent.substring(0, endThinkingIndex + DONE_THINKING.length) - .replace(START_THINKING, "") + newContent.substring(0, endThinkingIndex + DONE_THINKING.length).replace(START_THINKING, "") .replace(DONE_THINKING, "") if (thinkingContent.isBlank()) { newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length) @@ -495,7 +497,7 @@ fun processLlmResponse(response: String): String { } @OptIn(ExperimentalSerializationApi::class) -inline fun getJsonResponse(url: String): T? { +inline fun getJsonResponse(url: String): JsonObjAndTextContent? { try { val connection = URL(url).openConnection() as HttpURLConnection connection.requestMethod = "GET" @@ -514,14 +516,13 @@ inline fun getJsonResponse(url: String): T? { allowTrailingComma = true } val jsonObj = json.decodeFromString(response) - return jsonObj + return JsonObjAndTextContent(jsonObj = jsonObj, textContent = response) } else { Log.e("AGUtils", "HTTP error: $responseCode") } } catch (e: Exception) { Log.e( - "AGUtils", - "Error when getting json response: ${e.message}" + "AGUtils", "Error when getting json response: ${e.message}" ) e.printStackTrace() } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt index 5ce0f4b..3ae0043 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/common/chat/ChatPanel.kt @@ -342,7 +342,12 @@ fun ChatPanel( } .sharedElement( sharedContentState = rememberSharedContentState(key = "selected_image"), - animatedVisibilityScope = this@AnimatedContent + animatedVisibilityScope = this@AnimatedContent, + clipInOverlayDuringTransition = OverlayClip( + MessageBubbleShape( + radius = bubbleBorderRadius + ) + ) ), ) } @@ -378,7 +383,7 @@ fun ChatPanel( ) { LatencyText(message = message) // A button to show stats for the LLM message. - if ((task.type == TaskType.LLM_CHAT || task.type == TaskType.LLM_ASK_IMAGE) && message is ChatMessageText + if (task.type.id.startsWith("llm_") && message is ChatMessageText // This means we only want to show the action button when the message is done // generating, at which point the latency will be set. && message.latencyMs >= 0 @@ -565,7 +570,7 @@ fun ChatPanel( ) .sharedElement( sharedContentState = rememberSharedContentState(key = "selected_image"), - animatedVisibilityScope = this@AnimatedContent + animatedVisibilityScope = this@AnimatedContent, ), contentScale = ContentScale.Fit, ) diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt index 231a92b..05978dc 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/home/NewReleaseNotification.kt @@ -51,12 +51,12 @@ fun NewReleaseNotification() { val info = getJsonResponse("https://api.github.com/repos/$REPO/releases/latest") if (info != null) { val curRelease = BuildConfig.VERSION_NAME - val newRelease = info.tag_name + val newRelease = info.jsonObj.tag_name val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease) Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer") if (isNewer) { newReleaseVersion = newRelease - newReleaseUrl = info.html_url + newReleaseUrl = info.jsonObj.html_url } } } diff --git a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt index ef8b636..50d5a57 100644 --- a/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/aiedge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -60,6 +60,8 @@ import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch import kotlinx.coroutines.withContext +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.Json import net.openid.appauth.AuthorizationException import net.openid.appauth.AuthorizationRequest import net.openid.appauth.AuthorizationResponse @@ -73,6 +75,7 @@ private const val TAG = "AGModelManagerViewModel" private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50 private const val MODEL_ALLOWLIST_URL = "https://raw.githubusercontent.com/google-ai-edge/gallery/refs/heads/main/model_allowlist.json" +private const val MODEL_ALLOWLIST_FILENAME = "model_allowlist.json" data class ModelInitializationStatus( val status: ModelInitializationStatusType, var error: String = "" @@ -146,6 +149,7 @@ data class PagerScrollState( * and cleaning up models. It also manages the UI state for model management, including the * list of tasks, models, download statuses, and initialization statuses. */ +@OptIn(ExperimentalSerializationApi::class) open class ModelManagerViewModel( private val downloadRepository: DownloadRepository, private val dataStoreRepository: DataStoreRepository, @@ -227,8 +231,7 @@ open class ModelManagerViewModel( dataStoreRepository.saveImportedModels(importedModels = importedModels) } val newUiState = uiState.value.copy( - modelDownloadStatus = curModelDownloadStatus, - tasks = uiState.value.tasks.toList() + modelDownloadStatus = curModelDownloadStatus, tasks = uiState.value.tasks.toList() ) _uiState.update { newUiState } } @@ -660,8 +663,17 @@ open class ModelManagerViewModel( viewModelScope.launch(Dispatchers.IO) { try { // Load model allowlist json. - val modelAllowlist: ModelAllowlist? = - getJsonResponse(url = MODEL_ALLOWLIST_URL) + Log.d(TAG, "Loading model allowlist from internet...") + val data = getJsonResponse(url = MODEL_ALLOWLIST_URL) + var modelAllowlist: ModelAllowlist? = data?.jsonObj + + if (modelAllowlist == null) { + Log.d(TAG, "Failed to load model allowlist from internet. Trying to load it from disk") + modelAllowlist = readModelAllowlistFromDisk() + } else { + Log.d(TAG, "Done: loading model allowlist from internet") + saveModelAllowlistToDisk(modelAllowlistContent = data?.textContent ?: "{}") + } if (modelAllowlist == null) { _uiState.update { uiState.value.copy(loadingModelAllowlistError = "Failed to load model list") } @@ -707,6 +719,40 @@ open class ModelManagerViewModel( } } + private fun saveModelAllowlistToDisk(modelAllowlistContent: String) { + try { + Log.d(TAG, "Saving model allowlist to disk...") + val file = File(externalFilesDir, MODEL_ALLOWLIST_FILENAME) + file.writeText(modelAllowlistContent) + Log.d(TAG, "Done: saving model allowlist to disk.") + } catch (e: Exception) { + Log.e(TAG, "failed to write model allowlist to disk", e) + } + } + + private fun readModelAllowlistFromDisk(): ModelAllowlist? { + try { + Log.d(TAG, "Reading model allowlist from disk...") + val file = File(externalFilesDir, MODEL_ALLOWLIST_FILENAME) + if (file.exists()) { + val content = file.readText() + Log.d(TAG, "Model allowlist content from local file: $content") + val json = Json { + // Handle potential extra fields + ignoreUnknownKeys = true + allowComments = true + allowTrailingComma = true + } + return json.decodeFromString(content) + } + } catch (e: Exception) { + Log.e(TAG, "failed to read model allowlist from disk", e) + return null + } + + return null + } + private fun isModelPartiallyDownloaded(model: Model): Boolean { return inProgressWorkInfos.find { it.modelName == model.name } != null }