fix: use macros to ensure request/response symmetry (#4529)

Manually curating `protocol-ts/src/lib.rs` was error-prone, as expected.
I finally asked Codex to write some Rust macros so we can ensure that:

- For every variant of `ClientRequest` and `ServerRequest`, there is an
associated `params` and `response` type.
- All response types are included automatically in the output of `codex
generate-ts`.
This commit is contained in:
Michael Bolin
2025-09-30 18:06:05 -07:00
committed by GitHub
parent 7fc3edf8a7
commit 32853ecbc5
9 changed files with 254 additions and 134 deletions

View File

@@ -13,7 +13,9 @@ use crate::protocol::ReviewDecision;
use crate::protocol::SandboxPolicy;
use crate::protocol::TurnAbortReason;
use mcp_types::JSONRPCNotification;
use mcp_types::JSONRPCRequest;
use mcp_types::RequestId;
use paste::paste;
use serde::Deserialize;
use serde::Serialize;
use strum_macros::Display;
@@ -89,137 +91,137 @@ pub enum AuthMode {
ChatGPT,
}
/// Request from the client to the server.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ClientRequest {
/// Generates an `enum ClientRequest` where each variant is a request that the
/// client can send to the server. Each variant has associated `params` and
/// `response` types. Also generates a `export_client_responses()` function to
/// export all response types to TypeScript.
macro_rules! client_request_definitions {
(
$(
$(#[$variant_meta:meta])*
$variant:ident {
params: $(#[$params_meta:meta])* $params:ty,
response: $response:ty,
}
),* $(,)?
) => {
/// Request from the client to the server.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ClientRequest {
$(
$(#[$variant_meta])*
$variant {
#[serde(rename = "id")]
request_id: RequestId,
$(#[$params_meta])*
params: $params,
},
)*
}
pub fn export_client_responses(
out_dir: &::std::path::Path,
) -> ::std::result::Result<(), ::ts_rs::ExportError> {
$(
<$response as ::ts_rs::TS>::export_all_to(out_dir)?;
)*
Ok(())
}
};
}
client_request_definitions! {
Initialize {
#[serde(rename = "id")]
request_id: RequestId,
params: InitializeParams,
response: InitializeResponse,
},
NewConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: NewConversationParams,
response: NewConversationResponse,
},
/// List recorded Codex conversations (rollouts) with optional pagination and search.
ListConversations {
#[serde(rename = "id")]
request_id: RequestId,
params: ListConversationsParams,
response: ListConversationsResponse,
},
/// Resume a recorded Codex conversation from a rollout file.
ResumeConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: ResumeConversationParams,
response: ResumeConversationResponse,
},
ArchiveConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: ArchiveConversationParams,
response: ArchiveConversationResponse,
},
SendUserMessage {
#[serde(rename = "id")]
request_id: RequestId,
params: SendUserMessageParams,
response: SendUserMessageResponse,
},
SendUserTurn {
#[serde(rename = "id")]
request_id: RequestId,
params: SendUserTurnParams,
response: SendUserTurnResponse,
},
InterruptConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: InterruptConversationParams,
response: InterruptConversationResponse,
},
AddConversationListener {
#[serde(rename = "id")]
request_id: RequestId,
params: AddConversationListenerParams,
response: AddConversationSubscriptionResponse,
},
RemoveConversationListener {
#[serde(rename = "id")]
request_id: RequestId,
params: RemoveConversationListenerParams,
response: RemoveConversationSubscriptionResponse,
},
GitDiffToRemote {
#[serde(rename = "id")]
request_id: RequestId,
params: GitDiffToRemoteParams,
response: GitDiffToRemoteResponse,
},
LoginApiKey {
#[serde(rename = "id")]
request_id: RequestId,
params: LoginApiKeyParams,
response: LoginApiKeyResponse,
},
LoginChatGpt {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: LoginChatGptResponse,
},
CancelLoginChatGpt {
#[serde(rename = "id")]
request_id: RequestId,
params: CancelLoginChatGptParams,
response: CancelLoginChatGptResponse,
},
LogoutChatGpt {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: LogoutChatGptResponse,
},
GetAuthStatus {
#[serde(rename = "id")]
request_id: RequestId,
params: GetAuthStatusParams,
response: GetAuthStatusResponse,
},
GetUserSavedConfig {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: GetUserSavedConfigResponse,
},
SetDefaultModel {
#[serde(rename = "id")]
request_id: RequestId,
params: SetDefaultModelParams,
response: SetDefaultModelResponse,
},
GetUserAgent {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: GetUserAgentResponse,
},
UserInfo {
#[serde(rename = "id")]
request_id: RequestId,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
response: UserInfoResponse,
},
FuzzyFileSearch {
#[serde(rename = "id")]
request_id: RequestId,
params: FuzzyFileSearchParams,
response: FuzzyFileSearchResponse,
},
/// Execute a command (argv vector) under the server's sandbox.
ExecOneOffCommand {
#[serde(rename = "id")]
request_id: RequestId,
params: ExecOneOffCommandParams,
response: ExecOneOffCommandResponse,
},
}
@@ -449,7 +451,7 @@ pub struct ExecOneOffCommandParams {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct ExecArbitraryCommandResponse {
pub struct ExecOneOffCommandResponse {
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
@@ -653,30 +655,102 @@ pub enum InputItem {
},
}
// TODO(mbolin): Need test to ensure these constants match the enum variants.
/// Generates an `enum ServerRequest` where each variant is a request that the
/// server can send to the client along with the corresponding params and
/// response types. It also generates helper types used by the app/server
/// infrastructure (method constants, payload enum, and export helpers).
macro_rules! server_request_definitions {
(
$(
$(#[$variant_meta:meta])*
$variant:ident => $method:literal
),* $(,)?
) => {
paste! {
$(pub const [<$variant:snake:upper _METHOD>]: &str = $method;)*
pub const APPLY_PATCH_APPROVAL_METHOD: &str = "applyPatchApproval";
pub const EXEC_COMMAND_APPROVAL_METHOD: &str = "execCommandApproval";
/// Method names for server-initiated requests (camelCase to match JSON-RPC).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServerRequestMethod {
$( $variant ),*
}
/// Request initiated from the server and sent to the client.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ServerRequest {
impl ServerRequestMethod {
pub const fn as_str(self) -> &'static str {
match self {
$(ServerRequestMethod::$variant => $method,)*
}
}
}
/// Request initiated from the server and sent to the client.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ServerRequest {
$(
$(#[$variant_meta])*
$variant {
#[serde(rename = "id")]
request_id: RequestId,
params: [<$variant Params>],
},
)*
}
#[derive(Debug, Clone, PartialEq)]
pub enum ServerRequestPayload {
$( $variant([<$variant Params>]), )*
}
impl ServerRequestPayload {
pub fn method(&self) -> &'static str {
match self {
$(Self::$variant(..) => $method,)*
}
}
pub fn into_params_value(self) -> serde_json::Value {
match self {
$(Self::$variant(params) => serde_json::to_value(params).unwrap_or_default(),)*
}
}
pub fn into_request(self, request_id: RequestId) -> ServerRequest {
match self {
$(Self::$variant(params) => ServerRequest::$variant { request_id, params },)*
}
}
}
}
pub fn export_server_responses(
out_dir: &::std::path::Path,
) -> ::std::result::Result<(), ::ts_rs::ExportError> {
paste! {
$(<[<$variant Response>] as ::ts_rs::TS>::export_all_to(out_dir)?;)*
}
Ok(())
}
};
}
impl TryFrom<JSONRPCRequest> for ServerRequest {
type Error = serde_json::Error;
fn try_from(value: JSONRPCRequest) -> Result<Self, Self::Error> {
serde_json::from_value(serde_json::to_value(value)?)
}
}
server_request_definitions! {
/// Request to approve a patch.
ApplyPatchApproval {
#[serde(rename = "id")]
request_id: RequestId,
params: ApplyPatchApprovalParams,
},
ApplyPatchApproval => "applyPatchApproval",
/// Request to exec a command.
ExecCommandApproval {
#[serde(rename = "id")]
request_id: RequestId,
params: ExecCommandApprovalParams,
},
ExecCommandApproval => "execCommandApproval",
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct ApplyPatchApprovalParams {
pub conversation_id: ConversationId,
/// Use to correlate this with [codex_core::protocol::PatchApplyBeginEvent]
@@ -693,6 +767,7 @@ pub struct ApplyPatchApprovalParams {
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct ExecCommandApprovalParams {
pub conversation_id: ConversationId,
/// Use to correlate this with [codex_core::protocol::ExecCommandBeginEvent]
@@ -766,6 +841,7 @@ pub struct SessionConfiguredNotification {
pub history_log_id: u64,
/// Current number of entries in the history log.
#[ts(type = "number")]
pub history_entry_count: usize,
/// Optional initial messages (as events) for resumed sessions.
@@ -903,4 +979,40 @@ mod tests {
);
Ok(())
}
#[test]
fn serialize_server_request() -> Result<()> {
let conversation_id = ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8")?;
let params = ExecCommandApprovalParams {
conversation_id,
call_id: "call-42".to_string(),
command: vec!["echo".to_string(), "hello".to_string()],
cwd: PathBuf::from("/tmp"),
reason: Some("because tests".to_string()),
};
let request = ServerRequest::ExecCommandApproval {
request_id: RequestId::Integer(7),
params: params.clone(),
};
assert_eq!(
json!({
"method": "execCommandApproval",
"id": 7,
"params": {
"conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8",
"callId": "call-42",
"command": ["echo", "hello"],
"cwd": "/tmp",
"reason": "because tests",
}
}),
serde_json::to_value(&request)?,
);
let payload = ServerRequestPayload::ExecCommandApproval(params);
assert_eq!("execCommandApproval", EXEC_COMMAND_APPROVAL_METHOD);
assert_eq!(EXEC_COMMAND_APPROVAL_METHOD, payload.method());
Ok(())
}
}