fix: add check to ensure output of generate_mcp_types.py matches codex-rs/mcp-types/src/lib.rs (#3450)
As a follow-up to https://github.com/openai/codex/pull/3439, this adds a CI job to ensure the codegen script has to be updated in order to change `codex-rs/mcp-types/src/lib.rs`.
This commit is contained in:
2
.github/workflows/rust-ci.yml
vendored
2
.github/workflows/rust-ci.yml
vendored
@@ -62,6 +62,8 @@ jobs:
|
|||||||
components: rustfmt
|
components: rustfmt
|
||||||
- name: cargo fmt
|
- name: cargo fmt
|
||||||
run: cargo fmt -- --config imports_granularity=Item --check
|
run: cargo fmt -- --config imports_granularity=Item --check
|
||||||
|
- name: Verify codegen for mcp-types
|
||||||
|
run: ./mcp-types/check_lib_rs.py
|
||||||
|
|
||||||
cargo_shear:
|
cargo_shear:
|
||||||
name: cargo shear
|
name: cargo shear
|
||||||
|
|||||||
21
codex-rs/mcp-types/check_lib_rs.py
Executable file
21
codex-rs/mcp-types/check_lib_rs.py
Executable file
@@ -0,0 +1,21 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
crate_dir = Path(__file__).resolve().parent
|
||||||
|
generator = crate_dir / "generate_mcp_types.py"
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
[sys.executable, str(generator), "--check"],
|
||||||
|
cwd=crate_dir,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
return result.returncode
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -5,15 +5,19 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from dataclasses import (
|
from dataclasses import (
|
||||||
dataclass,
|
dataclass,
|
||||||
)
|
)
|
||||||
|
from difflib import unified_diff
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from shutil import copy2
|
||||||
|
|
||||||
# Helper first so it is defined when other functions call it.
|
# Helper first so it is defined when other functions call it.
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
||||||
SCHEMA_VERSION = "2025-06-18"
|
SCHEMA_VERSION = "2025-06-18"
|
||||||
JSONRPC_VERSION = "2.0"
|
JSONRPC_VERSION = "2.0"
|
||||||
|
|
||||||
@@ -43,16 +47,31 @@ def main() -> int:
|
|||||||
default_schema_file = (
|
default_schema_file = (
|
||||||
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
|
||||||
)
|
)
|
||||||
|
default_lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"schema_file",
|
"schema_file",
|
||||||
nargs="?",
|
nargs="?",
|
||||||
default=default_schema_file,
|
default=default_schema_file,
|
||||||
help="schema.json file to process",
|
help="schema.json file to process",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--check",
|
||||||
|
action="store_true",
|
||||||
|
help="Regenerate lib.rs in a sandbox and ensure the checked-in file matches",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
schema_file = args.schema_file
|
schema_file = Path(args.schema_file)
|
||||||
|
crate_dir = Path(__file__).resolve().parent
|
||||||
|
|
||||||
lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
|
if args.check:
|
||||||
|
return run_check(schema_file, crate_dir, default_lib_rs)
|
||||||
|
|
||||||
|
generate_lib_rs(schema_file, default_lib_rs, fmt=True)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lib_rs(schema_file: Path, lib_rs: Path, fmt: bool) -> None:
|
||||||
|
lib_rs.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
global DEFINITIONS # Allow helper functions to access the schema.
|
global DEFINITIONS # Allow helper functions to access the schema.
|
||||||
|
|
||||||
@@ -117,9 +136,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
|||||||
|
|
||||||
for req_name in CLIENT_REQUEST_TYPE_NAMES:
|
for req_name in CLIENT_REQUEST_TYPE_NAMES:
|
||||||
defn = definitions[req_name]
|
defn = definitions[req_name]
|
||||||
method_const = (
|
method_const = defn.get("properties", {}).get("method", {}).get("const", req_name)
|
||||||
defn.get("properties", {}).get("method", {}).get("const", req_name)
|
|
||||||
)
|
|
||||||
payload_type = f"<{req_name} as ModelContextProtocolRequest>::Params"
|
payload_type = f"<{req_name} as ModelContextProtocolRequest>::Params"
|
||||||
try_from_impl_lines.append(f' "{method_const}" => {{\n')
|
try_from_impl_lines.append(f' "{method_const}" => {{\n')
|
||||||
try_from_impl_lines.append(
|
try_from_impl_lines.append(
|
||||||
@@ -128,9 +145,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
|||||||
try_from_impl_lines.append(
|
try_from_impl_lines.append(
|
||||||
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
|
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
|
||||||
)
|
)
|
||||||
try_from_impl_lines.append(
|
try_from_impl_lines.append(f" Ok(ClientRequest::{req_name}(params))\n")
|
||||||
f" Ok(ClientRequest::{req_name}(params))\n"
|
|
||||||
)
|
|
||||||
try_from_impl_lines.append(" },\n")
|
try_from_impl_lines.append(" },\n")
|
||||||
|
|
||||||
try_from_impl_lines.append(
|
try_from_impl_lines.append(
|
||||||
@@ -144,9 +159,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
|||||||
|
|
||||||
# Generate TryFrom for ServerNotification
|
# Generate TryFrom for ServerNotification
|
||||||
notif_impl_lines: list[str] = []
|
notif_impl_lines: list[str] = []
|
||||||
notif_impl_lines.append(
|
notif_impl_lines.append("impl TryFrom<JSONRPCNotification> for ServerNotification {\n")
|
||||||
"impl TryFrom<JSONRPCNotification> for ServerNotification {\n"
|
|
||||||
)
|
|
||||||
notif_impl_lines.append(" type Error = serde_json::Error;\n")
|
notif_impl_lines.append(" type Error = serde_json::Error;\n")
|
||||||
notif_impl_lines.append(
|
notif_impl_lines.append(
|
||||||
" fn try_from(n: JSONRPCNotification) -> std::result::Result<Self, Self::Error> {\n"
|
" fn try_from(n: JSONRPCNotification) -> std::result::Result<Self, Self::Error> {\n"
|
||||||
@@ -155,9 +168,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
|||||||
|
|
||||||
for notif_name in SERVER_NOTIFICATION_TYPE_NAMES:
|
for notif_name in SERVER_NOTIFICATION_TYPE_NAMES:
|
||||||
n_def = definitions[notif_name]
|
n_def = definitions[notif_name]
|
||||||
method_const = (
|
method_const = n_def.get("properties", {}).get("method", {}).get("const", notif_name)
|
||||||
n_def.get("properties", {}).get("method", {}).get("const", notif_name)
|
|
||||||
)
|
|
||||||
payload_type = f"<{notif_name} as ModelContextProtocolNotification>::Params"
|
payload_type = f"<{notif_name} as ModelContextProtocolNotification>::Params"
|
||||||
notif_impl_lines.append(f' "{method_const}" => {{\n')
|
notif_impl_lines.append(f' "{method_const}" => {{\n')
|
||||||
# params may be optional
|
# params may be optional
|
||||||
@@ -167,9 +178,7 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
|||||||
notif_impl_lines.append(
|
notif_impl_lines.append(
|
||||||
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
|
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
|
||||||
)
|
)
|
||||||
notif_impl_lines.append(
|
notif_impl_lines.append(f" Ok(ServerNotification::{notif_name}(params))\n")
|
||||||
f" Ok(ServerNotification::{notif_name}(params))\n"
|
|
||||||
)
|
|
||||||
notif_impl_lines.append(" },\n")
|
notif_impl_lines.append(" },\n")
|
||||||
|
|
||||||
notif_impl_lines.append(
|
notif_impl_lines.append(
|
||||||
@@ -185,13 +194,70 @@ fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
|
|||||||
for chunk in out:
|
for chunk in out:
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
|
|
||||||
subprocess.check_call(
|
if fmt:
|
||||||
["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
|
subprocess.check_call(
|
||||||
cwd=lib_rs.parent.parent,
|
["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
|
||||||
stderr=subprocess.DEVNULL,
|
cwd=lib_rs.parent.parent,
|
||||||
)
|
stderr=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
|
||||||
return 0
|
|
||||||
|
def run_check(schema_file: Path, crate_dir: Path, checked_in_lib: Path) -> int:
|
||||||
|
config_path = crate_dir.parent / "rustfmt.toml"
|
||||||
|
eprint(f"Running --check with schema {schema_file}")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tmp_path = Path(tmp_dir)
|
||||||
|
eprint(f"Created temporary workspace at {tmp_path}")
|
||||||
|
manifest_path = tmp_path / "Cargo.toml"
|
||||||
|
eprint(f"Copying Cargo.toml into {manifest_path}")
|
||||||
|
copy2(crate_dir / "Cargo.toml", manifest_path)
|
||||||
|
manifest_text = manifest_path.read_text(encoding="utf-8")
|
||||||
|
manifest_text = manifest_text.replace(
|
||||||
|
"version = { workspace = true }",
|
||||||
|
'version = "0.0.0"',
|
||||||
|
)
|
||||||
|
manifest_text = manifest_text.replace("\n[lints]\nworkspace = true\n", "\n")
|
||||||
|
manifest_path.write_text(manifest_text, encoding="utf-8")
|
||||||
|
src_dir = tmp_path / "src"
|
||||||
|
src_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
eprint(f"Generating lib.rs into {src_dir}")
|
||||||
|
generated_lib = src_dir / "lib.rs"
|
||||||
|
|
||||||
|
generate_lib_rs(schema_file, generated_lib, fmt=False)
|
||||||
|
|
||||||
|
eprint("Formatting generated lib.rs with rustfmt")
|
||||||
|
subprocess.check_call(
|
||||||
|
[
|
||||||
|
"rustfmt",
|
||||||
|
"--config-path",
|
||||||
|
str(config_path),
|
||||||
|
str(generated_lib),
|
||||||
|
],
|
||||||
|
cwd=tmp_path,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
|
||||||
|
eprint("Comparing generated lib.rs with checked-in version")
|
||||||
|
checked_in_contents = checked_in_lib.read_text(encoding="utf-8")
|
||||||
|
generated_contents = generated_lib.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
if checked_in_contents == generated_contents:
|
||||||
|
eprint("lib.rs matches checked-in version")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
diff = unified_diff(
|
||||||
|
checked_in_contents.splitlines(keepends=True),
|
||||||
|
generated_contents.splitlines(keepends=True),
|
||||||
|
fromfile=str(checked_in_lib),
|
||||||
|
tofile=str(generated_lib),
|
||||||
|
)
|
||||||
|
diff_text = "".join(diff)
|
||||||
|
eprint("Generated lib.rs does not match the checked-in version. Diff:")
|
||||||
|
if diff_text:
|
||||||
|
eprint(diff_text, end="")
|
||||||
|
eprint("Re-run generate_mcp_types.py without --check to update src/lib.rs.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> None:
|
def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> None:
|
||||||
@@ -421,15 +487,11 @@ def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> Non
|
|||||||
case "integer":
|
case "integer":
|
||||||
out.append(" Integer(i64),\n")
|
out.append(" Integer(i64),\n")
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(
|
raise ValueError(f"Unknown type in untagged enum: {simple_type} in {name}")
|
||||||
f"Unknown type in untagged enum: {simple_type} in {name}"
|
|
||||||
)
|
|
||||||
out.append("}\n\n")
|
out.append("}\n\n")
|
||||||
|
|
||||||
|
|
||||||
def define_any_of(
|
def define_any_of(name: str, list_of_refs: list[Any], description: str | None = None) -> list[str]:
|
||||||
name: str, list_of_refs: list[Any], description: str | None = None
|
|
||||||
) -> list[str]:
|
|
||||||
"""Generate a Rust enum for a JSON-Schema `anyOf` union.
|
"""Generate a Rust enum for a JSON-Schema `anyOf` union.
|
||||||
|
|
||||||
For most types we simply map each `$ref` inside the `anyOf` list to a
|
For most types we simply map each `$ref` inside the `anyOf` list to a
|
||||||
@@ -494,9 +556,7 @@ def define_any_of(
|
|||||||
if name == "ClientRequest":
|
if name == "ClientRequest":
|
||||||
payload_type = f"<{ref_name} as ModelContextProtocolRequest>::Params"
|
payload_type = f"<{ref_name} as ModelContextProtocolRequest>::Params"
|
||||||
else:
|
else:
|
||||||
payload_type = (
|
payload_type = f"<{ref_name} as ModelContextProtocolNotification>::Params"
|
||||||
f"<{ref_name} as ModelContextProtocolNotification>::Params"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine the wire value for `method` so we can annotate the
|
# Determine the wire value for `method` so we can annotate the
|
||||||
# variant appropriately. If for some reason the schema does not
|
# variant appropriately. If for some reason the schema does not
|
||||||
@@ -504,9 +564,7 @@ def define_any_of(
|
|||||||
# least compile (although deserialization will likely fail).
|
# least compile (although deserialization will likely fail).
|
||||||
request_def = DEFINITIONS.get(ref_name, {})
|
request_def = DEFINITIONS.get(ref_name, {})
|
||||||
method_const = (
|
method_const = (
|
||||||
request_def.get("properties", {})
|
request_def.get("properties", {}).get("method", {}).get("const", ref_name)
|
||||||
.get("method", {})
|
|
||||||
.get("const", ref_name)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
out.append(f' #[serde(rename = "{method_const}")]\n')
|
out.append(f' #[serde(rename = "{method_const}")]\n')
|
||||||
@@ -556,7 +614,7 @@ def map_type(
|
|||||||
if type_prop == "string":
|
if type_prop == "string":
|
||||||
if const_prop := typedef.get("const", None):
|
if const_prop := typedef.get("const", None):
|
||||||
assert isinstance(const_prop, str)
|
assert isinstance(const_prop, str)
|
||||||
return f'&\'static str = "{const_prop }"'
|
return f'&\'static str = "{const_prop}"'
|
||||||
else:
|
else:
|
||||||
return "String"
|
return "String"
|
||||||
elif type_prop == "integer":
|
elif type_prop == "integer":
|
||||||
@@ -632,7 +690,7 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
|
|||||||
serde_annotations.append('skip_serializing_if = "Option::is_none"')
|
serde_annotations.append('skip_serializing_if = "Option::is_none"')
|
||||||
|
|
||||||
if serde_annotations:
|
if serde_annotations:
|
||||||
serde_str = f'#[serde({", ".join(serde_annotations)})]'
|
serde_str = f"#[serde({', '.join(serde_annotations)})]"
|
||||||
else:
|
else:
|
||||||
serde_str = None
|
serde_str = None
|
||||||
return RustProp(prop_name, serde_str)
|
return RustProp(prop_name, serde_str)
|
||||||
@@ -640,9 +698,7 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp:
|
|||||||
|
|
||||||
def to_snake_case(name: str) -> str:
|
def to_snake_case(name: str) -> str:
|
||||||
"""Convert a camelCase or PascalCase name to snake_case."""
|
"""Convert a camelCase or PascalCase name to snake_case."""
|
||||||
snake_case = name[0].lower() + "".join(
|
snake_case = name[0].lower() + "".join("_" + c.lower() if c.isupper() else c for c in name[1:])
|
||||||
"_" + c.lower() if c.isupper() else c for c in name[1:]
|
|
||||||
)
|
|
||||||
if snake_case != name:
|
if snake_case != name:
|
||||||
return snake_case
|
return snake_case
|
||||||
else:
|
else:
|
||||||
@@ -678,5 +734,9 @@ def emit_doc_comment(text: str | None, out: list[str]) -> None:
|
|||||||
out.append(f"/// {line.rstrip()}\n")
|
out.append(f"/// {line.rstrip()}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def eprint(*args: Any, **kwargs: Any) -> None:
|
||||||
|
print(*args, file=sys.stderr, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|||||||
Reference in New Issue
Block a user