windows fix again

This commit is contained in:
Scott Lessans 2025-08-05 12:55:20 -07:00
parent 69001b7064
commit 155f53eca8

View file

@ -1,3 +1,5 @@
use std::path::Path;
use crate::{
chat::{
Author, Conversation, DeveloperContent, Message, ReasoningEffort, Role, SystemContent,
@ -10,12 +12,25 @@ use crate::{
use pretty_assertions::{assert_eq, Comparison};
use serde_json::json;
fn parse_tokens(text: &str) -> Vec<Rank> {
text.split_whitespace()
fn parse_tokens(text: impl AsRef<str>) -> Vec<Rank> {
text.as_ref()
.split_whitespace()
.map(|s| s.parse().unwrap())
.collect()
}
fn load_test_data(path: impl AsRef<Path>) -> String {
// on windows, we need to replace \r\n with \n
let cargo_manifest_dir = Path::new(env!("CARGO_MANIFEST_DIR"));
let src_dir = cargo_manifest_dir.join("src");
let path = src_dir.join(path);
std::fs::read_to_string(path)
.unwrap()
.replace("\r\n", "\n")
.trim_end()
.to_string()
}
const ENCODINGS: [HarmonyEncodingName; 1] = [HarmonyEncodingName::HarmonyGptOss];
#[test]
@ -25,9 +40,7 @@ fn test_simple_convo() {
let expected_tokens = encoding
.tokenizer
.encode(
include_str!("../test-data/test_simple_convo.txt")
.replace("\r\n", "\n")
.trim_end(),
load_test_data("../test-data/test_simple_convo.txt").as_str(),
&encoding.tokenizer.special_tokens(),
)
.0;
@ -52,47 +65,42 @@ fn test_simple_convo_with_effort() {
let test_cases = [
(
ReasoningEffort::Low,
include_str!("../test-data/test_simple_convo_low_effort.txt"),
load_test_data("../test-data/test_simple_convo_low_effort.txt"),
true,
),
(
ReasoningEffort::Medium,
include_str!("../test-data/test_simple_convo_medium_effort.txt"),
load_test_data("../test-data/test_simple_convo_medium_effort.txt"),
true,
),
(
ReasoningEffort::High,
include_str!("../test-data/test_simple_convo_high_effort.txt"),
load_test_data("../test-data/test_simple_convo_high_effort.txt"),
true,
),
(
ReasoningEffort::Low,
include_str!("../test-data/test_simple_convo_low_effort_no_instruction.txt"),
load_test_data("../test-data/test_simple_convo_low_effort_no_instruction.txt"),
false,
),
(
ReasoningEffort::Medium,
include_str!("../test-data/test_simple_convo_medium_effort_no_instruction.txt"),
load_test_data("../test-data/test_simple_convo_medium_effort_no_instruction.txt"),
false,
),
(
ReasoningEffort::High,
include_str!("../test-data/test_simple_convo_high_effort_no_instruction.txt"),
load_test_data("../test-data/test_simple_convo_high_effort_no_instruction.txt"),
false,
),
];
for encoding_name in ENCODINGS {
let encoding = load_harmony_encoding(encoding_name).unwrap();
for (effort, expected_text, use_instruction) in test_cases {
// on windows, we need to replace \r\n with \n
let expected_text = expected_text.replace("\r\n", "\n");
for &(effort, ref expected_text, use_instruction) in &test_cases {
let expected_tokens = encoding
.tokenizer
.encode(
expected_text.trim_end(),
&encoding.tokenizer.special_tokens(),
)
.encode(expected_text.as_str(), &encoding.tokenizer.special_tokens())
.0;
let sys = SystemContent::new()
.with_model_identity("You are ChatGPT, a large language model trained by OpenAI.")
@ -127,8 +135,8 @@ fn test_simple_convo_with_effort() {
#[test]
fn test_simple_reasoning_response() {
let expected_tokens = parse_tokens(include_str!(
"../test-data/test_simple_reasoning_response.txt"
let expected_tokens = parse_tokens(load_test_data(
"../test-data/test_simple_reasoning_response.txt",
));
for encoding_name in ENCODINGS {
let encoding = load_harmony_encoding(encoding_name).unwrap();
@ -184,7 +192,7 @@ fn test_reasoning_system_message() {
let expected = encoding
.tokenizer
.encode(
include_str!("../test-data/test_reasoning_system_message.txt").trim_end(),
load_test_data("../test-data/test_reasoning_system_message.txt").as_str(),
&encoding.tokenizer.special_tokens(),
)
.0;
@ -215,8 +223,8 @@ fn test_reasoning_system_message_no_instruction() {
let expected = encoding
.tokenizer
.encode(
include_str!("../test-data/test_reasoning_system_message_no_instruction.txt")
.trim_end(),
load_test_data("../test-data/test_reasoning_system_message_no_instruction.txt")
.as_str(),
&encoding.tokenizer.special_tokens(),
)
.0;
@ -249,8 +257,8 @@ fn test_reasoning_system_message_with_dates() {
let expected = encoding
.tokenizer
.encode(
include_str!("../test-data/test_reasoning_system_message_with_dates.txt")
.trim_end(),
load_test_data("../test-data/test_reasoning_system_message_with_dates.txt")
.as_str(),
&encoding.tokenizer.special_tokens(),
)
.0;
@ -279,8 +287,7 @@ fn test_reasoning_system_message_with_dates() {
#[test]
fn test_render_functions_with_parameters() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let expected_output =
include_str!("../test-data/test_render_functions_with_parameters.txt").trim_end();
let expected_output = load_test_data("../test-data/test_render_functions_with_parameters.txt");
let sys = SystemContent::new()
.with_reasoning_effort(ReasoningEffort::High)
@ -386,7 +393,7 @@ fn test_render_functions_with_parameters() {
#[test]
fn test_browser_and_python_tool() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let expected_output = include_str!("../test-data/test_browser_and_python_tool.txt").trim_end();
let expected_output = load_test_data("../test-data/test_browser_and_python_tool.txt");
let convo = Conversation::from_messages([Message::from_role_and_content(
Role::System,
@ -407,7 +414,7 @@ fn test_browser_and_python_tool() {
#[test]
fn test_dropping_cot_by_default() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let expected_output = include_str!("../test-data/test_dropping_cot_by_default.txt").trim_end();
let expected_output = load_test_data("../test-data/test_dropping_cot_by_default.txt");
let convo = Conversation::from_messages([
Message::from_role_and_content(Role::User, "What is 2 + 2?"),
@ -437,8 +444,7 @@ fn test_dropping_cot_by_default() {
#[test]
fn test_does_not_drop_if_ongoing_analysis() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let expected_output =
include_str!("../test-data/test_does_not_drop_if_ongoing_analysis.txt").trim_end();
let expected_output = load_test_data("../test-data/test_does_not_drop_if_ongoing_analysis.txt");
let convo = Conversation::from_messages([
Message::from_role_and_content(Role::User, "What is the weather in SF?"),
@ -474,7 +480,7 @@ fn test_does_not_drop_if_ongoing_analysis() {
#[test]
fn test_preserve_cot() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let expected_output = include_str!("../test-data/test_preserve_cot.txt").trim_end();
let expected_output = load_test_data("../test-data/test_preserve_cot.txt");
let convo = Conversation::from_messages([
Message::from_role_and_content(Role::User, "What is 2 + 2?"),
@ -538,10 +544,10 @@ fn test_decode_utf8_invalid_token() {
#[test]
fn test_tool_response_parsing() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let text_tokens = include_str!("../test-data/test_tool_response_parsing.txt").trim_end();
let text_tokens = load_test_data("../test-data/test_tool_response_parsing.txt");
let tokens = encoding
.tokenizer
.encode(text_tokens, &encoding.tokenizer.special_tokens())
.encode(&text_tokens, &encoding.tokenizer.special_tokens())
.0;
let expected_message = Message::from_author_and_content(
@ -620,10 +626,10 @@ fn test_invalid_utf8_decoding() {
#[test]
fn test_streamable_parser() {
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
let text = include_str!("../test-data/test_streamable_parser.txt").trim_end();
let text = load_test_data("../test-data/test_streamable_parser.txt");
let tokens = encoding
.tokenizer
.encode(text, &encoding.tokenizer.special_tokens())
.encode(&text, &encoding.tokenizer.special_tokens())
.0;
let mut parser =
crate::encoding::StreamableParser::new(encoding.clone(), Some(Role::Assistant)).unwrap();