openai-harmony/python/openai_harmony/__init__.py
Simon Willison bd5ed13085
Add 'DeveloperContent' to __all__
For better import *
2025-08-08 04:58:53 -07:00

722 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Python wrapper around the Rust implementation of *harmony*.
The heavy lifting (tokenisation, rendering, parsing, …) is implemented in
Rust. The thin bindings are available through the private ``openai_harmony``
extension module which is compiled via *maturin* / *PyO3*.
This package provides a small, typed convenience layer that mirrors the public
API of the Rust crate so that it can be used from Python code in an
idiomatic way (``dataclasses``, ``Enum``s, …).
"""
from __future__ import annotations
import functools
import json
from enum import Enum
from typing import (
AbstractSet,
Any,
Collection,
Dict,
List,
Literal,
Optional,
Pattern,
Sequence,
TypeVar,
Union,
)
import re
from pydantic import BaseModel, Field
# Re-export the low-level Rust bindings under a private name so that we can
# keep the *public* namespace clean and purely Pythonic.
try:
from .openai_harmony import (
HarmonyError as HarmonyError, # expose the actual Rust error directly
)
from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding # type: ignore
from .openai_harmony import (
PyStreamableParser as _PyStreamableParser, # type: ignore
)
from .openai_harmony import (
load_harmony_encoding as _load_harmony_encoding, # type: ignore
)
except ModuleNotFoundError: # pragma: no cover raised during type-checking
# When running *mypy* without the compiled extension in place we still want
# to succeed. Therefore we create dummy stubs that satisfy the type
# checker. They will, however, raise at **runtime** if accessed.
class _Stub: # pylint: disable=too-few-public-methods
def __getattr__(self, name: str) -> None: # noqa: D401
raise RuntimeError(
"The compiled harmony bindings are not available. Make sure to "
"build the project with `maturin develop` before running this "
"code."
)
_load_harmony_encoding = _Stub() # type: ignore
_PyHarmonyEncoding = _Stub() # type: ignore
_PyStreamableParser = _Stub() # type: ignore
_HarmonyError = RuntimeError
def _special_token_regex(tokens: frozenset[str]) -> Pattern[str]:
inner = "|".join(re.escape(token) for token in tokens)
return re.compile(f"({inner})")
def raise_disallowed_special_token(token: str) -> None:
raise HarmonyError(
"Encountered text corresponding to disallowed special token "
f"{token!r}.\n"
"If you want this text to be encoded as a special token, "
f"pass it to `allowed_special`, e.g. `allowed_special={{'{token}', ...}}`.\n"
"If you want this text to be encoded as normal text, disable the check for this token "
f"by passing `disallowed_special=(enc.special_tokens_set - {{'{token}'}})`.\n"
"To disable this check for all special tokens, pass `disallowed_special=()`.\n"
)
# ---------------------------------------------------------------------------
# Chat-related data-structures (mirroring ``src/chat.rs``)
# ---------------------------------------------------------------------------
class Role(str, Enum):
"""The role of a message author (mirrors ``chat::Role``)."""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
DEVELOPER = "developer"
TOOL = "tool"
@classmethod
def _missing_(cls, value: object) -> "Role": # type: ignore[override]
raise ValueError(f"Unknown role: {value!r}")
class Author(BaseModel):
role: Role
name: Optional[str] = None
@classmethod
def new(cls, role: Role, name: str) -> "Author": # noqa: D401 keep parity with Rust API
return cls(role=role, name=name)
# Content hierarchy ---------------------------------------------------------
T = TypeVar("T")
class Content(BaseModel): # noqa: D101 simple wrapper
def to_dict(self) -> Dict[str, Any]:
raise NotImplementedError
class TextContent(Content):
text: str
def to_dict(self) -> Dict[str, Any]:
return {"type": "text", "text": self.text}
class ToolDescription(BaseModel):
name: str
description: str
parameters: Optional[dict] = None
@classmethod
def new(
cls, name: str, description: str, parameters: Optional[dict] = None
) -> "ToolDescription": # noqa: D401
return cls(name=name, description=description, parameters=parameters)
class ReasoningEffort(str, Enum):
LOW = "Low"
MEDIUM = "Medium"
HIGH = "High"
class ChannelConfig(BaseModel):
valid_channels: List[str]
channel_required: bool
@classmethod
def require_channels(cls, channels: List[str]) -> "ChannelConfig": # noqa: D401
return cls(valid_channels=channels, channel_required=True)
class ToolNamespaceConfig(BaseModel):
name: str
description: Optional[str] = None
tools: List[ToolDescription]
@staticmethod
def browser() -> "ToolNamespaceConfig":
from .openai_harmony import (
get_tool_namespace_config as _get_tool_namespace_config,
)
cfg = _get_tool_namespace_config("browser")
return ToolNamespaceConfig(**cfg)
@staticmethod
def python() -> "ToolNamespaceConfig":
from .openai_harmony import (
get_tool_namespace_config as _get_tool_namespace_config,
)
cfg = _get_tool_namespace_config("python")
return ToolNamespaceConfig(**cfg)
class SystemContent(Content):
model_identity: Optional[str] = (
"You are ChatGPT, a large language model trained by OpenAI."
)
reasoning_effort: Optional[ReasoningEffort] = ReasoningEffort.MEDIUM
conversation_start_date: Optional[str] = None
knowledge_cutoff: Optional[str] = "2024-06"
channel_config: Optional[ChannelConfig] = Field(
default_factory=lambda: ChannelConfig.require_channels(
["analysis", "commentary", "final"]
)
)
tools: Optional[dict[str, ToolNamespaceConfig]] = None
@classmethod
def new(cls) -> "SystemContent":
return cls()
# Fluent interface ------------------------------------------------------
def with_model_identity(self, model_identity: str) -> "SystemContent":
self.model_identity = model_identity
return self
def with_reasoning_effort(
self, reasoning_effort: ReasoningEffort
) -> "SystemContent":
self.reasoning_effort = reasoning_effort
return self
def with_conversation_start_date(
self, conversation_start_date: str
) -> "SystemContent":
self.conversation_start_date = conversation_start_date
return self
def with_knowledge_cutoff(self, knowledge_cutoff: str) -> "SystemContent":
self.knowledge_cutoff = knowledge_cutoff
return self
def with_channel_config(self, channel_config: ChannelConfig) -> "SystemContent":
self.channel_config = channel_config
return self
def with_required_channels(self, channels: list[str]) -> "SystemContent":
self.channel_config = ChannelConfig.require_channels(channels)
return self
def with_tools(self, ns_config: ToolNamespaceConfig) -> "SystemContent":
if self.tools is None:
self.tools = {}
self.tools[ns_config.name] = ns_config
return self
def with_browser_tool(self) -> "SystemContent":
return self.with_tools(ToolNamespaceConfig.browser())
def with_python_tool(self) -> "SystemContent":
return self.with_tools(ToolNamespaceConfig.python())
def to_dict(self) -> dict:
out = self.model_dump(exclude_none=True)
out["type"] = "system_content"
return out
@classmethod
def from_dict(cls, raw: dict) -> "SystemContent":
return cls(**raw)
class DeveloperContent(Content):
instructions: Optional[str] = None
tools: Optional[dict[str, ToolNamespaceConfig]] = None
@classmethod
def new(cls) -> "DeveloperContent":
return cls()
def with_instructions(self, instructions: str) -> "DeveloperContent":
self.instructions = instructions
return self
def with_tools(self, ns_config: ToolNamespaceConfig) -> "DeveloperContent":
if self.tools is None:
self.tools = {}
self.tools[ns_config.name] = ns_config
return self
def with_function_tools(
self, tools: Sequence[ToolDescription]
) -> "DeveloperContent":
return self.with_tools(
ToolNamespaceConfig(name="functions", description=None, tools=list(tools))
)
def to_dict(self) -> dict:
out = self.model_dump(exclude_none=True)
out["type"] = "developer_content"
return out
@classmethod
def from_dict(cls, raw: dict) -> "DeveloperContent":
return cls(**raw)
# Message & Conversation -----------------------------------------------------
class Message(BaseModel):
author: Author
content: List[Content] = Field(default_factory=list)
channel: Optional[str] = None
recipient: Optional[str] = None
content_type: Optional[str] = None
# ------------------------------------------------------------------
# Convenience constructors (mirroring the Rust API)
# ------------------------------------------------------------------
@classmethod
def from_author_and_content(
cls, author: Author, content: Union[str, Content]
) -> "Message":
if isinstance(content, str):
content = TextContent(text=content)
return cls(author=author, content=[content])
@classmethod
def from_role_and_content(
cls, role: Role, content: Union[str, Content]
) -> "Message": # noqa: D401 parity with Rust API
return cls.from_author_and_content(Author(role=role), content)
@classmethod
def from_role_and_contents(
cls, role: Role, contents: Sequence[Content]
) -> "Message":
return cls(author=Author(role=role), content=list(contents))
# ------------------------------------------------------------------
# Builder helpers
# ------------------------------------------------------------------
def adding_content(self, content: Union[str, Content]) -> "Message":
if isinstance(content, str):
content = TextContent(text=content)
self.content.append(content)
return self
def with_channel(self, channel: str) -> "Message":
self.channel = channel
return self
def with_recipient(self, recipient: str) -> "Message":
self.recipient = recipient
return self
def with_content_type(self, content_type: str) -> "Message":
self.content_type = content_type
return self
# ------------------------------------------------------------------
# Serialisation helpers
# ------------------------------------------------------------------
def to_dict(self) -> Dict[str, Any]: # noqa: D401 simple mapper
out: Dict[str, Any] = {
**self.author.model_dump(),
"content": [c.to_dict() for c in self.content],
}
if self.channel is not None:
out["channel"] = self.channel
if self.recipient is not None:
out["recipient"] = self.recipient
if self.content_type is not None:
out["content_type"] = self.content_type
return out
def to_json(self) -> str: # noqa: D401
return json.dumps(self.to_dict())
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Message":
# Simple, sufficient implementation for test-roundtrip purposes.
role = Role(data["role"])
author = Author(role=role, name=data.get("name"))
contents: List[Content] = []
raw_content = data["content"]
# The Rust side serialises *single* text contents as a **plain string**
# for convenience. Detect this shortcut and normalise it to the list
# representation that the rest of the Python code expects.
if isinstance(raw_content, str):
raw_content = [{"type": "text", "text": raw_content}]
for raw in raw_content:
if raw.get("type") == "text":
contents.append(TextContent(**raw))
elif raw.get("type") == "system_content":
contents.append(SystemContent(**raw))
elif raw.get("type") == "developer_content":
contents.append(DeveloperContent(**raw))
else: # pragma: no cover unknown variant
raise ValueError(f"Unknown content variant: {raw}")
msg = cls(author=author, content=contents)
msg.channel = data.get("channel")
msg.recipient = data.get("recipient")
msg.content_type = data.get("content_type")
return msg
class Conversation(BaseModel):
messages: List[Message] = Field(default_factory=list)
@classmethod
def from_messages(cls, messages: Sequence[Message]) -> "Conversation": # noqa: D401
return cls(messages=list(messages))
def __iter__(self):
return iter(self.messages)
# Serialisation helpers -------------------------------------------------
def to_dict(self) -> Dict[str, Any]: # noqa: D401
return {"messages": [m.to_dict() for m in self.messages]}
def to_json(self) -> str: # noqa: D401
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, payload: str) -> "Conversation": # noqa: D401
data = json.loads(payload)
return cls(messages=[Message.from_dict(m) for m in data["messages"]])
# ---------------------------------------------------------------------------
# Encoding interaction (thin wrappers around the Rust bindings)
# ---------------------------------------------------------------------------
class RenderConversationConfig(BaseModel):
auto_drop_analysis: bool = True
class RenderOptions(BaseModel):
conversation_has_function_tools: bool = False
class HarmonyEncoding:
"""High-level wrapper around the Rust ``PyHarmonyEncoding`` class."""
def __init__(self, inner: _PyHarmonyEncoding):
self._inner = inner
# ------------------------------------------------------------------
# Delegated helpers
# ------------------------------------------------------------------
@property
def name(self) -> str: # noqa: D401
return self._inner.name # type: ignore[attr-defined]
@functools.cached_property
def special_tokens_set(self) -> set[str]:
return set(self._inner.special_tokens())
# -- Rendering -----------------------------------------------------
def render_conversation_for_completion(
self,
conversation: Conversation,
next_turn_role: Role,
config: Optional[RenderConversationConfig] = None,
) -> List[int]:
"""
Render a conversation for completion.
Args:
conversation: Conversation object
next_turn_role: Role for the next turn
config: Optional RenderConversationConfig (default auto_drop_analysis=True)
"""
if config is None:
config_dict = {"auto_drop_analysis": True}
else:
config_dict = {"auto_drop_analysis": config.auto_drop_analysis}
return self._inner.render_conversation_for_completion(
conversation_json=conversation.to_json(),
next_turn_role=str(next_turn_role.value),
config=config_dict,
)
def render_conversation(
self,
conversation: Conversation,
config: Optional[RenderConversationConfig] = None,
) -> List[int]:
"""Render a conversation without appending a new role."""
if config is None:
config_dict = {"auto_drop_analysis": True}
else:
config_dict = {"auto_drop_analysis": config.auto_drop_analysis}
return self._inner.render_conversation(
conversation_json=conversation.to_json(),
config=config_dict,
)
def render_conversation_for_training(
self,
conversation: Conversation,
config: Optional[RenderConversationConfig] = None,
) -> List[int]:
"""Render a conversation for training."""
if config is None:
config_dict = {"auto_drop_analysis": True}
else:
config_dict = {"auto_drop_analysis": config.auto_drop_analysis}
return self._inner.render_conversation_for_training(
conversation_json=conversation.to_json(),
config=config_dict,
)
def render(
self, message: Message, render_options: Optional[RenderOptions] = None
) -> List[int]:
"""Render a single message into tokens."""
if render_options is None:
render_options_dict = {"conversation_has_function_tools": False}
else:
render_options_dict = {
"conversation_has_function_tools": render_options.conversation_has_function_tools
}
return self._inner.render(
message_json=message.to_json(), render_options=render_options_dict
)
# -- Parsing -------------------------------------------------------
def parse_messages_from_completion_tokens(
self, tokens: Sequence[int], role: Optional[Role] | None = None
) -> List[Message]:
raw_json: str = self._inner.parse_messages_from_completion_tokens(
list(tokens), None if role is None else str(role.value)
)
return [Message.from_dict(m) for m in json.loads(raw_json)]
# -- Token decoding ------------------------------------------------
def decode_utf8(self, tokens: Sequence[int]) -> str:
"""Decode a list of tokens into a UTF-8 string. Will raise an error if the tokens result in invalid UTF-8. Use decode if you want to replace invalid UTF-8 with the unicode replacement character."""
return self._inner.decode_utf8(list(tokens))
def encode(
self,
text: str,
*,
allowed_special: Literal["all"] | AbstractSet[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
) -> list[int]:
"""Encodes a string into tokens.
Special tokens are artificial tokens used to unlock capabilities from a model,
such as fill-in-the-middle. So we want to be careful about accidentally encoding special
tokens, since they can be used to trick a model into doing something we don't want it to do.
Hence, by default, encode will raise an error if it encounters text that corresponds
to a special token. This can be controlled on a per-token level using the `allowed_special`
and `disallowed_special` parameters. In particular:
- Setting `disallowed_special` to () will prevent this function from raising errors and
cause all text corresponding to special tokens to be encoded as natural text.
- Setting `allowed_special` to "all" will cause this function to treat all text
corresponding to special tokens to be encoded as special tokens.
```
>>> enc.encode("hello world")
[31373, 995]
>>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})
[50256]
>>> enc.encode("<|endoftext|>", allowed_special="all")
[50256]
>>> enc.encode("<|endoftext|>")
# Raises ValueError
>>> enc.encode("<|endoftext|>", disallowed_special=())
[27, 91, 437, 1659, 5239, 91, 29]
```
"""
if allowed_special == "all":
allowed_special = self.special_tokens_set
if disallowed_special == "all":
disallowed_special = self.special_tokens_set - set(allowed_special)
if disallowed_special:
if not isinstance(disallowed_special, frozenset):
disallowed_special = frozenset(disallowed_special)
if match := _special_token_regex(disallowed_special).search(text):
raise_disallowed_special_token(match.group())
try:
return self._inner.encode(text, list(allowed_special))
except UnicodeEncodeError:
text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
return self._inner.encode(text, list(allowed_special))
def decode(self, tokens: Sequence[int], errors: str = "replace") -> str:
"""Decodes a list of tokens into a string.
WARNING: the default behaviour of this function is lossy, since decoded bytes are not
guaranteed to be valid UTF-8. You can use `decode_utf8` if you want to raise an error on invalid UTF-8.
```
>>> enc.decode([31373, 995])
'hello world'
```
"""
data = bytes(self._inner.decode_bytes(list(tokens)))
return data.decode("utf-8", errors=errors)
def is_special_token(self, token: int) -> bool:
"""Returns if an individual token is a special token"""
return self._inner.is_special_token(token)
# -- Stop tokens --------------------------------------------------
def stop_tokens(self) -> List[int]:
return self._inner.stop_tokens()
def stop_tokens_for_assistant_actions(self) -> List[int]:
return self._inner.stop_tokens_for_assistant_actions()
class StreamState(Enum):
EXPECT_START = "ExpectStart"
HEADER = "Header"
CONTENT = "Content"
class StreamableParser:
"""Incremental parser over completion tokens."""
def __init__(self, encoding: HarmonyEncoding, role: Role | None):
role_str = str(role.value) if role is not None else None
self._inner = _PyStreamableParser(encoding._inner, role_str)
def process(self, token: int) -> "StreamableParser":
self._inner.process(token)
return self
def process_eos(self) -> "StreamableParser":
self._inner.process_eos()
return self
@property
def current_content(self) -> str:
return self._inner.current_content
@property
def current_role(self) -> Optional[Role]:
raw = self._inner.current_role
return Role(raw) if raw is not None else None
@property
def current_content_type(self) -> Optional[str]:
return self._inner.current_content_type
@property
def last_content_delta(self) -> Optional[str]:
return self._inner.last_content_delta
@property
def messages(self) -> List[Message]:
raw = self._inner.messages
return [Message.from_dict(m) for m in json.loads(raw)]
@property
def tokens(self) -> List[int]:
return self._inner.tokens
@property
def state_data(self) -> Dict[str, Any]:
"""Return a JSON string representing the parser's internal state."""
return json.loads(self._inner.state)
@property
def state(self) -> StreamState:
data = self.state_data
return StreamState(data["state"])
@property
def current_recipient(self) -> Optional[str]:
return self._inner.current_recipient
@property
def current_channel(self) -> Optional[str]:
return self._inner.current_channel
# Public helper --------------------------------------------------------------
def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding: # type: ignore[name-defined]
"""Load an encoding by *name* (delegates to the Rust implementation)."""
# Allow both strings and enum values.
if not isinstance(name, str):
name = str(name)
inner: _PyHarmonyEncoding = _load_harmony_encoding(name)
return HarmonyEncoding(inner)
# For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum. At
# **runtime** the user is expected to pass the *string* names because the Rust
# side only operates on strings anyway.
class HarmonyEncodingName(str, Enum): # noqa: D101 simple enum stub
HARMONY_GPT_OSS = "HarmonyGptOss"
def __str__(self) -> str: # noqa: D401
return str(self.value)
# What should be re-exported when the user does ``from harmony import *``?
__all__ = [
"Role",
"Author",
"Content",
"TextContent",
"DeveloperContent",
"ToolDescription",
"SystemContent",
"Message",
"Conversation",
"HarmonyEncoding",
"HarmonyEncodingName",
"load_harmony_encoding",
"StreamableParser",
"StreamState",
"HarmonyError",
]