diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index cda2e81d..d1b9cb39 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -785,6 +785,7 @@ dependencies = [ "codex-protocol", "mcp-types", "mcp_test_support", + "os_info", "pretty_assertions", "schemars 0.8.22", "serde", @@ -843,6 +844,7 @@ dependencies = [ "anyhow", "clap", "codex-protocol", + "mcp-types", "ts-rs", ] @@ -2680,9 +2682,11 @@ version = "0.0.0" dependencies = [ "anyhow", "assert_cmd", + "codex-core", "codex-mcp-server", "codex-protocol", "mcp-types", + "os_info", "pretty_assertions", "serde", "serde_json", diff --git a/codex-rs/core/src/default_client.rs b/codex-rs/core/src/default_client.rs index a2192ffa..36212f25 100644 --- a/codex-rs/core/src/default_client.rs +++ b/codex-rs/core/src/default_client.rs @@ -1,5 +1,23 @@ use reqwest::header::HeaderValue; use std::sync::LazyLock; +use std::sync::Mutex; + +/// Set this to add a suffix to the User-Agent string. +/// +/// It is not ideal that we're using a global singleton for this. +/// This is primarily designed to differentiate MCP clients from each other. +/// Because there can only be one MCP server per process, it should be safe for this to be a global static. +/// However, future users of this should use this with caution as a result. +/// In addition, we want to be confident that this value is used for ALL clients and doing that requires a +/// lot of wiring and it's easy to miss code paths by doing so. +/// See https://github.com/openai/codex/pull/3388/files for an example of what that would look like. +/// Finally, we want to make sure this is set for ALL mcp clients without needing to know a special env var +/// or having to set data that they already specified in the mcp initialize request somewhere else. +/// +/// A space is automatically added between the suffix and the rest of the User-Agent string. +/// The full user agent string is returned from the mcp initialize response. +/// Parenthesis will be added by Codex. This should only specify what goes inside of the parenthesis. +pub static USER_AGENT_SUFFIX: LazyLock>> = LazyLock::new(|| Mutex::new(None)); pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE"; @@ -32,14 +50,58 @@ pub static ORIGINATOR: LazyLock = LazyLock::new(|| { pub fn get_codex_user_agent() -> String { let build_version = env!("CARGO_PKG_VERSION"); let os_info = os_info::get(); - format!( + let prefix = format!( "{}/{build_version} ({} {}; {}) {}", ORIGINATOR.value.as_str(), os_info.os_type(), os_info.version(), os_info.architecture().unwrap_or("unknown"), crate::terminal::user_agent() - ) + ); + let suffix = USER_AGENT_SUFFIX + .lock() + .ok() + .and_then(|guard| guard.clone()); + let suffix = suffix + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .map_or_else(String::new, |value| format!(" ({value})")); + + let candidate = format!("{prefix}{suffix}"); + sanitize_user_agent(candidate, &prefix) +} + +/// Sanitize the user agent string. +/// +/// Invalid characters are replaced with an underscore. +/// +/// If the user agent fails to parse, it falls back to fallback and then to ORIGINATOR. +fn sanitize_user_agent(candidate: String, fallback: &str) -> String { + if HeaderValue::from_str(candidate.as_str()).is_ok() { + return candidate; + } + + let sanitized: String = candidate + .chars() + .map(|ch| if matches!(ch, ' '..='~') { ch } else { '_' }) + .collect(); + if !sanitized.is_empty() && HeaderValue::from_str(sanitized.as_str()).is_ok() { + tracing::warn!( + "Sanitized Codex user agent because provided suffix contained invalid header characters" + ); + sanitized + } else if HeaderValue::from_str(fallback).is_ok() { + tracing::warn!( + "Falling back to base Codex user agent because provided suffix could not be sanitized" + ); + fallback.to_string() + } else { + tracing::warn!( + "Falling back to default Codex originator because base user agent string is invalid" + ); + ORIGINATOR.value.clone() + } } /// Create a reqwest client with default `originator` and `User-Agent` headers set. @@ -114,6 +176,28 @@ mod tests { assert_eq!(ua_header.to_str().unwrap(), expected_ua); } + #[test] + fn test_invalid_suffix_is_sanitized() { + let prefix = "codex_cli_rs/0.0.0"; + let suffix = "bad\rsuffix"; + + assert_eq!( + sanitize_user_agent(format!("{prefix} ({suffix})"), prefix), + "codex_cli_rs/0.0.0 (bad_suffix)" + ); + } + + #[test] + fn test_invalid_suffix_is_sanitized2() { + let prefix = "codex_cli_rs/0.0.0"; + let suffix = "bad\0suffix"; + + assert_eq!( + sanitize_user_agent(format!("{prefix} ({suffix})"), prefix), + "codex_cli_rs/0.0.0 (bad_suffix)" + ); + } + #[test] #[cfg(target_os = "macos")] fn test_macos() { diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 387ea173..9fc72755 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -17,7 +17,7 @@ use anyhow::Result; use anyhow::anyhow; use codex_mcp_client::McpClient; use mcp_types::ClientCapabilities; -use mcp_types::Implementation; +use mcp_types::McpClientInfo; use mcp_types::Tool; use serde_json::json; @@ -159,7 +159,7 @@ impl McpConnectionManager { // indicates this should be an empty object. elicitation: Some(json!({})), }, - client_info: Implementation { + client_info: McpClientInfo { name: "codex-mcp-client".to_owned(), version: env!("CARGO_PKG_VERSION").to_owned(), title: Some("Codex".into()), diff --git a/codex-rs/mcp-client/src/main.rs b/codex-rs/mcp-client/src/main.rs index 10cfe389..a9f4e335 100644 --- a/codex-rs/mcp-client/src/main.rs +++ b/codex-rs/mcp-client/src/main.rs @@ -17,10 +17,10 @@ use anyhow::Context; use anyhow::Result; use codex_mcp_client::McpClient; use mcp_types::ClientCapabilities; -use mcp_types::Implementation; use mcp_types::InitializeRequestParams; use mcp_types::ListToolsRequestParams; use mcp_types::MCP_SCHEMA_VERSION; +use mcp_types::McpClientInfo; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -60,7 +60,7 @@ async fn main() -> Result<()> { sampling: None, elicitation: None, }, - client_info: Implementation { + client_info: McpClientInfo { name: "codex-mcp-client".to_owned(), version: env!("CARGO_PKG_VERSION").to_owned(), title: Some("Codex".to_string()), diff --git a/codex-rs/mcp-server/Cargo.toml b/codex-rs/mcp-server/Cargo.toml index 335e4ce6..c219c729 100644 --- a/codex-rs/mcp-server/Cargo.toml +++ b/codex-rs/mcp-server/Cargo.toml @@ -41,6 +41,7 @@ uuid = { version = "1", features = ["serde", "v4"] } [dev-dependencies] assert_cmd = "2" mcp_test_support = { path = "tests/common" } +os_info = "3.12.0" pretty_assertions = "1.4.1" tempfile = "3" wiremock = "0.6" diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index dccd86e9..8018095d 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -14,6 +14,8 @@ use codex_protocol::mcp_protocol::ConversationId; use codex_core::AuthManager; use codex_core::ConversationManager; use codex_core::config::Config; +use codex_core::default_client::USER_AGENT_SUFFIX; +use codex_core::default_client::get_codex_user_agent; use codex_core::protocol::Submission; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; @@ -208,6 +210,14 @@ impl MessageProcessor { return; } + let client_info = params.client_info; + let name = client_info.name; + let version = client_info.version; + let user_agent_suffix = format!("{name}; {version}"); + if let Ok(mut suffix) = USER_AGENT_SUFFIX.lock() { + *suffix = Some(user_agent_suffix); + } + self.initialized = true; // Build a minimal InitializeResult. Fill with placeholders. @@ -224,10 +234,11 @@ impl MessageProcessor { }, instructions: None, protocol_version: params.protocol_version.clone(), - server_info: mcp_types::Implementation { + server_info: mcp_types::McpServerInfo { name: "codex-mcp-server".to_string(), version: env!("CARGO_PKG_VERSION").to_string(), title: Some("Codex".to_string()), + user_agent: get_codex_user_agent(), }, }; diff --git a/codex-rs/mcp-server/tests/common/Cargo.toml b/codex-rs/mcp-server/tests/common/Cargo.toml index 88ad93f5..6bdef423 100644 --- a/codex-rs/mcp-server/tests/common/Cargo.toml +++ b/codex-rs/mcp-server/tests/common/Cargo.toml @@ -9,9 +9,11 @@ path = "lib.rs" [dependencies] anyhow = "1" assert_cmd = "2" +codex-core = { path = "../../../core" } codex-mcp-server = { path = "../.." } codex-protocol = { path = "../../../protocol" } mcp-types = { path = "../../../mcp-types" } +os_info = "3.12.0" pretty_assertions = "1.4.1" serde = { version = "1" } serde_json = "1" diff --git a/codex-rs/mcp-server/tests/common/mcp_process.rs b/codex-rs/mcp-server/tests/common/mcp_process.rs index 64f2cc38..66d546e2 100644 --- a/codex-rs/mcp-server/tests/common/mcp_process.rs +++ b/codex-rs/mcp-server/tests/common/mcp_process.rs @@ -26,13 +26,13 @@ use codex_protocol::mcp_protocol::SendUserTurnParams; use mcp_types::CallToolRequestParams; use mcp_types::ClientCapabilities; -use mcp_types::Implementation; use mcp_types::InitializeRequestParams; use mcp_types::JSONRPC_VERSION; use mcp_types::JSONRPCMessage; use mcp_types::JSONRPCNotification; use mcp_types::JSONRPCRequest; use mcp_types::JSONRPCResponse; +use mcp_types::McpClientInfo; use mcp_types::ModelContextProtocolNotification; use mcp_types::ModelContextProtocolRequest; use mcp_types::RequestId; @@ -111,7 +111,7 @@ impl McpProcess { roots: None, sampling: None, }, - client_info: Implementation { + client_info: McpClientInfo { name: "elicitation test".into(), title: Some("Elicitation Test".into()), version: "0.0.0".into(), @@ -129,6 +129,14 @@ impl McpProcess { .await?; let initialized = self.read_jsonrpc_message().await?; + let os_info = os_info::get(); + let user_agent = format!( + "codex_cli_rs/0.0.0 ({} {}; {}) {} (elicitation test; 0.0.0)", + os_info.os_type(), + os_info.version(), + os_info.architecture().unwrap_or("unknown"), + codex_core::terminal::user_agent() + ); assert_eq!( JSONRPCMessage::Response(JSONRPCResponse { jsonrpc: JSONRPC_VERSION.into(), @@ -142,7 +150,8 @@ impl McpProcess { "serverInfo": { "name": "codex-mcp-server", "title": "Codex", - "version": "0.0.0" + "version": "0.0.0", + "user_agent": user_agent }, "protocolVersion": mcp_types::MCP_SCHEMA_VERSION }) diff --git a/codex-rs/mcp-server/tests/suite/user_agent.rs b/codex-rs/mcp-server/tests/suite/user_agent.rs index 19de87a5..718e1452 100644 --- a/codex-rs/mcp-server/tests/suite/user_agent.rs +++ b/codex-rs/mcp-server/tests/suite/user_agent.rs @@ -1,4 +1,3 @@ -use codex_core::default_client::get_codex_user_agent; use codex_protocol::mcp_protocol::GetUserAgentResponse; use mcp_test_support::McpProcess; use mcp_test_support::to_response; @@ -34,11 +33,18 @@ async fn get_user_agent_returns_current_codex_user_agent() { .expect("getUserAgent timeout") .expect("getUserAgent response"); + let os_info = os_info::get(); + let user_agent = format!( + "codex_cli_rs/0.0.0 ({} {}; {}) {} (elicitation test; 0.0.0)", + os_info.os_type(), + os_info.version(), + os_info.architecture().unwrap_or("unknown"), + codex_core::terminal::user_agent() + ); + let received: GetUserAgentResponse = to_response(response).expect("deserialize getUserAgent response"); - let expected = GetUserAgentResponse { - user_agent: get_codex_user_agent(), - }; + let expected = GetUserAgentResponse { user_agent }; assert_eq!(received, expected); } diff --git a/codex-rs/mcp-types/src/lib.rs b/codex-rs/mcp-types/src/lib.rs index 988bec2a..2f862be8 100644 --- a/codex-rs/mcp-types/src/lib.rs +++ b/codex-rs/mcp-types/src/lib.rs @@ -482,13 +482,23 @@ pub struct ImageContent { /// Describes the name and version of an MCP implementation, with an optional title for UI representation. #[derive(Debug, Clone, PartialEq, Deserialize, Serialize, TS)] -pub struct Implementation { +pub struct McpClientInfo { pub name: String, #[serde(default, skip_serializing_if = "Option::is_none")] pub title: Option, pub version: String, } +/// Describes the name and version of an MCP implementation, with an optional title for UI representation. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, TS)] +pub struct McpServerInfo { + pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, + pub version: String, + pub user_agent: String, +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize, TS)] pub enum InitializeRequest {} @@ -502,7 +512,7 @@ impl ModelContextProtocolRequest for InitializeRequest { pub struct InitializeRequestParams { pub capabilities: ClientCapabilities, #[serde(rename = "clientInfo")] - pub client_info: Implementation, + pub client_info: McpClientInfo, #[serde(rename = "protocolVersion")] pub protocol_version: String, } @@ -516,7 +526,7 @@ pub struct InitializeResult { #[serde(rename = "protocolVersion")] pub protocol_version: String, #[serde(rename = "serverInfo")] - pub server_info: Implementation, + pub server_info: McpServerInfo, } impl From for serde_json::Value { diff --git a/codex-rs/mcp-types/tests/suite/initialize.rs b/codex-rs/mcp-types/tests/suite/initialize.rs index 04778f2a..2bd3c789 100644 --- a/codex-rs/mcp-types/tests/suite/initialize.rs +++ b/codex-rs/mcp-types/tests/suite/initialize.rs @@ -1,10 +1,10 @@ use mcp_types::ClientCapabilities; use mcp_types::ClientRequest; -use mcp_types::Implementation; use mcp_types::InitializeRequestParams; use mcp_types::JSONRPC_VERSION; use mcp_types::JSONRPCMessage; use mcp_types::JSONRPCRequest; +use mcp_types::McpClientInfo; use mcp_types::RequestId; use serde_json::json; @@ -58,7 +58,7 @@ fn deserialize_initialize_request() { sampling: None, elicitation: None, }, - client_info: Implementation { + client_info: McpClientInfo { name: "acme-client".into(), title: Some("Acme".to_string()), version: "1.2.3".into(), diff --git a/codex-rs/protocol-ts/Cargo.toml b/codex-rs/protocol-ts/Cargo.toml index 9faa9344..1131a621 100644 --- a/codex-rs/protocol-ts/Cargo.toml +++ b/codex-rs/protocol-ts/Cargo.toml @@ -16,6 +16,7 @@ path = "src/main.rs" [dependencies] anyhow = "1" +mcp-types = { path = "../mcp-types" } codex-protocol = { path = "../protocol" } ts-rs = "11" clap = { version = "4", features = ["derive"] } diff --git a/codex-rs/protocol-ts/src/lib.rs b/codex-rs/protocol-ts/src/lib.rs index 08cb9407..776c8ba3 100644 --- a/codex-rs/protocol-ts/src/lib.rs +++ b/codex-rs/protocol-ts/src/lib.rs @@ -16,6 +16,7 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> { ensure_dir(out_dir)?; // Generate TS bindings + mcp_types::InitializeResult::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ConversationId::export_all_to(out_dir)?; codex_protocol::mcp_protocol::InputItem::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ClientRequest::export_all_to(out_dir)?;