mirror of
https://github.com/openai/harmony.git
synced 2025-08-22 16:17:08 -04:00
fix: make HarmonyEncoding
usable concurrently
the `conversation_has_function_tools` atomic bool makes `HarmonyEncoding` stateful
This commit is contained in:
parent
b255cbeb62
commit
d00ac3de49
6 changed files with 130 additions and 34 deletions
|
@ -425,6 +425,10 @@ class RenderConversationConfig(BaseModel):
|
|||
auto_drop_analysis: bool = True
|
||||
|
||||
|
||||
class RenderOptions(BaseModel):
|
||||
conversation_has_function_tools: bool = False
|
||||
|
||||
|
||||
class HarmonyEncoding:
|
||||
"""High-level wrapper around the Rust ``PyHarmonyEncoding`` class."""
|
||||
|
||||
|
@ -498,9 +502,20 @@ class HarmonyEncoding:
|
|||
config=config_dict,
|
||||
)
|
||||
|
||||
def render(self, message: Message) -> List[int]:
|
||||
def render(
|
||||
self, message: Message, render_options: Optional[RenderOptions] = None
|
||||
) -> List[int]:
|
||||
"""Render a single message into tokens."""
|
||||
return self._inner.render(message_json=message.to_json())
|
||||
if render_options is None:
|
||||
render_options_dict = {"conversation_has_function_tools": False}
|
||||
else:
|
||||
render_options_dict = {
|
||||
"conversation_has_function_tools": render_options.conversation_has_function_tools
|
||||
}
|
||||
|
||||
return self._inner.render(
|
||||
message_json=message.to_json(), render_options=render_options_dict
|
||||
)
|
||||
|
||||
# -- Parsing -------------------------------------------------------
|
||||
|
||||
|
|
|
@ -5,10 +5,7 @@ use crate::{
|
|||
use anyhow::Context as _;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
sync::Arc,
|
||||
vec,
|
||||
};
|
||||
|
||||
|
@ -92,7 +89,6 @@ pub struct HarmonyEncoding {
|
|||
pub(crate) format_token_mapping: HashMap<FormattingToken, String>,
|
||||
pub(crate) stop_formatting_tokens: HashSet<FormattingToken>,
|
||||
pub(crate) stop_formatting_tokens_for_assistant_actions: HashSet<FormattingToken>,
|
||||
pub(crate) conversation_has_function_tools: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for HarmonyEncoding {
|
||||
|
@ -191,8 +187,9 @@ impl HarmonyEncoding {
|
|||
}
|
||||
})
|
||||
});
|
||||
self.conversation_has_function_tools
|
||||
.store(has_function_tools, Ordering::Relaxed);
|
||||
let render_options = RenderOptions {
|
||||
conversation_has_function_tools: has_function_tools,
|
||||
};
|
||||
let last_assistant_is_final = messages
|
||||
.iter()
|
||||
.rev()
|
||||
|
@ -217,9 +214,7 @@ impl HarmonyEncoding {
|
|||
&& first_final_idx.is_some_and(|first| *idx < first)
|
||||
&& msg.channel.as_deref() == Some("analysis"))
|
||||
})
|
||||
.try_for_each(|(_, msg)| self.render_into(msg, into));
|
||||
self.conversation_has_function_tools
|
||||
.store(false, Ordering::Relaxed);
|
||||
.try_for_each(|(_, msg)| self.render_into(msg, into, Some(&render_options)));
|
||||
result?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -305,18 +300,27 @@ impl HarmonyEncoding {
|
|||
}
|
||||
|
||||
/// Render a single message into tokens.
|
||||
pub fn render(&self, message: &Message) -> anyhow::Result<Vec<Rank>> {
|
||||
pub fn render(
|
||||
&self,
|
||||
message: &Message,
|
||||
render_options: Option<&RenderOptions>,
|
||||
) -> anyhow::Result<Vec<Rank>> {
|
||||
let mut out = vec![];
|
||||
Render::<Message>::render(self, message, &mut out)?;
|
||||
Render::<Message>::render(self, message, &mut out, render_options)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Render a single message into the provided buffer.
|
||||
pub fn render_into<B>(&self, message: &Message, into: &mut B) -> anyhow::Result<()>
|
||||
pub fn render_into<B>(
|
||||
&self,
|
||||
message: &Message,
|
||||
into: &mut B,
|
||||
render_options: Option<&RenderOptions>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
B: Extend<Rank>,
|
||||
{
|
||||
Render::<Message>::render(self, message, into)
|
||||
Render::<Message>::render(self, message, into, render_options)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -772,14 +776,29 @@ impl HarmonyEncoding {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct RenderOptions {
|
||||
pub conversation_has_function_tools: bool,
|
||||
}
|
||||
|
||||
trait Render<T: ?Sized> {
|
||||
fn render<B>(&self, item: &T, into: &mut B) -> anyhow::Result<()>
|
||||
fn render<B>(
|
||||
&self,
|
||||
item: &T,
|
||||
into: &mut B,
|
||||
render_options: Option<&RenderOptions>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
B: Extend<Rank>;
|
||||
}
|
||||
|
||||
impl Render<Message> for HarmonyEncoding {
|
||||
fn render<B>(&self, message: &Message, into: &mut B) -> anyhow::Result<()>
|
||||
fn render<B>(
|
||||
&self,
|
||||
message: &Message,
|
||||
into: &mut B,
|
||||
render_options: Option<&RenderOptions>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
B: Extend<Rank>,
|
||||
{
|
||||
|
@ -836,7 +855,7 @@ impl Render<Message> for HarmonyEncoding {
|
|||
message.author.role
|
||||
);
|
||||
}
|
||||
Render::<Content>::render(self, content, into)?;
|
||||
Render::<Content>::render(self, content, into, render_options)?;
|
||||
}
|
||||
|
||||
// If there is a tool call we should render a tool call token
|
||||
|
@ -851,15 +870,22 @@ impl Render<Message> for HarmonyEncoding {
|
|||
|
||||
// Dispatch Content variants to their specific Render implementations
|
||||
impl Render<Content> for HarmonyEncoding {
|
||||
fn render<B>(&self, content: &Content, into: &mut B) -> anyhow::Result<()>
|
||||
fn render<B>(
|
||||
&self,
|
||||
content: &Content,
|
||||
into: &mut B,
|
||||
render_options: Option<&RenderOptions>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
B: Extend<Rank>,
|
||||
{
|
||||
match content {
|
||||
Content::Text(text) => Render::<TextContent>::render(self, text, into),
|
||||
Content::SystemContent(sys) => Render::<SystemContent>::render(self, sys, into),
|
||||
Content::Text(text) => Render::<TextContent>::render(self, text, into, render_options),
|
||||
Content::SystemContent(sys) => {
|
||||
Render::<SystemContent>::render(self, sys, into, render_options)
|
||||
}
|
||||
Content::DeveloperContent(dev) => {
|
||||
Render::<crate::chat::DeveloperContent>::render(self, dev, into)
|
||||
Render::<crate::chat::DeveloperContent>::render(self, dev, into, render_options)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -867,7 +893,12 @@ impl Render<Content> for HarmonyEncoding {
|
|||
|
||||
// Render plain text content
|
||||
impl Render<TextContent> for HarmonyEncoding {
|
||||
fn render<B>(&self, text: &TextContent, into: &mut B) -> anyhow::Result<()>
|
||||
fn render<B>(
|
||||
&self,
|
||||
text: &TextContent,
|
||||
into: &mut B,
|
||||
_render_options: Option<&RenderOptions>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
B: Extend<Rank>,
|
||||
{
|
||||
|
@ -877,7 +908,12 @@ impl Render<TextContent> for HarmonyEncoding {
|
|||
|
||||
// Render system-specific content (model identity, instructions, effort)
|
||||
impl Render<SystemContent> for HarmonyEncoding {
|
||||
fn render<B>(&self, sys: &SystemContent, into: &mut B) -> anyhow::Result<()>
|
||||
fn render<B>(
|
||||
&self,
|
||||
sys: &SystemContent,
|
||||
into: &mut B,
|
||||
render_options: Option<&RenderOptions>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
B: Extend<Rank>,
|
||||
{
|
||||
|
@ -923,7 +959,7 @@ impl Render<SystemContent> for HarmonyEncoding {
|
|||
if channel_config.channel_required {
|
||||
channels_header.push_str(" Channel must be included for every message.");
|
||||
}
|
||||
if self.conversation_has_function_tools.load(Ordering::Relaxed) {
|
||||
if render_options.is_some_and(|o| o.conversation_has_function_tools) {
|
||||
channels_header.push('\n');
|
||||
channels_header.push_str(
|
||||
"Calls to these tools must go to the commentary channel: 'functions'.",
|
||||
|
@ -940,7 +976,12 @@ impl Render<SystemContent> for HarmonyEncoding {
|
|||
|
||||
// Render developer-specific content (instructions, tools)
|
||||
impl Render<crate::chat::DeveloperContent> for HarmonyEncoding {
|
||||
fn render<B>(&self, dev: &crate::chat::DeveloperContent, into: &mut B) -> anyhow::Result<()>
|
||||
fn render<B>(
|
||||
&self,
|
||||
dev: &crate::chat::DeveloperContent,
|
||||
into: &mut B,
|
||||
_render_options: Option<&RenderOptions>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
B: Extend<Rank>,
|
||||
{
|
||||
|
|
|
@ -178,13 +178,29 @@ impl PyHarmonyEncoding {
|
|||
}
|
||||
|
||||
/// Render a single message into tokens.
|
||||
fn render(&self, message_json: &str) -> PyResult<Vec<u32>> {
|
||||
fn render(
|
||||
&self,
|
||||
message_json: &str,
|
||||
render_options: Option<Bound<'_, PyDict>>,
|
||||
) -> PyResult<Vec<u32>> {
|
||||
let message: crate::chat::Message = serde_json::from_str(message_json).map_err(|e| {
|
||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("invalid message JSON: {e}"))
|
||||
})?;
|
||||
|
||||
let rust_options = if let Some(options_dict) = render_options {
|
||||
let conversation_has_function_tools = options_dict
|
||||
.get_item("conversation_has_function_tools")?
|
||||
.and_then(|v| v.extract().ok())
|
||||
.unwrap_or(false);
|
||||
Some(crate::encoding::RenderOptions {
|
||||
conversation_has_function_tools,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
self.inner
|
||||
.render(&message)
|
||||
.render(&message, rust_options.as_ref())
|
||||
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
@ -76,7 +76,6 @@ pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<Harmon
|
|||
FormattingToken::EndMessageDoneSampling,
|
||||
FormattingToken::EndMessageAssistantToTool,
|
||||
]),
|
||||
conversation_has_function_tools: Arc::new(AtomicBool::new(false)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -525,7 +525,7 @@ fn test_render_and_render_conversation_roundtrip() {
|
|||
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
|
||||
let msg = Message::from_role_and_content(Role::User, "Hello");
|
||||
let convo = Conversation::from_messages([msg.clone()]);
|
||||
let tokens_msg = encoding.render(&msg).unwrap();
|
||||
let tokens_msg = encoding.render(&msg, None).unwrap();
|
||||
let tokens_convo = encoding.render_conversation(&convo, None).unwrap();
|
||||
assert_eq!(tokens_msg, tokens_convo);
|
||||
let tokens_completion = encoding
|
||||
|
|
|
@ -18,6 +18,9 @@ extern "C" {
|
|||
|
||||
#[wasm_bindgen(typescript_type = "RenderConversationConfig")]
|
||||
pub type JsRenderConversationConfig;
|
||||
|
||||
#[wasm_bindgen(typescript_type = "RenderOptions")]
|
||||
pub type JsRenderOptions;
|
||||
}
|
||||
|
||||
#[wasm_bindgen(typescript_custom_section)]
|
||||
|
@ -127,12 +130,34 @@ impl JsHarmonyEncoding {
|
|||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn render(&self, message: JsMessage) -> Result<Vec<u32>, JsValue> {
|
||||
pub fn render(
|
||||
&self,
|
||||
message: JsMessage,
|
||||
render_options: JsRenderOptions,
|
||||
) -> Result<Vec<u32>, JsValue> {
|
||||
let message: JsValue = message.into();
|
||||
let message: crate::chat::Message = serde_wasm_bindgen::from_value(message)
|
||||
.map_err(|e| JsValue::from_str(&format!("invalid message JSON: {e}")))?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RenderOptions {
|
||||
conversation_has_function_tools: Option<bool>,
|
||||
}
|
||||
let render_options: JsValue = render_options.into();
|
||||
let rust_options = if render_options.is_undefined() || render_options.is_null() {
|
||||
None
|
||||
} else {
|
||||
let cfg: RenderOptions = serde_wasm_bindgen::from_value(render_options)
|
||||
.map_err(|e| JsValue::from_str(&format!("invalid render options: {e}")))?;
|
||||
Some(crate::encoding::RenderOptions {
|
||||
conversation_has_function_tools: cfg
|
||||
.conversation_has_function_tools
|
||||
.unwrap_or(false),
|
||||
})
|
||||
};
|
||||
|
||||
self.inner
|
||||
.render(&message)
|
||||
.render(&message, rust_options.as_ref())
|
||||
.map_err(|e| JsValue::from_str(&e.to_string()))
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue