I've integrated web search into the LLM Chat. Here's what I did:

- I added a WebSearchService to call the Tavily API.
- I modified the LlmChatViewModel to use the WebSearchService to augment your queries with web search results (using a placeholder API key).
- I added UI feedback for web search status (loading, errors, no results).
- I updated the ViewModelProvider to correctly inject the WebSearchService into the LlmChatViewModel and LlmAskImageViewModel.
This commit is contained in:
google-labs-jules[bot] 2025-05-25 06:18:26 +00:00
parent ebb605131d
commit 2ed268e5ce
3 changed files with 183 additions and 8 deletions

View file

@ -0,0 +1,100 @@
package com.google.ai.edge.gallery.data
import android.util.Log
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import okhttp3.MediaType.Companion.toMediaTypeOrNull
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.RequestBody.Companion.toRequestBody
import org.json.JSONObject
import java.io.IOException
data class TavilySearchResult(
val title: String,
val url: String,
val content: String,
val score: Double
)
data class TavilySearchResponse(
val answer: String?,
val query: String?,
val results: List<TavilySearchResult>?
)
class WebSearchService {
private val client = OkHttpClient()
suspend fun search(apiKey: String, query: String): TavilySearchResponse? {
return withContext(Dispatchers.IO) {
try {
val jsonRequestBody = JSONObject().apply {
put("api_key", apiKey)
put("query", query)
put("search_depth", "basic")
put("include_answer", true)
put("max_results", 3)
// include_domains and exclude_domains are empty by default
}.toString()
val request = Request.Builder()
.url("https://api.tavily.com/search")
.header("Authorization", "Bearer $apiKey")
.header("Content-Type", "application/json")
.post(jsonRequestBody.toRequestBody("application/json; charset=utf-8".toMediaTypeOrNull()))
.build()
client.newCall(request).execute().use { response ->
if (!response.isSuccessful) {
Log.e("WebSearchService", "API Error: ${response.code} ${response.message}")
return@withContext null
}
val responseBody = response.body?.string()
if (responseBody == null) {
Log.e("WebSearchService", "Empty response body")
return@withContext null
}
parseTavilyResponse(responseBody)
}
} catch (e: IOException) {
Log.e("WebSearchService", "Network Error: ${e.message}", e)
null
} catch (e: Exception) {
Log.e("WebSearchService", "Error during search: ${e.message}", e)
null
}
}
}
private fun parseTavilyResponse(responseBody: String): TavilySearchResponse? {
return try {
val jsonObject = JSONObject(responseBody)
val answer = jsonObject.optString("answer", null)
val query = jsonObject.optString("query", null)
val resultsArray = jsonObject.optJSONArray("results")
val searchResults = mutableListOf<TavilySearchResult>()
if (resultsArray != null) {
for (i in 0 until resultsArray.length()) {
val resultObj = resultsArray.getJSONObject(i)
searchResults.add(
TavilySearchResult(
title = resultObj.getString("title"),
url = resultObj.getString("url"),
content = resultObj.getString("content"),
score = resultObj.getDouble("score")
)
)
}
}
TavilySearchResponse(answer, query, if (searchResults.isEmpty()) null else searchResults)
} catch (e: Exception) {
Log.e("WebSearchService", "Error parsing JSON response: ${e.message}", e)
null
}
}
}

View file

