feat: use actual tokenizer for unified_exec truncation (#5514)

This commit is contained in:
jif-oai
2025-10-23 17:08:06 +01:00
committed by GitHub
parent 6745b12427
commit 0b4527146e
4 changed files with 68 additions and 41 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -1066,6 +1066,7 @@ dependencies = [
"codex-rmcp-client", "codex-rmcp-client",
"codex-utils-pty", "codex-utils-pty",
"codex-utils-string", "codex-utils-string",
"codex-utils-tokenizer",
"core-foundation 0.9.4", "core-foundation 0.9.4",
"core_test_support", "core_test_support",
"dirs", "dirs",

View File

@@ -28,6 +28,7 @@ codex-rmcp-client = { workspace = true }
codex-async-utils = { workspace = true } codex-async-utils = { workspace = true }
codex-utils-string = { workspace = true } codex-utils-string = { workspace = true }
codex-utils-pty = { workspace = true } codex-utils-pty = { workspace = true }
codex-utils-tokenizer = { workspace = true }
dirs = { workspace = true } dirs = { workspace = true }
dunce = { workspace = true } dunce = { workspace = true }
env-flags = { workspace = true } env-flags = { workspace = true }

View File

@@ -1,18 +1,35 @@
//! Utilities for truncating large chunks of output while preserving a prefix //! Utilities for truncating large chunks of output while preserving a prefix
//! and suffix on UTF-8 boundaries. //! and suffix on UTF-8 boundaries.
use codex_utils_tokenizer::Tokenizer;
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes, /// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
/// preserving the beginning and the end. Returns the possibly truncated /// preserving the beginning and the end. Returns the possibly truncated
/// string and `Some(original_token_count)` (estimated at 4 bytes/token) /// string and `Some(original_token_count)` (counted with the local tokenizer;
/// falls back to a 4-bytes-per-token estimate if the tokenizer cannot load)
/// if truncation occurred; otherwise returns the original string and `None`. /// if truncation occurred; otherwise returns the original string and `None`.
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) { pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
if s.len() <= max_bytes { if s.len() <= max_bytes {
return (s.to_string(), None); return (s.to_string(), None);
} }
let est_tokens = (s.len() as u64).div_ceil(4); // Build a tokenizer for counting (default to o200k_base; fall back to cl100k_base).
// If both fail, fall back to a 4-bytes-per-token estimate.
let tok = Tokenizer::try_default().ok();
let token_count = |text: &str| -> u64 {
if let Some(ref t) = tok {
t.count(text) as u64
} else {
(text.len() as u64).div_ceil(4)
}
};
let total_tokens = token_count(s);
if max_bytes == 0 { if max_bytes == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens)); return (
format!("{total_tokens} tokens truncated…"),
Some(total_tokens),
);
} }
fn truncate_on_boundary(input: &str, max_len: usize) -> &str { fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
@@ -50,13 +67,17 @@ pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>
idx idx
} }
let mut guess_tokens = est_tokens; // Iterate to stabilize marker length → keep budget → boundaries.
let mut guess_tokens: u64 = 1;
for _ in 0..4 { for _ in 0..4 {
let marker = format!("{guess_tokens} tokens truncated…"); let marker = format!("{guess_tokens} tokens truncated…");
let marker_len = marker.len(); let marker_len = marker.len();
let keep_budget = max_bytes.saturating_sub(marker_len); let keep_budget = max_bytes.saturating_sub(marker_len);
if keep_budget == 0 { if keep_budget == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens)); return (
format!("{total_tokens} tokens truncated…"),
Some(total_tokens),
);
} }
let left_budget = keep_budget / 2; let left_budget = keep_budget / 2;
@@ -67,59 +88,72 @@ pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>
suffix_start = prefix_end; suffix_start = prefix_end;
} }
let kept_content_bytes = prefix_end + (s.len() - suffix_start); // Tokens actually removed (middle slice) using the real tokenizer.
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes); let removed_tokens = token_count(&s[prefix_end..suffix_start]);
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
if new_tokens == guess_tokens { // If the number of digits in the token count does not change the marker length,
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1); // we can finalize output.
let final_marker = format!("{removed_tokens} tokens truncated…");
if final_marker.len() == marker_len {
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
let mut out = String::with_capacity(final_marker.len() + kept_content_bytes + 1);
out.push_str(&s[..prefix_end]); out.push_str(&s[..prefix_end]);
out.push_str(&marker); out.push_str(&final_marker);
out.push('\n'); out.push('\n');
out.push_str(&s[suffix_start..]); out.push_str(&s[suffix_start..]);
return (out, Some(est_tokens)); return (out, Some(total_tokens));
} }
guess_tokens = new_tokens; guess_tokens = removed_tokens;
} }
// Fallback build after iterations: compute with the last guess.
let marker = format!("{guess_tokens} tokens truncated…"); let marker = format!("{guess_tokens} tokens truncated…");
let marker_len = marker.len(); let marker_len = marker.len();
let keep_budget = max_bytes.saturating_sub(marker_len); let keep_budget = max_bytes.saturating_sub(marker_len);
if keep_budget == 0 { if keep_budget == 0 {
return (format!("{est_tokens} tokens truncated…"), Some(est_tokens)); return (
format!("{total_tokens} tokens truncated…"),
Some(total_tokens),
);
} }
let left_budget = keep_budget / 2; let left_budget = keep_budget / 2;
let right_budget = keep_budget - left_budget; let right_budget = keep_budget - left_budget;
let prefix_end = pick_prefix_end(s, left_budget); let prefix_end = pick_prefix_end(s, left_budget);
let suffix_start = pick_suffix_start(s, right_budget); let mut suffix_start = pick_suffix_start(s, right_budget);
if suffix_start < prefix_end {
suffix_start = prefix_end;
}
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1); let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
out.push_str(&s[..prefix_end]); out.push_str(&s[..prefix_end]);
out.push_str(&marker); out.push_str(&marker);
out.push('\n'); out.push('\n');
out.push_str(&s[suffix_start..]); out.push_str(&s[suffix_start..]);
(out, Some(est_tokens)) (out, Some(total_tokens))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::truncate_middle; use super::truncate_middle;
use codex_utils_tokenizer::Tokenizer;
#[test] #[test]
fn truncate_middle_no_newlines_fallback() { fn truncate_middle_no_newlines_fallback() {
let tok = Tokenizer::try_default().expect("load tokenizer");
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*"; let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
let max_bytes = 32; let max_bytes = 32;
let (out, original) = truncate_middle(s, max_bytes); let (out, original) = truncate_middle(s, max_bytes);
assert!(out.starts_with("abc")); assert!(out.starts_with("abc"));
assert!(out.contains("tokens truncated")); assert!(out.contains("tokens truncated"));
assert!(out.ends_with("XYZ*")); assert!(out.ends_with("XYZ*"));
assert_eq!(original, Some((s.len() as u64).div_ceil(4))); assert_eq!(original, Some(tok.count(s) as u64));
} }
#[test] #[test]
fn truncate_middle_prefers_newline_boundaries() { fn truncate_middle_prefers_newline_boundaries() {
let tok = Tokenizer::try_default().expect("load tokenizer");
let mut s = String::new(); let mut s = String::new();
for i in 1..=20 { for i in 1..=20 {
s.push_str(&format!("{i:03}\n")); s.push_str(&format!("{i:03}\n"));
@@ -131,50 +165,36 @@ mod tests {
assert!(out.starts_with("001\n002\n003\n004\n")); assert!(out.starts_with("001\n002\n003\n004\n"));
assert!(out.contains("tokens truncated")); assert!(out.contains("tokens truncated"));
assert!(out.ends_with("017\n018\n019\n020\n")); assert!(out.ends_with("017\n018\n019\n020\n"));
assert_eq!(tokens, Some(20)); assert_eq!(tokens, Some(tok.count(&s) as u64));
} }
#[test] #[test]
fn truncate_middle_handles_utf8_content() { fn truncate_middle_handles_utf8_content() {
let tok = Tokenizer::try_default().expect("load tokenizer");
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n"; let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
let max_bytes = 32; let max_bytes = 32;
let (out, tokens) = truncate_middle(s, max_bytes); let (out, tokens) = truncate_middle(s, max_bytes);
assert!(out.contains("tokens truncated")); assert!(out.contains("tokens truncated"));
assert!(!out.contains('\u{fffd}')); assert!(!out.contains('\u{fffd}'));
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4))); assert_eq!(tokens, Some(tok.count(s) as u64));
} }
#[test] #[test]
fn truncate_middle_prefers_newline_boundaries_2() { fn truncate_middle_prefers_newline_boundaries_2() {
let tok = Tokenizer::try_default().expect("load tokenizer");
// Build a multi-line string of 20 numbered lines (each "NNN\n"). // Build a multi-line string of 20 numbered lines (each "NNN\n").
let mut s = String::new(); let mut s = String::new();
for i in 1..=20 { for i in 1..=20 {
s.push_str(&format!("{i:03}\n")); s.push_str(&format!("{i:03}\n"));
} }
// Total length: 20 lines * 4 bytes per line = 80 bytes.
assert_eq!(s.len(), 80); assert_eq!(s.len(), 80);
// Choose a cap that forces truncation while leaving room for
// a few lines on each side after accounting for the marker.
let max_bytes = 64; let max_bytes = 64;
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20). let (out, total) = truncate_middle(&s, max_bytes);
assert_eq!( assert!(out.starts_with("001\n002\n003\n004\n"));
truncate_middle(&s, max_bytes), assert!(out.contains("tokens truncated"));
( assert!(out.ends_with("017\n018\n019\n020\n"));
r#"001 assert_eq!(total, Some(tok.count(&s) as u64));
002
003
004
…12 tokens truncated…
017
018
019
020
"#
.to_string(),
Some(20)
)
);
} }
} }

View File

@@ -55,8 +55,13 @@ impl Tokenizer {
Ok(Self { inner }) Ok(Self { inner })
} }
/// Default to `O200kBase`
pub fn try_default() -> Result<Self, TokenizerError> {
Self::new(EncodingKind::O200kBase)
}
/// Build a tokenizer using an `OpenAI` model name (maps to an encoding). /// Build a tokenizer using an `OpenAI` model name (maps to an encoding).
/// Falls back to the `o200k_base` encoding when the model is unknown. /// Falls back to the `O200kBase` encoding when the model is unknown.
pub fn for_model(model: &str) -> Result<Self, TokenizerError> { pub fn for_model(model: &str) -> Result<Self, TokenizerError> {
match tiktoken_rs::get_bpe_from_model(model) { match tiktoken_rs::get_bpe_from_model(model) {
Ok(inner) => Ok(Self { inner }), Ok(inner) => Ok(Self { inner }),