Save the loaded model allowlist to a local file so that it can be read when the allowlist cannot be loaded from internet.

Also improve the image clipping transitioning from full image back to the image thumbnail in chat ui.
This commit is contained in:
Jing Jin 2025-05-18 17:26:31 -07:00
parent 3495b7cf6e
commit e7dda4b4ad
4 changed files with 70 additions and 18 deletions

View file

@ -66,6 +66,10 @@ interface LatencyProvider {
private const val START_THINKING = "***Thinking...***" private const val START_THINKING = "***Thinking...***"
private const val DONE_THINKING = "***Done thinking***" private const val DONE_THINKING = "***Done thinking***"
data class JsonObjAndTextContent<T>(
val jsonObj: T, val textContent: String,
)
/** Format the bytes into a human-readable format. */ /** Format the bytes into a human-readable format. */
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String { fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
val bytes = this val bytes = this
@ -473,16 +477,14 @@ fun processTasks() {
fun processLlmResponse(response: String): String { fun processLlmResponse(response: String): String {
// Add "thinking" and "done thinking" around the thinking content. // Add "thinking" and "done thinking" around the thinking content.
var newContent = response var newContent =
.replace("<think>", "$START_THINKING\n") response.replace("<think>", "$START_THINKING\n").replace("</think>", "\n$DONE_THINKING")
.replace("</think>", "\n$DONE_THINKING")
// Remove empty thinking content. // Remove empty thinking content.
val endThinkingIndex = newContent.indexOf(DONE_THINKING) val endThinkingIndex = newContent.indexOf(DONE_THINKING)
if (endThinkingIndex >= 0) { if (endThinkingIndex >= 0) {
val thinkingContent = val thinkingContent =
newContent.substring(0, endThinkingIndex + DONE_THINKING.length) newContent.substring(0, endThinkingIndex + DONE_THINKING.length).replace(START_THINKING, "")
.replace(START_THINKING, "")
.replace(DONE_THINKING, "") .replace(DONE_THINKING, "")
if (thinkingContent.isBlank()) { if (thinkingContent.isBlank()) {
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length) newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
@ -495,7 +497,7 @@ fun processLlmResponse(response: String): String {
} }
@OptIn(ExperimentalSerializationApi::class) @OptIn(ExperimentalSerializationApi::class)
inline fun <reified T> getJsonResponse(url: String): T? { inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
try { try {
val connection = URL(url).openConnection() as HttpURLConnection val connection = URL(url).openConnection() as HttpURLConnection
connection.requestMethod = "GET" connection.requestMethod = "GET"
@ -514,14 +516,13 @@ inline fun <reified T> getJsonResponse(url: String): T? {
allowTrailingComma = true allowTrailingComma = true
} }
val jsonObj = json.decodeFromString<T>(response) val jsonObj = json.decodeFromString<T>(response)
return jsonObj return JsonObjAndTextContent(jsonObj = jsonObj, textContent = response)
} else { } else {
Log.e("AGUtils", "HTTP error: $responseCode") Log.e("AGUtils", "HTTP error: $responseCode")
} }
} catch (e: Exception) { } catch (e: Exception) {
Log.e( Log.e(
"AGUtils", "AGUtils", "Error when getting json response: ${e.message}"
"Error when getting json response: ${e.message}"
) )
e.printStackTrace() e.printStackTrace()
} }

View file

@ -342,7 +342,12 @@ fun ChatPanel(
} }
.sharedElement( .sharedElement(
sharedContentState = rememberSharedContentState(key = "selected_image"), sharedContentState = rememberSharedContentState(key = "selected_image"),
animatedVisibilityScope = this@AnimatedContent animatedVisibilityScope = this@AnimatedContent,
clipInOverlayDuringTransition = OverlayClip(
MessageBubbleShape(
radius = bubbleBorderRadius
)
)
), ),
) )
} }
@ -378,7 +383,7 @@ fun ChatPanel(
) { ) {
LatencyText(message = message) LatencyText(message = message)
// A button to show stats for the LLM message. // A button to show stats for the LLM message.
if ((task.type == TaskType.LLM_CHAT || task.type == TaskType.LLM_ASK_IMAGE) && message is ChatMessageText if (task.type.id.startsWith("llm_") && message is ChatMessageText
// This means we only want to show the action button when the message is done // This means we only want to show the action button when the message is done
// generating, at which point the latency will be set. // generating, at which point the latency will be set.
&& message.latencyMs >= 0 && message.latencyMs >= 0
@ -565,7 +570,7 @@ fun ChatPanel(
) )
.sharedElement( .sharedElement(
sharedContentState = rememberSharedContentState(key = "selected_image"), sharedContentState = rememberSharedContentState(key = "selected_image"),
animatedVisibilityScope = this@AnimatedContent animatedVisibilityScope = this@AnimatedContent,
), ),
contentScale = ContentScale.Fit, contentScale = ContentScale.Fit,
) )

View file

