fix: use macros to ensure request/response symmetry (#4529)

Manually curating `protocol-ts/src/lib.rs` was error-prone, as expected.
I finally asked Codex to write some Rust macros so we can ensure that:

- For every variant of `ClientRequest` and `ServerRequest`, there is an
associated `params` and `response` type.
- All response types are included automatically in the output of `codex
generate-ts`.
This commit is contained in:
Michael Bolin
2025-09-30 18:06:05 -07:00
committed by GitHub
parent 7fc3edf8a7
commit 32853ecbc5
9 changed files with 254 additions and 134 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -1103,6 +1103,7 @@ dependencies = [
"icu_locale_core",
"mcp-types",
"mime_guess",
"paste",
"pretty_assertions",
"serde",
"serde_json",

View File

@@ -123,6 +123,7 @@ opentelemetry-semantic-conventions = "0.30.0"
opentelemetry_sdk = "0.30.0"
os_info = "3.12.0"
owo-colors = "4.2.0"
paste = "1.0.15"
path-absolutize = "3.1.1"
path-clean = "1.0.1"
pathdiff = "0.2"

View File

@@ -36,7 +36,6 @@ use codex_core::protocol::ReviewDecision;
use codex_login::ServerOptions as LoginServerOptions;
use codex_login::ShutdownHandle;
use codex_login::run_login_server;
use codex_protocol::mcp_protocol::APPLY_PATCH_APPROVAL_METHOD;
use codex_protocol::mcp_protocol::AddConversationListenerParams;
use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse;
use codex_protocol::mcp_protocol::ApplyPatchApprovalParams;
@@ -47,11 +46,10 @@ use codex_protocol::mcp_protocol::AuthStatusChangeNotification;
use codex_protocol::mcp_protocol::ClientRequest;
use codex_protocol::mcp_protocol::ConversationId;
use codex_protocol::mcp_protocol::ConversationSummary;
use codex_protocol::mcp_protocol::EXEC_COMMAND_APPROVAL_METHOD;
use codex_protocol::mcp_protocol::ExecArbitraryCommandResponse;
use codex_protocol::mcp_protocol::ExecCommandApprovalParams;
use codex_protocol::mcp_protocol::ExecCommandApprovalResponse;
use codex_protocol::mcp_protocol::ExecOneOffCommandParams;
use codex_protocol::mcp_protocol::ExecOneOffCommandResponse;
use codex_protocol::mcp_protocol::FuzzyFileSearchParams;
use codex_protocol::mcp_protocol::FuzzyFileSearchResponse;
use codex_protocol::mcp_protocol::GetUserAgentResponse;
@@ -76,6 +74,7 @@ use codex_protocol::mcp_protocol::SendUserMessageResponse;
use codex_protocol::mcp_protocol::SendUserTurnParams;
use codex_protocol::mcp_protocol::SendUserTurnResponse;
use codex_protocol::mcp_protocol::ServerNotification;
use codex_protocol::mcp_protocol::ServerRequestPayload;
use codex_protocol::mcp_protocol::SessionConfiguredNotification;
use codex_protocol::mcp_protocol::SetDefaultModelParams;
use codex_protocol::mcp_protocol::SetDefaultModelResponse;
@@ -632,7 +631,7 @@ impl CodexMessageProcessor {
.await
{
Ok(output) => {
let response = ExecArbitraryCommandResponse {
let response = ExecOneOffCommandResponse {
exit_code: output.exit_code,
stdout: output.stdout.text,
stderr: output.stderr.text,
@@ -1268,9 +1267,8 @@ async fn apply_bespoke_event_handling(
reason,
grant_root,
};
let value = serde_json::to_value(&params).unwrap_or_default();
let rx = outgoing
.send_request(APPLY_PATCH_APPROVAL_METHOD, Some(value))
.send_request(ServerRequestPayload::ApplyPatchApproval(params))
.await;
// TODO(mbolin): Enforce a timeout so this task does not live indefinitely?
tokio::spawn(async move {
@@ -1290,9 +1288,8 @@ async fn apply_bespoke_event_handling(
cwd,
reason,
};
let value = serde_json::to_value(&params).unwrap_or_default();
let rx = outgoing
.send_request(EXEC_COMMAND_APPROVAL_METHOD, Some(value))
.send_request(ServerRequestPayload::ExecCommandApproval(params))
.await;
// TODO(mbolin): Enforce a timeout so this task does not live indefinitely?

View File

@@ -3,6 +3,7 @@ use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use codex_protocol::mcp_protocol::ServerNotification;
use codex_protocol::mcp_protocol::ServerRequestPayload;
use mcp_types::JSONRPC_VERSION;
use mcp_types::JSONRPCError;
use mcp_types::JSONRPCErrorError;
@@ -38,8 +39,7 @@ impl OutgoingMessageSender {
pub(crate) async fn send_request(
&self,
method: &str,
params: Option<serde_json::Value>,
request: ServerRequestPayload,
) -> oneshot::Receiver<Result> {
let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed));
let outgoing_message_id = id.clone();
@@ -49,6 +49,14 @@ impl OutgoingMessageSender {
request_id_to_callback.insert(id, tx_approve);
}
let method = request.method();
let params_value = request.into_params_value();
let params = if params_value.is_null() {
None
} else {
Some(params_value)
};
let outgoing_message = OutgoingMessage::Request(OutgoingRequest {
id: outgoing_message_id,
method: method.to_string(),

View File

@@ -26,6 +26,7 @@ use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
use codex_protocol::mcp_protocol::ResumeConversationParams;
use codex_protocol::mcp_protocol::SendUserMessageParams;
use codex_protocol::mcp_protocol::SendUserTurnParams;
use codex_protocol::mcp_protocol::ServerRequest;
use codex_protocol::mcp_protocol::SetDefaultModelParams;
use mcp_types::JSONRPC_VERSION;
@@ -373,7 +374,7 @@ impl McpProcess {
Ok(message)
}
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> {
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<ServerRequest> {
eprintln!("in read_stream_until_request_message()");
loop {
@@ -384,7 +385,9 @@ impl McpProcess {
eprintln!("notification: {message:?}");
}
JSONRPCMessage::Request(jsonrpc_request) => {
return Ok(jsonrpc_request);
return jsonrpc_request.try_into().with_context(
|| "failed to deserialize ServerRequest from JSONRPCRequest",
);
}
JSONRPCMessage::Error(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");

View File

@@ -12,7 +12,7 @@ use codex_core::protocol_config_types::ReasoningSummary;
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_protocol::mcp_protocol::AddConversationListenerParams;
use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse;
use codex_protocol::mcp_protocol::EXEC_COMMAND_APPROVAL_METHOD;
use codex_protocol::mcp_protocol::ExecCommandApprovalParams;
use codex_protocol::mcp_protocol::NewConversationParams;
use codex_protocol::mcp_protocol::NewConversationResponse;
use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
@@ -21,6 +21,7 @@ use codex_protocol::mcp_protocol::SendUserMessageParams;
use codex_protocol::mcp_protocol::SendUserMessageResponse;
use codex_protocol::mcp_protocol::SendUserTurnParams;
use codex_protocol::mcp_protocol::SendUserTurnResponse;
use codex_protocol::mcp_protocol::ServerRequest;
use mcp_types::JSONRPCNotification;
use mcp_types::JSONRPCResponse;
use mcp_types::RequestId;
@@ -290,11 +291,28 @@ async fn test_send_user_turn_changes_approval_policy_behavior() {
.await
.expect("waiting for exec approval request timeout")
.expect("exec approval request");
assert_eq!(request.method, EXEC_COMMAND_APPROVAL_METHOD);
let ServerRequest::ExecCommandApproval { request_id, params } = request else {
panic!("expected ExecCommandApproval request, got: {request:?}");
};
assert_eq!(
ExecCommandApprovalParams {
conversation_id,
call_id: "call1".to_string(),
command: vec![
"python3".to_string(),
"-c".to_string(),
"print(42)".to_string(),
],
cwd: working_directory.clone(),
reason: None,
},
params
);
// Approve so the first turn can complete
mcp.send_response(
request.id,
request_id,
serde_json::json!({ "decision": codex_core::protocol::ReviewDecision::Approved }),
)
.await

View File

@@ -1,6 +1,12 @@
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use codex_protocol::mcp_protocol::ClientNotification;
use codex_protocol::mcp_protocol::ClientRequest;
use codex_protocol::mcp_protocol::ServerNotification;
use codex_protocol::mcp_protocol::ServerRequest;
use codex_protocol::mcp_protocol::export_client_responses;
use codex_protocol::mcp_protocol::export_server_responses;
use std::ffi::OsStr;
use std::fs;
use std::io::Read;
@@ -15,44 +21,17 @@ const HEADER: &str = "// GENERATED CODE! DO NOT MODIFY BY HAND!\n\n";
pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> {
ensure_dir(out_dir)?;
use codex_protocol::mcp_protocol::*;
// Generating the TS bindings for these top-level enums ensures all types
// reachable from them will be generated by induction, so they do not need
// to be listed individually.
// Generate the TS bindings client -> server messages.
ClientRequest::export_all_to(out_dir)?;
export_client_responses(out_dir)?;
ClientNotification::export_all_to(out_dir)?;
// Generate the TS bindings server -> client messages.
ServerRequest::export_all_to(out_dir)?;
export_server_responses(out_dir)?;
ServerNotification::export_all_to(out_dir)?;
// Response types for ClientRequest (mirror enum order).
InitializeResponse::export_all_to(out_dir)?;
NewConversationResponse::export_all_to(out_dir)?;
ListConversationsResponse::export_all_to(out_dir)?;
ResumeConversationResponse::export_all_to(out_dir)?;
ArchiveConversationResponse::export_all_to(out_dir)?;
SendUserMessageResponse::export_all_to(out_dir)?;
SendUserTurnResponse::export_all_to(out_dir)?;
InterruptConversationResponse::export_all_to(out_dir)?;
AddConversationSubscriptionResponse::export_all_to(out_dir)?;
RemoveConversationSubscriptionResponse::export_all_to(out_dir)?;
GitDiffToRemoteResponse::export_all_to(out_dir)?;
LoginApiKeyResponse::export_all_to(out_dir)?;
LoginChatGptResponse::export_all_to(out_dir)?;
CancelLoginChatGptResponse::export_all_to(out_dir)?;
LogoutChatGptResponse::export_all_to(out_dir)?;
GetAuthStatusResponse::export_all_to(out_dir)?;
GetUserSavedConfigResponse::export_all_to(out_dir)?;
SetDefaultModelResponse::export_all_to(out_dir)?;
GetUserAgentResponse::export_all_to(out_dir)?;
UserInfoResponse::export_all_to(out_dir)?;
FuzzyFileSearchResponse::export_all_to(out_dir)?;
ExecArbitraryCommandResponse::export_all_to(out_dir)?;
// Response types for ServerRequest (mirror enum order).
ApplyPatchApprovalResponse::export_all_to(out_dir)?;
ExecCommandApprovalResponse::export_all_to(out_dir)?;
// Generate index.ts that re-exports all types.
generate_index_ts(out_dir)?;
// Prepend header to each generated .ts file

View File

@@ -16,6 +16,7 @@ icu_decimal = { workspace = true }
icu_locale_core = { workspace = true }
mcp-types = { workspace = true }
mime_guess = { workspace = true }
paste = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
serde_with = { workspace = true, features = ["macros", "base64"] }

View File

@@ -13,7 +13,9 @@ use crate::protocol::ReviewDecision;
use crate::protocol::SandboxPolicy;
use crate::protocol::TurnAbortReason;
use mcp_types::JSONRPCNotification;
use mcp_types::JSONRPCRequest;
use mcp_types::RequestId;
use paste::paste;
use serde::Deserialize;
use serde::Serialize;
use strum_macros::Display;
@@ -89,137 +91,137 @@ pub enum AuthMode {
ChatGPT,
}
/// Request from the client to the server.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ClientRequest {
Initialize {
/// Generates an `enum ClientRequest` where each variant is a request that the
/// client can send to the server. Each variant has associated `params` and
/// `response` types. Also generates a `export_client_responses()` function to
/// export all response types to TypeScript.
macro_rules! client_request_definitions {
(
$(
$(#[$variant_meta:meta])*
$variant:ident {
params: $(#[$params_meta:meta])* $params:ty,
response: $response:ty,
}
),* $(,)?
) => {
/// Request from the client to the server.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ClientRequest {
$(
$(#[$variant_meta])*
$variant {
#[serde(rename = "id")]
request_id: RequestId,
$(#[$params_meta])*
params: $params,
},
)*
}
pub fn export_client_responses(
out_dir: &::std::path::Path,
) -> ::std::result::Result<(), ::ts_rs::ExportError> {
$(
<$response as ::ts_rs::TS>::export_all_to(out_dir)?;
)*
Ok(())
}
};
}
client_request_definitions! {
Initialize {
params: InitializeParams,
response: InitializeResponse,
},
NewConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: NewConversationParams,
response: NewConversationResponse,
},
/// List recorded Codex conversations (rollouts) with optional pagination and search.
ListConversations {
#[serde(rename = "id")]
request_id: RequestId,
params: ListConversationsParams,
response: ListConversationsResponse,
},
/// Resume a recorded Codex conversation from a rollout file.
ResumeConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: ResumeConversationParams,
response: ResumeConversationResponse,
},
ArchiveConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: ArchiveConversationParams,
response: ArchiveConversationResponse,
},
SendUserMessage {
#[serde(rename = "id")]
request_id: RequestId,
params: SendUserMessageParams,
response: SendUserMessageResponse,
},
SendUserTurn {
#[serde(rename = "id")]
request_id: RequestId,
params: SendUserTurnParams,
response: SendUserTurnResponse,
},
InterruptConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: InterruptConversationParams,
response: InterruptConversationResponse,
},
AddConversationListener {
#[serde(rename = "id")]
request_id: RequestId,
params: AddConversationListenerParams,
response: AddConversationSubscriptionResponse,
},
RemoveConversationListener {
#[serde(rename = "id")]
request_id: RequestId,
params: RemoveConversationListenerParams,
response: RemoveConversationSubscriptionResponse,
},
GitDiffToRemote {
#[serde(rename = "id")]
request_id: RequestId,
params: GitDiffToRemoteParams,
response: GitDiffToRemoteResponse,
},
LoginApiKey {
#[serde(rename = "id")]
request_id: RequestId,
params: LoginApiKeyParams,
response: LoginApiKeyResponse,
},
LoginChatGpt {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: LoginChatGptResponse,
},
CancelLoginChatGpt {
#[serde(rename = "id")]
request_id: RequestId,
params: CancelLoginChatGptParams,
response: CancelLoginChatGptResponse,
},
LogoutChatGpt {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: LogoutChatGptResponse,
},
GetAuthStatus {
#[serde(rename = "id")]
request_id: RequestId,
params: GetAuthStatusParams,
response: GetAuthStatusResponse,
},
GetUserSavedConfig {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: GetUserSavedConfigResponse,
},
SetDefaultModel {
#[serde(rename = "id")]
request_id: RequestId,
params: SetDefaultModelParams,
response: SetDefaultModelResponse,
},
GetUserAgent {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: GetUserAgentResponse,
},
UserInfo {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: UserInfoResponse,
},
FuzzyFileSearch {
#[serde(rename = "id")]
request_id: RequestId,
params: FuzzyFileSearchParams,
response: FuzzyFileSearchResponse,
},
/// Execute a command (argv vector) under the server's sandbox.
ExecOneOffCommand {
#[serde(rename = "id")]
request_id: RequestId,
params: ExecOneOffCommandParams,
response: ExecOneOffCommandResponse,
},
}
@@ -449,7 +451,7 @@ pub struct ExecOneOffCommandParams {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct ExecArbitraryCommandResponse {
pub struct ExecOneOffCommandResponse {
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
@@ -653,30 +655,102 @@ pub enum InputItem {
},
}
// TODO(mbolin): Need test to ensure these constants match the enum variants.
/// Generates an `enum ServerRequest` where each variant is a request that the
/// server can send to the client along with the corresponding params and
/// response types. It also generates helper types used by the app/server
/// infrastructure (method constants, payload enum, and export helpers).
macro_rules! server_request_definitions {
(
$(
$(#[$variant_meta:meta])*
$variant:ident => $method:literal
),* $(,)?
) => {
paste! {
$(pub const [<$variant:snake:upper _METHOD>]: &str = $method;)*
pub const APPLY_PATCH_APPROVAL_METHOD: &str = "applyPatchApproval";
pub const EXEC_COMMAND_APPROVAL_METHOD: &str = "execCommandApproval";
/// Method names for server-initiated requests (camelCase to match JSON-RPC).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServerRequestMethod {
$( $variant ),*
}
/// Request initiated from the server and sent to the client.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ServerRequest {
impl ServerRequestMethod {
pub const fn as_str(self) -> &'static str {
match self {
$(ServerRequestMethod::$variant => $method,)*
}
}
}
/// Request initiated from the server and sent to the client.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ServerRequest {
$(
$(#[$variant_meta])*
$variant {
#[serde(rename = "id")]
request_id: RequestId,
params: [<$variant Params>],
},
)*
}
#[derive(Debug, Clone, PartialEq)]
pub enum ServerRequestPayload {
$( $variant([<$variant Params>]), )*
}
impl ServerRequestPayload {
pub fn method(&self) -> &'static str {
match self {
$(Self::$variant(..) => $method,)*
}
}
pub fn into_params_value(self) -> serde_json::Value {
match self {
$(Self::$variant(params) => serde_json::to_value(params).unwrap_or_default(),)*
}
}
pub fn into_request(self, request_id: RequestId) -> ServerRequest {
match self {
$(Self::$variant(params) => ServerRequest::$variant { request_id, params },)*
}
}
}
}
pub fn export_server_responses(
out_dir: &::std::path::Path,
) -> ::std::result::Result<(), ::ts_rs::ExportError> {
paste! {
$(<[<$variant Response>] as ::ts_rs::TS>::export_all_to(out_dir)?;)*
}
Ok(())
}
};
}
impl TryFrom<JSONRPCRequest> for ServerRequest {
type Error = serde_json::Error;
fn try_from(value: JSONRPCRequest) -> Result<Self, Self::Error> {
serde_json::from_value(serde_json::to_value(value)?)
}
}
server_request_definitions! {
/// Request to approve a patch.
ApplyPatchApproval {
#[serde(rename = "id")]
request_id: RequestId,
params: ApplyPatchApprovalParams,
},
ApplyPatchApproval => "applyPatchApproval",
/// Request to exec a command.
ExecCommandApproval {
#[serde(rename = "id")]
request_id: RequestId,
params: ExecCommandApprovalParams,
},
ExecCommandApproval => "execCommandApproval",
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct ApplyPatchApprovalParams {
pub conversation_id: ConversationId,
/// Use to correlate this with [codex_core::protocol::PatchApplyBeginEvent]
@@ -693,6 +767,7 @@ pub struct ApplyPatchApprovalParams {
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct ExecCommandApprovalParams {
pub conversation_id: ConversationId,
/// Use to correlate this with [codex_core::protocol::ExecCommandBeginEvent]
@@ -766,6 +841,7 @@ pub struct SessionConfiguredNotification {
pub history_log_id: u64,
/// Current number of entries in the history log.
#[ts(type = "number")]
pub history_entry_count: usize,
/// Optional initial messages (as events) for resumed sessions.
@@ -903,4 +979,40 @@ mod tests {
);
Ok(())
}
#[test]
fn serialize_server_request() -> Result<()> {
let conversation_id = ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8")?;
let params = ExecCommandApprovalParams {
conversation_id,
call_id: "call-42".to_string(),
command: vec!["echo".to_string(), "hello".to_string()],
cwd: PathBuf::from("/tmp"),
reason: Some("because tests".to_string()),
};
let request = ServerRequest::ExecCommandApproval {
request_id: RequestId::Integer(7),
params: params.clone(),
};
assert_eq!(
json!({
"method": "execCommandApproval",
"id": 7,
"params": {
"conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8",
"callId": "call-42",
"command": ["echo", "hello"],
"cwd": "/tmp",
"reason": "because tests",
}
}),
serde_json::to_value(&request)?,
);
let payload = ServerRequestPayload::ExecCommandApproval(params);
assert_eq!("execCommandApproval", EXEC_COMMAND_APPROVAL_METHOD);
assert_eq!(EXEC_COMMAND_APPROVAL_METHOD, payload.method());
Ok(())
}
}