diff --git a/pyproject.toml b/pyproject.toml index 251e497..ccf3dfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,3 +20,7 @@ demo = ["uvicorn", "fastapi"] features = ["pyo3/extension-module"] module-name = "openai_harmony" python-source = "python" + +[tool.pytest.ini_options] +# Only collect tests from the top-level tests directory +testpaths = ["tests"] diff --git a/python/openai_harmony/__init__.py b/python/openai_harmony/__init__.py index e763af1..3485864 100644 --- a/python/openai_harmony/__init__.py +++ b/python/openai_harmony/__init__.py @@ -29,7 +29,7 @@ from typing import ( ) import re -from pydantic import BaseModel, Field, root_validator, validator +from pydantic import BaseModel, Field # Re-export the low-level Rust bindings under a private name so that we can # keep the *public* namespace clean and purely Pythonic. @@ -612,6 +612,10 @@ class StreamableParser: self._inner.process(token) return self + def process_eos(self) -> "StreamableParser": + self._inner.process_eos() + return self + @property def current_content(self) -> str: return self._inner.current_content diff --git a/src/encoding.rs b/src/encoding.rs index 75b5cf5..f5d0378 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -368,26 +368,17 @@ impl HarmonyEncoding { pub fn parse_messages_from_completion_tokens( &self, tokens: I, - mut role: Option, + role: Option, ) -> anyhow::Result> where I: IntoIterator, { - let mut messages = Vec::::new(); - let mut parser = Parser { - encoding: self, - tokens: tokens.into_iter().peekable(), - }; - loop { - let (message, did_reach_end_of_stream) = parser.parse_message(role)?; - messages.push(message); - role = None; - if did_reach_end_of_stream { - break; - } + let mut parser = StreamableParser::new(self.clone(), role)?; + for token in tokens { + parser.process(token)?; } - anyhow::ensure!(parser.tokens.next().is_none(), "Expected end of stream"); - Ok(messages) + parser.process_eos()?; + Ok(parser.into_messages()) } /// Helper to convert a JSON schema (OpenAPI style) to a TypeScript type definition. @@ -982,305 +973,6 @@ impl Render for HarmonyEncoding { } } -enum TakeUntilStatus { - Found, - EndOfStream, -} - -impl TakeUntilStatus { - fn was_found(&self) -> bool { - matches!(self, TakeUntilStatus::Found) - } -} - -struct Parser<'a, I> -where - I: Iterator, -{ - tokens: std::iter::Peekable, - encoding: &'a HarmonyEncoding, -} - -impl Parser<'_, I> -where - I: Iterator, -{ - fn expect_special(&mut self, token: FormattingToken) -> anyhow::Result { - let next = self.tokens.next().context(format!( - "Expected special token ({}), but out of tokens", - token - ))?; - let expected = self.encoding.render_formatting_token(token)?; - if next != expected { - anyhow::bail!( - "Expected special token ({}) {} but got {}", - token, - expected, - next, - ); - } - Ok(next) - } - - fn take_until_any(&mut self, ends: &HashSet) -> (Vec, TakeUntilStatus) { - let mut out = vec![]; - for t in &mut self.tokens { - if ends.contains(&t) { - return (out, TakeUntilStatus::Found); - } - out.push(t); - } - (out, TakeUntilStatus::EndOfStream) - } - - fn take_until(&mut self, end: Rank) -> (Vec, TakeUntilStatus) { - self.take_until_any(&HashSet::from([end])) - } - - fn parse_header(&mut self, role: Option) -> anyhow::Result { - // FormattingToken::Message marks the end of the header. - // Everything before that belongs to the header. - let message_start_token = self - .encoding - .render_formatting_token(FormattingToken::Message)?; - - let (header_tokens, status) = self.take_until(message_start_token); - if !status.was_found() { - anyhow::bail!("Expected message start token but ran out of tokens"); - } - - // Decode the header into a UTF-8 string so we can reason about its structure. - let mut header_string = self - .encoding - .tokenizer - .decode_utf8(header_tokens) - .context("could not decode header")?; - - // -------------------------------------------------------------------- - // 1. Extract the channel (if any) - // -------------------------------------------------------------------- - // A channel, when present, is encoded as: - // <|channel|>CHANNEL_VALUE - // where <|channel|> is the literal rendering of FormattingToken::Channel - // and CHANNEL_VALUE is a contiguous string (no whitespace) naming the - // channel. The <|channel|> marker can appear before or after the - // recipient part, but *always* before the optional content-type (which - // must be last). - let mut channel: Option = None; - if let Some(channel_marker) = self.encoding.mapped_format_token(FormattingToken::Channel) { - if let Some(idx) = header_string.find(channel_marker) { - // Slice parts around the marker - let after_marker = &header_string[idx + channel_marker.len()..]; - - // The channel value continues until the next ASCII whitespace, - // the start of another special token ("<"), or end-of-string. - let channel_end = after_marker - .find(|c: char| c.is_whitespace() || c == '<') - .unwrap_or(after_marker.len()); - let channel_value = &after_marker[..channel_end]; - if channel_value.is_empty() { - anyhow::bail!("channel marker present but no channel value found in header"); - } - channel = Some(channel_value.to_string()); - - // Remove the marker *and* the channel value from the header so - // the remaining pieces can be parsed independently. - let mut new_header = String::new(); - new_header.push_str(&header_string[..idx]); - new_header.push_str(&after_marker[channel_end..]); - header_string = new_header; - } - } - - // Trim extraneous whitespace that may have been introduced when we - // removed the channel section. - header_string = header_string.trim().to_string(); - - // If the constrained format marker is present but not preceded by - // whitespace (e.g. "to=foo<|constrain|>json"), insert a space before - // the marker so that splitting on whitespace treats the content type - // as a separate token. - if let Some(constrain_marker) = self - .encoding - .mapped_format_token(FormattingToken::ConstrainedFormat) - { - if header_string.contains(constrain_marker) { - header_string = header_string - .replace(constrain_marker, &format!(" {}", constrain_marker)) - .trim() - .to_string(); - } - } - - // -------------------------------------------------------------------- - // 2. Split the remaining header into whitespace-separated tokens. - // -------------------------------------------------------------------- - // Debug output for development (only active when the `debug_header_parsing` cfg flag is - // enabled). - // For debugging purposes one might want to inspect the header string - // at this point. To avoid unwanted stdout noise in production use - // the following (commented) line and recompile as needed. - // println!("[DEBUG header] '{}'", header_string); - - let mut parts: Vec<&str> = header_string.split_ascii_whitespace().collect(); - - // -------------------------------------------------------------------- - // 3. Determine the role (if not already provided). - // -------------------------------------------------------------------- - let mut role_str_opt: Option = None; - let role = match role { - Some(r) => r, - None => { - let role_str = parts - .first() - .context("message header did not contain a role")?; - role_str_opt = Some((*role_str).to_string()); - let parsed_role = Role::try_from(*role_str); - let out = match parsed_role { - Ok(r) => r, - Err(_) => { - // If recipient is present, treat as tool call - if parts.len() > 1 || (parts.len() == 1 && parts[0].starts_with("to=")) { - parts.remove(0); // Remove the unknown role string - Role::Tool - } else { - return Err(anyhow::anyhow!("Unknown role: {}", role_str)); - } - } - }; - out - } - }; - // If the role was supplied externally but also redundantly present in the - // header itself, strip it off so that it does not interfere with the - // parsing of the remaining fields. - if let Some(first) = parts.first() { - if *first == role.as_str() { - parts.remove(0); - } - } - - // -------------------------------------------------------------------- - // 4. Identify recipient and content-type. - // -------------------------------------------------------------------- - let mut recipient: Option = None; - let mut content_type: Option = None; - - if !parts.is_empty() { - // Determine whether the last token is a content-type or part of the - // recipient specification. - let num_tokens_before_pop = parts.len(); - let last_token_owned = parts.pop().unwrap().to_string(); - - if last_token_owned.starts_with("to=") { - // The header contains a recipient but *no* content-type. - recipient = Some(last_token_owned.trim_start_matches("to=").to_string()); - } else if num_tokens_before_pop == 1 { - // Only one token total (after potential role removal) and it doesn't start - // with "to=" => interpret it as a standalone recipient. - recipient = Some(last_token_owned); - } else { - // More than one token and the last one is not a recipient -> treat as content-type. - content_type = Some(last_token_owned); - - // After removing the content-type there may be exactly one token describing the recipient. - if !parts.is_empty() { - if parts.len() != 1 { - anyhow::bail!("Could not parse header: too many tokens remaining after extracting content-type and recipient"); - } - let raw_recipient = parts.pop().unwrap(); - recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") { - Some(stripped.to_string()) - } else { - Some(raw_recipient.to_string()) - }; - } - } - } - - // After processing, no unparsed tokens should remain. - anyhow::ensure!( - parts.is_empty(), - "unexpected tokens remaining in message header: {:?}", - parts - ); - // We have successfully parsed the header. - let author = if role == Role::Tool { - let name = role_str_opt; - Author { role, name } - } else { - Author { role, name: None } - }; - Ok(ParsedHeader { - author, - recipient, - channel, - content_type, - }) - } - - /// Parses a message from the remaining tokens. - /// - /// Returns the message and a boolean indicating whether end of stream was reached. - fn parse_message(&mut self, role: Option) -> anyhow::Result<(Message, bool)> { - let start_token = self - .encoding - .render_formatting_token(FormattingToken::Start)?; - match role { - Some(_) => { - if let Some(&next) = self.tokens.peek() { - if next == start_token { - self.tokens.next(); - } - } else { - anyhow::bail!("Expected at least one token while parsing message"); - } - } - None => { - self.expect_special(FormattingToken::Start)?; - } - } - let header = self.parse_header(role)?; - let ParsedHeader { - author, - recipient, - channel, - content_type, - } = header; - - // TODO other content types - // since we bail on anything other than just the role in the header for now, we can assume - // that the content type is text - let end_tokens = self.encoding.stop_tokens()?; - let (remaining_tokens, status) = self.take_until_any(&end_tokens); - let remaining_text = self - .encoding - .tokenizer - .decode_utf8(remaining_tokens) - .context("could not decode message content")?; - let did_reach_end_of_stream = match status { - TakeUntilStatus::Found => self.tokens.peek().is_none(), - TakeUntilStatus::EndOfStream => true, - }; - Ok(( - Message { - author, - content: vec![Content::Text(TextContent { - text: remaining_text, - })], - channel, - recipient, - content_type, - }, - did_reach_end_of_stream, - )) - } -} - -// --------------------------------------------------------------------------- -// Streamable parsing --------------------------------------------------------- -// --------------------------------------------------------------------------- - /// Incremental parser that can consume tokens one by one. /// /// It keeps track of all tokens seen so far, exposes all fully parsed messages @@ -1334,8 +1026,11 @@ impl StreamableParser { } /// Consume a single token and update the internal state. - pub fn process(&mut self, token: Rank) -> anyhow::Result<&mut Self> { - self.tokens.push(token); + /// Consume a single token and update the internal state. + fn process_next(&mut self, token: Option) -> anyhow::Result<&mut Self> { + if let Some(token) = token { + self.tokens.push(token); + } // Clone next_role up front to avoid borrow checker issues let next_role_clone = self.next_role.clone(); match &mut self.state { @@ -1343,44 +1038,89 @@ impl StreamableParser { let start = self .encoding .render_formatting_token(FormattingToken::Start)?; - if token == start { - self.state = StreamState::Header { - header_tokens: Vec::new(), - }; - } else { - anyhow::bail!( - "Unexpected token {} while expecting start token {}", - token, - start - ); + match token { + Some(token) if token == start => { + self.state = StreamState::Header { + header_tokens: Vec::new(), + }; + } + Some(token) => { + anyhow::bail!( + "Unexpected token {} while expecting start token {}", + token, + start + ); + } + None => { + // receiving EOS while waiting for start token is actually fine + // as we may have just parsed a stop token. in this case we can + // simple keep state as is + } } } StreamState::Header { header_tokens } => { let msg_tok = self .encoding .render_formatting_token(FormattingToken::Message)?; - if token == msg_tok { - // Clone the tokens and next_role, then clear the state before parsing - let header_tokens_cloned = header_tokens.clone(); - let next_role_cloned = next_role_clone; - // Set state to dummy to drop mutable borrow - self.state = StreamState::ExpectStart; - let header = - self.parse_header_from_tokens(&header_tokens_cloned, next_role_cloned)?; - self.next_role = None; - self.state = StreamState::Content { - header, - content_tokens: Vec::new(), - }; - } else { - header_tokens.push(token); + match token { + Some(token) if token == msg_tok => { + // Clone the tokens and next_role, then clear the state before parsing + let header_tokens_cloned = header_tokens.clone(); + let next_role_cloned = next_role_clone; + // Set state to dummy to drop mutable borrow + self.state = StreamState::ExpectStart; + let header = + self.parse_header_from_tokens(&header_tokens_cloned, next_role_cloned)?; + self.next_role = None; + self.state = StreamState::Content { + header, + content_tokens: Vec::new(), + }; + } + Some(token) => { + header_tokens.push(token); + } + None => { + anyhow::bail!( + "Unexpected EOS while waiting for message header to complete" + ); + } } } StreamState::Content { header, content_tokens, } => { - if self.stop_tokens.contains(&token) { + let is_eos = if let Some(token) = token { + if self.stop_tokens.contains(&token) { + // this is a stop token, dont parse and mark EOS + true + } else { + self.undecoded_tokens.push(token); + // some tokens might not appropriately decode on their own. If they don't + // we will collect them until they eventually decode + match self + .encoding + .tokenizer() + .decode_utf8(&self.undecoded_tokens) + { + Ok(decoded) => { + content_tokens.extend(self.undecoded_tokens.iter().copied()); + self.last_content_delta = Some(decoded); + self.undecoded_tokens.clear(); + } + Err(_) => { + self.last_content_delta = None; + } + } + // this was not an EOS + false + } + } else { + // token = None signals EOS to this function + true + }; + if is_eos { let text = self.encoding.tokenizer().decode_utf8(content_tokens)?; let message = Message { author: header.author.clone(), @@ -1393,30 +1133,21 @@ impl StreamableParser { self.state = StreamState::ExpectStart; self.last_content_delta = None; self.undecoded_tokens.clear(); - } else { - self.undecoded_tokens.push(token); - // some tokens might not appropriately decode on their own. If they don't - // we will collect them until they eventually decode - match self - .encoding - .tokenizer() - .decode_utf8(&self.undecoded_tokens) - { - Ok(decoded) => { - content_tokens.extend(self.undecoded_tokens.iter().copied()); - self.last_content_delta = Some(decoded); - self.undecoded_tokens.clear(); - } - Err(_) => { - self.last_content_delta = None; - } - } } } } Ok(self) } + pub fn process(&mut self, token: Rank) -> anyhow::Result<&mut Self> { + self.process_next(Some(token)) + } + + pub fn process_eos(&mut self) -> anyhow::Result<&mut Self> { + self.process_next(None)?; + Ok(self) + } + fn parse_header_from_tokens( &self, header_tokens: &[Rank], @@ -1462,7 +1193,7 @@ impl StreamableParser { { if header_string.contains(constrain_marker) { header_string = header_string - .replace(constrain_marker, &format!(" {constrain_marker}")) + .replace(constrain_marker, &format!(" {}", constrain_marker)) .trim() .to_string(); } @@ -1495,8 +1226,8 @@ impl StreamableParser { } }; - if let Some(first) = parts.first() { - if *first == role.as_str() { + if let Some(&first) = parts.first() { + if first == role.as_str() { parts.remove(0); } } @@ -1505,23 +1236,25 @@ impl StreamableParser { let mut content_type: Option = None; if !parts.is_empty() { - let num_tokens_before_pop = parts.len(); - let last_token_owned = parts.pop().unwrap().to_string(); + // Determine whether the last token is a content-type or part of the + // recipient specification. + let num_parts = parts.len(); + // SAFETY: we know that there is at least one part remaining, because of is_empty check above + let last_part = parts.pop().unwrap(); - if last_token_owned.starts_with("to=") { - recipient = Some(last_token_owned.trim_start_matches("to=").to_string()); - } else if num_tokens_before_pop == 1 { - recipient = Some(last_token_owned); + if let Some(stripped) = last_part.strip_prefix("to=") { + // The header contains a recipient but *no* content-type. + recipient = Some(stripped.to_string()); + } else if num_parts == 1 { + // Only one part total (after potential role removal) and it doesn't start + // with "to=" => interpret it as a standalone recipient. + recipient = Some(last_part.to_string()); } else { - content_type = Some(last_token_owned); + // More than one token and the last one is not a recipient -> treat as content-type. + content_type = Some(last_part.to_string()); - if !parts.is_empty() { - if parts.len() != 1 { - anyhow::bail!( - "Could not parse header: too many tokens remaining after extracting content-type and recipient" - ); - } - let raw_recipient = parts.pop().unwrap(); + // After removing the content-type there may be exactly one token describing the recipient. + if let Some(raw_recipient) = parts.pop() { recipient = if let Some(stripped) = raw_recipient.strip_prefix("to=") { Some(stripped.to_string()) } else { @@ -1530,12 +1263,12 @@ impl StreamableParser { } } } - anyhow::ensure!( parts.is_empty(), "unexpected tokens remaining in message header: {:?}", parts ); + let author = if role == Role::Tool { let name = role_str_opt; Author { role, name } @@ -1583,6 +1316,11 @@ impl StreamableParser { Ok(self.last_content_delta.clone()) } + /// Consume the parser and return all parsed messages. + pub fn into_messages(self) -> Vec { + self.messages + } + /// All fully parsed messages so far. pub fn messages(&self) -> &[Message] { &self.messages diff --git a/src/py_module.rs b/src/py_module.rs index fe68b3c..fbb3129 100644 --- a/src/py_module.rs +++ b/src/py_module.rs @@ -306,6 +306,13 @@ impl PyStreamableParser { .map_err(|e| PyErr::new::(e.to_string())) } + fn process_eos(&mut self) -> PyResult<()> { + self.inner + .process_eos() + .map(|_| ()) + .map_err(|e| PyErr::new::(e.to_string())) + } + #[getter] fn current_content(&self) -> PyResult { self.inner diff --git a/src/tests.rs b/src/tests.rs index fe74f49..2f0e117 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -663,3 +663,39 @@ fn test_streamable_parser_tool_call_with_constrain_adjacent() { parser.messages()[0] ); } + +#[test] +fn test_tool_call_with_constrain_marker_adjacent() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let text = "<|start|>assistant to=functions.get_weather<|channel|>commentary<|constrain|>json<|message|>{\"location\": \"Tokyo\"}<|end|>"; + let tokens = encoding.tokenizer().encode_with_special_tokens(text); + let parsed = encoding + .parse_messages_from_completion_tokens(tokens, None) + .expect("expected to parse"); + let expected = + vec![ + Message::from_role_and_content(Role::Assistant, "{\"location\": \"Tokyo\"}") + .with_channel("commentary") + .with_recipient("functions.get_weather") + .with_content_type("<|constrain|>json"), + ]; + assert_eq!(parsed, expected); +} + +#[test] +fn test_tool_call_with_channel_before_recipient_and_constrain_adjacent() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let text = "<|start|>assistant<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{\"latitude\":48.8566,\"longitude\":2.3522}<|call|>"; + let tokens = encoding.tokenizer().encode_with_special_tokens(text); + let parsed = encoding + .parse_messages_from_completion_tokens(tokens, None) + .expect("expected to parse"); + let expected = vec![Message::from_role_and_content( + Role::Assistant, + "{\"latitude\":48.8566,\"longitude\":2.3522}", + ) + .with_channel("commentary") + .with_recipient("functions.get_weather") + .with_content_type("<|constrain|>json")]; + assert_eq!(parsed, expected); +} diff --git a/test_python.sh b/test_python.sh new file mode 100755 index 0000000..09a24de --- /dev/null +++ b/test_python.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -e +source .venv/bin/activate +maturin develop -F python-binding --release +pytest "$@" diff --git a/tests/test_harmony.py b/tests/test_harmony.py index 2b00672..8392d7f 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -34,7 +34,6 @@ from openai_harmony import ( # noqa: E402 StreamableParser, SystemContent, ToolDescription, - ToolNamespaceConfig, load_harmony_encoding, ) from pydantic import ValidationError @@ -245,23 +244,18 @@ def test_tool_call_with_constrain_marker_adjacent(encoding_name): correctly and instead handle it as a separate content type. """ encoding = load_harmony_encoding(encoding_name) - text = ( "<|start|>assistant to=functions.get_weather<|channel|>commentary" '<|constrain|>json<|message|>{"location": "Tokyo"}<|end|>' ) - tokens = encoding.encode(text, allowed_special="all") - - parsed = encoding.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT) - + parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None) expected = [ Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') .with_channel("commentary") .with_recipient("functions.get_weather") .with_content_type("<|constrain|>json"), ] - assert parsed == expected @@ -280,11 +274,8 @@ def test_tool_call_with_channel_before_recipient_and_constrain_adjacent( "<|start|>assistant<|channel|>commentary to=functions.get_weather" '<|constrain|>json<|message|>{"latitude":48.8566,"longitude":2.3522}<|call|>' ) - tokens = encoding.encode(text, allowed_special="all") - - parsed = encoding.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT) - + parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None) expected = [ Message.from_role_and_content( Role.ASSISTANT, '{"latitude":48.8566,"longitude":2.3522}'