diff --git a/src/tests.rs b/src/tests.rs index 9e19a0c..d6b8504 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,3 +1,5 @@ +use std::path::Path; + use crate::{ chat::{ Author, Conversation, DeveloperContent, Message, ReasoningEffort, Role, SystemContent, @@ -10,12 +12,25 @@ use crate::{ use pretty_assertions::{assert_eq, Comparison}; use serde_json::json; -fn parse_tokens(text: &str) -> Vec { - text.split_whitespace() +fn parse_tokens(text: impl AsRef) -> Vec { + text.as_ref() + .split_whitespace() .map(|s| s.parse().unwrap()) .collect() } +fn load_test_data(path: impl AsRef) -> String { + // on windows, we need to replace \r\n with \n + let cargo_manifest_dir = Path::new(env!("CARGO_MANIFEST_DIR")); + let src_dir = cargo_manifest_dir.join("src"); + let path = src_dir.join(path); + std::fs::read_to_string(path) + .unwrap() + .replace("\r\n", "\n") + .trim_end() + .to_string() +} + const ENCODINGS: [HarmonyEncodingName; 1] = [HarmonyEncodingName::HarmonyGptOss]; #[test] @@ -25,9 +40,7 @@ fn test_simple_convo() { let expected_tokens = encoding .tokenizer .encode( - include_str!("../test-data/test_simple_convo.txt") - .replace("\r\n", "\n") - .trim_end(), + load_test_data("../test-data/test_simple_convo.txt").as_str(), &encoding.tokenizer.special_tokens(), ) .0; @@ -52,47 +65,42 @@ fn test_simple_convo_with_effort() { let test_cases = [ ( ReasoningEffort::Low, - include_str!("../test-data/test_simple_convo_low_effort.txt"), + load_test_data("../test-data/test_simple_convo_low_effort.txt"), true, ), ( ReasoningEffort::Medium, - include_str!("../test-data/test_simple_convo_medium_effort.txt"), + load_test_data("../test-data/test_simple_convo_medium_effort.txt"), true, ), ( ReasoningEffort::High, - include_str!("../test-data/test_simple_convo_high_effort.txt"), + load_test_data("../test-data/test_simple_convo_high_effort.txt"), true, ), ( ReasoningEffort::Low, - include_str!("../test-data/test_simple_convo_low_effort_no_instruction.txt"), + load_test_data("../test-data/test_simple_convo_low_effort_no_instruction.txt"), false, ), ( ReasoningEffort::Medium, - include_str!("../test-data/test_simple_convo_medium_effort_no_instruction.txt"), + load_test_data("../test-data/test_simple_convo_medium_effort_no_instruction.txt"), false, ), ( ReasoningEffort::High, - include_str!("../test-data/test_simple_convo_high_effort_no_instruction.txt"), + load_test_data("../test-data/test_simple_convo_high_effort_no_instruction.txt"), false, ), ]; for encoding_name in ENCODINGS { let encoding = load_harmony_encoding(encoding_name).unwrap(); - for (effort, expected_text, use_instruction) in test_cases { - // on windows, we need to replace \r\n with \n - let expected_text = expected_text.replace("\r\n", "\n"); + for &(effort, ref expected_text, use_instruction) in &test_cases { let expected_tokens = encoding .tokenizer - .encode( - expected_text.trim_end(), - &encoding.tokenizer.special_tokens(), - ) + .encode(expected_text.as_str(), &encoding.tokenizer.special_tokens()) .0; let sys = SystemContent::new() .with_model_identity("You are ChatGPT, a large language model trained by OpenAI.") @@ -127,8 +135,8 @@ fn test_simple_convo_with_effort() { #[test] fn test_simple_reasoning_response() { - let expected_tokens = parse_tokens(include_str!( - "../test-data/test_simple_reasoning_response.txt" + let expected_tokens = parse_tokens(load_test_data( + "../test-data/test_simple_reasoning_response.txt", )); for encoding_name in ENCODINGS { let encoding = load_harmony_encoding(encoding_name).unwrap(); @@ -184,7 +192,7 @@ fn test_reasoning_system_message() { let expected = encoding .tokenizer .encode( - include_str!("../test-data/test_reasoning_system_message.txt").trim_end(), + load_test_data("../test-data/test_reasoning_system_message.txt").as_str(), &encoding.tokenizer.special_tokens(), ) .0; @@ -215,8 +223,8 @@ fn test_reasoning_system_message_no_instruction() { let expected = encoding .tokenizer .encode( - include_str!("../test-data/test_reasoning_system_message_no_instruction.txt") - .trim_end(), + load_test_data("../test-data/test_reasoning_system_message_no_instruction.txt") + .as_str(), &encoding.tokenizer.special_tokens(), ) .0; @@ -249,8 +257,8 @@ fn test_reasoning_system_message_with_dates() { let expected = encoding .tokenizer .encode( - include_str!("../test-data/test_reasoning_system_message_with_dates.txt") - .trim_end(), + load_test_data("../test-data/test_reasoning_system_message_with_dates.txt") + .as_str(), &encoding.tokenizer.special_tokens(), ) .0; @@ -279,8 +287,7 @@ fn test_reasoning_system_message_with_dates() { #[test] fn test_render_functions_with_parameters() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); - let expected_output = - include_str!("../test-data/test_render_functions_with_parameters.txt").trim_end(); + let expected_output = load_test_data("../test-data/test_render_functions_with_parameters.txt"); let sys = SystemContent::new() .with_reasoning_effort(ReasoningEffort::High) @@ -386,7 +393,7 @@ fn test_render_functions_with_parameters() { #[test] fn test_browser_and_python_tool() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); - let expected_output = include_str!("../test-data/test_browser_and_python_tool.txt").trim_end(); + let expected_output = load_test_data("../test-data/test_browser_and_python_tool.txt"); let convo = Conversation::from_messages([Message::from_role_and_content( Role::System, @@ -407,7 +414,7 @@ fn test_browser_and_python_tool() { #[test] fn test_dropping_cot_by_default() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); - let expected_output = include_str!("../test-data/test_dropping_cot_by_default.txt").trim_end(); + let expected_output = load_test_data("../test-data/test_dropping_cot_by_default.txt"); let convo = Conversation::from_messages([ Message::from_role_and_content(Role::User, "What is 2 + 2?"), @@ -437,8 +444,7 @@ fn test_dropping_cot_by_default() { #[test] fn test_does_not_drop_if_ongoing_analysis() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); - let expected_output = - include_str!("../test-data/test_does_not_drop_if_ongoing_analysis.txt").trim_end(); + let expected_output = load_test_data("../test-data/test_does_not_drop_if_ongoing_analysis.txt"); let convo = Conversation::from_messages([ Message::from_role_and_content(Role::User, "What is the weather in SF?"), @@ -474,7 +480,7 @@ fn test_does_not_drop_if_ongoing_analysis() { #[test] fn test_preserve_cot() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); - let expected_output = include_str!("../test-data/test_preserve_cot.txt").trim_end(); + let expected_output = load_test_data("../test-data/test_preserve_cot.txt"); let convo = Conversation::from_messages([ Message::from_role_and_content(Role::User, "What is 2 + 2?"), @@ -538,10 +544,10 @@ fn test_decode_utf8_invalid_token() { #[test] fn test_tool_response_parsing() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); - let text_tokens = include_str!("../test-data/test_tool_response_parsing.txt").trim_end(); + let text_tokens = load_test_data("../test-data/test_tool_response_parsing.txt"); let tokens = encoding .tokenizer - .encode(text_tokens, &encoding.tokenizer.special_tokens()) + .encode(&text_tokens, &encoding.tokenizer.special_tokens()) .0; let expected_message = Message::from_author_and_content( @@ -620,10 +626,10 @@ fn test_invalid_utf8_decoding() { #[test] fn test_streamable_parser() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); - let text = include_str!("../test-data/test_streamable_parser.txt").trim_end(); + let text = load_test_data("../test-data/test_streamable_parser.txt"); let tokens = encoding .tokenizer - .encode(text, &encoding.tokenizer.special_tokens()) + .encode(&text, &encoding.tokenizer.special_tokens()) .0; let mut parser = crate::encoding::StreamableParser::new(encoding.clone(), Some(Role::Assistant)).unwrap();