openai-harmony/src/wasm_module.rs
Dominik Kundel 253cdca537 Initial commit
Co-authored-by: scott-oai <142930063+scott-oai@users.noreply.github.com>
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
2025-08-05 08:42:25 -07:00

345 lines
11 KiB
Rust

#![cfg(feature = "wasm-binding")]
use wasm_bindgen::prelude::*;
use crate::{
chat::{Message, Role, ToolNamespaceConfig},
encoding::{HarmonyEncoding, StreamableParser},
load_harmony_encoding as inner_load_harmony_encoding, HarmonyEncodingName,
};
use serde::Deserialize;
use serde_json;
use serde_wasm_bindgen;
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(typescript_type = "Conversation")]
pub type JsConversation;
#[wasm_bindgen(typescript_type = "Message")]
pub type JsMessage;
#[wasm_bindgen(typescript_type = "RenderConversationConfig")]
pub type JsRenderConversationConfig;
}
#[wasm_bindgen(typescript_custom_section)]
const TS_APPEND: &str = r#"
export interface Author {
role: 'user' | 'assistant' | 'system' | 'developer' | 'tool';
name?: string;
}
export type Content =
| { type: 'text'; text: string }
| { type: 'system_content'; model_identity?: string; reasoning_effort?: string; tools?: Record<string, ToolNamespaceConfig>; conversation_start_date?: string; knowledge_cutoff?: string }
| { type: 'developer_content'; instructions?: string; tools?: Record<string, ToolNamespaceConfig> };
export interface Message {
author: Author;
content: Content[];
channel?: string;
recipient?: string;
content_type?: string;
}
export interface Conversation {
messages: Message[];
}
export interface RenderConversationConfig {
auto_drop_analysis?: boolean;
}
export interface ToolNamespaceConfig {
name: string;
description?: string;
tools: any[];
}
"#;
#[wasm_bindgen]
pub struct JsHarmonyEncoding {
inner: HarmonyEncoding,
}
#[wasm_bindgen]
impl JsHarmonyEncoding {
#[wasm_bindgen(getter)]
pub fn name(&self) -> String {
self.inner.name().to_string()
}
#[wasm_bindgen(js_name = renderConversationForCompletion)]
pub fn render_conversation_for_completion(
&self,
conversation: JsConversation,
next_turn_role: &str,
config: JsRenderConversationConfig,
) -> Result<Vec<u32>, JsValue> {
let conversation: JsValue = conversation.into();
let conversation: crate::chat::Conversation = serde_wasm_bindgen::from_value(conversation)
.map_err(|e| JsValue::from_str(&format!("invalid conversation JSON: {e}")))?;
let role = Role::try_from(next_turn_role)
.map_err(|_| JsValue::from_str(&format!("unknown role: {next_turn_role}")))?;
#[derive(Deserialize)]
struct Config {
auto_drop_analysis: Option<bool>,
}
let config: JsValue = config.into();
let rust_config = if config.is_undefined() || config.is_null() {
None
} else {
let cfg: Config = serde_wasm_bindgen::from_value(config)
.map_err(|e| JsValue::from_str(&format!("invalid config: {e}")))?;
Some(crate::encoding::RenderConversationConfig {
auto_drop_analysis: cfg.auto_drop_analysis.unwrap_or(true),
})
};
self.inner
.render_conversation_for_completion(&conversation, role, rust_config.as_ref())
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = renderConversation)]
pub fn render_conversation(
&self,
conversation: JsConversation,
config: JsRenderConversationConfig,
) -> Result<Vec<u32>, JsValue> {
let conversation: JsValue = conversation.into();
let conversation: crate::chat::Conversation = serde_wasm_bindgen::from_value(conversation)
.map_err(|e| JsValue::from_str(&format!("invalid conversation JSON: {e}")))?;
#[derive(Deserialize)]
struct Config {
auto_drop_analysis: Option<bool>,
}
let config: JsValue = config.into();
let rust_config = if config.is_undefined() || config.is_null() {
None
} else {
let cfg: Config = serde_wasm_bindgen::from_value(config)
.map_err(|e| JsValue::from_str(&format!("invalid config: {e}")))?;
Some(crate::encoding::RenderConversationConfig {
auto_drop_analysis: cfg.auto_drop_analysis.unwrap_or(true),
})
};
self.inner
.render_conversation(&conversation, rust_config.as_ref())
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen]
pub fn render(&self, message: JsMessage) -> Result<Vec<u32>, JsValue> {
let message: JsValue = message.into();
let message: crate::chat::Message = serde_wasm_bindgen::from_value(message)
.map_err(|e| JsValue::from_str(&format!("invalid message JSON: {e}")))?;
self.inner
.render(&message)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = parseMessagesFromCompletionTokens)]
pub fn parse_messages_from_completion_tokens(
&self,
tokens: Vec<u32>,
role: Option<String>,
) -> Result<String, JsValue> {
let role_parsed = if let Some(r) = role {
Some(
Role::try_from(r.as_str())
.map_err(|_| JsValue::from_str(&format!("unknown role: {r}")))?,
)
} else {
None
};
let messages: Vec<Message> = self
.inner
.parse_messages_from_completion_tokens(tokens, role_parsed)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
serde_json::to_string(&messages)
.map_err(|e| JsValue::from_str(&format!("failed to serialise messages to JSON: {e}")))
}
#[wasm_bindgen(js_name = decodeUtf8)]
pub fn decode_utf8(&self, tokens: Vec<u32>) -> Result<String, JsValue> {
self.inner
.tokenizer()
.decode_utf8(tokens)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = decodeBytes)]
pub fn decode_bytes(&self, tokens: Vec<u32>) -> Result<Vec<u8>, JsValue> {
self.inner
.tokenizer()
.decode_bytes(tokens)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen]
pub fn encode(&self, text: &str, allowed_special: JsValue) -> Result<Vec<u32>, JsValue> {
let allowed_vec: Vec<String> =
if allowed_special.is_undefined() || allowed_special.is_null() {
Vec::new()
} else {
serde_wasm_bindgen::from_value(allowed_special)
.map_err(|e| JsValue::from_str(&format!("invalid allowed_special: {e}")))?
};
let allowed_set: std::collections::HashSet<&str> =
allowed_vec.iter().map(|s| s.as_str()).collect();
Ok(self.inner.tokenizer().encode(text, &allowed_set).0)
}
#[wasm_bindgen(js_name = specialTokens)]
pub fn special_tokens(&self) -> Vec<String> {
self.inner
.tokenizer()
.special_tokens()
.into_iter()
.map(str::to_string)
.collect()
}
#[wasm_bindgen(js_name = isSpecialToken)]
pub fn is_special_token(&self, token: u32) -> bool {
self.inner.tokenizer().is_special_token(token)
}
#[wasm_bindgen(js_name = stopTokens)]
pub fn stop_tokens(&self) -> Result<Vec<u32>, JsValue> {
self.inner
.stop_tokens()
.map(|set| set.into_iter().collect())
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = stopTokensForAssistantActions)]
pub fn stop_tokens_for_assistant_actions(&self) -> Result<Vec<u32>, JsValue> {
self.inner
.stop_tokens_for_assistant_actions()
.map(|set| set.into_iter().collect())
.map_err(|e| JsValue::from_str(&e.to_string()))
}
}
#[wasm_bindgen]
pub struct JsStreamableParser {
inner: StreamableParser,
}
#[wasm_bindgen]
impl JsStreamableParser {
#[wasm_bindgen(constructor)]
pub fn new(encoding: &JsHarmonyEncoding, role: &str) -> Result<JsStreamableParser, JsValue> {
let parsed_role = Role::try_from(role)
.map_err(|_| JsValue::from_str(&format!("unknown role: {role}")))?;
let inner = StreamableParser::new(encoding.inner.clone(), parsed_role)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(Self { inner })
}
pub fn process(&mut self, token: u32) -> Result<(), JsValue> {
self.inner
.process(token)
.map(|_| ())
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(getter, js_name = currentContent)]
pub fn current_content(&self) -> Result<String, JsValue> {
self.inner
.current_content()
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(getter, js_name = currentRole)]
pub fn current_role(&self) -> String {
self.inner
.current_role()
.map(|r| r.as_str().to_string())
.unwrap_or_default()
}
#[wasm_bindgen(getter, js_name = currentContentType)]
pub fn current_content_type(&self) -> String {
self.inner.current_content_type().unwrap_or_default()
}
#[wasm_bindgen(getter, js_name = lastContentDelta)]
pub fn last_content_delta(&self) -> Result<String, JsValue> {
match self.inner.last_content_delta() {
Ok(Some(s)) => Ok(s),
Ok(None) => Ok(String::new()),
Err(e) => Err(JsValue::from_str(&e.to_string())),
}
}
#[wasm_bindgen(getter)]
pub fn messages(&self) -> Result<String, JsValue> {
serde_json::to_string(self.inner.messages())
.map_err(|e| JsValue::from_str(&format!("failed to serialise messages to JSON: {e}")))
}
#[wasm_bindgen(getter)]
pub fn tokens(&self) -> Vec<u32> {
self.inner.tokens().to_vec()
}
#[wasm_bindgen(getter)]
pub fn state(&self) -> Result<String, JsValue> {
self.inner
.state_json()
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(getter, js_name = currentRecipient)]
pub fn current_recipient(&self) -> String {
self.inner.current_recipient().unwrap_or_default()
}
#[wasm_bindgen(getter, js_name = currentChannel)]
pub fn current_channel(&self) -> String {
self.inner.current_channel().unwrap_or_default()
}
}
#[wasm_bindgen]
pub enum StreamState {
ExpectStart,
Header,
Content,
}
#[wasm_bindgen]
pub async fn load_harmony_encoding(
name: &str,
base_url: Option<String>,
) -> Result<JsHarmonyEncoding, JsValue> {
if let Some(base) = base_url {
crate::tiktoken_ext::set_tiktoken_base_url(base);
}
let parsed: HarmonyEncodingName = name
.parse::<HarmonyEncodingName>()
.map_err(|e| JsValue::from_str(&e.to_string()))?;
let encoding = inner_load_harmony_encoding(parsed)
.await
.map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(JsHarmonyEncoding { inner: encoding })
}
#[wasm_bindgen]
pub fn get_tool_namespace_config(tool: &str) -> Result<JsValue, JsValue> {
let cfg = match tool {
"browser" => ToolNamespaceConfig::browser(),
"python" => ToolNamespaceConfig::python(),
_ => {
return Err(JsValue::from_str(&format!(
"Unknown tool namespace: {}",
tool
)))
}
};
serde_wasm_bindgen::to_value(&cfg).map_err(|e| JsValue::from_str(&e.to_string()))
}