From 8f837f109381339e5fc100ec7f82042135843092 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Wed, 10 Sep 2025 23:31:28 -0700 Subject: [PATCH] fix: add check to ensure output of generate_mcp_types.py matches codex-rs/mcp-types/src/lib.rs (#3450) As a follow-up to https://github.com/openai/codex/pull/3439, this adds a CI job to ensure the codegen script has to be updated in order to change `codex-rs/mcp-types/src/lib.rs`. --- .github/workflows/rust-ci.yml | 2 + codex-rs/mcp-types/check_lib_rs.py | 21 ++++ codex-rs/mcp-types/generate_mcp_types.py | 140 ++++++++++++++++------- 3 files changed, 123 insertions(+), 40 deletions(-) create mode 100755 codex-rs/mcp-types/check_lib_rs.py diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index bae33e30..280939c6 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -62,6 +62,8 @@ jobs: components: rustfmt - name: cargo fmt run: cargo fmt -- --config imports_granularity=Item --check + - name: Verify codegen for mcp-types + run: ./mcp-types/check_lib_rs.py cargo_shear: name: cargo shear diff --git a/codex-rs/mcp-types/check_lib_rs.py b/codex-rs/mcp-types/check_lib_rs.py new file mode 100755 index 00000000..37b623a2 --- /dev/null +++ b/codex-rs/mcp-types/check_lib_rs.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +import subprocess +import sys +from pathlib import Path + + +def main() -> int: + crate_dir = Path(__file__).resolve().parent + generator = crate_dir / "generate_mcp_types.py" + + result = subprocess.run( + [sys.executable, str(generator), "--check"], + cwd=crate_dir, + check=False, + ) + return result.returncode + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/codex-rs/mcp-types/generate_mcp_types.py b/codex-rs/mcp-types/generate_mcp_types.py index 03c82baf..60c261e8 100755 --- a/codex-rs/mcp-types/generate_mcp_types.py +++ b/codex-rs/mcp-types/generate_mcp_types.py @@ -5,15 +5,19 @@ import argparse import json import subprocess import sys +import tempfile from dataclasses import ( dataclass, ) +from difflib import unified_diff from pathlib import Path +from shutil import copy2 # Helper first so it is defined when other functions call it. from typing import Any, Literal + SCHEMA_VERSION = "2025-06-18" JSONRPC_VERSION = "2.0" @@ -43,16 +47,31 @@ def main() -> int: default_schema_file = ( Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json" ) + default_lib_rs = Path(__file__).resolve().parent / "src/lib.rs" parser.add_argument( "schema_file", nargs="?", default=default_schema_file, help="schema.json file to process", ) + parser.add_argument( + "--check", + action="store_true", + help="Regenerate lib.rs in a sandbox and ensure the checked-in file matches", + ) args = parser.parse_args() - schema_file = args.schema_file + schema_file = Path(args.schema_file) + crate_dir = Path(__file__).resolve().parent - lib_rs = Path(__file__).resolve().parent / "src/lib.rs" + if args.check: + return run_check(schema_file, crate_dir, default_lib_rs) + + generate_lib_rs(schema_file, default_lib_rs, fmt=True) + return 0 + + +def generate_lib_rs(schema_file: Path, lib_rs: Path, fmt: bool) -> None: + lib_rs.parent.mkdir(parents=True, exist_ok=True) global DEFINITIONS # Allow helper functions to access the schema. @@ -117,9 +136,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }} for req_name in CLIENT_REQUEST_TYPE_NAMES: defn = definitions[req_name] - method_const = ( - defn.get("properties", {}).get("method", {}).get("const", req_name) - ) + method_const = defn.get("properties", {}).get("method", {}).get("const", req_name) payload_type = f"<{req_name} as ModelContextProtocolRequest>::Params" try_from_impl_lines.append(f' "{method_const}" => {{\n') try_from_impl_lines.append( @@ -128,9 +145,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }} try_from_impl_lines.append( f" let params: {payload_type} = serde_json::from_value(params_json)?;\n" ) - try_from_impl_lines.append( - f" Ok(ClientRequest::{req_name}(params))\n" - ) + try_from_impl_lines.append(f" Ok(ClientRequest::{req_name}(params))\n") try_from_impl_lines.append(" },\n") try_from_impl_lines.append( @@ -144,9 +159,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }} # Generate TryFrom for ServerNotification notif_impl_lines: list[str] = [] - notif_impl_lines.append( - "impl TryFrom for ServerNotification {\n" - ) + notif_impl_lines.append("impl TryFrom for ServerNotification {\n") notif_impl_lines.append(" type Error = serde_json::Error;\n") notif_impl_lines.append( " fn try_from(n: JSONRPCNotification) -> std::result::Result {\n" @@ -155,9 +168,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }} for notif_name in SERVER_NOTIFICATION_TYPE_NAMES: n_def = definitions[notif_name] - method_const = ( - n_def.get("properties", {}).get("method", {}).get("const", notif_name) - ) + method_const = n_def.get("properties", {}).get("method", {}).get("const", notif_name) payload_type = f"<{notif_name} as ModelContextProtocolNotification>::Params" notif_impl_lines.append(f' "{method_const}" => {{\n') # params may be optional @@ -167,9 +178,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }} notif_impl_lines.append( f" let params: {payload_type} = serde_json::from_value(params_json)?;\n" ) - notif_impl_lines.append( - f" Ok(ServerNotification::{notif_name}(params))\n" - ) + notif_impl_lines.append(f" Ok(ServerNotification::{notif_name}(params))\n") notif_impl_lines.append(" },\n") notif_impl_lines.append( @@ -185,13 +194,70 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }} for chunk in out: f.write(chunk) - subprocess.check_call( - ["cargo", "fmt", "--", "--config", "imports_granularity=Item"], - cwd=lib_rs.parent.parent, - stderr=subprocess.DEVNULL, - ) + if fmt: + subprocess.check_call( + ["cargo", "fmt", "--", "--config", "imports_granularity=Item"], + cwd=lib_rs.parent.parent, + stderr=subprocess.DEVNULL, + ) - return 0 + +def run_check(schema_file: Path, crate_dir: Path, checked_in_lib: Path) -> int: + config_path = crate_dir.parent / "rustfmt.toml" + eprint(f"Running --check with schema {schema_file}") + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + eprint(f"Created temporary workspace at {tmp_path}") + manifest_path = tmp_path / "Cargo.toml" + eprint(f"Copying Cargo.toml into {manifest_path}") + copy2(crate_dir / "Cargo.toml", manifest_path) + manifest_text = manifest_path.read_text(encoding="utf-8") + manifest_text = manifest_text.replace( + "version = { workspace = true }", + 'version = "0.0.0"', + ) + manifest_text = manifest_text.replace("\n[lints]\nworkspace = true\n", "\n") + manifest_path.write_text(manifest_text, encoding="utf-8") + src_dir = tmp_path / "src" + src_dir.mkdir(parents=True, exist_ok=True) + eprint(f"Generating lib.rs into {src_dir}") + generated_lib = src_dir / "lib.rs" + + generate_lib_rs(schema_file, generated_lib, fmt=False) + + eprint("Formatting generated lib.rs with rustfmt") + subprocess.check_call( + [ + "rustfmt", + "--config-path", + str(config_path), + str(generated_lib), + ], + cwd=tmp_path, + stderr=subprocess.DEVNULL, + ) + + eprint("Comparing generated lib.rs with checked-in version") + checked_in_contents = checked_in_lib.read_text(encoding="utf-8") + generated_contents = generated_lib.read_text(encoding="utf-8") + + if checked_in_contents == generated_contents: + eprint("lib.rs matches checked-in version") + return 0 + + diff = unified_diff( + checked_in_contents.splitlines(keepends=True), + generated_contents.splitlines(keepends=True), + fromfile=str(checked_in_lib), + tofile=str(generated_lib), + ) + diff_text = "".join(diff) + eprint("Generated lib.rs does not match the checked-in version. Diff:") + if diff_text: + eprint(diff_text, end="") + eprint("Re-run generate_mcp_types.py without --check to update src/lib.rs.") + return 1 def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> None: @@ -421,15 +487,11 @@ def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> Non case "integer": out.append(" Integer(i64),\n") case _: - raise ValueError( - f"Unknown type in untagged enum: {simple_type} in {name}" - ) + raise ValueError(f"Unknown type in untagged enum: {simple_type} in {name}") out.append("}\n\n") -def define_any_of( - name: str, list_of_refs: list[Any], description: str | None = None -) -> list[str]: +def define_any_of(name: str, list_of_refs: list[Any], description: str | None = None) -> list[str]: """Generate a Rust enum for a JSON-Schema `anyOf` union. For most types we simply map each `$ref` inside the `anyOf` list to a @@ -494,9 +556,7 @@ def define_any_of( if name == "ClientRequest": payload_type = f"<{ref_name} as ModelContextProtocolRequest>::Params" else: - payload_type = ( - f"<{ref_name} as ModelContextProtocolNotification>::Params" - ) + payload_type = f"<{ref_name} as ModelContextProtocolNotification>::Params" # Determine the wire value for `method` so we can annotate the # variant appropriately. If for some reason the schema does not @@ -504,9 +564,7 @@ def define_any_of( # least compile (although deserialization will likely fail). request_def = DEFINITIONS.get(ref_name, {}) method_const = ( - request_def.get("properties", {}) - .get("method", {}) - .get("const", ref_name) + request_def.get("properties", {}).get("method", {}).get("const", ref_name) ) out.append(f' #[serde(rename = "{method_const}")]\n') @@ -556,7 +614,7 @@ def map_type( if type_prop == "string": if const_prop := typedef.get("const", None): assert isinstance(const_prop, str) - return f'&\'static str = "{const_prop }"' + return f'&\'static str = "{const_prop}"' else: return "String" elif type_prop == "integer": @@ -632,7 +690,7 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp: serde_annotations.append('skip_serializing_if = "Option::is_none"') if serde_annotations: - serde_str = f'#[serde({", ".join(serde_annotations)})]' + serde_str = f"#[serde({', '.join(serde_annotations)})]" else: serde_str = None return RustProp(prop_name, serde_str) @@ -640,9 +698,7 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp: def to_snake_case(name: str) -> str: """Convert a camelCase or PascalCase name to snake_case.""" - snake_case = name[0].lower() + "".join( - "_" + c.lower() if c.isupper() else c for c in name[1:] - ) + snake_case = name[0].lower() + "".join("_" + c.lower() if c.isupper() else c for c in name[1:]) if snake_case != name: return snake_case else: @@ -678,5 +734,9 @@ def emit_doc_comment(text: str | None, out: list[str]) -> None: out.append(f"/// {line.rstrip()}\n") +def eprint(*args: Any, **kwargs: Any) -> None: + print(*args, file=sys.stderr, **kwargs) + + if __name__ == "__main__": sys.exit(main())