From 1dca6392934bf4e3c403b2ecc2104e8ff3f67f45 Mon Sep 17 00:00:00 2001 From: Amirhossein Ghanipour Date: Wed, 6 Aug 2025 20:28:46 +0330 Subject: [PATCH] fix race conditions and add offline tokenizer loading api --- Cargo.lock | 33 ++++++++++++++ Cargo.toml | 1 + python/openai_harmony/__init__.py | 34 ++++++++++++--- src/encoding.rs | 25 +++++++++++ src/py_module.rs | 32 ++++++++++++++ src/tiktoken_ext/mod.rs | 2 +- src/tiktoken_ext/public_encodings.rs | 65 +++++++++++++++++++++------- tests/test_harmony.py | 65 ++++++++++++++++++++++++++++ 8 files changed, 235 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ce97b77..23e78a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -572,6 +572,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "futures" version = "0.3.31" @@ -1324,6 +1334,7 @@ dependencies = [ "bstr", "clap", "fancy-regex", + "fs2", "futures", "image", "pretty_assertions", @@ -2556,6 +2567,28 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.61.0" diff --git a/Cargo.toml b/Cargo.toml index 25d070c..87a8088 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ wasm-binding = ["wasm-bindgen", "serde-wasm-bindgen", "wasm-bindgen-futures"] [dependencies] anyhow = "1.0.98" base64 = "0.22.1" +fs2 = "0.4.3" image = "0.25.6" serde = { version = "1.0.219", features = ["derive"] } serde_json = { version = "1.0.140", features = ["preserve_order"] } diff --git a/python/openai_harmony/__init__.py b/python/openai_harmony/__init__.py index 13b5fdd..a922ff4 100644 --- a/python/openai_harmony/__init__.py +++ b/python/openai_harmony/__init__.py @@ -36,13 +36,10 @@ from pydantic import BaseModel, Field try: from .openai_harmony import ( HarmonyError as HarmonyError, # expose the actual Rust error directly - ) - from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding # type: ignore - from .openai_harmony import ( + PyHarmonyEncoding as _PyHarmonyEncoding, # type: ignore PyStreamableParser as _PyStreamableParser, # type: ignore - ) - from .openai_harmony import ( load_harmony_encoding as _load_harmony_encoding, # type: ignore + load_harmony_encoding_from_file as _load_harmony_encoding_from_file, # type: ignore ) except ModuleNotFoundError: # pragma: no cover – raised during type-checking @@ -690,6 +687,32 @@ def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding: return HarmonyEncoding(inner) +def load_harmony_encoding_from_file( + name: str, + vocab_file: str, + special_tokens: list[tuple[str, int]], + pattern: str, + n_ctx: int, + max_message_tokens: int, + max_action_length: int, + expected_hash: str | None = None, +) -> HarmonyEncoding: + """Load a HarmonyEncoding from a local vocab file (offline usage). + Use this when network access is restricted or for reproducible builds where you want to avoid remote downloads. + """ + inner: _PyHarmonyEncoding = _load_harmony_encoding_from_file( + name, + vocab_file, + special_tokens, + pattern, + n_ctx, + max_message_tokens, + max_action_length, + expected_hash, + ) + return HarmonyEncoding(inner) + + # For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum. At # **runtime** the user is expected to pass the *string* names because the Rust # side only operates on strings anyway. @@ -718,4 +741,5 @@ __all__ = [ "StreamableParser", "StreamState", "HarmonyError", + "load_harmony_encoding_from_file", ] diff --git a/src/encoding.rs b/src/encoding.rs index afe1fce..c791352 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -154,6 +154,31 @@ impl HarmonyEncoding { }) .collect() } + + pub fn from_local_file( + name: String, + vocab_file: &std::path::Path, + expected_hash: Option<&str>, + special_tokens: impl IntoIterator, + pattern: &str, + n_ctx: usize, + max_message_tokens: usize, + max_action_length: usize, + ) -> anyhow::Result { + use crate::tiktoken_ext::public_encodings::load_encoding_from_file; + let bpe = load_encoding_from_file(vocab_file, expected_hash, special_tokens, pattern)?; + Ok(HarmonyEncoding { + name, + n_ctx, + max_message_tokens, + max_action_length, + tokenizer_name: vocab_file.display().to_string(), + tokenizer: std::sync::Arc::new(bpe), + format_token_mapping: Default::default(), + stop_formatting_tokens: Default::default(), + stop_formatting_tokens_for_assistant_actions: Default::default(), + }) + } } // Methods for rendering conversations diff --git a/src/py_module.rs b/src/py_module.rs index c5c7b0a..a479397 100644 --- a/src/py_module.rs +++ b/src/py_module.rs @@ -396,6 +396,38 @@ fn openai_harmony(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { } m.add_function(pyo3::wrap_pyfunction!(load_harmony_encoding_py, m)?)?; + // Convenience function to load a HarmonyEncoding from a local vocab file for offline + // scenarios or reproducible builds where remote download is not possible. + #[pyfunction(name = "load_harmony_encoding_from_file")] + fn load_harmony_encoding_from_file_py( + py: Python<'_>, + name: &str, + vocab_file: &str, + special_tokens: Vec<(String, u32)>, + pattern: &str, + n_ctx: usize, + max_message_tokens: usize, + max_action_length: usize, + expected_hash: Option<&str>, + ) -> PyResult> { + let encoding = HarmonyEncoding::from_local_file( + name.to_string(), + std::path::Path::new(vocab_file), + expected_hash, + special_tokens, + pattern, + n_ctx, + max_message_tokens, + max_action_length, + ) + .map_err(|e| PyErr::new::(e.to_string()))?; + Py::new(py, PyHarmonyEncoding { inner: encoding }) + } + m.add_function(pyo3::wrap_pyfunction!( + load_harmony_encoding_from_file_py, + m + )?)?; + // Convenience functions to get the tool configs for the browser and python tools. #[pyfunction] fn get_tool_namespace_config(py: Python<'_>, tool: &str) -> PyResult { diff --git a/src/tiktoken_ext/mod.rs b/src/tiktoken_ext/mod.rs index 5ad31ae..df92258 100644 --- a/src/tiktoken_ext/mod.rs +++ b/src/tiktoken_ext/mod.rs @@ -1,2 +1,2 @@ -mod public_encodings; +pub mod public_encodings; pub use public_encodings::{set_tiktoken_base_url, Encoding}; diff --git a/src/tiktoken_ext/public_encodings.rs b/src/tiktoken_ext/public_encodings.rs index ab9c435..bd425c3 100644 --- a/src/tiktoken_ext/public_encodings.rs +++ b/src/tiktoken_ext/public_encodings.rs @@ -9,6 +9,7 @@ use std::{ use base64::{prelude::BASE64_STANDARD, Engine as _}; use crate::tiktoken::{CoreBPE, Rank}; +use fs2::FileExt; use sha1::Sha1; use sha2::{Digest as _, Sha256}; @@ -420,24 +421,35 @@ fn download_or_find_cached_file( ) -> Result { let cache_dir = resolve_cache_dir()?; let cache_path = resolve_cache_path(&cache_dir, url); - if cache_path.exists() { - if verify_file_hash(&cache_path, expected_hash)? { - return Ok(cache_path); - } - let _ = std::fs::remove_file(&cache_path); - } - let hash = load_remote_file(url, &cache_path)?; - if let Some(expected_hash) = expected_hash { - if hash != expected_hash { + let lock_path = cache_path.with_extension("lock"); + let lock_file = File::create(&lock_path).map_err(|e| { + RemoteVocabFileError::IOError(format!("creating lock file {lock_path:?}"), e) + })?; + lock_file + .lock_exclusive() + .map_err(|e| RemoteVocabFileError::IOError(format!("locking file {lock_path:?}"), e))?; + let result = (|| { + if cache_path.exists() { + if verify_file_hash(&cache_path, expected_hash)? { + return Ok(cache_path); + } let _ = std::fs::remove_file(&cache_path); - return Err(RemoteVocabFileError::HashMismatch { - file_url: url.to_string(), - expected_hash: expected_hash.to_string(), - computed_hash: hash, - }); } - } - Ok(cache_path) + let hash = load_remote_file(url, &cache_path)?; + if let Some(expected_hash) = expected_hash { + if hash != expected_hash { + let _ = std::fs::remove_file(&cache_path); + return Err(RemoteVocabFileError::HashMismatch { + file_url: url.to_string(), + expected_hash: expected_hash.to_string(), + computed_hash: hash, + }); + } + } + Ok(cache_path) + })(); + let _ = fs2::FileExt::unlock(&lock_file); + result } #[cfg(target_arch = "wasm32")] @@ -572,4 +584,25 @@ mod tests { let _ = encoding.load().unwrap(); } } + + #[test] + fn test_parallel_load_encodings() { + use std::thread; + + let encodings = Encoding::all(); + for encoding in encodings { + let name = encoding.name(); + let handles: Vec<_> = (0..8) + .map(|_| { + let name = name.to_string(); + thread::spawn(move || { + Encoding::from_name(&name).unwrap().load().unwrap(); + }) + }) + .collect(); + for handle in handles { + handle.join().expect("Thread panicked"); + } + } + } } diff --git a/tests/test_harmony.py b/tests/test_harmony.py index 07d5562..a962f79 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -35,6 +35,7 @@ from openai_harmony import ( # noqa: E402 SystemContent, ToolDescription, load_harmony_encoding, + load_harmony_encoding_from_file, ) from pydantic import ValidationError @@ -949,3 +950,67 @@ def test_streamable_parser_tool_call_with_constrain_adjacent(): ] assert parser.messages == expected + + +def test_load_harmony_encoding_from_file(tmp_path): + import os + from openai_harmony import load_harmony_encoding_from_file + + cache_dir = os.environ.get("TIKTOKEN_RS_CACHE_DIR") + if not cache_dir: + cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "tiktoken-rs-cache") + import hashlib + url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" + cache_key = hashlib.sha1(url.encode()).hexdigest() + vocab_file = os.path.join(cache_dir, cache_key) + if not os.path.exists(vocab_file): + import pytest + pytest.skip("No local vocab file available for offline test") + + special_tokens = [ + ("<|startoftext|>", 199998), + ("<|endoftext|>", 199999), + ("<|reserved_200000|>", 200000), + ("<|reserved_200001|>", 200001), + ("<|return|>", 200002), + ("<|constrain|>", 200003), + ("<|reserved_200004|>", 200004), + ("<|channel|>", 200005), + ("<|start|>", 200006), + ("<|end|>", 200007), + ("<|message|>", 200008), + ("<|reserved_200009|>", 200009), + ("<|reserved_200010|>", 200010), + ("<|reserved_200011|>", 200011), + ("<|call|>", 200012), + ("<|reserved_200013|>", 200013), + ] + pattern = "|".join([ + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?", + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?", + "\\p{N}{1,3}", + " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*", + "\\s*[\\r\\n]+", + "\\s+(?!\\S)", + "\\s+", + ]) + n_ctx = 8192 + max_message_tokens = 4096 + max_action_length = 256 + expected_hash = "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d" + + encoding = load_harmony_encoding_from_file( + name="test_local", + vocab_file=vocab_file, + special_tokens=special_tokens, + pattern=pattern, + n_ctx=n_ctx, + max_message_tokens=max_message_tokens, + max_action_length=max_action_length, + expected_hash=expected_hash, + ) + + text = "Hello world!" + tokens = encoding.encode(text) + decoded = encoding.decode(tokens) + assert decoded.startswith("Hello world") \ No newline at end of file