From d00ac3de4947868ecb00f8a0c7652ccdacb493d8 Mon Sep 17 00:00:00 2001 From: Jordan Wu <101218661+jordan-definitive@users.noreply.github.com> Date: Tue, 5 Aug 2025 14:52:12 -0700 Subject: [PATCH] fix: make `HarmonyEncoding` usable concurrently the `conversation_has_function_tools` atomic bool makes `HarmonyEncoding` stateful --- python/openai_harmony/__init__.py | 19 ++++++- src/encoding.rs | 91 ++++++++++++++++++++++--------- src/py_module.rs | 20 ++++++- src/registry.rs | 3 +- src/tests.rs | 2 +- src/wasm_module.rs | 29 +++++++++- 6 files changed, 130 insertions(+), 34 deletions(-) diff --git a/python/openai_harmony/__init__.py b/python/openai_harmony/__init__.py index 3485864..13b5fdd 100644 --- a/python/openai_harmony/__init__.py +++ b/python/openai_harmony/__init__.py @@ -425,6 +425,10 @@ class RenderConversationConfig(BaseModel): auto_drop_analysis: bool = True +class RenderOptions(BaseModel): + conversation_has_function_tools: bool = False + + class HarmonyEncoding: """High-level wrapper around the Rust ``PyHarmonyEncoding`` class.""" @@ -498,9 +502,20 @@ class HarmonyEncoding: config=config_dict, ) - def render(self, message: Message) -> List[int]: + def render( + self, message: Message, render_options: Optional[RenderOptions] = None + ) -> List[int]: """Render a single message into tokens.""" - return self._inner.render(message_json=message.to_json()) + if render_options is None: + render_options_dict = {"conversation_has_function_tools": False} + else: + render_options_dict = { + "conversation_has_function_tools": render_options.conversation_has_function_tools + } + + return self._inner.render( + message_json=message.to_json(), render_options=render_options_dict + ) # -- Parsing ------------------------------------------------------- diff --git a/src/encoding.rs b/src/encoding.rs index c58e8b8..afe1fce 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -5,10 +5,7 @@ use crate::{ use anyhow::Context as _; use std::{ collections::{HashMap, HashSet}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::Arc, vec, }; @@ -92,7 +89,6 @@ pub struct HarmonyEncoding { pub(crate) format_token_mapping: HashMap, pub(crate) stop_formatting_tokens: HashSet, pub(crate) stop_formatting_tokens_for_assistant_actions: HashSet, - pub(crate) conversation_has_function_tools: Arc, } impl std::fmt::Debug for HarmonyEncoding { @@ -191,8 +187,9 @@ impl HarmonyEncoding { } }) }); - self.conversation_has_function_tools - .store(has_function_tools, Ordering::Relaxed); + let render_options = RenderOptions { + conversation_has_function_tools: has_function_tools, + }; let last_assistant_is_final = messages .iter() .rev() @@ -217,9 +214,7 @@ impl HarmonyEncoding { && first_final_idx.is_some_and(|first| *idx < first) && msg.channel.as_deref() == Some("analysis")) }) - .try_for_each(|(_, msg)| self.render_into(msg, into)); - self.conversation_has_function_tools - .store(false, Ordering::Relaxed); + .try_for_each(|(_, msg)| self.render_into(msg, into, Some(&render_options))); result?; Ok(()) } @@ -305,18 +300,27 @@ impl HarmonyEncoding { } /// Render a single message into tokens. - pub fn render(&self, message: &Message) -> anyhow::Result> { + pub fn render( + &self, + message: &Message, + render_options: Option<&RenderOptions>, + ) -> anyhow::Result> { let mut out = vec![]; - Render::::render(self, message, &mut out)?; + Render::::render(self, message, &mut out, render_options)?; Ok(out) } /// Render a single message into the provided buffer. - pub fn render_into(&self, message: &Message, into: &mut B) -> anyhow::Result<()> + pub fn render_into( + &self, + message: &Message, + into: &mut B, + render_options: Option<&RenderOptions>, + ) -> anyhow::Result<()> where B: Extend, { - Render::::render(self, message, into) + Render::::render(self, message, into, render_options) } } @@ -772,14 +776,29 @@ impl HarmonyEncoding { } } +#[derive(Clone, Copy, Debug, Default)] +pub struct RenderOptions { + pub conversation_has_function_tools: bool, +} + trait Render { - fn render(&self, item: &T, into: &mut B) -> anyhow::Result<()> + fn render( + &self, + item: &T, + into: &mut B, + render_options: Option<&RenderOptions>, + ) -> anyhow::Result<()> where B: Extend; } impl Render for HarmonyEncoding { - fn render(&self, message: &Message, into: &mut B) -> anyhow::Result<()> + fn render( + &self, + message: &Message, + into: &mut B, + render_options: Option<&RenderOptions>, + ) -> anyhow::Result<()> where B: Extend, { @@ -836,7 +855,7 @@ impl Render for HarmonyEncoding { message.author.role ); } - Render::::render(self, content, into)?; + Render::::render(self, content, into, render_options)?; } // If there is a tool call we should render a tool call token @@ -851,15 +870,22 @@ impl Render for HarmonyEncoding { // Dispatch Content variants to their specific Render implementations impl Render for HarmonyEncoding { - fn render(&self, content: &Content, into: &mut B) -> anyhow::Result<()> + fn render( + &self, + content: &Content, + into: &mut B, + render_options: Option<&RenderOptions>, + ) -> anyhow::Result<()> where B: Extend, { match content { - Content::Text(text) => Render::::render(self, text, into), - Content::SystemContent(sys) => Render::::render(self, sys, into), + Content::Text(text) => Render::::render(self, text, into, render_options), + Content::SystemContent(sys) => { + Render::::render(self, sys, into, render_options) + } Content::DeveloperContent(dev) => { - Render::::render(self, dev, into) + Render::::render(self, dev, into, render_options) } } } @@ -867,7 +893,12 @@ impl Render for HarmonyEncoding { // Render plain text content impl Render for HarmonyEncoding { - fn render(&self, text: &TextContent, into: &mut B) -> anyhow::Result<()> + fn render( + &self, + text: &TextContent, + into: &mut B, + _render_options: Option<&RenderOptions>, + ) -> anyhow::Result<()> where B: Extend, { @@ -877,7 +908,12 @@ impl Render for HarmonyEncoding { // Render system-specific content (model identity, instructions, effort) impl Render for HarmonyEncoding { - fn render(&self, sys: &SystemContent, into: &mut B) -> anyhow::Result<()> + fn render( + &self, + sys: &SystemContent, + into: &mut B, + render_options: Option<&RenderOptions>, + ) -> anyhow::Result<()> where B: Extend, { @@ -923,7 +959,7 @@ impl Render for HarmonyEncoding { if channel_config.channel_required { channels_header.push_str(" Channel must be included for every message."); } - if self.conversation_has_function_tools.load(Ordering::Relaxed) { + if render_options.is_some_and(|o| o.conversation_has_function_tools) { channels_header.push('\n'); channels_header.push_str( "Calls to these tools must go to the commentary channel: 'functions'.", @@ -940,7 +976,12 @@ impl Render for HarmonyEncoding { // Render developer-specific content (instructions, tools) impl Render for HarmonyEncoding { - fn render(&self, dev: &crate::chat::DeveloperContent, into: &mut B) -> anyhow::Result<()> + fn render( + &self, + dev: &crate::chat::DeveloperContent, + into: &mut B, + _render_options: Option<&RenderOptions>, + ) -> anyhow::Result<()> where B: Extend, { diff --git a/src/py_module.rs b/src/py_module.rs index e7bb9e5..c5c7b0a 100644 --- a/src/py_module.rs +++ b/src/py_module.rs @@ -178,13 +178,29 @@ impl PyHarmonyEncoding { } /// Render a single message into tokens. - fn render(&self, message_json: &str) -> PyResult> { + fn render( + &self, + message_json: &str, + render_options: Option>, + ) -> PyResult> { let message: crate::chat::Message = serde_json::from_str(message_json).map_err(|e| { PyErr::new::(format!("invalid message JSON: {e}")) })?; + let rust_options = if let Some(options_dict) = render_options { + let conversation_has_function_tools = options_dict + .get_item("conversation_has_function_tools")? + .and_then(|v| v.extract().ok()) + .unwrap_or(false); + Some(crate::encoding::RenderOptions { + conversation_has_function_tools, + }) + } else { + None + }; + self.inner - .render(&message) + .render(&message, rust_options.as_ref()) .map_err(|e| PyErr::new::(e.to_string())) } diff --git a/src/registry.rs b/src/registry.rs index 6d8a98f..d1ffd2e 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -1,6 +1,6 @@ use std::{ collections::{HashMap, HashSet}, - sync::{atomic::AtomicBool, Arc}, + sync::Arc, }; use crate::{ @@ -76,7 +76,6 @@ pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result Result, JsValue> { + pub fn render( + &self, + message: JsMessage, + render_options: JsRenderOptions, + ) -> Result, JsValue> { let message: JsValue = message.into(); let message: crate::chat::Message = serde_wasm_bindgen::from_value(message) .map_err(|e| JsValue::from_str(&format!("invalid message JSON: {e}")))?; + + #[derive(Deserialize)] + struct RenderOptions { + conversation_has_function_tools: Option, + } + let render_options: JsValue = render_options.into(); + let rust_options = if render_options.is_undefined() || render_options.is_null() { + None + } else { + let cfg: RenderOptions = serde_wasm_bindgen::from_value(render_options) + .map_err(|e| JsValue::from_str(&format!("invalid render options: {e}")))?; + Some(crate::encoding::RenderOptions { + conversation_has_function_tools: cfg + .conversation_has_function_tools + .unwrap_or(false), + }) + }; + self.inner - .render(&message) + .render(&message, rust_options.as_ref()) .map_err(|e| JsValue::from_str(&e.to_string())) }