diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/WebSearchService.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/WebSearchService.kt new file mode 100644 index 0000000..c8c8ed1 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/WebSearchService.kt @@ -0,0 +1,100 @@ +package com.google.ai.edge.gallery.data + +import android.util.Log +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import okhttp3.MediaType.Companion.toMediaTypeOrNull +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.RequestBody.Companion.toRequestBody +import org.json.JSONObject +import java.io.IOException + +data class TavilySearchResult( + val title: String, + val url: String, + val content: String, + val score: Double +) + +data class TavilySearchResponse( + val answer: String?, + val query: String?, + val results: List? +) + +class WebSearchService { + + private val client = OkHttpClient() + + suspend fun search(apiKey: String, query: String): TavilySearchResponse? { + return withContext(Dispatchers.IO) { + try { + val jsonRequestBody = JSONObject().apply { + put("api_key", apiKey) + put("query", query) + put("search_depth", "basic") + put("include_answer", true) + put("max_results", 3) + // include_domains and exclude_domains are empty by default + }.toString() + + val request = Request.Builder() + .url("https://api.tavily.com/search") + .header("Authorization", "Bearer $apiKey") + .header("Content-Type", "application/json") + .post(jsonRequestBody.toRequestBody("application/json; charset=utf-8".toMediaTypeOrNull())) + .build() + + client.newCall(request).execute().use { response -> + if (!response.isSuccessful) { + Log.e("WebSearchService", "API Error: ${response.code} ${response.message}") + return@withContext null + } + + val responseBody = response.body?.string() + if (responseBody == null) { + Log.e("WebSearchService", "Empty response body") + return@withContext null + } + + parseTavilyResponse(responseBody) + } + } catch (e: IOException) { + Log.e("WebSearchService", "Network Error: ${e.message}", e) + null + } catch (e: Exception) { + Log.e("WebSearchService", "Error during search: ${e.message}", e) + null + } + } + } + + private fun parseTavilyResponse(responseBody: String): TavilySearchResponse? { + return try { + val jsonObject = JSONObject(responseBody) + val answer = jsonObject.optString("answer", null) + val query = jsonObject.optString("query", null) + + val resultsArray = jsonObject.optJSONArray("results") + val searchResults = mutableListOf() + if (resultsArray != null) { + for (i in 0 until resultsArray.length()) { + val resultObj = resultsArray.getJSONObject(i) + searchResults.add( + TavilySearchResult( + title = resultObj.getString("title"), + url = resultObj.getString("url"), + content = resultObj.getString("content"), + score = resultObj.getDouble("score") + ) + ) + } + } + TavilySearchResponse(answer, query, if (searchResults.isEmpty()) null else searchResults) + } catch (e: Exception) { + Log.e("WebSearchService", "Error parsing JSON response: ${e.message}", e) + null + } + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt index 8a257dd..2ac788c 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt @@ -22,6 +22,7 @@ import androidx.lifecycle.viewmodel.CreationExtras import androidx.lifecycle.viewmodel.initializer import androidx.lifecycle.viewmodel.viewModelFactory import com.google.ai.edge.gallery.GalleryApplication +import com.google.ai.edge.gallery.data.WebSearchService import com.google.ai.edge.gallery.ui.imageclassification.ImageClassificationViewModel import com.google.ai.edge.gallery.ui.imagegeneration.ImageGenerationViewModel import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel @@ -32,6 +33,10 @@ import com.google.ai.edge.gallery.ui.textclassification.TextClassificationViewMo object ViewModelProvider { val Factory = viewModelFactory { + // Create an instance of WebSearchService + // This instance will be shared by ViewModels that need it. + val webSearchService = WebSearchService() + // Initializer for ModelManagerViewModel. initializer { val downloadRepository = galleryApplication().container.downloadRepository @@ -55,17 +60,21 @@ object ViewModelProvider { // Initializer for LlmChatViewModel. initializer { - LlmChatViewModel() + // Pass the WebSearchService instance + LlmChatViewModel(webSearchService = webSearchService) } - // Initializer for LlmSingleTurnViewModel.. + // Initializer for LlmSingleTurnViewModel. + // Note: LlmSingleTurnViewModel's constructor was not modified in previous steps. + // If it also needs WebSearchService in the future, its initializer and constructor would need similar changes. initializer { LlmSingleTurnViewModel() } // Initializer for LlmAskImageViewModel. initializer { - LlmAskImageViewModel() + // Pass the WebSearchService instance + LlmAskImageViewModel(webSearchService = webSearchService) } // Initializer for ImageGenerationViewModel. diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt index b99cdc8..e951f4b 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -25,6 +25,7 @@ import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.data.TASK_LLM_CHAT import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE import com.google.ai.edge.gallery.data.Task +import com.google.ai.edge.gallery.data.WebSearchService import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText @@ -46,10 +47,71 @@ private val STATS = listOf( Stat(id = "latency", label = "Latency", unit = "sec") ) -open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) { +open class LlmChatViewModel( + curTask: Task = TASK_LLM_CHAT, + private val webSearchService: WebSearchService +) : ChatViewModel(task = curTask) { fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) { val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "") viewModelScope.launch(Dispatchers.Default) { + // Web Search Logic + var augmentedInput = input + var searchPerformed = false + var searchSuccessful = false + var searchErrorOccurred = false + + // Add search in-progress indicator + val searchIndicatorMessage = ChatMessageLoading( + text = "正在為您搜索網路獲取最新資訊...", + accelerator = accelerator, + side = ChatSide.AGENT + ) + addMessage(model = model, message = searchIndicatorMessage) + + try { + val tavilyResponse = webSearchService.search(apiKey = "YOUR_TAVILY_API_KEY_PLACEHOLDER", query = input) + searchPerformed = true + + if (tavilyResponse != null) { + searchSuccessful = true + val searchAnswer = tavilyResponse.answer + val searchResults = tavilyResponse.results + + if (!searchAnswer.isNullOrBlank()) { + augmentedInput = "Based on web search results, answer the following: \"${searchAnswer}\". The original question was: \"${input}\"" + } else if (!searchResults.isNullOrEmpty()) { + val snippets = searchResults.take(2).joinToString(separator = "; ") { it.content } + if (snippets.isNotBlank()) { + augmentedInput = "Based on web search results, here are some relevant snippets: \"${snippets}\". The original question was: \"${input}\"" + } + } + } else { + searchErrorOccurred = true + } + } catch (e: Exception) { + Log.e(TAG, "Web search call failed", e) + searchErrorOccurred = true + } + + // Remove search in-progress indicator + val lastMessage = getLastMessage(model = model) + if (lastMessage == searchIndicatorMessage) { + removeLastMessage(model = model) + } + + // Add search result status messages + if (searchErrorOccurred) { + addMessage( + model = model, + message = ChatMessageWarning(content = "網路搜索失敗,將嘗試使用模型知識回答。") + ) + } else if (searchPerformed && !searchSuccessful) { + addMessage( + model = model, + message = ChatMessageWarning(content = "網路搜索未能找到相關資訊,將嘗試使用模型知識回答。") + ) + } + setInProgress(true) setPreparing(true) @@ -67,9 +129,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task // Run inference. val instance = model.instance as LlmModelInstance - var prefillTokens = instance.session.sizeInTokens(input) + var prefillTokens = instance.session.sizeInTokens(augmentedInput) if (image != null) { - prefillTokens += 257 + // Assuming image context is added separately and not part of the text prompt for token calculation here. + // If image contributes to text prompt for LLM, this might need adjustment or be handled by the model instance. + prefillTokens += 257 // This is a magic number, ensure it's correct for multimodal inputs. } var firstRun = true @@ -82,7 +146,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task try { LlmChatModelHelper.runInference(model = model, - input = input, + input = augmentedInput, // Use augmentedInput here image = image, resultListener = { partialResult, done -> val curTs = System.currentTimeMillis() @@ -241,8 +305,10 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task ) // Re-generate the response automatically. + // The original triggeredMessage.content will go through the search logic again. generateResponse(model = model, input = triggeredMessage.content, onError = {}) } } -class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE) \ No newline at end of file +class LlmAskImageViewModel(webSearchService: WebSearchService) : + LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE, webSearchService = webSearchService) \ No newline at end of file