Merge pull request #7 from openai/dev/scl/unify-parsers

Unify Parsers, Fix pypi README
This commit is contained in:
Scott Lessans 2025-08-05 10:21:33 -07:00 committed by GitHub
commit 9f015d0fa9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 224 additions and 448 deletions

2
Cargo.lock generated
View file

@ -1317,7 +1317,7 @@ checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]] [[package]]
name = "openai-harmony" name = "openai-harmony"
version = "0.0.1" version = "0.0.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"base64", "base64",

View file

@ -1,6 +1,6 @@
[package] [package]
name = "openai-harmony" name = "openai-harmony"
version = "0.0.1" version = "0.0.2"
edition = "2021" edition = "2021"
license = "Apache-2.0" license = "Apache-2.0"
repository = "https://github.com/openai/harmony" repository = "https://github.com/openai/harmony"

View file

@ -12,6 +12,8 @@ classifiers = [
] ]
dynamic = ["version"] dynamic = ["version"]
dependencies = ["pydantic>=2.11.7"] dependencies = ["pydantic>=2.11.7"]
description = "OpenAI's response format for its open-weight model series gpt-oss"
readme = "README.md"
[project.optional-dependencies] [project.optional-dependencies]
demo = ["uvicorn", "fastapi"] demo = ["uvicorn", "fastapi"]
@ -20,3 +22,7 @@ demo = ["uvicorn", "fastapi"]
features = ["pyo3/extension-module"] features = ["pyo3/extension-module"]
module-name = "openai_harmony" module-name = "openai_harmony"
python-source = "python" python-source = "python"
[tool.pytest.ini_options]
# Only collect tests from the top-level tests directory
testpaths = ["tests"]

View file

@ -29,7 +29,7 @@ from typing import (
) )
import re import re
from pydantic import BaseModel, Field, root_validator, validator from pydantic import BaseModel, Field
# Re-export the low-level Rust bindings under a private name so that we can # Re-export the low-level Rust bindings under a private name so that we can
# keep the *public* namespace clean and purely Pythonic. # keep the *public* namespace clean and purely Pythonic.
@ -612,6 +612,10 @@ class StreamableParser:
self._inner.process(token) self._inner.process(token)
return self return self
def process_eos(self) -> "StreamableParser":
self._inner.process_eos()
return self
@property @property
def current_content(self) -> str: def current_content(self) -> str:
return self._inner.current_content return self._inner.current_content

View file

