feat: use actual tokenizer for unified_exec truncation (#5514)
This commit is contained in:
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -1066,6 +1066,7 @@ dependencies = [
|
|||||||
"codex-rmcp-client",
|
"codex-rmcp-client",
|
||||||
"codex-utils-pty",
|
"codex-utils-pty",
|
||||||
"codex-utils-string",
|
"codex-utils-string",
|
||||||
|
"codex-utils-tokenizer",
|
||||||
"core-foundation 0.9.4",
|
"core-foundation 0.9.4",
|
||||||
"core_test_support",
|
"core_test_support",
|
||||||
"dirs",
|
"dirs",
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ codex-rmcp-client = { workspace = true }
|
|||||||
codex-async-utils = { workspace = true }
|
codex-async-utils = { workspace = true }
|
||||||
codex-utils-string = { workspace = true }
|
codex-utils-string = { workspace = true }
|
||||||
codex-utils-pty = { workspace = true }
|
codex-utils-pty = { workspace = true }
|
||||||
|
codex-utils-tokenizer = { workspace = true }
|
||||||
dirs = { workspace = true }
|
dirs = { workspace = true }
|
||||||
dunce = { workspace = true }
|
dunce = { workspace = true }
|
||||||
env-flags = { workspace = true }
|
env-flags = { workspace = true }
|
||||||
|
|||||||
@@ -1,18 +1,35 @@
|
|||||||
//! Utilities for truncating large chunks of output while preserving a prefix
|
//! Utilities for truncating large chunks of output while preserving a prefix
|
||||||
//! and suffix on UTF-8 boundaries.
|
//! and suffix on UTF-8 boundaries.
|
||||||
|
|
||||||
|
use codex_utils_tokenizer::Tokenizer;
|
||||||
|
|
||||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||||
/// preserving the beginning and the end. Returns the possibly truncated
|
/// preserving the beginning and the end. Returns the possibly truncated
|
||||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
/// string and `Some(original_token_count)` (counted with the local tokenizer;
|
||||||
|
/// falls back to a 4-bytes-per-token estimate if the tokenizer cannot load)
|
||||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||||
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||||
if s.len() <= max_bytes {
|
if s.len() <= max_bytes {
|
||||||
return (s.to_string(), None);
|
return (s.to_string(), None);
|
||||||
}
|
}
|
||||||
|
|
||||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
// Build a tokenizer for counting (default to o200k_base; fall back to cl100k_base).
|
||||||
|
// If both fail, fall back to a 4-bytes-per-token estimate.
|
||||||
|
let tok = Tokenizer::try_default().ok();
|
||||||
|
let token_count = |text: &str| -> u64 {
|
||||||
|
if let Some(ref t) = tok {
|
||||||
|
t.count(text) as u64
|
||||||
|
} else {
|
||||||
|
(text.len() as u64).div_ceil(4)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let total_tokens = token_count(s);
|
||||||
if max_bytes == 0 {
|
if max_bytes == 0 {
|
||||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
return (
|
||||||
|
format!("…{total_tokens} tokens truncated…"),
|
||||||
|
Some(total_tokens),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||||
@@ -50,13 +67,17 @@ pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>
|
|||||||
idx
|
idx
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut guess_tokens = est_tokens;
|
// Iterate to stabilize marker length → keep budget → boundaries.
|
||||||
|
let mut guess_tokens: u64 = 1;
|
||||||
for _ in 0..4 {
|
for _ in 0..4 {
|
||||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||||
let marker_len = marker.len();
|
let marker_len = marker.len();
|
||||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||||
if keep_budget == 0 {
|
if keep_budget == 0 {
|
||||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
return (
|
||||||
|
format!("…{total_tokens} tokens truncated…"),
|
||||||
|
Some(total_tokens),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let left_budget = keep_budget / 2;
|
let left_budget = keep_budget / 2;
|
||||||
@@ -67,59 +88,72 @@ pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>
|
|||||||
suffix_start = prefix_end;
|
suffix_start = prefix_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
// Tokens actually removed (middle slice) using the real tokenizer.
|
||||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
let removed_tokens = token_count(&s[prefix_end..suffix_start]);
|
||||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
|
||||||
|
|
||||||
if new_tokens == guess_tokens {
|
// If the number of digits in the token count does not change the marker length,
|
||||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
// we can finalize output.
|
||||||
|
let final_marker = format!("…{removed_tokens} tokens truncated…");
|
||||||
|
if final_marker.len() == marker_len {
|
||||||
|
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||||
|
let mut out = String::with_capacity(final_marker.len() + kept_content_bytes + 1);
|
||||||
out.push_str(&s[..prefix_end]);
|
out.push_str(&s[..prefix_end]);
|
||||||
out.push_str(&marker);
|
out.push_str(&final_marker);
|
||||||
out.push('\n');
|
out.push('\n');
|
||||||
out.push_str(&s[suffix_start..]);
|
out.push_str(&s[suffix_start..]);
|
||||||
return (out, Some(est_tokens));
|
return (out, Some(total_tokens));
|
||||||
}
|
}
|
||||||
|
|
||||||
guess_tokens = new_tokens;
|
guess_tokens = removed_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fallback build after iterations: compute with the last guess.
|
||||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||||
let marker_len = marker.len();
|
let marker_len = marker.len();
|
||||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||||
if keep_budget == 0 {
|
if keep_budget == 0 {
|
||||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
return (
|
||||||
|
format!("…{total_tokens} tokens truncated…"),
|
||||||
|
Some(total_tokens),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let left_budget = keep_budget / 2;
|
let left_budget = keep_budget / 2;
|
||||||
let right_budget = keep_budget - left_budget;
|
let right_budget = keep_budget - left_budget;
|
||||||
let prefix_end = pick_prefix_end(s, left_budget);
|
let prefix_end = pick_prefix_end(s, left_budget);
|
||||||
let suffix_start = pick_suffix_start(s, right_budget);
|
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||||
|
if suffix_start < prefix_end {
|
||||||
|
suffix_start = prefix_end;
|
||||||
|
}
|
||||||
|
|
||||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||||
out.push_str(&s[..prefix_end]);
|
out.push_str(&s[..prefix_end]);
|
||||||
out.push_str(&marker);
|
out.push_str(&marker);
|
||||||
out.push('\n');
|
out.push('\n');
|
||||||
out.push_str(&s[suffix_start..]);
|
out.push_str(&s[suffix_start..]);
|
||||||
(out, Some(est_tokens))
|
(out, Some(total_tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::truncate_middle;
|
use super::truncate_middle;
|
||||||
|
use codex_utils_tokenizer::Tokenizer;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn truncate_middle_no_newlines_fallback() {
|
fn truncate_middle_no_newlines_fallback() {
|
||||||
|
let tok = Tokenizer::try_default().expect("load tokenizer");
|
||||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
|
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
|
||||||
let max_bytes = 32;
|
let max_bytes = 32;
|
||||||
let (out, original) = truncate_middle(s, max_bytes);
|
let (out, original) = truncate_middle(s, max_bytes);
|
||||||
assert!(out.starts_with("abc"));
|
assert!(out.starts_with("abc"));
|
||||||
assert!(out.contains("tokens truncated"));
|
assert!(out.contains("tokens truncated"));
|
||||||
assert!(out.ends_with("XYZ*"));
|
assert!(out.ends_with("XYZ*"));
|
||||||
assert_eq!(original, Some((s.len() as u64).div_ceil(4)));
|
assert_eq!(original, Some(tok.count(s) as u64));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn truncate_middle_prefers_newline_boundaries() {
|
fn truncate_middle_prefers_newline_boundaries() {
|
||||||
|
let tok = Tokenizer::try_default().expect("load tokenizer");
|
||||||
let mut s = String::new();
|
let mut s = String::new();
|
||||||
for i in 1..=20 {
|
for i in 1..=20 {
|
||||||
s.push_str(&format!("{i:03}\n"));
|
s.push_str(&format!("{i:03}\n"));
|
||||||
@@ -131,50 +165,36 @@ mod tests {
|
|||||||
assert!(out.starts_with("001\n002\n003\n004\n"));
|
assert!(out.starts_with("001\n002\n003\n004\n"));
|
||||||
assert!(out.contains("tokens truncated"));
|
assert!(out.contains("tokens truncated"));
|
||||||
assert!(out.ends_with("017\n018\n019\n020\n"));
|
assert!(out.ends_with("017\n018\n019\n020\n"));
|
||||||
assert_eq!(tokens, Some(20));
|
assert_eq!(tokens, Some(tok.count(&s) as u64));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn truncate_middle_handles_utf8_content() {
|
fn truncate_middle_handles_utf8_content() {
|
||||||
|
let tok = Tokenizer::try_default().expect("load tokenizer");
|
||||||
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
|
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
|
||||||
let max_bytes = 32;
|
let max_bytes = 32;
|
||||||
let (out, tokens) = truncate_middle(s, max_bytes);
|
let (out, tokens) = truncate_middle(s, max_bytes);
|
||||||
|
|
||||||
assert!(out.contains("tokens truncated"));
|
assert!(out.contains("tokens truncated"));
|
||||||
assert!(!out.contains('\u{fffd}'));
|
assert!(!out.contains('\u{fffd}'));
|
||||||
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4)));
|
assert_eq!(tokens, Some(tok.count(s) as u64));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn truncate_middle_prefers_newline_boundaries_2() {
|
fn truncate_middle_prefers_newline_boundaries_2() {
|
||||||
|
let tok = Tokenizer::try_default().expect("load tokenizer");
|
||||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||||
let mut s = String::new();
|
let mut s = String::new();
|
||||||
for i in 1..=20 {
|
for i in 1..=20 {
|
||||||
s.push_str(&format!("{i:03}\n"));
|
s.push_str(&format!("{i:03}\n"));
|
||||||
}
|
}
|
||||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
|
||||||
assert_eq!(s.len(), 80);
|
assert_eq!(s.len(), 80);
|
||||||
|
|
||||||
// Choose a cap that forces truncation while leaving room for
|
|
||||||
// a few lines on each side after accounting for the marker.
|
|
||||||
let max_bytes = 64;
|
let max_bytes = 64;
|
||||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
let (out, total) = truncate_middle(&s, max_bytes);
|
||||||
assert_eq!(
|
assert!(out.starts_with("001\n002\n003\n004\n"));
|
||||||
truncate_middle(&s, max_bytes),
|
assert!(out.contains("tokens truncated"));
|
||||||
(
|
assert!(out.ends_with("017\n018\n019\n020\n"));
|
||||||
r#"001
|
assert_eq!(total, Some(tok.count(&s) as u64));
|
||||||
002
|
|
||||||
003
|
|
||||||
004
|
|
||||||
…12 tokens truncated…
|
|
||||||
017
|
|
||||||
018
|
|
||||||
019
|
|
||||||
020
|
|
||||||
"#
|
|
||||||
.to_string(),
|
|
||||||
Some(20)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,8 +55,13 @@ impl Tokenizer {
|
|||||||
Ok(Self { inner })
|
Ok(Self { inner })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Default to `O200kBase`
|
||||||
|
pub fn try_default() -> Result<Self, TokenizerError> {
|
||||||
|
Self::new(EncodingKind::O200kBase)
|
||||||
|
}
|
||||||
|
|
||||||
/// Build a tokenizer using an `OpenAI` model name (maps to an encoding).
|
/// Build a tokenizer using an `OpenAI` model name (maps to an encoding).
|
||||||
/// Falls back to the `o200k_base` encoding when the model is unknown.
|
/// Falls back to the `O200kBase` encoding when the model is unknown.
|
||||||
pub fn for_model(model: &str) -> Result<Self, TokenizerError> {
|
pub fn for_model(model: &str) -> Result<Self, TokenizerError> {
|
||||||
match tiktoken_rs::get_bpe_from_model(model) {
|
match tiktoken_rs::get_bpe_from_model(model) {
|
||||||
Ok(inner) => Ok(Self { inner }),
|
Ok(inner) => Ok(Self { inner }),
|
||||||
|
|||||||
Reference in New Issue
Block a user