feat: introduce mcp-types crate (#787)

This adds our own `mcp-types` crate to our Cargo workspace. We vendor in
the
[`2025-03-26/schema.json`](05f2045136/schema/2025-03-26/schema.json)
from the MCP repo and introduce a `generate_mcp_types.py` script to
codegen the `lib.rs` from the JSON schema.

Test coverage is currently light, but I plan to refine things as we
start making use of this crate.

And yes, I am aware that
https://github.com/modelcontextprotocol/rust-sdk exists, though the
published https://crates.io/crates/rmcp appears to be a competing
effort. While things are up in the air, it seems better for us to
control our own version of this code.

Incidentally, Codex did a lot of the work for this PR. I told it to
never edit `lib.rs` directly and instead to update
`generate_mcp_types.py` and then re-run it to update `lib.rs`. It
followed these instructions and once things were working end-to-end, I
iteratively asked for changes to the tests until the API looked
reasonable (and the code worked). Codex was responsible for figuring out
what to do to `generate_mcp_types.py` to achieve the requested test/API
changes.
This commit is contained in:
Michael Bolin
2025-05-02 13:33:14 -07:00
committed by GitHub
parent f6b1ce2e3a
commit 83961e0299
9 changed files with 4055 additions and 0 deletions

8
codex-rs/Cargo.lock generated
View File

@@ -1940,6 +1940,14 @@ dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "mcp-types"
version = "0.1.0"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "memchr"
version = "2.7.4"

View File

@@ -7,6 +7,7 @@ members = [
"core",
"exec",
"execpolicy",
"mcp-types",
"tui",
]

View File

@@ -0,0 +1,8 @@
[package]
name = "mcp-types"
version = "0.1.0"
edition = "2021"
[dependencies]
serde = { version = "1", features = ["derive"] }
serde_json = "1"

View File

@@ -0,0 +1,8 @@
# mcp-types
Types for Model Context Protocol. Inspired by https://crates.io/crates/lsp-types.
As documented on https://modelcontextprotocol.io/specification/2025-03-26/basic:
- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts
- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json

View File

@@ -0,0 +1,621 @@
#!/usr/bin/env python3
# flake8: noqa: E501
import json
import subprocess
import sys
from dataclasses import (
dataclass,
)
from pathlib import Path
# Helper first so it is defined when other functions call it.
from typing import Any, Literal
STANDARD_DERIVE = "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]\n"
# Will be populated with the schema's `definitions` map in `main()` so that
# helper functions (for example `define_any_of`) can perform look-ups while
# generating code.
DEFINITIONS: dict[str, Any] = {}
# Names of the concrete *Request types that make up the ClientRequest enum.
CLIENT_REQUEST_TYPE_NAMES: list[str] = []
# Concrete *Notification types that make up the ServerNotification enum.
SERVER_NOTIFICATION_TYPE_NAMES: list[str] = []
def main() -> int:
num_args = len(sys.argv)
if num_args == 1:
schema_file = (
Path(__file__).resolve().parent / "schema" / "2025-03-26" / "schema.json"
)
elif num_args == 2:
schema_file = Path(sys.argv[1])
else:
print("Usage: python3 codegen.py <schema.json>")
return 1
lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
global DEFINITIONS # Allow helper functions to access the schema.
with schema_file.open(encoding="utf-8") as f:
schema_json = json.load(f)
DEFINITIONS = schema_json["definitions"]
out = [
"""
// @generated
// DO NOT EDIT THIS FILE DIRECTLY.
// Run the following in the crate root to regenerate this file:
//
// ```shell
// ./generate_mcp_types.py
// ```
use serde::Deserialize;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::convert::TryFrom;
/// Paired request/response types for the Model Context Protocol (MCP).
pub trait ModelContextProtocolRequest {
const METHOD: &'static str;
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
type Result: DeserializeOwned + Serialize + Send + Sync + 'static;
}
/// One-way message in the Model Context Protocol (MCP).
pub trait ModelContextProtocolNotification {
const METHOD: &'static str;
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
}
"""
]
definitions = schema_json["definitions"]
# Keep track of every *Request type so we can generate the TryFrom impl at
# the end.
# The concrete *Request types referenced by the ClientRequest enum will be
# captured dynamically while we are processing that definition.
for name, definition in definitions.items():
add_definition(name, definition, out)
# No-op: list collected via define_any_of("ClientRequest").
# Generate TryFrom impl string and append to out before writing to file.
try_from_impl_lines: list[str] = []
try_from_impl_lines.append("impl TryFrom<JSONRPCRequest> for ClientRequest {\n")
try_from_impl_lines.append(" type Error = serde_json::Error;\n")
try_from_impl_lines.append(
" fn try_from(req: JSONRPCRequest) -> std::result::Result<Self, Self::Error> {\n"
)
try_from_impl_lines.append(" match req.method.as_str() {\n")
for req_name in CLIENT_REQUEST_TYPE_NAMES:
defn = definitions[req_name]
method_const = (
defn.get("properties", {}).get("method", {}).get("const", req_name)
)
payload_type = f"<{req_name} as ModelContextProtocolRequest>::Params"
try_from_impl_lines.append(f' "{method_const}" => {{\n')
try_from_impl_lines.append(
" let params_json = req.params.unwrap_or(serde_json::Value::Null);\n"
)
try_from_impl_lines.append(
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
)
try_from_impl_lines.append(
f" Ok(ClientRequest::{req_name}(params))\n"
)
try_from_impl_lines.append(" },\n")
try_from_impl_lines.append(
' _ => Err(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Unknown method: {}", req.method)))),\n'
)
try_from_impl_lines.append(" }\n")
try_from_impl_lines.append(" }\n")
try_from_impl_lines.append("}\n\n")
out.extend(try_from_impl_lines)
# Generate TryFrom for ServerNotification
notif_impl_lines: list[str] = []
notif_impl_lines.append(
"impl TryFrom<JSONRPCNotification> for ServerNotification {\n"
)
notif_impl_lines.append(" type Error = serde_json::Error;\n")
notif_impl_lines.append(
" fn try_from(n: JSONRPCNotification) -> std::result::Result<Self, Self::Error> {\n"
)
notif_impl_lines.append(" match n.method.as_str() {\n")
for notif_name in SERVER_NOTIFICATION_TYPE_NAMES:
n_def = definitions[notif_name]
method_const = (
n_def.get("properties", {}).get("method", {}).get("const", notif_name)
)
payload_type = f"<{notif_name} as ModelContextProtocolNotification>::Params"
notif_impl_lines.append(f' "{method_const}" => {{\n')
# params may be optional
notif_impl_lines.append(
" let params_json = n.params.unwrap_or(serde_json::Value::Null);\n"
)
notif_impl_lines.append(
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
)
notif_impl_lines.append(
f" Ok(ServerNotification::{notif_name}(params))\n"
)
notif_impl_lines.append(" },\n")
notif_impl_lines.append(
' _ => Err(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Unknown method: {}", n.method)))),\n'
)
notif_impl_lines.append(" }\n")
notif_impl_lines.append(" }\n")
notif_impl_lines.append("}\n")
out.extend(notif_impl_lines)
with open(lib_rs, "w", encoding="utf-8") as f:
for chunk in out:
f.write(chunk)
subprocess.check_call(
["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
cwd=lib_rs.parent.parent,
stderr=subprocess.DEVNULL,
)
return 0
def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> None:
# Capture description
description = definition.get("description")
properties = definition.get("properties", {})
if properties:
required_props = set(definition.get("required", []))
out.extend(define_struct(name, properties, required_props, description))
return
enum_values = definition.get("enum", [])
if enum_values:
assert definition.get("type") == "string"
define_string_enum(name, enum_values, out, description)
return
any_of = definition.get("anyOf", [])
if any_of:
assert isinstance(any_of, list)
if name == "JSONRPCMessage":
# Special case for JSONRPCMessage because its definition in the
# JSON schema does not quite match how we think about this type
# definition in Rust.
deep_copied_any_of = json.loads(json.dumps(any_of))
deep_copied_any_of[2] = {
"$ref": "#/definitions/JSONRPCBatchRequest",
}
deep_copied_any_of[5] = {
"$ref": "#/definitions/JSONRPCBatchResponse",
}
out.extend(define_any_of(name, deep_copied_any_of, description))
else:
out.extend(define_any_of(name, any_of, description))
return
type_prop = definition.get("type", None)
if type_prop:
if type_prop == "string":
# Newtype pattern
out.append(STANDARD_DERIVE)
out.append(f"pub struct {name}(String);\n\n")
return
elif types := check_string_list(type_prop):
define_untagged_enum(name, types, out)
return
elif type_prop == "array":
item_name = name + "Item"
out.extend(define_any_of(item_name, definition["items"]["anyOf"]))
out.append(f"pub type {name} = Vec<{item_name}>;\n\n")
return
raise ValueError(f"Unknown type: {type_prop} in {name}")
ref_prop = definition.get("$ref", None)
if ref_prop:
ref = type_from_ref(ref_prop)
out.extend(f"pub type {name} = {ref};\n\n")
return
raise ValueError(f"Definition for {name} could not be processed.")
extra_defs = []
@dataclass
class StructField:
viz: Literal["pub"] | Literal["const"]
name: str
type_name: str
serde: str | None = None
def append(self, out: list[str], supports_const: bool) -> None:
# Omit these for now.
if self.name == "jsonrpc":
return
if self.serde:
out.append(f" {self.serde}\n")
if self.viz == "const":
if supports_const:
out.append(f" const {self.name}: {self.type_name};\n")
else:
out.append(f" pub {self.name}: String, // {self.type_name}\n")
else:
out.append(f" pub {self.name}: {self.type_name},\n")
def define_struct(
name: str,
properties: dict[str, Any],
required_props: set[str],
description: str | None,
) -> list[str]:
out: list[str] = []
fields: list[StructField] = []
for prop_name, prop in properties.items():
if prop_name == "_meta":
# TODO?
continue
prop_type = map_type(prop, prop_name, name)
if prop_name not in required_props:
prop_type = f"Option<{prop_type}>"
rs_prop = rust_prop_name(prop_name)
if prop_type.startswith("&'static str"):
fields.append(StructField("const", rs_prop.name, prop_type, rs_prop.serde))
else:
fields.append(StructField("pub", rs_prop.name, prop_type, rs_prop.serde))
if implements_request_trait(name):
add_trait_impl(name, "ModelContextProtocolRequest", fields, out)
elif implements_notification_trait(name):
add_trait_impl(name, "ModelContextProtocolNotification", fields, out)
else:
# Add doc comment if available.
emit_doc_comment(description, out)
out.append(STANDARD_DERIVE)
out.append(f"pub struct {name} {{\n")
for field in fields:
field.append(out, supports_const=False)
out.append("}\n\n")
# Declare any extra structs after the main struct.
if extra_defs:
out.extend(extra_defs)
# Clear the extra structs for the next definition.
extra_defs.clear()
return out
def infer_result_type(request_type_name: str) -> str:
"""Return the corresponding Result type name for a given *Request name."""
if not request_type_name.endswith("Request"):
return "Result" # fallback
candidate = request_type_name[:-7] + "Result"
if candidate in DEFINITIONS:
return candidate
# Fallback to generic Result if specific one missing.
return "Result"
def implements_request_trait(name: str) -> bool:
return name.endswith("Request") and name not in (
"Request",
"JSONRPCRequest",
"PaginatedRequest",
)
def implements_notification_trait(name: str) -> bool:
return name.endswith("Notification") and name not in (
"Notification",
"JSONRPCNotification",
)
def add_trait_impl(
type_name: str, trait_name: str, fields: list[StructField], out: list[str]
) -> None:
# out.append("#[derive(Debug)]\n")
out.append(STANDARD_DERIVE)
out.append(f"pub enum {type_name} {{}}\n\n")
out.append(f"impl {trait_name} for {type_name} {{\n")
for field in fields:
if field.name == "method":
field.name = "METHOD"
field.append(out, supports_const=True)
elif field.name == "params":
out.append(f" type Params = {field.type_name};\n")
else:
print(f"Warning: {type_name} has unexpected field {field.name}.")
if trait_name == "ModelContextProtocolRequest":
result_type = infer_result_type(type_name)
out.append(f" type Result = {result_type};\n")
out.append("}\n\n")
def define_string_enum(
name: str, enum_values: Any, out: list[str], description: str | None
) -> None:
emit_doc_comment(description, out)
out.append(STANDARD_DERIVE)
out.append(f"pub enum {name} {{\n")
for value in enum_values:
assert isinstance(value, str)
out.append(f' #[serde(rename = "{value}")]\n')
out.append(f" {capitalize(value)},\n")
out.append("}\n\n")
return out
def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> None:
out.append(STANDARD_DERIVE)
out.append("#[serde(untagged)]\n")
out.append(f"pub enum {name} {{\n")
for simple_type in type_list:
match simple_type:
case "string":
out.append(" String(String),\n")
case "integer":
out.append(" Integer(i64),\n")
case _:
raise ValueError(
f"Unknown type in untagged enum: {simple_type} in {name}"
)
out.append("}\n\n")
def define_any_of(
name: str, list_of_refs: list[Any], description: str | None = None
) -> list[str]:
"""Generate a Rust enum for a JSON-Schema `anyOf` union.
For most types we simply map each `$ref` inside the `anyOf` list to a
similarly named enum variant that holds the referenced type as its
payload. For certain well-known composite types (currently only
`ClientRequest`) we need a little bit of extra intelligence:
* The JSON shape of a request is `{ "method": <string>, "params": <object?> }`.
* We want to deserialize directly into `ClientRequest` using Serde's
`#[serde(tag = "method", content = "params")]` representation so that
the enum payload is **only** the request's `params` object.
* Therefore each enum variant needs to carry the dedicated `…Params` type
(wrapped in `Option<…>` if the `params` field is not required), not the
full `…Request` struct from the schema definition.
"""
# Verify each item in list_of_refs is a dict with a $ref key.
refs = [item["$ref"] for item in list_of_refs if isinstance(item, dict)]
out: list[str] = []
if description:
emit_doc_comment(description, out)
out.append(STANDARD_DERIVE)
if serde := get_serde_annotation_for_anyof_type(name):
out.append(serde + "\n")
out.append(f"pub enum {name} {{\n")
if name == "ClientRequest":
# Record the set of request type names so we can later generate a
# `TryFrom<JSONRPCRequest>` implementation.
global CLIENT_REQUEST_TYPE_NAMES
CLIENT_REQUEST_TYPE_NAMES = [type_from_ref(r) for r in refs]
if name == "ServerNotification":
global SERVER_NOTIFICATION_TYPE_NAMES
SERVER_NOTIFICATION_TYPE_NAMES = [type_from_ref(r) for r in refs]
for ref in refs:
ref_name = type_from_ref(ref)
# For JSONRPCMessage variants, drop the common "JSONRPC" prefix to
# make the enum easier to read (e.g. `Request` instead of
# `JSONRPCRequest`). The payload type remains unchanged.
variant_name = (
ref_name[len("JSONRPC") :]
if name == "JSONRPCMessage" and ref_name.startswith("JSONRPC")
else ref_name
)
# Special-case for `ClientRequest` and `ServerNotification` so the enum
# variant's payload is the *Params type rather than the full *Request /
# *Notification marker type.
if name in ("ClientRequest", "ServerNotification"):
# Rely on the trait implementation to tell us the exact Rust type
# of the `params` payload. This guarantees we stay in sync with any
# special-case logic used elsewhere (e.g. objects with
# `additionalProperties` mapping to `serde_json::Value`).
if name == "ClientRequest":
payload_type = f"<{ref_name} as ModelContextProtocolRequest>::Params"
else:
payload_type = (
f"<{ref_name} as ModelContextProtocolNotification>::Params"
)
# Determine the wire value for `method` so we can annotate the
# variant appropriately. If for some reason the schema does not
# specify a constant we fall back to the type name, which will at
# least compile (although deserialization will likely fail).
request_def = DEFINITIONS.get(ref_name, {})
method_const = (
request_def.get("properties", {})
.get("method", {})
.get("const", ref_name)
)
out.append(f' #[serde(rename = "{method_const}")]\n')
out.append(f" {variant_name}({payload_type}),\n")
else:
# The regular/straight-forward case.
out.append(f" {variant_name}({ref_name}),\n")
out.append("}\n\n")
return out
def get_serde_annotation_for_anyof_type(type_name: str) -> str | None:
# TODO: Solve this in a more generic way.
match type_name:
case "ClientRequest":
return '#[serde(tag = "method", content = "params")]'
case "ServerNotification":
return '#[serde(tag = "method", content = "params")]'
case "JSONRPCMessage":
return "#[serde(untagged)]"
case _:
return None
def map_type(
typedef: dict[str, any],
prop_name: str | None = None,
struct_name: str | None = None,
) -> str:
"""typedef must have a `type` key, but may also have an `items`key."""
ref_prop = typedef.get("$ref", None)
if ref_prop:
return type_from_ref(ref_prop)
any_of = typedef.get("anyOf", None)
if any_of:
assert prop_name is not None
assert struct_name is not None
custom_type = struct_name + capitalize(prop_name)
extra_defs.extend(define_any_of(custom_type, any_of))
return custom_type
type_prop = typedef.get("type", None)
if type_prop is None:
# Likely `unknown` in TypeScript, like the JSONRPCError.data property.
return "serde_json::Value"
if type_prop == "string":
if const_prop := typedef.get("const", None):
assert isinstance(const_prop, str)
return f'&\'static str = "{const_prop }"'
else:
return "String"
elif type_prop == "integer":
return "i64"
elif type_prop == "number":
return "f64"
elif type_prop == "boolean":
return "bool"
elif type_prop == "array":
item_type = typedef.get("items", None)
if item_type:
item_type = map_type(item_type, prop_name, struct_name)
assert isinstance(item_type, str)
return f"Vec<{item_type}>"
else:
raise ValueError("Array type without items.")
elif type_prop == "object":
# If the schema says `additionalProperties: {}` this is effectively an
# open-ended map, so deserialize into `serde_json::Value` for maximum
# flexibility.
if typedef.get("additionalProperties") is not None:
return "serde_json::Value"
# If there are *no* properties declared treat it similarly.
if not typedef.get("properties"):
return "serde_json::Value"
# Otherwise, synthesize a nested struct for the inline object.
assert prop_name is not None
assert struct_name is not None
custom_type = struct_name + capitalize(prop_name)
extra_defs.extend(
define_struct(
custom_type,
typedef["properties"],
set(typedef.get("required", [])),
typedef.get("description"),
)
)
return custom_type
else:
raise ValueError(f"Unknown type: {type_prop} in {typedef}")
@dataclass
class RustProp:
name: str
# serde annotation, if necessary
serde: str | None = None
def rust_prop_name(name: str) -> RustProp:
"""Convert a JSON property name to a Rust property name."""
if name == "type":
return RustProp("r#type", None)
elif name == "ref":
return RustProp("r#ref", None)
elif snake_case := to_snake_case(name):
return RustProp(snake_case, f'#[serde(rename = "{name}")]')
else:
return RustProp(name, None)
def to_snake_case(name: str) -> str:
"""Convert a camelCase or PascalCase name to snake_case."""
snake_case = name[0].lower() + "".join(
"_" + c.lower() if c.isupper() else c for c in name[1:]
)
if snake_case != name:
return snake_case
else:
return None
def capitalize(name: str) -> str:
"""Capitalize the first letter of a name."""
return name[0].upper() + name[1:]
def check_string_list(value: Any) -> list[str] | None:
"""If the value is a list of strings, return it. Otherwise, return None."""
if not isinstance(value, list):
return None
for item in value:
if not isinstance(item, str):
return None
return value
def type_from_ref(ref: str) -> str:
"""Convert a JSON reference to a Rust type."""
assert ref.startswith("#/definitions/")
return ref.split("/")[-1]
def emit_doc_comment(text: str | None, out: list[str]) -> None:
"""Append Rust doc comments derived from the JSON-schema description."""
if not text:
return
for line in text.strip().split("\n"):
out.append(f"/// {line.rstrip()}\n")
if __name__ == "__main__":
sys.exit(main())

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,65 @@
use mcp_types::ClientCapabilities;
use mcp_types::ClientRequest;
use mcp_types::Implementation;
use mcp_types::InitializeRequestParams;
use mcp_types::JSONRPCMessage;
use mcp_types::JSONRPCRequest;
use mcp_types::RequestId;
use serde_json::json;
#[test]
fn deserialize_initialize_request() {
let raw = r#"{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"capabilities": {},
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
"protocolVersion": "2025-03-26"
}
}"#;
// Deserialize full JSONRPCMessage first.
let msg: JSONRPCMessage =
serde_json::from_str(raw).expect("failed to deserialize JSONRPCMessage");
// Extract the request variant.
let JSONRPCMessage::Request(json_req) = msg else {
unreachable!()
};
let expected_req = JSONRPCRequest {
id: RequestId::Integer(1),
method: "initialize".into(),
params: Some(json!({
"capabilities": {},
"clientInfo": { "name": "acme-client", "version": "1.2.3" },
"protocolVersion": "2025-03-26"
})),
};
assert_eq!(json_req, expected_req);
let client_req: ClientRequest =
ClientRequest::try_from(json_req).expect("conversion must succeed");
let ClientRequest::InitializeRequest(init_params) = client_req else {
unreachable!()
};
assert_eq!(
init_params,
InitializeRequestParams {
capabilities: ClientCapabilities {
experimental: None,
roots: None,
sampling: None,
},
client_info: Implementation {
name: "acme-client".into(),
version: "1.2.3".into(),
},
protocol_version: "2025-03-26".into(),
}
);
}

View File

@@ -0,0 +1,43 @@
use mcp_types::JSONRPCMessage;
use mcp_types::ProgressNotificationParams;
use mcp_types::ProgressToken;
use mcp_types::ServerNotification;
#[test]
fn deserialize_progress_notification() {
let raw = r#"{
"jsonrpc": "2.0",
"method": "notifications/progress",
"params": {
"message": "Half way there",
"progress": 0.5,
"progressToken": 99,
"total": 1.0
}
}"#;
// Deserialize full JSONRPCMessage first.
let msg: JSONRPCMessage = serde_json::from_str(raw).expect("invalid JSONRPCMessage");
// Extract the notification variant.
let JSONRPCMessage::Notification(notif) = msg else {
unreachable!()
};
// Convert via generated TryFrom.
let server_notif: ServerNotification =
ServerNotification::try_from(notif).expect("conversion must succeed");
let ServerNotification::ProgressNotification(params) = server_notif else {
unreachable!()
};
let expected_params = ProgressNotificationParams {
message: Some("Half way there".into()),
progress: 0.5,
progress_token: ProgressToken::Integer(99),
total: Some(1.0),
};
assert_eq!(params, expected_params);
}