feat: local tokenizer (#5508)

This commit is contained in:
jif-oai
2025-10-22 16:01:02 +01:00
committed by GitHub
parent 00b1e130b3
commit fd0673e457
4 changed files with 218 additions and 3 deletions

46
codex-rs/Cargo.lock generated
View File

@@ -1517,6 +1517,16 @@ dependencies = [
name = "codex-utils-string"
version = "0.0.0"
[[package]]
name = "codex-utils-tokenizer"
version = "0.0.0"
dependencies = [
"anyhow",
"pretty_assertions",
"thiserror 2.0.16",
"tiktoken-rs",
]
[[package]]
name = "color-eyre"
version = "0.6.5"
@@ -2314,6 +2324,17 @@ dependencies = [
"once_cell",
]
[[package]]
name = "fancy-regex"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2"
dependencies = [
"bit-set",
"regex-automata",
"regex-syntax 0.8.5",
]
[[package]]
name = "fastrand"
version = "2.3.0"
@@ -4631,7 +4652,7 @@ dependencies = [
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustc-hash 2.1.1",
"rustls",
"socket2 0.6.0",
"thiserror 2.0.16",
@@ -4651,7 +4672,7 @@ dependencies = [
"lru-slab",
"rand 0.9.2",
"ring",
"rustc-hash",
"rustc-hash 2.1.1",
"rustls",
"rustls-pki-types",
"slab",
@@ -4996,6 +5017,12 @@ version = "0.1.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.1"
@@ -6176,6 +6203,21 @@ dependencies = [
"zune-jpeg",
]
[[package]]
name = "tiktoken-rs"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25563eeba904d770acf527e8b370fe9a5547bacd20ff84a0b6c3bc41288e5625"
dependencies = [
"anyhow",
"base64",
"bstr",
"fancy-regex",
"lazy_static",
"regex",
"rustc-hash 1.1.0",
]
[[package]]
name = "time"
version = "0.3.44"

View File

@@ -37,6 +37,7 @@ members = [
"utils/readiness",
"utils/pty",
"utils/string",
"utils/tokenizer",
]
resolver = "2"
@@ -82,6 +83,7 @@ codex-utils-json-to-toml = { path = "utils/json-to-toml" }
codex-utils-pty = { path = "utils/pty" }
codex-utils-readiness = { path = "utils/readiness" }
codex-utils-string = { path = "utils/string" }
codex-utils-tokenizer = { path = "utils/tokenizer" }
core_test_support = { path = "core/tests/common" }
mcp-types = { path = "mcp-types" }
mcp_test_support = { path = "mcp-server/tests/common" }
@@ -246,7 +248,7 @@ unwrap_used = "deny"
# cargo-shear cannot see the platform-specific openssl-sys usage, so we
# silence the false positive here instead of deleting a real dependency.
[workspace.metadata.cargo-shear]
ignored = ["openssl-sys", "codex-utils-readiness"]
ignored = ["openssl-sys", "codex-utils-readiness", "codex-utils-tokenizer"]
[profile.release]
lto = "fat"

View File

@@ -0,0 +1,15 @@
[package]
edition.workspace = true
name = "codex-utils-tokenizer"
version.workspace = true
[lints]
workspace = true
[dependencies]
anyhow = { workspace = true }
thiserror = { workspace = true }
tiktoken-rs = "0.7"
[dev-dependencies]
pretty_assertions = { workspace = true }

View File

@@ -0,0 +1,156 @@
use std::fmt;
use anyhow::Context;
use anyhow::Error as AnyhowError;
use thiserror::Error;
use tiktoken_rs::CoreBPE;
/// Supported local encodings.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum EncodingKind {
O200kBase,
Cl100kBase,
}
impl fmt::Display for EncodingKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::O200kBase => f.write_str("o200k_base"),
Self::Cl100kBase => f.write_str("cl100k_base"),
}
}
}
/// Tokenizer error type.
#[derive(Debug, Error)]
pub enum TokenizerError {
#[error("failed to load encoding {kind}")]
LoadEncoding {
kind: EncodingKind,
#[source]
source: AnyhowError,
},
#[error("failed to decode tokens")]
Decode {
#[source]
source: AnyhowError,
},
}
/// Thin wrapper around a `tiktoken_rs::CoreBPE` tokenizer.
#[derive(Clone)]
pub struct Tokenizer {
inner: CoreBPE,
}
impl Tokenizer {
/// Build a tokenizer for a specific encoding.
pub fn new(kind: EncodingKind) -> Result<Self, TokenizerError> {
let loader: fn() -> anyhow::Result<CoreBPE> = match kind {
EncodingKind::O200kBase => tiktoken_rs::o200k_base,
EncodingKind::Cl100kBase => tiktoken_rs::cl100k_base,
};
let inner = loader().map_err(|source| TokenizerError::LoadEncoding { kind, source })?;
Ok(Self { inner })
}
/// Build a tokenizer using an `OpenAI` model name (maps to an encoding).
/// Falls back to the `o200k_base` encoding when the model is unknown.
pub fn for_model(model: &str) -> Result<Self, TokenizerError> {
match tiktoken_rs::get_bpe_from_model(model) {
Ok(inner) => Ok(Self { inner }),
Err(model_error) => {
let inner = tiktoken_rs::o200k_base()
.with_context(|| {
format!("fallback after model lookup failure for {model}: {model_error}")
})
.map_err(|source| TokenizerError::LoadEncoding {
kind: EncodingKind::O200kBase,
source,
})?;
Ok(Self { inner })
}
}
}
/// Encode text to token IDs. If `with_special_tokens` is true, special
/// tokens are allowed and may appear in the result.
#[must_use]
pub fn encode(&self, text: &str, with_special_tokens: bool) -> Vec<i32> {
let raw = if with_special_tokens {
self.inner.encode_with_special_tokens(text)
} else {
self.inner.encode_ordinary(text)
};
raw.into_iter().map(|t| t as i32).collect()
}
/// Count tokens in `text` as a signed integer.
#[must_use]
pub fn count(&self, text: &str) -> i64 {
// Signed length to satisfy our style preference.
i64::try_from(self.inner.encode_ordinary(text).len()).unwrap_or(i64::MAX)
}
/// Decode token IDs back to text.
pub fn decode(&self, tokens: &[i32]) -> Result<String, TokenizerError> {
let raw: Vec<u32> = tokens.iter().map(|t| *t as u32).collect();
self.inner
.decode(raw)
.map_err(|source| TokenizerError::Decode { source })
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn cl100k_base_roundtrip_simple() -> Result<(), TokenizerError> {
let tok = Tokenizer::new(EncodingKind::Cl100kBase)?;
let s = "hello world";
let ids = tok.encode(s, false);
// Stable expectation for cl100k_base
assert_eq!(ids, vec![15339, 1917]);
let back = tok.decode(&ids)?;
assert_eq!(back, s);
Ok(())
}
#[test]
fn preserves_whitespace_and_special_tokens_flag() -> Result<(), TokenizerError> {
let tok = Tokenizer::new(EncodingKind::Cl100kBase)?;
let s = "This has multiple spaces";
let ids_no_special = tok.encode(s, false);
let round = tok.decode(&ids_no_special)?;
assert_eq!(round, s);
// With special tokens allowed, result may be identical for normal text,
// but the API should still function.
let ids_with_special = tok.encode(s, true);
let round2 = tok.decode(&ids_with_special)?;
assert_eq!(round2, s);
Ok(())
}
#[test]
fn model_mapping_builds_tokenizer() -> Result<(), TokenizerError> {
// Choose a long-standing model alias that maps to cl100k_base.
let tok = Tokenizer::for_model("gpt-5")?;
let ids = tok.encode("ok", false);
let back = tok.decode(&ids)?;
assert_eq!(back, "ok");
Ok(())
}
#[test]
fn unknown_model_defaults_to_o200k_base() -> Result<(), TokenizerError> {
let fallback = Tokenizer::new(EncodingKind::O200kBase)?;
let tok = Tokenizer::for_model("does-not-exist")?;
let text = "fallback please";
assert_eq!(tok.encode(text, false), fallback.encode(text, false));
Ok(())
}
}