This commit is contained in:
Amirhossein Ghanipour 2025-08-16 12:32:16 +09:00 committed by GitHub
commit 27c4fa226a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 235 additions and 22 deletions

33
Cargo.lock generated
View file

@ -572,6 +572,16 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fs2"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
dependencies = [
"libc",
"winapi",
]
[[package]]
name = "futures"
version = "0.3.31"
@ -1324,6 +1334,7 @@ dependencies = [
"bstr",
"clap",
"fancy-regex",
"fs2",
"futures",
"image",
"pretty_assertions",
@ -2556,6 +2567,28 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.61.0"

View file

@ -18,6 +18,7 @@ wasm-binding = ["wasm-bindgen", "serde-wasm-bindgen", "wasm-bindgen-futures"]
[dependencies]
anyhow = "1.0.98"
base64 = "0.22.1"
fs2 = "0.4.3"
image = "0.25.6"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = { version = "1.0.140", features = ["preserve_order"] }

View file

@ -36,13 +36,10 @@ from pydantic import BaseModel, Field
try:
from .openai_harmony import (
HarmonyError as HarmonyError, # expose the actual Rust error directly
)
from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding # type: ignore
from .openai_harmony import (
PyHarmonyEncoding as _PyHarmonyEncoding, # type: ignore
PyStreamableParser as _PyStreamableParser, # type: ignore
)
from .openai_harmony import (
load_harmony_encoding as _load_harmony_encoding, # type: ignore
load_harmony_encoding_from_file as _load_harmony_encoding_from_file, # type: ignore
)
except ModuleNotFoundError: # pragma: no cover raised during type-checking
@ -690,6 +687,32 @@ def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding:
return HarmonyEncoding(inner)
def load_harmony_encoding_from_file(
name: str,
vocab_file: str,
special_tokens: list[tuple[str, int]],
pattern: str,
n_ctx: int,
max_message_tokens: int,
max_action_length: int,
expected_hash: str | None = None,
) -> HarmonyEncoding:
"""Load a HarmonyEncoding from a local vocab file (offline usage).
Use this when network access is restricted or for reproducible builds where you want to avoid remote downloads.
"""
inner: _PyHarmonyEncoding = _load_harmony_encoding_from_file(
name,
vocab_file,
special_tokens,
pattern,
n_ctx,
max_message_tokens,
max_action_length,
expected_hash,
)
return HarmonyEncoding(inner)
# For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum. At
# **runtime** the user is expected to pass the *string* names because the Rust
# side only operates on strings anyway.
@ -719,4 +742,5 @@ __all__ = [
"StreamableParser",
"StreamState",
"HarmonyError",
"load_harmony_encoding_from_file",
]

View file

@ -154,6 +154,31 @@ impl HarmonyEncoding {
})
.collect()
}
pub fn from_local_file(
name: String,
vocab_file: &std::path::Path,
expected_hash: Option<&str>,
special_tokens: impl IntoIterator<Item = (String, u32)>,
pattern: &str,
n_ctx: usize,
max_message_tokens: usize,
max_action_length: usize,
) -> anyhow::Result<Self> {
use crate::tiktoken_ext::public_encodings::load_encoding_from_file;
let bpe = load_encoding_from_file(vocab_file, expected_hash, special_tokens, pattern)?;
Ok(HarmonyEncoding {
name,
n_ctx,
max_message_tokens,
max_action_length,
tokenizer_name: vocab_file.display().to_string(),
tokenizer: std::sync::Arc::new(bpe),
format_token_mapping: Default::default(),
stop_formatting_tokens: Default::default(),
stop_formatting_tokens_for_assistant_actions: Default::default(),
})
}
}
// Methods for rendering conversations

View file

