openai-harmony/src/py_module.rs
Jordan Wu d00ac3de49 fix: make HarmonyEncoding usable concurrently
the `conversation_has_function_tools` atomic bool makes `HarmonyEncoding` stateful
2025-08-05 16:55:03 -07:00

422 lines
14 KiB
Rust

//! Python bindings for the harmony crate.
//!
//! The bindings are kept intentionally small: we expose the `HarmonyEncoding` type
//! together with the operations that are required by the original Rust test
//! suite (rendering a conversation for completion, parsing messages from
//! completion tokens and decoding tokens back into UTF-8). All higher-level
//! data-structures (Conversation, Message, SystemContent, DeveloperContent, …) are passed across the FFI
//! boundary as JSON. This allows us to keep the Rust ↔ Python interface very
//! light-weight while still re-using the exact same logic that is implemented
//! in Rust.
//!
//! A thin, typed, user-facing Python wrapper around these low-level bindings is
//! provided in `harmony/__init__.py`.
use pyo3::prelude::*;
// We need the `Python` type later on.
use pyo3::create_exception;
use pyo3::exceptions::PyRuntimeError;
use pyo3::Python;
use pyo3::types::{PyAny, PyDict, PyModule};
use pyo3::Bound;
// Define a custom Python exception so users can catch Harmony specific errors.
create_exception!(openai_harmony, HarmonyError, PyRuntimeError);
use crate::{
chat::{Message, Role, ToolNamespaceConfig},
encoding::{HarmonyEncoding, StreamableParser},
load_harmony_encoding, HarmonyEncodingName,
};
/// A thin PyO3 wrapper around the Rust `HarmonyEncoding` struct.
#[pyclass]
struct PyHarmonyEncoding {
inner: HarmonyEncoding,
}
/// Streaming parser exposed to Python.
#[pyclass]
struct PyStreamableParser {
inner: StreamableParser,
}
#[pyclass]
pub enum PyStreamState {
ExpectStart,
Header,
Content,
}
#[pymethods]
impl PyHarmonyEncoding {
/// Create a new `HarmonyEncoding` by name.
#[new]
fn new(name: &str) -> PyResult<Self> {
let parsed: HarmonyEncodingName = name
.parse::<HarmonyEncodingName>()
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
let encoding = load_harmony_encoding(parsed)
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
Ok(Self { inner: encoding })
}
/// Return the name of the encoding.
#[getter]
fn name(&self) -> &str {
self.inner.name()
}
/// Render a conversation (in JSON format) for completion.
///
/// Parameters
/// ----------
/// conversation_json : str
/// A JSON encoded `Conversation` (as produced by `serde_json`).
/// next_turn_role : str
/// The role of the *next* turn (e.g. "assistant").
/// config : dict (optional)
/// Optional config dict. Only supports 'auto_drop_analysis' (bool).
///
/// Returns
/// -------
/// List[int]
/// The encoded token sequence.
fn render_conversation_for_completion(
&self,
conversation_json: &str,
next_turn_role: &str,
config: Option<Bound<'_, PyDict>>,
) -> PyResult<Vec<u32>> {
// Deserialize the conversation first.
let conversation: crate::chat::Conversation = serde_json::from_str(conversation_json)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"invalid conversation JSON: {e}"
))
})?;
// Convert the role string into the `Role` enum.
let role = Role::try_from(next_turn_role).map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"unknown role: {next_turn_role}"
))
})?;
// Parse config
let rust_config = if let Some(cfg_dict) = config {
let auto_drop_analysis = cfg_dict
.get_item("auto_drop_analysis")?
.and_then(|v| v.extract().ok())
.unwrap_or(true);
Some(crate::encoding::RenderConversationConfig { auto_drop_analysis })
} else {
None
};
self.inner
.render_conversation_for_completion(&conversation, role, rust_config.as_ref())
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
/// Render a conversation without appending a new role.
fn render_conversation(
&self,
conversation_json: &str,
config: Option<Bound<'_, PyDict>>,
) -> PyResult<Vec<u32>> {
let conversation: crate::chat::Conversation = serde_json::from_str(conversation_json)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"invalid conversation JSON: {e}"
))
})?;
let rust_config = if let Some(cfg_dict) = config {
let auto_drop_analysis = cfg_dict
.get_item("auto_drop_analysis")?
.and_then(|v| v.extract().ok())
.unwrap_or(true);
Some(crate::encoding::RenderConversationConfig { auto_drop_analysis })
} else {
None
};
self.inner
.render_conversation(&conversation, rust_config.as_ref())
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
}
/// Render a conversation for training.
fn render_conversation_for_training(
&self,
conversation_json: &str,
config: Option<Bound<'_, PyDict>>,
) -> PyResult<Vec<u32>> {
let conversation: crate::chat::Conversation = serde_json::from_str(conversation_json)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"invalid conversation JSON: {e}"
))
})?;
let rust_config = if let Some(cfg_dict) = config {
let auto_drop_analysis = cfg_dict
.get_item("auto_drop_analysis")?
.and_then(|v| v.extract().ok())
.unwrap_or(true);
Some(crate::encoding::RenderConversationConfig { auto_drop_analysis })
} else {
None
};
self.inner
.render_conversation_for_training(&conversation, rust_config.as_ref())
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
}
/// Render a single message into tokens.
fn render(
&self,
message_json: &str,
render_options: Option<Bound<'_, PyDict>>,
) -> PyResult<Vec<u32>> {
let message: crate::chat::Message = serde_json::from_str(message_json).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("invalid message JSON: {e}"))
})?;
let rust_options = if let Some(options_dict) = render_options {
let conversation_has_function_tools = options_dict
.get_item("conversation_has_function_tools")?
.and_then(|v| v.extract().ok())
.unwrap_or(false);
Some(crate::encoding::RenderOptions {
conversation_has_function_tools,
})
} else {
None
};
self.inner
.render(&message, rust_options.as_ref())
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
}
/// Given a list of completion tokens, parse them back into a sequence of
/// messages. The result is returned as a JSON string which can be
/// deserialised on the Python side.
#[allow(clippy::needless_pass_by_value)]
fn parse_messages_from_completion_tokens(
&self,
tokens: Vec<u32>,
role: Option<&str>,
) -> PyResult<String> {
let role_parsed = if let Some(r) = role {
Some(Role::try_from(r).map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("unknown role: {r}"))
})?)
} else {
None
};
let messages: Vec<Message> = self
.inner
.parse_messages_from_completion_tokens(tokens, role_parsed)
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
serde_json::to_string(&messages).map_err(|e| {
PyErr::new::<HarmonyError, _>(format!("failed to serialise messages to JSON: {e}"))
})
}
/// Decode a sequence of tokens into text using the underlying tokenizer.
fn decode_utf8(&self, tokens: Vec<u32>) -> PyResult<String> {
self.inner
.tokenizer()
.decode_utf8(tokens)
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
/// Decode a sequence of tokens into raw bytes using the underlying tokenizer.
fn decode_bytes(&self, tokens: Vec<u32>) -> PyResult<Vec<u8>> {
self.inner
.tokenizer()
.decode_bytes(tokens)
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
/// Encode text into tokens using the underlying tokenizer with a set of allowed special tokens.
fn encode(&self, text: &str, allowed_special: Option<Bound<'_, PyAny>>) -> PyResult<Vec<u32>> {
let allowed_vec: Vec<String> = match allowed_special {
Some(obj) => obj.extract::<Vec<String>>().map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"invalid allowed_special: {e}"
))
})?,
None => Vec::new(),
};
let allowed_set: std::collections::HashSet<&str> =
allowed_vec.iter().map(|s| s.as_str()).collect();
Ok(self.inner.tokenizer().encode(text, &allowed_set).0)
}
/// Return the list of special tokens for this tokenizer.
fn special_tokens(&self) -> Vec<String> {
self.inner
.tokenizer()
.special_tokens()
.into_iter()
.map(str::to_string)
.collect()
}
/// Check whether a token id corresponds to a special token.
fn is_special_token(&self, token: u32) -> bool {
self.inner.tokenizer().is_special_token(token)
}
/// Return the stop tokens for the encoding.
fn stop_tokens(&self) -> PyResult<Vec<u32>> {
self.inner
.stop_tokens()
.map(|set| set.into_iter().collect())
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
/// Return the stop tokens for assistant actions.
fn stop_tokens_for_assistant_actions(&self) -> PyResult<Vec<u32>> {
self.inner
.stop_tokens_for_assistant_actions()
.map(|set| set.into_iter().collect())
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
}
#[pymethods]
impl PyStreamableParser {
#[new]
fn new(encoding: &PyHarmonyEncoding, role: Option<&str>) -> PyResult<Self> {
let parsed_role = role
.map(|r| {
Role::try_from(r).map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("unknown role: {r}"))
})
})
.transpose()?;
let inner = StreamableParser::new(encoding.inner.clone(), parsed_role)
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
Ok(Self { inner })
}
fn process(&mut self, token: u32) -> PyResult<()> {
self.inner
.process(token)
.map(|_| ())
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
fn process_eos(&mut self) -> PyResult<()> {
self.inner
.process_eos()
.map(|_| ())
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
#[getter]
fn current_content(&self) -> PyResult<String> {
self.inner
.current_content()
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
#[getter]
fn current_role(&self) -> Option<String> {
self.inner.current_role().map(|r| r.as_str().to_string())
}
#[getter]
fn current_content_type(&self) -> Option<String> {
self.inner.current_content_type()
}
#[getter]
fn last_content_delta(&self) -> PyResult<Option<String>> {
self.inner
.last_content_delta()
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
#[getter]
fn messages(&self) -> PyResult<String> {
serde_json::to_string(self.inner.messages()).map_err(|e| {
PyErr::new::<HarmonyError, _>(format!("failed to serialise messages to JSON: {e}"))
})
}
#[getter]
fn tokens(&self) -> Vec<u32> {
self.inner.tokens().to_vec()
}
#[getter]
fn state(&self) -> PyResult<String> {
self.inner
.state_json()
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
#[getter]
fn current_recipient(&self) -> Option<String> {
self.inner.current_recipient()
}
#[getter]
fn current_channel(&self) -> Option<String> {
self.inner.current_channel()
}
}
/// Python module definition.
#[pymodule]
fn openai_harmony(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
// Register the PyHarmonyEncoding class.
m.add_class::<PyHarmonyEncoding>()?;
m.add_class::<PyStreamableParser>()?;
m.add_class::<PyStreamState>()?;
m.add("HarmonyError", _py.get_type::<HarmonyError>())?;
// Convenience function mirroring the Rust-side `load_harmony_encoding` but
// returning an *instance* of `PyHarmonyEncoding`.
#[pyfunction(name = "load_harmony_encoding")]
fn load_harmony_encoding_py(py: Python<'_>, name: &str) -> PyResult<Py<PyHarmonyEncoding>> {
let enc = PyHarmonyEncoding::new(name)?;
Py::new(py, enc)
}
m.add_function(pyo3::wrap_pyfunction!(load_harmony_encoding_py, m)?)?;
// Convenience functions to get the tool configs for the browser and python tools.
#[pyfunction]
fn get_tool_namespace_config(py: Python<'_>, tool: &str) -> PyResult<PyObject> {
let cfg = match tool {
"browser" => ToolNamespaceConfig::browser(),
"python" => ToolNamespaceConfig::python(),
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Unknown tool namespace: {tool}"
)));
}
};
let py_cfg =
serde_json::to_value(&cfg).map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
let json_str = serde_json::to_string(&py_cfg)
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
let json_mod = PyModule::import(py, "json")?;
let py_obj = json_mod.call_method1("loads", (json_str,))?;
Ok(py_obj.into())
}
m.add_function(pyo3::wrap_pyfunction!(get_tool_namespace_config, m)?)?;
Ok(())
}