From 3c8d6ae14d1b3be7f6ba027973ed8946474a1558 Mon Sep 17 00:00:00 2001 From: Spica910 <68777937+Spica910@users.noreply.github.com> Date: Mon, 2 Jun 2025 17:47:56 +0900 Subject: [PATCH 1/2] Create android.yml --- .github/workflows/android.yml | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 .github/workflows/android.yml diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml new file mode 100644 index 0000000..2f633b6 --- /dev/null +++ b/.github/workflows/android.yml @@ -0,0 +1,26 @@ +name: Android CI + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: set up JDK 11 + uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: 'temurin' + cache: gradle + + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew build From 010601f4ad2e7951ee0cca6a1d3ad7532dc65efb Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 2 Jun 2025 09:46:28 +0000 Subject: [PATCH 2/2] feat: Implement personal offline chat app features (Phase 1 & 2 initiation) This commit introduces a comprehensive set of features to transform the AI Edge Gallery into a personalized offline chat application. Phase 1: Core Offline Chat Functionality - Data Structures: Defined UserProfile, Persona, ChatMessage, Conversation, and UserDocument to model application data. - DataStoreRepository: Enhanced to manage persistence for all new data models, including encryption for UserProfile and local storage for conversations, personas, and user documents. Default personas are now also localized. - UI for Personal Information: Added a screen for you to input and edit your CV/resume details (name, summary, skills, experience). - Feature Removal: Streamlined the app by removing the "Ask Image" and "Prompt Lab" features to focus on chat. - UI for Persona Management: Implemented UI for creating, editing, deleting, and selecting an active persona to guide AI responses. - Core Chat Logic & UI: - Refactored LlmChatViewModel and LlmChatScreen. - Supports starting new conversations with an optional custom system prompt. - Integrates active persona and user profile summary into LLM context. - Manages conversations (saving messages, title, timestamps, model used, persona used). - Conversation History UI: Added a screen to view, open, and delete past conversations. - Localization: Implemented localization for English and Korean for all new user-facing strings and default personas. Phase 2: Document Handling (Started) - UserDocument data class defined for managing imported files. - DataStoreRepository updated to support CRUD operations for UserDocuments. The application now provides a personalized chat experience with features for managing user identity, AI personas, and conversation history, all designed for offline use. Further document handling, monetization, and cloud sync features are planned for subsequent phases. --- .../com/google/ai/edge/gallery/GalleryApp.kt | 9 +- .../ai/edge/gallery/data/AppContainer.kt | 2 +- .../edge/gallery/data/DataStoreRepository.kt | 281 +++++++++- .../ai/edge/gallery/data/PersonalAppModels.kt | 110 ++++ .../com/google/ai/edge/gallery/data/Tasks.kt | 34 +- .../ai/edge/gallery/ui/ViewModelProvider.kt | 40 +- .../gallery/ui/common/LocalAppContainer.kt | 6 + .../ConversationHistoryScreen.kt | 146 ++++++ .../ConversationHistoryViewModel.kt | 33 ++ .../ai/edge/gallery/ui/home/HomeScreen.kt | 12 +- .../ai/edge/gallery/ui/home/SettingsDialog.kt | 45 ++ .../gallery/ui/llmchat/LlmChatModelHelper.kt | 11 + .../edge/gallery/ui/llmchat/LlmChatScreen.kt | 242 +++++---- .../gallery/ui/llmchat/LlmChatViewModel.kt | 436 +++++++++------- .../ui/llmsingleturn/LlmSingleTurnScreen.kt | 218 -------- .../llmsingleturn/LlmSingleTurnViewModel.kt | 221 -------- .../ui/llmsingleturn/PromptTemplateConfigs.kt | 185 ------- .../ui/llmsingleturn/PromptTemplatesPanel.kt | 491 ------------------ .../gallery/ui/llmsingleturn/ResponsePanel.kt | 262 ---------- .../ui/llmsingleturn/SingleSelectButton.kt | 90 ---- .../ui/llmsingleturn/VerticalSplitView.kt | 133 ----- .../gallery/ui/navigation/GalleryNavGraph.kt | 173 ++++-- .../ui/persona/PersonaManagementScreen.kt | 197 +++++++ .../gallery/ui/persona/PersonaViewModel.kt | 68 +++ .../ui/userprofile/UserProfileScreen.kt | 177 +++++++ .../ui/userprofile/UserProfileViewModel.kt | 87 ++++ .../app/src/main/res/values-ko/strings.xml | 72 +++ .../src/app/src/main/res/values/strings.xml | 80 ++- 28 files changed, 1894 insertions(+), 1967 deletions(-) create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/data/PersonalAppModels.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/LocalAppContainer.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryScreen.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryViewModel.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/PromptTemplateConfigs.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/PromptTemplatesPanel.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/ResponsePanel.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/SingleSelectButton.kt delete mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/VerticalSplitView.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/persona/PersonaManagementScreen.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/persona/PersonaViewModel.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/userprofile/UserProfileScreen.kt create mode 100644 Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/userprofile/UserProfileViewModel.kt create mode 100644 Android/src/app/src/main/res/values-ko/strings.xml diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApp.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApp.kt index 19f53a0..e3b34c3 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApp.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApp.kt @@ -47,10 +47,14 @@ import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.sp +import androidx.compose.runtime.CompositionLocalProvider +import androidx.compose.ui.platform.LocalContext import androidx.navigation.NavHostController import androidx.navigation.compose.rememberNavController +import com.google.ai.edge.gallery.data.AppContainer // Needed for type import com.google.ai.edge.gallery.data.AppBarAction import com.google.ai.edge.gallery.data.AppBarActionType +import com.google.ai.edge.gallery.ui.common.LocalAppContainer import com.google.ai.edge.gallery.ui.navigation.GalleryNavHost /** @@ -58,7 +62,10 @@ import com.google.ai.edge.gallery.ui.navigation.GalleryNavHost */ @Composable fun GalleryApp(navController: NavHostController = rememberNavController()) { - GalleryNavHost(navController = navController) + val appContainer = (LocalContext.current.applicationContext as GalleryApplication).container + CompositionLocalProvider(LocalAppContainer provides appContainer) { + GalleryNavHost(navController = navController) + } } /** diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/AppContainer.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/AppContainer.kt index 57a905d..bb2375f 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/AppContainer.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/AppContainer.kt @@ -42,6 +42,6 @@ interface AppContainer { class DefaultAppContainer(ctx: Context, dataStore: DataStore) : AppContainer { override val context = ctx override val lifecycleProvider = GalleryLifecycleProvider() - override val dataStoreRepository = DefaultDataStoreRepository(dataStore) + override val dataStoreRepository = DefaultDataStoreRepository(dataStore, ctx) // Pass context here override val downloadRepository = DefaultDownloadRepository(ctx, lifecycleProvider) } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/DataStoreRepository.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/DataStoreRepository.kt index 82749e9..32f0024 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/DataStoreRepository.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/DataStoreRepository.kt @@ -16,9 +16,11 @@ package com.google.ai.edge.gallery.data +import android.content.Context // Added import import android.security.keystore.KeyGenParameterSpec import android.security.keystore.KeyProperties import android.util.Base64 +import com.google.ai.edge.gallery.R // Added import for R class import androidx.datastore.core.DataStore import androidx.datastore.preferences.core.Preferences import androidx.datastore.preferences.core.edit @@ -27,9 +29,13 @@ import androidx.datastore.preferences.core.stringPreferencesKey import com.google.gson.Gson import com.google.gson.reflect.TypeToken import com.google.ai.edge.gallery.ui.theme.THEME_AUTO +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.map import kotlinx.coroutines.runBlocking import java.security.KeyStore +import java.util.UUID import javax.crypto.Cipher import javax.crypto.KeyGenerator import javax.crypto.SecretKey @@ -50,6 +56,30 @@ interface DataStoreRepository { fun readAccessTokenData(): AccessTokenData? fun saveImportedModels(importedModels: List) fun readImportedModels(): List + + fun saveUserProfile(userProfile: UserProfile) + fun readUserProfile(): UserProfile? + fun savePersonas(personas: List) + fun readPersonas(): List + fun addPersona(persona: Persona) + fun updatePersona(persona: Persona) + fun deletePersona(personaId: String) + fun saveConversations(conversations: List) + fun readConversations(): List + fun getConversationById(conversationId: String): Conversation? + fun addConversation(conversation: Conversation) + fun updateConversation(conversation: Conversation) + fun deleteConversation(conversationId: String) + + fun saveActivePersonaId(personaId: String?) + fun readActivePersonaId(): Flow + + fun saveUserDocuments(documents: List) + fun readUserDocuments(): Flow> + fun addUserDocument(document: UserDocument) + fun updateUserDocument(document: UserDocument) + fun deleteUserDocument(documentId: String) + fun getUserDocumentById(documentId: String): Flow } /** @@ -62,7 +92,8 @@ interface DataStoreRepository { * DataStore is used to persist data as JSON strings under specified keys. */ class DefaultDataStoreRepository( - private val dataStore: DataStore + private val dataStore: DataStore, + private val context: Context // Added context ) : DataStoreRepository { @@ -85,6 +116,13 @@ class DefaultDataStoreRepository( // Data for all imported models. val IMPORTED_MODELS = stringPreferencesKey("imported_models") + + val ENCRYPTED_USER_PROFILE = stringPreferencesKey("encrypted_user_profile") + val USER_PROFILE_IV = stringPreferencesKey("user_profile_iv") + val PERSONAS_LIST = stringPreferencesKey("personas_list") + val CONVERSATIONS_LIST = stringPreferencesKey("conversations_list") + val ACTIVE_PERSONA_ID = stringPreferencesKey("active_persona_id") + val USER_DOCUMENTS_LIST = stringPreferencesKey("user_documents_list") } private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key" @@ -189,7 +227,8 @@ class DefaultDataStoreRepository( val infosStr = preferences[PreferencesKeys.IMPORTED_MODELS] ?: "[]" val gson = Gson() val listType = object : TypeToken>() {}.type - gson.fromJson(infosStr, listType) + // Ensure to return emptyList() if fromJson returns null + return gson.fromJson(infosStr, listType) ?: emptyList() } } @@ -197,7 +236,8 @@ class DefaultDataStoreRepository( val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]" val gson = Gson() val listType = object : TypeToken>() {}.type - return gson.fromJson(infosStr, listType) + // Ensure to return emptyList() if fromJson returns null + return gson.fromJson(infosStr, listType) ?: emptyList() } private fun getOrCreateSecretKey(): SecretKey { @@ -243,4 +283,239 @@ class DefaultDataStoreRepository( null } } + + override fun saveUserProfile(userProfile: UserProfile) { + runBlocking { + val gson = Gson() + val jsonString = gson.toJson(userProfile) + val (encryptedProfile, iv) = encrypt(jsonString) + dataStore.edit { preferences -> + preferences[PreferencesKeys.ENCRYPTED_USER_PROFILE] = encryptedProfile + preferences[PreferencesKeys.USER_PROFILE_IV] = iv + } + } + } + + override fun readUserProfile(): UserProfile? { + return runBlocking { + val preferences = dataStore.data.first() + val encryptedProfile = preferences[PreferencesKeys.ENCRYPTED_USER_PROFILE] + val iv = preferences[PreferencesKeys.USER_PROFILE_IV] + + if (encryptedProfile != null && iv != null) { + try { + val decryptedJson = decrypt(encryptedProfile, iv) + if (decryptedJson != null) { + Gson().fromJson(decryptedJson, UserProfile::class.java) + } else { + UserProfile() // Return default if decryption fails + } + } catch (e: Exception) { + UserProfile() // Return default on error + } + } else { + UserProfile() // Return default if not found + } + } + } + + override fun savePersonas(personas: List) { + runBlocking { + dataStore.edit { preferences -> + val gson = Gson() + val jsonString = gson.toJson(personas) + preferences[PreferencesKeys.PERSONAS_LIST] = jsonString + } + } + } + + override fun readPersonas(): List { + return runBlocking { + val preferences = dataStore.data.first() + val jsonString = preferences[PreferencesKeys.PERSONAS_LIST] + val gson = Gson() + val listType = object : TypeToken>() {}.type + var personas: List = if (jsonString != null) { + try { + gson.fromJson(jsonString, listType) ?: emptyList() + } catch (e: Exception) { + emptyList() // Return empty list on deserialization error + } + } else { + emptyList() + } + + if (personas.isEmpty()) { + personas = listOf( + Persona( + id = UUID.randomUUID().toString(), + name = context.getString(R.string.persona_add_edit_dialog_name_default_assist), + prompt = context.getString(R.string.persona_add_edit_dialog_prompt_default_assist), + isDefault = true + ), + Persona( + id = UUID.randomUUID().toString(), + name = context.getString(R.string.persona_add_edit_dialog_name_default_creative), + prompt = context.getString(R.string.persona_add_edit_dialog_prompt_default_creative), + isDefault = true + ) + ) + // Save these default personas back to DataStore + savePersonas(personas) + } + personas + } + } + + override fun addPersona(persona: Persona) { + val currentPersonas = readPersonas().toMutableList() + currentPersonas.add(persona) + savePersonas(currentPersonas) + } + + override fun updatePersona(persona: Persona) { + val currentPersonas = readPersonas().toMutableList() + val index = currentPersonas.indexOfFirst { it.id == persona.id } + if (index != -1) { + currentPersonas[index] = persona + savePersonas(currentPersonas) + } + } + + override fun deletePersona(personaId: String) { + val currentPersonas = readPersonas().toMutableList() + currentPersonas.removeAll { it.id == personaId } + savePersonas(currentPersonas) + } + + override fun saveConversations(conversations: List) { + runBlocking { + dataStore.edit { preferences -> + val gson = Gson() + val jsonString = gson.toJson(conversations) + preferences[PreferencesKeys.CONVERSATIONS_LIST] = jsonString + } + } + } + + override fun readConversations(): List { + return runBlocking { + val preferences = dataStore.data.first() + val jsonString = preferences[PreferencesKeys.CONVERSATIONS_LIST] + val gson = Gson() + val listType = object : TypeToken>() {}.type + if (jsonString != null) { + try { + gson.fromJson(jsonString, listType) ?: emptyList() + } catch (e: Exception) { + emptyList() // Return empty list on deserialization error + } + } else { + emptyList() + } + } + } + + override fun getConversationById(conversationId: String): Conversation? { + return readConversations().firstOrNull { it.id == conversationId } + } + + override fun addConversation(conversation: Conversation) { + val currentConversations = readConversations().toMutableList() + currentConversations.add(conversation) + saveConversations(currentConversations) + } + + override fun updateConversation(conversation: Conversation) { + val currentConversations = readConversations().toMutableList() + val index = currentConversations.indexOfFirst { it.id == conversation.id } + if (index != -1) { + currentConversations[index] = conversation + saveConversations(currentConversations) + } + } + + override fun deleteConversation(conversationId: String) { + val currentConversations = readConversations().toMutableList() + currentConversations.removeAll { it.id == conversationId } + saveConversations(currentConversations) + } + + override fun saveActivePersonaId(personaId: String?) { + runBlocking { + dataStore.edit { preferences -> + if (personaId == null) { + preferences.remove(PreferencesKeys.ACTIVE_PERSONA_ID) + } else { + preferences[PreferencesKeys.ACTIVE_PERSONA_ID] = personaId + } + } + } + } + + override fun readActivePersonaId(): Flow { + return dataStore.data.map { preferences -> + preferences[PreferencesKeys.ACTIVE_PERSONA_ID] + }.distinctUntilChanged() + } + + override fun saveUserDocuments(documents: List) { + runBlocking { + dataStore.edit { preferences -> + val gson = Gson() + val jsonString = gson.toJson(documents) + preferences[PreferencesKeys.USER_DOCUMENTS_LIST] = jsonString + } + } + } + + override fun readUserDocuments(): Flow> { + return dataStore.data.map { preferences -> + val jsonString = preferences[PreferencesKeys.USER_DOCUMENTS_LIST] + if (jsonString != null) { + val gson = Gson() + val type = object : TypeToken>() {}.type + gson.fromJson(jsonString, type) ?: emptyList() + } else { + emptyList() + } + }.distinctUntilChanged() + } + + override fun addUserDocument(document: UserDocument) { + runBlocking { // Consider making these suspend functions if runBlocking becomes an issue + val currentDocuments = readUserDocuments().first().toMutableList() + currentDocuments.removeAll { it.id == document.id } // Remove if already exists by ID, then add + currentDocuments.add(document) + saveUserDocuments(currentDocuments) + } + } + + override fun updateUserDocument(document: UserDocument) { + runBlocking { + val currentDocuments = readUserDocuments().first().toMutableList() + val index = currentDocuments.indexOfFirst { it.id == document.id } + if (index != -1) { + currentDocuments[index] = document + saveUserDocuments(currentDocuments) + } else { + // Optionally add if not found, or log an error + addUserDocument(document) // Or handle error: document to update not found + } + } + } + + override fun deleteUserDocument(documentId: String) { + runBlocking { + val currentDocuments = readUserDocuments().first().toMutableList() + currentDocuments.removeAll { it.id == documentId } + saveUserDocuments(currentDocuments) + } + } + + override fun getUserDocumentById(documentId: String): Flow { + return readUserDocuments().map { documents -> + documents.find { it.id == documentId } + } + } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/PersonalAppModels.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/PersonalAppModels.kt new file mode 100644 index 0000000..ca0b7af --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/PersonalAppModels.kt @@ -0,0 +1,110 @@ +package com.google.ai.edge.gallery.data + +import java.util.UUID + +/** + * Represents a user's profile information. + * + * @property name The name of the user. + * @property summary A brief summary or bio of the user. + * @property skills A list of the user's skills. + * @property experience A list of strings, where each string can represent a job or project description. + */ +data class UserProfile( + val name: String? = null, + val summary: String? = null, + val skills: List = emptyList(), + val experience: List = emptyList() +) + +/** + * Represents an AI persona that can be used in conversations. + * + * @property id A unique identifier for the persona (e.g., a UUID). + * @property name The name of the persona. + * @property prompt The system prompt associated with this persona, defining its behavior and responses. + * @property isDefault Indicates if this is a default persona. + */ +data class Persona( + val id: String = UUID.randomUUID().toString(), + val name: String, + val prompt: String, + val isDefault: Boolean = false +) + +/** + * Defines the role of the sender of a chat message. + */ +enum class ChatMessageRole { + /** The message is from the end-user. */ + USER, + /** The message is from the AI assistant. */ + ASSISTANT, + /** The message is a system instruction or context. */ + SYSTEM +} + +/** + * Represents a single message within a chat conversation. + * + * @property id A unique identifier for the message (e.g., a UUID). + * @property conversationId The ID of the conversation this message belongs to. + * @property timestamp The time the message was created, in epoch milliseconds. + * @property role The role of the message sender (user, assistant, or system). + * @property content The textual content of the message. + * @property personaUsedId The ID of the Persona active when this message was generated or sent, if applicable. + */ +data class ChatMessage( + val id: String = UUID.randomUUID().toString(), + val conversationId: String, + val timestamp: Long, + val role: ChatMessageRole, + val content: String, + val personaUsedId: String? = null +) + +/** + * Represents a chat conversation. + * + * @property id A unique identifier for the conversation (e.g., a UUID). + * @property title An optional user-defined title for the conversation. + * @property creationTimestamp The time the conversation was created, in epoch milliseconds. + * @property lastModifiedTimestamp The time the conversation was last modified, in epoch milliseconds. + * @property initialSystemPrompt A custom system prompt for this specific conversation, which might override a Persona's default prompt. + * @property messages A list of chat messages in this conversation. For local storage, embedding can work. For cloud sync, separate storage of messages linked by ID is better. + * @property activePersonaId The ID of the persona primarily used in this conversation. + * @property modelIdUsed The ID (e.g. name) of the AI model used in this conversation. + */ +data class Conversation( + val id: String = UUID.randomUUID().toString(), + var title: String? = null, + val creationTimestamp: Long, + var lastModifiedTimestamp: Long, + val initialSystemPrompt: String? = null, + val modelIdUsed: String? = null, // Add this + val messages: List = emptyList(), + val activePersonaId: String? = null +) + +/** + * Represents a document imported or managed by the user. + * + * @property id Unique identifier for the document (e.g., UUID). + * @property fileName Original name of the file. + * @property localPath Path to the locally stored copy of the document, if applicable. + * @property originalSource Indicates where the document came from (e.g., "local", a URL for Google Docs). + * @property fileType The MIME type or a simple extension string (e.g., "txt", "pdf", "docx"). + * @property extractedText The text content extracted from the document. Null if not yet extracted or not applicable. + * @property importTimestamp Timestamp when the document was imported. + * @property lastAccessedTimestamp Timestamp when the document was last used in a chat (optional). + */ +data class UserDocument( + val id: String = java.util.UUID.randomUUID().toString(), + val fileName: String, + val localPath: String? = null, + val originalSource: String, // e.g., "local", "google_drive_id:" + val fileType: String, // e.g., "text/plain", "application/pdf" + var extractedText: String? = null, + val importTimestamp: Long = System.currentTimeMillis(), + var lastAccessedTimestamp: Long? = null +) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt index 41a2314..1a98c9c 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt @@ -33,9 +33,7 @@ enum class TaskType(val label: String, val id: String) { IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"), IMAGE_GENERATION(label = "Image Generation", id = "image_generation"), LLM_CHAT(label = "AI Chat", id = "llm_chat"), - LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"), - LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"), - + // LLM_PROMPT_LAB and LLM_ASK_IMAGE removed from enum TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_2(label = "Test task 2", id = "test_task_2") } @@ -100,25 +98,8 @@ val TASK_LLM_CHAT = Task( textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat ) -val TASK_LLM_PROMPT_LAB = Task( - type = TaskType.LLM_PROMPT_LAB, - icon = Icons.Outlined.Widgets, - models = mutableListOf(), - description = "Single turn use cases with on-device large language model", - docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", - sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt", - textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat -) - -val TASK_LLM_ASK_IMAGE = Task( - type = TaskType.LLM_ASK_IMAGE, - icon = Icons.Outlined.Mms, - models = mutableListOf(), - description = "Ask questions about images with on-device large language models", - docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", - sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt", - textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat -) +// TASK_LLM_PROMPT_LAB definition removed +// TASK_LLM_ASK_IMAGE definition removed val TASK_IMAGE_GENERATION = Task( type = TaskType.IMAGE_GENERATION, @@ -132,9 +113,12 @@ val TASK_IMAGE_GENERATION = Task( /** All tasks. */ val TASKS: List = listOf( - TASK_LLM_ASK_IMAGE, - TASK_LLM_PROMPT_LAB, - TASK_LLM_CHAT, + TASK_TEXT_CLASSIFICATION, + TASK_IMAGE_CLASSIFICATION, + TASK_IMAGE_GENERATION, + TASK_LLM_CHAT + // TASK_LLM_ASK_IMAGE removed + // TASK_LLM_PROMPT_LAB removed ) fun getModelByName(name: String): Model? { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt index 8a257dd..dd9bc9a 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt @@ -22,13 +22,17 @@ import androidx.lifecycle.viewmodel.CreationExtras import androidx.lifecycle.viewmodel.initializer import androidx.lifecycle.viewmodel.viewModelFactory import com.google.ai.edge.gallery.GalleryApplication +import com.google.ai.edge.gallery.data.TASK_LLM_CHAT // Import TASK_LLM_CHAT import com.google.ai.edge.gallery.ui.imageclassification.ImageClassificationViewModel import com.google.ai.edge.gallery.ui.imagegeneration.ImageGenerationViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel -import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel -import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel +// import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel // Removed +// import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel // Removed import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.textclassification.TextClassificationViewModel +import com.google.ai.edge.gallery.ui.userprofile.UserProfileViewModel +import com.google.ai.edge.gallery.ui.persona.PersonaViewModel +import com.google.ai.edge.gallery.ui.conversationhistory.ConversationHistoryViewModel // Added import object ViewModelProvider { val Factory = viewModelFactory { @@ -55,23 +59,35 @@ object ViewModelProvider { // Initializer for LlmChatViewModel. initializer { - LlmChatViewModel() + val dataStoreRepository = galleryApplication().container.dataStoreRepository + LlmChatViewModel(dataStoreRepository = dataStoreRepository, curTask = TASK_LLM_CHAT) } - // Initializer for LlmSingleTurnViewModel.. - initializer { - LlmSingleTurnViewModel() - } - - // Initializer for LlmAskImageViewModel. - initializer { - LlmAskImageViewModel() - } + // Initializer for LlmSingleTurnViewModel.. - REMOVED + // Initializer for LlmAskImageViewModel. - REMOVED // Initializer for ImageGenerationViewModel. initializer { ImageGenerationViewModel() } + + // Initializer for UserProfileViewModel. + initializer { + val dataStoreRepository = galleryApplication().container.dataStoreRepository + UserProfileViewModel(dataStoreRepository = dataStoreRepository) + } + + // Initializer for PersonaViewModel. + initializer { + val dataStoreRepository = galleryApplication().container.dataStoreRepository + PersonaViewModel(dataStoreRepository = dataStoreRepository) + } + + // Initializer for ConversationHistoryViewModel. + initializer { + val dataStoreRepository = galleryApplication().container.dataStoreRepository + ConversationHistoryViewModel(dataStoreRepository = dataStoreRepository) + } } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/LocalAppContainer.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/LocalAppContainer.kt new file mode 100644 index 0000000..6f64314 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/LocalAppContainer.kt @@ -0,0 +1,6 @@ +package com.google.ai.edge.gallery.ui.common + +import androidx.compose.runtime.compositionLocalOf +import com.google.ai.edge.gallery.data.AppContainer + +val LocalAppContainer = compositionLocalOf { error("AppContainer not provided") } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryScreen.kt new file mode 100644 index 0000000..0b94fd5 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryScreen.kt @@ -0,0 +1,146 @@ +package com.google.ai.edge.gallery.ui.conversationhistory + +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.ArrowBack +import androidx.compose.material.icons.filled.Delete +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import androidx.compose.ui.res.stringResource // Added import +import androidx.lifecycle.viewmodel.compose.viewModel +import androidx.navigation.NavController +import com.google.ai.edge.gallery.R // Added import for R class +import com.google.ai.edge.gallery.data.Conversation +// import com.google.ai.edge.gallery.data.getModelByName // Not strictly needed if modelName is just a string +import com.google.ai.edge.gallery.ui.ViewModelProvider // For ViewModelProvider.Factory +import com.google.ai.edge.gallery.ui.navigation.LlmChatDestination // For navigation route +import java.text.SimpleDateFormat +import java.util.Date +import java.util.Locale + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun ConversationHistoryScreen( + navController: NavController, + viewModelFactory: ViewModelProvider.Factory, // Ensure this matches the actual factory name + viewModel: ConversationHistoryViewModel = viewModel(factory = viewModelFactory) +) { + val conversations by viewModel.conversations.collectAsState() + + Scaffold( + topBar = { + TopAppBar( + title = { Text(stringResource(R.string.conversation_history_title)) }, + navigationIcon = { + IconButton(onClick = { navController.popBackStack() }) { + Icon(Icons.Filled.ArrowBack, contentDescription = stringResource(R.string.user_profile_back_button_desc)) // Reused + } + } + ) + } + ) { paddingValues -> + if (conversations.isEmpty()) { + Box(modifier = Modifier.fillMaxSize().padding(paddingValues), contentAlignment = Alignment.Center) { + Text(stringResource(R.string.conversation_history_no_conversations)) + } + } else { + LazyColumn( + contentPadding = paddingValues, + modifier = Modifier.fillMaxSize().padding(horizontal = 8.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + items(conversations, key = { it.id }) { conversation -> + ConversationHistoryItem( + conversation = conversation, + onItemClick = { + val modelName = conversation.modelIdUsed + if (modelName != null) { + // Navigate using the route that includes conversationId and modelName + navController.navigate( + "${LlmChatDestination.routeTemplate}/conversation/${conversation.id}?modelName=${modelName}" + ) + } else { + // Fallback or error: modelIdUsed should ideally not be null for conversations created post-update. + // Consider navigating to a generic chat or showing an error. + // For now, this click might do nothing if modelIdUsed is null. + android.util.Log.w("ConvHistory", "modelIdUsed is null for conversation: ${conversation.id}") + } + }, + onDeleteClick = { viewModel.deleteConversation(conversation.id) } + ) + } + } + } + } +} + +@Composable +fun ConversationHistoryItem( + conversation: Conversation, + onItemClick: () -> Unit, + onDeleteClick: () -> Unit +) { + var showDeleteConfirmDialog by remember { mutableStateOf(false) } + val dateFormatter = remember { SimpleDateFormat("MMM dd, yyyy hh:mm a", Locale.getDefault()) } + + Card( + modifier = Modifier + .fillMaxWidth() + .clickable(onClick = onItemClick), + elevation = CardDefaults.cardElevation(defaultElevation = 2.dp) + ) { + Row( + modifier = Modifier.padding(16.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Column(modifier = Modifier.weight(1f)) { + Text( + conversation.title ?: stringResource(R.string.conversation_history_item_title_prefix, dateFormatter.format(Date(conversation.creationTimestamp))), + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + Spacer(Modifier.height(4.dp)) + Text( + "Model: ${conversation.modelIdUsed ?: stringResource(R.string.chat_default_agent_name)}", // Display model ID or default + style = MaterialTheme.typography.bodySmall, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + Text( + stringResource(R.string.conversation_history_last_activity_prefix, dateFormatter.format(Date(conversation.lastModifiedTimestamp))), + style = MaterialTheme.typography.bodySmall + ) + conversation.messages.lastOrNull()?.let { + Text( + "${it.role}: ${it.content.take(80)}${if (it.content.length > 80) "..." else ""}", + style = MaterialTheme.typography.bodySmall, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + } + } + IconButton(onClick = { showDeleteConfirmDialog = true }) { + Icon(Icons.Filled.Delete, contentDescription = stringResource(R.string.persona_item_delete_desc)) // Reused + } + } + } + if (showDeleteConfirmDialog) { + AlertDialog( + onDismissRequest = { showDeleteConfirmDialog = false }, + title = { Text(stringResource(R.string.conversation_history_delete_dialog_title)) }, + text = { Text(stringResource(R.string.conversation_history_delete_dialog_message)) }, + confirmButton = { Button(onClick = { onDeleteClick(); showDeleteConfirmDialog = false }) { Text(stringResource(R.string.conversation_history_delete_dialog_confirm_button)) } }, + dismissButton = { Button(onClick = { showDeleteConfirmDialog = false }) { Text(stringResource(R.string.dialog_cancel_button)) } } + ) + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryViewModel.kt new file mode 100644 index 0000000..6472c9c --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryViewModel.kt @@ -0,0 +1,33 @@ +package com.google.ai.edge.gallery.ui.conversationhistory + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.google.ai.edge.gallery.data.Conversation +import com.google.ai.edge.gallery.data.DataStoreRepository +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.launch + +class ConversationHistoryViewModel(private val dataStoreRepository: DataStoreRepository) : ViewModel() { + + private val _conversations = MutableStateFlow>(emptyList()) + val conversations: StateFlow> = _conversations.asStateFlow() + + init { + loadConversations() + } + + fun loadConversations() { + viewModelScope.launch { + _conversations.value = dataStoreRepository.readConversations().sortedByDescending { it.lastModifiedTimestamp } + } + } + + fun deleteConversation(conversationId: String) { + viewModelScope.launch { + dataStoreRepository.deleteConversation(conversationId) + loadConversations() // Refresh the list + } + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/HomeScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/HomeScreen.kt index 5b124c1..a7848fe 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/HomeScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/HomeScreen.kt @@ -130,10 +130,13 @@ object HomeScreenDestination { } @OptIn(ExperimentalMaterial3Api::class) +import androidx.navigation.NavController + @Composable fun HomeScreen( modelManagerViewModel: ModelManagerViewModel, navigateToTaskScreen: (Task) -> Unit, + navController: NavController, // Add NavController modifier: Modifier = Modifier ) { val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior() @@ -210,6 +213,7 @@ fun HomeScreen( SettingsDialog( curThemeOverride = modelManagerViewModel.readThemeOverride(), modelManagerViewModel = modelManagerViewModel, + navController = navController, // Pass NavController onDismissed = { showSettingsDialog = false }, ) } @@ -571,10 +575,16 @@ fun getFileName(context: Context, uri: Uri): String? { @Composable fun HomeScreenPreview( ) { + // Preview will not have a real NavController, so this might need adjustment + // if SettingsDialog is to be previewed from here. For now, focusing on functionality. + // For a simple preview, one might pass a dummy NavController or conditional logic. + val context = LocalContext.current + val dummyNavController = remember { NavController(context) } GalleryTheme { HomeScreen( - modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), + modelManagerViewModel = PreviewModelManagerViewModel(context = context), navigateToTaskScreen = {}, + navController = dummyNavController, // Pass dummy NavController for preview ) } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt index c0cae8f..4ee53ed 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt @@ -60,8 +60,12 @@ import androidx.compose.ui.text.TextStyle import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.unit.dp import androidx.compose.ui.window.Dialog +import androidx.navigation.NavController +import com.google.ai.edge.gallery.R // Added import for R class import com.google.ai.edge.gallery.BuildConfig import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel +import com.google.ai.edge.gallery.ui.navigation.PersonaManagementDestination // For the route +import com.google.ai.edge.gallery.ui.navigation.UserProfileDestination // For the route import com.google.ai.edge.gallery.ui.theme.THEME_AUTO import com.google.ai.edge.gallery.ui.theme.THEME_DARK import com.google.ai.edge.gallery.ui.theme.THEME_LIGHT @@ -79,6 +83,7 @@ private val THEME_OPTIONS = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK) fun SettingsDialog( curThemeOverride: String, modelManagerViewModel: ModelManagerViewModel, + navController: NavController, // Add NavController onDismissed: () -> Unit, ) { var selectedTheme by remember { mutableStateOf(curThemeOverride) } @@ -255,6 +260,46 @@ fun SettingsDialog( } } } + + // Personal Profile Section + Column( + modifier = Modifier.fillMaxWidth() + ) { + Text( + stringResource(R.string.settings_personal_profile_title), + style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold), + modifier = Modifier.padding(bottom = 8.dp) + ) + OutlinedButton( + onClick = { + navController.navigate(UserProfileDestination.route) + onDismissed() // Optionally dismiss settings dialog after navigation + }, + modifier = Modifier.fillMaxWidth() + ) { + Text(stringResource(R.string.settings_edit_profile_button)) + } + } + + // Persona Management Section + Column( + modifier = Modifier.fillMaxWidth() + ) { + Text( + stringResource(R.string.settings_persona_management_title), + style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold), + modifier = Modifier.padding(top = 16.dp, bottom = 8.dp) // Add top padding + ) + OutlinedButton( + onClick = { + navController.navigate(PersonaManagementDestination.route) + onDismissed() // Optionally dismiss settings dialog + }, + modifier = Modifier.fillMaxWidth() + ) { + Text(stringResource(R.string.settings_manage_personas_button)) + } + } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt index ce2094e..659fd5d 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -39,6 +39,17 @@ object LlmChatModelHelper { // Indexed by model name. private val cleanUpListeners: MutableMap = mutableMapOf() + fun primeSessionWithSystemPrompt(model: Model, systemPrompt: String) { + val instance = model.instance as? LlmModelInstance ?: return + try { + instance.session.addQueryChunk(systemPrompt) + Log.d(TAG, "Session primed with system prompt.") + } catch (e: Exception) { + Log.e(TAG, "Error priming session with system prompt: ", e) + // Consider how to handle this error, maybe throw or callback + } + } + fun initialize( context: Context, model: Model, onDone: (String) -> Unit ) { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt index 6cebc53..63e40b1 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt @@ -17,10 +17,18 @@ package com.google.ai.edge.gallery.ui.llmchat import android.graphics.Bitmap -import androidx.compose.runtime.Composable +import androidx.compose.foundation.layout.* // For Column, Row, Spacer, padding, etc. +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.ArrowBack +import androidx.compose.material3.* // For TextField, TopAppBar, etc. +import androidx.compose.runtime.* // For remember, mutableStateOf, collectAsState +import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.res.stringResource // For string resources +import androidx.compose.ui.unit.dp import androidx.lifecycle.viewmodel.compose.viewModel +import com.google.ai.edge.gallery.data.Model // Ensure Model is imported import com.google.ai.edge.gallery.ui.ViewModelProvider import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText @@ -30,115 +38,157 @@ import kotlinx.serialization.Serializable /** Navigation destination data */ object LlmChatDestination { - @Serializable - val route = "LlmChatRoute" -} - -object LlmAskImageDestination { - @Serializable - val route = "LlmAskImageRoute" + @Serializable // Keep serializable if used in NavType directly, though we'll use String for nav args + const val routeTemplate = "LlmChatRoute" + const val conversationIdArg = "conversationId" + // Route for opening an existing conversation + val routeForConversation = "$routeTemplate/conversation/{$conversationIdArg}" + // Route for starting a new chat, potentially with a pre-selected model + const val modelNameArg = "modelName" + val routeForNewChatWithModel = "$routeTemplate/new/{$modelNameArg}" + val routeForNewChat = routeTemplate // General new chat } @Composable fun LlmChatScreen( - modelManagerViewModel: ModelManagerViewModel, - navigateUp: () -> Unit, - modifier: Modifier = Modifier, - viewModel: LlmChatViewModel = viewModel( - factory = ViewModelProvider.Factory - ), + modelManagerViewModel: ModelManagerViewModel, + navigateUp: () -> Unit, + modifier: Modifier = Modifier, + viewModel: LlmChatViewModel = viewModel(factory = ViewModelProvider.Factory), + conversationId: String? = null // New parameter ) { - ChatViewWrapper( - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - navigateUp = navigateUp, - modifier = modifier, - ) + var customSystemPromptInput by remember { mutableStateOf("") } + val currentConversation by viewModel.currentConversation.collectAsState() + val activePersona by viewModel.activePersona.collectAsState() + val uiMessages by viewModel.uiMessages.collectAsState() // Observe uiMessages + + // Prioritize model from conversation if available, then from ModelManagerViewModel + val selectedModel: Model? = remember(currentConversation, modelManagerViewModel.getSelectedModel(viewModel.task.type)) { + modelManagerViewModel.getSelectedModel(viewModel.task.type) + } + + LaunchedEffect(conversationId, selectedModel) { + if (selectedModel == null) { + android.util.Log.e("LlmChatScreen", "No model selected for chat.") + return@LaunchedEffect + } + if (conversationId != null) { + viewModel.loadConversation(conversationId, selectedModel) + } else { + if (currentConversation == null || (currentConversation?.id == null && currentConversation?.messages.isNullOrEmpty())) { + // Let user type system prompt. startNewConversation will be called on first send. + } + } + } + + ChatViewWrapper( + viewModel = viewModel, + modelManagerViewModel = modelManagerViewModel, + navigateUp = navigateUp, + navController = navController, // Pass navController + modifier = modifier, + customSystemPromptInput = customSystemPromptInput, + onCustomSystemPromptChange = { customSystemPromptInput = it }, + activePersonaName = activePersona?.name ?: currentConversation?.activePersonaId?.let { "ID: $it" } // Fallback to ID if name not loaded + ) } -@Composable -fun LlmAskImageScreen( - modelManagerViewModel: ModelManagerViewModel, - navigateUp: () -> Unit, - modifier: Modifier = Modifier, - viewModel: LlmAskImageViewModel = viewModel( - factory = ViewModelProvider.Factory - ), -) { - ChatViewWrapper( - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - navigateUp = navigateUp, - modifier = modifier, - ) -} +// Removed duplicated ChatViewWrapper call and imports that were above it. +// The ExperimentalMaterial3Api annotation is kept for the actual ChatViewWrapper below. +@OptIn(ExperimentalMaterial3Api::class) @Composable fun ChatViewWrapper( - viewModel: LlmChatViewModel, - modelManagerViewModel: ModelManagerViewModel, - navigateUp: () -> Unit, - modifier: Modifier = Modifier + viewModel: LlmChatViewModel, + modelManagerViewModel: ModelManagerViewModel, + navigateUp: () -> Unit, + navController: NavController, // Added navController parameter + modifier: Modifier = Modifier, + customSystemPromptInput: String, + onCustomSystemPromptChange: (String) -> Unit, + activePersonaName: String? ) { - val context = LocalContext.current + val context = LocalContext.current + val currentConvo by viewModel.currentConversation.collectAsState() + val messagesForUi by viewModel.uiMessages.collectAsState() + val selectedModel = modelManagerViewModel.getSelectedModel(viewModel.task.type) - ChatView( - task = viewModel.task, - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - onSendMessage = { model, messages -> - for (message in messages) { - viewModel.addMessage( - model = model, - message = message, + // Show system prompt input if no conversation has started or if it's a new, empty conversation + val showSystemPromptInput = currentConvo == null || (currentConvo?.id != null && currentConvo!!.messages.isEmpty()) + + + Column(modifier = modifier.fillMaxSize()) { + TopAppBar( + title = { + Text(activePersonaName ?: stringResource(viewModel.task.agentNameRes)) + }, + navigationIcon = { + IconButton(onClick = navigateUp) { + Icon(Icons.Filled.ArrowBack, contentDescription = stringResource(R.string.user_profile_back_button_desc)) // Reused + } + }, + actions = { + IconButton(onClick = { navController.navigate(ConversationHistoryDestination.route) }) { + Icon(Icons.Filled.History, contentDescription = stringResource(R.string.chat_history_button_desc)) + } + } ) - } - var text = "" - var image: Bitmap? = null - var chatMessageText: ChatMessageText? = null - for (message in messages) { - if (message is ChatMessageText) { - chatMessageText = message - text = message.content - } else if (message is ChatMessageImage) { - image = message.bitmap + if (showSystemPromptInput) { + OutlinedTextField( + value = customSystemPromptInput, + onValueChange = onCustomSystemPromptChange, + label = { Text(stringResource(R.string.chat_custom_system_prompt_label)) }, + modifier = Modifier + .fillMaxWidth() + .padding(8.dp), + maxLines = 3 + ) } - } - if (text.isNotEmpty() && chatMessageText != null) { - modelManagerViewModel.addTextInputHistory(text) - viewModel.generateResponse(model = model, input = text, image = image, onError = { - viewModel.handleError( - context = context, - model = model, + + ChatView( + task = viewModel.task, + viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, - triggeredMessage = chatMessageText, - ) - }) - } - }, - onRunAgainClicked = { model, message -> - if (message is ChatMessageText) { - viewModel.runAgain(model = model, message = message, onError = { - viewModel.handleError( - context = context, - model = model, - modelManagerViewModel = modelManagerViewModel, - triggeredMessage = message, - ) - }) - } - }, - onBenchmarkClicked = { _, _, _, _ -> - }, - onResetSessionClicked = { model -> - viewModel.resetSession(model = model) - }, - showStopButtonInInputWhenInProgress = true, - onStopButtonClicked = { model -> - viewModel.stopResponse(model = model) - }, - navigateUp = navigateUp, - modifier = modifier, - ) + messages = messagesForUi, + onSendMessage = { modelFromChatView, userMessages -> + val userInputMessage = userMessages.firstNotNullOfOrNull { it as? ChatMessageText }?.content ?: "" + val imageBitmap = userMessages.firstNotNullOfOrNull { it as? ChatMessageImage }?.bitmap + + if (userInputMessage.isNotBlank() || imageBitmap != null) { + selectedModel?.let { validSelectedModel -> + modelManagerViewModel.addTextInputHistory(userInputMessage) + + if (currentConvo == null || (currentConvo?.id != null && currentConvo!!.messages.isEmpty() && !currentConvo!!.initialSystemPrompt.isNullOrEmpty().not() && customSystemPromptInput.isBlank())) { + viewModel.startNewConversation( + customSystemPrompt = if (customSystemPromptInput.isNotBlank()) customSystemPromptInput else null, + selectedPersonaId = viewModel.activePersona.value?.id, + title = userInputMessage.take(30).ifBlank { stringResource(R.string.chat_new_conversation_title_prefix) }, // Use string resource + selectedModel = validSelectedModel + ) + } + viewModel.generateChatResponse(model = validSelectedModel, input = userInputMessage, image = imageBitmap) + } ?: run { + android.util.Log.e("ChatViewWrapper", "No model selected, cannot send message.") + // Potentially show error to user + } + } + }, + // TODO: Add confirmation dialog for reset session + onResetSessionClicked = { model -> + selectedModel?.let { + onCustomSystemPromptChange("") + viewModel.startNewConversation( + customSystemPrompt = null, + selectedPersonaId = viewModel.activePersona.value?.id, + title = stringResource(R.string.chat_new_conversation_title_prefix), // Use string resource + selectedModel = it + ) + } + }, + showStopButtonInInputWhenInProgress = true, + onStopButtonClicked = { model -> viewModel.stopResponse(model) }, + navigateUp = navigateUp + ) + } } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt index b99cdc8..9957abf 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -21,10 +21,16 @@ import android.graphics.Bitmap import android.util.Log import androidx.lifecycle.viewModelScope import com.google.ai.edge.gallery.data.ConfigKey +import com.google.ai.edge.gallery.data.Conversation +import com.google.ai.edge.gallery.data.DataStoreRepository import com.google.ai.edge.gallery.data.Model +import com.google.ai.edge.gallery.data.Persona import com.google.ai.edge.gallery.data.TASK_LLM_CHAT -import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE +// import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE // Removed import com.google.ai.edge.gallery.data.Task +import com.google.ai.edge.gallery.data.UserProfile +import com.google.ai.edge.gallery.data.ChatMessage // Assuming ChatMessage is in data package from PersonalAppModels.kt +import com.google.ai.edge.gallery.data.ChatMessageRole import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText @@ -36,7 +42,9 @@ import com.google.ai.edge.gallery.ui.common.chat.Stat import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.* import kotlinx.coroutines.launch +import java.util.UUID private const val TAG = "AGLlmChatViewModel" private val STATS = listOf( @@ -46,203 +54,265 @@ private val STATS = listOf( Stat(id = "latency", label = "Latency", unit = "sec") ) -open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) { - fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) { - val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") - viewModelScope.launch(Dispatchers.Default) { - setInProgress(true) - setPreparing(true) +open class LlmChatViewModel( + private val dataStoreRepository: DataStoreRepository, // Add this + curTask: Task = TASK_LLM_CHAT // Keep if still relevant, or simplify if chat is the only focus +) : ChatViewModel(task = curTask) { // ChatViewModel base class might need review - // Loading. - addMessage( - model = model, - message = ChatMessageLoading(accelerator = accelerator), - ) + private val _currentConversation = MutableStateFlow(null) + val currentConversation: StateFlow = _currentConversation.asStateFlow() - // Wait for instance to be initialized. - while (model.instance == null) { - delay(100) - } - delay(500) + val userProfile: StateFlow = flow { + emit(dataStoreRepository.readUserProfile()) + }.stateIn(viewModelScope, SharingStarted.Eagerly, null) - // Run inference. - val instance = model.instance as LlmModelInstance - var prefillTokens = instance.session.sizeInTokens(input) - if (image != null) { - prefillTokens += 257 - } - - var firstRun = true - var timeToFirstToken = 0f - var firstTokenTs = 0L - var decodeTokens = 0 - var prefillSpeed = 0f - var decodeSpeed: Float - val start = System.currentTimeMillis() - - try { - LlmChatModelHelper.runInference(model = model, - input = input, - image = image, - resultListener = { partialResult, done -> - val curTs = System.currentTimeMillis() - - if (firstRun) { - firstTokenTs = System.currentTimeMillis() - timeToFirstToken = (firstTokenTs - start) / 1000f - prefillSpeed = prefillTokens / timeToFirstToken - firstRun = false - setPreparing(false) - } else { - decodeTokens++ + // Get active persona ID from repo, then fetch the full Persona object + val activePersona: StateFlow = dataStoreRepository.readActivePersonaId().flatMapLatest { activeId -> + if (activeId == null) { + flowOf(null) + } else { + flow { + val personas = dataStoreRepository.readPersonas() + emit(personas.find { it.id == activeId }) } - - // Remove the last message if it is a "loading" message. - // This will only be done once. - val lastMessage = getLastMessage(model = model) - if (lastMessage?.type == ChatMessageType.LOADING) { - removeLastMessage(model = model) - - // Add an empty message that will receive streaming results. - addMessage( - model = model, - message = ChatMessageText( - content = "", - side = ChatSide.AGENT, - accelerator = accelerator - ) - ) - } - - // Incrementally update the streamed partial results. - val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1 - updateLastTextMessageContentIncrementally( - model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat() - ) - - if (done) { - setInProgress(false) - - decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f) - if (decodeSpeed.isNaN()) { - decodeSpeed = 0f - } - - if (lastMessage is ChatMessageText) { - updateLastTextMessageLlmBenchmarkResult( - model = model, llmBenchmarkResult = ChatMessageBenchmarkLlmResult( - orderedStats = STATS, - statValues = mutableMapOf( - "prefill_speed" to prefillSpeed, - "decode_speed" to decodeSpeed, - "time_to_first_token" to timeToFirstToken, - "latency" to (curTs - start).toFloat() / 1000f, - ), - running = false, - latencyMs = -1f, - accelerator = accelerator, - ) - ) - } - } - }, - cleanUpListener = { - setInProgress(false) - setPreparing(false) - }) - } catch (e: Exception) { - Log.e(TAG, "Error occurred while running inference", e) - setInProgress(false) - setPreparing(false) - onError() - } - } - } - - fun stopResponse(model: Model) { - Log.d(TAG, "Stopping response for model ${model.name}...") - if (getLastMessage(model = model) is ChatMessageLoading) { - removeLastMessage(model = model) - } - viewModelScope.launch(Dispatchers.Default) { - setInProgress(false) - val instance = model.instance as LlmModelInstance - instance.session.cancelGenerateResponseAsync() - } - } - - fun resetSession(model: Model) { - viewModelScope.launch(Dispatchers.Default) { - setIsResettingSession(true) - clearAllMessages(model = model) - stopResponse(model = model) - - while (true) { - try { - LlmChatModelHelper.resetSession(model = model) - break - } catch (e: Exception) { - Log.d(TAG, "Failed to reset session. Trying again") } - delay(200) - } - setIsResettingSession(false) - } - } + }.stateIn(viewModelScope, SharingStarted.Eagerly, null) - fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) { - viewModelScope.launch(Dispatchers.Default) { - // Wait for model to be initialized. - while (model.instance == null) { - delay(100) - } + // This replaces messagesByModel for the new conversation-centric approach + private val _uiMessages = MutableStateFlow>(emptyList()) + val uiMessages: StateFlow> = _uiMessages.asStateFlow() - // Clone the clicked message and add it. - addMessage(model = model, message = message.clone()) + // TODO: Review how ChatViewModel's messagesByModel and related methods + // (addMessage, getLastMessage, removeLastMessage, clearAllMessages) are used. + // They might need to be overridden or adapted to work with _currentConversation.messages + // and update _uiMessages. For now, new methods will manage _currentConversation. - // Run inference. - generateResponse( - model = model, input = message.content, onError = onError - ) - } - } + fun startNewConversation(customSystemPrompt: String?, selectedPersonaId: String?, title: String? = null, selectedModel: Model) { + viewModelScope.launch { + val newConversationId = UUID.randomUUID().toString() + val newConversation = Conversation( + id = newConversationId, + title = title, + creationTimestamp = System.currentTimeMillis(), + lastModifiedTimestamp = System.currentTimeMillis(), + initialSystemPrompt = customSystemPrompt, + activePersonaId = selectedPersonaId, + modelIdUsed = selectedModel.name, // Store the model name or a unique ID + messages = mutableListOf() // Start with empty messages + ) + dataStoreRepository.addConversation(newConversation) + _currentConversation.value = newConversation + _uiMessages.value = emptyList() // Clear UI messages for the new chat - fun handleError( - context: Context, - model: Model, - modelManagerViewModel: ModelManagerViewModel, - triggeredMessage: ChatMessageText, - ) { - // Clean up. - modelManagerViewModel.cleanupModel(task = task, model = model) + // Reset LLM session for the selected model + LlmChatModelHelper.resetSession(selectedModel) // Ensure model instance is ready - // Remove the "loading" message. - if (getLastMessage(model = model) is ChatMessageLoading) { - removeLastMessage(model = model) + // Prime with system prompt immediately if starting a new conversation + val systemPromptParts = mutableListOf() + customSystemPrompt?.let { systemPromptParts.add(it) } + activePersona.value?.prompt?.let { systemPromptParts.add(it) } + userProfile.value?.summary?.let { if(it.isNotBlank()) systemPromptParts.add("User Profile Summary: $it") } + // Add other profile details as needed, e.g., skills + + if (systemPromptParts.isNotEmpty()) { + val fullSystemPrompt = systemPromptParts.joinToString("\n\n") + // Ensure model instance is available and ready before priming + if (selectedModel.instance == null) { + Log.e(TAG, "Model instance is null before priming. Initialize model first.") + // Potentially trigger model initialization via ModelManagerViewModel if not already done. + // For now, we assume the model selected for chat is already initialized by ModelManager. + return@launch + } + LlmChatModelHelper.primeSessionWithSystemPrompt(selectedModel, fullSystemPrompt) + } + } } - // Remove the last Text message. - if (getLastMessage(model = model) == triggeredMessage) { - removeLastMessage(model = model) + fun loadConversation(conversationId: String, selectedModel: Model) { + viewModelScope.launch { + val conversation = dataStoreRepository.getConversationById(conversationId) + _currentConversation.value = conversation + if (conversation != null) { + // Convert stored ChatMessage to UI ChatMessage + _uiMessages.value = conversation.messages.map { convertToUiChatMessage(it, selectedModel) } + + // Reset session and re-prime with history + LlmChatModelHelper.resetSession(selectedModel) + + val systemPromptParts = mutableListOf() + conversation.initialSystemPrompt?.let { systemPromptParts.add(it) } + // Need to fetch persona for this conversation + val personas = dataStoreRepository.readPersonas() + val personaForLoadedConv = personas.find { it.id == conversation.activePersonaId } + personaForLoadedConv?.prompt?.let { systemPromptParts.add(it) } + userProfile.value?.summary?.let { if(it.isNotBlank()) systemPromptParts.add("User Profile Summary: $it") } + // Add other profile details + + if (systemPromptParts.isNotEmpty()) { + LlmChatModelHelper.primeSessionWithSystemPrompt(selectedModel, systemPromptParts.joinToString("\n\n")) + } + // Replay message history into the session + conversation.messages.forEach { msg -> + if (msg.role == ChatMessageRole.USER) { + (selectedModel.instance as? LlmModelInstance)?.session?.addQueryChunk(msg.content) + } else if (msg.role == ChatMessageRole.ASSISTANT) { + // If LLM Inference API supports adding assistant messages to context, do it here. + // For now, assume addQueryChunk is for user input primarily to prompt next response. + (selectedModel.instance as? LlmModelInstance)?.session?.addQueryChunk(msg.content) // Or format appropriately + } + } + } else { + _uiMessages.value = emptyList() + } + } } - // Add a warning message for re-initializing the session. - addMessage( - model = model, - message = ChatMessageWarning(content = "Error occurred. Re-initializing the session.") - ) + // Helper to convert your data model ChatMessage to the UI model ChatMessage + private fun convertToUiChatMessage(appMessage: com.google.ai.edge.gallery.data.ChatMessage, model: Model): com.google.ai.edge.gallery.ui.common.chat.ChatMessage { + val side = if (appMessage.role == ChatMessageRole.USER) ChatSide.USER else ChatSide.AGENT + // This is a simplified conversion. You might need more fields. + // The existing ChatMessage in common.chat seems to be an interface/sealed class. + // We need to map to ChatMessageText or other appropriate types. + return ChatMessageText(content = appMessage.content, side = side, accelerator = model.getStringConfigValue(ConfigKey.ACCELERATOR, "")) + } - // Add the triggered message back. - addMessage(model = model, message = triggeredMessage) - // Re-initialize the session/engine. - modelManagerViewModel.initializeModel( - context = context, task = task, model = model - ) + // Override or adapt ChatViewModel's addMessage + override fun addMessage(model: Model, message: com.google.ai.edge.gallery.ui.common.chat.ChatMessage) { + val conversation = _currentConversation.value ?: return + val role = if (message.side == ChatSide.USER) ChatMessageRole.USER else ChatMessageRole.ASSISTANT - // Re-generate the response automatically. - generateResponse(model = model, input = triggeredMessage.content, onError = {}) - } -} + // Only add ChatMessageText for now to conversation history. Loading/Error messages are transient. + if (message is ChatMessageText) { + val appChatMessage = com.google.ai.edge.gallery.data.ChatMessage( + id = UUID.randomUUID().toString(), + conversationId = conversation.id, + timestamp = System.currentTimeMillis(), + role = role, + content = message.content, + personaUsedId = conversation.activePersonaId + ) + val updatedMessages = conversation.messages.toMutableList().apply { add(appChatMessage) } + _currentConversation.value = conversation.copy( + messages = updatedMessages, + lastModifiedTimestamp = System.currentTimeMillis() + ) + // Save updated conversation + viewModelScope.launch { + _currentConversation.value?.let { dataStoreRepository.updateConversation(it) } + } + } + // Update the UI-specific message list + _uiMessages.value = _uiMessages.value.toMutableList().apply { add(message) } + } -class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) \ No newline at end of file + // Adapt generateResponse + fun generateChatResponse(model: Model, userInput: String, image: Bitmap? = null) { // Renamed to avoid conflict if base still used + val currentConvo = _currentConversation.value + if (currentConvo == null) { + Log.e(TAG, "Cannot generate response, no active conversation.") + // Optionally, start a new conversation implicitly or show an error + return + } + + // Add user's message to conversation and UI + val userUiMessage = ChatMessageText(content = userInput, side = ChatSide.USER, accelerator = model.getStringConfigValue(ConfigKey.ACCELERATOR, "")) + addMessage(model, userUiMessage) // This will also save it to DataStore + + // The rest is similar to original generateResponse, but uses currentConvo + val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") + viewModelScope.launch(Dispatchers.Default) { + setInProgress(true) // From base ChatViewModel + setPreparing(true) // From base ChatViewModel + + addMessage(model, ChatMessageLoading(accelerator = accelerator)) // Show loading in UI + + while (model.instance == null) { delay(100) } // Wait for model instance + delay(500) + + val instance = model.instance as LlmModelInstance + // History is now part of the session, primed by startNewConversation or loadConversation. + // LlmChatModelHelper.runInference will just add the latest userInput. + + try { + LlmChatModelHelper.runInference( + model = model, + input = userInput, // Just the new input + image = image, // Handle image if provided + resultListener = { partialResult, done -> + // UI update logic for streaming response - largely same as original + // Ensure to use 'addMessage' or similar to update _uiMessages + // and save assistant's final message to _currentConversation + val lastUiMsg = _uiMessages.value.lastOrNull() + if (lastUiMsg?.type == ChatMessageType.LOADING) { + _uiMessages.value = _uiMessages.value.dropLast(1) + // Add an empty message that will receive streaming results for UI + addMessage(model, ChatMessageText(content = "", side = ChatSide.AGENT, accelerator = accelerator)) + } + + val currentAgentMessage = _uiMessages.value.lastOrNull() as? ChatMessageText + if (currentAgentMessage != null) { + _uiMessages.value = _uiMessages.value.dropLast(1) + currentAgentMessage.copy(content = currentAgentMessage.content + partialResult) + } + + + if (done) { + setInProgress(false) + val finalAssistantContent = (_uiMessages.value.lastOrNull() as? ChatMessageText)?.content ?: "" + // Add final assistant message to DataStore Conversation + val assistantAppMessage = com.google.ai.edge.gallery.data.ChatMessage( + id = UUID.randomUUID().toString(), + conversationId = currentConvo.id, + timestamp = System.currentTimeMillis(), + role = ChatMessageRole.ASSISTANT, + content = finalAssistantContent, + personaUsedId = currentConvo.activePersonaId + ) + val updatedMessages = currentConvo.messages.toMutableList().apply { add(assistantAppMessage) } + _currentConversation.value = currentConvo.copy( + messages = updatedMessages, + lastModifiedTimestamp = System.currentTimeMillis() + ) + viewModelScope.launch { _currentConversation.value?.let { dataStoreRepository.updateConversation(it) } } + // Update benchmark results if necessary (code omitted for brevity, but similar logic as original generateResponse) + val lastMessage = _uiMessages.value.lastOrNull { it.side == ChatSide.AGENT } // get the agent's message + if (lastMessage is ChatMessageText && STATS.isNotEmpty()) { // Assuming STATS is defined + // This part needs to be adapted. The original `generateResponse` calculated these. + // For now, we'll omit direct benchmark updates here to simplify, + // as they depend on variables (timeToFirstToken, prefillSpeed, etc.) + // not directly available in this refactored `generateChatResponse` structure + // without significant further adaptation of the LlmChatModelHelper.runInference callback. + // A simpler approach might be to just mark the message as not running. + val updatedAgentMessage = lastMessage.copy( + llmBenchmarkResult = lastMessage.llmBenchmarkResult?.copy(running = false) + ?: ChatMessageBenchmarkLlmResult(orderedStats = STATS, statValues = mutableMapOf(), running = false, latencyMs = -1f, accelerator = accelerator) + ) + val finalUiMessages = _uiMessages.value.toMutableList() + val agentMsgIndex = finalUiMessages.indexOfLast { it.id == lastMessage.id } + if(agentMsgIndex != -1) { + finalUiMessages[agentMsgIndex] = updatedAgentMessage + _uiMessages.value = finalUiMessages + } + } + } + }, + cleanUpListener = { + setInProgress(false) + setPreparing(false) + } + ) + } catch (e: Exception) { + Log.e(TAG, "Error in generateChatResponse: ", e) + setInProgress(false) + setPreparing(false) + // Add error message to UI + addMessage(model, ChatMessageWarning(content = "Error: ${e.message}")) + } + } + } + + // TODO: Override/adapt clearAllMessages, stopResponse, runAgain, handleError from base ChatViewModel + // to work with the new currentConversation model and _uiMessages. + // For example, clearAllMessages should clear _uiMessages and potentially currentConvo.messages then save. + // resetSession should re-prime with system prompt and history if currentConvo exists. +} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt deleted file mode 100644 index 94d1c66..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.llmsingleturn - -import android.util.Log -import androidx.activity.compose.BackHandler -import androidx.compose.foundation.background -import androidx.compose.foundation.layout.Box -import androidx.compose.foundation.layout.Column -import androidx.compose.foundation.layout.calculateStartPadding -import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.foundation.layout.padding -import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.Scaffold -import androidx.compose.runtime.Composable -import androidx.compose.runtime.LaunchedEffect -import androidx.compose.runtime.collectAsState -import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.remember -import androidx.compose.runtime.rememberCoroutineScope -import androidx.compose.runtime.setValue -import androidx.compose.ui.Alignment -import androidx.compose.ui.Modifier -import androidx.compose.ui.draw.alpha -import androidx.compose.ui.platform.LocalContext -import androidx.compose.ui.platform.LocalLayoutDirection -import androidx.compose.ui.tooling.preview.Preview -import androidx.lifecycle.viewmodel.compose.viewModel -import com.google.ai.edge.gallery.data.ModelDownloadStatusType -import com.google.ai.edge.gallery.ui.ViewModelProvider -import com.google.ai.edge.gallery.ui.common.ErrorDialog -import com.google.ai.edge.gallery.ui.common.ModelPageAppBar -import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel -import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType -import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel -import com.google.ai.edge.gallery.ui.preview.PreviewLlmSingleTurnViewModel -import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel -import com.google.ai.edge.gallery.ui.theme.GalleryTheme -import com.google.ai.edge.gallery.ui.theme.customColors -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch -import kotlinx.serialization.Serializable - -/** Navigation destination data */ -object LlmSingleTurnDestination { - @Serializable - val route = "LlmSingleTurnRoute" -} - -private const val TAG = "AGLlmSingleTurnScreen" - -@Composable -fun LlmSingleTurnScreen( - modelManagerViewModel: ModelManagerViewModel, - navigateUp: () -> Unit, - modifier: Modifier = Modifier, - viewModel: LlmSingleTurnViewModel = viewModel( - factory = ViewModelProvider.Factory - ), -) { - val task = viewModel.task - val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() - val uiState by viewModel.uiState.collectAsState() - val selectedModel = modelManagerUiState.selectedModel - val scope = rememberCoroutineScope() - val context = LocalContext.current - var navigatingUp by remember { mutableStateOf(false) } - var showErrorDialog by remember { mutableStateOf(false) } - - val handleNavigateUp = { - navigatingUp = true - navigateUp() - - // clean up all models. - scope.launch(Dispatchers.Default) { - for (model in task.models) { - modelManagerViewModel.cleanupModel(task = task, model = model) - } - } - } - - // Handle system's edge swipe. - BackHandler { - handleNavigateUp() - } - - // Initialize model when model/download state changes. - val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name] - LaunchedEffect(curDownloadStatus, selectedModel.name) { - if (!navigatingUp) { - if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { - Log.d( - TAG, - "Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect" - ) - modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) - } - } - } - - val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[selectedModel.name] - LaunchedEffect(modelInitializationStatus) { - showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR - } - - Scaffold(modifier = modifier, topBar = { - ModelPageAppBar( - task = task, - model = selectedModel, - modelManagerViewModel = modelManagerViewModel, - inProgress = uiState.inProgress, - modelPreparing = uiState.preparing, - onConfigChanged = { _, _ -> }, - onBackClicked = { handleNavigateUp() }, - onModelSelected = { newSelectedModel -> - scope.launch(Dispatchers.Default) { - // Clean up current model. - modelManagerViewModel.cleanupModel(task = task, model = selectedModel) - - // Update selected model. - modelManagerViewModel.selectModel(model = newSelectedModel) - } - } - ) - }) { innerPadding -> - Column( - modifier = Modifier.padding( - top = innerPadding.calculateTopPadding(), - start = innerPadding.calculateStartPadding(LocalLayoutDirection.current), - end = innerPadding.calculateStartPadding(LocalLayoutDirection.current), - ) - ) { - ModelDownloadStatusInfoPanel( - model = selectedModel, - task = task, - modelManagerViewModel = modelManagerViewModel - ) - - // Main UI after model is downloaded. - val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED - Box( - contentAlignment = Alignment.BottomCenter, - modifier = Modifier - .weight(1f) - // Just hide the UI without removing it from the screen so that the scroll syncing - // from ResponsePanel still works. - .alpha(if (modelDownloaded) 1.0f else 0.0f) - ) { - VerticalSplitView(modifier = Modifier.fillMaxSize(), - topView = { - PromptTemplatesPanel( - model = selectedModel, - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - onSend = { fullPrompt -> - viewModel.generateResponse(model = selectedModel, input = fullPrompt) - }, - onStopButtonClicked = { model -> - viewModel.stopResponse(model = model) - }, - modifier = Modifier.fillMaxSize() - ) - }, - bottomView = { - Box( - contentAlignment = Alignment.BottomCenter, - modifier = Modifier - .fillMaxSize() - .background(MaterialTheme.customColors.agentBubbleBgColor) - ) { - ResponsePanel( - model = selectedModel, - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - modifier = Modifier - .fillMaxSize() - .padding(bottom = innerPadding.calculateBottomPadding()) - ) - } - }) - } - - if (showErrorDialog) { - ErrorDialog(error = modelInitializationStatus?.error ?: "", onDismiss = { - showErrorDialog = false - }) - } - } - } -} - -@Preview(showBackground = true) -@Composable -fun LlmSingleTurnScreenPreview() { - val context = LocalContext.current - GalleryTheme { - LlmSingleTurnScreen( - modelManagerViewModel = PreviewModelManagerViewModel(context = context), - viewModel = PreviewLlmSingleTurnViewModel(), - navigateUp = {}, - ) - } -} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt deleted file mode 100644 index 64ce612..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.llmsingleturn - -import android.util.Log -import androidx.lifecycle.ViewModel -import androidx.lifecycle.viewModelScope -import com.google.ai.edge.gallery.data.Model -import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB -import com.google.ai.edge.gallery.data.Task -import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult -import com.google.ai.edge.gallery.ui.common.chat.Stat -import com.google.ai.edge.gallery.ui.common.processLlmResponse -import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper -import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.asStateFlow -import kotlinx.coroutines.flow.update -import kotlinx.coroutines.launch - -private const val TAG = "AGLlmSingleTurnViewModel" - -data class LlmSingleTurnUiState( - /** - * Indicates whether the runtime is currently processing a message. - */ - val inProgress: Boolean = false, - - /** - * Indicates whether the model is preparing (before outputting any result and after initializing). - */ - val preparing: Boolean = false, - - // model ->