@ -51,12 +51,12 @@ fun NewReleaseNotification() {
val info = getJsonResponse<ReleaseInfo>("https://api.github.com/repos/$REPO/releases/latest") val info = getJsonResponse<ReleaseInfo>("https://api.github.com/repos/$REPO/releases/latest")
if (info != null) { if (info != null) {
val curRelease = BuildConfig.VERSION_NAME val curRelease = BuildConfig.VERSION_NAME
val newRelease = info.tag_name val newRelease = info.jsonObj.tag_name
val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease) val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease)
Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer") Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer")
if (isNewer) { if (isNewer) {
newReleaseVersion = newRelease newReleaseVersion = newRelease
newReleaseUrl = info.html_url newReleaseUrl = info.jsonObj.html_url
} }
} }
} }

View file

@ -60,6 +60,8 @@ import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import net.openid.appauth.AuthorizationException import net.openid.appauth.AuthorizationException
import net.openid.appauth.AuthorizationRequest import net.openid.appauth.AuthorizationRequest
import net.openid.appauth.AuthorizationResponse import net.openid.appauth.AuthorizationResponse
@ -73,6 +75,7 @@ private const val TAG = "AGModelManagerViewModel"
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50 private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
private const val MODEL_ALLOWLIST_URL = private const val MODEL_ALLOWLIST_URL =
"https://raw.githubusercontent.com/google-ai-edge/gallery/refs/heads/main/model_allowlist.json" "https://raw.githubusercontent.com/google-ai-edge/gallery/refs/heads/main/model_allowlist.json"
private const val MODEL_ALLOWLIST_FILENAME = "model_allowlist.json"
data class ModelInitializationStatus( data class ModelInitializationStatus(
val status: ModelInitializationStatusType, var error: String = "" val status: ModelInitializationStatusType, var error: String = ""
@ -146,6 +149,7 @@ data class PagerScrollState(
* and cleaning up models. It also manages the UI state for model management, including the * and cleaning up models. It also manages the UI state for model management, including the
* list of tasks, models, download statuses, and initialization statuses. * list of tasks, models, download statuses, and initialization statuses.
*/ */
@OptIn(ExperimentalSerializationApi::class)
open class ModelManagerViewModel( open class ModelManagerViewModel(
private val downloadRepository: DownloadRepository, private val downloadRepository: DownloadRepository,
private val dataStoreRepository: DataStoreRepository, private val dataStoreRepository: DataStoreRepository,
@ -227,8 +231,7 @@ open class ModelManagerViewModel(
dataStoreRepository.saveImportedModels(importedModels = importedModels) dataStoreRepository.saveImportedModels(importedModels = importedModels)
} }
val newUiState = uiState.value.copy( val newUiState = uiState.value.copy(
modelDownloadStatus = curModelDownloadStatus, modelDownloadStatus = curModelDownloadStatus, tasks = uiState.value.tasks.toList()
tasks = uiState.value.tasks.toList()
) )
_uiState.update { newUiState } _uiState.update { newUiState }
} }
@ -660,8 +663,17 @@ open class ModelManagerViewModel(
viewModelScope.launch(Dispatchers.IO) { viewModelScope.launch(Dispatchers.IO) {
try { try {
// Load model allowlist json. // Load model allowlist json.
val modelAllowlist: ModelAllowlist? = Log.d(TAG, "Loading model allowlist from internet...")
getJsonResponse<ModelAllowlist>(url = MODEL_ALLOWLIST_URL) val data = getJsonResponse<ModelAllowlist>(url = MODEL_ALLOWLIST_URL)
var modelAllowlist: ModelAllowlist? = data?.jsonObj
if (modelAllowlist == null) {
Log.d(TAG, "Failed to load model allowlist from internet. Trying to load it from disk")
modelAllowlist = readModelAllowlistFromDisk()
} else {
Log.d(TAG, "Done: loading model allowlist from internet")
saveModelAllowlistToDisk(modelAllowlistContent = data?.textContent ?: "{}")
}
if (modelAllowlist == null) { if (modelAllowlist == null) {
_uiState.update { uiState.value.copy(loadingModelAllowlistError = "Failed to load model list") } _uiState.update { uiState.value.copy(loadingModelAllowlistError = "Failed to load model list") }
@ -707,6 +719,40 @@ open class ModelManagerViewModel(
} }
} }
private fun saveModelAllowlistToDisk(modelAllowlistContent: String) {
try {
Log.d(TAG, "Saving model allowlist to disk...")
val file = File(externalFilesDir, MODEL_ALLOWLIST_FILENAME)
file.writeText(modelAllowlistContent)
Log.d(TAG, "Done: saving model allowlist to disk.")
} catch (e: Exception) {
Log.e(TAG, "failed to write model allowlist to disk", e)
}
}
private fun readModelAllowlistFromDisk(): ModelAllowlist? {
try {
Log.d(TAG, "Reading model allowlist from disk...")
val file = File(externalFilesDir, MODEL_ALLOWLIST_FILENAME)
if (file.exists()) {
val content = file.readText()
Log.d(TAG, "Model allowlist content from local file: $content")
val json = Json {
// Handle potential extra fields
ignoreUnknownKeys = true
allowComments = true
allowTrailingComma = true
}
return json.decodeFromString<ModelAllowlist>(content)
}
} catch (e: Exception) {
Log.e(TAG, "failed to read model allowlist from disk", e)
return null
}
return null
}
private fun isModelPartiallyDownloaded(model: Model): Boolean { private fun isModelPartiallyDownloaded(model: Model): Boolean {
return inProgressWorkInfos.find { it.modelName == model.name } != null return inProgressWorkInfos.find { it.modelName == model.name } != null
} }