From 350b00d54b91a84a3b240e7ad9ac46300b009f21 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Sun, 17 Aug 2025 10:03:52 -0700 Subject: [PATCH] Added MCP server command to enable authentication using ChatGPT (#2373) This PR adds two new APIs for the MCP server: 1) loginChatGpt, and 2) cancelLoginChatGpt. The first starts a login server and returns a local URL that allows for browser-based authentication, and the second provides a way to cancel the login attempt. If the login attempt succeeds, a notification (in the form of an event) is sent to a subscriber. I also added a timeout mechanism for the existing login server. The loginChatGpt code path uses a 10-minute timeout by default, so if the user fails to complete the login flow in that timeframe, the login server automatically shuts down. I tested the timeout code by manually setting the timeout to a much lower number and confirming that it works as expected when used e2e. --- codex-rs/Cargo.lock | 1 + codex-rs/login/src/lib.rs | 1 + codex-rs/login/src/server.rs | 98 ++++++++++- codex-rs/login/tests/login_server_e2e.rs | 2 + codex-rs/mcp-server/Cargo.toml | 1 + .../mcp-server/src/codex_message_processor.rs | 153 ++++++++++++++++++ codex-rs/mcp-server/src/wire_format.rs | 39 +++++ 7 files changed, 290 insertions(+), 5 deletions(-) diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 67f0199c..f118cb67 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -855,6 +855,7 @@ dependencies = [ "assert_cmd", "codex-arg0", "codex-core", + "codex-login", "mcp-types", "mcp_test_support", "pretty_assertions", diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index 7a5f0277..80fc0e82 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -18,6 +18,7 @@ use std::time::Duration; pub use crate::server::LoginServer; pub use crate::server::ServerOptions; +pub use crate::server::ShutdownHandle; pub use crate::server::run_login_server; pub use crate::token_data::TokenData; use crate::token_data::parse_id_token; diff --git a/codex-rs/login/src/server.rs b/codex-rs/login/src/server.rs index 9365905f..566b562d 100644 --- a/codex-rs/login/src/server.rs +++ b/codex-rs/login/src/server.rs @@ -4,7 +4,9 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; +use std::sync::mpsc; use std::thread; +use std::time::Duration; use crate::AuthDotJson; use crate::get_auth_file; @@ -27,6 +29,7 @@ pub struct ServerOptions { pub port: u16, pub open_browser: bool, pub force_state: Option, + pub login_timeout: Option, } impl ServerOptions { @@ -38,16 +41,17 @@ impl ServerOptions { port: DEFAULT_PORT, open_browser: true, force_state: None, + login_timeout: None, } } } -#[derive(Debug)] pub struct LoginServer { pub auth_url: String, pub actual_port: u16, pub server_handle: thread::JoinHandle>, pub shutdown_flag: Arc, + pub server: Arc, } impl LoginServer { @@ -59,8 +63,32 @@ impl LoginServer { } pub fn cancel(&self) { - self.shutdown_flag.store(true, Ordering::SeqCst); + shutdown(&self.shutdown_flag, &self.server); } + + pub fn cancel_handle(&self) -> ShutdownHandle { + ShutdownHandle { + shutdown_flag: self.shutdown_flag.clone(), + server: self.server.clone(), + } + } +} + +#[derive(Clone)] +pub struct ShutdownHandle { + shutdown_flag: Arc, + server: Arc, +} + +impl ShutdownHandle { + pub fn cancel(&self) { + shutdown(&self.shutdown_flag, &self.server); + } +} + +pub fn shutdown(shutdown_flag: &AtomicBool, server: &Server) { + shutdown_flag.store(true, Ordering::SeqCst); + server.unblock(); } pub fn run_login_server( @@ -80,6 +108,7 @@ pub fn run_login_server( )); } }; + let server = Arc::new(server); let redirect_uri = format!("http://localhost:{actual_port}/auth/callback"); let auth_url = build_authorize_url(&opts.issuer, &opts.client_id, &redirect_uri, &pkce, &state); @@ -89,11 +118,35 @@ pub fn run_login_server( } let shutdown_flag = shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false))); let shutdown_flag_clone = shutdown_flag.clone(); + let timeout_flag = Arc::new(AtomicBool::new(false)); + + // Channel used to signal completion to timeout watcher. + let (done_tx, done_rx) = mpsc::channel::<()>(); + + if let Some(timeout) = opts.login_timeout { + spawn_timeout_watcher( + done_rx, + timeout, + shutdown_flag.clone(), + timeout_flag.clone(), + server.clone(), + ); + } + + let server_for_thread = server.clone(); let server_handle = thread::spawn(move || { while !shutdown_flag.load(Ordering::SeqCst) { - let req = match server.recv() { + let req = match server_for_thread.recv() { Ok(r) => r, - Err(e) => return Err(io::Error::other(e)), + Err(e) => { + // If we've been asked to shut down, break gracefully so that + // we can report timeout or cancellation status uniformly. + if shutdown_flag.load(Ordering::SeqCst) { + break; + } else { + return Err(io::Error::other(e)); + } + } }; let url_raw = req.url().to_string(); @@ -198,6 +251,9 @@ pub fn run_login_server( } let _ = req.respond(resp); shutdown_flag.store(true, Ordering::SeqCst); + + // Login has succeeded, so disarm the timeout watcher. + let _ = done_tx.send(()); return Ok(()); } _ => { @@ -205,7 +261,15 @@ pub fn run_login_server( } } } - Err(io::Error::other("Login flow was not completed")) + + // Login has failed or timed out, so disarm the timeout watcher. + let _ = done_tx.send(()); + + if timeout_flag.load(Ordering::SeqCst) { + Err(io::Error::other("Login timed out")) + } else { + Err(io::Error::other("Login was not completed")) + } }); Ok(LoginServer { @@ -213,9 +277,33 @@ pub fn run_login_server( actual_port, server_handle, shutdown_flag: shutdown_flag_clone, + server, }) } +/// Spawns a detached thread that waits for either a completion signal on `done_rx` +/// or the specified `timeout` to elapse. If the timeout elapses first it marks +/// the `shutdown_flag`, records `timeout_flag`, and unblocks the HTTP server so +/// that the main server loop can exit promptly. +fn spawn_timeout_watcher( + done_rx: mpsc::Receiver<()>, + timeout: Duration, + shutdown_flag: Arc, + timeout_flag: Arc, + server: Arc, +) { + thread::spawn(move || { + if done_rx.recv_timeout(timeout).is_err() + && shutdown_flag + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + timeout_flag.store(true, Ordering::SeqCst); + server.unblock(); + } + }); +} + fn build_authorize_url( issuer: &str, client_id: &str, diff --git a/codex-rs/login/tests/login_server_e2e.rs b/codex-rs/login/tests/login_server_e2e.rs index b3e12468..6b7098b9 100644 --- a/codex-rs/login/tests/login_server_e2e.rs +++ b/codex-rs/login/tests/login_server_e2e.rs @@ -100,6 +100,7 @@ fn end_to_end_login_flow_persists_auth_json() { port: 0, open_browser: false, force_state: Some(state), + login_timeout: None, }; let server = run_login_server(opts, None).unwrap(); let login_port = server.actual_port; @@ -158,6 +159,7 @@ fn creates_missing_codex_home_dir() { port: 0, open_browser: false, force_state: Some(state), + login_timeout: None, }; let server = run_login_server(opts, None).unwrap(); let login_port = server.actual_port; diff --git a/codex-rs/mcp-server/Cargo.toml b/codex-rs/mcp-server/Cargo.toml index 2f618808..6274ba8e 100644 --- a/codex-rs/mcp-server/Cargo.toml +++ b/codex-rs/mcp-server/Cargo.toml @@ -18,6 +18,7 @@ workspace = true anyhow = "1" codex-arg0 = { path = "../arg0" } codex-core = { path = "../core" } +codex-login = { path = "../login" } mcp-types = { path = "../mcp-types" } schemars = "0.8.22" serde = { version = "1", features = ["derive"] } diff --git a/codex-rs/mcp-server/src/codex_message_processor.rs b/codex-rs/mcp-server/src/codex_message_processor.rs index d930c03b..3a859fbe 100644 --- a/codex-rs/mcp-server/src/codex_message_processor.rs +++ b/codex-rs/mcp-server/src/codex_message_processor.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; use codex_core::CodexConversation; use codex_core::ConversationManager; @@ -14,6 +15,7 @@ use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::ReviewDecision; use mcp_types::JSONRPCErrorError; use mcp_types::RequestId; +use tokio::sync::Mutex; use tokio::sync::oneshot; use tracing::error; use uuid::Uuid; @@ -36,6 +38,9 @@ use crate::wire_format::ExecCommandApprovalResponse; use crate::wire_format::InputItem as WireInputItem; use crate::wire_format::InterruptConversationParams; use crate::wire_format::InterruptConversationResponse; +use crate::wire_format::LOGIN_CHATGPT_COMPLETE_EVENT; +use crate::wire_format::LoginChatGptCompleteNotification; +use crate::wire_format::LoginChatGptResponse; use crate::wire_format::NewConversationParams; use crate::wire_format::NewConversationResponse; use crate::wire_format::RemoveConversationListenerParams; @@ -46,6 +51,24 @@ use crate::wire_format::SendUserTurnParams; use crate::wire_format::SendUserTurnResponse; use codex_core::protocol::InputItem as CoreInputItem; use codex_core::protocol::Op; +use codex_login::CLIENT_ID; +use codex_login::ServerOptions as LoginServerOptions; +use codex_login::ShutdownHandle; +use codex_login::run_login_server; + +// Duration before a ChatGPT login attempt is abandoned. +const LOGIN_CHATGPT_TIMEOUT: Duration = Duration::from_secs(10 * 60); + +struct ActiveLogin { + shutdown_handle: ShutdownHandle, + login_id: Uuid, +} + +impl ActiveLogin { + fn drop(&self) { + self.shutdown_handle.cancel(); + } +} /// Handles JSON-RPC messages for Codex conversations. pub(crate) struct CodexMessageProcessor { @@ -53,6 +76,7 @@ pub(crate) struct CodexMessageProcessor { outgoing: Arc, codex_linux_sandbox_exe: Option, conversation_listeners: HashMap>, + active_login: Arc>>, } impl CodexMessageProcessor { @@ -66,6 +90,7 @@ impl CodexMessageProcessor { outgoing, codex_linux_sandbox_exe, conversation_listeners: HashMap::new(), + active_login: Arc::new(Mutex::new(None)), } } @@ -92,6 +117,134 @@ impl CodexMessageProcessor { ClientRequest::RemoveConversationListener { request_id, params } => { self.remove_conversation_listener(request_id, params).await; } + ClientRequest::LoginChatGpt { request_id } => { + self.login_chatgpt(request_id).await; + } + ClientRequest::CancelLoginChatGpt { request_id, params } => { + self.cancel_login_chatgpt(request_id, params.login_id).await; + } + } + } + + async fn login_chatgpt(&mut self, request_id: RequestId) { + let config = + match Config::load_with_cli_overrides(Default::default(), ConfigOverrides::default()) { + Ok(cfg) => cfg, + Err(err) => { + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("error loading config for login: {err}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + }; + + let opts = LoginServerOptions { + open_browser: false, + login_timeout: Some(LOGIN_CHATGPT_TIMEOUT), + ..LoginServerOptions::new(config.codex_home.clone(), CLIENT_ID.to_string()) + }; + + enum LoginChatGptReply { + Response(LoginChatGptResponse), + Error(JSONRPCErrorError), + } + + let reply = match run_login_server(opts, None) { + Ok(server) => { + let login_id = Uuid::new_v4(); + + // Replace active login if present. + { + let mut guard = self.active_login.lock().await; + if let Some(existing) = guard.take() { + existing.drop(); + } + *guard = Some(ActiveLogin { + shutdown_handle: server.cancel_handle(), + login_id, + }); + } + + let response = LoginChatGptResponse { + login_id, + auth_url: server.auth_url.clone(), + }; + + // Spawn background task to monitor completion. + let outgoing_clone = self.outgoing.clone(); + let active_login = self.active_login.clone(); + tokio::spawn(async move { + let result = + tokio::task::spawn_blocking(move || server.block_until_done()).await; + let (success, error_msg) = match result { + Ok(Ok(())) => (true, None), + Ok(Err(err)) => (false, Some(format!("Login server error: {err}"))), + Err(join_err) => ( + false, + Some(format!("failed to join login server thread: {join_err}")), + ), + }; + let notification = LoginChatGptCompleteNotification { + login_id, + success, + error: error_msg, + }; + let params = serde_json::to_value(¬ification).ok(); + outgoing_clone + .send_notification(OutgoingNotification { + method: LOGIN_CHATGPT_COMPLETE_EVENT.to_string(), + params, + }) + .await; + + // Clear the active login if it matches this attempt. It may have been replaced or cancelled. + let mut guard = active_login.lock().await; + if guard.as_ref().map(|l| l.login_id) == Some(login_id) { + *guard = None; + } + }); + + LoginChatGptReply::Response(response) + } + Err(err) => LoginChatGptReply::Error(JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("failed to start login server: {err}"), + data: None, + }), + }; + + match reply { + LoginChatGptReply::Response(resp) => { + self.outgoing.send_response(request_id, resp).await + } + LoginChatGptReply::Error(err) => self.outgoing.send_error(request_id, err).await, + } + } + + async fn cancel_login_chatgpt(&mut self, request_id: RequestId, login_id: Uuid) { + let mut guard = self.active_login.lock().await; + if guard.as_ref().map(|l| l.login_id) == Some(login_id) { + if let Some(active) = guard.take() { + active.drop(); + } + drop(guard); + self.outgoing + .send_response( + request_id, + crate::wire_format::CancelLoginChatGptResponse {}, + ) + .await; + } else { + drop(guard); + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: format!("login id not found: {login_id}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; } } diff --git a/codex-rs/mcp-server/src/wire_format.rs b/codex-rs/mcp-server/src/wire_format.rs index 2dca1b79..f8fb53b4 100644 --- a/codex-rs/mcp-server/src/wire_format.rs +++ b/codex-rs/mcp-server/src/wire_format.rs @@ -60,6 +60,15 @@ pub enum ClientRequest { request_id: RequestId, params: RemoveConversationListenerParams, }, + LoginChatGpt { + #[serde(rename = "id")] + request_id: RequestId, + }, + CancelLoginChatGpt { + #[serde(rename = "id")] + request_id: RequestId, + params: CancelLoginChatGptParams, + }, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] @@ -122,6 +131,36 @@ pub struct AddConversationSubscriptionResponse { #[serde(rename_all = "camelCase")] pub struct RemoveConversationSubscriptionResponse {} +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct LoginChatGptResponse { + pub login_id: Uuid, + /// URL the client should open in a browser to initiate the OAuth flow. + pub auth_url: String, +} + +// Event name for notifying client of login completion or failure. +pub const LOGIN_CHATGPT_COMPLETE_EVENT: &str = "codex/event/login_chatgpt_complete"; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct LoginChatGptCompleteNotification { + pub login_id: Uuid, + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct CancelLoginChatGptParams { + pub login_id: Uuid, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct CancelLoginChatGptResponse {} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(rename_all = "camelCase")] pub struct SendUserMessageParams {