This commit is contained in:
Niral 2025-07-08 20:48:51 +03:00 committed by GitHub
commit ddc62e791d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 453 additions and 19 deletions

2
.gitignore vendored
View file

@ -1 +1,3 @@
.DS_Store
.idea/
.gemini/

View file

@ -101,6 +101,7 @@ dependencies {
implementation(libs.play.services.oss.licenses)
implementation(platform(libs.firebase.bom))
implementation(libs.firebase.analytics)
implementation("commons-fileupload:commons-fileupload:1.4")
kapt(libs.hilt.android.compiler)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
@ -115,4 +116,4 @@ dependencies {
protobuf {
protoc { artifact = "com.google.protobuf:protoc:4.26.1" }
generateProtoTasks { all().forEach { it.plugins { create("java") { option("lite") } } } }
}
}

View file

@ -32,6 +32,7 @@
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<uses-permission android:name="android.permission.WAKE_LOCK"/>
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-feature
android:name="android.hardware.camera"
@ -47,6 +48,7 @@
android:roundIcon="@mipmap/ic_launcher"
android:supportsRtl="true"
android:theme="@style/Theme.Gallery"
android:enableOnBackInvokedCallback="true"
tools:targetApi="31">
<!--
android:configChanges="uiMode" tells the system don't destroy and

View file

@ -26,6 +26,7 @@ import androidx.activity.enableEdgeToEdge
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.Surface
import androidx.compose.ui.Modifier
import androidx.core.app.ActivityCompat
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import com.google.ai.edge.gallery.ui.theme.GalleryTheme
import com.google.firebase.analytics.FirebaseAnalytics
@ -60,6 +61,10 @@ class MainActivity : ComponentActivity() {
setContent { GalleryTheme { Surface(modifier = Modifier.fillMaxSize()) { GalleryApp() } } }
// Keep the screen on while the app is running for better demo experience.
window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
requestPermissions(arrayOf(android.Manifest.permission.READ_EXTERNAL_STORAGE), 1)
}
}
companion object {

View file

@ -33,6 +33,7 @@ enum class TaskType(val label: String, val id: String) {
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
TOGGLE_SERVER(label = "Toggle Server", id = "toggle_server"),
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
}
@ -121,9 +122,20 @@ val TASK_LLM_ASK_AUDIO =
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
val TASK_TOGGLE_SERVER =
Task(
type = TaskType.TOGGLE_SERVER,
icon = Icons.Outlined.Forum,
models = mutableListOf(),
description = "Toggle an LLM endpoint server running on-device (Placeholder).",
docUrl = "",
sourceCodeUrl = "",
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
)
/** All tasks. */
val TASKS: List<Task> =
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT, TASK_TOGGLE_SERVER)
fun getModelByName(name: String): Model? {
for (task in TASKS) {

View file

@ -0,0 +1,278 @@
package com.google.ai.edge.gallery.server
import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.util.Log
import com.google.ai.edge.gallery.data.TASK_LLM_ASK_IMAGE
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import dagger.hilt.android.qualifiers.ApplicationContext
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.PrintWriter
import java.net.ServerSocket
import java.net.Socket
import java.net.SocketException
import java.net.URLDecoder
import java.util.concurrent.CountDownLatch
import javax.inject.Inject
import javax.inject.Singleton
import org.apache.commons.fileupload.FileItem
import org.apache.commons.fileupload.disk.DiskFileItemFactory
import org.apache.commons.fileupload.servlet.ServletFileUpload
@Singleton
class InAppServer @Inject constructor(
@ApplicationContext private val context: Context,
private val llmChatModelHelper: LlmChatModelHelper
) {
private var serverSocket: ServerSocket? = null
private var serverThread: Thread? = null
@Volatile
private var isServerRunning = false
fun start() {
if (isServerRunning) return
serverThread = Thread {
try {
llmChatModelHelper.initialize(context, TASK_LLM_ASK_IMAGE.models.first()) {
if (it.isNotEmpty()) {
Log.e(TAG, "Failed to initialize model: $it")
return@initialize
}
}
serverSocket = ServerSocket(DEVICE_PORT)
isServerRunning = true
Log.i(TAG, "In-App Server started on port " + DEVICE_PORT)
while (isServerRunning) {
try {
val clientSocket = serverSocket!!.accept()
Log.i(TAG, "Client connected: " + clientSocket.inetAddress)
handleClient(clientSocket)
} catch (e: SocketException) {
if (!isServerRunning) {
Log.i(TAG, "Server socket closed intentionally.")
} else {
Log.e(TAG, "Error accepting connection", e)
}
}
}
} catch (e: IOException) {
Log.e(TAG, "Error starting server", e)
isServerRunning = false
}
}
serverThread!!.start()
}
fun stop() {
if (!isServerRunning) return
try {
isServerRunning = false
if (serverSocket != null && !serverSocket!!.isClosed) {
serverSocket!!.close()
}
if (serverThread != null) {
serverThread!!.interrupt()
serverThread = null
}
llmChatModelHelper.cleanUp(TASK_LLM_ASK_IMAGE.models.first())
Log.i(TAG, "In-App Server stopped.")
} catch (e: IOException) {
Log.e(TAG, "Error stopping server", e)
}
}
private fun handleClient(clientSocket: Socket) {
try {
val inputStream = clientSocket.inputStream
val writer = PrintWriter(clientSocket.outputStream, true)
val requestLine = readLine(inputStream)
if (requestLine.isBlank()) {
clientSocket.close()
return
}
Log.i(TAG, "Request: $requestLine")
val requestParts = requestLine.split(" ")
val method = requestParts[0]
var contentType = ""
var contentLength = 0
var line = readLine(inputStream)
while (line.isNotEmpty()) {
if (line.startsWith("Content-Type:", ignoreCase = true)) {
contentType = line.substringAfter(":").trim()
} else if (line.startsWith("Content-Length:", ignoreCase = true)) {
contentLength = line.substringAfter(":").trim().toInt()
}
line = readLine(inputStream)
}
var prompt = ""
var imageData: ByteArray? = null
if (method == "POST") {
if (contentLength > 0) {
val bodyBytes = ByteArray(contentLength)
var bytesRead = 0
while (bytesRead < contentLength) {
val read = inputStream.read(bodyBytes, bytesRead, contentLength - bytesRead)
if (read == -1) break
bytesRead += read
}
if (ServletFileUpload.isMultipartContent(RequestContext(ByteArrayInputStream(bodyBytes), contentType, contentLength))) {
val factory = DiskFileItemFactory()
val upload = ServletFileUpload(factory)
val items = upload.parseRequest(RequestContext(ByteArrayInputStream(bodyBytes), contentType, contentLength))
for (item in items) {
if (item.isFormField) {
if (item.fieldName == "prompt") {
prompt = item.string
}
} else {
if (item.fieldName == "image") {
imageData = item.get()
}
}
}
} else {
prompt = String(bodyBytes)
}
}
} else { // GET
val queryParams = getQueryParams(requestLine)
prompt = queryParams["prompt"] ?: ""
}
if (prompt.isBlank() && imageData == null) {
writer.println("HTTP/1.1 400 Bad Request")
writer.println("Content-Type: text/plain")
writer.println()
writer.println("No prompt or image provided.")
writer.flush()
clientSocket.close()
return
}
writer.println("HTTP/1.1 200 OK")
writer.println("Content-Type: text/plain")
writer.println("Connection: close")
writer.println()
writer.flush()
val latch = CountDownLatch(1)
llmChatModelHelper.resetSession(TASK_LLM_ASK_IMAGE.models.first())
val images: List<Bitmap> = imageData?.let {
val bitmap = BitmapFactory.decodeByteArray(it, 0, it.size)
listOf(bitmap)
} ?: emptyList()
llmChatModelHelper.runInference(
model = TASK_LLM_ASK_IMAGE.models.first(),
input = prompt,
images = images,
resultListener = { partialResult, done ->
writer.print(partialResult)
writer.flush()
if (done) {
clientSocket.close()
latch.countDown()
}
},
cleanUpListener = {
if (!clientSocket.isClosed) {
writer.flush()
clientSocket.close()
}
latch.countDown()
}
)
latch.await()
} catch (e: Exception) {
Log.e(TAG, "Error handling client", e)
} finally {
try {
if (!clientSocket.isClosed) {
clientSocket.close()
}
} catch (e: IOException) {
Log.e(TAG, "Error closing client socket", e)
}
}
}
private fun readLine(stream: InputStream): String {
val buffer = ByteArrayOutputStream()
while (true) {
val b = stream.read()
if (b == -1) break
if (b == '\n'.code) {
break
}
buffer.write(b)
}
val bytes = buffer.toByteArray()
if (bytes.isNotEmpty() && bytes.last() == '\r'.toByte()) {
return String(bytes, 0, bytes.size - 1, Charsets.ISO_8859_1)
}
return String(bytes, Charsets.ISO_8859_1)
}
private fun getQueryParams(requestLine: String): Map<String, String> {
val queryParams = mutableMapOf<String, String>()
val urlParts = requestLine.split(" ")[1].split("?")
if (urlParts.size > 1) {
val query = urlParts[1]
for (param in query.split("&")) {
val pair = param.split("=")
if (pair.size > 1) {
queryParams[URLDecoder.decode(pair[0], "UTF-8")] =
URLDecoder.decode(pair[1], "UTF-8")
}
}
}
return queryParams
}
fun isRunning(): Boolean {
return isServerRunning
}
companion object {
private const val TAG = "AIEdgeServer"
private const val DEVICE_PORT = 8080
}
}
class RequestContext(
private val inputStream: java.io.InputStream,
private val contentType: String,
private val contentLength: Int
) : org.apache.commons.fileupload.RequestContext {
override fun getCharacterEncoding(): String {
return "UTF-8"
}
override fun getContentType(): String {
return contentType
}
override fun getContentLength(): Int {
return contentLength
}
override fun getInputStream(): java.io.InputStream {
return inputStream
}
}

View file

@ -32,6 +32,8 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder
import com.google.mediapipe.tasks.genai.llminference.GraphOptions
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
import javax.inject.Inject
import javax.inject.Singleton
private const val TAG = "AGLlmChatModelHelper"
@ -41,7 +43,8 @@ typealias CleanUpListener = () -> Unit
data class LlmModelInstance(val engine: LlmInference, var session: LlmInferenceSession)
object LlmChatModelHelper {
@Singleton
class LlmChatModelHelper @Inject constructor() {
// Indexed by model name.
private val cleanUpListeners: MutableMap<String, CleanUpListener> = mutableMapOf()

View file

@ -51,7 +51,10 @@ private val STATS =
Stat(id = "latency", label = "Latency", unit = "sec"),
)
open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTask) {
open class LlmChatViewModelBase(
val curTask: Task,
private val llmChatModelHelper: LlmChatModelHelper
) : ChatViewModel(task = curTask) {
fun generateResponse(
model: Model,
input: String,
@ -92,7 +95,7 @@ open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTas
val start = System.currentTimeMillis()
try {
LlmChatModelHelper.runInference(
llmChatModelHelper.runInference(
model = model,
input = input,
images = images,
@ -195,7 +198,7 @@ open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTas
while (true) {
try {
LlmChatModelHelper.resetSession(model = model)
llmChatModelHelper.resetSession(model = model)
break
} catch (e: Exception) {
Log.d(TAG, "Failed to reset session. Trying again")
@ -262,12 +265,16 @@ open class LlmChatViewModelBase(val curTask: Task) : ChatViewModel(task = curTas
}
@HiltViewModel
class LlmChatViewModel @Inject constructor() : LlmChatViewModelBase(curTask = TASK_LLM_CHAT)
class LlmChatViewModel @Inject constructor(
llmChatModelHelper: LlmChatModelHelper
) : LlmChatViewModelBase(curTask = TASK_LLM_CHAT, llmChatModelHelper = llmChatModelHelper)
@HiltViewModel
class LlmAskImageViewModel @Inject constructor() :
LlmChatViewModelBase(curTask = TASK_LLM_ASK_IMAGE)
class LlmAskImageViewModel @Inject constructor(
llmChatModelHelper: LlmChatModelHelper
) : LlmChatViewModelBase(curTask = TASK_LLM_ASK_IMAGE, llmChatModelHelper = llmChatModelHelper)
@HiltViewModel
class LlmAskAudioViewModel @Inject constructor() :
LlmChatViewModelBase(curTask = TASK_LLM_ASK_AUDIO)
class LlmAskAudioViewModel @Inject constructor(
llmChatModelHelper: LlmChatModelHelper
) : LlmChatViewModelBase(curTask = TASK_LLM_ASK_AUDIO, llmChatModelHelper = llmChatModelHelper)

View file

@ -66,7 +66,9 @@ private val STATS =
)
@HiltViewModel
class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
class LlmSingleTurnViewModel @Inject constructor(
private val llmChatModelHelper: LlmChatModelHelper
) : ViewModel() {
private val _uiState = MutableStateFlow(createUiState(task = TASK_LLM_PROMPT_LAB))
val uiState = _uiState.asStateFlow()
@ -80,7 +82,7 @@ class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
delay(100)
}
LlmChatModelHelper.resetSession(model = model)
llmChatModelHelper.resetSession(model = model)
delay(500)
// Run inference.
@ -96,7 +98,7 @@ class LlmSingleTurnViewModel @Inject constructor() : ViewModel() {
val start = System.currentTimeMillis()
var response = ""
var lastBenchmarkUpdateTs = 0L
LlmChatModelHelper.runInference(
llmChatModelHelper.runInference(
model = model,
input = input,
resultListener = { partialResult, done ->

View file

@ -144,6 +144,7 @@ constructor(
private val downloadRepository: DownloadRepository,
private val dataStoreRepository: DataStoreRepository,
private val lifecycleProvider: AppLifecycleProvider,
private val llmChatModelHelper: LlmChatModelHelper,
@ApplicationContext private val context: Context,
) : ViewModel() {
private val externalFilesDir = context.getExternalFilesDir(null)
@ -292,8 +293,8 @@ constructor(
TaskType.LLM_ASK_IMAGE,
TaskType.LLM_ASK_AUDIO,
TaskType.LLM_PROMPT_LAB ->
LlmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
llmChatModelHelper.initialize(context = context, model = model, onDone = onDone)
TaskType.TOGGLE_SERVER -> {}
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}
@ -308,8 +309,8 @@ constructor(
TaskType.LLM_CHAT,
TaskType.LLM_PROMPT_LAB,
TaskType.LLM_ASK_IMAGE,
TaskType.LLM_ASK_AUDIO -> LlmChatModelHelper.cleanUp(model = model)
TaskType.LLM_ASK_AUDIO -> llmChatModelHelper.cleanUp(model = model)
TaskType.TOGGLE_SERVER -> {}
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}

View file

@ -0,0 +1,5 @@
package com.google.ai.edge.gallery.ui.navigation
interface Destination {
val route: String
}

View file

@ -69,6 +69,8 @@ import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnScreen
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManager
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.toggleserver.ToggleServerDestination
import com.google.ai.edge.gallery.ui.toggleserver.ToggleServerScreen
private const val TAG = "AGGalleryNavGraph"
private const val ROUTE_PLACEHOLDER = "placeholder"
@ -143,7 +145,14 @@ fun GalleryNavHost(
modelManagerViewModel = modelManagerViewModel,
navigateToTaskScreen = { task ->
pickedTask = task
showModelManager = true
if (task.type == TaskType.TOGGLE_SERVER) {
navigateToTaskScreen(
navController = navController,
taskType = task.type,
)
} else {
showModelManager = true
}
},
)
@ -260,6 +269,14 @@ fun GalleryNavHost(
)
}
}
composable(
route = ToggleServerDestination.route,
enterTransition = { slideEnter() },
exitTransition = { slideExit() },
) {
ToggleServerScreen()
}
}
// Handle incoming intents for deep links
@ -294,6 +311,7 @@ fun navigateToTaskScreen(
TaskType.LLM_ASK_AUDIO -> navController.navigate("${LlmAskAudioDestination.route}/${modelName}")
TaskType.LLM_PROMPT_LAB ->
navController.navigate("${LlmSingleTurnDestination.route}/${modelName}")
TaskType.TOGGLE_SERVER -> navController.navigate(ToggleServerDestination.route)
TaskType.TEST_TASK_1 -> {}
TaskType.TEST_TASK_2 -> {}
}

View file

@ -0,0 +1,7 @@
package com.google.ai.edge.gallery.ui.toggleserver
import com.google.ai.edge.gallery.ui.navigation.Destination
object ToggleServerDestination : Destination {
override val route = "toggle_server"
}

View file

@ -0,0 +1,30 @@
package com.google.ai.edge.gallery.ui.toggleserver
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.Button
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.hilt.navigation.compose.hiltViewModel
@Composable
fun ToggleServerScreen(
toggleServerViewModel: ToggleServerViewModel = hiltViewModel()
) {
val isServerRunning by toggleServerViewModel.isServerRunning.collectAsState()
Column(
modifier = Modifier.fillMaxSize(),
verticalArrangement = Arrangement.Center,
horizontalAlignment = Alignment.CenterHorizontally
) {
Button(onClick = { toggleServerViewModel.toggleServer() }) {
Text(if (isServerRunning) "Stop In-App Server" else "Start In-App Server")
}
}
}

View file

@ -0,0 +1,33 @@
package com.google.ai.edge.gallery.ui.toggleserver
import androidx.lifecycle.ViewModel
import com.google.ai.edge.gallery.server.InAppServer
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import android.util.Log
@HiltViewModel
class ToggleServerViewModel @Inject constructor(
private val inAppServer: InAppServer
) : ViewModel() {
private val _isServerRunning = MutableStateFlow(inAppServer.isRunning())
val isServerRunning: StateFlow<Boolean> = _isServerRunning
fun toggleServer() {
Log.d("ToggleServerViewModel", "toggleServer called")
if (inAppServer.isRunning()) {
inAppServer.stop()
} else {
inAppServer.start()
}
_isServerRunning.value = inAppServer.isRunning()
}
override fun onCleared() {
super.onCleared()
inAppServer.stop()
}
}

View file

@ -19,6 +19,34 @@ The Google AI Edge Gallery is an experimental app that puts the power of cutting
**AI Chat**
<img width="1532" alt="AI Chat" src="https://github.com/user-attachments/assets/edaa4f89-237a-4b84-b647-b3c4631f09dc" />
## 🔌 Toggle Server
The "Toggle Server" feature runs a local HTTP server on your mobile device that allows you to interact with the on-device AI models from your laptop using `curl`, with all communication tunneled exclusively over a USB cable connection.
### Usage
1. **Enable USB Debugging**:
* Follow these [steps](https://developer.android.com/studio/debug/dev-options) to enable ADB port forwarding between your device and computer.
2. **Connect Device to Computer & Enable Port Forwarding**:
```bash
adb -d forward tcp:8080 tcp:8080
```
3. **Start the Server in the App**:
* Navigate to the "Toggle Server" screen.
* Tap the "Start In-App Server" button.
4. **Send Requests with `curl`**:
* **Prompt only**:
```bash
curl -X POST -F "prompt=Hello, world!" http://localhost:8080
```
* **Image and prompt**:
```bash
curl -X POST -F "prompt=What is in this image?" -F "image=@/path/to/your/image.jpg" http://localhost:8080
```
## ✨ Core Features
* **📱 Run Locally, Fully Offline:** Experience the magic of GenAI without an internet connection. All processing happens directly on your device.