mirror of
https://github.com/openai/harmony.git
synced 2025-09-03 06:25:48 -04:00
Initial commit
Co-authored-by: scott-oai <142930063+scott-oai@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan@openai.com>
This commit is contained in:
commit
253cdca537
70 changed files with 15013 additions and 0 deletions
405
src/py_module.rs
Normal file
405
src/py_module.rs
Normal file
|
@ -0,0 +1,405 @@
|
|||
//! Python bindings for the harmony crate.
|
||||
//!
|
||||
//! The bindings are kept intentionally small: we expose the `HarmonyEncoding` type
|
||||
//! together with the operations that are required by the original Rust test
|
||||
//! suite (rendering a conversation for completion, parsing messages from
|
||||
//! completion tokens and decoding tokens back into UTF-8). All higher-level
|
||||
//! data-structures (Conversation, Message, SystemContent, DeveloperContent, …) are passed across the FFI
|
||||
//! boundary as JSON. This allows us to keep the Rust ↔ Python interface very
|
||||
//! light-weight while still re-using the exact same logic that is implemented
|
||||
//! in Rust.
|
||||
//!
|
||||
//! A thin, typed, user-facing Python wrapper around these low-level bindings is
|
||||
//! provided in `harmony/__init__.py`.
|
||||
|
||||
// Only compile when the `python-binding` feature is enabled.
|
||||
#![cfg(feature = "python-binding")]
|
||||
|
||||
use pyo3::prelude::*;
|
||||
|
||||
// We need the `Python` type later on.
|
||||
use pyo3::create_exception;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::Python;
|
||||
|
||||
use pyo3::types::{PyAny, PyDict, PyModule};
|
||||
use pyo3::Bound;
|
||||
|
||||
// Define a custom Python exception so users can catch Harmony specific errors.
|
||||
create_exception!(openai_harmony, HarmonyError, PyRuntimeError);
|
||||
|
||||
use crate::{
|
||||
chat::{Message, Role, ToolNamespaceConfig},
|
||||
encoding::{HarmonyEncoding, StreamableParser},
|
||||
load_harmony_encoding, HarmonyEncodingName,
|
||||
};
|
||||
|
||||
use serde_json;
|
||||
|
||||
/// A thin PyO3 wrapper around the Rust `HarmonyEncoding` struct.
|
||||
#[pyclass]
|
||||
struct PyHarmonyEncoding {
|
||||
inner: HarmonyEncoding,
|
||||
}
|
||||
|
||||
/// Streaming parser exposed to Python.
|
||||
#[pyclass]
|
||||
struct PyStreamableParser {
|
||||
inner: StreamableParser,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub enum PyStreamState {
|
||||
ExpectStart,
|
||||
Header,
|
||||
Content,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyHarmonyEncoding {
|
||||
/// Create a new `HarmonyEncoding` by name.
|
||||
#[new]
|
||||
fn new(name: &str) -> PyResult<Self> {
|
||||
let parsed: HarmonyEncodingName = name
|
||||
.parse::<HarmonyEncodingName>()
|
||||
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
|
||||
let encoding = load_harmony_encoding(parsed)
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
|
||||
Ok(Self { inner: encoding })
|
||||
}
|
||||
|
||||
/// Return the name of the encoding.
|
||||
#[getter]
|
||||
fn name(&self) -> &str {
|
||||
self.inner.name()
|
||||
}
|
||||
|
||||
/// Render a conversation (in JSON format) for completion.
|
||||
///
|
||||
/// Parameters
|
||||
/// ----------
|
||||
/// conversation_json : str
|
||||
/// A JSON encoded `Conversation` (as produced by `serde_json`).
|
||||
/// next_turn_role : str
|
||||
/// The role of the *next* turn (e.g. "assistant").
|
||||
/// config : dict (optional)
|
||||
/// Optional config dict. Only supports 'auto_drop_analysis' (bool).
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// List[int]
|
||||
/// The encoded token sequence.
|
||||
fn render_conversation_for_completion(
|
||||
&self,
|
||||
conversation_json: &str,
|
||||
next_turn_role: &str,
|
||||
config: Option<Bound<'_, PyDict>>,
|
||||
) -> PyResult<Vec<u32>> {
|
||||
// Deserialize the conversation first.
|
||||
let conversation: crate::chat::Conversation = serde_json::from_str(conversation_json)
|
||||
.map_err(|e| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"invalid conversation JSON: {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
// Convert the role string into the `Role` enum.
|
||||
let role = Role::try_from(next_turn_role).map_err(|_| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"unknown role: {next_turn_role}"
|
||||
))
|
||||
})?;
|
||||
|
||||
// Parse config
|
||||
let rust_config = if let Some(cfg_dict) = config {
|
||||
let auto_drop_analysis = cfg_dict
|
||||
.get_item("auto_drop_analysis")?
|
||||
.and_then(|v| v.extract().ok())
|
||||
.unwrap_or(true);
|
||||
Some(crate::encoding::RenderConversationConfig { auto_drop_analysis })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
self.inner
|
||||
.render_conversation_for_completion(&conversation, role, rust_config.as_ref())
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
/// Render a conversation without appending a new role.
|
||||
fn render_conversation(
|
||||
&self,
|
||||
conversation_json: &str,
|
||||
config: Option<Bound<'_, PyDict>>,
|
||||
) -> PyResult<Vec<u32>> {
|
||||
let conversation: crate::chat::Conversation = serde_json::from_str(conversation_json)
|
||||
.map_err(|e| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"invalid conversation JSON: {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let rust_config = if let Some(cfg_dict) = config {
|
||||
let auto_drop_analysis = cfg_dict
|
||||
.get_item("auto_drop_analysis")?
|
||||
.and_then(|v| v.extract().ok())
|
||||
.unwrap_or(true);
|
||||
Some(crate::encoding::RenderConversationConfig { auto_drop_analysis })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
self.inner
|
||||
.render_conversation(&conversation, rust_config.as_ref())
|
||||
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
/// Render a conversation for training.
|
||||
fn render_conversation_for_training(
|
||||
&self,
|
||||
conversation_json: &str,
|
||||
config: Option<Bound<'_, PyDict>>,
|
||||
) -> PyResult<Vec<u32>> {
|
||||
let conversation: crate::chat::Conversation = serde_json::from_str(conversation_json)
|
||||
.map_err(|e| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"invalid conversation JSON: {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let rust_config = if let Some(cfg_dict) = config {
|
||||
let auto_drop_analysis = cfg_dict
|
||||
.get_item("auto_drop_analysis")?
|
||||
.and_then(|v| v.extract().ok())
|
||||
.unwrap_or(true);
|
||||
Some(crate::encoding::RenderConversationConfig { auto_drop_analysis })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
self.inner
|
||||
.render_conversation_for_training(&conversation, rust_config.as_ref())
|
||||
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
/// Render a single message into tokens.
|
||||
fn render(&self, message_json: &str) -> PyResult<Vec<u32>> {
|
||||
let message: crate::chat::Message = serde_json::from_str(message_json).map_err(|e| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("invalid message JSON: {e}"))
|
||||
})?;
|
||||
|
||||
self.inner
|
||||
.render(&message)
|
||||
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
/// Given a list of completion tokens, parse them back into a sequence of
|
||||
/// messages. The result is returned as a JSON string which can be
|
||||
/// deserialised on the Python side.
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
fn parse_messages_from_completion_tokens(
|
||||
&self,
|
||||
tokens: Vec<u32>,
|
||||
role: Option<&str>,
|
||||
) -> PyResult<String> {
|
||||
let role_parsed = if let Some(r) = role {
|
||||
Some(Role::try_from(r).map_err(|_| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("unknown role: {r}"))
|
||||
})?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let messages: Vec<Message> = self
|
||||
.inner
|
||||
.parse_messages_from_completion_tokens(tokens, role_parsed)
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
|
||||
|
||||
serde_json::to_string(&messages).map_err(|e| {
|
||||
PyErr::new::<HarmonyError, _>(format!("failed to serialise messages to JSON: {e}"))
|
||||
})
|
||||
}
|
||||
|
||||
/// Decode a sequence of tokens into text using the underlying tokenizer.
|
||||
fn decode_utf8(&self, tokens: Vec<u32>) -> PyResult<String> {
|
||||
self.inner
|
||||
.tokenizer()
|
||||
.decode_utf8(tokens)
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
/// Decode a sequence of tokens into raw bytes using the underlying tokenizer.
|
||||
fn decode_bytes(&self, tokens: Vec<u32>) -> PyResult<Vec<u8>> {
|
||||
self.inner
|
||||
.tokenizer()
|
||||
.decode_bytes(tokens)
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
/// Encode text into tokens using the underlying tokenizer with a set of allowed special tokens.
|
||||
fn encode(&self, text: &str, allowed_special: Option<Bound<'_, PyAny>>) -> PyResult<Vec<u32>> {
|
||||
let allowed_vec: Vec<String> = match allowed_special {
|
||||
Some(obj) => obj.extract::<Vec<String>>().map_err(|e| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"invalid allowed_special: {e}"
|
||||
))
|
||||
})?,
|
||||
None => Vec::new(),
|
||||
};
|
||||
let allowed_set: std::collections::HashSet<&str> =
|
||||
allowed_vec.iter().map(|s| s.as_str()).collect();
|
||||
Ok(self.inner.tokenizer().encode(text, &allowed_set).0)
|
||||
}
|
||||
|
||||
/// Return the list of special tokens for this tokenizer.
|
||||
fn special_tokens(&self) -> Vec<String> {
|
||||
self.inner
|
||||
.tokenizer()
|
||||
.special_tokens()
|
||||
.into_iter()
|
||||
.map(str::to_string)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Check whether a token id corresponds to a special token.
|
||||
fn is_special_token(&self, token: u32) -> bool {
|
||||
self.inner.tokenizer().is_special_token(token)
|
||||
}
|
||||
|
||||
/// Return the stop tokens for the encoding.
|
||||
fn stop_tokens(&self) -> PyResult<Vec<u32>> {
|
||||
self.inner
|
||||
.stop_tokens()
|
||||
.map(|set| set.into_iter().collect())
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
/// Return the stop tokens for assistant actions.
|
||||
fn stop_tokens_for_assistant_actions(&self) -> PyResult<Vec<u32>> {
|
||||
self.inner
|
||||
.stop_tokens_for_assistant_actions()
|
||||
.map(|set| set.into_iter().collect())
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyStreamableParser {
|
||||
#[new]
|
||||
fn new(encoding: &PyHarmonyEncoding, role: Option<&str>) -> PyResult<Self> {
|
||||
let parsed_role = role
|
||||
.map(|r| {
|
||||
Role::try_from(r).map_err(|_| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("unknown role: {r}"))
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
let inner = StreamableParser::new(encoding.inner.clone(), parsed_role)
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
fn process(&mut self, token: u32) -> PyResult<()> {
|
||||
self.inner
|
||||
.process(token)
|
||||
.map(|_| ())
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn current_content(&self) -> PyResult<String> {
|
||||
self.inner
|
||||
.current_content()
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn current_role(&self) -> Option<String> {
|
||||
self.inner.current_role().map(|r| r.as_str().to_string())
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn current_content_type(&self) -> Option<String> {
|
||||
self.inner.current_content_type()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn last_content_delta(&self) -> PyResult<Option<String>> {
|
||||
self.inner
|
||||
.last_content_delta()
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn messages(&self) -> PyResult<String> {
|
||||
serde_json::to_string(self.inner.messages()).map_err(|e| {
|
||||
PyErr::new::<HarmonyError, _>(format!("failed to serialise messages to JSON: {e}"))
|
||||
})
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn tokens(&self) -> Vec<u32> {
|
||||
self.inner.tokens().to_vec()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn state(&self) -> PyResult<String> {
|
||||
self.inner
|
||||
.state_json()
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn current_recipient(&self) -> Option<String> {
|
||||
self.inner.current_recipient()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn current_channel(&self) -> Option<String> {
|
||||
self.inner.current_channel()
|
||||
}
|
||||
}
|
||||
|
||||
/// Python module definition.
|
||||
#[pymodule]
|
||||
fn openai_harmony(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// Register the PyHarmonyEncoding class.
|
||||
m.add_class::<PyHarmonyEncoding>()?;
|
||||
m.add_class::<PyStreamableParser>()?;
|
||||
m.add_class::<PyStreamState>()?;
|
||||
m.add("HarmonyError", _py.get_type::<HarmonyError>())?;
|
||||
|
||||
// Convenience function mirroring the Rust-side `load_harmony_encoding` but
|
||||
// returning an *instance* of `PyHarmonyEncoding`.
|
||||
#[pyfunction(name = "load_harmony_encoding")]
|
||||
fn load_harmony_encoding_py(py: Python<'_>, name: &str) -> PyResult<Py<PyHarmonyEncoding>> {
|
||||
let enc = PyHarmonyEncoding::new(name)?;
|
||||
Py::new(py, enc)
|
||||
}
|
||||
m.add_function(pyo3::wrap_pyfunction!(load_harmony_encoding_py, m)?)?;
|
||||
|
||||
// Convenience functions to get the tool configs for the browser and python tools.
|
||||
#[pyfunction]
|
||||
fn get_tool_namespace_config(py: Python<'_>, tool: &str) -> PyResult<PyObject> {
|
||||
let cfg = match tool {
|
||||
"browser" => ToolNamespaceConfig::browser(),
|
||||
"python" => ToolNamespaceConfig::python(),
|
||||
_ => {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"Unknown tool namespace: {}",
|
||||
tool
|
||||
)));
|
||||
}
|
||||
};
|
||||
let py_cfg =
|
||||
serde_json::to_value(&cfg).map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
|
||||
let json_str = serde_json::to_string(&py_cfg)
|
||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
|
||||
let json_mod = PyModule::import(py, "json")?;
|
||||
let py_obj = json_mod.call_method1("loads", (json_str,))?;
|
||||
Ok(py_obj.into())
|
||||
}
|
||||
m.add_function(pyo3::wrap_pyfunction!(get_tool_namespace_config, m)?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue