mirror of
https://github.com/openai/harmony.git
synced 2025-08-24 13:17:08 -04:00
Merge 1dca639293
into 508cbaa7f6
This commit is contained in:
commit
27c4fa226a
8 changed files with 235 additions and 22 deletions
33
Cargo.lock
generated
33
Cargo.lock
generated
|
@ -572,6 +572,16 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs2"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.31"
|
||||
|
@ -1324,6 +1334,7 @@ dependencies = [
|
|||
"bstr",
|
||||
"clap",
|
||||
"fancy-regex",
|
||||
"fs2",
|
||||
"futures",
|
||||
"image",
|
||||
"pretty_assertions",
|
||||
|
@ -2556,6 +2567,28 @@ version = "0.1.8"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.61.0"
|
||||
|
|
|
@ -18,6 +18,7 @@ wasm-binding = ["wasm-bindgen", "serde-wasm-bindgen", "wasm-bindgen-futures"]
|
|||
[dependencies]
|
||||
anyhow = "1.0.98"
|
||||
base64 = "0.22.1"
|
||||
fs2 = "0.4.3"
|
||||
image = "0.25.6"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = { version = "1.0.140", features = ["preserve_order"] }
|
||||
|
|
|
@ -36,13 +36,10 @@ from pydantic import BaseModel, Field
|
|||
try:
|
||||
from .openai_harmony import (
|
||||
HarmonyError as HarmonyError, # expose the actual Rust error directly
|
||||
)
|
||||
from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding # type: ignore
|
||||
from .openai_harmony import (
|
||||
PyHarmonyEncoding as _PyHarmonyEncoding, # type: ignore
|
||||
PyStreamableParser as _PyStreamableParser, # type: ignore
|
||||
)
|
||||
from .openai_harmony import (
|
||||
load_harmony_encoding as _load_harmony_encoding, # type: ignore
|
||||
load_harmony_encoding_from_file as _load_harmony_encoding_from_file, # type: ignore
|
||||
)
|
||||
|
||||
except ModuleNotFoundError: # pragma: no cover – raised during type-checking
|
||||
|
@ -690,6 +687,32 @@ def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding:
|
|||
return HarmonyEncoding(inner)
|
||||
|
||||
|
||||
def load_harmony_encoding_from_file(
|
||||
name: str,
|
||||
vocab_file: str,
|
||||
special_tokens: list[tuple[str, int]],
|
||||
pattern: str,
|
||||
n_ctx: int,
|
||||
max_message_tokens: int,
|
||||
max_action_length: int,
|
||||
expected_hash: str | None = None,
|
||||
) -> HarmonyEncoding:
|
||||
"""Load a HarmonyEncoding from a local vocab file (offline usage).
|
||||
Use this when network access is restricted or for reproducible builds where you want to avoid remote downloads.
|
||||
"""
|
||||
inner: _PyHarmonyEncoding = _load_harmony_encoding_from_file(
|
||||
name,
|
||||
vocab_file,
|
||||
special_tokens,
|
||||
pattern,
|
||||
n_ctx,
|
||||
max_message_tokens,
|
||||
max_action_length,
|
||||
expected_hash,
|
||||
)
|
||||
return HarmonyEncoding(inner)
|
||||
|
||||
|
||||
# For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum. At
|
||||
# **runtime** the user is expected to pass the *string* names because the Rust
|
||||
# side only operates on strings anyway.
|
||||
|
@ -719,4 +742,5 @@ __all__ = [
|
|||
"StreamableParser",
|
||||
"StreamState",
|
||||
"HarmonyError",
|
||||
"load_harmony_encoding_from_file",
|
||||
]
|
||||
|
|
|
@ -154,6 +154,31 @@ impl HarmonyEncoding {
|
|||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn from_local_file(
|
||||
name: String,
|
||||
vocab_file: &std::path::Path,
|
||||
expected_hash: Option<&str>,
|
||||
special_tokens: impl IntoIterator<Item = (String, u32)>,
|
||||
pattern: &str,
|
||||
n_ctx: usize,
|
||||
max_message_tokens: usize,
|
||||
max_action_length: usize,
|
||||
) -> anyhow::Result<Self> {
|
||||
use crate::tiktoken_ext::public_encodings::load_encoding_from_file;
|
||||
let bpe = load_encoding_from_file(vocab_file, expected_hash, special_tokens, pattern)?;
|
||||
Ok(HarmonyEncoding {
|
||||
name,
|
||||
n_ctx,
|
||||
max_message_tokens,
|
||||
max_action_length,
|
||||
tokenizer_name: vocab_file.display().to_string(),
|
||||
tokenizer: std::sync::Arc::new(bpe),
|
||||
format_token_mapping: Default::default(),
|
||||
stop_formatting_tokens: Default::default(),
|
||||
stop_formatting_tokens_for_assistant_actions: Default::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Methods for rendering conversations
|
||||
|
|
|
@ -396,6 +396,38 @@ fn openai_harmony(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|||
}
|
||||
m.add_function(pyo3::wrap_pyfunction!(load_harmony_encoding_py, m)?)?;
|
||||
|
||||
// Convenience function to load a HarmonyEncoding from a local vocab file for offline
|
||||
// scenarios or reproducible builds where remote download is not possible.
|
||||
#[pyfunction(name = "load_harmony_encoding_from_file")]
|
||||
fn load_harmony_encoding_from_file_py(
|
||||
py: Python<'_>,
|
||||
name: &str,
|
||||
vocab_file: &str,
|
||||
special_tokens: Vec<(String, u32)>,
|
||||
pattern: &str,
|
||||
n_ctx: usize,
|
||||
max_message_tokens: usize,
|
||||
max_action_length: usize,
|
||||
expected_hash: Option<&str>,
|
||||
) -> PyResult<Py<PyHarmonyEncoding>> {
|
||||
let encoding = HarmonyEncoding::from_local_file(
|
||||
name.to_string(),
|
||||
std::path::Path::new(vocab_file),
|
||||
expected_hash,
|
||||
special_tokens,
|
||||
pattern,
|
||||
n_ctx,
|
||||
max_message_tokens,
|
||||
max_action_length,
|
||||
)
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
|
||||
Py::new(py, PyHarmonyEncoding { inner: encoding })
|
||||
}
|
||||
m.add_function(pyo3::wrap_pyfunction!(
|
||||
load_harmony_encoding_from_file_py,
|
||||
m
|
||||
)?)?;
|
||||
|
||||
// Convenience functions to get the tool configs for the browser and python tools.
|
||||
#[pyfunction]
|
||||
fn get_tool_namespace_config(py: Python<'_>, tool: &str) -> PyResult<PyObject> {
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
mod public_encodings;
|
||||
pub mod public_encodings;
|
||||
pub use public_encodings::{set_tiktoken_base_url, Encoding};
|
||||
|
|
|
@ -9,6 +9,7 @@ use std::{
|
|||
use base64::{prelude::BASE64_STANDARD, Engine as _};
|
||||
|
||||
use crate::tiktoken::{CoreBPE, Rank};
|
||||
use fs2::FileExt;
|
||||
use sha1::Sha1;
|
||||
use sha2::{Digest as _, Sha256};
|
||||
|
||||
|
@ -420,24 +421,35 @@ fn download_or_find_cached_file(
|
|||
) -> Result<PathBuf, RemoteVocabFileError> {
|
||||
let cache_dir = resolve_cache_dir()?;
|
||||
let cache_path = resolve_cache_path(&cache_dir, url);
|
||||
if cache_path.exists() {
|
||||
if verify_file_hash(&cache_path, expected_hash)? {
|
||||
return Ok(cache_path);
|
||||
}
|
||||
let _ = std::fs::remove_file(&cache_path);
|
||||
}
|
||||
let hash = load_remote_file(url, &cache_path)?;
|
||||
if let Some(expected_hash) = expected_hash {
|
||||
if hash != expected_hash {
|
||||
let lock_path = cache_path.with_extension("lock");
|
||||
let lock_file = File::create(&lock_path).map_err(|e| {
|
||||
RemoteVocabFileError::IOError(format!("creating lock file {lock_path:?}"), e)
|
||||
})?;
|
||||
lock_file
|
||||
.lock_exclusive()
|
||||
.map_err(|e| RemoteVocabFileError::IOError(format!("locking file {lock_path:?}"), e))?;
|
||||
let result = (|| {
|
||||
if cache_path.exists() {
|
||||
if verify_file_hash(&cache_path, expected_hash)? {
|
||||
return Ok(cache_path);
|
||||
}
|
||||
let _ = std::fs::remove_file(&cache_path);
|
||||
return Err(RemoteVocabFileError::HashMismatch {
|
||||
file_url: url.to_string(),
|
||||
expected_hash: expected_hash.to_string(),
|
||||
computed_hash: hash,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(cache_path)
|
||||
let hash = load_remote_file(url, &cache_path)?;
|
||||
if let Some(expected_hash) = expected_hash {
|
||||
if hash != expected_hash {
|
||||
let _ = std::fs::remove_file(&cache_path);
|
||||
return Err(RemoteVocabFileError::HashMismatch {
|
||||
file_url: url.to_string(),
|
||||
expected_hash: expected_hash.to_string(),
|
||||
computed_hash: hash,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(cache_path)
|
||||
})();
|
||||
let _ = fs2::FileExt::unlock(&lock_file);
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
|
@ -572,4 +584,25 @@ mod tests {
|
|||
let _ = encoding.load().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_load_encodings() {
|
||||
use std::thread;
|
||||
|
||||
let encodings = Encoding::all();
|
||||
for encoding in encodings {
|
||||
let name = encoding.name();
|
||||
let handles: Vec<_> = (0..8)
|
||||
.map(|_| {
|
||||
let name = name.to_string();
|
||||
thread::spawn(move || {
|
||||
Encoding::from_name(&name).unwrap().load().unwrap();
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
for handle in handles {
|
||||
handle.join().expect("Thread panicked");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,6 +35,7 @@ from openai_harmony import ( # noqa: E402
|
|||
SystemContent,
|
||||
ToolDescription,
|
||||
load_harmony_encoding,
|
||||
load_harmony_encoding_from_file,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
@ -981,3 +982,67 @@ def test_streamable_parser_tool_call_with_constrain_adjacent():
|
|||
]
|
||||
|
||||
assert parser.messages == expected
|
||||
|
||||
|
||||
def test_load_harmony_encoding_from_file(tmp_path):
|
||||
import os
|
||||
from openai_harmony import load_harmony_encoding_from_file
|
||||
|
||||
cache_dir = os.environ.get("TIKTOKEN_RS_CACHE_DIR")
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "tiktoken-rs-cache")
|
||||
import hashlib
|
||||
url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
|
||||
cache_key = hashlib.sha1(url.encode()).hexdigest()
|
||||
vocab_file = os.path.join(cache_dir, cache_key)
|
||||
if not os.path.exists(vocab_file):
|
||||
import pytest
|
||||
pytest.skip("No local vocab file available for offline test")
|
||||
|
||||
special_tokens = [
|
||||
("<|startoftext|>", 199998),
|
||||
("<|endoftext|>", 199999),
|
||||
("<|reserved_200000|>", 200000),
|
||||
("<|reserved_200001|>", 200001),
|
||||
("<|return|>", 200002),
|
||||
("<|constrain|>", 200003),
|
||||
("<|reserved_200004|>", 200004),
|
||||
("<|channel|>", 200005),
|
||||
("<|start|>", 200006),
|
||||
("<|end|>", 200007),
|
||||
("<|message|>", 200008),
|
||||
("<|reserved_200009|>", 200009),
|
||||
("<|reserved_200010|>", 200010),
|
||||
("<|reserved_200011|>", 200011),
|
||||
("<|call|>", 200012),
|
||||
("<|reserved_200013|>", 200013),
|
||||
]
|
||||
pattern = "|".join([
|
||||
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
|
||||
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
|
||||
"\\p{N}{1,3}",
|
||||
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
|
||||
"\\s*[\\r\\n]+",
|
||||
"\\s+(?!\\S)",
|
||||
"\\s+",
|
||||
])
|
||||
n_ctx = 8192
|
||||
max_message_tokens = 4096
|
||||
max_action_length = 256
|
||||
expected_hash = "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d"
|
||||
|
||||
encoding = load_harmony_encoding_from_file(
|
||||
name="test_local",
|
||||
vocab_file=vocab_file,
|
||||
special_tokens=special_tokens,
|
||||
pattern=pattern,
|
||||
n_ctx=n_ctx,
|
||||
max_message_tokens=max_message_tokens,
|
||||
max_action_length=max_action_length,
|
||||
expected_hash=expected_hash,
|
||||
)
|
||||
|
||||
text = "Hello world!"
|
||||
tokens = encoding.encode(text)
|
||||
decoded = encoding.decode(tokens)
|
||||
assert decoded.startswith("Hello world")
|
Loading…
Add table
Add a link
Reference in a new issue