mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-12 17:32:30 -04:00
Save the loaded model allowlist to a local file so that it can be read when the allowlist cannot be loaded from internet.
Also improve the image clipping transitioning from full image back to the image thumbnail in chat ui.
This commit is contained in:
parent
3495b7cf6e
commit
e7dda4b4ad
4 changed files with 70 additions and 18 deletions
|
@ -66,6 +66,10 @@ interface LatencyProvider {
|
|||
private const val START_THINKING = "***Thinking...***"
|
||||
private const val DONE_THINKING = "***Done thinking***"
|
||||
|
||||
data class JsonObjAndTextContent<T>(
|
||||
val jsonObj: T, val textContent: String,
|
||||
)
|
||||
|
||||
/** Format the bytes into a human-readable format. */
|
||||
fun Long.humanReadableSize(si: Boolean = true, extraDecimalForGbAndAbove: Boolean = false): String {
|
||||
val bytes = this
|
||||
|
@ -473,16 +477,14 @@ fun processTasks() {
|
|||
|
||||
fun processLlmResponse(response: String): String {
|
||||
// Add "thinking" and "done thinking" around the thinking content.
|
||||
var newContent = response
|
||||
.replace("<think>", "$START_THINKING\n")
|
||||
.replace("</think>", "\n$DONE_THINKING")
|
||||
var newContent =
|
||||
response.replace("<think>", "$START_THINKING\n").replace("</think>", "\n$DONE_THINKING")
|
||||
|
||||
// Remove empty thinking content.
|
||||
val endThinkingIndex = newContent.indexOf(DONE_THINKING)
|
||||
if (endThinkingIndex >= 0) {
|
||||
val thinkingContent =
|
||||
newContent.substring(0, endThinkingIndex + DONE_THINKING.length)
|
||||
.replace(START_THINKING, "")
|
||||
newContent.substring(0, endThinkingIndex + DONE_THINKING.length).replace(START_THINKING, "")
|
||||
.replace(DONE_THINKING, "")
|
||||
if (thinkingContent.isBlank()) {
|
||||
newContent = newContent.substring(endThinkingIndex + DONE_THINKING.length)
|
||||
|
@ -495,7 +497,7 @@ fun processLlmResponse(response: String): String {
|
|||
}
|
||||
|
||||
@OptIn(ExperimentalSerializationApi::class)
|
||||
inline fun <reified T> getJsonResponse(url: String): T? {
|
||||
inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
|
||||
try {
|
||||
val connection = URL(url).openConnection() as HttpURLConnection
|
||||
connection.requestMethod = "GET"
|
||||
|
@ -514,14 +516,13 @@ inline fun <reified T> getJsonResponse(url: String): T? {
|
|||
allowTrailingComma = true
|
||||
}
|
||||
val jsonObj = json.decodeFromString<T>(response)
|
||||
return jsonObj
|
||||
return JsonObjAndTextContent(jsonObj = jsonObj, textContent = response)
|
||||
} else {
|
||||
Log.e("AGUtils", "HTTP error: $responseCode")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(
|
||||
"AGUtils",
|
||||
"Error when getting json response: ${e.message}"
|
||||
"AGUtils", "Error when getting json response: ${e.message}"
|
||||
)
|
||||
e.printStackTrace()
|
||||
}
|
||||
|
|
|
@ -342,7 +342,12 @@ fun ChatPanel(
|
|||
}
|
||||
.sharedElement(
|
||||
sharedContentState = rememberSharedContentState(key = "selected_image"),
|
||||
animatedVisibilityScope = this@AnimatedContent
|
||||
animatedVisibilityScope = this@AnimatedContent,
|
||||
clipInOverlayDuringTransition = OverlayClip(
|
||||
MessageBubbleShape(
|
||||
radius = bubbleBorderRadius
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
}
|
||||
|
@ -378,7 +383,7 @@ fun ChatPanel(
|
|||
) {
|
||||
LatencyText(message = message)
|
||||
// A button to show stats for the LLM message.
|
||||
if ((task.type == TaskType.LLM_CHAT || task.type == TaskType.LLM_ASK_IMAGE) && message is ChatMessageText
|
||||
if (task.type.id.startsWith("llm_") && message is ChatMessageText
|
||||
// This means we only want to show the action button when the message is done
|
||||
// generating, at which point the latency will be set.
|
||||
&& message.latencyMs >= 0
|
||||
|
@ -565,7 +570,7 @@ fun ChatPanel(
|
|||
)
|
||||
.sharedElement(
|
||||
sharedContentState = rememberSharedContentState(key = "selected_image"),
|
||||
animatedVisibilityScope = this@AnimatedContent
|
||||
animatedVisibilityScope = this@AnimatedContent,
|
||||
),
|
||||
contentScale = ContentScale.Fit,
|
||||
)
|
||||
|
|
|
@ -51,12 +51,12 @@ fun NewReleaseNotification() {
|
|||
val info = getJsonResponse<ReleaseInfo>("https://api.github.com/repos/$REPO/releases/latest")
|
||||
if (info != null) {
|
||||
val curRelease = BuildConfig.VERSION_NAME
|
||||
val newRelease = info.tag_name
|
||||
val newRelease = info.jsonObj.tag_name
|
||||
val isNewer = isNewerRelease(currentRelease = curRelease, newRelease = newRelease)
|
||||
Log.d(TAG, "curRelease: $curRelease, newRelease: $newRelease, isNewer: $isNewer")
|
||||
if (isNewer) {
|
||||
newReleaseVersion = newRelease
|
||||
newReleaseUrl = info.html_url
|
||||
newReleaseUrl = info.jsonObj.html_url
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,6 +60,8 @@ import kotlinx.coroutines.flow.asStateFlow
|
|||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.ExperimentalSerializationApi
|
||||
import kotlinx.serialization.json.Json
|
||||
import net.openid.appauth.AuthorizationException
|
||||
import net.openid.appauth.AuthorizationRequest
|
||||
import net.openid.appauth.AuthorizationResponse
|
||||
|
@ -73,6 +75,7 @@ private const val TAG = "AGModelManagerViewModel"
|
|||
private const val TEXT_INPUT_HISTORY_MAX_SIZE = 50
|
||||
private const val MODEL_ALLOWLIST_URL =
|
||||
"https://raw.githubusercontent.com/google-ai-edge/gallery/refs/heads/main/model_allowlist.json"
|
||||
private const val MODEL_ALLOWLIST_FILENAME = "model_allowlist.json"
|
||||
|
||||
data class ModelInitializationStatus(
|
||||
val status: ModelInitializationStatusType, var error: String = ""
|
||||
|
@ -146,6 +149,7 @@ data class PagerScrollState(
|
|||
* and cleaning up models. It also manages the UI state for model management, including the
|
||||
* list of tasks, models, download statuses, and initialization statuses.
|
||||
*/
|
||||
@OptIn(ExperimentalSerializationApi::class)
|
||||
open class ModelManagerViewModel(
|
||||
private val downloadRepository: DownloadRepository,
|
||||
private val dataStoreRepository: DataStoreRepository,
|
||||
|
@ -227,8 +231,7 @@ open class ModelManagerViewModel(
|
|||
dataStoreRepository.saveImportedModels(importedModels = importedModels)
|
||||
}
|
||||
val newUiState = uiState.value.copy(
|
||||
modelDownloadStatus = curModelDownloadStatus,
|
||||
tasks = uiState.value.tasks.toList()
|
||||
modelDownloadStatus = curModelDownloadStatus, tasks = uiState.value.tasks.toList()
|
||||
)
|
||||
_uiState.update { newUiState }
|
||||
}
|
||||
|
@ -660,8 +663,17 @@ open class ModelManagerViewModel(
|
|||
viewModelScope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
// Load model allowlist json.
|
||||
val modelAllowlist: ModelAllowlist? =
|
||||
getJsonResponse<ModelAllowlist>(url = MODEL_ALLOWLIST_URL)
|
||||
Log.d(TAG, "Loading model allowlist from internet...")
|
||||
val data = getJsonResponse<ModelAllowlist>(url = MODEL_ALLOWLIST_URL)
|
||||
var modelAllowlist: ModelAllowlist? = data?.jsonObj
|
||||
|
||||
if (modelAllowlist == null) {
|
||||
Log.d(TAG, "Failed to load model allowlist from internet. Trying to load it from disk")
|
||||
modelAllowlist = readModelAllowlistFromDisk()
|
||||
} else {
|
||||
Log.d(TAG, "Done: loading model allowlist from internet")
|
||||
saveModelAllowlistToDisk(modelAllowlistContent = data?.textContent ?: "{}")
|
||||
}
|
||||
|
||||
if (modelAllowlist == null) {
|
||||
_uiState.update { uiState.value.copy(loadingModelAllowlistError = "Failed to load model list") }
|
||||
|
@ -707,6 +719,40 @@ open class ModelManagerViewModel(
|
|||
}
|
||||
}
|
||||
|
||||
private fun saveModelAllowlistToDisk(modelAllowlistContent: String) {
|
||||
try {
|
||||
Log.d(TAG, "Saving model allowlist to disk...")
|
||||
val file = File(externalFilesDir, MODEL_ALLOWLIST_FILENAME)
|
||||
file.writeText(modelAllowlistContent)
|
||||
Log.d(TAG, "Done: saving model allowlist to disk.")
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "failed to write model allowlist to disk", e)
|
||||
}
|
||||
}
|
||||
|
||||
private fun readModelAllowlistFromDisk(): ModelAllowlist? {
|
||||
try {
|
||||
Log.d(TAG, "Reading model allowlist from disk...")
|
||||
val file = File(externalFilesDir, MODEL_ALLOWLIST_FILENAME)
|
||||
if (file.exists()) {
|
||||
val content = file.readText()
|
||||
Log.d(TAG, "Model allowlist content from local file: $content")
|
||||
val json = Json {
|
||||
// Handle potential extra fields
|
||||
ignoreUnknownKeys = true
|
||||
allowComments = true
|
||||
allowTrailingComma = true
|
||||
}
|
||||
return json.decodeFromString<ModelAllowlist>(content)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "failed to read model allowlist from disk", e)
|
||||
return null
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
private fun isModelPartiallyDownloaded(model: Model): Boolean {
|
||||
return inProgressWorkInfos.find { it.modelName == model.name } != null
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue