mirror of
https://github.com/openai/harmony.git
synced 2025-08-23 01:17:09 -04:00
fix: make HarmonyEncoding
usable concurrently
the `conversation_has_function_tools` atomic bool makes `HarmonyEncoding` stateful
This commit is contained in:
parent
b255cbeb62
commit
d00ac3de49
6 changed files with 130 additions and 34 deletions
|
@ -425,6 +425,10 @@ class RenderConversationConfig(BaseModel):
|
||||||
auto_drop_analysis: bool = True
|
auto_drop_analysis: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class RenderOptions(BaseModel):
|
||||||
|
conversation_has_function_tools: bool = False
|
||||||
|
|
||||||
|
|
||||||
class HarmonyEncoding:
|
class HarmonyEncoding:
|
||||||
"""High-level wrapper around the Rust ``PyHarmonyEncoding`` class."""
|
"""High-level wrapper around the Rust ``PyHarmonyEncoding`` class."""
|
||||||
|
|
||||||
|
@ -498,9 +502,20 @@ class HarmonyEncoding:
|
||||||
config=config_dict,
|
config=config_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
def render(self, message: Message) -> List[int]:
|
def render(
|
||||||
|
self, message: Message, render_options: Optional[RenderOptions] = None
|
||||||
|
) -> List[int]:
|
||||||
"""Render a single message into tokens."""
|
"""Render a single message into tokens."""
|
||||||
return self._inner.render(message_json=message.to_json())
|
if render_options is None:
|
||||||
|
render_options_dict = {"conversation_has_function_tools": False}
|
||||||
|
else:
|
||||||
|
render_options_dict = {
|
||||||
|
"conversation_has_function_tools": render_options.conversation_has_function_tools
|
||||||
|
}
|
||||||
|
|
||||||
|
return self._inner.render(
|
||||||
|
message_json=message.to_json(), render_options=render_options_dict
|
||||||
|
)
|
||||||
|
|
||||||
# -- Parsing -------------------------------------------------------
|
# -- Parsing -------------------------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,7 @@ use crate::{
|
||||||
use anyhow::Context as _;
|
use anyhow::Context as _;
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
sync::{
|
sync::Arc,
|
||||||
atomic::{AtomicBool, Ordering},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
vec,
|
vec,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -92,7 +89,6 @@ pub struct HarmonyEncoding {
|
||||||
pub(crate) format_token_mapping: HashMap<FormattingToken, String>,
|
pub(crate) format_token_mapping: HashMap<FormattingToken, String>,
|
||||||
pub(crate) stop_formatting_tokens: HashSet<FormattingToken>,
|
pub(crate) stop_formatting_tokens: HashSet<FormattingToken>,
|
||||||
pub(crate) stop_formatting_tokens_for_assistant_actions: HashSet<FormattingToken>,
|
pub(crate) stop_formatting_tokens_for_assistant_actions: HashSet<FormattingToken>,
|
||||||
pub(crate) conversation_has_function_tools: Arc<AtomicBool>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for HarmonyEncoding {
|
impl std::fmt::Debug for HarmonyEncoding {
|
||||||
|
@ -191,8 +187,9 @@ impl HarmonyEncoding {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
self.conversation_has_function_tools
|
let render_options = RenderOptions {
|
||||||
.store(has_function_tools, Ordering::Relaxed);
|
conversation_has_function_tools: has_function_tools,
|
||||||
|
};
|
||||||
let last_assistant_is_final = messages
|
let last_assistant_is_final = messages
|
||||||
.iter()
|
.iter()
|
||||||
.rev()
|
.rev()
|
||||||
|
@ -217,9 +214,7 @@ impl HarmonyEncoding {
|
||||||
&& first_final_idx.is_some_and(|first| *idx < first)
|
&& first_final_idx.is_some_and(|first| *idx < first)
|
||||||
&& msg.channel.as_deref() == Some("analysis"))
|
&& msg.channel.as_deref() == Some("analysis"))
|
||||||
})
|
})
|
||||||
.try_for_each(|(_, msg)| self.render_into(msg, into));
|
.try_for_each(|(_, msg)| self.render_into(msg, into, Some(&render_options)));
|
||||||
self.conversation_has_function_tools
|
|
||||||
.store(false, Ordering::Relaxed);
|
|
||||||
result?;
|
result?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -305,18 +300,27 @@ impl HarmonyEncoding {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Render a single message into tokens.
|
/// Render a single message into tokens.
|
||||||
pub fn render(&self, message: &Message) -> anyhow::Result<Vec<Rank>> {
|
pub fn render(
|
||||||
|
&self,
|
||||||
|
message: &Message,
|
||||||
|
render_options: Option<&RenderOptions>,
|
||||||
|
) -> anyhow::Result<Vec<Rank>> {
|
||||||
let mut out = vec![];
|
let mut out = vec![];
|
||||||
Render::<Message>::render(self, message, &mut out)?;
|
Render::<Message>::render(self, message, &mut out, render_options)?;
|
||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Render a single message into the provided buffer.
|
/// Render a single message into the provided buffer.
|
||||||
pub fn render_into<B>(&self, message: &Message, into: &mut B) -> anyhow::Result<()>
|
pub fn render_into<B>(
|
||||||
|
&self,
|
||||||
|
message: &Message,
|
||||||
|
into: &mut B,
|
||||||
|
render_options: Option<&RenderOptions>,
|
||||||
|
) -> anyhow::Result<()>
|
||||||
where
|
where
|
||||||
B: Extend<Rank>,
|
B: Extend<Rank>,
|
||||||
{
|
{
|
||||||
Render::<Message>::render(self, message, into)
|
Render::<Message>::render(self, message, into, render_options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -772,14 +776,29 @@ impl HarmonyEncoding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Default)]
|
||||||
|
pub struct RenderOptions {
|
||||||
|
pub conversation_has_function_tools: bool,
|
||||||
|
}
|
||||||
|
|
||||||
trait Render<T: ?Sized> {
|
trait Render<T: ?Sized> {
|
||||||
fn render<B>(&self, item: &T, into: &mut B) -> anyhow::Result<()>
|
fn render<B>(
|
||||||
|
&self,
|
||||||
|
item: &T,
|
||||||
|
into: &mut B,
|
||||||
|
render_options: Option<&RenderOptions>,
|
||||||
|
) -> anyhow::Result<()>
|
||||||
where
|
where
|
||||||
B: Extend<Rank>;
|
B: Extend<Rank>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Render<Message> for HarmonyEncoding {
|
impl Render<Message> for HarmonyEncoding {
|
||||||
fn render<B>(&self, message: &Message, into: &mut B) -> anyhow::Result<()>
|
fn render<B>(
|
||||||
|
&self,
|
||||||
|
message: &Message,
|
||||||
|
into: &mut B,
|
||||||
|
render_options: Option<&RenderOptions>,
|
||||||
|
) -> anyhow::Result<()>
|
||||||
where
|
where
|
||||||
B: Extend<Rank>,
|
B: Extend<Rank>,
|
||||||
{
|
{
|
||||||
|
@ -836,7 +855,7 @@ impl Render<Message> for HarmonyEncoding {
|
||||||
message.author.role
|
message.author.role
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Render::<Content>::render(self, content, into)?;
|
Render::<Content>::render(self, content, into, render_options)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is a tool call we should render a tool call token
|
// If there is a tool call we should render a tool call token
|
||||||
|
@ -851,15 +870,22 @@ impl Render<Message> for HarmonyEncoding {
|
||||||
|
|
||||||
// Dispatch Content variants to their specific Render implementations
|
// Dispatch Content variants to their specific Render implementations
|
||||||
impl Render<Content> for HarmonyEncoding {
|
impl Render<Content> for HarmonyEncoding {
|
||||||
fn render<B>(&self, content: &Content, into: &mut B) -> anyhow::Result<()>
|
fn render<B>(
|
||||||
|
&self,
|
||||||
|
content: &Content,
|
||||||
|
into: &mut B,
|
||||||
|
render_options: Option<&RenderOptions>,
|
||||||
|
) -> anyhow::Result<()>
|
||||||
where
|
where
|
||||||
B: Extend<Rank>,
|
B: Extend<Rank>,
|
||||||
{
|
{
|
||||||
match content {
|
match content {
|
||||||
Content::Text(text) => Render::<TextContent>::render(self, text, into),
|
Content::Text(text) => Render::<TextContent>::render(self, text, into, render_options),
|
||||||
Content::SystemContent(sys) => Render::<SystemContent>::render(self, sys, into),
|
Content::SystemContent(sys) => {
|
||||||
|
Render::<SystemContent>::render(self, sys, into, render_options)
|
||||||
|
}
|
||||||
Content::DeveloperContent(dev) => {
|
Content::DeveloperContent(dev) => {
|
||||||
Render::<crate::chat::DeveloperContent>::render(self, dev, into)
|
Render::<crate::chat::DeveloperContent>::render(self, dev, into, render_options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -867,7 +893,12 @@ impl Render<Content> for HarmonyEncoding {
|
||||||
|
|
||||||
// Render plain text content
|
// Render plain text content
|
||||||
impl Render<TextContent> for HarmonyEncoding {
|
impl Render<TextContent> for HarmonyEncoding {
|
||||||
fn render<B>(&self, text: &TextContent, into: &mut B) -> anyhow::Result<()>
|
fn render<B>(
|
||||||
|
&self,
|
||||||
|
text: &TextContent,
|
||||||
|
into: &mut B,
|
||||||
|
_render_options: Option<&RenderOptions>,
|
||||||
|
) -> anyhow::Result<()>
|
||||||
where
|
where
|
||||||
B: Extend<Rank>,
|
B: Extend<Rank>,
|
||||||
{
|
{
|
||||||
|
@ -877,7 +908,12 @@ impl Render<TextContent> for HarmonyEncoding {
|
||||||
|
|
||||||
// Render system-specific content (model identity, instructions, effort)
|
// Render system-specific content (model identity, instructions, effort)
|
||||||
impl Render<SystemContent> for HarmonyEncoding {
|
impl Render<SystemContent> for HarmonyEncoding {
|
||||||
fn render<B>(&self, sys: &SystemContent, into: &mut B) -> anyhow::Result<()>
|
fn render<B>(
|
||||||
|
&self,
|
||||||
|
sys: &SystemContent,
|
||||||
|
into: &mut B,
|
||||||
|
render_options: Option<&RenderOptions>,
|
||||||
|
) -> anyhow::Result<()>
|
||||||
where
|
where
|
||||||
B: Extend<Rank>,
|
B: Extend<Rank>,
|
||||||
{
|
{
|
||||||
|
@ -923,7 +959,7 @@ impl Render<SystemContent> for HarmonyEncoding {
|
||||||
if channel_config.channel_required {
|
if channel_config.channel_required {
|
||||||
channels_header.push_str(" Channel must be included for every message.");
|
channels_header.push_str(" Channel must be included for every message.");
|
||||||
}
|
}
|
||||||
if self.conversation_has_function_tools.load(Ordering::Relaxed) {
|
if render_options.is_some_and(|o| o.conversation_has_function_tools) {
|
||||||
channels_header.push('\n');
|
channels_header.push('\n');
|
||||||
channels_header.push_str(
|
channels_header.push_str(
|
||||||
"Calls to these tools must go to the commentary channel: 'functions'.",
|
"Calls to these tools must go to the commentary channel: 'functions'.",
|
||||||
|
@ -940,7 +976,12 @@ impl Render<SystemContent> for HarmonyEncoding {
|
||||||
|
|
||||||
// Render developer-specific content (instructions, tools)
|
// Render developer-specific content (instructions, tools)
|
||||||
impl Render<crate::chat::DeveloperContent> for HarmonyEncoding {
|
impl Render<crate::chat::DeveloperContent> for HarmonyEncoding {
|
||||||
fn render<B>(&self, dev: &crate::chat::DeveloperContent, into: &mut B) -> anyhow::Result<()>
|
fn render<B>(
|
||||||
|
&self,
|
||||||
|
dev: &crate::chat::DeveloperContent,
|
||||||
|
into: &mut B,
|
||||||
|
_render_options: Option<&RenderOptions>,
|
||||||
|
) -> anyhow::Result<()>
|
||||||
where
|
where
|
||||||
B: Extend<Rank>,
|
B: Extend<Rank>,
|
||||||
{
|
{
|
||||||
|
|
|
@ -178,13 +178,29 @@ impl PyHarmonyEncoding {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Render a single message into tokens.
|
/// Render a single message into tokens.
|
||||||
fn render(&self, message_json: &str) -> PyResult<Vec<u32>> {
|
fn render(
|
||||||
|
&self,
|
||||||
|
message_json: &str,
|
||||||
|
render_options: Option<Bound<'_, PyDict>>,
|
||||||
|
) -> PyResult<Vec<u32>> {
|
||||||
let message: crate::chat::Message = serde_json::from_str(message_json).map_err(|e| {
|
let message: crate::chat::Message = serde_json::from_str(message_json).map_err(|e| {
|
||||||
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("invalid message JSON: {e}"))
|
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("invalid message JSON: {e}"))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let rust_options = if let Some(options_dict) = render_options {
|
||||||
|
let conversation_has_function_tools = options_dict
|
||||||
|
.get_item("conversation_has_function_tools")?
|
||||||
|
.and_then(|v| v.extract().ok())
|
||||||
|
.unwrap_or(false);
|
||||||
|
Some(crate::encoding::RenderOptions {
|
||||||
|
conversation_has_function_tools,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
self.inner
|
self.inner
|
||||||
.render(&message)
|
.render(&message, rust_options.as_ref())
|
||||||
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
sync::{atomic::AtomicBool, Arc},
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -76,7 +76,6 @@ pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<Harmon
|
||||||
FormattingToken::EndMessageDoneSampling,
|
FormattingToken::EndMessageDoneSampling,
|
||||||
FormattingToken::EndMessageAssistantToTool,
|
FormattingToken::EndMessageAssistantToTool,
|
||||||
]),
|
]),
|
||||||
conversation_has_function_tools: Arc::new(AtomicBool::new(false)),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -525,7 +525,7 @@ fn test_render_and_render_conversation_roundtrip() {
|
||||||
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
|
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
|
||||||
let msg = Message::from_role_and_content(Role::User, "Hello");
|
let msg = Message::from_role_and_content(Role::User, "Hello");
|
||||||
let convo = Conversation::from_messages([msg.clone()]);
|
let convo = Conversation::from_messages([msg.clone()]);
|
||||||
let tokens_msg = encoding.render(&msg).unwrap();
|
let tokens_msg = encoding.render(&msg, None).unwrap();
|
||||||
let tokens_convo = encoding.render_conversation(&convo, None).unwrap();
|
let tokens_convo = encoding.render_conversation(&convo, None).unwrap();
|
||||||
assert_eq!(tokens_msg, tokens_convo);
|
assert_eq!(tokens_msg, tokens_convo);
|
||||||
let tokens_completion = encoding
|
let tokens_completion = encoding
|
||||||
|
|
|
@ -18,6 +18,9 @@ extern "C" {
|
||||||
|
|
||||||
#[wasm_bindgen(typescript_type = "RenderConversationConfig")]
|
#[wasm_bindgen(typescript_type = "RenderConversationConfig")]
|
||||||
pub type JsRenderConversationConfig;
|
pub type JsRenderConversationConfig;
|
||||||
|
|
||||||
|
#[wasm_bindgen(typescript_type = "RenderOptions")]
|
||||||
|
pub type JsRenderOptions;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[wasm_bindgen(typescript_custom_section)]
|
#[wasm_bindgen(typescript_custom_section)]
|
||||||
|
@ -127,12 +130,34 @@ impl JsHarmonyEncoding {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
pub fn render(&self, message: JsMessage) -> Result<Vec<u32>, JsValue> {
|
pub fn render(
|
||||||
|
&self,
|
||||||
|
message: JsMessage,
|
||||||
|
render_options: JsRenderOptions,
|
||||||
|
) -> Result<Vec<u32>, JsValue> {
|
||||||
let message: JsValue = message.into();
|
let message: JsValue = message.into();
|
||||||
let message: crate::chat::Message = serde_wasm_bindgen::from_value(message)
|
let message: crate::chat::Message = serde_wasm_bindgen::from_value(message)
|
||||||
.map_err(|e| JsValue::from_str(&format!("invalid message JSON: {e}")))?;
|
.map_err(|e| JsValue::from_str(&format!("invalid message JSON: {e}")))?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct RenderOptions {
|
||||||
|
conversation_has_function_tools: Option<bool>,
|
||||||
|
}
|
||||||
|
let render_options: JsValue = render_options.into();
|
||||||
|
let rust_options = if render_options.is_undefined() || render_options.is_null() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let cfg: RenderOptions = serde_wasm_bindgen::from_value(render_options)
|
||||||
|
.map_err(|e| JsValue::from_str(&format!("invalid render options: {e}")))?;
|
||||||
|
Some(crate::encoding::RenderOptions {
|
||||||
|
conversation_has_function_tools: cfg
|
||||||
|
.conversation_has_function_tools
|
||||||
|
.unwrap_or(false),
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
self.inner
|
self.inner
|
||||||
.render(&message)
|
.render(&message, rust_options.as_ref())
|
||||||
.map_err(|e| JsValue::from_str(&e.to_string()))
|
.map_err(|e| JsValue::from_str(&e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue