mirror of
https://github.com/openai/harmony.git
synced 2025-08-23 01:17:09 -04:00
unified
This commit is contained in:
parent
41a404a90b
commit
4bc6933549
7 changed files with 175 additions and 390 deletions
|
@ -20,3 +20,7 @@ demo = ["uvicorn", "fastapi"]
|
||||||
features = ["pyo3/extension-module"]
|
features = ["pyo3/extension-module"]
|
||||||
module-name = "openai_harmony"
|
module-name = "openai_harmony"
|
||||||
python-source = "python"
|
python-source = "python"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
# Only collect tests from the top-level tests directory
|
||||||
|
testpaths = ["tests"]
|
||||||
|
|
|
@ -29,7 +29,7 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from pydantic import BaseModel, Field, root_validator, validator
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
# Re-export the low-level Rust bindings under a private name so that we can
|
# Re-export the low-level Rust bindings under a private name so that we can
|
||||||
# keep the *public* namespace clean and purely Pythonic.
|
# keep the *public* namespace clean and purely Pythonic.
|
||||||
|
@ -612,6 +612,10 @@ class StreamableParser:
|
||||||
self._inner.process(token)
|
self._inner.process(token)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def process_eos(self) -> "StreamableParser":
|
||||||
|
self._inner.process_eos()
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_content(self) -> str:
|
def current_content(self) -> str:
|
||||||
return self._inner.current_content
|
return self._inner.current_content
|
||||||
|
|
494
src/encoding.rs
494
src/encoding.rs
|
@ -368,26 +368,17 @@ impl HarmonyEncoding {
|
||||||
pub fn parse_messages_from_completion_tokens<I>(
|
pub fn parse_messages_from_completion_tokens<I>(
|
||||||
&self,
|
&self,
|
||||||
tokens: I,
|
tokens: I,
|
||||||
mut role: Option<Role>,
|
role: Option<Role>,
|
||||||
) -> anyhow::Result<Vec<Message>>
|
) -> anyhow::Result<Vec<Message>>
|
||||||
where
|
where
|
||||||
I: IntoIterator<Item = Rank>,
|
I: IntoIterator<Item = Rank>,
|
||||||
{
|
{
|
||||||
let mut messages = Vec::<Message>::new();
|
let mut parser = StreamableParser::new(self.clone(), role)?;
|
||||||
let mut parser = Parser {
|
for token in tokens {
|
||||||
encoding: self,
|
parser.process(token)?;
|
||||||
tokens: tokens.into_iter().peekable(),
|
|
||||||
};
|
|
||||||
loop {
|
|
||||||
let (message, did_reach_end_of_stream) = parser.parse_message(role)?;
|
|
||||||
messages.push(message);
|
|
||||||
role = None;
|
|
||||||
if did_reach_end_of_stream {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
anyhow::ensure!(parser.tokens.next().is_none(), "Expected end of stream");
|
parser.process_eos()?;
|
||||||
Ok(messages)
|
Ok(parser.into_messages())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper to convert a JSON schema (OpenAPI style) to a TypeScript type definition.
|
/// Helper to convert a JSON schema (OpenAPI style) to a TypeScript type definition.
|
||||||
|
@ -982,305 +973,6 @@ impl Render<crate::chat::DeveloperContent> for HarmonyEncoding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enum TakeUntilStatus {
|
|
||||||
Found,
|
|
||||||
EndOfStream,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TakeUntilStatus {
|
|
||||||
fn was_found(&self) -> bool {
|
|
||||||
matches!(self, TakeUntilStatus::Found)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Parser<'a, I>
|
|
||||||
where
|
|
||||||
I: Iterator<Item = Rank>,
|
|
||||||
{
|
|
||||||
tokens: std::iter::Peekable<I>,
|
|
||||||
encoding: &'a HarmonyEncoding,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<I> Parser<'_, I>
|
|
||||||
where
|
|
||||||
I: Iterator<Item = Rank>,
|
|
||||||
{
|
|
||||||
fn expect_special(&mut self, token: FormattingToken) -> anyhow::Result<Rank> {
|
|
||||||
let next = self.tokens.next().context(format!(
|
|
||||||
"Expected special token ({}), but out of tokens",
|
|
||||||
token
|
|
||||||
))?;
|
|
||||||
let expected = self.encoding.render_formatting_token(token)?;
|
|
||||||
if next != expected {
|
|
||||||
anyhow::bail!(
|
|
||||||
"Expected special token ({}) {} but got {}",
|
|
||||||
token,
|
|
||||||
expected,
|
|
||||||
next,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(next)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn take_until_any(&mut self, ends: &HashSet<Rank>) -> (Vec<Rank>, TakeUntilStatus) {
|
|
||||||
let mut out = vec![];
|
|
||||||
for t in &mut self.tokens {
|
|
||||||
if ends.contains(&t) {
|
|
||||||
return (out, TakeUntilStatus::Found);
|
|
||||||
}
|
|
||||||
out.push(t);
|
|
||||||
}
|
|
||||||
(out, TakeUntilStatus::EndOfStream)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn take_until(&mut self, end: Rank) -> (Vec<Rank>, TakeUntilStatus) {
|
|
||||||
self.take_until_any(&HashSet::from([end]))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_header(&mut self, role: Option<Role>) -> anyhow::Result<ParsedHeader> {
|
|
||||||
// FormattingToken::Message marks the end of the header.
|
|
||||||
// Everything before that belongs to the header.
|
|
||||||
let message_start_token = self
|
|
||||||
.encoding
|
|
||||||
.render_formatting_token(FormattingToken::Message)?;
|
|
||||||
|
|
||||||
let (header_tokens, status) = self.take_until(message_start_token);
|
|
||||||
if !status.was_found() {
|
|
||||||
anyhow::bail!("Expected message start token but ran out of tokens");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode the header into a UTF-8 string so we can reason about its structure.
|
|
||||||
let mut header_string = self
|
|
||||||
.encoding
|
|
||||||
.tokenizer
|
|
||||||
.decode_utf8(header_tokens)
|
|
||||||
.context("could not decode header")?;
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
|
||||||
// 1. Extract the channel (if any)
|
|
||||||
// --------------------------------------------------------------------
|
|
||||||
// A channel, when present, is encoded as:
|
|
||||||
// <|channel|>CHANNEL_VALUE
|
|
||||||
// where <|channel|> is the literal rendering of FormattingToken::Channel
|
|
||||||
// and CHANNEL_VALUE is a contiguous string (no whitespace) naming the
|
|
||||||
// channel. The <|channel|> marker can appear before or after the
|
|
||||||
// recipient part, but *always* before the optional content-type (which
|
|
||||||
// must be last).
|
|
||||||
let mut channel: Option<String> = None;
|
|
||||||
if let Some(channel_marker) = self.encoding.mapped_format_token(FormattingToken::Channel) {
|
|
||||||
if let Some(idx) = header_string.find(channel_marker) {
|
|
||||||
// Slice parts around the marker
|
|
||||||
let after_marker = &header_string[idx + channel_marker.len()..];
|
|
||||||
|
|
||||||
// The channel value continues until the next ASCII whitespace,
|
|
||||||
// the start of another special token ("<"), or end-of-string.
|
|
||||||
let channel_end = after_marker
|
|
||||||
.find(|c: char| c.is_whitespace() || c == '<')
|
|
||||||
.unwrap_or(after_marker.len());
|
|
||||||
let channel_value = &after_marker[..channel_end];
|
|
||||||
if channel_value.is_empty() {
|
|
||||||
anyhow::bail!("channel marker present but no channel value found in header");
|
|
||||||
}
|
|
||||||
channel = Some(channel_value.to_string());
|
|
||||||
|
|
||||||
// Remove the marker *and* the channel value from the header so
|
|
||||||
// the remaining pieces can be parsed independently.
|
|
||||||
let mut new_header = String::new();
|
|
||||||
new_header.push_str(&header_string[..idx]);
|
|
||||||
new_header.push_str(&after_marker[channel_end..]);
|
|
||||||
header_string = new_header;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trim extraneous whitespace that may have been introduced when we
|
|
||||||
// removed the channel section.
|
|
||||||
header_string = header_string.trim().to_string();
|
|
||||||
|
|
||||||
// If the constrained format marker is present but not preceded by
|
|
||||||
// whitespace (e.g. "to=foo<|constrain|>json"), insert a space before
|
|
||||||
// the marker so that splitting on whitespace treats the content type
|
|
||||||
// as a separate token.
|
|
||||||
if let Some(constrain_marker) = self
|
|
||||||
.encoding
|
|
||||||
.mapped_format_token(FormattingToken::ConstrainedFormat)
|
|
||||||
{
|
|
||||||
if header_string.contains(constrain_marker) {
|
|
||||||
header_string = header_string
|
|
||||||
.replace(constrain_marker, &format!(" {}", constrain_marker))
|
|
||||||
.trim()
|
|
||||||
.to_string();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
|
||||||
// 2. Split the remaining header into whitespace-separated tokens.
|
|
||||||
// --------------------------------------------------------------------
|
|
||||||
// Debug output for development (only active when the `debug_header_parsing` cfg flag is
|
|
||||||
// enabled).
|
|
||||||
// For debugging purposes one might want to inspect the header string
|
|
||||||
// at this point. To avoid unwanted stdout noise in production use
|
|
||||||
// the following (commented) line and recompile as needed.
|
|
||||||
// println!("[DEBUG header] '{}'", header_string);
|
|
||||||
|
|
||||||
let mut parts: Vec<&str> = header_string.split_ascii_whitespace().collect();
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
|
||||||
// 3. Determine the role (if not already provided).
|
|
||||||
// --------------------------------------------------------------------
|
|
||||||
let mut role_str_opt: Option<String> = None;
|
|
||||||
let role = match role {
|
|
||||||
Some(r) => r,
|
|
||||||
None => {
|
|
||||||
let role_str = parts
|
|
||||||
.first()
|
|
||||||
.context("message header did not contain a role")?;
|
|
||||||
role_str_opt = Some((*role_str).to_string());
|
|
||||||
let parsed_role = Role::try_from(*role_str);
|
|
||||||
let out = match parsed_role {
|
|
||||||
Ok(r) => r,
|
|
||||||
Err(_) => {
|
|
||||||
// If recipient is present, treat as tool call
|
|
||||||
if parts.len() > 1 || (parts.len() == 1 && parts[0].starts_with("to=")) {
|
|
||||||
parts.remove(0); // Remove the unknown role string
|
|
||||||
Role::Tool
|
|
||||||
} else {
|
|
||||||
return Err(anyhow::anyhow!("Unknown role: {}", role_str));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
out
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// If the role was supplied externally but also redundantly present in the
|
|
||||||
// header itself, strip it off so that it does not interfere with the
|
|
||||||
// parsing of the remaining fields.
|
|
||||||
if let Some(first) = parts.first() {
|
|
||||||
if *first == role.as_str() {
|
|
||||||
parts.remove(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
|
||||||
// 4. Identify recipient and content-type.
|
|
||||||
// --------------------------------------------------------------------
|
|
||||||
let mut recipient: Option<String> = None;
|
|
||||||
let mut content_type: Option<String> = None;
|
|
||||||
|
|
||||||
if !parts.is_empty() {
|
|
||||||
// Determine whether the last token is a content-type or part of the
|
|
||||||
// recipient specification.
|
|
||||||
let num_tokens_before_pop = parts.len();
|
|
||||||
let last_token_owned = parts.pop().unwrap().to_string();
|
|
||||||
|
|
||||||
if last_token_owned.starts_with("to=") {
|
|
||||||
// The header contains a recipient but *no* content-type.
|
|
||||||
recipient = Some(last_token_owned.trim_start_matches("to=").to_string());
|
|
||||||
} else if num_tokens_before_pop == 1 {
|
|
||||||
// Only one token total (after potential role removal) and it doesn't start
|
|
||||||
// with "to=" => interpret it as a standalone recipient.
|
|
||||||
recipient = Some(last_token_owned);
|
|
||||||
} else {
|
|
||||||
// More than one token and the last one is not a recipient -> treat as content-type.
|
|
||||||
content_type = Some(last_token_owned);
|
|
||||||
|
|
||||||
// After removing the content-type there may be exactly one token describing the recipient.
|
|
||||||
if !parts.is_empty() {
|
|
||||||
if parts.len() != 1 {
|
|
||||||
anyhow::bail!("Could not parse header: too many tokens remaining after extracting content-type and recipient");
|
|
||||||
}
|
|
||||||
let raw_recipient = parts.pop().unwrap();
|
|
||||||
recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") {
|
|
||||||
Some(stripped.to_string())
|
|
||||||
} else {
|
|
||||||
Some(raw_recipient.to_string())
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// After processing, no unparsed tokens should remain.
|
|
||||||
anyhow::ensure!(
|
|
||||||
parts.is_empty(),
|
|
||||||
"unexpected tokens remaining in message header: {:?}",
|
|
||||||
parts
|
|
||||||
);
|
|
||||||
// We have successfully parsed the header.
|
|
||||||
let author = if role == Role::Tool {
|
|
||||||
let name = role_str_opt;
|
|
||||||
Author { role, name }
|
|
||||||
} else {
|
|
||||||
Author { role, name: None }
|
|
||||||
};
|
|
||||||
Ok(ParsedHeader {
|
|
||||||
author,
|
|
||||||
recipient,
|
|
||||||
channel,
|
|
||||||
content_type,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parses a message from the remaining tokens.
|
|
||||||
///
|
|
||||||
/// Returns the message and a boolean indicating whether end of stream was reached.
|
|
||||||
fn parse_message(&mut self, role: Option<Role>) -> anyhow::Result<(Message, bool)> {
|
|
||||||
let start_token = self
|
|
||||||
.encoding
|
|
||||||
.render_formatting_token(FormattingToken::Start)?;
|
|
||||||
match role {
|
|
||||||
Some(_) => {
|
|
||||||
if let Some(&next) = self.tokens.peek() {
|
|
||||||
if next == start_token {
|
|
||||||
self.tokens.next();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
anyhow::bail!("Expected at least one token while parsing message");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
self.expect_special(FormattingToken::Start)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let header = self.parse_header(role)?;
|
|
||||||
let ParsedHeader {
|
|
||||||
author,
|
|
||||||
recipient,
|
|
||||||
channel,
|
|
||||||
content_type,
|
|
||||||
} = header;
|
|
||||||
|
|
||||||
// TODO other content types
|
|
||||||
// since we bail on anything other than just the role in the header for now, we can assume
|
|
||||||
// that the content type is text
|
|
||||||
let end_tokens = self.encoding.stop_tokens()?;
|
|
||||||
let (remaining_tokens, status) = self.take_until_any(&end_tokens);
|
|
||||||
let remaining_text = self
|
|
||||||
.encoding
|
|
||||||
.tokenizer
|
|
||||||
.decode_utf8(remaining_tokens)
|
|
||||||
.context("could not decode message content")?;
|
|
||||||
let did_reach_end_of_stream = match status {
|
|
||||||
TakeUntilStatus::Found => self.tokens.peek().is_none(),
|
|
||||||
TakeUntilStatus::EndOfStream => true,
|
|
||||||
};
|
|
||||||
Ok((
|
|
||||||
Message {
|
|
||||||
author,
|
|
||||||
content: vec![Content::Text(TextContent {
|
|
||||||
text: remaining_text,
|
|
||||||
})],
|
|
||||||
channel,
|
|
||||||
recipient,
|
|
||||||
content_type,
|
|
||||||
},
|
|
||||||
did_reach_end_of_stream,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// Streamable parsing ---------------------------------------------------------
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// Incremental parser that can consume tokens one by one.
|
/// Incremental parser that can consume tokens one by one.
|
||||||
///
|
///
|
||||||
/// It keeps track of all tokens seen so far, exposes all fully parsed messages
|
/// It keeps track of all tokens seen so far, exposes all fully parsed messages
|
||||||
|
@ -1334,8 +1026,11 @@ impl StreamableParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consume a single token and update the internal state.
|
/// Consume a single token and update the internal state.
|
||||||
pub fn process(&mut self, token: Rank) -> anyhow::Result<&mut Self> {
|
/// Consume a single token and update the internal state.
|
||||||
self.tokens.push(token);
|
fn process_next(&mut self, token: Option<Rank>) -> anyhow::Result<&mut Self> {
|
||||||
|
if let Some(token) = token {
|
||||||
|
self.tokens.push(token);
|
||||||
|
}
|
||||||
// Clone next_role up front to avoid borrow checker issues
|
// Clone next_role up front to avoid borrow checker issues
|
||||||
let next_role_clone = self.next_role.clone();
|
let next_role_clone = self.next_role.clone();
|
||||||
match &mut self.state {
|
match &mut self.state {
|
||||||
|
@ -1343,44 +1038,89 @@ impl StreamableParser {
|
||||||
let start = self
|
let start = self
|
||||||
.encoding
|
.encoding
|
||||||
.render_formatting_token(FormattingToken::Start)?;
|
.render_formatting_token(FormattingToken::Start)?;
|
||||||
if token == start {
|
match token {
|
||||||
self.state = StreamState::Header {
|
Some(token) if token == start => {
|
||||||
header_tokens: Vec::new(),
|
self.state = StreamState::Header {
|
||||||
};
|
header_tokens: Vec::new(),
|
||||||
} else {
|
};
|
||||||
anyhow::bail!(
|
}
|
||||||
"Unexpected token {} while expecting start token {}",
|
Some(token) => {
|
||||||
token,
|
anyhow::bail!(
|
||||||
start
|
"Unexpected token {} while expecting start token {}",
|
||||||
);
|
token,
|
||||||
|
start
|
||||||
|
);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// receiving EOS while waiting for start token is actually fine
|
||||||
|
// as we may have just parsed a stop token. in this case we can
|
||||||
|
// simple keep state as is
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StreamState::Header { header_tokens } => {
|
StreamState::Header { header_tokens } => {
|
||||||
let msg_tok = self
|
let msg_tok = self
|
||||||
.encoding
|
.encoding
|
||||||
.render_formatting_token(FormattingToken::Message)?;
|
.render_formatting_token(FormattingToken::Message)?;
|
||||||
if token == msg_tok {
|
match token {
|
||||||
// Clone the tokens and next_role, then clear the state before parsing
|
Some(token) if token == msg_tok => {
|
||||||
let header_tokens_cloned = header_tokens.clone();
|
// Clone the tokens and next_role, then clear the state before parsing
|
||||||
let next_role_cloned = next_role_clone;
|
let header_tokens_cloned = header_tokens.clone();
|
||||||
// Set state to dummy to drop mutable borrow
|
let next_role_cloned = next_role_clone;
|
||||||
self.state = StreamState::ExpectStart;
|
// Set state to dummy to drop mutable borrow
|
||||||
let header =
|
self.state = StreamState::ExpectStart;
|
||||||
self.parse_header_from_tokens(&header_tokens_cloned, next_role_cloned)?;
|
let header =
|
||||||
self.next_role = None;
|
self.parse_header_from_tokens(&header_tokens_cloned, next_role_cloned)?;
|
||||||
self.state = StreamState::Content {
|
self.next_role = None;
|
||||||
header,
|
self.state = StreamState::Content {
|
||||||
content_tokens: Vec::new(),
|
header,
|
||||||
};
|
content_tokens: Vec::new(),
|
||||||
} else {
|
};
|
||||||
header_tokens.push(token);
|
}
|
||||||
|
Some(token) => {
|
||||||
|
header_tokens.push(token);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Unexpected EOS while waiting for message header to complete"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StreamState::Content {
|
StreamState::Content {
|
||||||
header,
|
header,
|
||||||
content_tokens,
|
content_tokens,
|
||||||
} => {
|
} => {
|
||||||
if self.stop_tokens.contains(&token) {
|
let is_eos = if let Some(token) = token {
|
||||||
|
if self.stop_tokens.contains(&token) {
|
||||||
|
// this is a stop token, dont parse and mark EOS
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
self.undecoded_tokens.push(token);
|
||||||
|
// some tokens might not appropriately decode on their own. If they don't
|
||||||
|
// we will collect them until they eventually decode
|
||||||
|
match self
|
||||||
|
.encoding
|
||||||
|
.tokenizer()
|
||||||
|
.decode_utf8(&self.undecoded_tokens)
|
||||||
|
{
|
||||||
|
Ok(decoded) => {
|
||||||
|
content_tokens.extend(self.undecoded_tokens.iter().copied());
|
||||||
|
self.last_content_delta = Some(decoded);
|
||||||
|
self.undecoded_tokens.clear();
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
self.last_content_delta = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// this was not an EOS
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// token = None signals EOS to this function
|
||||||
|
true
|
||||||
|
};
|
||||||
|
if is_eos {
|
||||||
let text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
|
let text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
|
||||||
let message = Message {
|
let message = Message {
|
||||||
author: header.author.clone(),
|
author: header.author.clone(),
|
||||||
|
@ -1393,30 +1133,21 @@ impl StreamableParser {
|
||||||
self.state = StreamState::ExpectStart;
|
self.state = StreamState::ExpectStart;
|
||||||
self.last_content_delta = None;
|
self.last_content_delta = None;
|
||||||
self.undecoded_tokens.clear();
|
self.undecoded_tokens.clear();
|
||||||
} else {
|
|
||||||
self.undecoded_tokens.push(token);
|
|
||||||
// some tokens might not appropriately decode on their own. If they don't
|
|
||||||
// we will collect them until they eventually decode
|
|
||||||
match self
|
|
||||||
.encoding
|
|
||||||
.tokenizer()
|
|
||||||
.decode_utf8(&self.undecoded_tokens)
|
|
||||||
{
|
|
||||||
Ok(decoded) => {
|
|
||||||
content_tokens.extend(self.undecoded_tokens.iter().copied());
|
|
||||||
self.last_content_delta = Some(decoded);
|
|
||||||
self.undecoded_tokens.clear();
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
self.last_content_delta = None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn process(&mut self, token: Rank) -> anyhow::Result<&mut Self> {
|
||||||
|
self.process_next(Some(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn process_eos(&mut self) -> anyhow::Result<&mut Self> {
|
||||||
|
self.process_next(None)?;
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_header_from_tokens(
|
fn parse_header_from_tokens(
|
||||||
&self,
|
&self,
|
||||||
header_tokens: &[Rank],
|
header_tokens: &[Rank],
|
||||||
|
@ -1462,7 +1193,7 @@ impl StreamableParser {
|
||||||
{
|
{
|
||||||
if header_string.contains(constrain_marker) {
|
if header_string.contains(constrain_marker) {
|
||||||
header_string = header_string
|
header_string = header_string
|
||||||
.replace(constrain_marker, &format!(" {constrain_marker}"))
|
.replace(constrain_marker, &format!(" {}", constrain_marker))
|
||||||
.trim()
|
.trim()
|
||||||
.to_string();
|
.to_string();
|
||||||
}
|
}
|
||||||
|
@ -1495,8 +1226,8 @@ impl StreamableParser {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(first) = parts.first() {
|
if let Some(&first) = parts.first() {
|
||||||
if *first == role.as_str() {
|
if first == role.as_str() {
|
||||||
parts.remove(0);
|
parts.remove(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1505,23 +1236,25 @@ impl StreamableParser {
|
||||||
let mut content_type: Option<String> = None;
|
let mut content_type: Option<String> = None;
|
||||||
|
|
||||||
if !parts.is_empty() {
|
if !parts.is_empty() {
|
||||||
let num_tokens_before_pop = parts.len();
|
// Determine whether the last token is a content-type or part of the
|
||||||
let last_token_owned = parts.pop().unwrap().to_string();
|
// recipient specification.
|
||||||
|
let num_parts = parts.len();
|
||||||
|
// SAFETY: we know that there is at least one part remaining, because of is_empty check above
|
||||||
|
let last_part = parts.pop().unwrap();
|
||||||
|
|
||||||
if last_token_owned.starts_with("to=") {
|
if let Some(stripped) = last_part.strip_prefix("to=") {
|
||||||
recipient = Some(last_token_owned.trim_start_matches("to=").to_string());
|
// The header contains a recipient but *no* content-type.
|
||||||
} else if num_tokens_before_pop == 1 {
|
recipient = Some(stripped.to_string());
|
||||||
recipient = Some(last_token_owned);
|
} else if num_parts == 1 {
|
||||||
|
// Only one part total (after potential role removal) and it doesn't start
|
||||||
|
// with "to=" => interpret it as a standalone recipient.
|
||||||
|
recipient = Some(last_part.to_string());
|
||||||
} else {
|
} else {
|
||||||
content_type = Some(last_token_owned);
|
// More than one token and the last one is not a recipient -> treat as content-type.
|
||||||
|
content_type = Some(last_part.to_string());
|
||||||
|
|
||||||
if !parts.is_empty() {
|
// After removing the content-type there may be exactly one token describing the recipient.
|
||||||
if parts.len() != 1 {
|
if let Some(raw_recipient) = parts.pop() {
|
||||||
anyhow::bail!(
|
|
||||||
"Could not parse header: too many tokens remaining after extracting content-type and recipient"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let raw_recipient = parts.pop().unwrap();
|
|
||||||
recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") {
|
recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") {
|
||||||
Some(stripped.to_string())
|
Some(stripped.to_string())
|
||||||
} else {
|
} else {
|
||||||
|
@ -1530,12 +1263,12 @@ impl StreamableParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::ensure!(
|
anyhow::ensure!(
|
||||||
parts.is_empty(),
|
parts.is_empty(),
|
||||||
"unexpected tokens remaining in message header: {:?}",
|
"unexpected tokens remaining in message header: {:?}",
|
||||||
parts
|
parts
|
||||||
);
|
);
|
||||||
|
|
||||||
let author = if role == Role::Tool {
|
let author = if role == Role::Tool {
|
||||||
let name = role_str_opt;
|
let name = role_str_opt;
|
||||||
Author { role, name }
|
Author { role, name }
|
||||||
|
@ -1583,6 +1316,11 @@ impl StreamableParser {
|
||||||
Ok(self.last_content_delta.clone())
|
Ok(self.last_content_delta.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Consume the parser and return all parsed messages.
|
||||||
|
pub fn into_messages(self) -> Vec<Message> {
|
||||||
|
self.messages
|
||||||
|
}
|
||||||
|
|
||||||
/// All fully parsed messages so far.
|
/// All fully parsed messages so far.
|
||||||
pub fn messages(&self) -> &[Message] {
|
pub fn messages(&self) -> &[Message] {
|
||||||
&self.messages
|
&self.messages
|
||||||
|
|
|
@ -306,6 +306,13 @@ impl PyStreamableParser {
|
||||||
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn process_eos(&mut self) -> PyResult<()> {
|
||||||
|
self.inner
|
||||||
|
.process_eos()
|
||||||
|
.map(|_| ())
|
||||||
|
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn current_content(&self) -> PyResult<String> {
|
fn current_content(&self) -> PyResult<String> {
|
||||||
self.inner
|
self.inner
|
||||||
|
|
36
src/tests.rs
36
src/tests.rs
|
@ -663,3 +663,39 @@ fn test_streamable_parser_tool_call_with_constrain_adjacent() {
|
||||||
parser.messages()[0]
|
parser.messages()[0]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_call_with_constrain_marker_adjacent() {
|
||||||
|
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
|
||||||
|
let text = "<|start|>assistant to=functions.get_weather<|channel|>commentary<|constrain|>json<|message|>{\"location\": \"Tokyo\"}<|end|>";
|
||||||
|
let tokens = encoding.tokenizer().encode_with_special_tokens(text);
|
||||||
|
let parsed = encoding
|
||||||
|
.parse_messages_from_completion_tokens(tokens, None)
|
||||||
|
.expect("expected to parse");
|
||||||
|
let expected =
|
||||||
|
vec![
|
||||||
|
Message::from_role_and_content(Role::Assistant, "{\"location\": \"Tokyo\"}")
|
||||||
|
.with_channel("commentary")
|
||||||
|
.with_recipient("functions.get_weather")
|
||||||
|
.with_content_type("<|constrain|>json"),
|
||||||
|
];
|
||||||
|
assert_eq!(parsed, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_call_with_channel_before_recipient_and_constrain_adjacent() {
|
||||||
|
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
|
||||||
|
let text = "<|start|>assistant<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{\"latitude\":48.8566,\"longitude\":2.3522}<|call|>";
|
||||||
|
let tokens = encoding.tokenizer().encode_with_special_tokens(text);
|
||||||
|
let parsed = encoding
|
||||||
|
.parse_messages_from_completion_tokens(tokens, None)
|
||||||
|
.expect("expected to parse");
|
||||||
|
let expected = vec![Message::from_role_and_content(
|
||||||
|
Role::Assistant,
|
||||||
|
"{\"latitude\":48.8566,\"longitude\":2.3522}",
|
||||||
|
)
|
||||||
|
.with_channel("commentary")
|
||||||
|
.with_recipient("functions.get_weather")
|
||||||
|
.with_content_type("<|constrain|>json")];
|
||||||
|
assert_eq!(parsed, expected);
|
||||||
|
}
|
||||||
|
|
5
test_python.sh
Executable file
5
test_python.sh
Executable file
|
@ -0,0 +1,5 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
set -e
|
||||||
|
source .venv/bin/activate
|
||||||
|
maturin develop -F python-binding --release
|
||||||
|
pytest "$@"
|
|
@ -34,7 +34,6 @@ from openai_harmony import ( # noqa: E402
|
||||||
StreamableParser,
|
StreamableParser,
|
||||||
SystemContent,
|
SystemContent,
|
||||||
ToolDescription,
|
ToolDescription,
|
||||||
ToolNamespaceConfig,
|
|
||||||
load_harmony_encoding,
|
load_harmony_encoding,
|
||||||
)
|
)
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
@ -245,23 +244,18 @@ def test_tool_call_with_constrain_marker_adjacent(encoding_name):
|
||||||
correctly and instead handle it as a separate content type.
|
correctly and instead handle it as a separate content type.
|
||||||
"""
|
"""
|
||||||
encoding = load_harmony_encoding(encoding_name)
|
encoding = load_harmony_encoding(encoding_name)
|
||||||
|
|
||||||
text = (
|
text = (
|
||||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||||
'<|constrain|>json<|message|>{"location": "Tokyo"}<|end|>'
|
'<|constrain|>json<|message|>{"location": "Tokyo"}<|end|>'
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = encoding.encode(text, allowed_special="all")
|
tokens = encoding.encode(text, allowed_special="all")
|
||||||
|
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
|
||||||
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT)
|
|
||||||
|
|
||||||
expected = [
|
expected = [
|
||||||
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||||
.with_channel("commentary")
|
.with_channel("commentary")
|
||||||
.with_recipient("functions.get_weather")
|
.with_recipient("functions.get_weather")
|
||||||
.with_content_type("<|constrain|>json"),
|
.with_content_type("<|constrain|>json"),
|
||||||
]
|
]
|
||||||
|
|
||||||
assert parsed == expected
|
assert parsed == expected
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,11 +274,8 @@ def test_tool_call_with_channel_before_recipient_and_constrain_adjacent(
|
||||||
"<|start|>assistant<|channel|>commentary to=functions.get_weather"
|
"<|start|>assistant<|channel|>commentary to=functions.get_weather"
|
||||||
'<|constrain|>json<|message|>{"latitude":48.8566,"longitude":2.3522}<|call|>'
|
'<|constrain|>json<|message|>{"latitude":48.8566,"longitude":2.3522}<|call|>'
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = encoding.encode(text, allowed_special="all")
|
tokens = encoding.encode(text, allowed_special="all")
|
||||||
|
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
|
||||||
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT)
|
|
||||||
|
|
||||||
expected = [
|
expected = [
|
||||||
Message.from_role_and_content(
|
Message.from_role_and_content(
|
||||||
Role.ASSISTANT, '{"latitude":48.8566,"longitude":2.3522}'
|
Role.ASSISTANT, '{"latitude":48.8566,"longitude":2.3522}'
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue