From 8a4645f0f9d9c20a8ecd3381dbb88ce5f56084cc Mon Sep 17 00:00:00 2001 From: Dmytro Dzhulgakov Date: Fri, 8 Aug 2025 17:50:17 -0700 Subject: [PATCH] Fix tokenization of <|constrain|> content type in rendering (#47) --- src/encoding.rs | 17 ++++++++++++++++- tests/test_harmony.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/encoding.rs b/src/encoding.rs index afe1fce..d57f8ec 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -835,7 +835,22 @@ impl Render for HarmonyEncoding { // finally content type if let Some(content_type) = &message.content_type { - self.render_text_into(format!(" {content_type}"), into)?; + // <|constrain|> is a unique case which needs to be tokenized as a special token + if let Some(constrain_marker) = self.mapped_format_token(FormattingToken::ConstrainedFormat) { + if content_type.starts_with(constrain_marker) { + // Render the space, then the constrain marker as a special token, then the rest as text (if any) + self.render_text_into(" ", into)?; + self.render_formatting_token_into(FormattingToken::ConstrainedFormat, into)?; + let rest = &content_type[constrain_marker.len()..]; + if !rest.is_empty() { + self.render_text_into(rest, into)?; + } + } else { + self.render_text_into(format!(" {content_type}"), into)?; + } + } else { + self.render_text_into(format!(" {content_type}"), into)?; + } } self.render_formatting_token_into(FormattingToken::Message, into)?; diff --git a/tests/test_harmony.py b/tests/test_harmony.py index 07d5562..dd34e81 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -233,6 +233,36 @@ def test_simple_tool_call(encoding_name): assert parsed == expected +@pytest.mark.parametrize( + "encoding_name", + [ + HarmonyEncodingName.HARMONY_GPT_OSS, + ], +) +def test_tool_call_with_constrain_tokenized_correctly(encoding_name): + """ + Despite passing <|constrain|> as a string in "content_type" it has to be kept as a special token. + """ + encoding = load_harmony_encoding(encoding_name) + text = ( + "<|start|>assistant to=functions.get_weather<|channel|>commentary" + ' <|constrain|>json<|message|>{"location": "Tokyo"}<|call|>' + ) + tokens = encoding.encode(text, allowed_special="all") + parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None) + expected = [ + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.get_weather") + .with_content_type("<|constrain|>json"), + ] + assert parsed == expected + + rendered = encoding.render_conversation(Conversation.from_messages(expected)) + assert text == encoding.decode_utf8(tokens) + assert rendered == tokens + + @pytest.mark.parametrize( "encoding_name", [ @@ -248,7 +278,7 @@ def test_tool_call_with_constrain_marker_adjacent(encoding_name): encoding = load_harmony_encoding(encoding_name) text = ( "<|start|>assistant to=functions.get_weather<|channel|>commentary" - '<|constrain|>json<|message|>{"location": "Tokyo"}<|end|>' + '<|constrain|>json<|message|>{"location": "Tokyo"}<|call|>' ) tokens = encoding.encode(text, allowed_special="all") parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None) @@ -702,6 +732,8 @@ def test_does_not_drop_if_ongoing_analysis(): ) assert encoding.decode_utf8(tokens) == expected_output + # ensure that <|constrain|>json part is tokenized correctly as special tokens + assert encoding.encode(expected_output, allowed_special="all") == tokens def test_preserve_cot():