mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-16 11:16:43 -04:00
I've integrated web search into the LLM Chat. Here's what I did:
- I added a WebSearchService to call the Tavily API. - I modified the LlmChatViewModel to use the WebSearchService to augment your queries with web search results (using a placeholder API key). - I added UI feedback for web search status (loading, errors, no results). - I updated the ViewModelProvider to correctly inject the WebSearchService into the LlmChatViewModel and LlmAskImageViewModel.
This commit is contained in:
parent
ebb605131d
commit
2ed268e5ce
3 changed files with 183 additions and 8 deletions
|
@ -0,0 +1,100 @@
|
|||
package com.google.ai.edge.gallery.data
|
||||
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import okhttp3.MediaType.Companion.toMediaTypeOrNull
|
||||
import okhttp3.OkHttpClient
|
||||
import okhttp3.Request
|
||||
import okhttp3.RequestBody.Companion.toRequestBody
|
||||
import org.json.JSONObject
|
||||
import java.io.IOException
|
||||
|
||||
data class TavilySearchResult(
|
||||
val title: String,
|
||||
val url: String,
|
||||
val content: String,
|
||||
val score: Double
|
||||
)
|
||||
|
||||
data class TavilySearchResponse(
|
||||
val answer: String?,
|
||||
val query: String?,
|
||||
val results: List<TavilySearchResult>?
|
||||
)
|
||||
|
||||
class WebSearchService {
|
||||
|
||||
private val client = OkHttpClient()
|
||||
|
||||
suspend fun search(apiKey: String, query: String): TavilySearchResponse? {
|
||||
return withContext(Dispatchers.IO) {
|
||||
try {
|
||||
val jsonRequestBody = JSONObject().apply {
|
||||
put("api_key", apiKey)
|
||||
put("query", query)
|
||||
put("search_depth", "basic")
|
||||
put("include_answer", true)
|
||||
put("max_results", 3)
|
||||
// include_domains and exclude_domains are empty by default
|
||||
}.toString()
|
||||
|
||||
val request = Request.Builder()
|
||||
.url("https://api.tavily.com/search")
|
||||
.header("Authorization", "Bearer $apiKey")
|
||||
.header("Content-Type", "application/json")
|
||||
.post(jsonRequestBody.toRequestBody("application/json; charset=utf-8".toMediaTypeOrNull()))
|
||||
.build()
|
||||
|
||||
client.newCall(request).execute().use { response ->
|
||||
if (!response.isSuccessful) {
|
||||
Log.e("WebSearchService", "API Error: ${response.code} ${response.message}")
|
||||
return@withContext null
|
||||
}
|
||||
|
||||
val responseBody = response.body?.string()
|
||||
if (responseBody == null) {
|
||||
Log.e("WebSearchService", "Empty response body")
|
||||
return@withContext null
|
||||
}
|
||||
|
||||
parseTavilyResponse(responseBody)
|
||||
}
|
||||
} catch (e: IOException) {
|
||||
Log.e("WebSearchService", "Network Error: ${e.message}", e)
|
||||
null
|
||||
} catch (e: Exception) {
|
||||
Log.e("WebSearchService", "Error during search: ${e.message}", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun parseTavilyResponse(responseBody: String): TavilySearchResponse? {
|
||||
return try {
|
||||
val jsonObject = JSONObject(responseBody)
|
||||
val answer = jsonObject.optString("answer", null)
|
||||
val query = jsonObject.optString("query", null)
|
||||
|
||||
val resultsArray = jsonObject.optJSONArray("results")
|
||||
val searchResults = mutableListOf<TavilySearchResult>()
|
||||
if (resultsArray != null) {
|
||||
for (i in 0 until resultsArray.length()) {
|
||||
val resultObj = resultsArray.getJSONObject(i)
|
||||
searchResults.add(
|
||||
TavilySearchResult(
|
||||
title = resultObj.getString("title"),
|
||||
url = resultObj.getString("url"),
|
||||
content = resultObj.getString("content"),
|
||||
score = resultObj.getDouble("score")
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
TavilySearchResponse(answer, query, if (searchResults.isEmpty()) null else searchResults)
|
||||
} catch (e: Exception) {
|
||||
Log.e("WebSearchService", "Error parsing JSON response: ${e.message}", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
|
@ -22,6 +22,7 @@ import androidx.lifecycle.viewmodel.CreationExtras
|
|||
import androidx.lifecycle.viewmodel.initializer
|
||||
import androidx.lifecycle.viewmodel.viewModelFactory
|
||||
import com.google.ai.edge.gallery.GalleryApplication
|
||||
import com.google.ai.edge.gallery.data.WebSearchService
|
||||
import com.google.ai.edge.gallery.ui.imageclassification.ImageClassificationViewModel
|
||||
import com.google.ai.edge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
||||
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
|
||||
|
@ -32,6 +33,10 @@ import com.google.ai.edge.gallery.ui.textclassification.TextClassificationViewMo
|
|||
|
||||
object ViewModelProvider {
|
||||
val Factory = viewModelFactory {
|
||||
// Create an instance of WebSearchService
|
||||
// This instance will be shared by ViewModels that need it.
|
||||
val webSearchService = WebSearchService()
|
||||
|
||||
// Initializer for ModelManagerViewModel.
|
||||
initializer {
|
||||
val downloadRepository = galleryApplication().container.downloadRepository
|
||||
|
@ -55,17 +60,21 @@ object ViewModelProvider {
|
|||
|
||||
// Initializer for LlmChatViewModel.
|
||||
initializer {
|
||||
LlmChatViewModel()
|
||||
// Pass the WebSearchService instance
|
||||
LlmChatViewModel(webSearchService = webSearchService)
|
||||
}
|
||||
|
||||
// Initializer for LlmSingleTurnViewModel..
|
||||
// Initializer for LlmSingleTurnViewModel.
|
||||
// Note: LlmSingleTurnViewModel's constructor was not modified in previous steps.
|
||||
// If it also needs WebSearchService in the future, its initializer and constructor would need similar changes.
|
||||
initializer {
|
||||
LlmSingleTurnViewModel()
|
||||
}
|
||||
|
||||
// Initializer for LlmAskImageViewModel.
|
||||
initializer {
|
||||
LlmAskImageViewModel()
|
||||
// Pass the WebSearchService instance
|
||||
LlmAskImageViewModel(webSearchService = webSearchService)
|
||||
}
|
||||
|
||||
// Initializer for ImageGenerationViewModel.
|
||||
|
|
|
@ -25,6 +25,7 @@ import com.google.ai.edge.gallery.data.Model
|
|||
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||
import com.google.ai.edge.gallery.data.Task
|
||||
import com.google.ai.edge.gallery.data.WebSearchService
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
|
||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
|
||||
|
@ -46,10 +47,71 @@ private val STATS = listOf(
|
|||
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||
)
|
||||
|
||||
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
|
||||
open class LlmChatViewModel(
|
||||
curTask: Task = TASK_LLM_CHAT,
|
||||
private val webSearchService: WebSearchService
|
||||
) : ChatViewModel(task = curTask) {
|
||||
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
|
||||
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
|
||||
viewModelScope.launch(Dispatchers.Default) {
|
||||
// Web Search Logic
|
||||
var augmentedInput = input
|
||||
var searchPerformed = false
|
||||
var searchSuccessful = false
|
||||
var searchErrorOccurred = false
|
||||
|
||||
// Add search in-progress indicator
|
||||
val searchIndicatorMessage = ChatMessageLoading(
|
||||
text = "正在為您搜索網路獲取最新資訊...",
|
||||
accelerator = accelerator,
|
||||
side = ChatSide.AGENT
|
||||
)
|
||||
addMessage(model = model, message = searchIndicatorMessage)
|
||||
|
||||
try {
|
||||
val tavilyResponse = webSearchService.search(apiKey = "YOUR_TAVILY_API_KEY_PLACEHOLDER", query = input)
|
||||
searchPerformed = true
|
||||
|
||||
if (tavilyResponse != null) {
|
||||
searchSuccessful = true
|
||||
val searchAnswer = tavilyResponse.answer
|
||||
val searchResults = tavilyResponse.results
|
||||
|
||||
if (!searchAnswer.isNullOrBlank()) {
|
||||
augmentedInput = "Based on web search results, answer the following: \"${searchAnswer}\". The original question was: \"${input}\""
|
||||
} else if (!searchResults.isNullOrEmpty()) {
|
||||
val snippets = searchResults.take(2).joinToString(separator = "; ") { it.content }
|
||||
if (snippets.isNotBlank()) {
|
||||
augmentedInput = "Based on web search results, here are some relevant snippets: \"${snippets}\". The original question was: \"${input}\""
|
||||
}
|
||||
}
|
||||
} else {
|
||||
searchErrorOccurred = true
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Web search call failed", e)
|
||||
searchErrorOccurred = true
|
||||
}
|
||||
|
||||
// Remove search in-progress indicator
|
||||
val lastMessage = getLastMessage(model = model)
|
||||
if (lastMessage == searchIndicatorMessage) {
|
||||
removeLastMessage(model = model)
|
||||
}
|
||||
|
||||
// Add search result status messages
|
||||
if (searchErrorOccurred) {
|
||||
addMessage(
|
||||
model = model,
|
||||
message = ChatMessageWarning(content = "網路搜索失敗,將嘗試使用模型知識回答。")
|
||||
)
|
||||
} else if (searchPerformed && !searchSuccessful) {
|
||||
addMessage(
|
||||
model = model,
|
||||
message = ChatMessageWarning(content = "網路搜索未能找到相關資訊,將嘗試使用模型知識回答。")
|
||||
)
|
||||
}
|
||||
|
||||
setInProgress(true)
|
||||
setPreparing(true)
|
||||
|
||||
|
@ -67,9 +129,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
|
||||
// Run inference.
|
||||
val instance = model.instance as LlmModelInstance
|
||||
var prefillTokens = instance.session.sizeInTokens(input)
|
||||
var prefillTokens = instance.session.sizeInTokens(augmentedInput)
|
||||
if (image != null) {
|
||||
prefillTokens += 257
|
||||
// Assuming image context is added separately and not part of the text prompt for token calculation here.
|
||||
// If image contributes to text prompt for LLM, this might need adjustment or be handled by the model instance.
|
||||
prefillTokens += 257 // This is a magic number, ensure it's correct for multimodal inputs.
|
||||
}
|
||||
|
||||
var firstRun = true
|
||||
|
@ -82,7 +146,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
|
||||
try {
|
||||
LlmChatModelHelper.runInference(model = model,
|
||||
input = input,
|
||||
input = augmentedInput, // Use augmentedInput here
|
||||
image = image,
|
||||
resultListener = { partialResult, done ->
|
||||
val curTs = System.currentTimeMillis()
|
||||
|
@ -241,8 +305,10 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
|||
)
|
||||
|
||||
// Re-generate the response automatically.
|
||||
// The original triggeredMessage.content will go through the search logic again.
|
||||
generateResponse(model = model, input = triggeredMessage.content, onError = {})
|
||||
}
|
||||
}
|
||||
|
||||
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
|
||||
class LlmAskImageViewModel(webSearchService: WebSearchService) :
|
||||
LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE, webSearchService = webSearchService)
|
Loading…
Add table
Add a link
Reference in a new issue