@ -217,7 +217,7 @@ impl HarmonyEncoding {
&& first_final_idx.is_some_and(|first| *idx < first) && first_final_idx.is_some_and(|first| *idx < first)
&& msg.channel.as_deref() == Some("analysis")) && msg.channel.as_deref() == Some("analysis"))
}) })
.try_for_each(|(_, msg)| self.render_into(*msg, into)); .try_for_each(|(_, msg)| self.render_into(msg, into));
self.conversation_has_function_tools self.conversation_has_function_tools
.store(false, Ordering::Relaxed); .store(false, Ordering::Relaxed);
result?; result?;
@ -368,36 +368,27 @@ impl HarmonyEncoding {
pub fn parse_messages_from_completion_tokens<I>( pub fn parse_messages_from_completion_tokens<I>(
&self, &self,
tokens: I, tokens: I,
mut role: Option<Role>, role: Option<Role>,
) -> anyhow::Result<Vec<Message>> ) -> anyhow::Result<Vec<Message>>
where where
I: IntoIterator<Item = Rank>, I: IntoIterator<Item = Rank>,
{ {
let mut messages = Vec::<Message>::new(); let mut parser = StreamableParser::new(self.clone(), role)?;
let mut parser = Parser { for token in tokens {
encoding: self, parser.process(token)?;
tokens: tokens.into_iter().peekable(),
};
loop {
let (message, did_reach_end_of_stream) = parser.parse_message(role)?;
messages.push(message);
role = None;
if did_reach_end_of_stream {
break;
}
} }
anyhow::ensure!(parser.tokens.next().is_none(), "Expected end of stream"); parser.process_eos()?;
Ok(messages) Ok(parser.into_messages())
} }
/// Helper to convert a JSON schema (OpenAPI style) to a TypeScript type definition. /// Helper to convert a JSON schema (OpenAPI style) to a TypeScript type definition.
fn json_schema_to_typescript(schema: &serde_json::Value, indent: &str) -> String { fn json_schema_to_typescript(schema: &serde_json::Value, indent: &str) -> String {
// Helper to check if this schema is an enum // Helper to check if this schema is an enum
fn is_enum(schema: &serde_json::Value) -> bool { fn is_enum(schema: &serde_json::Value) -> bool {
return schema schema
.get("enum") .get("enum")
.and_then(|e| e.as_array()) .and_then(|e| e.as_array())
.map_or(false, |arr| !arr.is_empty()); .is_some_and(|arr| !arr.is_empty())
} }
// Handle oneOf at the top level // Handle oneOf at the top level
@ -407,30 +398,29 @@ impl HarmonyEncoding {
let mut first = true; let mut first = true;
for variant in arr { for variant in arr {
if !first { if !first {
out.push_str("\n"); out.push('\n');
out.push_str(&format!("{} | ", indent)); out.push_str(&format!("{indent} | "));
} else { } else {
out.push_str(&format!("\n{} | ", indent)); out.push_str(&format!("\n{indent} | "));
first = false; first = false;
} }
let type_str = let type_str =
Self::json_schema_to_typescript(variant, &format!("{} ", indent)); Self::json_schema_to_typescript(variant, &format!("{indent} "));
let mut type_str = type_str; let mut type_str = type_str;
if variant if variant
.get("nullable") .get("nullable")
.and_then(|n| n.as_bool()) .and_then(|n| n.as_bool())
.unwrap_or(false) .unwrap_or(false)
&& !type_str.contains("null")
{ {
if !type_str.contains("null") { type_str = format!("{type_str} | null");
type_str = format!("{} | null", type_str);
}
} }
out.push_str(&type_str); out.push_str(&type_str);
// Add trailing comments (description, default) // Add trailing comments (description, default)
let mut trailing_comments = Vec::new(); let mut trailing_comments = Vec::new();
if let Some(desc) = variant.get("description") { if let Some(desc) = variant.get("description") {
if let Some(desc_str) = desc.as_str() { if let Some(desc_str) = desc.as_str() {
trailing_comments.push(format!("{}", desc_str)); trailing_comments.push(desc_str.to_string());
} }
} }
if let Some(default) = variant.get("default") { if let Some(default) = variant.get("default") {
@ -438,7 +428,7 @@ impl HarmonyEncoding {
trailing_comments trailing_comments
.push(format!("default: \"{}\"", default.as_str().unwrap())); .push(format!("default: \"{}\"", default.as_str().unwrap()));
} else { } else {
trailing_comments.push(format!("default: {}", default)); trailing_comments.push(format!("default: {default}"));
} }
} }
if !trailing_comments.is_empty() { if !trailing_comments.is_empty() {
@ -472,7 +462,7 @@ impl HarmonyEncoding {
// Render object-level description as comment // Render object-level description as comment
if let Some(desc) = schema.get("description") { if let Some(desc) = schema.get("description") {
if let Some(desc_str) = desc.as_str() { if let Some(desc_str) = desc.as_str() {
out.push_str(&format!("{}// {}\n", indent, desc_str)); out.push_str(&format!("{indent}// {desc_str}\n"));
} }
} }
out.push_str("{\n"); out.push_str("{\n");
@ -495,8 +485,7 @@ impl HarmonyEncoding {
if let Some(title) = val.get("title") { if let Some(title) = val.get("title") {
if let Some(title_str) = title.as_str() { if let Some(title_str) = title.as_str() {
out.push_str(&format!( out.push_str(&format!(
"{0}// {1}\n{0}//\n", "{indent}// {title_str}\n{indent}//\n"
indent, title_str
)); ));
} }
} }
@ -504,19 +493,18 @@ impl HarmonyEncoding {
if val.get("oneOf").is_none() { if val.get("oneOf").is_none() {
if let Some(desc) = val.get("description") { if let Some(desc) = val.get("description") {
if let Some(desc_str) = desc.as_str() { if let Some(desc_str) = desc.as_str() {
out.push_str(&format!("{}// {}\n", indent, desc_str)); out.push_str(&format!("{indent}// {desc_str}\n"));
} }
} }
} }
if let Some(examples) = val.get("examples") { if let Some(examples) = val.get("examples") {
if let Some(arr) = examples.as_array() { if let Some(arr) = examples.as_array() {
if !arr.is_empty() { if !arr.is_empty() {
out.push_str(&format!("{}// Examples:\n", indent)); out.push_str(&format!("{indent}// Examples:\n"));
for ex in arr { for ex in arr {
if let Some(ex_str) = ex.as_str() { if let Some(ex_str) = ex.as_str() {
out.push_str(&format!( out.push_str(&format!(
"{}// - \"{}\"\n", "{indent}// - \"{ex_str}\"\n"
indent, ex_str
)); ));
} }
} }
@ -535,7 +523,7 @@ impl HarmonyEncoding {
} }
let mut skip_property_desc = false; let mut skip_property_desc = false;
if let Some(desc_str) = property_desc { if let Some(desc_str) = property_desc {
if let Some(first_variant) = arr.get(0) { if let Some(first_variant) = arr.first() {
if let Some(variant_desc) = if let Some(variant_desc) =
first_variant.get("description") first_variant.get("description")
{ {
@ -553,10 +541,7 @@ impl HarmonyEncoding {
let mut rendered_property_desc_above = false; let mut rendered_property_desc_above = false;
if !skip_property_desc { if !skip_property_desc {
if let Some(desc_str) = property_desc { if let Some(desc_str) = property_desc {
out.push_str(&format!( out.push_str(&format!("{indent}// {desc_str}\n"));
"{}// {}\n",
indent, desc_str
));
rendered_property_desc_above = true; rendered_property_desc_above = true;
} }
} }
@ -575,8 +560,7 @@ impl HarmonyEncoding {
)); ));
} else { } else {
out.push_str(&format!( out.push_str(&format!(
"{}// default: {}\n", "{indent}// default: {default}\n"
indent, default
)); ));
} }
} }
@ -593,10 +577,10 @@ impl HarmonyEncoding {
)); ));
// Render each variant // Render each variant
for (i, variant) in arr.iter().enumerate() { for (i, variant) in arr.iter().enumerate() {
out.push_str(&format!("{} | ", indent)); out.push_str(&format!("{indent} | "));
let type_str = Self::json_schema_to_typescript( let type_str = Self::json_schema_to_typescript(
variant, variant,
&format!("{} ", indent), &format!("{indent} "),
); );
// Handle nullable in variant // Handle nullable in variant
let mut type_str = type_str; let mut type_str = type_str;
@ -604,10 +588,9 @@ impl HarmonyEncoding {
.get("nullable") .get("nullable")
.and_then(|n| n.as_bool()) .and_then(|n| n.as_bool())
.unwrap_or(false) .unwrap_or(false)
&& !type_str.contains("null")
{ {
if !type_str.contains("null") { type_str = format!("{type_str} | null");
type_str = format!("{} | null", type_str);
}
} }
out.push_str(&type_str); out.push_str(&type_str);
// Add variant-level comments after the type // Add variant-level comments after the type
@ -619,7 +602,7 @@ impl HarmonyEncoding {
// Only render if not equal to property-level description // Only render if not equal to property-level description
if Some(desc_str) != property_desc { if Some(desc_str) != property_desc {
trailing_comments trailing_comments
.push(format!("{}", desc_str)); .push(desc_str.to_string());
} }
} }
} }
@ -636,7 +619,7 @@ impl HarmonyEncoding {
)); ));
} else { } else {
trailing_comments trailing_comments
.push(format!("default: {}", default)); .push(format!("default: {default}"));
} }
} }
if !trailing_comments.is_empty() { if !trailing_comments.is_empty() {
@ -645,9 +628,9 @@ impl HarmonyEncoding {
trailing_comments.join(" ") trailing_comments.join(" ")
)); ));
} }
out.push_str("\n"); out.push('\n');
} }
out.push_str(&format!("{},\n", indent)); out.push_str(&format!("{indent},\n"));
continue; continue;
} }
} }
@ -663,21 +646,18 @@ impl HarmonyEncoding {
} }
)); ));
// Handle nullable // Handle nullable
let mut type_str = Self::json_schema_to_typescript( let mut type_str =
val, Self::json_schema_to_typescript(val, &format!("{indent} "));
&format!("{} ", indent),
);
if val if val
.get("nullable") .get("nullable")
.and_then(|n| n.as_bool()) .and_then(|n| n.as_bool())
.unwrap_or(false) .unwrap_or(false)
&& !type_str.contains("null")
{ {
if !type_str.contains("null") { type_str = format!("{type_str} | null");
type_str = format!("{} | null", type_str);
}
} }
out.push_str(&type_str); out.push_str(&type_str);
out.push_str(","); out.push(',');
// Add default as comment if present (and not already handled) // Add default as comment if present (and not already handled)
if val.get("oneOf").is_none() { if val.get("oneOf").is_none() {
if let Some(default) = val.get("default") { if let Some(default) = val.get("default") {
@ -692,15 +672,15 @@ impl HarmonyEncoding {
default.as_str().unwrap() default.as_str().unwrap()
)); ));
} else { } else {
out.push_str(&format!(" // default: {}", default)); out.push_str(&format!(" // default: {default}"));
} }
} }
} }
out.push_str("\n"); out.push('\n');
} }
} }
} }
out.push_str(&format!("{}}}", indent)); out.push_str(&format!("{indent}}}"));
out out
} }
"string" => { "string" => {
@ -708,7 +688,7 @@ impl HarmonyEncoding {
if let Some(arr) = enum_vals.as_array() { if let Some(arr) = enum_vals.as_array() {
let enums: Vec<String> = arr let enums: Vec<String> = arr
.iter() .iter()
.filter_map(|v| v.as_str().map(|s| format!("\"{}\"", s))) .filter_map(|v| v.as_str().map(|s| format!("\"{s}\"")))
.collect(); .collect();
if !enums.is_empty() { if !enums.is_empty() {
return enums.join(" | "); return enums.join(" | ");
@ -756,13 +736,13 @@ impl HarmonyEncoding {
) -> String { ) -> String {
let mut tool_sections = Vec::<String>::new(); let mut tool_sections = Vec::<String>::new();
tool_sections.push("# Tools".to_string()); tool_sections.push("# Tools".to_string());
for (_namespace, ns_config) in tools { for ns_config in tools.values() {
let mut tool_section_content = Vec::<String>::new(); let mut tool_section_content = Vec::<String>::new();
tool_section_content.push(format!("## {}\n", ns_config.name)); tool_section_content.push(format!("## {}\n", ns_config.name));
if let Some(desc) = &ns_config.description { if let Some(desc) = &ns_config.description {
for line in desc.lines() { for line in desc.lines() {
if !ns_config.tools.is_empty() { if !ns_config.tools.is_empty() {
tool_section_content.push(format!("// {}", line)); tool_section_content.push(format!("// {line}"));
} else { } else {
tool_section_content.push(line.to_string()); tool_section_content.push(line.to_string());
} }
@ -772,7 +752,7 @@ impl HarmonyEncoding {
tool_section_content.push(format!("namespace {} {{\n", ns_config.name)); tool_section_content.push(format!("namespace {} {{\n", ns_config.name));
for tool in &ns_config.tools { for tool in &ns_config.tools {
for line in tool.description.lines() { for line in tool.description.lines() {
tool_section_content.push(format!("// {}", line)); tool_section_content.push(format!("// {line}"));
} }
if let Some(params) = &tool.parameters { if let Some(params) = &tool.parameters {
let param_type = Self::json_schema_to_typescript(params, ""); let param_type = Self::json_schema_to_typescript(params, "");
@ -817,14 +797,14 @@ impl Render<Message> for HarmonyEncoding {
// For users and assistants we put both the role, and optionally the user name. // For users and assistants we put both the role, and optionally the user name.
self.render_text_into(message.author.role.as_str(), into)?; self.render_text_into(message.author.role.as_str(), into)?;
if let Some(name) = &message.author.name { if let Some(name) = &message.author.name {
self.render_text_into(format!(":{}", name), into)?; self.render_text_into(format!(":{name}"), into)?;
} }
}; };
// next render the header recipient, if there is one // next render the header recipient, if there is one
if let Some(recipient) = &message.recipient { if let Some(recipient) = &message.recipient {
if recipient != "all" { if recipient != "all" {
self.render_text_into(format!(" to={}", recipient), into)?; self.render_text_into(format!(" to={recipient}"), into)?;
} }
} }
@ -836,7 +816,7 @@ impl Render<Message> for HarmonyEncoding {
// finally content type // finally content type
if let Some(content_type) = &message.content_type { if let Some(content_type) = &message.content_type {
self.render_text_into(format!(" {}", content_type), into)?; self.render_text_into(format!(" {content_type}"), into)?;
} }
self.render_formatting_token_into(FormattingToken::Message, into)?; self.render_formatting_token_into(FormattingToken::Message, into)?;
@ -944,7 +924,7 @@ impl Render<SystemContent> for HarmonyEncoding {
channels_header.push_str(" Channel must be included for every message."); channels_header.push_str(" Channel must be included for every message.");
} }
if self.conversation_has_function_tools.load(Ordering::Relaxed) { if self.conversation_has_function_tools.load(Ordering::Relaxed) {
channels_header.push_str("\n"); channels_header.push('\n');
channels_header.push_str( channels_header.push_str(
"Calls to these tools must go to the commentary channel: 'functions'.", "Calls to these tools must go to the commentary channel: 'functions'.",
); );
@ -982,305 +962,6 @@ impl Render<crate::chat::DeveloperContent> for HarmonyEncoding {
} }
} }
enum TakeUntilStatus {
Found,
EndOfStream,
}
impl TakeUntilStatus {
fn was_found(&self) -> bool {
matches!(self, TakeUntilStatus::Found)
}
}
struct Parser<'a, I>
where
I: Iterator<Item = Rank>,
{
tokens: std::iter::Peekable<I>,
encoding: &'a HarmonyEncoding,
}
impl<I> Parser<'_, I>
where
I: Iterator<Item = Rank>,
{
fn expect_special(&mut self, token: FormattingToken) -> anyhow::Result<Rank> {
let next = self.tokens.next().context(format!(
"Expected special token ({}), but out of tokens",
token
))?;
let expected = self.encoding.render_formatting_token(token)?;
if next != expected {
anyhow::bail!(
"Expected special token ({}) {} but got {}",
token,
expected,
next,
);
}
Ok(next)
}
fn take_until_any(&mut self, ends: &HashSet<Rank>) -> (Vec<Rank>, TakeUntilStatus) {
let mut out = vec![];
for t in &mut self.tokens {
if ends.contains(&t) {
return (out, TakeUntilStatus::Found);
}
out.push(t);
}
(out, TakeUntilStatus::EndOfStream)
}
fn take_until(&mut self, end: Rank) -> (Vec<Rank>, TakeUntilStatus) {
self.take_until_any(&HashSet::from([end]))
}
fn parse_header(&mut self, role: Option<Role>) -> anyhow::Result<ParsedHeader> {
// FormattingToken::Message marks the end of the header.
// Everything before that belongs to the header.
let message_start_token = self
.encoding
.render_formatting_token(FormattingToken::Message)?;
let (header_tokens, status) = self.take_until(message_start_token);
if !status.was_found() {
anyhow::bail!("Expected message start token but ran out of tokens");
}
// Decode the header into a UTF-8 string so we can reason about its structure.
let mut header_string = self
.encoding
.tokenizer
.decode_utf8(header_tokens)
.context("could not decode header")?;
// --------------------------------------------------------------------
// 1. Extract the channel (if any)
// --------------------------------------------------------------------
// A channel, when present, is encoded as:
// <|channel|>CHANNEL_VALUE
// where <|channel|> is the literal rendering of FormattingToken::Channel
// and CHANNEL_VALUE is a contiguous string (no whitespace) naming the
// channel. The <|channel|> marker can appear before or after the
// recipient part, but *always* before the optional content-type (which
// must be last).
let mut channel: Option<String> = None;
if let Some(channel_marker) = self.encoding.mapped_format_token(FormattingToken::Channel) {
if let Some(idx) = header_string.find(channel_marker) {
// Slice parts around the marker
let after_marker = &header_string[idx + channel_marker.len()..];
// The channel value continues until the next ASCII whitespace,
// the start of another special token ("<"), or end-of-string.
let channel_end = after_marker
.find(|c: char| c.is_whitespace() || c == '<')
.unwrap_or(after_marker.len());
let channel_value = &after_marker[..channel_end];
if channel_value.is_empty() {
anyhow::bail!("channel marker present but no channel value found in header");
}
channel = Some(channel_value.to_string());
// Remove the marker *and* the channel value from the header so
// the remaining pieces can be parsed independently.
let mut new_header = String::new();
new_header.push_str(&header_string[..idx]);
new_header.push_str(&after_marker[channel_end..]);
header_string = new_header;
}
}
// Trim extraneous whitespace that may have been introduced when we
// removed the channel section.
header_string = header_string.trim().to_string();
// If the constrained format marker is present but not preceded by
// whitespace (e.g. "to=foo<|constrain|>json"), insert a space before
// the marker so that splitting on whitespace treats the content type
// as a separate token.
if let Some(constrain_marker) = self
.encoding
.mapped_format_token(FormattingToken::ConstrainedFormat)
{
if header_string.contains(constrain_marker) {
header_string = header_string
.replace(constrain_marker, &format!(" {}", constrain_marker))
.trim()
.to_string();
}
}
// --------------------------------------------------------------------
// 2. Split the remaining header into whitespace-separated tokens.
// --------------------------------------------------------------------
// Debug output for development (only active when the `debug_header_parsing` cfg flag is
// enabled).
// For debugging purposes one might want to inspect the header string
// at this point. To avoid unwanted stdout noise in production use
// the following (commented) line and recompile as needed.
// println!("[DEBUG header] '{}'", header_string);
let mut parts: Vec<&str> = header_string.split_ascii_whitespace().collect();
// --------------------------------------------------------------------
// 3. Determine the role (if not already provided).
// --------------------------------------------------------------------
let mut role_str_opt: Option<String> = None;
let role = match role {
Some(r) => r,
None => {
let role_str = parts
.first()
.context("message header did not contain a role")?;
role_str_opt = Some((*role_str).to_string());
let parsed_role = Role::try_from(*role_str);
let out = match parsed_role {
Ok(r) => r,
Err(_) => {
// If recipient is present, treat as tool call
if parts.len() > 1 || (parts.len() == 1 && parts[0].starts_with("to=")) {
parts.remove(0); // Remove the unknown role string
Role::Tool
} else {
return Err(anyhow::anyhow!("Unknown role: {}", role_str));
}
}
};
out
}
};
// If the role was supplied externally but also redundantly present in the
// header itself, strip it off so that it does not interfere with the
// parsing of the remaining fields.
if let Some(first) = parts.first() {
if *first == role.as_str() {
parts.remove(0);
}
}
// --------------------------------------------------------------------
// 4. Identify recipient and content-type.
// --------------------------------------------------------------------
let mut recipient: Option<String> = None;
let mut content_type: Option<String> = None;
if !parts.is_empty() {
// Determine whether the last token is a content-type or part of the
// recipient specification.
let num_tokens_before_pop = parts.len();
let last_token_owned = parts.pop().unwrap().to_string();
if last_token_owned.starts_with("to=") {
// The header contains a recipient but *no* content-type.
recipient = Some(last_token_owned.trim_start_matches("to=").to_string());
} else if num_tokens_before_pop == 1 {
// Only one token total (after potential role removal) and it doesn't start
// with "to=" => interpret it as a standalone recipient.
recipient = Some(last_token_owned);
} else {
// More than one token and the last one is not a recipient -> treat as content-type.
content_type = Some(last_token_owned);
// After removing the content-type there may be exactly one token describing the recipient.
if !parts.is_empty() {
if parts.len() != 1 {
anyhow::bail!("Could not parse header: too many tokens remaining after extracting content-type and recipient");
}
let raw_recipient = parts.pop().unwrap();
recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") {
Some(stripped.to_string())
} else {
Some(raw_recipient.to_string())
};
}
}
}
// After processing, no unparsed tokens should remain.
anyhow::ensure!(
parts.is_empty(),
"unexpected tokens remaining in message header: {:?}",
parts
);
// We have successfully parsed the header.
let author = if role == Role::Tool {
let name = role_str_opt;
Author { role, name }
} else {
Author { role, name: None }
};
Ok(ParsedHeader {
author,
recipient,
channel,
content_type,
})
}
/// Parses a message from the remaining tokens.
///
/// Returns the message and a boolean indicating whether end of stream was reached.
fn parse_message(&mut self, role: Option<Role>) -> anyhow::Result<(Message, bool)> {
let start_token = self
.encoding
.render_formatting_token(FormattingToken::Start)?;
match role {
Some(_) => {
if let Some(&next) = self.tokens.peek() {
if next == start_token {
self.tokens.next();
}
} else {
anyhow::bail!("Expected at least one token while parsing message");
}
}
None => {
self.expect_special(FormattingToken::Start)?;
}
}
let header = self.parse_header(role)?;
let ParsedHeader {
author,
recipient,
channel,
content_type,
} = header;
// TODO other content types
// since we bail on anything other than just the role in the header for now, we can assume
// that the content type is text
let end_tokens = self.encoding.stop_tokens()?;
let (remaining_tokens, status) = self.take_until_any(&end_tokens);
let remaining_text = self
.encoding
.tokenizer
.decode_utf8(remaining_tokens)
.context("could not decode message content")?;
let did_reach_end_of_stream = match status {
TakeUntilStatus::Found => self.tokens.peek().is_none(),
TakeUntilStatus::EndOfStream => true,
};
Ok((
Message {
author,
content: vec![Content::Text(TextContent {
text: remaining_text,
})],
channel,
recipient,
content_type,
},
did_reach_end_of_stream,
))
}
}
// ---------------------------------------------------------------------------
// Streamable parsing ---------------------------------------------------------
// ---------------------------------------------------------------------------
/// Incremental parser that can consume tokens one by one. /// Incremental parser that can consume tokens one by one.
/// ///
/// It keeps track of all tokens seen so far, exposes all fully parsed messages /// It keeps track of all tokens seen so far, exposes all fully parsed messages
@ -1334,8 +1015,11 @@ impl StreamableParser {
} }
/// Consume a single token and update the internal state. /// Consume a single token and update the internal state.
pub fn process(&mut self, token: Rank) -> anyhow::Result<&mut Self> { /// Consume a single token and update the internal state.
self.tokens.push(token); fn process_next(&mut self, token: Option<Rank>) -> anyhow::Result<&mut Self> {
if let Some(token) = token {
self.tokens.push(token);
}
// Clone next_role up front to avoid borrow checker issues // Clone next_role up front to avoid borrow checker issues
let next_role_clone = self.next_role.clone(); let next_role_clone = self.next_role.clone();
match &mut self.state { match &mut self.state {
@ -1343,44 +1027,89 @@ impl StreamableParser {
let start = self let start = self
.encoding .encoding
.render_formatting_token(FormattingToken::Start)?; .render_formatting_token(FormattingToken::Start)?;
if token == start { match token {
self.state = StreamState::Header { Some(token) if token == start => {
header_tokens: Vec::new(), self.state = StreamState::Header {
}; header_tokens: Vec::new(),
} else { };
anyhow::bail!( }
"Unexpected token {} while expecting start token {}", Some(token) => {
token, anyhow::bail!(
start "Unexpected token {} while expecting start token {}",
); token,
start
);
}
None => {
// receiving EOS while waiting for start token is actually fine
// as we may have just parsed a stop token. in this case we can
// simple keep state as is
}
} }
} }
StreamState::Header { header_tokens } => { StreamState::Header { header_tokens } => {
let msg_tok = self let msg_tok = self
.encoding .encoding
.render_formatting_token(FormattingToken::Message)?; .render_formatting_token(FormattingToken::Message)?;
if token == msg_tok { match token {
// Clone the tokens and next_role, then clear the state before parsing Some(token) if token == msg_tok => {
let header_tokens_cloned = header_tokens.clone(); // Clone the tokens and next_role, then clear the state before parsing
let next_role_cloned = next_role_clone; let header_tokens_cloned = header_tokens.clone();
// Set state to dummy to drop mutable borrow let next_role_cloned = next_role_clone;
self.state = StreamState::ExpectStart; // Set state to dummy to drop mutable borrow
let header = self.state = StreamState::ExpectStart;
self.parse_header_from_tokens(&header_tokens_cloned, next_role_cloned)?; let header =
self.next_role = None; self.parse_header_from_tokens(&header_tokens_cloned, next_role_cloned)?;
self.state = StreamState::Content { self.next_role = None;
header, self.state = StreamState::Content {
content_tokens: Vec::new(), header,
}; content_tokens: Vec::new(),
} else { };
header_tokens.push(token); }
Some(token) => {
header_tokens.push(token);
}
None => {
anyhow::bail!(
"Unexpected EOS while waiting for message header to complete"
);
}
} }
} }
StreamState::Content { StreamState::Content {
header, header,
content_tokens, content_tokens,
} => { } => {
if self.stop_tokens.contains(&token) { let is_eos = if let Some(token) = token {
if self.stop_tokens.contains(&token) {
// this is a stop token, dont parse and mark EOS
true
} else {
self.undecoded_tokens.push(token);
// some tokens might not appropriately decode on their own. If they don't
// we will collect them until they eventually decode
match self
.encoding
.tokenizer()
.decode_utf8(&self.undecoded_tokens)
{
Ok(decoded) => {
content_tokens.extend(self.undecoded_tokens.iter().copied());
self.last_content_delta = Some(decoded);
self.undecoded_tokens.clear();
}
Err(_) => {
self.last_content_delta = None;
}
}
// this was not an EOS
false
}
} else {
// token = None signals EOS to this function
true
};
if is_eos {
let text = self.encoding.tokenizer().decode_utf8(content_tokens)?; let text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
let message = Message { let message = Message {
author: header.author.clone(), author: header.author.clone(),
@ -1393,30 +1122,21 @@ impl StreamableParser {
self.state = StreamState::ExpectStart; self.state = StreamState::ExpectStart;
self.last_content_delta = None; self.last_content_delta = None;
self.undecoded_tokens.clear(); self.undecoded_tokens.clear();
} else {
self.undecoded_tokens.push(token);
// some tokens might not appropriately decode on their own. If they don't
// we will collect them until they eventually decode
match self
.encoding
.tokenizer()
.decode_utf8(&self.undecoded_tokens)
{
Ok(decoded) => {
content_tokens.extend(self.undecoded_tokens.iter().copied());
self.last_content_delta = Some(decoded);
self.undecoded_tokens.clear();
}
Err(_) => {
self.last_content_delta = None;
}
}
} }
} }
} }
Ok(self) Ok(self)
} }
pub fn process(&mut self, token: Rank) -> anyhow::Result<&mut Self> {
self.process_next(Some(token))
}
pub fn process_eos(&mut self) -> anyhow::Result<&mut Self> {
self.process_next(None)?;
Ok(self)
}
fn parse_header_from_tokens( fn parse_header_from_tokens(
&self, &self,
header_tokens: &[Rank], header_tokens: &[Rank],
@ -1495,8 +1215,8 @@ impl StreamableParser {
} }
}; };
if let Some(first) = parts.first() { if let Some(&first) = parts.first() {
if *first == role.as_str() { if first == role.as_str() {
parts.remove(0); parts.remove(0);
} }
} }
@ -1505,23 +1225,25 @@ impl StreamableParser {
let mut content_type: Option<String> = None; let mut content_type: Option<String> = None;
if !parts.is_empty() { if !parts.is_empty() {
let num_tokens_before_pop = parts.len(); // Determine whether the last token is a content-type or part of the
let last_token_owned = parts.pop().unwrap().to_string(); // recipient specification.
let num_parts = parts.len();
// SAFETY: we know that there is at least one part remaining, because of is_empty check above
let last_part = parts.pop().unwrap();
if last_token_owned.starts_with("to=") { if let Some(stripped) = last_part.strip_prefix("to=") {
recipient = Some(last_token_owned.trim_start_matches("to=").to_string()); // The header contains a recipient but *no* content-type.
} else if num_tokens_before_pop == 1 { recipient = Some(stripped.to_string());
recipient = Some(last_token_owned); } else if num_parts == 1 {
// Only one part total (after potential role removal) and it doesn't start
// with "to=" => interpret it as a standalone recipient.
recipient = Some(last_part.to_string());
} else { } else {
content_type = Some(last_token_owned); // More than one token and the last one is not a recipient -> treat as content-type.
content_type = Some(last_part.to_string());
if !parts.is_empty() { // After removing the content-type there may be exactly one token describing the recipient.
if parts.len() != 1 { if let Some(raw_recipient) = parts.pop() {
anyhow::bail!(
"Could not parse header: too many tokens remaining after extracting content-type and recipient"
);
}
let raw_recipient = parts.pop().unwrap();
recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") { recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") {
Some(stripped.to_string()) Some(stripped.to_string())
} else { } else {
@ -1530,12 +1252,12 @@ impl StreamableParser {
} }
} }
} }
anyhow::ensure!( anyhow::ensure!(
parts.is_empty(), parts.is_empty(),
"unexpected tokens remaining in message header: {:?}", "unexpected tokens remaining in message header: {:?}",
parts parts
); );
let author = if role == Role::Tool { let author = if role == Role::Tool {
let name = role_str_opt; let name = role_str_opt;
Author { role, name } Author { role, name }
@ -1583,6 +1305,11 @@ impl StreamableParser {
Ok(self.last_content_delta.clone()) Ok(self.last_content_delta.clone())
} }
/// Consume the parser and return all parsed messages.
pub fn into_messages(self) -> Vec<Message> {
self.messages
}
/// All fully parsed messages so far. /// All fully parsed messages so far.
pub fn messages(&self) -> &[Message] { pub fn messages(&self) -> &[Message] {
&self.messages &self.messages

View file

@ -306,6 +306,13 @@ impl PyStreamableParser {
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string())) .map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
} }
fn process_eos(&mut self) -> PyResult<()> {
self.inner
.process_eos()
.map(|_| ())
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}
#[getter] #[getter]
fn current_content(&self) -> PyResult<String> { fn current_content(&self) -> PyResult<String> {
self.inner self.inner

View file

@ -501,11 +501,11 @@ fn test_preserve_cot() {
fn test_reserved_token_decoding() { fn test_reserved_token_decoding() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
assert_eq!( assert_eq!(
encoding.tokenizer.decode_utf8(&[200014]).unwrap(), encoding.tokenizer.decode_utf8([200014]).unwrap(),
"<|reserved_200014|>" "<|reserved_200014|>"
); );
assert_eq!( assert_eq!(
encoding.tokenizer.decode_utf8(&[201088]).unwrap(), encoding.tokenizer.decode_utf8([201088]).unwrap(),
"<|reserved_201088|>" "<|reserved_201088|>"
); );
} }
@ -527,7 +527,7 @@ fn test_render_and_render_conversation_roundtrip() {
#[test] #[test]
fn test_decode_utf8_invalid_token() { fn test_decode_utf8_invalid_token() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let result = encoding.tokenizer.decode_utf8(&[99999999]); let result = encoding.tokenizer.decode_utf8([99999999]);
assert!(result.is_err(), "Expected error for invalid token"); assert!(result.is_err(), "Expected error for invalid token");
} }
@ -663,3 +663,39 @@ fn test_streamable_parser_tool_call_with_constrain_adjacent() {
parser.messages()[0] parser.messages()[0]
); );
} }
#[test]
fn test_tool_call_with_constrain_marker_adjacent() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let text = "<|start|>assistant to=functions.get_weather<|channel|>commentary<|constrain|>json<|message|>{\"location\": \"Tokyo\"}<|end|>";
let tokens = encoding.tokenizer().encode_with_special_tokens(text);
let parsed = encoding
.parse_messages_from_completion_tokens(tokens, None)
.expect("expected to parse");
let expected =
vec![
Message::from_role_and_content(Role::Assistant, "{\"location\": \"Tokyo\"}")
.with_channel("commentary")
.with_recipient("functions.get_weather")
.with_content_type("<|constrain|>json"),
];
assert_eq!(parsed, expected);
}
#[test]
fn test_tool_call_with_channel_before_recipient_and_constrain_adjacent() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let text = "<|start|>assistant<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{\"latitude\":48.8566,\"longitude\":2.3522}<|call|>";
let tokens = encoding.tokenizer().encode_with_special_tokens(text);
let parsed = encoding
.parse_messages_from_completion_tokens(tokens, None)
.expect("expected to parse");
let expected = vec![Message::from_role_and_content(
Role::Assistant,
"{\"latitude\":48.8566,\"longitude\":2.3522}",
)
.with_channel("commentary")
.with_recipient("functions.get_weather")
.with_content_type("<|constrain|>json")];
assert_eq!(parsed, expected);
}

5
test_python.sh Executable file
View file

@ -0,0 +1,5 @@
#!/usr/bin/env bash
set -e
source .venv/bin/activate
maturin develop -F python-binding --release
pytest "$@"

View file

@ -34,7 +34,6 @@ from openai_harmony import ( # noqa: E402
StreamableParser, StreamableParser,
SystemContent, SystemContent,
ToolDescription, ToolDescription,
ToolNamespaceConfig,
load_harmony_encoding, load_harmony_encoding,
) )
from pydantic import ValidationError from pydantic import ValidationError
@ -245,23 +244,18 @@ def test_tool_call_with_constrain_marker_adjacent(encoding_name):
correctly and instead handle it as a separate content type. correctly and instead handle it as a separate content type.
""" """
encoding = load_harmony_encoding(encoding_name) encoding = load_harmony_encoding(encoding_name)
text = ( text = (
"<|start|>assistant to=functions.get_weather<|channel|>commentary" "<|start|>assistant to=functions.get_weather<|channel|>commentary"
'<|constrain|>json<|message|>{"location": "Tokyo"}<|end|>' '<|constrain|>json<|message|>{"location": "Tokyo"}<|end|>'
) )
tokens = encoding.encode(text, allowed_special="all") tokens = encoding.encode(text, allowed_special="all")
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT)
expected = [ expected = [
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary") .with_channel("commentary")
.with_recipient("functions.get_weather") .with_recipient("functions.get_weather")
.with_content_type("<|constrain|>json"), .with_content_type("<|constrain|>json"),
] ]
assert parsed == expected assert parsed == expected
@ -280,11 +274,8 @@ def test_tool_call_with_channel_before_recipient_and_constrain_adjacent(
"<|start|>assistant<|channel|>commentary to=functions.get_weather" "<|start|>assistant<|channel|>commentary to=functions.get_weather"
'<|constrain|>json<|message|>{"latitude":48.8566,"longitude":2.3522}<|call|>' '<|constrain|>json<|message|>{"latitude":48.8566,"longitude":2.3522}<|call|>'
) )
tokens = encoding.encode(text, allowed_special="all") tokens = encoding.encode(text, allowed_special="all")
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT)
expected = [ expected = [
Message.from_role_and_content( Message.from_role_and_content(
Role.ASSISTANT, '{"latitude":48.8566,"longitude":2.3522}' Role.ASSISTANT, '{"latitude":48.8566,"longitude":2.3522}'