@ -22,6 +22,7 @@ import androidx.lifecycle.viewmodel.CreationExtras
import androidx.lifecycle.viewmodel.initializer
import androidx.lifecycle.viewmodel.viewModelFactory
import com.google.ai.edge.gallery.GalleryApplication
import com.google.ai.edge.gallery.data.WebSearchService
import com.google.ai.edge.gallery.ui.imageclassification.ImageClassificationViewModel
import com.google.ai.edge.gallery.ui.imagegeneration.ImageGenerationViewModel
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
@ -32,6 +33,10 @@ import com.google.ai.edge.gallery.ui.textclassification.TextClassificationViewMo
object ViewModelProvider {
val Factory = viewModelFactory {
// Create an instance of WebSearchService
// This instance will be shared by ViewModels that need it.
val webSearchService = WebSearchService()
// Initializer for ModelManagerViewModel.
initializer {
val downloadRepository = galleryApplication().container.downloadRepository
@ -55,17 +60,21 @@ object ViewModelProvider {
// Initializer for LlmChatViewModel.
initializer {
LlmChatViewModel()
// Pass the WebSearchService instance
LlmChatViewModel(webSearchService = webSearchService)
}
// Initializer for LlmSingleTurnViewModel..
// Initializer for LlmSingleTurnViewModel.
// Note: LlmSingleTurnViewModel's constructor was not modified in previous steps.
// If it also needs WebSearchService in the future, its initializer and constructor would need similar changes.
initializer {
LlmSingleTurnViewModel()
}
// Initializer for LlmAskImageViewModel.
initializer {
LlmAskImageViewModel()
// Pass the WebSearchService instance
LlmAskImageViewModel(webSearchService = webSearchService)
}
// Initializer for ImageGenerationViewModel.

View file

@ -25,6 +25,7 @@ import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.data.WebSearchService
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
@ -46,10 +47,71 @@ private val STATS = listOf(
Stat(id = "latency", label = "Latency", unit = "sec")
)
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
open class LlmChatViewModel(
curTask: Task = TASK_LLM_CHAT,
private val webSearchService: WebSearchService
) : ChatViewModel(task = curTask) {
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
viewModelScope.launch(Dispatchers.Default) {
// Web Search Logic
var augmentedInput = input
var searchPerformed = false
var searchSuccessful = false
var searchErrorOccurred = false
// Add search in-progress indicator
val searchIndicatorMessage = ChatMessageLoading(
text = "正在為您搜索網路獲取最新資訊...",
accelerator = accelerator,
side = ChatSide.AGENT
)
addMessage(model = model, message = searchIndicatorMessage)
try {
val tavilyResponse = webSearchService.search(apiKey = "YOUR_TAVILY_API_KEY_PLACEHOLDER", query = input)
searchPerformed = true
if (tavilyResponse != null) {
searchSuccessful = true
val searchAnswer = tavilyResponse.answer
val searchResults = tavilyResponse.results
if (!searchAnswer.isNullOrBlank()) {
augmentedInput = "Based on web search results, answer the following: \"${searchAnswer}\". The original question was: \"${input}\""
} else if (!searchResults.isNullOrEmpty()) {
val snippets = searchResults.take(2).joinToString(separator = "; ") { it.content }
if (snippets.isNotBlank()) {
augmentedInput = "Based on web search results, here are some relevant snippets: \"${snippets}\". The original question was: \"${input}\""
}
}
} else {
searchErrorOccurred = true
}
} catch (e: Exception) {
Log.e(TAG, "Web search call failed", e)
searchErrorOccurred = true
}
// Remove search in-progress indicator
val lastMessage = getLastMessage(model = model)
if (lastMessage == searchIndicatorMessage) {
removeLastMessage(model = model)
}
// Add search result status messages
if (searchErrorOccurred) {
addMessage(
model = model,
message = ChatMessageWarning(content = "網路搜索失敗,將嘗試使用模型知識回答。")
)
} else if (searchPerformed && !searchSuccessful) {
addMessage(
model = model,
message = ChatMessageWarning(content = "網路搜索未能找到相關資訊,將嘗試使用模型知識回答。")
)
}
setInProgress(true)
setPreparing(true)
@ -67,9 +129,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
// Run inference.
val instance = model.instance as LlmModelInstance
var prefillTokens = instance.session.sizeInTokens(input)
var prefillTokens = instance.session.sizeInTokens(augmentedInput)
if (image != null) {
prefillTokens += 257
// Assuming image context is added separately and not part of the text prompt for token calculation here.
// If image contributes to text prompt for LLM, this might need adjustment or be handled by the model instance.
prefillTokens += 257 // This is a magic number, ensure it's correct for multimodal inputs.
}
var firstRun = true
@ -82,7 +146,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
try {
LlmChatModelHelper.runInference(model = model,
input = input,
input = augmentedInput, // Use augmentedInput here
image = image,
resultListener = { partialResult, done ->
val curTs = System.currentTimeMillis()
@ -241,8 +305,10 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
)
// Re-generate the response automatically.
// The original triggeredMessage.content will go through the search logic again.
generateResponse(model = model, input = triggeredMessage.content, onError = {})
}
}
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
class LlmAskImageViewModel(webSearchService: WebSearchService) :
LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE, webSearchService = webSearchService)