feat: local tokenizer (#5508)
This commit is contained in:
46
codex-rs/Cargo.lock
generated
46
codex-rs/Cargo.lock
generated
@@ -1517,6 +1517,16 @@ dependencies = [
|
||||
name = "codex-utils-string"
|
||||
version = "0.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "codex-utils-tokenizer"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"pretty_assertions",
|
||||
"thiserror 2.0.16",
|
||||
"tiktoken-rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "color-eyre"
|
||||
version = "0.6.5"
|
||||
@@ -2314,6 +2324,17 @@ dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2"
|
||||
dependencies = [
|
||||
"bit-set",
|
||||
"regex-automata",
|
||||
"regex-syntax 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
@@ -4631,7 +4652,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls",
|
||||
"socket2 0.6.0",
|
||||
"thiserror 2.0.16",
|
||||
@@ -4651,7 +4672,7 @@ dependencies = [
|
||||
"lru-slab",
|
||||
"rand 0.9.2",
|
||||
"ring",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
@@ -4996,6 +5017,12 @@ version = "0.1.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.1.1"
|
||||
@@ -6176,6 +6203,21 @@ dependencies = [
|
||||
"zune-jpeg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiktoken-rs"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25563eeba904d770acf527e8b370fe9a5547bacd20ff84a0b6c3bc41288e5625"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64",
|
||||
"bstr",
|
||||
"fancy-regex",
|
||||
"lazy_static",
|
||||
"regex",
|
||||
"rustc-hash 1.1.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.44"
|
||||
|
||||
@@ -37,6 +37,7 @@ members = [
|
||||
"utils/readiness",
|
||||
"utils/pty",
|
||||
"utils/string",
|
||||
"utils/tokenizer",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
@@ -82,6 +83,7 @@ codex-utils-json-to-toml = { path = "utils/json-to-toml" }
|
||||
codex-utils-pty = { path = "utils/pty" }
|
||||
codex-utils-readiness = { path = "utils/readiness" }
|
||||
codex-utils-string = { path = "utils/string" }
|
||||
codex-utils-tokenizer = { path = "utils/tokenizer" }
|
||||
core_test_support = { path = "core/tests/common" }
|
||||
mcp-types = { path = "mcp-types" }
|
||||
mcp_test_support = { path = "mcp-server/tests/common" }
|
||||
@@ -246,7 +248,7 @@ unwrap_used = "deny"
|
||||
# cargo-shear cannot see the platform-specific openssl-sys usage, so we
|
||||
# silence the false positive here instead of deleting a real dependency.
|
||||
[workspace.metadata.cargo-shear]
|
||||
ignored = ["openssl-sys", "codex-utils-readiness"]
|
||||
ignored = ["openssl-sys", "codex-utils-readiness", "codex-utils-tokenizer"]
|
||||
|
||||
[profile.release]
|
||||
lto = "fat"
|
||||
|
||||
15
codex-rs/utils/tokenizer/Cargo.toml
Normal file
15
codex-rs/utils/tokenizer/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
edition.workspace = true
|
||||
name = "codex-utils-tokenizer"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tiktoken-rs = "0.7"
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
156
codex-rs/utils/tokenizer/src/lib.rs
Normal file
156
codex-rs/utils/tokenizer/src/lib.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use std::fmt;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Error as AnyhowError;
|
||||
use thiserror::Error;
|
||||
use tiktoken_rs::CoreBPE;
|
||||
|
||||
/// Supported local encodings.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub enum EncodingKind {
|
||||
O200kBase,
|
||||
Cl100kBase,
|
||||
}
|
||||
|
||||
impl fmt::Display for EncodingKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::O200kBase => f.write_str("o200k_base"),
|
||||
Self::Cl100kBase => f.write_str("cl100k_base"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tokenizer error type.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TokenizerError {
|
||||
#[error("failed to load encoding {kind}")]
|
||||
LoadEncoding {
|
||||
kind: EncodingKind,
|
||||
#[source]
|
||||
source: AnyhowError,
|
||||
},
|
||||
#[error("failed to decode tokens")]
|
||||
Decode {
|
||||
#[source]
|
||||
source: AnyhowError,
|
||||
},
|
||||
}
|
||||
|
||||
/// Thin wrapper around a `tiktoken_rs::CoreBPE` tokenizer.
|
||||
#[derive(Clone)]
|
||||
pub struct Tokenizer {
|
||||
inner: CoreBPE,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
/// Build a tokenizer for a specific encoding.
|
||||
pub fn new(kind: EncodingKind) -> Result<Self, TokenizerError> {
|
||||
let loader: fn() -> anyhow::Result<CoreBPE> = match kind {
|
||||
EncodingKind::O200kBase => tiktoken_rs::o200k_base,
|
||||
EncodingKind::Cl100kBase => tiktoken_rs::cl100k_base,
|
||||
};
|
||||
|
||||
let inner = loader().map_err(|source| TokenizerError::LoadEncoding { kind, source })?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
/// Build a tokenizer using an `OpenAI` model name (maps to an encoding).
|
||||
/// Falls back to the `o200k_base` encoding when the model is unknown.
|
||||
pub fn for_model(model: &str) -> Result<Self, TokenizerError> {
|
||||
match tiktoken_rs::get_bpe_from_model(model) {
|
||||
Ok(inner) => Ok(Self { inner }),
|
||||
Err(model_error) => {
|
||||
let inner = tiktoken_rs::o200k_base()
|
||||
.with_context(|| {
|
||||
format!("fallback after model lookup failure for {model}: {model_error}")
|
||||
})
|
||||
.map_err(|source| TokenizerError::LoadEncoding {
|
||||
kind: EncodingKind::O200kBase,
|
||||
source,
|
||||
})?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode text to token IDs. If `with_special_tokens` is true, special
|
||||
/// tokens are allowed and may appear in the result.
|
||||
#[must_use]
|
||||
pub fn encode(&self, text: &str, with_special_tokens: bool) -> Vec<i32> {
|
||||
let raw = if with_special_tokens {
|
||||
self.inner.encode_with_special_tokens(text)
|
||||
} else {
|
||||
self.inner.encode_ordinary(text)
|
||||
};
|
||||
raw.into_iter().map(|t| t as i32).collect()
|
||||
}
|
||||
|
||||
/// Count tokens in `text` as a signed integer.
|
||||
#[must_use]
|
||||
pub fn count(&self, text: &str) -> i64 {
|
||||
// Signed length to satisfy our style preference.
|
||||
i64::try_from(self.inner.encode_ordinary(text).len()).unwrap_or(i64::MAX)
|
||||
}
|
||||
|
||||
/// Decode token IDs back to text.
|
||||
pub fn decode(&self, tokens: &[i32]) -> Result<String, TokenizerError> {
|
||||
let raw: Vec<u32> = tokens.iter().map(|t| *t as u32).collect();
|
||||
self.inner
|
||||
.decode(raw)
|
||||
.map_err(|source| TokenizerError::Decode { source })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn cl100k_base_roundtrip_simple() -> Result<(), TokenizerError> {
|
||||
let tok = Tokenizer::new(EncodingKind::Cl100kBase)?;
|
||||
let s = "hello world";
|
||||
let ids = tok.encode(s, false);
|
||||
// Stable expectation for cl100k_base
|
||||
assert_eq!(ids, vec![15339, 1917]);
|
||||
let back = tok.decode(&ids)?;
|
||||
assert_eq!(back, s);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserves_whitespace_and_special_tokens_flag() -> Result<(), TokenizerError> {
|
||||
let tok = Tokenizer::new(EncodingKind::Cl100kBase)?;
|
||||
let s = "This has multiple spaces";
|
||||
let ids_no_special = tok.encode(s, false);
|
||||
let round = tok.decode(&ids_no_special)?;
|
||||
assert_eq!(round, s);
|
||||
|
||||
// With special tokens allowed, result may be identical for normal text,
|
||||
// but the API should still function.
|
||||
let ids_with_special = tok.encode(s, true);
|
||||
let round2 = tok.decode(&ids_with_special)?;
|
||||
assert_eq!(round2, s);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_mapping_builds_tokenizer() -> Result<(), TokenizerError> {
|
||||
// Choose a long-standing model alias that maps to cl100k_base.
|
||||
let tok = Tokenizer::for_model("gpt-5")?;
|
||||
let ids = tok.encode("ok", false);
|
||||
let back = tok.decode(&ids)?;
|
||||
assert_eq!(back, "ok");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_model_defaults_to_o200k_base() -> Result<(), TokenizerError> {
|
||||
let fallback = Tokenizer::new(EncodingKind::O200kBase)?;
|
||||
let tok = Tokenizer::for_model("does-not-exist")?;
|
||||
let text = "fallback please";
|
||||
assert_eq!(tok.encode(text, false), fallback.encode(text, false));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user