diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 6a8e76dd..9c604e79 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -807,6 +807,7 @@ dependencies = [ "toml 0.9.1", "tracing", "tracing-subscriber", + "uuid", "wiremock", ] diff --git a/codex-rs/cli/src/proto.rs b/codex-rs/cli/src/proto.rs index 14869955..ec395dd1 100644 --- a/codex-rs/cli/src/proto.rs +++ b/codex-rs/cli/src/proto.rs @@ -35,7 +35,7 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> { let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?; let ctrl_c = notify_on_sigint(); - let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await?; + let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c.clone()).await?; let codex = Arc::new(codex); // Task that reads JSON lines from stdin and forwards to Submission Queue diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index d23981b9..392e84ea 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -101,7 +101,7 @@ impl Codex { /// Spawn a new [`Codex`] and initialize the session. Returns the instance /// of `Codex` and the ID of the `SessionInitialized` event that was /// submitted to start the session. - pub async fn spawn(config: Config, ctrl_c: Arc) -> CodexResult<(Codex, String)> { + pub async fn spawn(config: Config, ctrl_c: Arc) -> CodexResult<(Codex, String, Uuid)> { // experimental resume path (undocumented) let resume_path = config.experimental_resume.clone(); info!("resume_path: {resume_path:?}"); @@ -124,7 +124,12 @@ impl Codex { }; let config = Arc::new(config); - tokio::spawn(submission_loop(config, rx_sub, tx_event, ctrl_c)); + + // Generate a unique ID for the lifetime of this Codex session. + let session_id = Uuid::new_v4(); + tokio::spawn(submission_loop( + session_id, config, rx_sub, tx_event, ctrl_c, + )); let codex = Codex { next_id: AtomicU64::new(0), tx_sub, @@ -132,7 +137,7 @@ impl Codex { }; let init_id = codex.submit(configure_session).await?; - Ok((codex, init_id)) + Ok((codex, init_id, session_id)) } /// Submit the `op` wrapped in a `Submission` with a unique ID. @@ -521,14 +526,12 @@ impl AgentTask { } async fn submission_loop( + mut session_id: Uuid, config: Arc, rx_sub: Receiver, tx_event: Sender, ctrl_c: Arc, ) { - // Generate a unique ID for the lifetime of this Codex session. - let mut session_id = Uuid::new_v4(); - let mut sess: Option> = None; // shorthand - send an event when there is no active session let send_no_session_event = |sub_id: String| async { diff --git a/codex-rs/core/src/codex_wrapper.rs b/codex-rs/core/src/codex_wrapper.rs index f2ece22d..31f8295e 100644 --- a/codex-rs/core/src/codex_wrapper.rs +++ b/codex-rs/core/src/codex_wrapper.rs @@ -6,15 +6,16 @@ use crate::protocol::Event; use crate::protocol::EventMsg; use crate::util::notify_on_sigint; use tokio::sync::Notify; +use uuid::Uuid; /// Spawn a new [`Codex`] and initialize the session. /// /// Returns the wrapped [`Codex`] **and** the `SessionInitialized` event that /// is received as a response to the initial `ConfigureSession` submission so /// that callers can surface the information to the UI. -pub async fn init_codex(config: Config) -> anyhow::Result<(Codex, Event, Arc)> { +pub async fn init_codex(config: Config) -> anyhow::Result<(Codex, Event, Arc, Uuid)> { let ctrl_c = notify_on_sigint(); - let (codex, init_id) = Codex::spawn(config, ctrl_c.clone()).await?; + let (codex, init_id, session_id) = Codex::spawn(config, ctrl_c.clone()).await?; // The first event must be `SessionInitialized`. Validate and forward it to // the caller so that they can display it in the conversation history. @@ -33,5 +34,5 @@ pub async fn init_codex(config: Config) -> anyhow::Result<(Codex, Event, Arc Result { let mut config = load_default_config_for_test(&codex_home); config.model_provider.request_max_retries = Some(2); config.model_provider.stream_max_retries = Some(2); - let (agent, _init_id) = Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?; + let (agent, _init_id, _session_id) = + Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?; Ok(agent) } diff --git a/codex-rs/core/tests/previous_response_id.rs b/codex-rs/core/tests/previous_response_id.rs index 9630cc10..6523c764 100644 --- a/codex-rs/core/tests/previous_response_id.rs +++ b/codex-rs/core/tests/previous_response_id.rs @@ -113,7 +113,7 @@ async fn keeps_previous_response_id_between_tasks() { let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); - let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap(); + let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap(); // Task 1 – triggers first request (no previous_response_id) codex diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index f2de5de1..1a0455be 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -95,7 +95,7 @@ async fn retries_on_early_close() { let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; - let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap(); + let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c).await.unwrap(); codex .submit(Op::UserInput { diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index b557c893..769d3c3b 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -153,7 +153,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any .with_writer(std::io::stderr) .try_init(); - let (codex_wrapper, event, ctrl_c) = codex_wrapper::init_codex(config).await?; + let (codex_wrapper, event, ctrl_c, _session_id) = codex_wrapper::init_codex(config).await?; let codex = Arc::new(codex_wrapper); info!("Codex initialized with event: {event:?}"); diff --git a/codex-rs/mcp-server/Cargo.toml b/codex-rs/mcp-server/Cargo.toml index f43b101b..e524576a 100644 --- a/codex-rs/mcp-server/Cargo.toml +++ b/codex-rs/mcp-server/Cargo.toml @@ -33,6 +33,7 @@ tokio = { version = "1", features = [ "rt-multi-thread", "signal", ] } +uuid = { version = "1", features = ["serde", "v4"] } [dev-dependencies] assert_cmd = "2" diff --git a/codex-rs/mcp-server/src/codex_tool_config.rs b/codex-rs/mcp-server/src/codex_tool_config.rs index 9a31dbcc..54d108c0 100644 --- a/codex-rs/mcp-server/src/codex_tool_config.rs +++ b/codex-rs/mcp-server/src/codex_tool_config.rs @@ -160,6 +160,47 @@ impl CodexToolCallParam { } } +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) struct CodexToolCallReplyParam { + /// The *session id* for this conversation. + pub session_id: String, + + /// The *next user prompt* to continue the Codex conversation. + pub prompt: String, +} + +/// Builds a `Tool` definition for the `codex-reply` tool-call. +pub(crate) fn create_tool_for_codex_tool_call_reply_param() -> Tool { + let schema = SchemaSettings::draft2019_09() + .with(|s| { + s.inline_subschemas = true; + s.option_add_null_type = false; + }) + .into_generator() + .into_root_schema_for::(); + + #[expect(clippy::expect_used)] + let schema_value = + serde_json::to_value(&schema).expect("Codex reply tool schema should serialise to JSON"); + + let tool_input_schema = + serde_json::from_value::(schema_value).unwrap_or_else(|e| { + panic!("failed to create Tool from schema: {e}"); + }); + + Tool { + name: "codex-reply".to_string(), + title: Some("Codex Reply".to_string()), + input_schema: tool_input_schema, + output_schema: None, + description: Some( + "Continue a Codex session by providing the session id and prompt.".to_string(), + ), + annotations: None, + } +} + #[cfg(test)] mod tests { use super::*; @@ -235,4 +276,34 @@ mod tests { }); assert_eq!(expected_tool_json, tool_json); } + + #[test] + fn verify_codex_tool_reply_json_schema() { + let tool = create_tool_for_codex_tool_call_reply_param(); + #[expect(clippy::expect_used)] + let tool_json = serde_json::to_value(&tool).expect("tool serializes"); + let expected_tool_json = serde_json::json!({ + "description": "Continue a Codex session by providing the session id and prompt.", + "inputSchema": { + "properties": { + "prompt": { + "description": "The *next user prompt* to continue the Codex conversation.", + "type": "string" + }, + "sessionId": { + "description": "The *session id* for this conversation.", + "type": "string" + }, + }, + "required": [ + "prompt", + "sessionId", + ], + "type": "object", + }, + "name": "codex-reply", + "title": "Codex Reply", + }); + assert_eq!(expected_tool_json, tool_json); + } } diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index 163055de..3893a485 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -2,6 +2,7 @@ //! Tokio task. Separated from `message_processor.rs` to keep that file small //! and to make future feature-growth easier to manage. +use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -27,7 +28,9 @@ use mcp_types::TextContent; use serde::Deserialize; use serde::Serialize; use serde_json::json; +use tokio::sync::Mutex; use tracing::error; +use uuid::Uuid; use crate::outgoing_message::OutgoingMessageSender; @@ -42,8 +45,9 @@ pub async fn run_codex_tool_session( initial_prompt: String, config: CodexConfig, outgoing: Arc, + session_map: Arc>>>, ) { - let (codex, first_event, _ctrl_c) = match init_codex(config).await { + let (codex, first_event, _ctrl_c, session_id) = match init_codex(config).await { Ok(res) => res, Err(e) => { let result = CallToolResult { @@ -61,6 +65,11 @@ pub async fn run_codex_tool_session( }; let codex = Arc::new(codex); + // update the session map so we can retrieve the session in a reply, and then drop it, since + // we no longer need it for this function + session_map.lock().await.insert(session_id, codex.clone()); + drop(session_map); + // Send initial SessionConfigured event. outgoing.send_event_as_notification(&first_event).await; @@ -85,6 +94,37 @@ pub async fn run_codex_tool_session( tracing::error!("Failed to submit initial prompt: {e}"); } + run_codex_tool_session_inner(codex, outgoing, id).await; +} + +pub async fn run_codex_tool_session_reply( + codex: Arc, + outgoing: Arc, + request_id: RequestId, + prompt: String, +) { + if let Err(e) = codex + .submit(Op::UserInput { + items: vec![InputItem::Text { text: prompt }], + }) + .await + { + tracing::error!("Failed to submit user input: {e}"); + } + + run_codex_tool_session_inner(codex, outgoing, request_id).await; +} + +async fn run_codex_tool_session_inner( + codex: Arc, + outgoing: Arc, + request_id: RequestId, +) { + let sub_id = match &request_id { + RequestId::String(s) => s.clone(), + RequestId::Integer(n) => n.to_string(), + }; + // Stream events until the task needs to pause for user interaction or // completes. loop { @@ -128,7 +168,7 @@ pub async fn run_codex_tool_session( outgoing .send_error( - id.clone(), + request_id.clone(), JSONRPCErrorError { code: INVALID_PARAMS_ERROR_CODE, message, @@ -168,7 +208,9 @@ pub async fn run_codex_tool_session( is_error: None, structured_content: None, }; - outgoing.send_response(id.clone(), result.into()).await; + outgoing + .send_response(request_id.clone(), result.into()) + .await; // Continue, don't break so the session continues. continue; } @@ -186,7 +228,9 @@ pub async fn run_codex_tool_session( is_error: None, structured_content: None, }; - outgoing.send_response(id.clone(), result.into()).await; + outgoing + .send_response(request_id.clone(), result.into()) + .await; break; } EventMsg::SessionConfigured(_) => { @@ -234,7 +278,9 @@ pub async fn run_codex_tool_session( // structured way. structured_content: None, }; - outgoing.send_response(id.clone(), result.into()).await; + outgoing + .send_response(request_id.clone(), result.into()) + .await; break; } } diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index 61c320ed..e72a52e0 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -1,10 +1,14 @@ +use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use crate::codex_tool_config::CodexToolCallParam; +use crate::codex_tool_config::CodexToolCallReplyParam; use crate::codex_tool_config::create_tool_for_codex_tool_call_param; +use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param; use crate::outgoing_message::OutgoingMessageSender; +use codex_core::Codex; use codex_core::config::Config as CodexConfig; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; @@ -22,12 +26,15 @@ use mcp_types::ServerCapabilitiesTools; use mcp_types::ServerNotification; use mcp_types::TextContent; use serde_json::json; +use tokio::sync::Mutex; use tokio::task; +use uuid::Uuid; pub(crate) struct MessageProcessor { outgoing: Arc, initialized: bool, codex_linux_sandbox_exe: Option, + session_map: Arc>>>, } impl MessageProcessor { @@ -41,6 +48,7 @@ impl MessageProcessor { outgoing: Arc::new(outgoing), initialized: false, codex_linux_sandbox_exe, + session_map: Arc::new(Mutex::new(HashMap::new())), } } @@ -272,7 +280,10 @@ impl MessageProcessor { ) { tracing::trace!("tools/list -> {params:?}"); let result = ListToolsResult { - tools: vec![create_tool_for_codex_tool_call_param()], + tools: vec![ + create_tool_for_codex_tool_call_param(), + create_tool_for_codex_tool_call_reply_param(), + ], next_cursor: None, }; @@ -288,23 +299,29 @@ impl MessageProcessor { tracing::info!("tools/call -> params: {:?}", params); let CallToolRequestParams { name, arguments } = params; - // We only support the "codex" tool for now. - if name != "codex" { - // Tool not found – return error result so the LLM can react. - let result = CallToolResult { - content: vec![ContentBlock::TextContent(TextContent { - r#type: "text".to_string(), - text: format!("Unknown tool '{name}'"), - annotations: None, - })], - is_error: Some(true), - structured_content: None, - }; - self.send_response::(id, result) - .await; - return; + match name.as_str() { + "codex" => self.handle_tool_call_codex(id, arguments).await, + "codex-reply" => { + self.handle_tool_call_codex_session_reply(id, arguments) + .await + } + _ => { + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_string(), + text: format!("Unknown tool '{name}'"), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + self.send_response::(id, result) + .await; + } } + } + async fn handle_tool_call_codex(&self, id: RequestId, arguments: Option) { let (initial_prompt, config): (String, CodexConfig) = match arguments { Some(json_val) => match serde_json::from_value::(json_val) { Ok(tool_cfg) => match tool_cfg.into_config(self.codex_linux_sandbox_exe.clone()) { @@ -359,15 +376,127 @@ impl MessageProcessor { } }; - // Clone outgoing sender to move into async task. + // Clone outgoing and session map to move into async task. let outgoing = self.outgoing.clone(); + let session_map = self.session_map.clone(); // Spawn an async task to handle the Codex session so that we do not // block the synchronous message-processing loop. task::spawn(async move { // Run the Codex session and stream events back to the client. - crate::codex_tool_runner::run_codex_tool_session(id, initial_prompt, config, outgoing) - .await; + crate::codex_tool_runner::run_codex_tool_session( + id, + initial_prompt, + config, + outgoing, + session_map, + ) + .await; + }); + } + + async fn handle_tool_call_codex_session_reply( + &self, + request_id: RequestId, + arguments: Option, + ) { + tracing::info!("tools/call -> params: {:?}", arguments); + + // parse arguments + let CodexToolCallReplyParam { session_id, prompt } = match arguments { + Some(json_val) => match serde_json::from_value::(json_val) { + Ok(params) => params, + Err(e) => { + tracing::error!("Failed to parse Codex tool call reply parameters: {e}"); + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_owned(), + text: format!("Failed to parse configuration for Codex tool: {e}"), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + self.send_response::(request_id, result) + .await; + return; + } + }, + None => { + tracing::error!( + "Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required." + ); + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_owned(), + text: "Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required.".to_owned(), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + self.send_response::(request_id, result) + .await; + return; + } + }; + let session_id = match Uuid::parse_str(&session_id) { + Ok(id) => id, + Err(e) => { + tracing::error!("Failed to parse session_id: {e}"); + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_owned(), + text: format!("Failed to parse session_id: {e}"), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + self.send_response::(request_id, result) + .await; + return; + } + }; + + // load codex from session map + let session_map_mutex = Arc::clone(&self.session_map); + + // Clone outgoing and session map to move into async task. + let outgoing = self.outgoing.clone(); + + // Spawn an async task to handle the Codex session so that we do not + // block the synchronous message-processing loop. + task::spawn(async move { + let session_map = session_map_mutex.lock().await; + let codex = match session_map.get(&session_id) { + Some(codex) => codex, + None => { + tracing::warn!("Session not found for session_id: {session_id}"); + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_owned(), + text: format!("Session not found for session_id: {session_id}"), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + // unwrap_or_default is fine here because we know the result is valid JSON + outgoing + .send_response(request_id, serde_json::to_value(result).unwrap_or_default()) + .await; + return; + } + }; + + crate::codex_tool_runner::run_codex_tool_session_reply( + codex.clone(), + outgoing, + request_id, + prompt.clone(), + ) + .await; }); } diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index c22bbf97..c70c6f6d 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -96,14 +96,15 @@ impl ChatWidget<'_> { // Create the Codex asynchronously so the UI loads as quickly as possible. let config_for_agent_loop = config.clone(); tokio::spawn(async move { - let (codex, session_event, _ctrl_c) = match init_codex(config_for_agent_loop).await { - Ok(vals) => vals, - Err(e) => { - // TODO: surface this error to the user. - tracing::error!("failed to initialize codex: {e}"); - return; - } - }; + let (codex, session_event, _ctrl_c, _session_id) = + match init_codex(config_for_agent_loop).await { + Ok(vals) => vals, + Err(e) => { + // TODO: surface this error to the user. + tracing::error!("failed to initialize codex: {e}"); + return; + } + }; // Forward the captured `SessionInitialized` event that was consumed // inside `init_codex()` so it can be rendered in the UI.