mirror of
https://github.com/google-ai-edge/gallery.git
synced 2025-07-17 11:46:39 -04:00
feat: Add Toggle Server for local inference
This commit introduces a new "Toggle Server" feature that runs a local HTTP server on the device. This allows developers and researchers to interact with the on-device AI models using `curl`, with all communication tunneled exclusively over the USB cable. The server can handle multipart/form-data requests, allowing users to send a prompt, an image, or both. This provides a powerful new way to test, debug, and integrate the on-device models.
This commit is contained in:
parent
d97e115993
commit
05ad04deda
16 changed files with 453 additions and 19 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1 +1,3 @@
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
.idea/
|
||||||
|
.gemini/
|
||||||
|
|
|
@ -99,6 +99,7 @@ dependencies {
|
||||||
implementation(libs.hilt.navigation.compose)
|
implementation(libs.hilt.navigation.compose)
|
||||||
implementation(platform(libs.firebase.bom))
|
implementation(platform(libs.firebase.bom))
|
||||||
implementation(libs.firebase.analytics)
|
implementation(libs.firebase.analytics)
|
||||||
|
implementation("commons-fileupload:commons-fileupload:1.4")
|
||||||
kapt(libs.hilt.android.compiler)
|
kapt(libs.hilt.android.compiler)
|
||||||
testImplementation(libs.junit)
|
testImplementation(libs.junit)
|
||||||
androidTestImplementation(libs.androidx.junit)
|
androidTestImplementation(libs.androidx.junit)
|
||||||
|
@ -113,4 +114,4 @@ dependencies {
|
||||||
protobuf {
|
protobuf {
|
||||||
protoc { artifact = "com.google.protobuf:protoc:4.26.1" }
|
protoc { artifact = "com.google.protobuf:protoc:4.26.1" }
|
||||||
generateProtoTasks { all().forEach { it.plugins { create("java") { option("lite") } } } }
|
generateProtoTasks { all().forEach { it.plugins { create("java") { option("lite") } } } }
|
||||||
}
|
}
|
|
@ -32,6 +32,7 @@
|
||||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||||
<uses-permission android:name="android.permission.WAKE_LOCK"/>
|
<uses-permission android:name="android.permission.WAKE_LOCK"/>
|
||||||
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
|
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
|
||||||
|
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
|
||||||
|
|
||||||
<uses-feature
|
<uses-feature
|
||||||
android:name="android.hardware.camera"
|
android:name="android.hardware.camera"
|
||||||
|
@ -47,6 +48,7 @@
|
||||||
android:roundIcon="@mipmap/ic_launcher"
|
android:roundIcon="@mipmap/ic_launcher"
|
||||||
android:supportsRtl="true"
|
android:supportsRtl="true"
|
||||||
android:theme="@style/Theme.Gallery"
|
android:theme="@style/Theme.Gallery"
|
||||||
|
android:enableOnBackInvokedCallback="true"
|
||||||
tools:targetApi="31">
|
tools:targetApi="31">
|
||||||
<activity
|
<activity
|
||||||
android:name="com.google.ai.edge.gallery.MainActivity"
|
android:name="com.google.ai.edge.gallery.MainActivity"
|
||||||
|
|
|
@ -26,6 +26,7 @@ import androidx.activity.enableEdgeToEdge
|
||||||
import androidx.compose.foundation.layout.fillMaxSize
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.material3.Surface
|
import androidx.compose.material3.Surface
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.core.app.ActivityCompat
|
||||||
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
|
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
|
||||||
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
|
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
|
||||||
import com.google.firebase.analytics.FirebaseAnalytics
|
import com.google.firebase.analytics.FirebaseAnalytics
|
||||||
|
@ -60,6 +61,10 @@ class MainActivity : ComponentActivity() {
|
||||||
setContent { GalleryTheme { Surface(modifier = Modifier.fillMaxSize()) { GalleryApp() } } }
|
setContent { GalleryTheme { Surface(modifier = Modifier.fillMaxSize()) { GalleryApp() } } }
|
||||||
// Keep the screen on while the app is running for better demo experience.
|
// Keep the screen on while the app is running for better demo experience.
|
||||||
window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
|
window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
|
||||||
|
|
||||||
|
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
|
||||||
|
requestPermissions(arrayOf(android.Manifest.permission.READ_EXTERNAL_STORAGE), 1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|
|
@ -33,6 +33,7 @@ enum class TaskType(val label: String, val id: String) {
|
||||||
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
|
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
|
||||||
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
|
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
|
||||||
LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
|
LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
|
||||||
|
TOGGLE_SERVER(label = "Toggle Server", id = "toggle_server"),
|
||||||
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
|
||||||
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
|
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
|
||||||
}
|
}
|
||||||
|
@ -121,9 +122,20 @@ val TASK_LLM_ASK_AUDIO =
|
||||||
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val TASK_TOGGLE_SERVER =
|
||||||
|
Task(
|
||||||
|
type = TaskType.TOGGLE_SERVER,
|
||||||
|
icon = Icons.Outlined.Forum,
|
||||||
|
models = mutableListOf(),
|
||||||
|
description = "Toggle an LLM endpoint server running on-device (Placeholder).",
|
||||||
|
docUrl = "",
|
||||||
|
sourceCodeUrl = "",
|
||||||
|
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
|
||||||
|
)
|
||||||
|
|
||||||
/** All tasks. */
|
/** All tasks. */
|
||||||
val TASKS: List<Task> =
|
val TASKS: List<Task> =
|
||||||
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
|
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT, TASK_TOGGLE_SERVER)
|
||||||
|
|
||||||
fun getModelByName(name: String): Model? {
|
fun getModelByName(name: String): Model? {
|
||||||
for (task in TASKS) {
|
for (task in TASKS) {
|
||||||
|
|
|
@ -0,0 +1,278 @@
|
||||||
|
package com.google.ai.edge.gallery.server
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.graphics.Bitmap
|
||||||
|
import android.graphics.BitmapFactory
|
||||||
|
import android.util.Log
|
||||||
|
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
|
||||||
|
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
|
||||||
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
|
import java.io.ByteArrayInputStream
|
||||||
|
import java.io.ByteArrayOutputStream
|
||||||
|
import java.io.IOException
|
||||||
|
import java.io.InputStream
|
||||||
|
import java.io.PrintWriter
|
||||||
|
import java.net.ServerSocket
|
||||||
|
import java.net.Socket
|
||||||
|
import java.net.SocketException
|
||||||
|
import java.net.URLDecoder
|
||||||
|
import java.util.concurrent.CountDownLatch
|
||||||
|
import javax.inject.Inject
|
||||||
|
import javax.inject.Singleton
|
||||||
|
import org.apache.commons.fileupload.FileItem
|
||||||
|
import org.apache.commons.fileupload.disk.DiskFileItemFactory
|
||||||
|
import org.apache.commons.fileupload.servlet.ServletFileUpload
|
||||||
|
|
||||||
|
@Singleton
|
||||||
|
class InAppServer @Inject constructor(
|
||||||
|
@ApplicationContext private val context: Context,
|
||||||
|
private val llmChatModelHelper: LlmChatModelHelper
|
||||||
|
) {
|
||||||
|
|
||||||
|
private var serverSocket: ServerSocket? = null
|
||||||
|
private var serverThread: Thread? = null
|
||||||
|
@Volatile
|
||||||
|
private var isServerRunning = false
|
||||||
|
|
||||||
|
fun start() {
|
||||||
|
if (isServerRunning) return
|
||||||
|
|
||||||
|
serverThread = Thread {
|
||||||
|
try {
|
||||||
|
llmChatModelHelper.initialize(context, TASK_LLM_ASK_IMAGE.models.first()) {
|
||||||
|
if (it.isNotEmpty()) {
|
||||||
|
Log.e(TAG, "Failed to initialize model: $it")
|
||||||
|
return@initialize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
serverSocket = ServerSocket(DEVICE_PORT)
|
||||||
|
isServerRunning = true
|
||||||
|
Log.i(TAG, "In-App Server started on port " + DEVICE_PORT)
|
||||||
|
|
||||||
|
while (isServerRunning) {
|
||||||
|
try {
|
||||||
|
val clientSocket = serverSocket!!.accept()
|
||||||
|
Log.i(TAG, "Client connected: " + clientSocket.inetAddress)
|
||||||
|
handleClient(clientSocket)
|
||||||
|
} catch (e: SocketException) {
|
||||||
|
if (!isServerRunning) {
|
||||||
|
Log.i(TAG, "Server socket closed intentionally.")
|
||||||
|
} else {
|
||||||
|
Log.e(TAG, "Error accepting connection", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e: IOException) {
|
||||||
|
Log.e(TAG, "Error starting server", e)
|
||||||
|
isServerRunning = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
serverThread!!.start()
|
||||||
|
}
|
||||||
|
|
||||||
|
fun stop() {
|
||||||
|
if (!isServerRunning) return
|
||||||
|
|
||||||
|
try {
|
||||||
|
isServerRunning = false
|
||||||
|
if (serverSocket != null && !serverSocket!!.isClosed) {
|
||||||
|
serverSocket!!.close()
|
||||||
|
}
|
||||||
|
if (serverThread != null) {
|
||||||
|
serverThread!!.interrupt()
|
||||||
|
serverThread = null
|
||||||
|
}
|
||||||
|
llmChatModelHelper.cleanUp(TASK_LLM_ASK_IMAGE.models.first())
|
||||||
|
Log.i(TAG, "In-App Server stopped.")
|
||||||
|
} catch (e: IOException) {
|
||||||
|
Log.e(TAG, "Error stopping server", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun handleClient(clientSocket: Socket) {
|
||||||
|
try {
|
||||||
|
val inputStream = clientSocket.inputStream
|
||||||
|
val writer = PrintWriter(clientSocket.outputStream, true)
|
||||||
|
|
||||||
|
val requestLine = readLine(inputStream)
|
||||||
|
if (requestLine.isBlank()) {
|
||||||
|
clientSocket.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Log.i(TAG, "Request: $requestLine")
|
||||||
|
|
||||||
|
val requestParts = requestLine.split(" ")
|
||||||
|
val method = requestParts[0]
|
||||||
|
|
||||||
|
var contentType = ""
|
||||||
|
var contentLength = 0
|
||||||
|
var line = readLine(inputStream)
|
||||||
|
while (line.isNotEmpty()) {
|
||||||
|
if (line.startsWith("Content-Type:", ignoreCase = true)) {
|
||||||
|
contentType = line.substringAfter(":").trim()
|
||||||
|
} else if (line.startsWith("Content-Length:", ignoreCase = true)) {
|
||||||
|
contentLength = line.substringAfter(":").trim().toInt()
|
||||||
|
}
|
||||||
|
line = readLine(inputStream)
|
||||||
|
}
|
||||||
|
|
||||||
|
var prompt = ""
|
||||||
|
var imageData: ByteArray? = null
|
||||||
|
|
||||||
|
if (method == "POST") {
|
||||||
|
if (contentLength > 0) {
|
||||||
|
val bodyBytes = ByteArray(contentLength)
|
||||||
|
var bytesRead = 0
|
||||||
|
while (bytesRead < contentLength) {
|
||||||
|
val read = inputStream.read(bodyBytes, bytesRead, contentLength - bytesRead)
|
||||||
|
if (read == -1) break
|
||||||
|
bytesRead += read
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ServletFileUpload.isMultipartContent(RequestContext(ByteArrayInputStream(bodyBytes), contentType, contentLength))) {
|
||||||
|
val factory = DiskFileItemFactory()
|
||||||
|
val upload = ServletFileUpload(factory)
|
||||||
|
val items = upload.parseRequest(RequestContext(ByteArrayInputStream(bodyBytes), contentType, contentLength))
|
||||||
|
for (item in items) {
|
||||||
|
if (item.isFormField) {
|
||||||
|
if (item.fieldName == "prompt") {
|
||||||
|
prompt = item.string
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (item.fieldName == "image") {
|
||||||
|
imageData = item.get()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
prompt = String(bodyBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else { // GET
|
||||||
|
val queryParams = getQueryParams(requestLine)
|
||||||
|
prompt = queryParams["prompt"] ?: ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if (prompt.isBlank() && imageData == null) {
|
||||||
|
writer.println("HTTP/1.1 400 Bad Request")
|
||||||
|
writer.println("Content-Type: text/plain")
|
||||||
|
writer.println()
|
||||||
|
writer.println("No prompt or image provided.")
|
||||||
|
writer.flush()
|
||||||
|
clientSocket.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.println("HTTP/1.1 200 OK")
|
||||||
|
writer.println("Content-Type: text/plain")
|
||||||
|
writer.println("Connection: close")
|
||||||
|
writer.println()
|
||||||
|
writer.flush()
|
||||||
|
|
||||||
|
val latch = CountDownLatch(1)
|
||||||
|
llmChatModelHelper.resetSession(TASK_LLM_ASK_IMAGE.models.first())
|
||||||
|
|
||||||
|
val images: List<Bitmap> = imageData?.let {
|
||||||
|
val bitmap = BitmapFactory.decodeByteArray(it, 0, it.size)
|
||||||
|
listOf(bitmap)
|
||||||
|
} ?: emptyList()
|
||||||
|
|
||||||
|
llmChatModelHelper.runInference(
|
||||||
|
model = TASK_LLM_ASK_IMAGE.models.first(),
|
||||||
|
input = prompt,
|
||||||
|
images = images,
|
||||||
|
resultListener = { partialResult, done ->
|
||||||
|
writer.print(partialResult)
|
||||||
|
writer.flush()
|
||||||
|
if (done) {
|
||||||
|
clientSocket.close()
|
||||||
|
latch.countDown()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
cleanUpListener = {
|
||||||
|
if (!clientSocket.isClosed) {
|
||||||
|
writer.flush()
|
||||||
|
clientSocket.close()
|
||||||
|
}
|
||||||
|
latch.countDown()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
latch.await()
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Error handling client", e)
|
||||||
|
} finally {
|
||||||
|
try {
|
||||||
|
if (!clientSocket.isClosed) {
|
||||||
|
clientSocket.close()
|
||||||
|
}
|
||||||
|
} catch (e: IOException) {
|
||||||
|
Log.e(TAG, "Error closing client socket", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun readLine(stream: InputStream): String {
|
||||||
|
val buffer = ByteArrayOutputStream()
|
||||||
|
while (true) {
|
||||||
|
val b = stream.read()
|
||||||
|
if (b == -1) break
|
||||||
|
if (b == '\n'.code) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
buffer.write(b)
|
||||||
|
}
|
||||||
|
val bytes = buffer.toByteArray()
|
||||||
|
if (bytes.isNotEmpty() && bytes.last() == '\r'.toByte()) {
|
||||||
|
return String(bytes, 0, bytes.size - 1, Charsets.ISO_8859_1)
|
||||||
|
}
|
||||||
|
return String(bytes, Charsets.ISO_8859_1)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun getQueryParams(requestLine: String): Map<String, String> {
|
||||||
|
val queryParams = mutableMapOf<String, String>()
|
||||||
|
val urlParts = requestLine.split(" ")[1].split("?")
|
||||||
|
if (urlParts.size > 1) {
|
||||||
|
val query = urlParts[1]
|
||||||
|
for (param in query.split("&")) {
|
||||||
|
val pair = param.split("=")
|
||||||
|
if (pair.size > 1) {
|
||||||
|
queryParams[URLDecoder.decode(pair[0], "UTF-8")] =
|
||||||
|
URLDecoder.decode(pair[1], "UTF-8")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return queryParams
|
||||||
|
}
|
||||||
|
|
||||||
|
fun isRunning(): Boolean {
|
||||||
|
return isServerRunning
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val TAG = "AIEdgeServer"
|
||||||
|
private const val DEVICE_PORT = 8080
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class RequestContext(
|
||||||
|
private val inputStream: java.io.InputStream,
|
||||||
|
private val contentType: String,
|
||||||
|
private val contentLength: Int
|
||||||
|
) : org.apache.commons.fileupload.RequestContext {
|
||||||
|
override fun getCharacterEncoding(): String {
|
||||||
|
return "UTF-8"
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getContentType(): String {
|
||||||
|
return contentType
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getContentLength(): Int {
|
||||||
|
return contentLength
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getInputStream(): java.io.InputStream {
|
||||||
|
return inputStream
|
||||||
|
}
|
||||||
|
}
|
|
@ -32,6 +32,8 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder
|
||||||
import com.google.mediapipe.tasks.genai.llminference.GraphOptions
|
import com.google.mediapipe.tasks.genai.llminference.GraphOptions
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
import com.google.mediapipe.tasks.genai.llminference.LlmInference
|
||||||
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
|
||||||
|
import javax.inject.Inject
|
||||||
|
import javax.inject.Singleton
|
||||||
|
|
||||||
private const val TAG = "AGLlmChatModelHelper"
|
private const val TAG = "AGLlmChatModelHelper"
|
||||||
|
|
||||||
|
@ -41,7 +43,8 @@ typealias CleanUpListener = () -> Unit
|
||||||
|
|
||||||
data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession)
|
data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession)
|
||||||
|
|
||||||
object LlmChatModelHelper {
|
@Singleton
|
||||||
|
class LlmChatModelHelper @Inject constructor() {
|
||||||
// Indexed by model name.
|
// Indexed by model name.
|
||||||
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
|
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,10 @@ private val STATS =
|
||||||
Stat(id = "latency", label = "Latency", unit = "sec"),
|
Stat(id = "latency", label = "Latency", unit = "sec"),
|
||||||
)
|
)
|
||||||
|
|
||||||
open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTask) {
|
open class LlmChatViewModelBase(
|
||||||
|
val curTask: Task,
|
||||||
|
private val llmChatModelHelper: LlmChatModelHelper
|
||||||
|
) : ChatViewModel(task = curTask) {
|
||||||
fun generateResponse(
|
fun generateResponse(
|
||||||
model: Model,
|
model: Model,
|
||||||
input: String,
|
input: String,
|
||||||
|
@ -92,7 +95,7 @@ open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTas
|
||||||
val start = System.currentTimeMillis()
|
val start = System.currentTimeMillis()
|
||||||
|
|
||||||
try {
|
try {
|
||||||
LlmChatModelHelper.runInference(
|
llmChatModelHelper.runInference(
|
||||||
model = model,
|
model = model,
|
||||||
input = input,
|
input = input,
|
||||||
images = images,
|
images = images,
|
||||||
|
@ -195,7 +198,7 @@ open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTas
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
try {
|
try {
|
||||||
LlmChatModelHelper.resetSession(model = model)
|
llmChatModelHelper.resetSession(model = model)
|
||||||
break
|
break
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.d(TAG, "Failed to reset session. Trying again")
|
Log.d(TAG, "Failed to reset session. Trying again")
|
||||||
|
@ -262,12 +265,16 @@ open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTas
|
||||||
}
|
}
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class LlmChatViewModel @Inject constructor() : LlmChatViewModelBase(curTask = TASK_LLM_CHAT)
|
class LlmChatViewModel @Inject constructor(
|
||||||
|
llmChatModelHelper: LlmChatModelHelper
|
||||||
|
) : LlmChatViewModelBase(curTask = TASK_LLM_CHAT, llmChatModelHelper = llmChatModelHelper)
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class LlmAskImageViewModel @Inject constructor() :
|
class LlmAskImageViewModel @Inject constructor(
|
||||||
LlmChatViewModelBase(curTask = TASK_LLM_ASK_IMAGE)
|
llmChatModelHelper: LlmChatModelHelper
|
||||||
|
) : LlmChatViewModelBase(curTask = TASK_LLM_ASK_IMAGE, llmChatModelHelper = llmChatModelHelper)
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class LlmAskAudioViewModel @Inject constructor() :
|
class LlmAskAudioViewModel @Inject constructor(
|
||||||
LlmChatViewModelBase(curTask = TASK_LLM_ASK_AUDIO)
|
llmChatModelHelper: LlmChatModelHelper
|
||||||
|
) : LlmChatViewModelBase(curTask = TASK_LLM_ASK_AUDIO, llmChatModelHelper = llmChatModelHelper)
|
||||||
|
|
|
@ -66,7 +66,9 @@ private val STATS =
|
||||||
)
|
)
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
|
class LlmSingleTurnViewModel @Inject constructor(
|
||||||
|
private val llmChatModelHelper: LlmChatModelHelper
|
||||||
|
) : ViewModel() {
|
||||||
private val _uiState = MutableStateFlow(createUiState(task = TASK_LLM_PROMPT_LAB))
|
private val _uiState = MutableStateFlow(createUiState(task = TASK_LLM_PROMPT_LAB))
|
||||||
val uiState = _uiState.asStateFlow()
|
val uiState = _uiState.asStateFlow()
|
||||||
|
|
||||||
|
@ -80,7 +82,7 @@ class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
|
||||||
delay(100)
|
delay(100)
|
||||||
}
|
}
|
||||||
|
|
||||||
LlmChatModelHelper.resetSession(model = model)
|
llmChatModelHelper.resetSession(model = model)
|
||||||
delay(500)
|
delay(500)
|
||||||
|
|
||||||
// Run inference.
|
// Run inference.
|
||||||
|
@ -96,7 +98,7 @@ class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
|
||||||
val start = System.currentTimeMillis()
|
val start = System.currentTimeMillis()
|
||||||
var response = ""
|
var response = ""
|
||||||
var lastBenchmarkUpdateTs = 0L
|
var lastBenchmarkUpdateTs = 0L
|
||||||
LlmChatModelHelper.runInference(
|
llmChatModelHelper.runInference(
|
||||||
model = model,
|
model = model,
|
||||||
input = input,
|
input = input,
|
||||||
resultListener = { partialResult, done ->
|
resultListener = { partialResult, done ->
|
||||||
|
|
|
@ -144,6 +144,7 @@ constructor(
|
||||||
private val downloadRepository: DownloadRepository,
|
private val downloadRepository: DownloadRepository,
|
||||||
private val dataStoreRepository: DataStoreRepository,
|
private val dataStoreRepository: DataStoreRepository,
|
||||||
private val lifecycleProvider: AppLifecycleProvider,
|
private val lifecycleProvider: AppLifecycleProvider,
|
||||||
|
private val llmChatModelHelper: LlmChatModelHelper,
|
||||||
@ApplicationContext private val context: Context,
|
@ApplicationContext private val context: Context,
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
private val externalFilesDir = context.getExternalFilesDir(null)
|
private val externalFilesDir = context.getExternalFilesDir(null)
|
||||||
|
@ -292,8 +293,8 @@ constructor(
|
||||||
TaskType.LLM_ASK_IMAGE,
|
TaskType.LLM_ASK_IMAGE,
|
||||||
TaskType.LLM_ASK_AUDIO,
|
TaskType.LLM_ASK_AUDIO,
|
||||||
TaskType.LLM_PROMPT_LAB ->
|
TaskType.LLM_PROMPT_LAB ->
|
||||||
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
|
llmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
|
||||||
|
TaskType.TOGGLE_SERVER -> {}
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
}
|
}
|
||||||
|
@ -308,8 +309,8 @@ constructor(
|
||||||
TaskType.LLM_CHAT,
|
TaskType.LLM_CHAT,
|
||||||
TaskType.LLM_PROMPT_LAB,
|
TaskType.LLM_PROMPT_LAB,
|
||||||
TaskType.LLM_ASK_IMAGE,
|
TaskType.LLM_ASK_IMAGE,
|
||||||
TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model)
|
TaskType.LLM_ASK_AUDIO -> llmChatModelHelper.cleanUp(model = model)
|
||||||
|
TaskType.TOGGLE_SERVER -> {}
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
package com.google.ai.edge.gallery.ui.navigation
|
||||||
|
|
||||||
|
interface Destination {
|
||||||
|
val route: String
|
||||||
|
}
|
|
@ -69,6 +69,8 @@ import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
|
||||||
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
|
||||||
import com.google.ai.edge.gallery.ui.modelmanager.ModelManager
|
import com.google.ai.edge.gallery.ui.modelmanager.ModelManager
|
||||||
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
|
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
|
||||||
|
import com.google.ai.edge.gallery.ui.toggleserver.ToggleServerDestination
|
||||||
|
import com.google.ai.edge.gallery.ui.toggleserver.ToggleServerScreen
|
||||||
|
|
||||||
private const val TAG = "AGGalleryNavGraph"
|
private const val TAG = "AGGalleryNavGraph"
|
||||||
private const val ROUTE_PLACEHOLDER = "placeholder"
|
private const val ROUTE_PLACEHOLDER = "placeholder"
|
||||||
|
@ -143,7 +145,14 @@ fun GalleryNavHost(
|
||||||
modelManagerViewModel = modelManagerViewModel,
|
modelManagerViewModel = modelManagerViewModel,
|
||||||
navigateToTaskScreen = { task ->
|
navigateToTaskScreen = { task ->
|
||||||
pickedTask = task
|
pickedTask = task
|
||||||
showModelManager = true
|
if (task.type == TaskType.TOGGLE_SERVER) {
|
||||||
|
navigateToTaskScreen(
|
||||||
|
navController = navController,
|
||||||
|
taskType = task.type,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
showModelManager = true
|
||||||
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -260,6 +269,14 @@ fun GalleryNavHost(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
composable(
|
||||||
|
route = ToggleServerDestination.route,
|
||||||
|
enterTransition = { slideEnter() },
|
||||||
|
exitTransition = { slideExit() },
|
||||||
|
) {
|
||||||
|
ToggleServerScreen()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle incoming intents for deep links
|
// Handle incoming intents for deep links
|
||||||
|
@ -294,6 +311,7 @@ fun navigateToTaskScreen(
|
||||||
TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}")
|
TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}")
|
||||||
TaskType.LLM_PROMPT_LAB ->
|
TaskType.LLM_PROMPT_LAB ->
|
||||||
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
|
||||||
|
TaskType.TOGGLE_SERVER -> navController.navigate(ToggleServerDestination.route)
|
||||||
TaskType.TEST_TASK_1 -> {}
|
TaskType.TEST_TASK_1 -> {}
|
||||||
TaskType.TEST_TASK_2 -> {}
|
TaskType.TEST_TASK_2 -> {}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
package com.google.ai.edge.gallery.ui.toggleserver
|
||||||
|
|
||||||
|
import com.google.ai.edge.gallery.ui.navigation.Destination
|
||||||
|
|
||||||
|
object ToggleServerDestination : Destination {
|
||||||
|
override val route = "toggle_server"
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
package com.google.ai.edge.gallery.ui.toggleserver
|
||||||
|
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Column
|
||||||
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
|
import androidx.compose.material3.Button
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.collectAsState
|
||||||
|
import androidx.compose.runtime.getValue
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.hilt.navigation.compose.hiltViewModel
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
fun ToggleServerScreen(
|
||||||
|
toggleServerViewModel: ToggleServerViewModel = hiltViewModel()
|
||||||
|
) {
|
||||||
|
val isServerRunning by toggleServerViewModel.isServerRunning.collectAsState()
|
||||||
|
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
verticalArrangement = Arrangement.Center,
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally
|
||||||
|
) {
|
||||||
|
Button(onClick = { toggleServerViewModel.toggleServer() }) {
|
||||||
|
Text(if (isServerRunning) "Stop In-App Server" else "Start In-App Server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package com.google.ai.edge.gallery.ui.toggleserver
|
||||||
|
|
||||||
|
import androidx.lifecycle.ViewModel
|
||||||
|
import com.google.ai.edge.gallery.server.InAppServer
|
||||||
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
|
import javax.inject.Inject
|
||||||
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
import android.util.Log
|
||||||
|
|
||||||
|
@HiltViewModel
|
||||||
|
class ToggleServerViewModel @Inject constructor(
|
||||||
|
private val inAppServer: InAppServer
|
||||||
|
) : ViewModel() {
|
||||||
|
|
||||||
|
private val _isServerRunning = MutableStateFlow(inAppServer.isRunning())
|
||||||
|
val isServerRunning: StateFlow<Boolean> = _isServerRunning
|
||||||
|
|
||||||
|
fun toggleServer() {
|
||||||
|
Log.d("ToggleServerViewModel", "toggleServer called")
|
||||||
|
if (inAppServer.isRunning()) {
|
||||||
|
inAppServer.stop()
|
||||||
|
} else {
|
||||||
|
inAppServer.start()
|
||||||
|
}
|
||||||
|
_isServerRunning.value = inAppServer.isRunning()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onCleared() {
|
||||||
|
super.onCleared()
|
||||||
|
inAppServer.stop()
|
||||||
|
}
|
||||||
|
}
|
28
README.md
28
README.md
|
@ -19,6 +19,34 @@ The Google AI Edge Gallery is an experimental app that puts the power of cutting
|
||||||
**AI Chat**
|
**AI Chat**
|
||||||
<img width="1532" alt="AI Chat" src="https://github.com/user-attachments/assets/edaa4f89-237a-4b84-b647-b3c4631f09dc" />
|
<img width="1532" alt="AI Chat" src="https://github.com/user-attachments/assets/edaa4f89-237a-4b84-b647-b3c4631f09dc" />
|
||||||
|
|
||||||
|
## 🔌 Toggle Server
|
||||||
|
|
||||||
|
The "Toggle Server" feature runs a local HTTP server on your mobile device that allows you to interact with the on-device AI models from your laptop using `curl`, with all communication tunneled exclusively over a USB cable connection.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
1. **Enable USB Debugging**:
|
||||||
|
* Follow these [steps](https://developer.android.com/studio/debug/dev-options) to enable ADB port forwarding between your device and computer.
|
||||||
|
|
||||||
|
2. **Connect Device to Computer & Enable Port Forwarding**:
|
||||||
|
```bash
|
||||||
|
adb -d forward tcp:8080 tcp:8080
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Start the Server in the App**:
|
||||||
|
* Navigate to the "Toggle Server" screen.
|
||||||
|
* Tap the "Start In-App Server" button.
|
||||||
|
|
||||||
|
4. **Send Requests with `curl`**:
|
||||||
|
* **Prompt only**:
|
||||||
|
```bash
|
||||||
|
curl -X POST -F "prompt=Hello, world!" http://localhost:8080
|
||||||
|
```
|
||||||
|
* **Image and prompt**:
|
||||||
|
```bash
|
||||||
|
curl -X POST -F "prompt=What is in this image?" -F "image=@/path/to/your/image.jpg" http://localhost:8080
|
||||||
|
```
|
||||||
|
|
||||||
## ✨ Core Features
|
## ✨ Core Features
|
||||||
|
|
||||||
* **📱 Run Locally, Fully Offline:** Experience the magic of GenAI without an internet connection. All processing happens directly on your device.
|
* **📱 Run Locally, Fully Offline:** Experience the magic of GenAI without an internet connection. All processing happens directly on your device.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue