Add Python type stubs for Rust bindings in the C extension

This commit is contained in:
Xuehai Pan 2025-08-06 20:18:02 +08:00
parent 9528c7b4a0
commit cc2a2ce2bf
3 changed files with 122 additions and 31 deletions

View file

@ -13,8 +13,10 @@ from __future__ import annotations
import functools import functools
import json import json
import re
from enum import Enum from enum import Enum
from typing import ( from typing import (
TYPE_CHECKING,
AbstractSet, AbstractSet,
Any, Any,
Collection, Collection,
@ -28,40 +30,43 @@ from typing import (
Union, Union,
) )
import re
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# Re-export the low-level Rust bindings under a private name so that we can if not TYPE_CHECKING:
# keep the *public* namespace clean and purely Pythonic. # Re-export the low-level Rust bindings under a private name so that we can
try: # keep the *public* namespace clean and purely Pythonic.
try:
from .openai_harmony import (
HarmonyError as HarmonyError, # expose the actual Rust error directly
)
from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding
from .openai_harmony import PyStreamableParser as _PyStreamableParser
from .openai_harmony import load_harmony_encoding as _load_harmony_encoding
except ModuleNotFoundError: # pragma: no cover raised during type-checking
# When running *mypy* without the compiled extension in place we still want
# to succeed. Therefore we create dummy stubs that satisfy the type
# checker. They will, however, raise at **runtime** if accessed.
class _Stub: # pylint: disable=too-few-public-methods
def __getattr__(self, name: str) -> None: # noqa: D401
raise RuntimeError(
"The compiled harmony bindings are not available. Make sure to "
"build the project with `maturin develop` before running this "
"code."
)
_load_harmony_encoding = _Stub()
_PyHarmonyEncoding = _Stub()
_PyStreamableParser = _Stub()
_HarmonyError = RuntimeError
else: # pragma: no branch
from .openai_harmony import ( from .openai_harmony import (
HarmonyError as HarmonyError, # expose the actual Rust error directly HarmonyError as HarmonyError, # expose the actual Rust error directly
) )
from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding # type: ignore from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding
from .openai_harmony import ( from .openai_harmony import PyStreamableParser as _PyStreamableParser
PyStreamableParser as _PyStreamableParser, # type: ignore from .openai_harmony import load_harmony_encoding as _load_harmony_encoding
)
from .openai_harmony import (
load_harmony_encoding as _load_harmony_encoding, # type: ignore
)
except ModuleNotFoundError: # pragma: no cover raised during type-checking
# When running *mypy* without the compiled extension in place we still want
# to succeed. Therefore we create dummy stubs that satisfy the type
# checker. They will, however, raise at **runtime** if accessed.
class _Stub: # pylint: disable=too-few-public-methods
def __getattr__(self, name: str) -> None: # noqa: D401
raise RuntimeError(
"The compiled harmony bindings are not available. Make sure to "
"build the project with `maturin develop` before running this "
"code."
)
_load_harmony_encoding = _Stub() # type: ignore
_PyHarmonyEncoding = _Stub() # type: ignore
_PyStreamableParser = _Stub() # type: ignore
_HarmonyError = RuntimeError
def _special_token_regex(tokens: frozenset[str]) -> Pattern[str]: def _special_token_regex(tokens: frozenset[str]) -> Pattern[str]:
@ -441,7 +446,7 @@ class HarmonyEncoding:
@property @property
def name(self) -> str: # noqa: D401 def name(self) -> str: # noqa: D401
return self._inner.name # type: ignore[attr-defined] return self._inner.name
@functools.cached_property @functools.cached_property
def special_tokens_set(self) -> set[str]: def special_tokens_set(self) -> set[str]:
@ -679,7 +684,7 @@ class StreamableParser:
# Public helper -------------------------------------------------------------- # Public helper --------------------------------------------------------------
def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding: # type: ignore[name-defined] def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding:
"""Load an encoding by *name* (delegates to the Rust implementation).""" """Load an encoding by *name* (delegates to the Rust implementation)."""
# Allow both strings and enum values. # Allow both strings and enum values.

View file

@ -0,0 +1,86 @@
"""Type stubs for the OpenAI Harmony Rust bindings in the C extension."""
from collections.abc import Iterable
from enum import Enum
from typing import Any
class HarmonyError(RuntimeError): ...
class PyHarmonyEncoding:
def __init__(self, name: str) -> None: ...
@property
def name(self) -> str: ...
def decode_bytes(self, tokens: list[int]) -> bytes: ...
def decode_utf8(self, tokens: list[int]) -> str: ...
def encode(
self,
text: str,
allowed_special: Iterable[str] | None = None,
) -> list[int]: ...
def is_special_token(self, token: int) -> bool: ...
def parse_messages_from_completion_tokens(
self,
tokens: list[int],
role: str | None = None,
) -> str: ...
def render(
self,
message_json: str,
render_options: dict[str, Any] | None = None,
) -> list[int]: ...
def render_conversation(
self,
conversation_json: str,
config: dict[str, Any] | None = None,
) -> list[int]: ...
def render_conversation_for_completion(
self,
conversation_json: str,
next_turn_role: str,
config: dict[str, Any] | None = None,
) -> list[int]: ...
def render_conversation_for_training(
self,
conversation_json: str,
config: dict[str, Any] | None = None,
) -> list[int]: ...
def special_tokens(self) -> list[str]: ...
def stop_tokens(self) -> list[int]: ...
def stop_tokens_for_assistant_actions(self) -> list[int]: ...
class PyStreamableParser:
def __init__(
self,
encoding: PyHarmonyEncoding,
role: str | None = None,
) -> None: ...
def process(self, token: int) -> None: ...
def process_eos(self) -> None: ...
@property
def current_channel(self) -> str | None: ...
@property
def current_content(self) -> str: ...
@property
def current_content_type(self) -> str | None: ...
@property
def current_recipient(self) -> str | None: ...
@property
def current_role(self) -> str | None: ...
@property
def last_content_delta(self) -> str | None: ...
@property
def messages(self) -> str: ...
@property
def state(self) -> str: ...
@property
def tokens(self) -> list[int]: ...
class PyStreamState(Enum):
ExpectStart = ...
Header = ...
Content = ...
def __int__(self) -> int: ...
def get_tool_namespace_config(tool: str) -> Any: ...
def load_harmony_encoding(name: str) -> PyHarmonyEncoding: ...

View file