mirror of
https://github.com/openai/harmony.git
synced 2025-08-25 07:17:08 -04:00
fix race conditions and add offline tokenizer loading api
This commit is contained in:
parent
9528c7b4a0
commit
1dca639293
8 changed files with 235 additions and 22 deletions
33
Cargo.lock
generated
33
Cargo.lock
generated
|
@ -572,6 +572,16 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fs2"
|
||||||
|
version = "0.4.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures"
|
name = "futures"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
|
@ -1324,6 +1334,7 @@ dependencies = [
|
||||||
"bstr",
|
"bstr",
|
||||||
"clap",
|
"clap",
|
||||||
"fancy-regex",
|
"fancy-regex",
|
||||||
|
"fs2",
|
||||||
"futures",
|
"futures",
|
||||||
"image",
|
"image",
|
||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
|
@ -2556,6 +2567,28 @@ version = "0.1.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
|
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi"
|
||||||
|
version = "0.3.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-i686-pc-windows-gnu",
|
||||||
|
"winapi-x86_64-pc-windows-gnu",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-i686-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-x86_64-pc-windows-gnu"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-core"
|
name = "windows-core"
|
||||||
version = "0.61.0"
|
version = "0.61.0"
|
||||||
|
|
|
@ -18,6 +18,7 @@ wasm-binding = ["wasm-bindgen", "serde-wasm-bindgen", "wasm-bindgen-futures"]
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.98"
|
anyhow = "1.0.98"
|
||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
|
fs2 = "0.4.3"
|
||||||
image = "0.25.6"
|
image = "0.25.6"
|
||||||
serde = { version = "1.0.219", features = ["derive"] }
|
serde = { version = "1.0.219", features = ["derive"] }
|
||||||
serde_json = { version = "1.0.140", features = ["preserve_order"] }
|
serde_json = { version = "1.0.140", features = ["preserve_order"] }
|
||||||
|
|
|
@ -36,13 +36,10 @@ from pydantic import BaseModel, Field
|
||||||
try:
|
try:
|
||||||
from .openai_harmony import (
|
from .openai_harmony import (
|
||||||
HarmonyError as HarmonyError, # expose the actual Rust error directly
|
HarmonyError as HarmonyError, # expose the actual Rust error directly
|
||||||
)
|
PyHarmonyEncoding as _PyHarmonyEncoding, # type: ignore
|
||||||
from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding # type: ignore
|
|
||||||
from .openai_harmony import (
|
|
||||||
PyStreamableParser as _PyStreamableParser, # type: ignore
|
PyStreamableParser as _PyStreamableParser, # type: ignore
|
||||||
)
|
|
||||||
from .openai_harmony import (
|
|
||||||
load_harmony_encoding as _load_harmony_encoding, # type: ignore
|
load_harmony_encoding as _load_harmony_encoding, # type: ignore
|
||||||
|
load_harmony_encoding_from_file as _load_harmony_encoding_from_file, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
except ModuleNotFoundError: # pragma: no cover – raised during type-checking
|
except ModuleNotFoundError: # pragma: no cover – raised during type-checking
|
||||||
|
@ -690,6 +687,32 @@ def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding:
|
||||||
return HarmonyEncoding(inner)
|
return HarmonyEncoding(inner)
|
||||||
|
|
||||||
|
|
||||||
|
def load_harmony_encoding_from_file(
|
||||||
|
name: str,
|
||||||
|
vocab_file: str,
|
||||||
|
special_tokens: list[tuple[str, int]],
|
||||||
|
pattern: str,
|
||||||
|
n_ctx: int,
|
||||||
|
max_message_tokens: int,
|
||||||
|
max_action_length: int,
|
||||||
|
expected_hash: str | None = None,
|
||||||
|
) -> HarmonyEncoding:
|
||||||
|
"""Load a HarmonyEncoding from a local vocab file (offline usage).
|
||||||
|
Use this when network access is restricted or for reproducible builds where you want to avoid remote downloads.
|
||||||
|
"""
|
||||||
|
inner: _PyHarmonyEncoding = _load_harmony_encoding_from_file(
|
||||||
|
name,
|
||||||
|
vocab_file,
|
||||||
|
special_tokens,
|
||||||
|
pattern,
|
||||||
|
n_ctx,
|
||||||
|
max_message_tokens,
|
||||||
|
max_action_length,
|
||||||
|
expected_hash,
|
||||||
|
)
|
||||||
|
return HarmonyEncoding(inner)
|
||||||
|
|
||||||
|
|
||||||
# For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum. At
|
# For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum. At
|
||||||
# **runtime** the user is expected to pass the *string* names because the Rust
|
# **runtime** the user is expected to pass the *string* names because the Rust
|
||||||
# side only operates on strings anyway.
|
# side only operates on strings anyway.
|
||||||
|
@ -718,4 +741,5 @@ __all__ = [
|
||||||
"StreamableParser",
|
"StreamableParser",
|
||||||
"StreamState",
|
"StreamState",
|
||||||
"HarmonyError",
|
"HarmonyError",
|
||||||
|
"load_harmony_encoding_from_file",
|
||||||
]
|
]
|
||||||
|
|
|
@ -154,6 +154,31 @@ impl HarmonyEncoding {
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_local_file(
|
||||||
|
name: String,
|
||||||
|
vocab_file: &std::path::Path,
|
||||||
|
expected_hash: Option<&str>,
|
||||||
|
special_tokens: impl IntoIterator<Item = (String, u32)>,
|
||||||
|
pattern: &str,
|
||||||
|
n_ctx: usize,
|
||||||
|
max_message_tokens: usize,
|
||||||
|
max_action_length: usize,
|
||||||
|
) -> anyhow::Result<Self> {
|
||||||
|
use crate::tiktoken_ext::public_encodings::load_encoding_from_file;
|
||||||
|
let bpe = load_encoding_from_file(vocab_file, expected_hash, special_tokens, pattern)?;
|
||||||
|
Ok(HarmonyEncoding {
|
||||||
|
name,
|
||||||
|
n_ctx,
|
||||||
|
max_message_tokens,
|
||||||
|
max_action_length,
|
||||||
|
tokenizer_name: vocab_file.display().to_string(),
|
||||||
|
tokenizer: std::sync::Arc::new(bpe),
|
||||||
|
format_token_mapping: Default::default(),
|
||||||
|
stop_formatting_tokens: Default::default(),
|
||||||
|
stop_formatting_tokens_for_assistant_actions: Default::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Methods for rendering conversations
|
// Methods for rendering conversations
|
||||||
|
|
|
@ -396,6 +396,38 @@ fn openai_harmony(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||||
}
|
}
|
||||||
m.add_function(pyo3::wrap_pyfunction!(load_harmony_encoding_py, m)?)?;
|
m.add_function(pyo3::wrap_pyfunction!(load_harmony_encoding_py, m)?)?;
|
||||||
|
|
||||||
|
// Convenience function to load a HarmonyEncoding from a local vocab file for offline
|
||||||
|
// scenarios or reproducible builds where remote download is not possible.
|
||||||
|
#[pyfunction(name = "load_harmony_encoding_from_file")]
|
||||||
|
fn load_harmony_encoding_from_file_py(
|
||||||
|
py: Python<'_>,
|
||||||
|
name: &str,
|
||||||
|
vocab_file: &str,
|
||||||
|
special_tokens: Vec<(String, u32)>,
|
||||||
|
pattern: &str,
|
||||||
|
n_ctx: usize,
|
||||||
|
max_message_tokens: usize,
|
||||||
|
max_action_length: usize,
|
||||||
|
expected_hash: Option<&str>,
|
||||||
|
) -> PyResult<Py<PyHarmonyEncoding>> {
|
||||||
|
let encoding = HarmonyEncoding::from_local_file(
|
||||||
|
name.to_string(),
|
||||||
|
std::path::Path::new(vocab_file),
|
||||||
|
expected_hash,
|
||||||
|
special_tokens,
|
||||||
|
pattern,
|
||||||
|
n_ctx,
|
||||||
|
max_message_tokens,
|
||||||
|
max_action_length,
|
||||||
|
)
|
||||||
|
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
|
||||||
|
Py::new(py, PyHarmonyEncoding { inner: encoding })
|
||||||
|
}
|
||||||
|
m.add_function(pyo3::wrap_pyfunction!(
|
||||||
|
load_harmony_encoding_from_file_py,
|
||||||
|
m
|
||||||
|
)?)?;
|
||||||
|
|
||||||
// Convenience functions to get the tool configs for the browser and python tools.
|
// Convenience functions to get the tool configs for the browser and python tools.
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn get_tool_namespace_config(py: Python<'_>, tool: &str) -> PyResult<PyObject> {
|
fn get_tool_namespace_config(py: Python<'_>, tool: &str) -> PyResult<PyObject> {
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
mod public_encodings;
|
pub mod public_encodings;
|
||||||
pub use public_encodings::{set_tiktoken_base_url, Encoding};
|
pub use public_encodings::{set_tiktoken_base_url, Encoding};
|
||||||
|
|
|
@ -9,6 +9,7 @@ use std::{
|
||||||
use base64::{prelude::BASE64_STANDARD, Engine as _};
|
use base64::{prelude::BASE64_STANDARD, Engine as _};
|
||||||
|
|
||||||
use crate::tiktoken::{CoreBPE, Rank};
|
use crate::tiktoken::{CoreBPE, Rank};
|
||||||
|
use fs2::FileExt;
|
||||||
use sha1::Sha1;
|
use sha1::Sha1;
|
||||||
use sha2::{Digest as _, Sha256};
|
use sha2::{Digest as _, Sha256};
|
||||||
|
|
||||||
|
@ -420,24 +421,35 @@ fn download_or_find_cached_file(
|
||||||
) -> Result<PathBuf, RemoteVocabFileError> {
|
) -> Result<PathBuf, RemoteVocabFileError> {
|
||||||
let cache_dir = resolve_cache_dir()?;
|
let cache_dir = resolve_cache_dir()?;
|
||||||
let cache_path = resolve_cache_path(&cache_dir, url);
|
let cache_path = resolve_cache_path(&cache_dir, url);
|
||||||
if cache_path.exists() {
|
let lock_path = cache_path.with_extension("lock");
|
||||||
if verify_file_hash(&cache_path, expected_hash)? {
|
let lock_file = File::create(&lock_path).map_err(|e| {
|
||||||
return Ok(cache_path);
|
RemoteVocabFileError::IOError(format!("creating lock file {lock_path:?}"), e)
|
||||||
}
|
})?;
|
||||||
let _ = std::fs::remove_file(&cache_path);
|
lock_file
|
||||||
}
|
.lock_exclusive()
|
||||||
let hash = load_remote_file(url, &cache_path)?;
|
.map_err(|e| RemoteVocabFileError::IOError(format!("locking file {lock_path:?}"), e))?;
|
||||||
if let Some(expected_hash) = expected_hash {
|
let result = (|| {
|
||||||
if hash != expected_hash {
|
if cache_path.exists() {
|
||||||
|
if verify_file_hash(&cache_path, expected_hash)? {
|
||||||
|
return Ok(cache_path);
|
||||||
|
}
|
||||||
let _ = std::fs::remove_file(&cache_path);
|
let _ = std::fs::remove_file(&cache_path);
|
||||||
return Err(RemoteVocabFileError::HashMismatch {
|
|
||||||
file_url: url.to_string(),
|
|
||||||
expected_hash: expected_hash.to_string(),
|
|
||||||
computed_hash: hash,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
let hash = load_remote_file(url, &cache_path)?;
|
||||||
Ok(cache_path)
|
if let Some(expected_hash) = expected_hash {
|
||||||
|
if hash != expected_hash {
|
||||||
|
let _ = std::fs::remove_file(&cache_path);
|
||||||
|
return Err(RemoteVocabFileError::HashMismatch {
|
||||||
|
file_url: url.to_string(),
|
||||||
|
expected_hash: expected_hash.to_string(),
|
||||||
|
computed_hash: hash,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(cache_path)
|
||||||
|
})();
|
||||||
|
let _ = fs2::FileExt::unlock(&lock_file);
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_arch = "wasm32")]
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
@ -572,4 +584,25 @@ mod tests {
|
||||||
let _ = encoding.load().unwrap();
|
let _ = encoding.load().unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parallel_load_encodings() {
|
||||||
|
use std::thread;
|
||||||
|
|
||||||
|
let encodings = Encoding::all();
|
||||||
|
for encoding in encodings {
|
||||||
|
let name = encoding.name();
|
||||||
|
let handles: Vec<_> = (0..8)
|
||||||
|
.map(|_| {
|
||||||
|
let name = name.to_string();
|
||||||
|
thread::spawn(move || {
|
||||||
|
Encoding::from_name(&name).unwrap().load().unwrap();
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
for handle in handles {
|
||||||
|
handle.join().expect("Thread panicked");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,7 @@ from openai_harmony import ( # noqa: E402
|
||||||
SystemContent,
|
SystemContent,
|
||||||
ToolDescription,
|
ToolDescription,
|
||||||
load_harmony_encoding,
|
load_harmony_encoding,
|
||||||
|
load_harmony_encoding_from_file,
|
||||||
)
|
)
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
@ -949,3 +950,67 @@ def test_streamable_parser_tool_call_with_constrain_adjacent():
|
||||||
]
|
]
|
||||||
|
|
||||||
assert parser.messages == expected
|
assert parser.messages == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_harmony_encoding_from_file(tmp_path):
|
||||||
|
import os
|
||||||
|
from openai_harmony import load_harmony_encoding_from_file
|
||||||
|
|
||||||
|
cache_dir = os.environ.get("TIKTOKEN_RS_CACHE_DIR")
|
||||||
|
if not cache_dir:
|
||||||
|
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "tiktoken-rs-cache")
|
||||||
|
import hashlib
|
||||||
|
url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
|
||||||
|
cache_key = hashlib.sha1(url.encode()).hexdigest()
|
||||||
|
vocab_file = os.path.join(cache_dir, cache_key)
|
||||||
|
if not os.path.exists(vocab_file):
|
||||||
|
import pytest
|
||||||
|
pytest.skip("No local vocab file available for offline test")
|
||||||
|
|
||||||
|
special_tokens = [
|
||||||
|
("<|startoftext|>", 199998),
|
||||||
|
("<|endoftext|>", 199999),
|
||||||
|
("<|reserved_200000|>", 200000),
|
||||||
|
("<|reserved_200001|>", 200001),
|
||||||
|
("<|return|>", 200002),
|
||||||
|
("<|constrain|>", 200003),
|
||||||
|
("<|reserved_200004|>", 200004),
|
||||||
|
("<|channel|>", 200005),
|
||||||
|
("<|start|>", 200006),
|
||||||
|
("<|end|>", 200007),
|
||||||
|
("<|message|>", 200008),
|
||||||
|
("<|reserved_200009|>", 200009),
|
||||||
|
("<|reserved_200010|>", 200010),
|
||||||
|
("<|reserved_200011|>", 200011),
|
||||||
|
("<|call|>", 200012),
|
||||||
|
("<|reserved_200013|>", 200013),
|
||||||
|
]
|
||||||
|
pattern = "|".join([
|
||||||
|
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
|
||||||
|
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
|
||||||
|
"\\p{N}{1,3}",
|
||||||
|
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
|
||||||
|
"\\s*[\\r\\n]+",
|
||||||
|
"\\s+(?!\\S)",
|
||||||
|
"\\s+",
|
||||||
|
])
|
||||||
|
n_ctx = 8192
|
||||||
|
max_message_tokens = 4096
|
||||||
|
max_action_length = 256
|
||||||
|
expected_hash = "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d"
|
||||||
|
|
||||||
|
encoding = load_harmony_encoding_from_file(
|
||||||
|
name="test_local",
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
special_tokens=special_tokens,
|
||||||
|
pattern=pattern,
|
||||||
|
n_ctx=n_ctx,
|
||||||
|
max_message_tokens=max_message_tokens,
|
||||||
|
max_action_length=max_action_length,
|
||||||
|
expected_hash=expected_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = "Hello world!"
|
||||||
|
tokens = encoding.encode(text)
|
||||||
|
decoded = encoding.decode(tokens)
|
||||||
|
assert decoded.startswith("Hello world")
|
Loading…
Add table
Add a link
Reference in a new issue