@ -396,6 +396,38 @@ fn openai_harmony(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
}
m.add_function(pyo3::wrap_pyfunction!(load_harmony_encoding_py, m)?)?;
// Convenience function to load a HarmonyEncoding from a local vocab file for offline
// scenarios or reproducible builds where remote download is not possible.
#[pyfunction(name = "load_harmony_encoding_from_file")]
fn load_harmony_encoding_from_file_py(
py: Python<'_>,
name: &str,
vocab_file: &str,
special_tokens: Vec<(String, u32)>,
pattern: &str,
n_ctx: usize,
max_message_tokens: usize,
max_action_length: usize,
expected_hash: Option<&str>,
) -> PyResult<Py<PyHarmonyEncoding>> {
let encoding = HarmonyEncoding::from_local_file(
name.to_string(),
std::path::Path::new(vocab_file),
expected_hash,
special_tokens,
pattern,
n_ctx,
max_message_tokens,
max_action_length,
)
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
Py::new(py, PyHarmonyEncoding { inner: encoding })
}
m.add_function(pyo3::wrap_pyfunction!(
load_harmony_encoding_from_file_py,
m
)?)?;
// Convenience functions to get the tool configs for the browser and python tools.
#[pyfunction]
fn get_tool_namespace_config(py: Python<'_>, tool: &str) -> PyResult<PyObject> {

View file

@ -1,2 +1,2 @@
mod public_encodings;
pub mod public_encodings;
pub use public_encodings::{set_tiktoken_base_url, Encoding};

View file

@ -9,6 +9,7 @@ use std::{
use base64::{prelude::BASE64_STANDARD, Engine as _};
use crate::tiktoken::{CoreBPE, Rank};
use fs2::FileExt;
use sha1::Sha1;
use sha2::{Digest as _, Sha256};
@ -420,24 +421,35 @@ fn download_or_find_cached_file(
) -> Result<PathBuf, RemoteVocabFileError> {
let cache_dir = resolve_cache_dir()?;
let cache_path = resolve_cache_path(&cache_dir, url);
if cache_path.exists() {
if verify_file_hash(&cache_path, expected_hash)? {
return Ok(cache_path);
}
let _ = std::fs::remove_file(&cache_path);
}
let hash = load_remote_file(url, &cache_path)?;
if let Some(expected_hash) = expected_hash {
if hash != expected_hash {
let lock_path = cache_path.with_extension("lock");
let lock_file = File::create(&lock_path).map_err(|e| {
RemoteVocabFileError::IOError(format!("creating lock file {lock_path:?}"), e)
})?;
lock_file
.lock_exclusive()
.map_err(|e| RemoteVocabFileError::IOError(format!("locking file {lock_path:?}"), e))?;
let result = (|| {
if cache_path.exists() {
if verify_file_hash(&cache_path, expected_hash)? {
return Ok(cache_path);
}
let _ = std::fs::remove_file(&cache_path);
return Err(RemoteVocabFileError::HashMismatch {
file_url: url.to_string(),
expected_hash: expected_hash.to_string(),
computed_hash: hash,
});
}
}
Ok(cache_path)
let hash = load_remote_file(url, &cache_path)?;
if let Some(expected_hash) = expected_hash {
if hash != expected_hash {
let _ = std::fs::remove_file(&cache_path);
return Err(RemoteVocabFileError::HashMismatch {
file_url: url.to_string(),
expected_hash: expected_hash.to_string(),
computed_hash: hash,
});
}
}
Ok(cache_path)
})();
let _ = fs2::FileExt::unlock(&lock_file);
result
}
#[cfg(target_arch = "wasm32")]
@ -572,4 +584,25 @@ mod tests {
let _ = encoding.load().unwrap();
}
}
#[test]
fn test_parallel_load_encodings() {
use std::thread;
let encodings = Encoding::all();
for encoding in encodings {
let name = encoding.name();
let handles: Vec<_> = (0..8)
.map(|_| {
let name = name.to_string();
thread::spawn(move || {
Encoding::from_name(&name).unwrap().load().unwrap();
})
})
.collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
}
}
}

View file

@ -35,6 +35,7 @@ from openai_harmony import ( # noqa: E402
SystemContent,
ToolDescription,
load_harmony_encoding,
load_harmony_encoding_from_file,
)
from pydantic import ValidationError
@ -981,3 +982,67 @@ def test_streamable_parser_tool_call_with_constrain_adjacent():
]
assert parser.messages == expected
def test_load_harmony_encoding_from_file(tmp_path):
import os
from openai_harmony import load_harmony_encoding_from_file
cache_dir = os.environ.get("TIKTOKEN_RS_CACHE_DIR")
if not cache_dir:
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "tiktoken-rs-cache")
import hashlib
url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
cache_key = hashlib.sha1(url.encode()).hexdigest()
vocab_file = os.path.join(cache_dir, cache_key)
if not os.path.exists(vocab_file):
import pytest
pytest.skip("No local vocab file available for offline test")
special_tokens = [
("<|startoftext|>", 199998),
("<|endoftext|>", 199999),
("<|reserved_200000|>", 200000),
("<|reserved_200001|>", 200001),
("<|return|>", 200002),
("<|constrain|>", 200003),
("<|reserved_200004|>", 200004),
("<|channel|>", 200005),
("<|start|>", 200006),
("<|end|>", 200007),
("<|message|>", 200008),
("<|reserved_200009|>", 200009),
("<|reserved_200010|>", 200010),
("<|reserved_200011|>", 200011),
("<|call|>", 200012),
("<|reserved_200013|>", 200013),
]
pattern = "|".join([
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
"\\p{N}{1,3}",
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
"\\s*[\\r\\n]+",
"\\s+(?!\\S)",
"\\s+",
])
n_ctx = 8192
max_message_tokens = 4096
max_action_length = 256
expected_hash = "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d"
encoding = load_harmony_encoding_from_file(
name="test_local",
vocab_file=vocab_file,
special_tokens=special_tokens,
pattern=pattern,
n_ctx=n_ctx,
max_message_tokens=max_message_tokens,
max_action_length=max_action_length,
expected_hash=expected_hash,
)
text = "Hello world!"
tokens = encoding.encode(text)
decoded = encoding.decode(tokens)
assert decoded.startswith("Hello world")