diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApp.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApp.kt index 19f53a0..e3b34c3 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApp.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/GalleryApp.kt @@ -47,10 +47,14 @@ import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.sp +import androidx.compose.runtime.CompositionLocalProvider +import androidx.compose.ui.platform.LocalContext import androidx.navigation.NavHostController import androidx.navigation.compose.rememberNavController +import com.google.ai.edge.gallery.data.AppContainer // Needed for type import com.google.ai.edge.gallery.data.AppBarAction import com.google.ai.edge.gallery.data.AppBarActionType +import com.google.ai.edge.gallery.ui.common.LocalAppContainer import com.google.ai.edge.gallery.ui.navigation.GalleryNavHost /** @@ -58,7 +62,10 @@ import com.google.ai.edge.gallery.ui.navigation.GalleryNavHost */ @Composable fun GalleryApp(navController: NavHostController = rememberNavController()) { - GalleryNavHost(navController = navController) + val appContainer = (LocalContext.current.applicationContext as GalleryApplication).container + CompositionLocalProvider(LocalAppContainer provides appContainer) { + GalleryNavHost(navController = navController) + } } /** diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/AppContainer.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/AppContainer.kt index 57a905d..bb2375f 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/AppContainer.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/AppContainer.kt @@ -42,6 +42,6 @@ interface AppContainer { class DefaultAppContainer(ctx: Context, dataStore: DataStore) : AppContainer { override val context = ctx override val lifecycleProvider = GalleryLifecycleProvider() - override val dataStoreRepository = DefaultDataStoreRepository(dataStore) + override val dataStoreRepository = DefaultDataStoreRepository(dataStore, ctx) // Pass context here override val downloadRepository = DefaultDownloadRepository(ctx, lifecycleProvider) } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/DataStoreRepository.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/DataStoreRepository.kt index 82749e9..32f0024 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/DataStoreRepository.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/DataStoreRepository.kt @@ -16,9 +16,11 @@ package com.google.ai.edge.gallery.data +import android.content.Context // Added import import android.security.keystore.KeyGenParameterSpec import android.security.keystore.KeyProperties import android.util.Base64 +import com.google.ai.edge.gallery.R // Added import for R class import androidx.datastore.core.DataStore import androidx.datastore.preferences.core.Preferences import androidx.datastore.preferences.core.edit @@ -27,9 +29,13 @@ import androidx.datastore.preferences.core.stringPreferencesKey import com.google.gson.Gson import com.google.gson.reflect.TypeToken import com.google.ai.edge.gallery.ui.theme.THEME_AUTO +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.map import kotlinx.coroutines.runBlocking import java.security.KeyStore +import java.util.UUID import javax.crypto.Cipher import javax.crypto.KeyGenerator import javax.crypto.SecretKey @@ -50,6 +56,30 @@ interface DataStoreRepository { fun readAccessTokenData(): AccessTokenData? fun saveImportedModels(importedModels: List) fun readImportedModels(): List + + fun saveUserProfile(userProfile: UserProfile) + fun readUserProfile(): UserProfile? + fun savePersonas(personas: List) + fun readPersonas(): List + fun addPersona(persona: Persona) + fun updatePersona(persona: Persona) + fun deletePersona(personaId: String) + fun saveConversations(conversations: List) + fun readConversations(): List + fun getConversationById(conversationId: String): Conversation? + fun addConversation(conversation: Conversation) + fun updateConversation(conversation: Conversation) + fun deleteConversation(conversationId: String) + + fun saveActivePersonaId(personaId: String?) + fun readActivePersonaId(): Flow + + fun saveUserDocuments(documents: List) + fun readUserDocuments(): Flow> + fun addUserDocument(document: UserDocument) + fun updateUserDocument(document: UserDocument) + fun deleteUserDocument(documentId: String) + fun getUserDocumentById(documentId: String): Flow } /** @@ -62,7 +92,8 @@ interface DataStoreRepository { * DataStore is used to persist data as JSON strings under specified keys. */ class DefaultDataStoreRepository( - private val dataStore: DataStore + private val dataStore: DataStore, + private val context: Context // Added context ) : DataStoreRepository { @@ -85,6 +116,13 @@ class DefaultDataStoreRepository( // Data for all imported models. val IMPORTED_MODELS = stringPreferencesKey("imported_models") + + val ENCRYPTED_USER_PROFILE = stringPreferencesKey("encrypted_user_profile") + val USER_PROFILE_IV = stringPreferencesKey("user_profile_iv") + val PERSONAS_LIST = stringPreferencesKey("personas_list") + val CONVERSATIONS_LIST = stringPreferencesKey("conversations_list") + val ACTIVE_PERSONA_ID = stringPreferencesKey("active_persona_id") + val USER_DOCUMENTS_LIST = stringPreferencesKey("user_documents_list") } private val keystoreAlias: String = "com_google_aiedge_gallery_access_token_key" @@ -189,7 +227,8 @@ class DefaultDataStoreRepository( val infosStr = preferences[PreferencesKeys.IMPORTED_MODELS] ?: "[]" val gson = Gson() val listType = object : TypeToken>() {}.type - gson.fromJson(infosStr, listType) + // Ensure to return emptyList() if fromJson returns null + return gson.fromJson(infosStr, listType) ?: emptyList() } } @@ -197,7 +236,8 @@ class DefaultDataStoreRepository( val infosStr = preferences[PreferencesKeys.TEXT_INPUT_HISTORY] ?: "[]" val gson = Gson() val listType = object : TypeToken>() {}.type - return gson.fromJson(infosStr, listType) + // Ensure to return emptyList() if fromJson returns null + return gson.fromJson(infosStr, listType) ?: emptyList() } private fun getOrCreateSecretKey(): SecretKey { @@ -243,4 +283,239 @@ class DefaultDataStoreRepository( null } } + + override fun saveUserProfile(userProfile: UserProfile) { + runBlocking { + val gson = Gson() + val jsonString = gson.toJson(userProfile) + val (encryptedProfile, iv) = encrypt(jsonString) + dataStore.edit { preferences -> + preferences[PreferencesKeys.ENCRYPTED_USER_PROFILE] = encryptedProfile + preferences[PreferencesKeys.USER_PROFILE_IV] = iv + } + } + } + + override fun readUserProfile(): UserProfile? { + return runBlocking { + val preferences = dataStore.data.first() + val encryptedProfile = preferences[PreferencesKeys.ENCRYPTED_USER_PROFILE] + val iv = preferences[PreferencesKeys.USER_PROFILE_IV] + + if (encryptedProfile != null && iv != null) { + try { + val decryptedJson = decrypt(encryptedProfile, iv) + if (decryptedJson != null) { + Gson().fromJson(decryptedJson, UserProfile::class.java) + } else { + UserProfile() // Return default if decryption fails + } + } catch (e: Exception) { + UserProfile() // Return default on error + } + } else { + UserProfile() // Return default if not found + } + } + } + + override fun savePersonas(personas: List) { + runBlocking { + dataStore.edit { preferences -> + val gson = Gson() + val jsonString = gson.toJson(personas) + preferences[PreferencesKeys.PERSONAS_LIST] = jsonString + } + } + } + + override fun readPersonas(): List { + return runBlocking { + val preferences = dataStore.data.first() + val jsonString = preferences[PreferencesKeys.PERSONAS_LIST] + val gson = Gson() + val listType = object : TypeToken>() {}.type + var personas: List = if (jsonString != null) { + try { + gson.fromJson(jsonString, listType) ?: emptyList() + } catch (e: Exception) { + emptyList() // Return empty list on deserialization error + } + } else { + emptyList() + } + + if (personas.isEmpty()) { + personas = listOf( + Persona( + id = UUID.randomUUID().toString(), + name = context.getString(R.string.persona_add_edit_dialog_name_default_assist), + prompt = context.getString(R.string.persona_add_edit_dialog_prompt_default_assist), + isDefault = true + ), + Persona( + id = UUID.randomUUID().toString(), + name = context.getString(R.string.persona_add_edit_dialog_name_default_creative), + prompt = context.getString(R.string.persona_add_edit_dialog_prompt_default_creative), + isDefault = true + ) + ) + // Save these default personas back to DataStore + savePersonas(personas) + } + personas + } + } + + override fun addPersona(persona: Persona) { + val currentPersonas = readPersonas().toMutableList() + currentPersonas.add(persona) + savePersonas(currentPersonas) + } + + override fun updatePersona(persona: Persona) { + val currentPersonas = readPersonas().toMutableList() + val index = currentPersonas.indexOfFirst { it.id == persona.id } + if (index != -1) { + currentPersonas[index] = persona + savePersonas(currentPersonas) + } + } + + override fun deletePersona(personaId: String) { + val currentPersonas = readPersonas().toMutableList() + currentPersonas.removeAll { it.id == personaId } + savePersonas(currentPersonas) + } + + override fun saveConversations(conversations: List) { + runBlocking { + dataStore.edit { preferences -> + val gson = Gson() + val jsonString = gson.toJson(conversations) + preferences[PreferencesKeys.CONVERSATIONS_LIST] = jsonString + } + } + } + + override fun readConversations(): List { + return runBlocking { + val preferences = dataStore.data.first() + val jsonString = preferences[PreferencesKeys.CONVERSATIONS_LIST] + val gson = Gson() + val listType = object : TypeToken>() {}.type + if (jsonString != null) { + try { + gson.fromJson(jsonString, listType) ?: emptyList() + } catch (e: Exception) { + emptyList() // Return empty list on deserialization error + } + } else { + emptyList() + } + } + } + + override fun getConversationById(conversationId: String): Conversation? { + return readConversations().firstOrNull { it.id == conversationId } + } + + override fun addConversation(conversation: Conversation) { + val currentConversations = readConversations().toMutableList() + currentConversations.add(conversation) + saveConversations(currentConversations) + } + + override fun updateConversation(conversation: Conversation) { + val currentConversations = readConversations().toMutableList() + val index = currentConversations.indexOfFirst { it.id == conversation.id } + if (index != -1) { + currentConversations[index] = conversation + saveConversations(currentConversations) + } + } + + override fun deleteConversation(conversationId: String) { + val currentConversations = readConversations().toMutableList() + currentConversations.removeAll { it.id == conversationId } + saveConversations(currentConversations) + } + + override fun saveActivePersonaId(personaId: String?) { + runBlocking { + dataStore.edit { preferences -> + if (personaId == null) { + preferences.remove(PreferencesKeys.ACTIVE_PERSONA_ID) + } else { + preferences[PreferencesKeys.ACTIVE_PERSONA_ID] = personaId + } + } + } + } + + override fun readActivePersonaId(): Flow { + return dataStore.data.map { preferences -> + preferences[PreferencesKeys.ACTIVE_PERSONA_ID] + }.distinctUntilChanged() + } + + override fun saveUserDocuments(documents: List) { + runBlocking { + dataStore.edit { preferences -> + val gson = Gson() + val jsonString = gson.toJson(documents) + preferences[PreferencesKeys.USER_DOCUMENTS_LIST] = jsonString + } + } + } + + override fun readUserDocuments(): Flow> { + return dataStore.data.map { preferences -> + val jsonString = preferences[PreferencesKeys.USER_DOCUMENTS_LIST] + if (jsonString != null) { + val gson = Gson() + val type = object : TypeToken>() {}.type + gson.fromJson(jsonString, type) ?: emptyList() + } else { + emptyList() + } + }.distinctUntilChanged() + } + + override fun addUserDocument(document: UserDocument) { + runBlocking { // Consider making these suspend functions if runBlocking becomes an issue + val currentDocuments = readUserDocuments().first().toMutableList() + currentDocuments.removeAll { it.id == document.id } // Remove if already exists by ID, then add + currentDocuments.add(document) + saveUserDocuments(currentDocuments) + } + } + + override fun updateUserDocument(document: UserDocument) { + runBlocking { + val currentDocuments = readUserDocuments().first().toMutableList() + val index = currentDocuments.indexOfFirst { it.id == document.id } + if (index != -1) { + currentDocuments[index] = document + saveUserDocuments(currentDocuments) + } else { + // Optionally add if not found, or log an error + addUserDocument(document) // Or handle error: document to update not found + } + } + } + + override fun deleteUserDocument(documentId: String) { + runBlocking { + val currentDocuments = readUserDocuments().first().toMutableList() + currentDocuments.removeAll { it.id == documentId } + saveUserDocuments(currentDocuments) + } + } + + override fun getUserDocumentById(documentId: String): Flow { + return readUserDocuments().map { documents -> + documents.find { it.id == documentId } + } + } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/PersonalAppModels.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/PersonalAppModels.kt new file mode 100644 index 0000000..ca0b7af --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/PersonalAppModels.kt @@ -0,0 +1,110 @@ +package com.google.ai.edge.gallery.data + +import java.util.UUID + +/** + * Represents a user's profile information. + * + * @property name The name of the user. + * @property summary A brief summary or bio of the user. + * @property skills A list of the user's skills. + * @property experience A list of strings, where each string can represent a job or project description. + */ +data class UserProfile( + val name: String? = null, + val summary: String? = null, + val skills: List = emptyList(), + val experience: List = emptyList() +) + +/** + * Represents an AI persona that can be used in conversations. + * + * @property id A unique identifier for the persona (e.g., a UUID). + * @property name The name of the persona. + * @property prompt The system prompt associated with this persona, defining its behavior and responses. + * @property isDefault Indicates if this is a default persona. + */ +data class Persona( + val id: String = UUID.randomUUID().toString(), + val name: String, + val prompt: String, + val isDefault: Boolean = false +) + +/** + * Defines the role of the sender of a chat message. + */ +enum class ChatMessageRole { + /** The message is from the end-user. */ + USER, + /** The message is from the AI assistant. */ + ASSISTANT, + /** The message is a system instruction or context. */ + SYSTEM +} + +/** + * Represents a single message within a chat conversation. + * + * @property id A unique identifier for the message (e.g., a UUID). + * @property conversationId The ID of the conversation this message belongs to. + * @property timestamp The time the message was created, in epoch milliseconds. + * @property role The role of the message sender (user, assistant, or system). + * @property content The textual content of the message. + * @property personaUsedId The ID of the Persona active when this message was generated or sent, if applicable. + */ +data class ChatMessage( + val id: String = UUID.randomUUID().toString(), + val conversationId: String, + val timestamp: Long, + val role: ChatMessageRole, + val content: String, + val personaUsedId: String? = null +) + +/** + * Represents a chat conversation. + * + * @property id A unique identifier for the conversation (e.g., a UUID). + * @property title An optional user-defined title for the conversation. + * @property creationTimestamp The time the conversation was created, in epoch milliseconds. + * @property lastModifiedTimestamp The time the conversation was last modified, in epoch milliseconds. + * @property initialSystemPrompt A custom system prompt for this specific conversation, which might override a Persona's default prompt. + * @property messages A list of chat messages in this conversation. For local storage, embedding can work. For cloud sync, separate storage of messages linked by ID is better. + * @property activePersonaId The ID of the persona primarily used in this conversation. + * @property modelIdUsed The ID (e.g. name) of the AI model used in this conversation. + */ +data class Conversation( + val id: String = UUID.randomUUID().toString(), + var title: String? = null, + val creationTimestamp: Long, + var lastModifiedTimestamp: Long, + val initialSystemPrompt: String? = null, + val modelIdUsed: String? = null, // Add this + val messages: List = emptyList(), + val activePersonaId: String? = null +) + +/** + * Represents a document imported or managed by the user. + * + * @property id Unique identifier for the document (e.g., UUID). + * @property fileName Original name of the file. + * @property localPath Path to the locally stored copy of the document, if applicable. + * @property originalSource Indicates where the document came from (e.g., "local", a URL for Google Docs). + * @property fileType The MIME type or a simple extension string (e.g., "txt", "pdf", "docx"). + * @property extractedText The text content extracted from the document. Null if not yet extracted or not applicable. + * @property importTimestamp Timestamp when the document was imported. + * @property lastAccessedTimestamp Timestamp when the document was last used in a chat (optional). + */ +data class UserDocument( + val id: String = java.util.UUID.randomUUID().toString(), + val fileName: String, + val localPath: String? = null, + val originalSource: String, // e.g., "local", "google_drive_id:" + val fileType: String, // e.g., "text/plain", "application/pdf" + var extractedText: String? = null, + val importTimestamp: Long = System.currentTimeMillis(), + var lastAccessedTimestamp: Long? = null +) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt index 41a2314..1a98c9c 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt @@ -33,9 +33,7 @@ enum class TaskType(val label: String, val id: String) { IMAGE_CLASSIFICATION(label = "Image Classification", id = "image_classification"), IMAGE_GENERATION(label = "Image Generation", id = "image_generation"), LLM_CHAT(label = "AI Chat", id = "llm_chat"), - LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"), - LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"), - + // LLM_PROMPT_LAB and LLM_ASK_IMAGE removed from enum TEST_TASK_1(label = "Test task 1", id = "test_task_1"), TEST_TASK_2(label = "Test task 2", id = "test_task_2") } @@ -100,25 +98,8 @@ val TASK_LLM_CHAT = Task( textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat ) -val TASK_LLM_PROMPT_LAB = Task( - type = TaskType.LLM_PROMPT_LAB, - icon = Icons.Outlined.Widgets, - models = mutableListOf(), - description = "Single turn use cases with on-device large language model", - docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", - sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt", - textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat -) - -val TASK_LLM_ASK_IMAGE = Task( - type = TaskType.LLM_ASK_IMAGE, - icon = Icons.Outlined.Mms, - models = mutableListOf(), - description = "Ask questions about images with on-device large language models", - docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android", - sourceCodeUrl = "https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt", - textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat -) +// TASK_LLM_PROMPT_LAB definition removed +// TASK_LLM_ASK_IMAGE definition removed val TASK_IMAGE_GENERATION = Task( type = TaskType.IMAGE_GENERATION, @@ -132,9 +113,12 @@ val TASK_IMAGE_GENERATION = Task( /** All tasks. */ val TASKS: List = listOf( - TASK_LLM_ASK_IMAGE, - TASK_LLM_PROMPT_LAB, - TASK_LLM_CHAT, + TASK_TEXT_CLASSIFICATION, + TASK_IMAGE_CLASSIFICATION, + TASK_IMAGE_GENERATION, + TASK_LLM_CHAT + // TASK_LLM_ASK_IMAGE removed + // TASK_LLM_PROMPT_LAB removed ) fun getModelByName(name: String): Model? { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt index 8a257dd..dd9bc9a 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt @@ -22,13 +22,17 @@ import androidx.lifecycle.viewmodel.CreationExtras import androidx.lifecycle.viewmodel.initializer import androidx.lifecycle.viewmodel.viewModelFactory import com.google.ai.edge.gallery.GalleryApplication +import com.google.ai.edge.gallery.data.TASK_LLM_CHAT // Import TASK_LLM_CHAT import com.google.ai.edge.gallery.ui.imageclassification.ImageClassificationViewModel import com.google.ai.edge.gallery.ui.imagegeneration.ImageGenerationViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel -import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel -import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel +// import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel // Removed +// import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel // Removed import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.textclassification.TextClassificationViewModel +import com.google.ai.edge.gallery.ui.userprofile.UserProfileViewModel +import com.google.ai.edge.gallery.ui.persona.PersonaViewModel +import com.google.ai.edge.gallery.ui.conversationhistory.ConversationHistoryViewModel // Added import object ViewModelProvider { val Factory = viewModelFactory { @@ -55,23 +59,35 @@ object ViewModelProvider { // Initializer for LlmChatViewModel. initializer { - LlmChatViewModel() + val dataStoreRepository = galleryApplication().container.dataStoreRepository + LlmChatViewModel(dataStoreRepository = dataStoreRepository, curTask = TASK_LLM_CHAT) } - // Initializer for LlmSingleTurnViewModel.. - initializer { - LlmSingleTurnViewModel() - } - - // Initializer for LlmAskImageViewModel. - initializer { - LlmAskImageViewModel() - } + // Initializer for LlmSingleTurnViewModel.. - REMOVED + // Initializer for LlmAskImageViewModel. - REMOVED // Initializer for ImageGenerationViewModel. initializer { ImageGenerationViewModel() } + + // Initializer for UserProfileViewModel. + initializer { + val dataStoreRepository = galleryApplication().container.dataStoreRepository + UserProfileViewModel(dataStoreRepository = dataStoreRepository) + } + + // Initializer for PersonaViewModel. + initializer { + val dataStoreRepository = galleryApplication().container.dataStoreRepository + PersonaViewModel(dataStoreRepository = dataStoreRepository) + } + + // Initializer for ConversationHistoryViewModel. + initializer { + val dataStoreRepository = galleryApplication().container.dataStoreRepository + ConversationHistoryViewModel(dataStoreRepository = dataStoreRepository) + } } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/LocalAppContainer.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/LocalAppContainer.kt new file mode 100644 index 0000000..6f64314 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/LocalAppContainer.kt @@ -0,0 +1,6 @@ +package com.google.ai.edge.gallery.ui.common + +import androidx.compose.runtime.compositionLocalOf +import com.google.ai.edge.gallery.data.AppContainer + +val LocalAppContainer = compositionLocalOf { error("AppContainer not provided") } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryScreen.kt new file mode 100644 index 0000000..0b94fd5 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryScreen.kt @@ -0,0 +1,146 @@ +package com.google.ai.edge.gallery.ui.conversationhistory + +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.ArrowBack +import androidx.compose.material.icons.filled.Delete +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import androidx.compose.ui.res.stringResource // Added import +import androidx.lifecycle.viewmodel.compose.viewModel +import androidx.navigation.NavController +import com.google.ai.edge.gallery.R // Added import for R class +import com.google.ai.edge.gallery.data.Conversation +// import com.google.ai.edge.gallery.data.getModelByName // Not strictly needed if modelName is just a string +import com.google.ai.edge.gallery.ui.ViewModelProvider // For ViewModelProvider.Factory +import com.google.ai.edge.gallery.ui.navigation.LlmChatDestination // For navigation route +import java.text.SimpleDateFormat +import java.util.Date +import java.util.Locale + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun ConversationHistoryScreen( + navController: NavController, + viewModelFactory: ViewModelProvider.Factory, // Ensure this matches the actual factory name + viewModel: ConversationHistoryViewModel = viewModel(factory = viewModelFactory) +) { + val conversations by viewModel.conversations.collectAsState() + + Scaffold( + topBar = { + TopAppBar( + title = { Text(stringResource(R.string.conversation_history_title)) }, + navigationIcon = { + IconButton(onClick = { navController.popBackStack() }) { + Icon(Icons.Filled.ArrowBack, contentDescription = stringResource(R.string.user_profile_back_button_desc)) // Reused + } + } + ) + } + ) { paddingValues -> + if (conversations.isEmpty()) { + Box(modifier = Modifier.fillMaxSize().padding(paddingValues), contentAlignment = Alignment.Center) { + Text(stringResource(R.string.conversation_history_no_conversations)) + } + } else { + LazyColumn( + contentPadding = paddingValues, + modifier = Modifier.fillMaxSize().padding(horizontal = 8.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + items(conversations, key = { it.id }) { conversation -> + ConversationHistoryItem( + conversation = conversation, + onItemClick = { + val modelName = conversation.modelIdUsed + if (modelName != null) { + // Navigate using the route that includes conversationId and modelName + navController.navigate( + "${LlmChatDestination.routeTemplate}/conversation/${conversation.id}?modelName=${modelName}" + ) + } else { + // Fallback or error: modelIdUsed should ideally not be null for conversations created post-update. + // Consider navigating to a generic chat or showing an error. + // For now, this click might do nothing if modelIdUsed is null. + android.util.Log.w("ConvHistory", "modelIdUsed is null for conversation: ${conversation.id}") + } + }, + onDeleteClick = { viewModel.deleteConversation(conversation.id) } + ) + } + } + } + } +} + +@Composable +fun ConversationHistoryItem( + conversation: Conversation, + onItemClick: () -> Unit, + onDeleteClick: () -> Unit +) { + var showDeleteConfirmDialog by remember { mutableStateOf(false) } + val dateFormatter = remember { SimpleDateFormat("MMM dd, yyyy hh:mm a", Locale.getDefault()) } + + Card( + modifier = Modifier + .fillMaxWidth() + .clickable(onClick = onItemClick), + elevation = CardDefaults.cardElevation(defaultElevation = 2.dp) + ) { + Row( + modifier = Modifier.padding(16.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Column(modifier = Modifier.weight(1f)) { + Text( + conversation.title ?: stringResource(R.string.conversation_history_item_title_prefix, dateFormatter.format(Date(conversation.creationTimestamp))), + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + Spacer(Modifier.height(4.dp)) + Text( + "Model: ${conversation.modelIdUsed ?: stringResource(R.string.chat_default_agent_name)}", // Display model ID or default + style = MaterialTheme.typography.bodySmall, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + Text( + stringResource(R.string.conversation_history_last_activity_prefix, dateFormatter.format(Date(conversation.lastModifiedTimestamp))), + style = MaterialTheme.typography.bodySmall + ) + conversation.messages.lastOrNull()?.let { + Text( + "${it.role}: ${it.content.take(80)}${if (it.content.length > 80) "..." else ""}", + style = MaterialTheme.typography.bodySmall, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + } + } + IconButton(onClick = { showDeleteConfirmDialog = true }) { + Icon(Icons.Filled.Delete, contentDescription = stringResource(R.string.persona_item_delete_desc)) // Reused + } + } + } + if (showDeleteConfirmDialog) { + AlertDialog( + onDismissRequest = { showDeleteConfirmDialog = false }, + title = { Text(stringResource(R.string.conversation_history_delete_dialog_title)) }, + text = { Text(stringResource(R.string.conversation_history_delete_dialog_message)) }, + confirmButton = { Button(onClick = { onDeleteClick(); showDeleteConfirmDialog = false }) { Text(stringResource(R.string.conversation_history_delete_dialog_confirm_button)) } }, + dismissButton = { Button(onClick = { showDeleteConfirmDialog = false }) { Text(stringResource(R.string.dialog_cancel_button)) } } + ) + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryViewModel.kt new file mode 100644 index 0000000..6472c9c --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/conversationhistory/ConversationHistoryViewModel.kt @@ -0,0 +1,33 @@ +package com.google.ai.edge.gallery.ui.conversationhistory + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.google.ai.edge.gallery.data.Conversation +import com.google.ai.edge.gallery.data.DataStoreRepository +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.launch + +class ConversationHistoryViewModel(private val dataStoreRepository: DataStoreRepository) : ViewModel() { + + private val _conversations = MutableStateFlow>(emptyList()) + val conversations: StateFlow> = _conversations.asStateFlow() + + init { + loadConversations() + } + + fun loadConversations() { + viewModelScope.launch { + _conversations.value = dataStoreRepository.readConversations().sortedByDescending { it.lastModifiedTimestamp } + } + } + + fun deleteConversation(conversationId: String) { + viewModelScope.launch { + dataStoreRepository.deleteConversation(conversationId) + loadConversations() // Refresh the list + } + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/HomeScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/HomeScreen.kt index 5b124c1..a7848fe 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/HomeScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/HomeScreen.kt @@ -130,10 +130,13 @@ object HomeScreenDestination { } @OptIn(ExperimentalMaterial3Api::class) +import androidx.navigation.NavController + @Composable fun HomeScreen( modelManagerViewModel: ModelManagerViewModel, navigateToTaskScreen: (Task) -> Unit, + navController: NavController, // Add NavController modifier: Modifier = Modifier ) { val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior() @@ -210,6 +213,7 @@ fun HomeScreen( SettingsDialog( curThemeOverride = modelManagerViewModel.readThemeOverride(), modelManagerViewModel = modelManagerViewModel, + navController = navController, // Pass NavController onDismissed = { showSettingsDialog = false }, ) } @@ -571,10 +575,16 @@ fun getFileName(context: Context, uri: Uri): String? { @Composable fun HomeScreenPreview( ) { + // Preview will not have a real NavController, so this might need adjustment + // if SettingsDialog is to be previewed from here. For now, focusing on functionality. + // For a simple preview, one might pass a dummy NavController or conditional logic. + val context = LocalContext.current + val dummyNavController = remember { NavController(context) } GalleryTheme { HomeScreen( - modelManagerViewModel = PreviewModelManagerViewModel(context = LocalContext.current), + modelManagerViewModel = PreviewModelManagerViewModel(context = context), navigateToTaskScreen = {}, + navController = dummyNavController, // Pass dummy NavController for preview ) } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt index c0cae8f..4ee53ed 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/home/SettingsDialog.kt @@ -60,8 +60,12 @@ import androidx.compose.ui.text.TextStyle import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.unit.dp import androidx.compose.ui.window.Dialog +import androidx.navigation.NavController +import com.google.ai.edge.gallery.R // Added import for R class import com.google.ai.edge.gallery.BuildConfig import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel +import com.google.ai.edge.gallery.ui.navigation.PersonaManagementDestination // For the route +import com.google.ai.edge.gallery.ui.navigation.UserProfileDestination // For the route import com.google.ai.edge.gallery.ui.theme.THEME_AUTO import com.google.ai.edge.gallery.ui.theme.THEME_DARK import com.google.ai.edge.gallery.ui.theme.THEME_LIGHT @@ -79,6 +83,7 @@ private val THEME_OPTIONS = listOf(THEME_AUTO, THEME_LIGHT, THEME_DARK) fun SettingsDialog( curThemeOverride: String, modelManagerViewModel: ModelManagerViewModel, + navController: NavController, // Add NavController onDismissed: () -> Unit, ) { var selectedTheme by remember { mutableStateOf(curThemeOverride) } @@ -255,6 +260,46 @@ fun SettingsDialog( } } } + + // Personal Profile Section + Column( + modifier = Modifier.fillMaxWidth() + ) { + Text( + stringResource(R.string.settings_personal_profile_title), + style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold), + modifier = Modifier.padding(bottom = 8.dp) + ) + OutlinedButton( + onClick = { + navController.navigate(UserProfileDestination.route) + onDismissed() // Optionally dismiss settings dialog after navigation + }, + modifier = Modifier.fillMaxWidth() + ) { + Text(stringResource(R.string.settings_edit_profile_button)) + } + } + + // Persona Management Section + Column( + modifier = Modifier.fillMaxWidth() + ) { + Text( + stringResource(R.string.settings_persona_management_title), + style = MaterialTheme.typography.titleSmall.copy(fontWeight = FontWeight.Bold), + modifier = Modifier.padding(top = 16.dp, bottom = 8.dp) // Add top padding + ) + OutlinedButton( + onClick = { + navController.navigate(PersonaManagementDestination.route) + onDismissed() // Optionally dismiss settings dialog + }, + modifier = Modifier.fillMaxWidth() + ) { + Text(stringResource(R.string.settings_manage_personas_button)) + } + } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt index ce2094e..659fd5d 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -39,6 +39,17 @@ object LlmChatModelHelper { // Indexed by model name. private val cleanUpListeners: MutableMap = mutableMapOf() + fun primeSessionWithSystemPrompt(model: Model, systemPrompt: String) { + val instance = model.instance as? LlmModelInstance ?: return + try { + instance.session.addQueryChunk(systemPrompt) + Log.d(TAG, "Session primed with system prompt.") + } catch (e: Exception) { + Log.e(TAG, "Error priming session with system prompt: ", e) + // Consider how to handle this error, maybe throw or callback + } + } + fun initialize( context: Context, model: Model, onDone: (String) -> Unit ) { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt index 6cebc53..63e40b1 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt @@ -17,10 +17,18 @@ package com.google.ai.edge.gallery.ui.llmchat import android.graphics.Bitmap -import androidx.compose.runtime.Composable +import androidx.compose.foundation.layout.* // For Column, Row, Spacer, padding, etc. +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.ArrowBack +import androidx.compose.material3.* // For TextField, TopAppBar, etc. +import androidx.compose.runtime.* // For remember, mutableStateOf, collectAsState +import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.res.stringResource // For string resources +import androidx.compose.ui.unit.dp import androidx.lifecycle.viewmodel.compose.viewModel +import com.google.ai.edge.gallery.data.Model // Ensure Model is imported import com.google.ai.edge.gallery.ui.ViewModelProvider import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText @@ -30,115 +38,157 @@ import kotlinx.serialization.Serializable /** Navigation destination data */ object LlmChatDestination { - @Serializable - val route = "LlmChatRoute" -} - -object LlmAskImageDestination { - @Serializable - val route = "LlmAskImageRoute" + @Serializable // Keep serializable if used in NavType directly, though we'll use String for nav args + const val routeTemplate = "LlmChatRoute" + const val conversationIdArg = "conversationId" + // Route for opening an existing conversation + val routeForConversation = "$routeTemplate/conversation/{$conversationIdArg}" + // Route for starting a new chat, potentially with a pre-selected model + const val modelNameArg = "modelName" + val routeForNewChatWithModel = "$routeTemplate/new/{$modelNameArg}" + val routeForNewChat = routeTemplate // General new chat } @Composable fun LlmChatScreen( - modelManagerViewModel: ModelManagerViewModel, - navigateUp: () -> Unit, - modifier: Modifier = Modifier, - viewModel: LlmChatViewModel = viewModel( - factory = ViewModelProvider.Factory - ), + modelManagerViewModel: ModelManagerViewModel, + navigateUp: () -> Unit, + modifier: Modifier = Modifier, + viewModel: LlmChatViewModel = viewModel(factory = ViewModelProvider.Factory), + conversationId: String? = null // New parameter ) { - ChatViewWrapper( - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - navigateUp = navigateUp, - modifier = modifier, - ) + var customSystemPromptInput by remember { mutableStateOf("") } + val currentConversation by viewModel.currentConversation.collectAsState() + val activePersona by viewModel.activePersona.collectAsState() + val uiMessages by viewModel.uiMessages.collectAsState() // Observe uiMessages + + // Prioritize model from conversation if available, then from ModelManagerViewModel + val selectedModel: Model? = remember(currentConversation, modelManagerViewModel.getSelectedModel(viewModel.task.type)) { + modelManagerViewModel.getSelectedModel(viewModel.task.type) + } + + LaunchedEffect(conversationId, selectedModel) { + if (selectedModel == null) { + android.util.Log.e("LlmChatScreen", "No model selected for chat.") + return@LaunchedEffect + } + if (conversationId != null) { + viewModel.loadConversation(conversationId, selectedModel) + } else { + if (currentConversation == null || (currentConversation?.id == null && currentConversation?.messages.isNullOrEmpty())) { + // Let user type system prompt. startNewConversation will be called on first send. + } + } + } + + ChatViewWrapper( + viewModel = viewModel, + modelManagerViewModel = modelManagerViewModel, + navigateUp = navigateUp, + navController = navController, // Pass navController + modifier = modifier, + customSystemPromptInput = customSystemPromptInput, + onCustomSystemPromptChange = { customSystemPromptInput = it }, + activePersonaName = activePersona?.name ?: currentConversation?.activePersonaId?.let { "ID: $it" } // Fallback to ID if name not loaded + ) } -@Composable -fun LlmAskImageScreen( - modelManagerViewModel: ModelManagerViewModel, - navigateUp: () -> Unit, - modifier: Modifier = Modifier, - viewModel: LlmAskImageViewModel = viewModel( - factory = ViewModelProvider.Factory - ), -) { - ChatViewWrapper( - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - navigateUp = navigateUp, - modifier = modifier, - ) -} +// Removed duplicated ChatViewWrapper call and imports that were above it. +// The ExperimentalMaterial3Api annotation is kept for the actual ChatViewWrapper below. +@OptIn(ExperimentalMaterial3Api::class) @Composable fun ChatViewWrapper( - viewModel: LlmChatViewModel, - modelManagerViewModel: ModelManagerViewModel, - navigateUp: () -> Unit, - modifier: Modifier = Modifier + viewModel: LlmChatViewModel, + modelManagerViewModel: ModelManagerViewModel, + navigateUp: () -> Unit, + navController: NavController, // Added navController parameter + modifier: Modifier = Modifier, + customSystemPromptInput: String, + onCustomSystemPromptChange: (String) -> Unit, + activePersonaName: String? ) { - val context = LocalContext.current + val context = LocalContext.current + val currentConvo by viewModel.currentConversation.collectAsState() + val messagesForUi by viewModel.uiMessages.collectAsState() + val selectedModel = modelManagerViewModel.getSelectedModel(viewModel.task.type) - ChatView( - task = viewModel.task, - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - onSendMessage = { model, messages -> - for (message in messages) { - viewModel.addMessage( - model = model, - message = message, + // Show system prompt input if no conversation has started or if it's a new, empty conversation + val showSystemPromptInput = currentConvo == null || (currentConvo?.id != null && currentConvo!!.messages.isEmpty()) + + + Column(modifier = modifier.fillMaxSize()) { + TopAppBar( + title = { + Text(activePersonaName ?: stringResource(viewModel.task.agentNameRes)) + }, + navigationIcon = { + IconButton(onClick = navigateUp) { + Icon(Icons.Filled.ArrowBack, contentDescription = stringResource(R.string.user_profile_back_button_desc)) // Reused + } + }, + actions = { + IconButton(onClick = { navController.navigate(ConversationHistoryDestination.route) }) { + Icon(Icons.Filled.History, contentDescription = stringResource(R.string.chat_history_button_desc)) + } + } ) - } - var text = "" - var image: Bitmap? = null - var chatMessageText: ChatMessageText? = null - for (message in messages) { - if (message is ChatMessageText) { - chatMessageText = message - text = message.content - } else if (message is ChatMessageImage) { - image = message.bitmap + if (showSystemPromptInput) { + OutlinedTextField( + value = customSystemPromptInput, + onValueChange = onCustomSystemPromptChange, + label = { Text(stringResource(R.string.chat_custom_system_prompt_label)) }, + modifier = Modifier + .fillMaxWidth() + .padding(8.dp), + maxLines = 3 + ) } - } - if (text.isNotEmpty() && chatMessageText != null) { - modelManagerViewModel.addTextInputHistory(text) - viewModel.generateResponse(model = model, input = text, image = image, onError = { - viewModel.handleError( - context = context, - model = model, + + ChatView( + task = viewModel.task, + viewModel = viewModel, modelManagerViewModel = modelManagerViewModel, - triggeredMessage = chatMessageText, - ) - }) - } - }, - onRunAgainClicked = { model, message -> - if (message is ChatMessageText) { - viewModel.runAgain(model = model, message = message, onError = { - viewModel.handleError( - context = context, - model = model, - modelManagerViewModel = modelManagerViewModel, - triggeredMessage = message, - ) - }) - } - }, - onBenchmarkClicked = { _, _, _, _ -> - }, - onResetSessionClicked = { model -> - viewModel.resetSession(model = model) - }, - showStopButtonInInputWhenInProgress = true, - onStopButtonClicked = { model -> - viewModel.stopResponse(model = model) - }, - navigateUp = navigateUp, - modifier = modifier, - ) + messages = messagesForUi, + onSendMessage = { modelFromChatView, userMessages -> + val userInputMessage = userMessages.firstNotNullOfOrNull { it as? ChatMessageText }?.content ?: "" + val imageBitmap = userMessages.firstNotNullOfOrNull { it as? ChatMessageImage }?.bitmap + + if (userInputMessage.isNotBlank() || imageBitmap != null) { + selectedModel?.let { validSelectedModel -> + modelManagerViewModel.addTextInputHistory(userInputMessage) + + if (currentConvo == null || (currentConvo?.id != null && currentConvo!!.messages.isEmpty() && !currentConvo!!.initialSystemPrompt.isNullOrEmpty().not() && customSystemPromptInput.isBlank())) { + viewModel.startNewConversation( + customSystemPrompt = if (customSystemPromptInput.isNotBlank()) customSystemPromptInput else null, + selectedPersonaId = viewModel.activePersona.value?.id, + title = userInputMessage.take(30).ifBlank { stringResource(R.string.chat_new_conversation_title_prefix) }, // Use string resource + selectedModel = validSelectedModel + ) + } + viewModel.generateChatResponse(model = validSelectedModel, input = userInputMessage, image = imageBitmap) + } ?: run { + android.util.Log.e("ChatViewWrapper", "No model selected, cannot send message.") + // Potentially show error to user + } + } + }, + // TODO: Add confirmation dialog for reset session + onResetSessionClicked = { model -> + selectedModel?.let { + onCustomSystemPromptChange("") + viewModel.startNewConversation( + customSystemPrompt = null, + selectedPersonaId = viewModel.activePersona.value?.id, + title = stringResource(R.string.chat_new_conversation_title_prefix), // Use string resource + selectedModel = it + ) + } + }, + showStopButtonInInputWhenInProgress = true, + onStopButtonClicked = { model -> viewModel.stopResponse(model) }, + navigateUp = navigateUp + ) + } } \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt index b99cdc8..9957abf 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -21,10 +21,16 @@ import android.graphics.Bitmap import android.util.Log import androidx.lifecycle.viewModelScope import com.google.ai.edge.gallery.data.ConfigKey +import com.google.ai.edge.gallery.data.Conversation +import com.google.ai.edge.gallery.data.DataStoreRepository import com.google.ai.edge.gallery.data.Model +import com.google.ai.edge.gallery.data.Persona import com.google.ai.edge.gallery.data.TASK_LLM_CHAT -import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE +// import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE // Removed import com.google.ai.edge.gallery.data.Task +import com.google.ai.edge.gallery.data.UserProfile +import com.google.ai.edge.gallery.data.ChatMessage // Assuming ChatMessage is in data package from PersonalAppModels.kt +import com.google.ai.edge.gallery.data.ChatMessageRole import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText @@ -36,7 +42,9 @@ import com.google.ai.edge.gallery.ui.common.chat.Stat import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.* import kotlinx.coroutines.launch +import java.util.UUID private const val TAG = "AGLlmChatViewModel" private val STATS = listOf( @@ -46,203 +54,265 @@ private val STATS = listOf( Stat(id = "latency", label = "Latency", unit = "sec") ) -open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) { - fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) { - val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") - viewModelScope.launch(Dispatchers.Default) { - setInProgress(true) - setPreparing(true) +open class LlmChatViewModel( + private val dataStoreRepository: DataStoreRepository, // Add this + curTask: Task = TASK_LLM_CHAT // Keep if still relevant, or simplify if chat is the only focus +) : ChatViewModel(task = curTask) { // ChatViewModel base class might need review - // Loading. - addMessage( - model = model, - message = ChatMessageLoading(accelerator = accelerator), - ) + private val _currentConversation = MutableStateFlow(null) + val currentConversation: StateFlow = _currentConversation.asStateFlow() - // Wait for instance to be initialized. - while (model.instance == null) { - delay(100) - } - delay(500) + val userProfile: StateFlow = flow { + emit(dataStoreRepository.readUserProfile()) + }.stateIn(viewModelScope, SharingStarted.Eagerly, null) - // Run inference. - val instance = model.instance as LlmModelInstance - var prefillTokens = instance.session.sizeInTokens(input) - if (image != null) { - prefillTokens += 257 - } - - var firstRun = true - var timeToFirstToken = 0f - var firstTokenTs = 0L - var decodeTokens = 0 - var prefillSpeed = 0f - var decodeSpeed: Float - val start = System.currentTimeMillis() - - try { - LlmChatModelHelper.runInference(model = model, - input = input, - image = image, - resultListener = { partialResult, done -> - val curTs = System.currentTimeMillis() - - if (firstRun) { - firstTokenTs = System.currentTimeMillis() - timeToFirstToken = (firstTokenTs - start) / 1000f - prefillSpeed = prefillTokens / timeToFirstToken - firstRun = false - setPreparing(false) - } else { - decodeTokens++ + // Get active persona ID from repo, then fetch the full Persona object + val activePersona: StateFlow = dataStoreRepository.readActivePersonaId().flatMapLatest { activeId -> + if (activeId == null) { + flowOf(null) + } else { + flow { + val personas = dataStoreRepository.readPersonas() + emit(personas.find { it.id == activeId }) } - - // Remove the last message if it is a "loading" message. - // This will only be done once. - val lastMessage = getLastMessage(model = model) - if (lastMessage?.type == ChatMessageType.LOADING) { - removeLastMessage(model = model) - - // Add an empty message that will receive streaming results. - addMessage( - model = model, - message = ChatMessageText( - content = "", - side = ChatSide.AGENT, - accelerator = accelerator - ) - ) - } - - // Incrementally update the streamed partial results. - val latencyMs: Long = if (done) System.currentTimeMillis() - start else -1 - updateLastTextMessageContentIncrementally( - model = model, partialContent = partialResult, latencyMs = latencyMs.toFloat() - ) - - if (done) { - setInProgress(false) - - decodeSpeed = decodeTokens / ((curTs - firstTokenTs) / 1000f) - if (decodeSpeed.isNaN()) { - decodeSpeed = 0f - } - - if (lastMessage is ChatMessageText) { - updateLastTextMessageLlmBenchmarkResult( - model = model, llmBenchmarkResult = ChatMessageBenchmarkLlmResult( - orderedStats = STATS, - statValues = mutableMapOf( - "prefill_speed" to prefillSpeed, - "decode_speed" to decodeSpeed, - "time_to_first_token" to timeToFirstToken, - "latency" to (curTs - start).toFloat() / 1000f, - ), - running = false, - latencyMs = -1f, - accelerator = accelerator, - ) - ) - } - } - }, - cleanUpListener = { - setInProgress(false) - setPreparing(false) - }) - } catch (e: Exception) { - Log.e(TAG, "Error occurred while running inference", e) - setInProgress(false) - setPreparing(false) - onError() - } - } - } - - fun stopResponse(model: Model) { - Log.d(TAG, "Stopping response for model ${model.name}...") - if (getLastMessage(model = model) is ChatMessageLoading) { - removeLastMessage(model = model) - } - viewModelScope.launch(Dispatchers.Default) { - setInProgress(false) - val instance = model.instance as LlmModelInstance - instance.session.cancelGenerateResponseAsync() - } - } - - fun resetSession(model: Model) { - viewModelScope.launch(Dispatchers.Default) { - setIsResettingSession(true) - clearAllMessages(model = model) - stopResponse(model = model) - - while (true) { - try { - LlmChatModelHelper.resetSession(model = model) - break - } catch (e: Exception) { - Log.d(TAG, "Failed to reset session. Trying again") } - delay(200) - } - setIsResettingSession(false) - } - } + }.stateIn(viewModelScope, SharingStarted.Eagerly, null) - fun runAgain(model: Model, message: ChatMessageText, onError: () -> Unit) { - viewModelScope.launch(Dispatchers.Default) { - // Wait for model to be initialized. - while (model.instance == null) { - delay(100) - } + // This replaces messagesByModel for the new conversation-centric approach + private val _uiMessages = MutableStateFlow>(emptyList()) + val uiMessages: StateFlow> = _uiMessages.asStateFlow() - // Clone the clicked message and add it. - addMessage(model = model, message = message.clone()) + // TODO: Review how ChatViewModel's messagesByModel and related methods + // (addMessage, getLastMessage, removeLastMessage, clearAllMessages) are used. + // They might need to be overridden or adapted to work with _currentConversation.messages + // and update _uiMessages. For now, new methods will manage _currentConversation. - // Run inference. - generateResponse( - model = model, input = message.content, onError = onError - ) - } - } + fun startNewConversation(customSystemPrompt: String?, selectedPersonaId: String?, title: String? = null, selectedModel: Model) { + viewModelScope.launch { + val newConversationId = UUID.randomUUID().toString() + val newConversation = Conversation( + id = newConversationId, + title = title, + creationTimestamp = System.currentTimeMillis(), + lastModifiedTimestamp = System.currentTimeMillis(), + initialSystemPrompt = customSystemPrompt, + activePersonaId = selectedPersonaId, + modelIdUsed = selectedModel.name, // Store the model name or a unique ID + messages = mutableListOf() // Start with empty messages + ) + dataStoreRepository.addConversation(newConversation) + _currentConversation.value = newConversation + _uiMessages.value = emptyList() // Clear UI messages for the new chat - fun handleError( - context: Context, - model: Model, - modelManagerViewModel: ModelManagerViewModel, - triggeredMessage: ChatMessageText, - ) { - // Clean up. - modelManagerViewModel.cleanupModel(task = task, model = model) + // Reset LLM session for the selected model + LlmChatModelHelper.resetSession(selectedModel) // Ensure model instance is ready - // Remove the "loading" message. - if (getLastMessage(model = model) is ChatMessageLoading) { - removeLastMessage(model = model) + // Prime with system prompt immediately if starting a new conversation + val systemPromptParts = mutableListOf() + customSystemPrompt?.let { systemPromptParts.add(it) } + activePersona.value?.prompt?.let { systemPromptParts.add(it) } + userProfile.value?.summary?.let { if(it.isNotBlank()) systemPromptParts.add("User Profile Summary: $it") } + // Add other profile details as needed, e.g., skills + + if (systemPromptParts.isNotEmpty()) { + val fullSystemPrompt = systemPromptParts.joinToString("\n\n") + // Ensure model instance is available and ready before priming + if (selectedModel.instance == null) { + Log.e(TAG, "Model instance is null before priming. Initialize model first.") + // Potentially trigger model initialization via ModelManagerViewModel if not already done. + // For now, we assume the model selected for chat is already initialized by ModelManager. + return@launch + } + LlmChatModelHelper.primeSessionWithSystemPrompt(selectedModel, fullSystemPrompt) + } + } } - // Remove the last Text message. - if (getLastMessage(model = model) == triggeredMessage) { - removeLastMessage(model = model) + fun loadConversation(conversationId: String, selectedModel: Model) { + viewModelScope.launch { + val conversation = dataStoreRepository.getConversationById(conversationId) + _currentConversation.value = conversation + if (conversation != null) { + // Convert stored ChatMessage to UI ChatMessage + _uiMessages.value = conversation.messages.map { convertToUiChatMessage(it, selectedModel) } + + // Reset session and re-prime with history + LlmChatModelHelper.resetSession(selectedModel) + + val systemPromptParts = mutableListOf() + conversation.initialSystemPrompt?.let { systemPromptParts.add(it) } + // Need to fetch persona for this conversation + val personas = dataStoreRepository.readPersonas() + val personaForLoadedConv = personas.find { it.id == conversation.activePersonaId } + personaForLoadedConv?.prompt?.let { systemPromptParts.add(it) } + userProfile.value?.summary?.let { if(it.isNotBlank()) systemPromptParts.add("User Profile Summary: $it") } + // Add other profile details + + if (systemPromptParts.isNotEmpty()) { + LlmChatModelHelper.primeSessionWithSystemPrompt(selectedModel, systemPromptParts.joinToString("\n\n")) + } + // Replay message history into the session + conversation.messages.forEach { msg -> + if (msg.role == ChatMessageRole.USER) { + (selectedModel.instance as? LlmModelInstance)?.session?.addQueryChunk(msg.content) + } else if (msg.role == ChatMessageRole.ASSISTANT) { + // If LLM Inference API supports adding assistant messages to context, do it here. + // For now, assume addQueryChunk is for user input primarily to prompt next response. + (selectedModel.instance as? LlmModelInstance)?.session?.addQueryChunk(msg.content) // Or format appropriately + } + } + } else { + _uiMessages.value = emptyList() + } + } } - // Add a warning message for re-initializing the session. - addMessage( - model = model, - message = ChatMessageWarning(content = "Error occurred. Re-initializing the session.") - ) + // Helper to convert your data model ChatMessage to the UI model ChatMessage + private fun convertToUiChatMessage(appMessage: com.google.ai.edge.gallery.data.ChatMessage, model: Model): com.google.ai.edge.gallery.ui.common.chat.ChatMessage { + val side = if (appMessage.role == ChatMessageRole.USER) ChatSide.USER else ChatSide.AGENT + // This is a simplified conversion. You might need more fields. + // The existing ChatMessage in common.chat seems to be an interface/sealed class. + // We need to map to ChatMessageText or other appropriate types. + return ChatMessageText(content = appMessage.content, side = side, accelerator = model.getStringConfigValue(ConfigKey.ACCELERATOR, "")) + } - // Add the triggered message back. - addMessage(model = model, message = triggeredMessage) - // Re-initialize the session/engine. - modelManagerViewModel.initializeModel( - context = context, task = task, model = model - ) + // Override or adapt ChatViewModel's addMessage + override fun addMessage(model: Model, message: com.google.ai.edge.gallery.ui.common.chat.ChatMessage) { + val conversation = _currentConversation.value ?: return + val role = if (message.side == ChatSide.USER) ChatMessageRole.USER else ChatMessageRole.ASSISTANT - // Re-generate the response automatically. - generateResponse(model = model, input = triggeredMessage.content, onError = {}) - } -} + // Only add ChatMessageText for now to conversation history. Loading/Error messages are transient. + if (message is ChatMessageText) { + val appChatMessage = com.google.ai.edge.gallery.data.ChatMessage( + id = UUID.randomUUID().toString(), + conversationId = conversation.id, + timestamp = System.currentTimeMillis(), + role = role, + content = message.content, + personaUsedId = conversation.activePersonaId + ) + val updatedMessages = conversation.messages.toMutableList().apply { add(appChatMessage) } + _currentConversation.value = conversation.copy( + messages = updatedMessages, + lastModifiedTimestamp = System.currentTimeMillis() + ) + // Save updated conversation + viewModelScope.launch { + _currentConversation.value?.let { dataStoreRepository.updateConversation(it) } + } + } + // Update the UI-specific message list + _uiMessages.value = _uiMessages.value.toMutableList().apply { add(message) } + } -class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) \ No newline at end of file + // Adapt generateResponse + fun generateChatResponse(model: Model, userInput: String, image: Bitmap? = null) { // Renamed to avoid conflict if base still used + val currentConvo = _currentConversation.value + if (currentConvo == null) { + Log.e(TAG, "Cannot generate response, no active conversation.") + // Optionally, start a new conversation implicitly or show an error + return + } + + // Add user's message to conversation and UI + val userUiMessage = ChatMessageText(content = userInput, side = ChatSide.USER, accelerator = model.getStringConfigValue(ConfigKey.ACCELERATOR, "")) + addMessage(model, userUiMessage) // This will also save it to DataStore + + // The rest is similar to original generateResponse, but uses currentConvo + val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") + viewModelScope.launch(Dispatchers.Default) { + setInProgress(true) // From base ChatViewModel + setPreparing(true) // From base ChatViewModel + + addMessage(model, ChatMessageLoading(accelerator = accelerator)) // Show loading in UI + + while (model.instance == null) { delay(100) } // Wait for model instance + delay(500) + + val instance = model.instance as LlmModelInstance + // History is now part of the session, primed by startNewConversation or loadConversation. + // LlmChatModelHelper.runInference will just add the latest userInput. + + try { + LlmChatModelHelper.runInference( + model = model, + input = userInput, // Just the new input + image = image, // Handle image if provided + resultListener = { partialResult, done -> + // UI update logic for streaming response - largely same as original + // Ensure to use 'addMessage' or similar to update _uiMessages + // and save assistant's final message to _currentConversation + val lastUiMsg = _uiMessages.value.lastOrNull() + if (lastUiMsg?.type == ChatMessageType.LOADING) { + _uiMessages.value = _uiMessages.value.dropLast(1) + // Add an empty message that will receive streaming results for UI + addMessage(model, ChatMessageText(content = "", side = ChatSide.AGENT, accelerator = accelerator)) + } + + val currentAgentMessage = _uiMessages.value.lastOrNull() as? ChatMessageText + if (currentAgentMessage != null) { + _uiMessages.value = _uiMessages.value.dropLast(1) + currentAgentMessage.copy(content = currentAgentMessage.content + partialResult) + } + + + if (done) { + setInProgress(false) + val finalAssistantContent = (_uiMessages.value.lastOrNull() as? ChatMessageText)?.content ?: "" + // Add final assistant message to DataStore Conversation + val assistantAppMessage = com.google.ai.edge.gallery.data.ChatMessage( + id = UUID.randomUUID().toString(), + conversationId = currentConvo.id, + timestamp = System.currentTimeMillis(), + role = ChatMessageRole.ASSISTANT, + content = finalAssistantContent, + personaUsedId = currentConvo.activePersonaId + ) + val updatedMessages = currentConvo.messages.toMutableList().apply { add(assistantAppMessage) } + _currentConversation.value = currentConvo.copy( + messages = updatedMessages, + lastModifiedTimestamp = System.currentTimeMillis() + ) + viewModelScope.launch { _currentConversation.value?.let { dataStoreRepository.updateConversation(it) } } + // Update benchmark results if necessary (code omitted for brevity, but similar logic as original generateResponse) + val lastMessage = _uiMessages.value.lastOrNull { it.side == ChatSide.AGENT } // get the agent's message + if (lastMessage is ChatMessageText && STATS.isNotEmpty()) { // Assuming STATS is defined + // This part needs to be adapted. The original `generateResponse` calculated these. + // For now, we'll omit direct benchmark updates here to simplify, + // as they depend on variables (timeToFirstToken, prefillSpeed, etc.) + // not directly available in this refactored `generateChatResponse` structure + // without significant further adaptation of the LlmChatModelHelper.runInference callback. + // A simpler approach might be to just mark the message as not running. + val updatedAgentMessage = lastMessage.copy( + llmBenchmarkResult = lastMessage.llmBenchmarkResult?.copy(running = false) + ?: ChatMessageBenchmarkLlmResult(orderedStats = STATS, statValues = mutableMapOf(), running = false, latencyMs = -1f, accelerator = accelerator) + ) + val finalUiMessages = _uiMessages.value.toMutableList() + val agentMsgIndex = finalUiMessages.indexOfLast { it.id == lastMessage.id } + if(agentMsgIndex != -1) { + finalUiMessages[agentMsgIndex] = updatedAgentMessage + _uiMessages.value = finalUiMessages + } + } + } + }, + cleanUpListener = { + setInProgress(false) + setPreparing(false) + } + ) + } catch (e: Exception) { + Log.e(TAG, "Error in generateChatResponse: ", e) + setInProgress(false) + setPreparing(false) + // Add error message to UI + addMessage(model, ChatMessageWarning(content = "Error: ${e.message}")) + } + } + } + + // TODO: Override/adapt clearAllMessages, stopResponse, runAgain, handleError from base ChatViewModel + // to work with the new currentConversation model and _uiMessages. + // For example, clearAllMessages should clear _uiMessages and potentially currentConvo.messages then save. + // resetSession should re-prime with system prompt and history if currentConvo exists. +} \ No newline at end of file diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt deleted file mode 100644 index 94d1c66..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnScreen.kt +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.llmsingleturn - -import android.util.Log -import androidx.activity.compose.BackHandler -import androidx.compose.foundation.background -import androidx.compose.foundation.layout.Box -import androidx.compose.foundation.layout.Column -import androidx.compose.foundation.layout.calculateStartPadding -import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.foundation.layout.padding -import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.Scaffold -import androidx.compose.runtime.Composable -import androidx.compose.runtime.LaunchedEffect -import androidx.compose.runtime.collectAsState -import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.remember -import androidx.compose.runtime.rememberCoroutineScope -import androidx.compose.runtime.setValue -import androidx.compose.ui.Alignment -import androidx.compose.ui.Modifier -import androidx.compose.ui.draw.alpha -import androidx.compose.ui.platform.LocalContext -import androidx.compose.ui.platform.LocalLayoutDirection -import androidx.compose.ui.tooling.preview.Preview -import androidx.lifecycle.viewmodel.compose.viewModel -import com.google.ai.edge.gallery.data.ModelDownloadStatusType -import com.google.ai.edge.gallery.ui.ViewModelProvider -import com.google.ai.edge.gallery.ui.common.ErrorDialog -import com.google.ai.edge.gallery.ui.common.ModelPageAppBar -import com.google.ai.edge.gallery.ui.common.chat.ModelDownloadStatusInfoPanel -import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType -import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel -import com.google.ai.edge.gallery.ui.preview.PreviewLlmSingleTurnViewModel -import com.google.ai.edge.gallery.ui.preview.PreviewModelManagerViewModel -import com.google.ai.edge.gallery.ui.theme.GalleryTheme -import com.google.ai.edge.gallery.ui.theme.customColors -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch -import kotlinx.serialization.Serializable - -/** Navigation destination data */ -object LlmSingleTurnDestination { - @Serializable - val route = "LlmSingleTurnRoute" -} - -private const val TAG = "AGLlmSingleTurnScreen" - -@Composable -fun LlmSingleTurnScreen( - modelManagerViewModel: ModelManagerViewModel, - navigateUp: () -> Unit, - modifier: Modifier = Modifier, - viewModel: LlmSingleTurnViewModel = viewModel( - factory = ViewModelProvider.Factory - ), -) { - val task = viewModel.task - val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() - val uiState by viewModel.uiState.collectAsState() - val selectedModel = modelManagerUiState.selectedModel - val scope = rememberCoroutineScope() - val context = LocalContext.current - var navigatingUp by remember { mutableStateOf(false) } - var showErrorDialog by remember { mutableStateOf(false) } - - val handleNavigateUp = { - navigatingUp = true - navigateUp() - - // clean up all models. - scope.launch(Dispatchers.Default) { - for (model in task.models) { - modelManagerViewModel.cleanupModel(task = task, model = model) - } - } - } - - // Handle system's edge swipe. - BackHandler { - handleNavigateUp() - } - - // Initialize model when model/download state changes. - val curDownloadStatus = modelManagerUiState.modelDownloadStatus[selectedModel.name] - LaunchedEffect(curDownloadStatus, selectedModel.name) { - if (!navigatingUp) { - if (curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED) { - Log.d( - TAG, - "Initializing model '${selectedModel.name}' from LlmsingleTurnScreen launched effect" - ) - modelManagerViewModel.initializeModel(context, task = task, model = selectedModel) - } - } - } - - val modelInitializationStatus = modelManagerUiState.modelInitializationStatus[selectedModel.name] - LaunchedEffect(modelInitializationStatus) { - showErrorDialog = modelInitializationStatus?.status == ModelInitializationStatusType.ERROR - } - - Scaffold(modifier = modifier, topBar = { - ModelPageAppBar( - task = task, - model = selectedModel, - modelManagerViewModel = modelManagerViewModel, - inProgress = uiState.inProgress, - modelPreparing = uiState.preparing, - onConfigChanged = { _, _ -> }, - onBackClicked = { handleNavigateUp() }, - onModelSelected = { newSelectedModel -> - scope.launch(Dispatchers.Default) { - // Clean up current model. - modelManagerViewModel.cleanupModel(task = task, model = selectedModel) - - // Update selected model. - modelManagerViewModel.selectModel(model = newSelectedModel) - } - } - ) - }) { innerPadding -> - Column( - modifier = Modifier.padding( - top = innerPadding.calculateTopPadding(), - start = innerPadding.calculateStartPadding(LocalLayoutDirection.current), - end = innerPadding.calculateStartPadding(LocalLayoutDirection.current), - ) - ) { - ModelDownloadStatusInfoPanel( - model = selectedModel, - task = task, - modelManagerViewModel = modelManagerViewModel - ) - - // Main UI after model is downloaded. - val modelDownloaded = curDownloadStatus?.status == ModelDownloadStatusType.SUCCEEDED - Box( - contentAlignment = Alignment.BottomCenter, - modifier = Modifier - .weight(1f) - // Just hide the UI without removing it from the screen so that the scroll syncing - // from ResponsePanel still works. - .alpha(if (modelDownloaded) 1.0f else 0.0f) - ) { - VerticalSplitView(modifier = Modifier.fillMaxSize(), - topView = { - PromptTemplatesPanel( - model = selectedModel, - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - onSend = { fullPrompt -> - viewModel.generateResponse(model = selectedModel, input = fullPrompt) - }, - onStopButtonClicked = { model -> - viewModel.stopResponse(model = model) - }, - modifier = Modifier.fillMaxSize() - ) - }, - bottomView = { - Box( - contentAlignment = Alignment.BottomCenter, - modifier = Modifier - .fillMaxSize() - .background(MaterialTheme.customColors.agentBubbleBgColor) - ) { - ResponsePanel( - model = selectedModel, - viewModel = viewModel, - modelManagerViewModel = modelManagerViewModel, - modifier = Modifier - .fillMaxSize() - .padding(bottom = innerPadding.calculateBottomPadding()) - ) - } - }) - } - - if (showErrorDialog) { - ErrorDialog(error = modelInitializationStatus?.error ?: "", onDismiss = { - showErrorDialog = false - }) - } - } - } -} - -@Preview(showBackground = true) -@Composable -fun LlmSingleTurnScreenPreview() { - val context = LocalContext.current - GalleryTheme { - LlmSingleTurnScreen( - modelManagerViewModel = PreviewModelManagerViewModel(context = context), - viewModel = PreviewLlmSingleTurnViewModel(), - navigateUp = {}, - ) - } -} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt deleted file mode 100644 index 64ce612..0000000 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnViewModel.kt +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.edge.gallery.ui.llmsingleturn - -import android.util.Log -import androidx.lifecycle.ViewModel -import androidx.lifecycle.viewModelScope -import com.google.ai.edge.gallery.data.Model -import com.google.ai.edge.gallery.data.TASK_LLM_PROMPT_LAB -import com.google.ai.edge.gallery.data.Task -import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult -import com.google.ai.edge.gallery.ui.common.chat.Stat -import com.google.ai.edge.gallery.ui.common.processLlmResponse -import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper -import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.asStateFlow -import kotlinx.coroutines.flow.update -import kotlinx.coroutines.launch - -private const val TAG = "AGLlmSingleTurnViewModel" - -data class LlmSingleTurnUiState( - /** - * Indicates whether the runtime is currently processing a message. - */ - val inProgress: Boolean = false, - - /** - * Indicates whether the model is preparing (before outputting any result and after initializing). - */ - val preparing: Boolean = false, - - // model ->