diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 825b2a48..34e79320 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -823,6 +823,7 @@ version = "0.0.0" dependencies = [ "base64 0.22.1", "chrono", + "codex-protocol", "pretty_assertions", "rand 0.8.5", "reqwest", diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index d237fe67..2acc3d84 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -159,7 +159,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() codex_exec::run_main(exec_cli, codex_linux_sandbox_exe).await?; } Some(Subcommand::Mcp) => { - codex_mcp_server::run_main(codex_linux_sandbox_exe).await?; + codex_mcp_server::run_main(codex_linux_sandbox_exe, cli.config_overrides).await?; } Some(Subcommand::Login(mut login_cli)) => { prepend_config_flags(&mut login_cli.config_overrides, cli.config_overrides); diff --git a/codex-rs/login/Cargo.toml b/codex-rs/login/Cargo.toml index c1e21ca6..bf04a8e3 100644 --- a/codex-rs/login/Cargo.toml +++ b/codex-rs/login/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] base64 = "0.22" chrono = { version = "0.4", features = ["serde"] } +codex-protocol = { path = "../protocol" } rand = "0.8" reqwest = { version = "0.12", features = ["json", "blocking"] } serde = { version = "1", features = ["derive"] } diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index 8c9a5cf3..1f118823 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -29,13 +29,7 @@ mod token_data; pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY"; - -#[derive(Clone, Debug, PartialEq, Copy, Eq, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum AuthMode { - ApiKey, - ChatGPT, -} +pub use codex_protocol::mcp_protocol::AuthMode; #[derive(Debug, Clone)] pub struct CodexAuth { diff --git a/codex-rs/mcp-server/Cargo.toml b/codex-rs/mcp-server/Cargo.toml index 44a39071..cddf4cf3 100644 --- a/codex-rs/mcp-server/Cargo.toml +++ b/codex-rs/mcp-server/Cargo.toml @@ -17,7 +17,7 @@ workspace = true [dependencies] anyhow = "1" codex-arg0 = { path = "../arg0" } -codex-common = { path = "../common" } +codex-common = { path = "../common", features = ["cli"] } codex-core = { path = "../core" } codex-login = { path = "../login" } codex-protocol = { path = "../protocol" } diff --git a/codex-rs/mcp-server/src/codex_message_processor.rs b/codex-rs/mcp-server/src/codex_message_processor.rs index 07e06d66..657cda25 100644 --- a/codex-rs/mcp-server/src/codex_message_processor.rs +++ b/codex-rs/mcp-server/src/codex_message_processor.rs @@ -14,6 +14,7 @@ use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::ReviewDecision; +use codex_protocol::mcp_protocol::AuthMode; use codex_protocol::mcp_protocol::GitDiffToRemoteResponse; use mcp_types::JSONRPCErrorError; use mcp_types::RequestId; @@ -30,14 +31,17 @@ use crate::outgoing_message::OutgoingNotification; use codex_core::protocol::InputItem as CoreInputItem; use codex_core::protocol::Op; use codex_login::CLIENT_ID; +use codex_login::CodexAuth; use codex_login::ServerOptions as LoginServerOptions; use codex_login::ShutdownHandle; +use codex_login::logout; use codex_login::run_login_server; use codex_protocol::mcp_protocol::APPLY_PATCH_APPROVAL_METHOD; use codex_protocol::mcp_protocol::AddConversationListenerParams; use codex_protocol::mcp_protocol::AddConversationSubscriptionResponse; use codex_protocol::mcp_protocol::ApplyPatchApprovalParams; use codex_protocol::mcp_protocol::ApplyPatchApprovalResponse; +use codex_protocol::mcp_protocol::AuthStatusChangeNotification; use codex_protocol::mcp_protocol::ClientRequest; use codex_protocol::mcp_protocol::ConversationId; use codex_protocol::mcp_protocol::EXEC_COMMAND_APPROVAL_METHOD; @@ -46,7 +50,6 @@ use codex_protocol::mcp_protocol::ExecCommandApprovalResponse; use codex_protocol::mcp_protocol::InputItem as WireInputItem; use codex_protocol::mcp_protocol::InterruptConversationParams; use codex_protocol::mcp_protocol::InterruptConversationResponse; -use codex_protocol::mcp_protocol::LOGIN_CHATGPT_COMPLETE_EVENT; use codex_protocol::mcp_protocol::LoginChatGptCompleteNotification; use codex_protocol::mcp_protocol::LoginChatGptResponse; use codex_protocol::mcp_protocol::NewConversationParams; @@ -57,6 +60,7 @@ use codex_protocol::mcp_protocol::SendUserMessageParams; use codex_protocol::mcp_protocol::SendUserMessageResponse; use codex_protocol::mcp_protocol::SendUserTurnParams; use codex_protocol::mcp_protocol::SendUserTurnResponse; +use codex_protocol::mcp_protocol::ServerNotification; // Duration before a ChatGPT login attempt is abandoned. const LOGIN_CHATGPT_TIMEOUT: Duration = Duration::from_secs(10 * 60); @@ -77,6 +81,7 @@ pub(crate) struct CodexMessageProcessor { conversation_manager: Arc, outgoing: Arc, codex_linux_sandbox_exe: Option, + config: Arc, conversation_listeners: HashMap>, active_login: Arc>>, // Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives. @@ -88,11 +93,13 @@ impl CodexMessageProcessor { conversation_manager: Arc, outgoing: Arc, codex_linux_sandbox_exe: Option, + config: Arc, ) -> Self { Self { conversation_manager, outgoing, codex_linux_sandbox_exe, + config, conversation_listeners: HashMap::new(), active_login: Arc::new(Mutex::new(None)), pending_interrupts: Arc::new(Mutex::new(HashMap::new())), @@ -128,6 +135,12 @@ impl CodexMessageProcessor { ClientRequest::CancelLoginChatGpt { request_id, params } => { self.cancel_login_chatgpt(request_id, params.login_id).await; } + ClientRequest::LogoutChatGpt { request_id } => { + self.logout_chatgpt(request_id).await; + } + ClientRequest::GetAuthStatus { request_id } => { + self.get_auth_status(request_id).await; + } ClientRequest::GitDiffToRemote { request_id, params } => { self.git_diff_to_origin(request_id, params.cwd).await; } @@ -135,19 +148,7 @@ impl CodexMessageProcessor { } async fn login_chatgpt(&mut self, request_id: RequestId) { - let config = - match Config::load_with_cli_overrides(Default::default(), ConfigOverrides::default()) { - Ok(cfg) => cfg, - Err(err) => { - let error = JSONRPCErrorError { - code: INTERNAL_ERROR_CODE, - message: format!("error loading config for login: {err}"), - data: None, - }; - self.outgoing.send_error(request_id, error).await; - return; - } - }; + let config = self.config.as_ref(); let opts = LoginServerOptions { open_browser: false, @@ -199,19 +200,25 @@ impl CodexMessageProcessor { (false, Some("Login timed out".to_string())) } }; - let notification = LoginChatGptCompleteNotification { + let payload = LoginChatGptCompleteNotification { login_id, success, error: error_msg, }; - let params = serde_json::to_value(¬ification).ok(); outgoing_clone - .send_notification(OutgoingNotification { - method: LOGIN_CHATGPT_COMPLETE_EVENT.to_string(), - params, - }) + .send_server_notification(ServerNotification::LoginChatGptComplete(payload)) .await; + // Send an auth status change notification. + if success { + let payload = AuthStatusChangeNotification { + auth_method: Some(AuthMode::ChatGPT), + }; + outgoing_clone + .send_server_notification(ServerNotification::AuthStatusChange(payload)) + .await; + } + // Clear the active login if it matches this attempt. It may have been replaced or cancelled. let mut guard = active_login.lock().await; if guard.as_ref().map(|l| l.login_id) == Some(login_id) { @@ -260,6 +267,78 @@ impl CodexMessageProcessor { } } + async fn logout_chatgpt(&mut self, request_id: RequestId) { + { + // Cancel any active login attempt. + let mut guard = self.active_login.lock().await; + if let Some(active) = guard.take() { + active.drop(); + } + } + + // Load config to locate codex_home for persistent logout. + let config = self.config.as_ref(); + + if let Err(err) = logout(&config.codex_home) { + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("logout failed: {err}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; + } + + self.outgoing + .send_response( + request_id, + codex_protocol::mcp_protocol::LogoutChatGptResponse {}, + ) + .await; + + // Send auth status change notification. + let payload = AuthStatusChangeNotification { auth_method: None }; + self.outgoing + .send_server_notification(ServerNotification::AuthStatusChange(payload)) + .await; + } + + async fn get_auth_status(&self, request_id: RequestId) { + // Load config to determine codex_home and preferred auth method. + let config = self.config.as_ref(); + + let preferred_auth_method: AuthMode = config.preferred_auth_method; + let response = + match CodexAuth::from_codex_home(&config.codex_home, config.preferred_auth_method) { + Ok(Some(auth)) => { + // Verify that the current auth mode has a valid, non-empty token. + // If token acquisition fails or is empty, treat as unauthenticated. + let reported_auth_method = match auth.get_token().await { + Ok(token) if !token.is_empty() => Some(auth.mode), + Ok(_) => None, // Empty token + Err(err) => { + tracing::warn!("failed to get token for auth status: {err}"); + None + } + }; + codex_protocol::mcp_protocol::GetAuthStatusResponse { + auth_method: reported_auth_method, + preferred_auth_method, + } + } + Ok(None) => codex_protocol::mcp_protocol::GetAuthStatusResponse { + auth_method: None, + preferred_auth_method, + }, + Err(_) => codex_protocol::mcp_protocol::GetAuthStatusResponse { + auth_method: None, + preferred_auth_method, + }, + }; + + self.outgoing.send_response(request_id, response).await; + } + async fn process_new_conversation(&self, request_id: RequestId, params: NewConversationParams) { let config = match derive_config_from_params(params, self.codex_linux_sandbox_exe.clone()) { Ok(config) => config, diff --git a/codex-rs/mcp-server/src/lib.rs b/codex-rs/mcp-server/src/lib.rs index e22df1f9..aaf3e314 100644 --- a/codex-rs/mcp-server/src/lib.rs +++ b/codex-rs/mcp-server/src/lib.rs @@ -1,9 +1,14 @@ //! Prototype MCP server. #![deny(clippy::print_stdout, clippy::print_stderr)] +use std::io::ErrorKind; use std::io::Result as IoResult; use std::path::PathBuf; +use codex_common::CliConfigOverrides; +use codex_core::config::Config; +use codex_core::config::ConfigOverrides; + use mcp_types::JSONRPCMessage; use tokio::io::AsyncBufReadExt; use tokio::io::AsyncWriteExt; @@ -41,7 +46,10 @@ pub use crate::patch_approval::PatchApprovalResponse; /// plenty for an interactive CLI. const CHANNEL_CAPACITY: usize = 128; -pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> { +pub async fn run_main( + codex_linux_sandbox_exe: Option, + cli_config_overrides: CliConfigOverrides, +) -> IoResult<()> { // Install a simple subscriber so `tracing` output is visible. Users can // control the log level with `RUST_LOG`. tracing_subscriber::fmt() @@ -77,10 +85,27 @@ pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> } }); + // Parse CLI overrides once and derive the base Config eagerly so later + // components do not need to work with raw TOML values. + let cli_kv_overrides = cli_config_overrides.parse_overrides().map_err(|e| { + std::io::Error::new( + ErrorKind::InvalidInput, + format!("error parsing -c overrides: {e}"), + ) + })?; + let config = Config::load_with_cli_overrides(cli_kv_overrides, ConfigOverrides::default()) + .map_err(|e| { + std::io::Error::new(ErrorKind::InvalidData, format!("error loading config: {e}")) + })?; + // Task: process incoming messages. let processor_handle = tokio::spawn({ let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx); - let mut processor = MessageProcessor::new(outgoing_message_sender, codex_linux_sandbox_exe); + let mut processor = MessageProcessor::new( + outgoing_message_sender, + codex_linux_sandbox_exe, + std::sync::Arc::new(config), + ); async move { while let Some(msg) = incoming_rx.recv().await { match msg { diff --git a/codex-rs/mcp-server/src/main.rs b/codex-rs/mcp-server/src/main.rs index 60ddeeab..314944fa 100644 --- a/codex-rs/mcp-server/src/main.rs +++ b/codex-rs/mcp-server/src/main.rs @@ -1,9 +1,10 @@ use codex_arg0::arg0_dispatch_or_else; +use codex_common::CliConfigOverrides; use codex_mcp_server::run_main; fn main() -> anyhow::Result<()> { arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move { - run_main(codex_linux_sandbox_exe).await?; + run_main(codex_linux_sandbox_exe, CliConfigOverrides::default()).await?; Ok(()) }) } diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index 1ddcc6bc..a22f9c5b 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; use std::path::PathBuf; -use std::sync::Arc; use crate::codex_message_processor::CodexMessageProcessor; use crate::codex_tool_config::CodexToolCallParam; @@ -12,7 +11,7 @@ use crate::outgoing_message::OutgoingMessageSender; use codex_protocol::mcp_protocol::ClientRequest; use codex_core::ConversationManager; -use codex_core::config::Config as CodexConfig; +use codex_core::config::Config; use codex_core::protocol::Submission; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; @@ -30,6 +29,7 @@ use mcp_types::ServerCapabilitiesTools; use mcp_types::ServerNotification; use mcp_types::TextContent; use serde_json::json; +use std::sync::Arc; use tokio::sync::Mutex; use tokio::task; use uuid::Uuid; @@ -49,6 +49,7 @@ impl MessageProcessor { pub(crate) fn new( outgoing: OutgoingMessageSender, codex_linux_sandbox_exe: Option, + config: Arc, ) -> Self { let outgoing = Arc::new(outgoing); let conversation_manager = Arc::new(ConversationManager::default()); @@ -56,6 +57,7 @@ impl MessageProcessor { conversation_manager.clone(), outgoing.clone(), codex_linux_sandbox_exe.clone(), + config, ); Self { codex_message_processor, @@ -344,7 +346,7 @@ impl MessageProcessor { } } async fn handle_tool_call_codex(&self, id: RequestId, arguments: Option) { - let (initial_prompt, config): (String, CodexConfig) = match arguments { + let (initial_prompt, config): (String, Config) = match arguments { Some(json_val) => match serde_json::from_value::(json_val) { Ok(tool_cfg) => match tool_cfg.into_config(self.codex_linux_sandbox_exe.clone()) { Ok(cfg) => cfg, diff --git a/codex-rs/mcp-server/src/outgoing_message.rs b/codex-rs/mcp-server/src/outgoing_message.rs index c5e51a34..16241a08 100644 --- a/codex-rs/mcp-server/src/outgoing_message.rs +++ b/codex-rs/mcp-server/src/outgoing_message.rs @@ -3,6 +3,7 @@ use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; use codex_core::protocol::Event; +use codex_protocol::mcp_protocol::ServerNotification; use mcp_types::JSONRPC_VERSION; use mcp_types::JSONRPCError; use mcp_types::JSONRPCErrorError; @@ -121,6 +122,17 @@ impl OutgoingMessageSender { .await; } + pub(crate) async fn send_server_notification(&self, notification: ServerNotification) { + let method = format!("codex/event/{}", notification); + let params = match serde_json::to_value(¬ification) { + Ok(serde_json::Value::Object(mut map)) => map.remove("data"), + _ => None, + }; + let outgoing_message = + OutgoingMessage::Notification(OutgoingNotification { method, params }); + let _ = self.sender.send(outgoing_message).await; + } + pub(crate) async fn send_notification(&self, notification: OutgoingNotification) { let outgoing_message = OutgoingMessage::Notification(notification); let _ = self.sender.send(outgoing_message).await; diff --git a/codex-rs/protocol-ts/src/lib.rs b/codex-rs/protocol-ts/src/lib.rs index 6bbc9269..2366ae86 100644 --- a/codex-rs/protocol-ts/src/lib.rs +++ b/codex-rs/protocol-ts/src/lib.rs @@ -41,6 +41,7 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> { codex_protocol::mcp_protocol::ApplyPatchApprovalResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ExecCommandApprovalParams::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ExecCommandApprovalResponse::export_all_to(out_dir)?; + codex_protocol::mcp_protocol::ServerNotification::export_all_to(out_dir)?; // Prepend header to each generated .ts file let ts_files = ts_files_in(out_dir)?; diff --git a/codex-rs/protocol/src/mcp_protocol.rs b/codex-rs/protocol/src/mcp_protocol.rs index 68f5c01d..7cb38e15 100644 --- a/codex-rs/protocol/src/mcp_protocol.rs +++ b/codex-rs/protocol/src/mcp_protocol.rs @@ -13,6 +13,7 @@ use crate::protocol::TurnAbortReason; use mcp_types::RequestId; use serde::Deserialize; use serde::Serialize; +use strum_macros::Display; use ts_rs::TS; use uuid::Uuid; @@ -36,6 +37,13 @@ impl GitSha { } } +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, TS)] +#[serde(rename_all = "lowercase")] +pub enum AuthMode { + ApiKey, + ChatGPT, +} + /// Request from the client to the server. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[serde(tag = "method", rename_all = "camelCase")] @@ -79,6 +87,14 @@ pub enum ClientRequest { request_id: RequestId, params: CancelLoginChatGptParams, }, + LogoutChatGpt { + #[serde(rename = "id")] + request_id: RequestId, + }, + GetAuthStatus { + #[serde(rename = "id")] + request_id: RequestId, + }, GitDiffToRemote { #[serde(rename = "id")] request_id: RequestId, @@ -161,18 +177,6 @@ pub struct GitDiffToRemoteResponse { pub diff: String, } -// Event name for notifying client of login completion or failure. -pub const LOGIN_CHATGPT_COMPLETE_EVENT: &str = "codex/event/login_chatgpt_complete"; - -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] -#[serde(rename_all = "camelCase")] -pub struct LoginChatGptCompleteNotification { - pub login_id: Uuid, - pub success: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[serde(rename_all = "camelCase")] pub struct CancelLoginChatGptParams { @@ -189,6 +193,30 @@ pub struct GitDiffToRemoteParams { #[serde(rename_all = "camelCase")] pub struct CancelLoginChatGptResponse {} +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct LogoutChatGptParams { + pub login_id: Uuid, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct LogoutChatGptResponse {} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct GetAuthStatusParams { + pub login_id: Uuid, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct GetAuthStatusResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_method: Option, + pub preferred_auth_method: AuthMode, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[serde(rename_all = "camelCase")] pub struct SendUserMessageParams { @@ -321,6 +349,34 @@ pub struct ApplyPatchApprovalResponse { pub decision: ReviewDecision, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct LoginChatGptCompleteNotification { + pub login_id: Uuid, + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct AuthStatusChangeNotification { + /// Current authentication method; omitted if signed out. + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_method: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS, Display)] +#[serde(tag = "type", content = "data", rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum ServerNotification { + /// Authentication status changed + AuthStatusChange(AuthStatusChangeNotification), + + /// ChatGPT login flow completed + LoginChatGptComplete(LoginChatGptCompleteNotification), +} + #[cfg(test)] mod tests { use super::*;