mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-17 11:46:39 -04:00
I've integrated web search into the LLM Chat. Here's what I did:
- I added a WebSearchService to call the Tavily API. - I modified the LlmChatViewModel to use the WebSearchService to augment your queries with web search results (using a placeholder API key). - I added UI feedback for web search status (loading, errors, no results). - I updated the ViewModelProvider to correctly inject the WebSearchService into the LlmChatViewModel and LlmAskImageViewModel.
This commit is contained in:
parent
ebb605131d
commit
2ed268e5ce
3 changed files with 183 additions and 8 deletions
|
@ -0,0 +1,100 @@
|
||||||
|
package com.google.ai.edge.gallery.data
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import okhttp3.MediaType.Companion.toMediaTypeOrNull
|
||||||
|
import okhttp3.OkHttpClient
|
||||||
|
import okhttp3.Request
|
||||||
|
import okhttp3.RequestBody.Companion.toRequestBody
|
||||||
|
import org.json.JSONObject
|
||||||
|
import java.io.IOException
|
||||||
|
|
||||||
|
data class TavilySearchResult(
|
||||||
|
val title: String,
|
||||||
|
val url: String,
|
||||||
|
val content: String,
|
||||||
|
val score: Double
|
||||||
|
)
|
||||||
|
|
||||||
|
data class TavilySearchResponse(
|
||||||
|
val answer: String?,
|
||||||
|
val query: String?,
|
||||||
|
val results: List<TavilySearchResult>?
|
||||||
|
)
|
||||||
|
|
||||||
|
class WebSearchService {
|
||||||
|
|
||||||
|
private val client = OkHttpClient()
|
||||||
|
|
||||||
|
suspend fun search(apiKey: String, query: String): TavilySearchResponse? {
|
||||||
|
return withContext(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
val jsonRequestBody = JSONObject().apply {
|
||||||
|
put("api_key", apiKey)
|
||||||
|
put("query", query)
|
||||||
|
put("search_depth", "basic")
|
||||||
|
put("include_answer", true)
|
||||||
|
put("max_results", 3)
|
||||||
|
// include_domains and exclude_domains are empty by default
|
||||||
|
}.toString()
|
||||||
|
|
||||||
|
val request = Request.Builder()
|
||||||
|
.url("https://api.tavily.com/search")
|
||||||
|
.header("Authorization", "Bearer $apiKey")
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.post(jsonRequestBody.toRequestBody("application/json; charset=utf-8".toMediaTypeOrNull()))
|
||||||
|
.build()
|
||||||
|
|
||||||
|
client.newCall(request).execute().use { response ->
|
||||||
|
if (!response.isSuccessful) {
|
||||||
|
Log.e("WebSearchService", "API Error: ${response.code} ${response.message}")
|
||||||
|
return@withContext null
|
||||||
|
}
|
||||||
|
|
||||||
|
val responseBody = response.body?.string()
|
||||||
|
if (responseBody == null) {
|
||||||
|
Log.e("WebSearchService", "Empty response body")
|
||||||
|
return@withContext null
|
||||||
|
}
|
||||||
|
|
||||||
|
parseTavilyResponse(responseBody)
|
||||||
|
}
|
||||||
|
} catch (e: IOException) {
|
||||||
|
Log.e("WebSearchService", "Network Error: ${e.message}", e)
|
||||||
|
null
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e("WebSearchService", "Error during search: ${e.message}", e)
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun parseTavilyResponse(responseBody: String): TavilySearchResponse? {
|
||||||
|
return try {
|
||||||
|
val jsonObject = JSONObject(responseBody)
|
||||||
|
val answer = jsonObject.optString("answer", null)
|
||||||
|
val query = jsonObject.optString("query", null)
|
||||||
|
|
||||||
|
val resultsArray = jsonObject.optJSONArray("results")
|
||||||
|
val searchResults = mutableListOf<TavilySearchResult>()
|
||||||
|
if (resultsArray != null) {
|
||||||
|
for (i in 0 until resultsArray.length()) {
|
||||||
|
val resultObj = resultsArray.getJSONObject(i)
|
||||||
|
searchResults.add(
|
||||||
|
TavilySearchResult(
|
||||||
|
title = resultObj.getString("title"),
|
||||||
|
url = resultObj.getString("url"),
|
||||||
|
content = resultObj.getString("content"),
|
||||||
|
score = resultObj.getDouble("score")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TavilySearchResponse(answer, query, if (searchResults.isEmpty()) null else searchResults)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e("WebSearchService", "Error parsing JSON response: ${e.message}", e)
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -22,6 +22,7 @@ import androidx.lifecycle.viewmodel.CreationExtras
|
||||||
import androidx.lifecycle.viewmodel.initializer
|
import androidx.lifecycle.viewmodel.initializer
|
||||||
import androidx.lifecycle.viewmodel.viewModelFactory
|
import androidx.lifecycle.viewmodel.viewModelFactory
|
||||||
import com.google.ai.edge.gallery.GalleryApplication
|
import com.google.ai.edge.gallery.GalleryApplication
|
||||||
|
import com.google.ai.edge.gallery.data.WebSearchService
|
||||||
import com.google.ai.edge.gallery.ui.imageclassification.ImageClassificationViewModel
|
import com.google.ai.edge.gallery.ui.imageclassification.ImageClassificationViewModel
|
||||||
import com.google.ai.edge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
import com.google.ai.edge.gallery.ui.imagegeneration.ImageGenerationViewModel
|
||||||
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
|
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
|
||||||
|
@ -32,6 +33,10 @@ import com.google.ai.edge.gallery.ui.textclassification.TextClassificationViewMo
|
||||||
|
|
||||||
object ViewModelProvider {
|
object ViewModelProvider {
|
||||||
val Factory = viewModelFactory {
|
val Factory = viewModelFactory {
|
||||||
|
// Create an instance of WebSearchService
|
||||||
|
// This instance will be shared by ViewModels that need it.
|
||||||
|
val webSearchService = WebSearchService()
|
||||||
|
|
||||||
// Initializer for ModelManagerViewModel.
|
// Initializer for ModelManagerViewModel.
|
||||||
initializer {
|
initializer {
|
||||||
val downloadRepository = galleryApplication().container.downloadRepository
|
val downloadRepository = galleryApplication().container.downloadRepository
|
||||||
|
@ -55,17 +60,21 @@ object ViewModelProvider {
|
||||||
|
|
||||||
// Initializer for LlmChatViewModel.
|
// Initializer for LlmChatViewModel.
|
||||||
initializer {
|
initializer {
|
||||||
LlmChatViewModel()
|
// Pass the WebSearchService instance
|
||||||
|
LlmChatViewModel(webSearchService = webSearchService)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initializer for LlmSingleTurnViewModel..
|
// Initializer for LlmSingleTurnViewModel.
|
||||||
|
// Note: LlmSingleTurnViewModel's constructor was not modified in previous steps.
|
||||||
|
// If it also needs WebSearchService in the future, its initializer and constructor would need similar changes.
|
||||||
initializer {
|
initializer {
|
||||||
LlmSingleTurnViewModel()
|
LlmSingleTurnViewModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initializer for LlmAskImageViewModel.
|
// Initializer for LlmAskImageViewModel.
|
||||||
initializer {
|
initializer {
|
||||||
LlmAskImageViewModel()
|
// Pass the WebSearchService instance
|
||||||
|
LlmAskImageViewModel(webSearchService = webSearchService)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initializer for ImageGenerationViewModel.
|
// Initializer for ImageGenerationViewModel.
|
||||||
|
|
|
@ -25,6 +25,7 @@ import com.google.ai.edge.gallery.data.Model
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
import com.google.ai.edge.gallery.data.TASK_LLM_CHAT
|
||||||
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||||
import com.google.ai.edge.gallery.data.Task
|
import com.google.ai.edge.gallery.data.Task
|
||||||
|
import com.google.ai.edge.gallery.data.WebSearchService
|
||||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageBenchmarkLlmResult
|
||||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageLoading
|
||||||
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
|
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
|
||||||
|
@ -46,10 +47,71 @@ private val STATS = listOf(
|
||||||
Stat(id = "latency", label = "Latency", unit = "sec")
|
Stat(id = "latency", label = "Latency", unit = "sec")
|
||||||
)
|
)
|
||||||
|
|
||||||
open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task = curTask) {
|
open class LlmChatViewModel(
|
||||||
|
curTask: Task = TASK_LLM_CHAT,
|
||||||
|
private val webSearchService: WebSearchService
|
||||||
|
) : ChatViewModel(task = curTask) {
|
||||||
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
|
fun generateResponse(model: Model, input: String, image: Bitmap? = null, onError: () -> Unit) {
|
||||||
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
|
val accelerator = model.getStringConfigValue(key = ConfigKey.ACCELERATOR, defaultValue = "")
|
||||||
viewModelScope.launch(Dispatchers.Default) {
|
viewModelScope.launch(Dispatchers.Default) {
|
||||||
|
// Web Search Logic
|
||||||
|
var augmentedInput = input
|
||||||
|
var searchPerformed = false
|
||||||
|
var searchSuccessful = false
|
||||||
|
var searchErrorOccurred = false
|
||||||
|
|
||||||
|
// Add search in-progress indicator
|
||||||
|
val searchIndicatorMessage = ChatMessageLoading(
|
||||||
|
text = "正在為您搜索網路獲取最新資訊...",
|
||||||
|
accelerator = accelerator,
|
||||||
|
side = ChatSide.AGENT
|
||||||
|
)
|
||||||
|
addMessage(model = model, message = searchIndicatorMessage)
|
||||||
|
|
||||||
|
try {
|
||||||
|
val tavilyResponse = webSearchService.search(apiKey = "YOUR_TAVILY_API_KEY_PLACEHOLDER", query = input)
|
||||||
|
searchPerformed = true
|
||||||
|
|
||||||
|
if (tavilyResponse != null) {
|
||||||
|
searchSuccessful = true
|
||||||
|
val searchAnswer = tavilyResponse.answer
|
||||||
|
val searchResults = tavilyResponse.results
|
||||||
|
|
||||||
|
if (!searchAnswer.isNullOrBlank()) {
|
||||||
|
augmentedInput = "Based on web search results, answer the following: \"${searchAnswer}\". The original question was: \"${input}\""
|
||||||
|
} else if (!searchResults.isNullOrEmpty()) {
|
||||||
|
val snippets = searchResults.take(2).joinToString(separator = "; ") { it.content }
|
||||||
|
if (snippets.isNotBlank()) {
|
||||||
|
augmentedInput = "Based on web search results, here are some relevant snippets: \"${snippets}\". The original question was: \"${input}\""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
searchErrorOccurred = true
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Web search call failed", e)
|
||||||
|
searchErrorOccurred = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove search in-progress indicator
|
||||||
|
val lastMessage = getLastMessage(model = model)
|
||||||
|
if (lastMessage == searchIndicatorMessage) {
|
||||||
|
removeLastMessage(model = model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add search result status messages
|
||||||
|
if (searchErrorOccurred) {
|
||||||
|
addMessage(
|
||||||
|
model = model,
|
||||||
|
message = ChatMessageWarning(content = "網路搜索失敗,將嘗試使用模型知識回答。")
|
||||||
|
)
|
||||||
|
} else if (searchPerformed && !searchSuccessful) {
|
||||||
|
addMessage(
|
||||||
|
model = model,
|
||||||
|
message = ChatMessageWarning(content = "網路搜索未能找到相關資訊,將嘗試使用模型知識回答。")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
setInProgress(true)
|
setInProgress(true)
|
||||||
setPreparing(true)
|
setPreparing(true)
|
||||||
|
|
||||||
|
@ -67,9 +129,11 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
|
|
||||||
// Run inference.
|
// Run inference.
|
||||||
val instance = model.instance as LlmModelInstance
|
val instance = model.instance as LlmModelInstance
|
||||||
var prefillTokens = instance.session.sizeInTokens(input)
|
var prefillTokens = instance.session.sizeInTokens(augmentedInput)
|
||||||
if (image != null) {
|
if (image != null) {
|
||||||
prefillTokens += 257
|
// Assuming image context is added separately and not part of the text prompt for token calculation here.
|
||||||
|
// If image contributes to text prompt for LLM, this might need adjustment or be handled by the model instance.
|
||||||
|
prefillTokens += 257 // This is a magic number, ensure it's correct for multimodal inputs.
|
||||||
}
|
}
|
||||||
|
|
||||||
var firstRun = true
|
var firstRun = true
|
||||||
|
@ -82,7 +146,7 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
|
|
||||||
try {
|
try {
|
||||||
LlmChatModelHelper.runInference(model = model,
|
LlmChatModelHelper.runInference(model = model,
|
||||||
input = input,
|
input = augmentedInput, // Use augmentedInput here
|
||||||
image = image,
|
image = image,
|
||||||
resultListener = { partialResult, done ->
|
resultListener = { partialResult, done ->
|
||||||
val curTs = System.currentTimeMillis()
|
val curTs = System.currentTimeMillis()
|
||||||
|
@ -241,8 +305,10 @@ open class LlmChatViewModel(curTask: Task = TASK_LLM_CHAT) : ChatViewModel(task
|
||||||
)
|
)
|
||||||
|
|
||||||
// Re-generate the response automatically.
|
// Re-generate the response automatically.
|
||||||
|
// The original triggeredMessage.content will go through the search logic again.
|
||||||
generateResponse(model = model, input = triggeredMessage.content, onError = {})
|
generateResponse(model = model, input = triggeredMessage.content, onError = {})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class LlmAskImageViewModel : LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE)
|
class LlmAskImageViewModel(webSearchService: WebSearchService) :
|
||||||
|
LlmChatViewModel(curTask = TASK_LLM_ASK_IMAGE, webSearchService = webSearchService)
|
Loading…
Add table
Add a link
Reference in a new issue