fix: use macros to ensure request/response symmetry (#4529)

Manually curating `protocol-ts/src/lib.rs` was error-prone, as expected.
I finally asked Codex to write some Rust macros so we can ensure that:

- For every variant of `ClientRequest` and `ServerRequest`, there is an
associated `params` and `response` type.
- All response types are included automatically in the output of `codex
generate-ts`.
This commit is contained in:
Michael Bolin
2025-09-30 18:06:05 -07:00
committed by GitHub
parent 7fc3edf8a7
commit 32853ecbc5
9 changed files with 254 additions and 134 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -1103,6 +1103,7 @@ dependencies = [
"icu_locale_core", "icu_locale_core",
"mcp-types", "mcp-types",
"mime_guess", "mime_guess",
"paste",
"pretty_assertions", "pretty_assertions",
"serde", "serde",
"serde_json", "serde_json",

View File

@@ -123,6 +123,7 @@ opentelemetry-semantic-conventions = "0.30.0"
opentelemetry_sdk = "0.30.0" opentelemetry_sdk = "0.30.0"
os_info = "3.12.0" os_info = "3.12.0"
owo-colors = "4.2.0" owo-colors = "4.2.0"
paste = "1.0.15"
path-absolutize = "3.1.1" path-absolutize = "3.1.1"
path-clean = "1.0.1" path-clean = "1.0.1"
pathdiff = "0.2" pathdiff = "0.2"

View File

@@ -36,7 +36,6 @@ use codex_core::protocol::ReviewDecision;
use codex_login::ServerOptions as LoginServerOptions; use codex_login::ServerOptions as LoginServerOptions;
use codex_login::ShutdownHandle; use codex_login::ShutdownHandle;
use codex_login::run_login_server; use codex_login::run_login_server;
use codex_protocol::mcp_protocol::APPLY_PATCH_APPROVAL_METHOD;
use codex_protocol::mcp_protocol::AddConversationListenerParams; use codex_protocol::mcp_protocol::AddConversationListenerParams;
use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse; use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse;
use codex_protocol::mcp_protocol::ApplyPatchApprovalParams; use codex_protocol::mcp_protocol::ApplyPatchApprovalParams;
@@ -47,11 +46,10 @@ use codex_protocol::mcp_protocol::AuthStatusChangeNotification;
use codex_protocol::mcp_protocol::ClientRequest; use codex_protocol::mcp_protocol::ClientRequest;
use codex_protocol::mcp_protocol::ConversationId; use codex_protocol::mcp_protocol::ConversationId;
use codex_protocol::mcp_protocol::ConversationSummary; use codex_protocol::mcp_protocol::ConversationSummary;
use codex_protocol::mcp_protocol::EXEC_COMMAND_APPROVAL_METHOD;
use codex_protocol::mcp_protocol::ExecArbitraryCommandResponse;
use codex_protocol::mcp_protocol::ExecCommandApprovalParams; use codex_protocol::mcp_protocol::ExecCommandApprovalParams;
use codex_protocol::mcp_protocol::ExecCommandApprovalResponse; use codex_protocol::mcp_protocol::ExecCommandApprovalResponse;
use codex_protocol::mcp_protocol::ExecOneOffCommandParams; use codex_protocol::mcp_protocol::ExecOneOffCommandParams;
use codex_protocol::mcp_protocol::ExecOneOffCommandResponse;
use codex_protocol::mcp_protocol::FuzzyFileSearchParams; use codex_protocol::mcp_protocol::FuzzyFileSearchParams;
use codex_protocol::mcp_protocol::FuzzyFileSearchResponse; use codex_protocol::mcp_protocol::FuzzyFileSearchResponse;
use codex_protocol::mcp_protocol::GetUserAgentResponse; use codex_protocol::mcp_protocol::GetUserAgentResponse;
@@ -76,6 +74,7 @@ use codex_protocol::mcp_protocol::SendUserMessageResponse;
use codex_protocol::mcp_protocol::SendUserTurnParams; use codex_protocol::mcp_protocol::SendUserTurnParams;
use codex_protocol::mcp_protocol::SendUserTurnResponse; use codex_protocol::mcp_protocol::SendUserTurnResponse;
use codex_protocol::mcp_protocol::ServerNotification; use codex_protocol::mcp_protocol::ServerNotification;
use codex_protocol::mcp_protocol::ServerRequestPayload;
use codex_protocol::mcp_protocol::SessionConfiguredNotification; use codex_protocol::mcp_protocol::SessionConfiguredNotification;
use codex_protocol::mcp_protocol::SetDefaultModelParams; use codex_protocol::mcp_protocol::SetDefaultModelParams;
use codex_protocol::mcp_protocol::SetDefaultModelResponse; use codex_protocol::mcp_protocol::SetDefaultModelResponse;
@@ -632,7 +631,7 @@ impl CodexMessageProcessor {
.await .await
{ {
Ok(output) => { Ok(output) => {
let response = ExecArbitraryCommandResponse { let response = ExecOneOffCommandResponse {
exit_code: output.exit_code, exit_code: output.exit_code,
stdout: output.stdout.text, stdout: output.stdout.text,
stderr: output.stderr.text, stderr: output.stderr.text,
@@ -1268,9 +1267,8 @@ async fn apply_bespoke_event_handling(
reason, reason,
grant_root, grant_root,
}; };
let value = serde_json::to_value(&params).unwrap_or_default();
let rx = outgoing let rx = outgoing
.send_request(APPLY_PATCH_APPROVAL_METHOD, Some(value)) .send_request(ServerRequestPayload::ApplyPatchApproval(params))
.await; .await;
// TODO(mbolin): Enforce a timeout so this task does not live indefinitely? // TODO(mbolin): Enforce a timeout so this task does not live indefinitely?
tokio::spawn(async move { tokio::spawn(async move {
@@ -1290,9 +1288,8 @@ async fn apply_bespoke_event_handling(
cwd, cwd,
reason, reason,
}; };
let value = serde_json::to_value(&params).unwrap_or_default();
let rx = outgoing let rx = outgoing
.send_request(EXEC_COMMAND_APPROVAL_METHOD, Some(value)) .send_request(ServerRequestPayload::ExecCommandApproval(params))
.await; .await;
// TODO(mbolin): Enforce a timeout so this task does not live indefinitely? // TODO(mbolin): Enforce a timeout so this task does not live indefinitely?

View File

@@ -3,6 +3,7 @@ use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use codex_protocol::mcp_protocol::ServerNotification; use codex_protocol::mcp_protocol::ServerNotification;
use codex_protocol::mcp_protocol::ServerRequestPayload;
use mcp_types::JSONRPC_VERSION; use mcp_types::JSONRPC_VERSION;
use mcp_types::JSONRPCError; use mcp_types::JSONRPCError;
use mcp_types::JSONRPCErrorError; use mcp_types::JSONRPCErrorError;
@@ -38,8 +39,7 @@ impl OutgoingMessageSender {
pub(crate) async fn send_request( pub(crate) async fn send_request(
&self, &self,
method: &str, request: ServerRequestPayload,
params: Option<serde_json::Value>,
) -> oneshot::Receiver<Result> { ) -> oneshot::Receiver<Result> {
let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed)); let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed));
let outgoing_message_id = id.clone(); let outgoing_message_id = id.clone();
@@ -49,6 +49,14 @@ impl OutgoingMessageSender {
request_id_to_callback.insert(id, tx_approve); request_id_to_callback.insert(id, tx_approve);
} }
let method = request.method();
let params_value = request.into_params_value();
let params = if params_value.is_null() {
None
} else {
Some(params_value)
};
let outgoing_message = OutgoingMessage::Request(OutgoingRequest { let outgoing_message = OutgoingMessage::Request(OutgoingRequest {
id: outgoing_message_id, id: outgoing_message_id,
method: method.to_string(), method: method.to_string(),

View File

@@ -26,6 +26,7 @@ use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
use codex_protocol::mcp_protocol::ResumeConversationParams; use codex_protocol::mcp_protocol::ResumeConversationParams;
use codex_protocol::mcp_protocol::SendUserMessageParams; use codex_protocol::mcp_protocol::SendUserMessageParams;
use codex_protocol::mcp_protocol::SendUserTurnParams; use codex_protocol::mcp_protocol::SendUserTurnParams;
use codex_protocol::mcp_protocol::ServerRequest;
use codex_protocol::mcp_protocol::SetDefaultModelParams; use codex_protocol::mcp_protocol::SetDefaultModelParams;
use mcp_types::JSONRPC_VERSION; use mcp_types::JSONRPC_VERSION;
@@ -373,7 +374,7 @@ impl McpProcess {
Ok(message) Ok(message)
} }
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> { pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<ServerRequest> {
eprintln!("in read_stream_until_request_message()"); eprintln!("in read_stream_until_request_message()");
loop { loop {
@@ -384,7 +385,9 @@ impl McpProcess {
eprintln!("notification: {message:?}"); eprintln!("notification: {message:?}");
} }
JSONRPCMessage::Request(jsonrpc_request) => { JSONRPCMessage::Request(jsonrpc_request) => {
return Ok(jsonrpc_request); return jsonrpc_request.try_into().with_context(
|| "failed to deserialize ServerRequest from JSONRPCRequest",
);
} }
JSONRPCMessage::Error(_) => { JSONRPCMessage::Error(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}"); anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");

View File

@@ -12,7 +12,7 @@ use codex_core::protocol_config_types::ReasoningSummary;
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_protocol::mcp_protocol::AddConversationListenerParams; use codex_protocol::mcp_protocol::AddConversationListenerParams;
use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse; use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse;
use codex_protocol::mcp_protocol::EXEC_COMMAND_APPROVAL_METHOD; use codex_protocol::mcp_protocol::ExecCommandApprovalParams;
use codex_protocol::mcp_protocol::NewConversationParams; use codex_protocol::mcp_protocol::NewConversationParams;
use codex_protocol::mcp_protocol::NewConversationResponse; use codex_protocol::mcp_protocol::NewConversationResponse;
use codex_protocol::mcp_protocol::RemoveConversationListenerParams; use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
@@ -21,6 +21,7 @@ use codex_protocol::mcp_protocol::SendUserMessageParams;
use codex_protocol::mcp_protocol::SendUserMessageResponse; use codex_protocol::mcp_protocol::SendUserMessageResponse;
use codex_protocol::mcp_protocol::SendUserTurnParams; use codex_protocol::mcp_protocol::SendUserTurnParams;
use codex_protocol::mcp_protocol::SendUserTurnResponse; use codex_protocol::mcp_protocol::SendUserTurnResponse;
use codex_protocol::mcp_protocol::ServerRequest;
use mcp_types::JSONRPCNotification; use mcp_types::JSONRPCNotification;
use mcp_types::JSONRPCResponse; use mcp_types::JSONRPCResponse;
use mcp_types::RequestId; use mcp_types::RequestId;
@@ -290,11 +291,28 @@ async fn test_send_user_turn_changes_approval_policy_behavior() {
.await .await
.expect("waiting for exec approval request timeout") .expect("waiting for exec approval request timeout")
.expect("exec approval request"); .expect("exec approval request");
assert_eq!(request.method, EXEC_COMMAND_APPROVAL_METHOD); let ServerRequest::ExecCommandApproval { request_id, params } = request else {
panic!("expected ExecCommandApproval request, got: {request:?}");
};
assert_eq!(
ExecCommandApprovalParams {
conversation_id,
call_id: "call1".to_string(),
command: vec![
"python3".to_string(),
"-c".to_string(),
"print(42)".to_string(),
],
cwd: working_directory.clone(),
reason: None,
},
params
);
// Approve so the first turn can complete // Approve so the first turn can complete
mcp.send_response( mcp.send_response(
request.id, request_id,
serde_json::json!({ "decision": codex_core::protocol::ReviewDecision::Approved }), serde_json::json!({ "decision": codex_core::protocol::ReviewDecision::Approved }),
) )
.await .await

View File

@@ -1,6 +1,12 @@
use anyhow::Context; use anyhow::Context;
use anyhow::Result; use anyhow::Result;
use anyhow::anyhow; use anyhow::anyhow;
use codex_protocol::mcp_protocol::ClientNotification;
use codex_protocol::mcp_protocol::ClientRequest;
use codex_protocol::mcp_protocol::ServerNotification;
use codex_protocol::mcp_protocol::ServerRequest;
use codex_protocol::mcp_protocol::export_client_responses;
use codex_protocol::mcp_protocol::export_server_responses;
use std::ffi::OsStr; use std::ffi::OsStr;
use std::fs; use std::fs;
use std::io::Read; use std::io::Read;
@@ -15,44 +21,17 @@ const HEADER: &str = "// GENERATED CODE! DO NOT MODIFY BY HAND!\n\n";
pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> { pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> {
ensure_dir(out_dir)?; ensure_dir(out_dir)?;
use codex_protocol::mcp_protocol::*; // Generate the TS bindings client -> server messages.
// Generating the TS bindings for these top-level enums ensures all types
// reachable from them will be generated by induction, so they do not need
// to be listed individually.
ClientRequest::export_all_to(out_dir)?; ClientRequest::export_all_to(out_dir)?;
export_client_responses(out_dir)?;
ClientNotification::export_all_to(out_dir)?; ClientNotification::export_all_to(out_dir)?;
// Generate the TS bindings server -> client messages.
ServerRequest::export_all_to(out_dir)?; ServerRequest::export_all_to(out_dir)?;
export_server_responses(out_dir)?;
ServerNotification::export_all_to(out_dir)?; ServerNotification::export_all_to(out_dir)?;
// Response types for ClientRequest (mirror enum order). // Generate index.ts that re-exports all types.
InitializeResponse::export_all_to(out_dir)?;
NewConversationResponse::export_all_to(out_dir)?;
ListConversationsResponse::export_all_to(out_dir)?;
ResumeConversationResponse::export_all_to(out_dir)?;
ArchiveConversationResponse::export_all_to(out_dir)?;
SendUserMessageResponse::export_all_to(out_dir)?;
SendUserTurnResponse::export_all_to(out_dir)?;
InterruptConversationResponse::export_all_to(out_dir)?;
AddConversationSubscriptionResponse::export_all_to(out_dir)?;
RemoveConversationSubscriptionResponse::export_all_to(out_dir)?;
GitDiffToRemoteResponse::export_all_to(out_dir)?;
LoginApiKeyResponse::export_all_to(out_dir)?;
LoginChatGptResponse::export_all_to(out_dir)?;
CancelLoginChatGptResponse::export_all_to(out_dir)?;
LogoutChatGptResponse::export_all_to(out_dir)?;
GetAuthStatusResponse::export_all_to(out_dir)?;
GetUserSavedConfigResponse::export_all_to(out_dir)?;
SetDefaultModelResponse::export_all_to(out_dir)?;
GetUserAgentResponse::export_all_to(out_dir)?;
UserInfoResponse::export_all_to(out_dir)?;
FuzzyFileSearchResponse::export_all_to(out_dir)?;
ExecArbitraryCommandResponse::export_all_to(out_dir)?;
// Response types for ServerRequest (mirror enum order).
ApplyPatchApprovalResponse::export_all_to(out_dir)?;
ExecCommandApprovalResponse::export_all_to(out_dir)?;
generate_index_ts(out_dir)?; generate_index_ts(out_dir)?;
// Prepend header to each generated .ts file // Prepend header to each generated .ts file

View File

@@ -16,6 +16,7 @@ icu_decimal = { workspace = true }
icu_locale_core = { workspace = true } icu_locale_core = { workspace = true }
mcp-types = { workspace = true } mcp-types = { workspace = true }
mime_guess = { workspace = true } mime_guess = { workspace = true }
paste = { workspace = true }
serde = { workspace = true, features = ["derive"] } serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true } serde_json = { workspace = true }
serde_with = { workspace = true, features = ["macros", "base64"] } serde_with = { workspace = true, features = ["macros", "base64"] }

View File

@@ -13,7 +13,9 @@ use crate::protocol::ReviewDecision;
use crate::protocol::SandboxPolicy; use crate::protocol::SandboxPolicy;
use crate::protocol::TurnAbortReason; use crate::protocol::TurnAbortReason;
use mcp_types::JSONRPCNotification; use mcp_types::JSONRPCNotification;
use mcp_types::JSONRPCRequest;
use mcp_types::RequestId; use mcp_types::RequestId;
use paste::paste;
use serde::Deserialize; use serde::Deserialize;
use serde::Serialize; use serde::Serialize;
use strum_macros::Display; use strum_macros::Display;
@@ -89,137 +91,137 @@ pub enum AuthMode {
ChatGPT, ChatGPT,
} }
/// Request from the client to the server. /// Generates an `enum ClientRequest` where each variant is a request that the
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] /// client can send to the server. Each variant has associated `params` and
#[serde(tag = "method", rename_all = "camelCase")] /// `response` types. Also generates a `export_client_responses()` function to
pub enum ClientRequest { /// export all response types to TypeScript.
macro_rules! client_request_definitions {
(
$(
$(#[$variant_meta:meta])*
$variant:ident {
params: $(#[$params_meta:meta])* $params:ty,
response: $response:ty,
}
),* $(,)?
) => {
/// Request from the client to the server.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ClientRequest {
$(
$(#[$variant_meta])*
$variant {
#[serde(rename = "id")]
request_id: RequestId,
$(#[$params_meta])*
params: $params,
},
)*
}
pub fn export_client_responses(
out_dir: &::std::path::Path,
) -> ::std::result::Result<(), ::ts_rs::ExportError> {
$(
<$response as ::ts_rs::TS>::export_all_to(out_dir)?;
)*
Ok(())
}
};
}
client_request_definitions! {
Initialize { Initialize {
#[serde(rename = "id")]
request_id: RequestId,
params: InitializeParams, params: InitializeParams,
response: InitializeResponse,
}, },
NewConversation { NewConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: NewConversationParams, params: NewConversationParams,
response: NewConversationResponse,
}, },
/// List recorded Codex conversations (rollouts) with optional pagination and search. /// List recorded Codex conversations (rollouts) with optional pagination and search.
ListConversations { ListConversations {
#[serde(rename = "id")]
request_id: RequestId,
params: ListConversationsParams, params: ListConversationsParams,
response: ListConversationsResponse,
}, },
/// Resume a recorded Codex conversation from a rollout file. /// Resume a recorded Codex conversation from a rollout file.
ResumeConversation { ResumeConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: ResumeConversationParams, params: ResumeConversationParams,
response: ResumeConversationResponse,
}, },
ArchiveConversation { ArchiveConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: ArchiveConversationParams, params: ArchiveConversationParams,
response: ArchiveConversationResponse,
}, },
SendUserMessage { SendUserMessage {
#[serde(rename = "id")]
request_id: RequestId,
params: SendUserMessageParams, params: SendUserMessageParams,
response: SendUserMessageResponse,
}, },
SendUserTurn { SendUserTurn {
#[serde(rename = "id")]
request_id: RequestId,
params: SendUserTurnParams, params: SendUserTurnParams,
response: SendUserTurnResponse,
}, },
InterruptConversation { InterruptConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: InterruptConversationParams, params: InterruptConversationParams,
response: InterruptConversationResponse,
}, },
AddConversationListener { AddConversationListener {
#[serde(rename = "id")]
request_id: RequestId,
params: AddConversationListenerParams, params: AddConversationListenerParams,
response: AddConversationSubscriptionResponse,
}, },
RemoveConversationListener { RemoveConversationListener {
#[serde(rename = "id")]
request_id: RequestId,
params: RemoveConversationListenerParams, params: RemoveConversationListenerParams,
response: RemoveConversationSubscriptionResponse,
}, },
GitDiffToRemote { GitDiffToRemote {
#[serde(rename = "id")]
request_id: RequestId,
params: GitDiffToRemoteParams, params: GitDiffToRemoteParams,
response: GitDiffToRemoteResponse,
}, },
LoginApiKey { LoginApiKey {
#[serde(rename = "id")]
request_id: RequestId,
params: LoginApiKeyParams, params: LoginApiKeyParams,
response: LoginApiKeyResponse,
}, },
LoginChatGpt { LoginChatGpt {
#[serde(rename = "id")] params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
request_id: RequestId, response: LoginChatGptResponse,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
}, },
CancelLoginChatGpt { CancelLoginChatGpt {
#[serde(rename = "id")]
request_id: RequestId,
params: CancelLoginChatGptParams, params: CancelLoginChatGptParams,
response: CancelLoginChatGptResponse,
}, },
LogoutChatGpt { LogoutChatGpt {
#[serde(rename = "id")] params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
request_id: RequestId, response: LogoutChatGptResponse,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
}, },
GetAuthStatus { GetAuthStatus {
#[serde(rename = "id")]
request_id: RequestId,
params: GetAuthStatusParams, params: GetAuthStatusParams,
response: GetAuthStatusResponse,
}, },
GetUserSavedConfig { GetUserSavedConfig {
#[serde(rename = "id")] params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
request_id: RequestId, response: GetUserSavedConfigResponse,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
}, },
SetDefaultModel { SetDefaultModel {
#[serde(rename = "id")]
request_id: RequestId,
params: SetDefaultModelParams, params: SetDefaultModelParams,
response: SetDefaultModelResponse,
}, },
GetUserAgent { GetUserAgent {
#[serde(rename = "id")] params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
request_id: RequestId, response: GetUserAgentResponse,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
}, },
UserInfo { UserInfo {
#[serde(rename = "id")] params: #[ts(type = "undefined")] #[serde(skip_serializing_if = "Option::is_none")] Option<()>,
request_id: RequestId, response: UserInfoResponse,
#[ts(type = "undefined")]
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<()>,
}, },
FuzzyFileSearch { FuzzyFileSearch {
#[serde(rename = "id")]
request_id: RequestId,
params: FuzzyFileSearchParams, params: FuzzyFileSearchParams,
response: FuzzyFileSearchResponse,
}, },
/// Execute a command (argv vector) under the server's sandbox. /// Execute a command (argv vector) under the server's sandbox.
ExecOneOffCommand { ExecOneOffCommand {
#[serde(rename = "id")]
request_id: RequestId,
params: ExecOneOffCommandParams, params: ExecOneOffCommandParams,
response: ExecOneOffCommandResponse,
}, },
} }
@@ -449,7 +451,7 @@ pub struct ExecOneOffCommandParams {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ExecArbitraryCommandResponse { pub struct ExecOneOffCommandResponse {
pub exit_code: i32, pub exit_code: i32,
pub stdout: String, pub stdout: String,
pub stderr: String, pub stderr: String,
@@ -653,30 +655,102 @@ pub enum InputItem {
}, },
} }
// TODO(mbolin): Need test to ensure these constants match the enum variants. /// Generates an `enum ServerRequest` where each variant is a request that the
/// server can send to the client along with the corresponding params and
/// response types. It also generates helper types used by the app/server
/// infrastructure (method constants, payload enum, and export helpers).
macro_rules! server_request_definitions {
(
$(
$(#[$variant_meta:meta])*
$variant:ident => $method:literal
),* $(,)?
) => {
paste! {
$(pub const [<$variant:snake:upper _METHOD>]: &str = $method;)*
pub const APPLY_PATCH_APPROVAL_METHOD: &str = "applyPatchApproval"; /// Method names for server-initiated requests (camelCase to match JSON-RPC).
pub const EXEC_COMMAND_APPROVAL_METHOD: &str = "execCommandApproval"; #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServerRequestMethod {
$( $variant ),*
}
/// Request initiated from the server and sent to the client. impl ServerRequestMethod {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] pub const fn as_str(self) -> &'static str {
#[serde(tag = "method", rename_all = "camelCase")] match self {
pub enum ServerRequest { $(ServerRequestMethod::$variant => $method,)*
}
}
}
/// Request initiated from the server and sent to the client.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(tag = "method", rename_all = "camelCase")]
pub enum ServerRequest {
$(
$(#[$variant_meta])*
$variant {
#[serde(rename = "id")]
request_id: RequestId,
params: [<$variant Params>],
},
)*
}
#[derive(Debug, Clone, PartialEq)]
pub enum ServerRequestPayload {
$( $variant([<$variant Params>]), )*
}
impl ServerRequestPayload {
pub fn method(&self) -> &'static str {
match self {
$(Self::$variant(..) => $method,)*
}
}
pub fn into_params_value(self) -> serde_json::Value {
match self {
$(Self::$variant(params) => serde_json::to_value(params).unwrap_or_default(),)*
}
}
pub fn into_request(self, request_id: RequestId) -> ServerRequest {
match self {
$(Self::$variant(params) => ServerRequest::$variant { request_id, params },)*
}
}
}
}
pub fn export_server_responses(
out_dir: &::std::path::Path,
) -> ::std::result::Result<(), ::ts_rs::ExportError> {
paste! {
$(<[<$variant Response>] as ::ts_rs::TS>::export_all_to(out_dir)?;)*
}
Ok(())
}
};
}
impl TryFrom<JSONRPCRequest> for ServerRequest {
type Error = serde_json::Error;
fn try_from(value: JSONRPCRequest) -> Result<Self, Self::Error> {
serde_json::from_value(serde_json::to_value(value)?)
}
}
server_request_definitions! {
/// Request to approve a patch. /// Request to approve a patch.
ApplyPatchApproval { ApplyPatchApproval => "applyPatchApproval",
#[serde(rename = "id")]
request_id: RequestId,
params: ApplyPatchApprovalParams,
},
/// Request to exec a command. /// Request to exec a command.
ExecCommandApproval { ExecCommandApproval => "execCommandApproval",
#[serde(rename = "id")]
request_id: RequestId,
params: ExecCommandApprovalParams,
},
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct ApplyPatchApprovalParams { pub struct ApplyPatchApprovalParams {
pub conversation_id: ConversationId, pub conversation_id: ConversationId,
/// Use to correlate this with [codex_core::protocol::PatchApplyBeginEvent] /// Use to correlate this with [codex_core::protocol::PatchApplyBeginEvent]
@@ -693,6 +767,7 @@ pub struct ApplyPatchApprovalParams {
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct ExecCommandApprovalParams { pub struct ExecCommandApprovalParams {
pub conversation_id: ConversationId, pub conversation_id: ConversationId,
/// Use to correlate this with [codex_core::protocol::ExecCommandBeginEvent] /// Use to correlate this with [codex_core::protocol::ExecCommandBeginEvent]
@@ -766,6 +841,7 @@ pub struct SessionConfiguredNotification {
pub history_log_id: u64, pub history_log_id: u64,
/// Current number of entries in the history log. /// Current number of entries in the history log.
#[ts(type = "number")]
pub history_entry_count: usize, pub history_entry_count: usize,
/// Optional initial messages (as events) for resumed sessions. /// Optional initial messages (as events) for resumed sessions.
@@ -903,4 +979,40 @@ mod tests {
); );
Ok(()) Ok(())
} }
#[test]
fn serialize_server_request() -> Result<()> {
let conversation_id = ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8")?;
let params = ExecCommandApprovalParams {
conversation_id,
call_id: "call-42".to_string(),
command: vec!["echo".to_string(), "hello".to_string()],
cwd: PathBuf::from("/tmp"),
reason: Some("because tests".to_string()),
};
let request = ServerRequest::ExecCommandApproval {
request_id: RequestId::Integer(7),
params: params.clone(),
};
assert_eq!(
json!({
"method": "execCommandApproval",
"id": 7,
"params": {
"conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8",
"callId": "call-42",
"command": ["echo", "hello"],
"cwd": "/tmp",
"reason": "because tests",
}
}),
serde_json::to_value(&request)?,
);
let payload = ServerRequestPayload::ExecCommandApproval(params);
assert_eq!("execCommandApproval", EXEC_COMMAND_APPROVAL_METHOD);
assert_eq!(EXEC_COMMAND_APPROVAL_METHOD, payload.method());
Ok(())
}
} }