mirror of
https://github.com/openai/harmony.git
synced 2025-08-26 10:17:09 -04:00
Fix tokenization of <|constrain|> content type in rendering
This commit is contained in:
parent
3efbf74253
commit
7285cafe67
2 changed files with 49 additions and 2 deletions
|
@ -835,7 +835,22 @@ impl Render<Message> for HarmonyEncoding {
|
||||||
|
|
||||||
// finally content type
|
// finally content type
|
||||||
if let Some(content_type) = &message.content_type {
|
if let Some(content_type) = &message.content_type {
|
||||||
self.render_text_into(format!(" {content_type}"), into)?;
|
// <|constrain|> is a unique case which needs to be tokenized as a special token
|
||||||
|
if let Some(constrain_marker) = self.mapped_format_token(FormattingToken::ConstrainedFormat) {
|
||||||
|
if content_type.starts_with(constrain_marker) {
|
||||||
|
// Render the space, then the constrain marker as a special token, then the rest as text (if any)
|
||||||
|
self.render_text_into(" ", into)?;
|
||||||
|
self.render_formatting_token_into(FormattingToken::ConstrainedFormat, into)?;
|
||||||
|
let rest = &content_type[constrain_marker.len()..];
|
||||||
|
if !rest.is_empty() {
|
||||||
|
self.render_text_into(rest, into)?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.render_text_into(format!(" {content_type}"), into)?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.render_text_into(format!(" {content_type}"), into)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.render_formatting_token_into(FormattingToken::Message, into)?;
|
self.render_formatting_token_into(FormattingToken::Message, into)?;
|
||||||
|
|
|
@ -233,6 +233,36 @@ def test_simple_tool_call(encoding_name):
|
||||||
assert parsed == expected
|
assert parsed == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"encoding_name",
|
||||||
|
[
|
||||||
|
HarmonyEncodingName.HARMONY_GPT_OSS,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_tool_call_with_constrain_tokenized_correctly(encoding_name):
|
||||||
|
"""
|
||||||
|
Despite passing <|constrain|> as a string in "content_type" it has to be kept as a special token.
|
||||||
|
"""
|
||||||
|
encoding = load_harmony_encoding(encoding_name)
|
||||||
|
text = (
|
||||||
|
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||||
|
' <|constrain|>json<|message|>{"location": "Tokyo"}<|call|>'
|
||||||
|
)
|
||||||
|
tokens = encoding.encode(text, allowed_special="all")
|
||||||
|
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
|
||||||
|
expected = [
|
||||||
|
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||||
|
.with_channel("commentary")
|
||||||
|
.with_recipient("functions.get_weather")
|
||||||
|
.with_content_type("<|constrain|>json"),
|
||||||
|
]
|
||||||
|
assert parsed == expected
|
||||||
|
|
||||||
|
rendered = encoding.render_conversation(Conversation.from_messages(expected))
|
||||||
|
assert text == encoding.decode_utf8(tokens)
|
||||||
|
assert rendered == tokens
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"encoding_name",
|
"encoding_name",
|
||||||
[
|
[
|
||||||
|
@ -248,7 +278,7 @@ def test_tool_call_with_constrain_marker_adjacent(encoding_name):
|
||||||
encoding = load_harmony_encoding(encoding_name)
|
encoding = load_harmony_encoding(encoding_name)
|
||||||
text = (
|
text = (
|
||||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||||
'<|constrain|>json<|message|>{"location": "Tokyo"}<|end|>'
|
'<|constrain|>json<|message|>{"location": "Tokyo"}<|call|>'
|
||||||
)
|
)
|
||||||
tokens = encoding.encode(text, allowed_special="all")
|
tokens = encoding.encode(text, allowed_special="all")
|
||||||
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
|
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
|
||||||
|
@ -702,6 +732,8 @@ def test_does_not_drop_if_ongoing_analysis():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert encoding.decode_utf8(tokens) == expected_output
|
assert encoding.decode_utf8(tokens) == expected_output
|
||||||
|
# ensure that <|constrain|>json part is tokenized correctly as special tokens
|
||||||
|
assert encoding.encode(expected_output, allowed_special="all") == tokens
|
||||||
|
|
||||||
|
|
||||||
def test_preserve_cot():
|
def test_preserve_cot():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue