Merge pull request #26 from jordan-wu-97/jordan/fix-function-call-atomic-bool

fix: make `HarmonyEncoding` usable concurrently
This commit is contained in:
Scott Lessans 2025-08-05 17:12:50 -07:00 committed by GitHub
commit 82b3afb9eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 130 additions and 34 deletions

View file

@ -425,6 +425,10 @@ class RenderConversationConfig(BaseModel):
auto_drop_analysis: bool = True
class RenderOptions(BaseModel):
conversation_has_function_tools: bool = False
class HarmonyEncoding:
"""High-level wrapper around the Rust ``PyHarmonyEncoding`` class."""
@ -498,9 +502,20 @@ class HarmonyEncoding:
config=config_dict,
)
def render(self, message: Message) -> List[int]:
def render(
self, message: Message, render_options: Optional[RenderOptions] = None
) -> List[int]:
"""Render a single message into tokens."""
return self._inner.render(message_json=message.to_json())
if render_options is None:
render_options_dict = {"conversation_has_function_tools": False}
else:
render_options_dict = {
"conversation_has_function_tools": render_options.conversation_has_function_tools
}
return self._inner.render(
message_json=message.to_json(), render_options=render_options_dict
)
# -- Parsing -------------------------------------------------------

View file

@ -5,10 +5,7 @@ use crate::{
use anyhow::Context as _;
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
sync::Arc,
vec,
};
@ -92,7 +89,6 @@ pub struct HarmonyEncoding {
pub(crate) format_token_mapping: HashMap<FormattingToken, String>,
pub(crate) stop_formatting_tokens: HashSet<FormattingToken>,
pub(crate) stop_formatting_tokens_for_assistant_actions: HashSet<FormattingToken>,
pub(crate) conversation_has_function_tools: Arc<AtomicBool>,
}
impl std::fmt::Debug for HarmonyEncoding {
@ -191,8 +187,9 @@ impl HarmonyEncoding {
}
})
});
self.conversation_has_function_tools
.store(has_function_tools, Ordering::Relaxed);
let render_options = RenderOptions {
conversation_has_function_tools: has_function_tools,
};
let last_assistant_is_final = messages
.iter()
.rev()
@ -217,9 +214,7 @@ impl HarmonyEncoding {
&& first_final_idx.is_some_and(|first| *idx < first)
&& msg.channel.as_deref() == Some("analysis"))
})
.try_for_each(|(_, msg)| self.render_into(msg, into));
self.conversation_has_function_tools
.store(false, Ordering::Relaxed);
.try_for_each(|(_, msg)| self.render_into(msg, into, Some(&render_options)));
result?;
Ok(())
}
@ -305,18 +300,27 @@ impl HarmonyEncoding {
}
/// Render a single message into tokens.
pub fn render(&self, message: &Message) -> anyhow::Result<Vec<Rank>> {
pub fn render(
&self,
message: &Message,
render_options: Option<&RenderOptions>,
) -> anyhow::Result<Vec<Rank>> {
let mut out = vec![];
Render::<Message>::render(self, message, &mut out)?;
Render::<Message>::render(self, message, &mut out, render_options)?;
Ok(out)
}
/// Render a single message into the provided buffer.
pub fn render_into<B>(&self, message: &Message, into: &mut B) -> anyhow::Result<()>
pub fn render_into<B>(
&self,
message: &Message,
into: &mut B,
render_options: Option<&RenderOptions>,
) -> anyhow::Result<()>
where
B: Extend<Rank>,
{
Render::<Message>::render(self, message, into)
Render::<Message>::render(self, message, into, render_options)
}
}
@ -772,14 +776,29 @@ impl HarmonyEncoding {
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct RenderOptions {
pub conversation_has_function_tools: bool,
}
trait Render<T: ?Sized> {
fn render<B>(&self, item: &T, into: &mut B) -> anyhow::Result<()>
fn render<B>(
&self,
item: &T,
into: &mut B,
render_options: Option<&RenderOptions>,
) -> anyhow::Result<()>
where
B: Extend<Rank>;
}
impl Render<Message> for HarmonyEncoding {
fn render<B>(&self, message: &Message, into: &mut B) -> anyhow::Result<()>
fn render<B>(
&self,
message: &Message,
into: &mut B,
render_options: Option<&RenderOptions>,
) -> anyhow::Result<()>
where
B: Extend<Rank>,
{
@ -836,7 +855,7 @@ impl Render<Message> for HarmonyEncoding {
message.author.role
);
}
Render::<Content>::render(self, content, into)?;
Render::<Content>::render(self, content, into, render_options)?;
}
// If there is a tool call we should render a tool call token
@ -851,15 +870,22 @@ impl Render<Message> for HarmonyEncoding {
// Dispatch Content variants to their specific Render implementations
impl Render<Content> for HarmonyEncoding {
fn render<B>(&self, content: &Content, into: &mut B) -> anyhow::Result<()>
fn render<B>(
&self,
content: &Content,
into: &mut B,
render_options: Option<&RenderOptions>,
) -> anyhow::Result<()>
where
B: Extend<Rank>,
{
match content {
Content::Text(text) => Render::<TextContent>::render(self, text, into),
Content::SystemContent(sys) => Render::<SystemContent>::render(self, sys, into),
Content::Text(text) => Render::<TextContent>::render(self, text, into, render_options),
Content::SystemContent(sys) => {
Render::<SystemContent>::render(self, sys, into, render_options)
}
Content::DeveloperContent(dev) => {
Render::<crate::chat::DeveloperContent>::render(self, dev, into)
Render::<crate::chat::DeveloperContent>::render(self, dev, into, render_options)
}
}
}
@ -867,7 +893,12 @@ impl Render<Content> for HarmonyEncoding {
// Render plain text content
impl Render<TextContent> for HarmonyEncoding {
fn render<B>(&self, text: &TextContent, into: &mut B) -> anyhow::Result<()>
fn render<B>(
&self,
text: &TextContent,
into: &mut B,
_render_options: Option<&RenderOptions>,
) -> anyhow::Result<()>
where
B: Extend<Rank>,
{
@ -877,7 +908,12 @@ impl Render<TextContent> for HarmonyEncoding {
// Render system-specific content (model identity, instructions, effort)
impl Render<SystemContent> for HarmonyEncoding {
fn render<B>(&self, sys: &SystemContent, into: &mut B) -> anyhow::Result<()>
fn render<B>(
&self,
sys: &SystemContent,
into: &mut B,
render_options: Option<&RenderOptions>,
) -> anyhow::Result<()>
where
B: Extend<Rank>,
{
@ -923,7 +959,7 @@ impl Render<SystemContent> for HarmonyEncoding {
if channel_config.channel_required {
channels_header.push_str(" Channel must be included for every message.");
}
if self.conversation_has_function_tools.load(Ordering::Relaxed) {
if render_options.is_some_and(|o| o.conversation_has_function_tools) {
channels_header.push('\n');
channels_header.push_str(
"Calls to these tools must go to the commentary channel: 'functions'.",
@ -940,7 +976,12 @@ impl Render<SystemContent> for HarmonyEncoding {
// Render developer-specific content (instructions, tools)
impl Render<crate::chat::DeveloperContent> for HarmonyEncoding {
fn render<B>(&self, dev: &crate::chat::DeveloperContent, into: &mut B) -> anyhow::Result<()>
fn render<B>(
&self,
dev: &crate::chat::DeveloperContent,
into: &mut B,
_render_options: Option<&RenderOptions>,
) -> anyhow::Result<()>
where
B: Extend<Rank>,
{

View file

@ -178,13 +178,29 @@ impl PyHarmonyEncoding {
}
/// Render a single message into tokens.
fn render(&self, message_json: &str) -> PyResult<Vec<u32>> {
fn render(
&self,
message_json: &str,
render_options: Option<Bound<'_, PyDict>>,
) -> PyResult<Vec<u32>> {
let message: crate::chat::Message = serde_json::from_str(message_json).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("invalid message JSON: {e}"))
})?;
let rust_options = if let Some(options_dict) = render_options {
let conversation_has_function_tools = options_dict
.get_item("conversation_has_function_tools")?
.and_then(|v| v.extract().ok())
.unwrap_or(false);
Some(crate::encoding::RenderOptions {
conversation_has_function_tools,
})
} else {
None
};
self.inner
.render(&message)
.render(&message, rust_options.as_ref())
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
}

View file

@ -1,6 +1,6 @@
use std::{
collections::{HashMap, HashSet},
sync::{atomic::AtomicBool, Arc},
sync::Arc,
};
use crate::{
@ -76,7 +76,6 @@ pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<Harmon
FormattingToken::EndMessageDoneSampling,
FormattingToken::EndMessageAssistantToTool,
]),
conversation_has_function_tools: Arc::new(AtomicBool::new(false)),
})
}
}

View file

@ -525,7 +525,7 @@ fn test_render_and_render_conversation_roundtrip() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let msg = Message::from_role_and_content(Role::User, "Hello");
let convo = Conversation::from_messages([msg.clone()]);
let tokens_msg = encoding.render(&msg).unwrap();
let tokens_msg = encoding.render(&msg, None).unwrap();
let tokens_convo = encoding.render_conversation(&convo, None).unwrap();
assert_eq!(tokens_msg, tokens_convo);
let tokens_completion = encoding

View file

@ -18,6 +18,9 @@ extern "C" {
#[wasm_bindgen(typescript_type = "RenderConversationConfig")]
pub type JsRenderConversationConfig;
#[wasm_bindgen(typescript_type = "RenderOptions")]
pub type JsRenderOptions;
}
#[wasm_bindgen(typescript_custom_section)]
@ -127,12 +130,34 @@ impl JsHarmonyEncoding {
}
#[wasm_bindgen]
pub fn render(&self, message: JsMessage) -> Result<Vec<u32>, JsValue> {
pub fn render(
&self,
message: JsMessage,
render_options: JsRenderOptions,
) -> Result<Vec<u32>, JsValue> {
let message: JsValue = message.into();
let message: crate::chat::Message = serde_wasm_bindgen::from_value(message)
.map_err(|e| JsValue::from_str(&format!("invalid message JSON: {e}")))?;
#[derive(Deserialize)]
struct RenderOptions {
conversation_has_function_tools: Option<bool>,
}
let render_options: JsValue = render_options.into();
let rust_options = if render_options.is_undefined() || render_options.is_null() {
None
} else {
let cfg: RenderOptions = serde_wasm_bindgen::from_value(render_options)
.map_err(|e| JsValue::from_str(&format!("invalid render options: {e}")))?;
Some(crate::encoding::RenderOptions {
conversation_has_function_tools: cfg
.conversation_has_function_tools
.unwrap_or(false),
})
};
self.inner
.render(&message)
.render(&message, rust_options.as_ref())
.map_err(|e| JsValue::from_str(&e.to_string()))
}