From 2a95e5853b6cd67e952dd1f2ad1c104969810e14 Mon Sep 17 00:00:00 2001 From: Google AI Edge Gallery Date: Tue, 24 Jun 2025 09:23:21 -0700 Subject: [PATCH] Add a simple local test for allowlisted model. PiperOrigin-RevId: 775265777 --- .../ai/edge/gallery/data/ModelAllowlist.kt | 8 +- .../edge/gallery/data/ModelAllowlistTest.kt | 107 ++++++++++++++++++ 2 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt index 5cf11ca..03263ca 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt @@ -50,10 +50,10 @@ data class AllowedModel( taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id) var configs: List = listOf() if (isLlmModel) { - var defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK - var defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP - var defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE - var defaultMaxToken = defaultConfig.maxTokens ?: 1024 + val defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK + val defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP + val defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE + val defaultMaxToken = defaultConfig.maxTokens ?: 1024 var accelerators: List = DEFAULT_ACCELERATORS if (defaultConfig.accelerators != null) { val items = defaultConfig.accelerators.split(",") diff --git a/Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt b/Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt new file mode 100644 index 0000000..bfc04bb --- /dev/null +++ b/Android/src/app/src/test/java/com/google/ai/edge/gallery/data/ModelAllowlistTest.kt @@ -0,0 +1,107 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.data + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +@RunWith(JUnit4::class) +class ModelAllowlistTest { + @Test + fun toModel_success() { + val modelName = "test_model" + val modelId = "test_model_id" + val modelFile = "test_model_file" + val description = "test description" + val sizeInBytes = 100L + val version = "20250623" + val topK = 10 + val topP = 0.5f + val temperature = 0.1f + val maxTokens = 1000 + val accelerators = "gpu,cpu" + val taskTypes = listOf("llm_chat", "ask_image") + val estimatedPeakMemoryInBytes = 300L + + val allowedModel = + AllowedModel( + name = modelName, + modelId = modelId, + modelFile = modelFile, + description = description, + sizeInBytes = sizeInBytes, + version = version, + defaultConfig = + DefaultConfig( + topK = topK, + topP = topP, + temperature = temperature, + maxTokens = maxTokens, + accelerators = accelerators, + ), + taskTypes = taskTypes, + llmSupportImage = true, + llmSupportAudio = true, + estimatedPeakMemoryInBytes = estimatedPeakMemoryInBytes, + ) + val model = allowedModel.toModel() + + // Check that basic fields are set correctly. + assertEquals(model.name, modelName) + assertEquals(model.version, version) + assertEquals(model.info, description) + assertEquals( + model.url, + "https://huggingface.co/test_model_id/resolve/main/test_model_file?download=true", + ) + assertEquals(model.sizeInBytes, sizeInBytes) + assertEquals(model.estimatedPeakMemoryInBytes, estimatedPeakMemoryInBytes) + assertEquals(model.downloadFileName, modelFile) + assertFalse(model.showBenchmarkButton) + assertFalse(model.showRunAgainButton) + assertTrue(model.llmSupportImage) + assertTrue(model.llmSupportAudio) + + // Check that configs are set correctly. + assertEquals(model.configs.size, 5) + + // A label for showing max tokens (non-changeable). + assertTrue(model.configs[0] is LabelConfig) + assertEquals((model.configs[0] as LabelConfig).defaultValue, "$maxTokens") + + // A slider for topK. + assertTrue(model.configs[1] is NumberSliderConfig) + assertEquals((model.configs[1] as NumberSliderConfig).defaultValue, topK.toFloat()) + + // A slider for topP. + assertTrue(model.configs[2] is NumberSliderConfig) + assertEquals((model.configs[2] as NumberSliderConfig).defaultValue, topP) + + // A slider for temperature. + assertTrue(model.configs[3] is NumberSliderConfig) + assertEquals((model.configs[3] as NumberSliderConfig).defaultValue, temperature) + + // A segmented button for accelerators. + assertTrue(model.configs[4] is SegmentedButtonConfig) + assertEquals((model.configs[4] as SegmentedButtonConfig).defaultValue, "GPU") + assertEquals((model.configs[4] as SegmentedButtonConfig).options, listOf("GPU", "CPU")) + } +}