From fd0673e457da65f1fccb4551fe59b9df79b27d85 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Wed, 22 Oct 2025 16:01:02 +0100 Subject: [PATCH] feat: local tokenizer (#5508) --- codex-rs/Cargo.lock | 46 +++++++- codex-rs/Cargo.toml | 4 +- codex-rs/utils/tokenizer/Cargo.toml | 15 +++ codex-rs/utils/tokenizer/src/lib.rs | 156 ++++++++++++++++++++++++++++ 4 files changed, 218 insertions(+), 3 deletions(-) create mode 100644 codex-rs/utils/tokenizer/Cargo.toml create mode 100644 codex-rs/utils/tokenizer/src/lib.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 88f7f551..507c7d7b 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1517,6 +1517,16 @@ dependencies = [ name = "codex-utils-string" version = "0.0.0" +[[package]] +name = "codex-utils-tokenizer" +version = "0.0.0" +dependencies = [ + "anyhow", + "pretty_assertions", + "thiserror 2.0.16", + "tiktoken-rs", +] + [[package]] name = "color-eyre" version = "0.6.5" @@ -2314,6 +2324,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fancy-regex" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" +dependencies = [ + "bit-set", + "regex-automata", + "regex-syntax 0.8.5", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -4631,7 +4652,7 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash", + "rustc-hash 2.1.1", "rustls", "socket2 0.6.0", "thiserror 2.0.16", @@ -4651,7 +4672,7 @@ dependencies = [ "lru-slab", "rand 0.9.2", "ring", - "rustc-hash", + "rustc-hash 2.1.1", "rustls", "rustls-pki-types", "slab", @@ -4996,6 +5017,12 @@ version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc-hash" version = "2.1.1" @@ -6176,6 +6203,21 @@ dependencies = [ "zune-jpeg", ] +[[package]] +name = "tiktoken-rs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25563eeba904d770acf527e8b370fe9a5547bacd20ff84a0b6c3bc41288e5625" +dependencies = [ + "anyhow", + "base64", + "bstr", + "fancy-regex", + "lazy_static", + "regex", + "rustc-hash 1.1.0", +] + [[package]] name = "time" version = "0.3.44" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index cc05b527..83c9e78d 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -37,6 +37,7 @@ members = [ "utils/readiness", "utils/pty", "utils/string", + "utils/tokenizer", ] resolver = "2" @@ -82,6 +83,7 @@ codex-utils-json-to-toml = { path = "utils/json-to-toml" } codex-utils-pty = { path = "utils/pty" } codex-utils-readiness = { path = "utils/readiness" } codex-utils-string = { path = "utils/string" } +codex-utils-tokenizer = { path = "utils/tokenizer" } core_test_support = { path = "core/tests/common" } mcp-types = { path = "mcp-types" } mcp_test_support = { path = "mcp-server/tests/common" } @@ -246,7 +248,7 @@ unwrap_used = "deny" # cargo-shear cannot see the platform-specific openssl-sys usage, so we # silence the false positive here instead of deleting a real dependency. [workspace.metadata.cargo-shear] -ignored = ["openssl-sys", "codex-utils-readiness"] +ignored = ["openssl-sys", "codex-utils-readiness", "codex-utils-tokenizer"] [profile.release] lto = "fat" diff --git a/codex-rs/utils/tokenizer/Cargo.toml b/codex-rs/utils/tokenizer/Cargo.toml new file mode 100644 index 00000000..6f6b4dec --- /dev/null +++ b/codex-rs/utils/tokenizer/Cargo.toml @@ -0,0 +1,15 @@ +[package] +edition.workspace = true +name = "codex-utils-tokenizer" +version.workspace = true + +[lints] +workspace = true + +[dependencies] +anyhow = { workspace = true } +thiserror = { workspace = true } +tiktoken-rs = "0.7" + +[dev-dependencies] +pretty_assertions = { workspace = true } diff --git a/codex-rs/utils/tokenizer/src/lib.rs b/codex-rs/utils/tokenizer/src/lib.rs new file mode 100644 index 00000000..93740889 --- /dev/null +++ b/codex-rs/utils/tokenizer/src/lib.rs @@ -0,0 +1,156 @@ +use std::fmt; + +use anyhow::Context; +use anyhow::Error as AnyhowError; +use thiserror::Error; +use tiktoken_rs::CoreBPE; + +/// Supported local encodings. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum EncodingKind { + O200kBase, + Cl100kBase, +} + +impl fmt::Display for EncodingKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::O200kBase => f.write_str("o200k_base"), + Self::Cl100kBase => f.write_str("cl100k_base"), + } + } +} + +/// Tokenizer error type. +#[derive(Debug, Error)] +pub enum TokenizerError { + #[error("failed to load encoding {kind}")] + LoadEncoding { + kind: EncodingKind, + #[source] + source: AnyhowError, + }, + #[error("failed to decode tokens")] + Decode { + #[source] + source: AnyhowError, + }, +} + +/// Thin wrapper around a `tiktoken_rs::CoreBPE` tokenizer. +#[derive(Clone)] +pub struct Tokenizer { + inner: CoreBPE, +} + +impl Tokenizer { + /// Build a tokenizer for a specific encoding. + pub fn new(kind: EncodingKind) -> Result { + let loader: fn() -> anyhow::Result = match kind { + EncodingKind::O200kBase => tiktoken_rs::o200k_base, + EncodingKind::Cl100kBase => tiktoken_rs::cl100k_base, + }; + + let inner = loader().map_err(|source| TokenizerError::LoadEncoding { kind, source })?; + Ok(Self { inner }) + } + + /// Build a tokenizer using an `OpenAI` model name (maps to an encoding). + /// Falls back to the `o200k_base` encoding when the model is unknown. + pub fn for_model(model: &str) -> Result { + match tiktoken_rs::get_bpe_from_model(model) { + Ok(inner) => Ok(Self { inner }), + Err(model_error) => { + let inner = tiktoken_rs::o200k_base() + .with_context(|| { + format!("fallback after model lookup failure for {model}: {model_error}") + }) + .map_err(|source| TokenizerError::LoadEncoding { + kind: EncodingKind::O200kBase, + source, + })?; + Ok(Self { inner }) + } + } + } + + /// Encode text to token IDs. If `with_special_tokens` is true, special + /// tokens are allowed and may appear in the result. + #[must_use] + pub fn encode(&self, text: &str, with_special_tokens: bool) -> Vec { + let raw = if with_special_tokens { + self.inner.encode_with_special_tokens(text) + } else { + self.inner.encode_ordinary(text) + }; + raw.into_iter().map(|t| t as i32).collect() + } + + /// Count tokens in `text` as a signed integer. + #[must_use] + pub fn count(&self, text: &str) -> i64 { + // Signed length to satisfy our style preference. + i64::try_from(self.inner.encode_ordinary(text).len()).unwrap_or(i64::MAX) + } + + /// Decode token IDs back to text. + pub fn decode(&self, tokens: &[i32]) -> Result { + let raw: Vec = tokens.iter().map(|t| *t as u32).collect(); + self.inner + .decode(raw) + .map_err(|source| TokenizerError::Decode { source }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn cl100k_base_roundtrip_simple() -> Result<(), TokenizerError> { + let tok = Tokenizer::new(EncodingKind::Cl100kBase)?; + let s = "hello world"; + let ids = tok.encode(s, false); + // Stable expectation for cl100k_base + assert_eq!(ids, vec![15339, 1917]); + let back = tok.decode(&ids)?; + assert_eq!(back, s); + Ok(()) + } + + #[test] + fn preserves_whitespace_and_special_tokens_flag() -> Result<(), TokenizerError> { + let tok = Tokenizer::new(EncodingKind::Cl100kBase)?; + let s = "This has multiple spaces"; + let ids_no_special = tok.encode(s, false); + let round = tok.decode(&ids_no_special)?; + assert_eq!(round, s); + + // With special tokens allowed, result may be identical for normal text, + // but the API should still function. + let ids_with_special = tok.encode(s, true); + let round2 = tok.decode(&ids_with_special)?; + assert_eq!(round2, s); + Ok(()) + } + + #[test] + fn model_mapping_builds_tokenizer() -> Result<(), TokenizerError> { + // Choose a long-standing model alias that maps to cl100k_base. + let tok = Tokenizer::for_model("gpt-5")?; + let ids = tok.encode("ok", false); + let back = tok.decode(&ids)?; + assert_eq!(back, "ok"); + Ok(()) + } + + #[test] + fn unknown_model_defaults_to_o200k_base() -> Result<(), TokenizerError> { + let fallback = Tokenizer::new(EncodingKind::O200kBase)?; + let tok = Tokenizer::for_model("does-not-exist")?; + let text = "fallback please"; + assert_eq!(tok.encode(text, false), fallback.encode(text, false)); + Ok(()) + } +}