Initial commit

Co-authored-by: scott-oai <142930063+scott-oai@users.noreply.github.com>
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
This commit is contained in:
Dominik Kundel 2025-08-05 08:25:17 -07:00 committed by Scott Lessans
commit 253cdca537
70 changed files with 15013 additions and 0 deletions

View file

@ -0,0 +1,702 @@
"""Python wrapper around the Rust implementation of *harmony*.
The heavy lifting (tokenisation, rendering, parsing, ) is implemented in
Rust. The thin bindings are available through the private ``openai_harmony``
extension module which is compiled via *maturin* / *PyO3*.
This package provides a small, typed convenience layer that mirrors the public
API of the Rust crate so that it can be used from Python code in an
idiomatic way (``dataclasses``, ``Enum``s, ).
"""
from __future__ import annotations
import functools
import json
from enum import Enum
from typing import (
AbstractSet,
Any,
Collection,
Dict,
List,
Literal,
Optional,
Pattern,
Sequence,
TypeVar,
Union,
)
import re
from pydantic import BaseModel, Field, root_validator, validator
# Re-export the low-level Rust bindings under a private name so that we can
# keep the *public* namespace clean and purely Pythonic.
try:
from .openai_harmony import (
HarmonyError as HarmonyError, # expose the actual Rust error directly
)
from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding # type: ignore
from .openai_harmony import (
PyStreamableParser as _PyStreamableParser, # type: ignore
)
from .openai_harmony import (
load_harmony_encoding as _load_harmony_encoding, # type: ignore
)
except ModuleNotFoundError: # pragma: no cover raised during type-checking
# When running *mypy* without the compiled extension in place we still want
# to succeed. Therefore we create dummy stubs that satisfy the type
# checker. They will, however, raise at **runtime** if accessed.
class _Stub: # pylint: disable=too-few-public-methods
def __getattr__(self, name: str) -> None: # noqa: D401
raise RuntimeError(
"The compiled harmony bindings are not available. Make sure to "
"build the project with `maturin develop` before running this "
"code."
)
_load_harmony_encoding = _Stub() # type: ignore
_PyHarmonyEncoding = _Stub() # type: ignore
_PyStreamableParser = _Stub() # type: ignore
_HarmonyError = RuntimeError
def _special_token_regex(tokens: frozenset[str]) -> Pattern[str]:
inner = "|".join(re.escape(token) for token in tokens)
return re.compile(f"({inner})")
def raise_disallowed_special_token(token: str) -> None:
raise HarmonyError(
"Encountered text corresponding to disallowed special token "
f"{token!r}.\n"
"If you want this text to be encoded as a special token, "
f"pass it to `allowed_special`, e.g. `allowed_special={{'{token}', ...}}`.\n"
"If you want this text to be encoded as normal text, disable the check for this token "
f"by passing `disallowed_special=(enc.special_tokens_set - {{'{token}'}})`.\n"
"To disable this check for all special tokens, pass `disallowed_special=()`.\n"
)
# ---------------------------------------------------------------------------
# Chat-related data-structures (mirroring ``src/chat.rs``)
# ---------------------------------------------------------------------------
class Role(str, Enum):
"""The role of a message author (mirrors ``chat::Role``)."""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
DEVELOPER = "developer"
TOOL = "tool"
@classmethod
def _missing_(cls, value: object) -> "Role": # type: ignore[override]
raise ValueError(f"Unknown role: {value!r}")
class Author(BaseModel):
role: Role
name: Optional[str] = None
@classmethod
def new(cls, role: Role, name: str) -> "Author": # noqa: D401 keep parity with Rust API
return cls(role=role, name=name)
# Content hierarchy ---------------------------------------------------------
T = TypeVar("T")
class Content(BaseModel): # noqa: D101 simple wrapper
def to_dict(self) -> Dict[str, Any]:
raise NotImplementedError
class TextContent(Content):
text: str
def to_dict(self) -> Dict[str, Any]:
return {"type": "text", "text": self.text}
class ToolDescription(BaseModel):
name: str
description: str
parameters: Optional[dict] = None
@classmethod
def new(
cls, name: str, description: str, parameters: Optional[dict] = None
) -> "ToolDescription": # noqa: D401
return cls(name=name, description=description, parameters=parameters)
class ReasoningEffort(str, Enum):
LOW = "Low"
MEDIUM = "Medium"
HIGH = "High"
class ChannelConfig(BaseModel):
valid_channels: List[str]
channel_required: bool
@classmethod
def require_channels(cls, channels: List[str]) -> "ChannelConfig": # noqa: D401
return cls(valid_channels=channels, channel_required=True)
class ToolNamespaceConfig(BaseModel):
name: str
description: Optional[str] = None
tools: List[ToolDescription]
@staticmethod
def browser() -> "ToolNamespaceConfig":
from .openai_harmony import (
get_tool_namespace_config as _get_tool_namespace_config,
)
cfg = _get_tool_namespace_config("browser")
return ToolNamespaceConfig(**cfg)
@staticmethod
def python() -> "ToolNamespaceConfig":
from .openai_harmony import (
get_tool_namespace_config as _get_tool_namespace_config,
)
cfg = _get_tool_namespace_config("python")
return ToolNamespaceConfig(**cfg)
class SystemContent(Content):
model_identity: Optional[str] = (
"You are ChatGPT, a large language model trained by OpenAI."
)
reasoning_effort: Optional[ReasoningEffort] = ReasoningEffort.MEDIUM
conversation_start_date: Optional[str] = None
knowledge_cutoff: Optional[str] = "2024-06"
channel_config: Optional[ChannelConfig] = Field(
default_factory=lambda: ChannelConfig.require_channels(
["analysis", "commentary", "final"]
)
)
tools: Optional[dict[str, ToolNamespaceConfig]] = None
@classmethod
def new(cls) -> "SystemContent":
return cls()
# Fluent interface ------------------------------------------------------
def with_model_identity(self, model_identity: str) -> "SystemContent":
self.model_identity = model_identity
return self
def with_reasoning_effort(
self, reasoning_effort: ReasoningEffort
) -> "SystemContent":
self.reasoning_effort = reasoning_effort
return self
def with_conversation_start_date(
self, conversation_start_date: str
) -> "SystemContent":
self.conversation_start_date = conversation_start_date
return self
def with_knowledge_cutoff(self, knowledge_cutoff: str) -> "SystemContent":
self.knowledge_cutoff = knowledge_cutoff
return self
def with_channel_config(self, channel_config: ChannelConfig) -> "SystemContent":
self.channel_config = channel_config
return self
def with_required_channels(self, channels: list[str]) -> "SystemContent":
self.channel_config = ChannelConfig.require_channels(channels)
return self
def with_tools(self, ns_config: ToolNamespaceConfig) -> "SystemContent":
if self.tools is None:
self.tools = {}
self.tools[ns_config.name] = ns_config
return self
def with_browser_tool(self) -> "SystemContent":
return self.with_tools(ToolNamespaceConfig.browser())
def with_python_tool(self) -> "SystemContent":
return self.with_tools(ToolNamespaceConfig.python())
def to_dict(self) -> dict:
out = self.model_dump(exclude_none=True)
out["type"] = "system_content"
return out
@classmethod
def from_dict(cls, raw: dict) -> "SystemContent":
return cls(**raw)
class DeveloperContent(Content):
instructions: Optional[str] = None
tools: Optional[dict[str, ToolNamespaceConfig]] = None
@classmethod
def new(cls) -> "DeveloperContent":
return cls()
def with_instructions(self, instructions: str) -> "DeveloperContent":
self.instructions = instructions
return self
def with_tools(self, ns_config: ToolNamespaceConfig) -> "DeveloperContent":
if self.tools is None:
self.tools = {}
self.tools[ns_config.name] = ns_config
return self
def with_function_tools(
self, tools: Sequence[ToolDescription]
) -> "DeveloperContent":
return self.with_tools(
ToolNamespaceConfig(name="functions", description=None, tools=list(tools))
)
def to_dict(self) -> dict:
out = self.model_dump(exclude_none=True)
out["type"] = "developer_content"
return out
@classmethod
def from_dict(cls, raw: dict) -> "DeveloperContent":
return cls(**raw)
# Message & Conversation -----------------------------------------------------
class Message(BaseModel):
author: Author
content: List[Content] = Field(default_factory=list)
channel: Optional[str] = None
recipient: Optional[str] = None
content_type: Optional[str] = None
# ------------------------------------------------------------------
# Convenience constructors (mirroring the Rust API)
# ------------------------------------------------------------------
@classmethod
def from_author_and_content(
cls, author: Author, content: Union[str, Content]
) -> "Message":
if isinstance(content, str):
content = TextContent(text=content)
return cls(author=author, content=[content])
@classmethod
def from_role_and_content(
cls, role: Role, content: Union[str, Content]
) -> "Message": # noqa: D401 parity with Rust API
return cls.from_author_and_content(Author(role=role), content)
@classmethod
def from_role_and_contents(
cls, role: Role, contents: Sequence[Content]
) -> "Message":
return cls(author=Author(role=role), content=list(contents))
# ------------------------------------------------------------------
# Builder helpers
# ------------------------------------------------------------------
def adding_content(self, content: Union[str, Content]) -> "Message":
if isinstance(content, str):
content = TextContent(text=content)
self.content.append(content)
return self
def with_channel(self, channel: str) -> "Message":
self.channel = channel
return self
def with_recipient(self, recipient: str) -> "Message":
self.recipient = recipient
return self
def with_content_type(self, content_type: str) -> "Message":
self.content_type = content_type
return self
# ------------------------------------------------------------------
# Serialisation helpers
# ------------------------------------------------------------------
def to_dict(self) -> Dict[str, Any]: # noqa: D401 simple mapper
out: Dict[str, Any] = {
**self.author.model_dump(),
"content": [c.to_dict() for c in self.content],
}
if self.channel is not None:
out["channel"] = self.channel
if self.recipient is not None:
out["recipient"] = self.recipient
if self.content_type is not None:
out["content_type"] = self.content_type
return out
def to_json(self) -> str: # noqa: D401
return json.dumps(self.to_dict())
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Message":
# Simple, sufficient implementation for test-roundtrip purposes.
role = Role(data["role"])
author = Author(role=role, name=data.get("name"))
contents: List[Content] = []
raw_content = data["content"]
# The Rust side serialises *single* text contents as a **plain string**
# for convenience. Detect this shortcut and normalise it to the list
# representation that the rest of the Python code expects.
if isinstance(raw_content, str):
raw_content = [{"type": "text", "text": raw_content}]
for raw in raw_content:
if raw.get("type") == "text":
contents.append(TextContent(**raw))
elif raw.get("type") == "system_content":
contents.append(SystemContent(**raw))
elif raw.get("type") == "developer_content":
contents.append(DeveloperContent(**raw))
else: # pragma: no cover unknown variant
raise ValueError(f"Unknown content variant: {raw}")
msg = cls(author=author, content=contents)
msg.channel = data.get("channel")
msg.recipient = data.get("recipient")
msg.content_type = data.get("content_type")
return msg
class Conversation(BaseModel):
messages: List[Message] = Field(default_factory=list)
@classmethod
def from_messages(cls, messages: Sequence[Message]) -> "Conversation": # noqa: D401
return cls(messages=list(messages))
def __iter__(self):
return iter(self.messages)
# Serialisation helpers -------------------------------------------------
def to_dict(self) -> Dict[str, Any]: # noqa: D401
return {"messages": [m.to_dict() for m in self.messages]}
def to_json(self) -> str: # noqa: D401
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, payload: str) -> "Conversation": # noqa: D401
data = json.loads(payload)
return cls(messages=[Message.from_dict(m) for m in data["messages"]])
# ---------------------------------------------------------------------------
# Encoding interaction (thin wrappers around the Rust bindings)
# ---------------------------------------------------------------------------
class RenderConversationConfig(BaseModel):
auto_drop_analysis: bool = True
class HarmonyEncoding:
"""High-level wrapper around the Rust ``PyHarmonyEncoding`` class."""
def __init__(self, inner: _PyHarmonyEncoding):
self._inner = inner
# ------------------------------------------------------------------
# Delegated helpers
# ------------------------------------------------------------------
@property
def name(self) -> str: # noqa: D401
return self._inner.name # type: ignore[attr-defined]
@functools.cached_property
def special_tokens_set(self) -> set[str]:
return set(self._inner.special_tokens())
# -- Rendering -----------------------------------------------------
def render_conversation_for_completion(
self,
conversation: Conversation,
next_turn_role: Role,
config: Optional[RenderConversationConfig] = None,
) -> List[int]:
"""
Render a conversation for completion.
Args:
conversation: Conversation object
next_turn_role: Role for the next turn
config: Optional RenderConversationConfig (default auto_drop_analysis=True)
"""
if config is None:
config_dict = {"auto_drop_analysis": True}
else:
config_dict = {"auto_drop_analysis": config.auto_drop_analysis}
return self._inner.render_conversation_for_completion(
conversation_json=conversation.to_json(),
next_turn_role=str(next_turn_role.value),
config=config_dict,
)
def render_conversation(
self,
conversation: Conversation,
config: Optional[RenderConversationConfig] = None,
) -> List[int]:
"""Render a conversation without appending a new role."""
if config is None:
config_dict = {"auto_drop_analysis": True}
else:
config_dict = {"auto_drop_analysis": config.auto_drop_analysis}
return self._inner.render_conversation(
conversation_json=conversation.to_json(),
config=config_dict,
)
def render_conversation_for_training(
self,
conversation: Conversation,
config: Optional[RenderConversationConfig] = None,
) -> List[int]:
"""Render a conversation for training."""
if config is None:
config_dict = {"auto_drop_analysis": True}
else:
config_dict = {"auto_drop_analysis": config.auto_drop_analysis}
return self._inner.render_conversation_for_training(
conversation_json=conversation.to_json(),
config=config_dict,
)
def render(self, message: Message) -> List[int]:
"""Render a single message into tokens."""
return self._inner.render(message_json=message.to_json())
# -- Parsing -------------------------------------------------------
def parse_messages_from_completion_tokens(
self, tokens: Sequence[int], role: Optional[Role] | None = None
) -> List[Message]:
raw_json: str = self._inner.parse_messages_from_completion_tokens(
list(tokens), None if role is None else str(role.value)
)
return [Message.from_dict(m) for m in json.loads(raw_json)]
# -- Token decoding ------------------------------------------------
def decode_utf8(self, tokens: Sequence[int]) -> str:
"""Decode a list of tokens into a UTF-8 string. Will raise an error if the tokens result in invalid UTF-8. Use decode if you want to replace invalid UTF-8 with the unicode replacement character."""
return self._inner.decode_utf8(list(tokens))
def encode(
self,
text: str,
*,
allowed_special: Literal["all"] | AbstractSet[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
) -> list[int]:
"""Encodes a string into tokens.
Special tokens are artificial tokens used to unlock capabilities from a model,
such as fill-in-the-middle. So we want to be careful about accidentally encoding special
tokens, since they can be used to trick a model into doing something we don't want it to do.
Hence, by default, encode will raise an error if it encounters text that corresponds
to a special token. This can be controlled on a per-token level using the `allowed_special`
and `disallowed_special` parameters. In particular:
- Setting `disallowed_special` to () will prevent this function from raising errors and
cause all text corresponding to special tokens to be encoded as natural text.
- Setting `allowed_special` to "all" will cause this function to treat all text
corresponding to special tokens to be encoded as special tokens.
```
>>> enc.encode("hello world")
[31373, 995]
>>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})
[50256]
>>> enc.encode("<|endoftext|>", allowed_special="all")
[50256]
>>> enc.encode("<|endoftext|>")
# Raises ValueError
>>> enc.encode("<|endoftext|>", disallowed_special=())
[27, 91, 437, 1659, 5239, 91, 29]
```
"""
if allowed_special == "all":
allowed_special = self.special_tokens_set
if disallowed_special == "all":
disallowed_special = self.special_tokens_set - set(allowed_special)
if disallowed_special:
if not isinstance(disallowed_special, frozenset):
disallowed_special = frozenset(disallowed_special)
if match := _special_token_regex(disallowed_special).search(text):
raise_disallowed_special_token(match.group())
try:
return self._inner.encode(text, list(allowed_special))
except UnicodeEncodeError:
text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
return self._inner.encode(text, list(allowed_special))
def decode(self, tokens: Sequence[int], errors: str = "replace") -> str:
"""Decodes a list of tokens into a string.
WARNING: the default behaviour of this function is lossy, since decoded bytes are not
guaranteed to be valid UTF-8. You can use `decode_utf8` if you want to raise an error on invalid UTF-8.
```
>>> enc.decode([31373, 995])
'hello world'
```
"""
data = bytes(self._inner.decode_bytes(list(tokens)))
return data.decode("utf-8", errors=errors)
def is_special_token(self, token: int) -> bool:
"""Returns if an individual token is a special token"""
return self._inner.is_special_token(token)
# -- Stop tokens --------------------------------------------------
def stop_tokens(self) -> List[int]:
return self._inner.stop_tokens()
def stop_tokens_for_assistant_actions(self) -> List[int]:
return self._inner.stop_tokens_for_assistant_actions()
class StreamState(Enum):
EXPECT_START = "ExpectStart"
HEADER = "Header"
CONTENT = "Content"
class StreamableParser:
"""Incremental parser over completion tokens."""
def __init__(self, encoding: HarmonyEncoding, role: Role | None):
role_str = str(role.value) if role is not None else None
self._inner = _PyStreamableParser(encoding._inner, role_str)
def process(self, token: int) -> "StreamableParser":
self._inner.process(token)
return self
@property
def current_content(self) -> str:
return self._inner.current_content
@property
def current_role(self) -> Optional[Role]:
raw = self._inner.current_role
return Role(raw) if raw is not None else None
@property
def current_content_type(self) -> Optional[str]:
return self._inner.current_content_type
@property
def last_content_delta(self) -> Optional[str]:
return self._inner.last_content_delta
@property
def messages(self) -> List[Message]:
raw = self._inner.messages
return [Message.from_dict(m) for m in json.loads(raw)]
@property
def tokens(self) -> List[int]:
return self._inner.tokens
@property
def state_data(self) -> Dict[str, Any]:
"""Return a JSON string representing the parser's internal state."""
return json.loads(self._inner.state)
@property
def state(self) -> StreamState:
data = self.state_data
return StreamState(data["state"])
@property
def current_recipient(self) -> Optional[str]:
return self._inner.current_recipient
@property
def current_channel(self) -> Optional[str]:
return self._inner.current_channel
# Public helper --------------------------------------------------------------
def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding: # type: ignore[name-defined]
"""Load an encoding by *name* (delegates to the Rust implementation)."""
# Allow both strings and enum values.
if not isinstance(name, str):
name = str(name)
inner: _PyHarmonyEncoding = _load_harmony_encoding(name)
return HarmonyEncoding(inner)
# For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum. At
# **runtime** the user is expected to pass the *string* names because the Rust
# side only operates on strings anyway.
class HarmonyEncodingName(str, Enum): # noqa: D101 simple enum stub
HARMONY_GPT_OSS = "HarmonyGptOss"
def __str__(self) -> str: # noqa: D401
return str(self.value)
# What should be re-exported when the user does ``from harmony import *``?
__all__ = [
"Role",
"Author",
"Content",
"TextContent",
"ToolDescription",
"SystemContent",
"Message",
"Conversation",
"HarmonyEncoding",
"HarmonyEncodingName",
"load_harmony_encoding",
"StreamableParser",
"StreamState",
"HarmonyError",
]