Add support for a separate chatgpt auth endpoint (#1712)
Adds a `CodexAuth` type that encapsulates information about available auth modes and logic for refreshing the token. Changes `Responses` API to send requests to different endpoints based on the auth type. Updates login_with_chatgpt to support API-less mode and skip the key exchange.
This commit is contained in:
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
@@ -673,7 +673,9 @@ dependencies = [
|
|||||||
"async-channel",
|
"async-channel",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
"chrono",
|
||||||
"codex-apply-patch",
|
"codex-apply-patch",
|
||||||
|
"codex-login",
|
||||||
"codex-mcp-client",
|
"codex-mcp-client",
|
||||||
"core_test_support",
|
"core_test_support",
|
||||||
"dirs",
|
"dirs",
|
||||||
|
|||||||
@@ -21,10 +21,14 @@ pub(crate) async fn chatgpt_get_request<T: DeserializeOwned>(
|
|||||||
let token =
|
let token =
|
||||||
get_chatgpt_token_data().ok_or_else(|| anyhow::anyhow!("ChatGPT token not available"))?;
|
get_chatgpt_token_data().ok_or_else(|| anyhow::anyhow!("ChatGPT token not available"))?;
|
||||||
|
|
||||||
|
let account_id = token.account_id.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("ChatGPT account ID not available, please re-run `codex login`")
|
||||||
|
});
|
||||||
|
|
||||||
let response = client
|
let response = client
|
||||||
.get(&url)
|
.get(&url)
|
||||||
.bearer_auth(&token.access_token)
|
.bearer_auth(&token.access_token)
|
||||||
.header("chatgpt-account-id", &token.account_id)
|
.header("chatgpt-account-id", account_id?)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("User-Agent", "codex-cli")
|
.header("User-Agent", "codex-cli")
|
||||||
.send()
|
.send()
|
||||||
|
|||||||
@@ -18,7 +18,10 @@ pub fn set_chatgpt_token_data(value: TokenData) {
|
|||||||
|
|
||||||
/// Initialize the ChatGPT token from auth.json file
|
/// Initialize the ChatGPT token from auth.json file
|
||||||
pub async fn init_chatgpt_token_from_auth(codex_home: &Path) -> std::io::Result<()> {
|
pub async fn init_chatgpt_token_from_auth(codex_home: &Path) -> std::io::Result<()> {
|
||||||
let auth_json = codex_login::try_read_auth_json(codex_home).await?;
|
let auth = codex_login::load_auth(codex_home)?;
|
||||||
set_chatgpt_token_data(auth_json.tokens.clone());
|
if let Some(auth) = auth {
|
||||||
|
let token_data = auth.get_token_data().await?;
|
||||||
|
set_chatgpt_token_data(token_data);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ use codex_core::config::Config;
|
|||||||
use codex_core::config::ConfigOverrides;
|
use codex_core::config::ConfigOverrides;
|
||||||
use codex_core::protocol::Submission;
|
use codex_core::protocol::Submission;
|
||||||
use codex_core::util::notify_on_sigint;
|
use codex_core::util::notify_on_sigint;
|
||||||
|
use codex_login::load_auth;
|
||||||
use tokio::io::AsyncBufReadExt;
|
use tokio::io::AsyncBufReadExt;
|
||||||
use tokio::io::BufReader;
|
use tokio::io::BufReader;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
@@ -35,8 +36,9 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> {
|
|||||||
.map_err(anyhow::Error::msg)?;
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?;
|
let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?;
|
||||||
|
let auth = load_auth(&config.codex_home)?;
|
||||||
let ctrl_c = notify_on_sigint();
|
let ctrl_c = notify_on_sigint();
|
||||||
let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c.clone()).await?;
|
let CodexSpawnOk { codex, .. } = Codex::spawn(config, auth, ctrl_c.clone()).await?;
|
||||||
let codex = Arc::new(codex);
|
let codex = Arc::new(codex);
|
||||||
|
|
||||||
// Task that reads JSON lines from stdin and forwards to Submission Queue
|
// Task that reads JSON lines from stdin and forwards to Submission Queue
|
||||||
|
|||||||
@@ -110,12 +110,15 @@ stream_idle_timeout_ms = 300000 # 5m idle timeout
|
|||||||
```
|
```
|
||||||
|
|
||||||
#### request_max_retries
|
#### request_max_retries
|
||||||
|
|
||||||
How many times Codex will retry a failed HTTP request to the model provider. Defaults to `4`.
|
How many times Codex will retry a failed HTTP request to the model provider. Defaults to `4`.
|
||||||
|
|
||||||
#### stream_max_retries
|
#### stream_max_retries
|
||||||
|
|
||||||
Number of times Codex will attempt to reconnect when a streaming response is interrupted. Defaults to `10`.
|
Number of times Codex will attempt to reconnect when a streaming response is interrupted. Defaults to `10`.
|
||||||
|
|
||||||
#### stream_idle_timeout_ms
|
#### stream_idle_timeout_ms
|
||||||
|
|
||||||
How long Codex will wait for activity on a streaming response before treating the connection as lost. Defaults to `300_000` (5 minutes).
|
How long Codex will wait for activity on a streaming response before treating the connection as lost. Defaults to `300_000` (5 minutes).
|
||||||
|
|
||||||
## model_provider
|
## model_provider
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ base64 = "0.22"
|
|||||||
bytes = "1.10.1"
|
bytes = "1.10.1"
|
||||||
codex-apply-patch = { path = "../apply-patch" }
|
codex-apply-patch = { path = "../apply-patch" }
|
||||||
codex-mcp-client = { path = "../mcp-client" }
|
codex-mcp-client = { path = "../mcp-client" }
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
codex-login = { path = "../login" }
|
||||||
dirs = "6"
|
dirs = "6"
|
||||||
env-flags = "0.1.1"
|
env-flags = "0.1.1"
|
||||||
eventsource-stream = "0.2.3"
|
eventsource-stream = "0.2.3"
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ use std::path::Path;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
use codex_login::AuthMode;
|
||||||
|
use codex_login::CodexAuth;
|
||||||
use eventsource_stream::Eventsource;
|
use eventsource_stream::Eventsource;
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use reqwest::StatusCode;
|
use reqwest::StatusCode;
|
||||||
@@ -28,6 +30,7 @@ use crate::config::Config;
|
|||||||
use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
|
use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||||
use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||||
use crate::error::CodexErr;
|
use crate::error::CodexErr;
|
||||||
|
use crate::error::EnvVarError;
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||||
use crate::model_provider_info::ModelProviderInfo;
|
use crate::model_provider_info::ModelProviderInfo;
|
||||||
@@ -41,6 +44,7 @@ use std::sync::Arc;
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ModelClient {
|
pub struct ModelClient {
|
||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
|
auth: Option<CodexAuth>,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
provider: ModelProviderInfo,
|
provider: ModelProviderInfo,
|
||||||
session_id: Uuid,
|
session_id: Uuid,
|
||||||
@@ -51,6 +55,7 @@ pub struct ModelClient {
|
|||||||
impl ModelClient {
|
impl ModelClient {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
|
auth: Option<CodexAuth>,
|
||||||
provider: ModelProviderInfo,
|
provider: ModelProviderInfo,
|
||||||
effort: ReasoningEffortConfig,
|
effort: ReasoningEffortConfig,
|
||||||
summary: ReasoningSummaryConfig,
|
summary: ReasoningSummaryConfig,
|
||||||
@@ -58,6 +63,7 @@ impl ModelClient {
|
|||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
config,
|
config,
|
||||||
|
auth,
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
provider,
|
provider,
|
||||||
session_id,
|
session_id,
|
||||||
@@ -115,6 +121,25 @@ impl ModelClient {
|
|||||||
return stream_from_fixture(path, self.provider.clone()).await;
|
return stream_from_fixture(path, self.provider.clone()).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let auth = self.auth.as_ref().ok_or_else(|| {
|
||||||
|
CodexErr::EnvVar(EnvVarError {
|
||||||
|
var: "OPENAI_API_KEY".to_string(),
|
||||||
|
instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".to_string()),
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let store = prompt.store && auth.mode != AuthMode::ChatGPT;
|
||||||
|
|
||||||
|
let base_url = match self.provider.base_url.clone() {
|
||||||
|
Some(url) => url,
|
||||||
|
None => match auth.mode {
|
||||||
|
AuthMode::ChatGPT => "https://chatgpt.com/backend-api/codex".to_string(),
|
||||||
|
AuthMode::ApiKey => "https://api.openai.com/v1".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let token = auth.get_token().await?;
|
||||||
|
|
||||||
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
||||||
let tools_json = create_tools_json_for_responses_api(
|
let tools_json = create_tools_json_for_responses_api(
|
||||||
prompt,
|
prompt,
|
||||||
@@ -125,7 +150,7 @@ impl ModelClient {
|
|||||||
|
|
||||||
// Request encrypted COT if we are not storing responses,
|
// Request encrypted COT if we are not storing responses,
|
||||||
// otherwise reasoning items will be referenced by ID
|
// otherwise reasoning items will be referenced by ID
|
||||||
let include = if !prompt.store && reasoning.is_some() {
|
let include: Vec<String> = if !store && reasoning.is_some() {
|
||||||
vec!["reasoning.encrypted_content".to_string()]
|
vec!["reasoning.encrypted_content".to_string()]
|
||||||
} else {
|
} else {
|
||||||
vec![]
|
vec![]
|
||||||
@@ -139,8 +164,7 @@ impl ModelClient {
|
|||||||
tool_choice: "auto",
|
tool_choice: "auto",
|
||||||
parallel_tool_calls: false,
|
parallel_tool_calls: false,
|
||||||
reasoning,
|
reasoning,
|
||||||
store: prompt.store,
|
store,
|
||||||
// TODO: make this configurable
|
|
||||||
stream: true,
|
stream: true,
|
||||||
include,
|
include,
|
||||||
};
|
};
|
||||||
@@ -153,17 +177,21 @@ impl ModelClient {
|
|||||||
|
|
||||||
let mut attempt = 0;
|
let mut attempt = 0;
|
||||||
let max_retries = self.provider.request_max_retries();
|
let max_retries = self.provider.request_max_retries();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
attempt += 1;
|
attempt += 1;
|
||||||
|
|
||||||
let req_builder = self
|
let req_builder = self
|
||||||
.provider
|
.client
|
||||||
.create_request_builder(&self.client)?
|
.post(format!("{base_url}/responses"))
|
||||||
.header("OpenAI-Beta", "responses=experimental")
|
.header("OpenAI-Beta", "responses=experimental")
|
||||||
.header("session_id", self.session_id.to_string())
|
.header("session_id", self.session_id.to_string())
|
||||||
|
.bearer_auth(&token)
|
||||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||||
.json(&payload);
|
.json(&payload);
|
||||||
|
|
||||||
|
let req_builder = self.provider.apply_http_headers(req_builder);
|
||||||
|
|
||||||
let res = req_builder.send().await;
|
let res = req_builder.send().await;
|
||||||
if let Ok(resp) = &res {
|
if let Ok(resp) = &res {
|
||||||
trace!(
|
trace!(
|
||||||
@@ -572,7 +600,7 @@ mod tests {
|
|||||||
|
|
||||||
let provider = ModelProviderInfo {
|
let provider = ModelProviderInfo {
|
||||||
name: "test".to_string(),
|
name: "test".to_string(),
|
||||||
base_url: "https://test.com".to_string(),
|
base_url: Some("https://test.com".to_string()),
|
||||||
env_key: Some("TEST_API_KEY".to_string()),
|
env_key: Some("TEST_API_KEY".to_string()),
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
wire_api: WireApi::Responses,
|
wire_api: WireApi::Responses,
|
||||||
@@ -582,6 +610,7 @@ mod tests {
|
|||||||
request_max_retries: Some(0),
|
request_max_retries: Some(0),
|
||||||
stream_max_retries: Some(0),
|
stream_max_retries: Some(0),
|
||||||
stream_idle_timeout_ms: Some(1000),
|
stream_idle_timeout_ms: Some(1000),
|
||||||
|
requires_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let events = collect_events(
|
let events = collect_events(
|
||||||
@@ -631,7 +660,7 @@ mod tests {
|
|||||||
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
|
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
|
||||||
let provider = ModelProviderInfo {
|
let provider = ModelProviderInfo {
|
||||||
name: "test".to_string(),
|
name: "test".to_string(),
|
||||||
base_url: "https://test.com".to_string(),
|
base_url: Some("https://test.com".to_string()),
|
||||||
env_key: Some("TEST_API_KEY".to_string()),
|
env_key: Some("TEST_API_KEY".to_string()),
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
wire_api: WireApi::Responses,
|
wire_api: WireApi::Responses,
|
||||||
@@ -641,6 +670,7 @@ mod tests {
|
|||||||
request_max_retries: Some(0),
|
request_max_retries: Some(0),
|
||||||
stream_max_retries: Some(0),
|
stream_max_retries: Some(0),
|
||||||
stream_idle_timeout_ms: Some(1000),
|
stream_idle_timeout_ms: Some(1000),
|
||||||
|
requires_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let events = collect_events(&[sse1.as_bytes()], provider).await;
|
let events = collect_events(&[sse1.as_bytes()], provider).await;
|
||||||
@@ -733,7 +763,7 @@ mod tests {
|
|||||||
|
|
||||||
let provider = ModelProviderInfo {
|
let provider = ModelProviderInfo {
|
||||||
name: "test".to_string(),
|
name: "test".to_string(),
|
||||||
base_url: "https://test.com".to_string(),
|
base_url: Some("https://test.com".to_string()),
|
||||||
env_key: Some("TEST_API_KEY".to_string()),
|
env_key: Some("TEST_API_KEY".to_string()),
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
wire_api: WireApi::Responses,
|
wire_api: WireApi::Responses,
|
||||||
@@ -743,6 +773,7 @@ mod tests {
|
|||||||
request_max_retries: Some(0),
|
request_max_retries: Some(0),
|
||||||
stream_max_retries: Some(0),
|
stream_max_retries: Some(0),
|
||||||
stream_idle_timeout_ms: Some(1000),
|
stream_idle_timeout_ms: Some(1000),
|
||||||
|
requires_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let out = run_sse(evs, provider).await;
|
let out = run_sse(evs, provider).await;
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ use async_channel::Sender;
|
|||||||
use codex_apply_patch::ApplyPatchAction;
|
use codex_apply_patch::ApplyPatchAction;
|
||||||
use codex_apply_patch::MaybeApplyPatchVerified;
|
use codex_apply_patch::MaybeApplyPatchVerified;
|
||||||
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||||
|
use codex_login::CodexAuth;
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use mcp_types::CallToolResult;
|
use mcp_types::CallToolResult;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
@@ -103,7 +104,11 @@ pub struct CodexSpawnOk {
|
|||||||
|
|
||||||
impl Codex {
|
impl Codex {
|
||||||
/// Spawn a new [`Codex`] and initialize the session.
|
/// Spawn a new [`Codex`] and initialize the session.
|
||||||
pub async fn spawn(config: Config, ctrl_c: Arc<Notify>) -> CodexResult<CodexSpawnOk> {
|
pub async fn spawn(
|
||||||
|
config: Config,
|
||||||
|
auth: Option<CodexAuth>,
|
||||||
|
ctrl_c: Arc<Notify>,
|
||||||
|
) -> CodexResult<CodexSpawnOk> {
|
||||||
// experimental resume path (undocumented)
|
// experimental resume path (undocumented)
|
||||||
let resume_path = config.experimental_resume.clone();
|
let resume_path = config.experimental_resume.clone();
|
||||||
info!("resume_path: {resume_path:?}");
|
info!("resume_path: {resume_path:?}");
|
||||||
@@ -132,7 +137,7 @@ impl Codex {
|
|||||||
// Generate a unique ID for the lifetime of this Codex session.
|
// Generate a unique ID for the lifetime of this Codex session.
|
||||||
let session_id = Uuid::new_v4();
|
let session_id = Uuid::new_v4();
|
||||||
tokio::spawn(submission_loop(
|
tokio::spawn(submission_loop(
|
||||||
session_id, config, rx_sub, tx_event, ctrl_c,
|
session_id, config, auth, rx_sub, tx_event, ctrl_c,
|
||||||
));
|
));
|
||||||
let codex = Codex {
|
let codex = Codex {
|
||||||
next_id: AtomicU64::new(0),
|
next_id: AtomicU64::new(0),
|
||||||
@@ -525,6 +530,7 @@ impl AgentTask {
|
|||||||
async fn submission_loop(
|
async fn submission_loop(
|
||||||
mut session_id: Uuid,
|
mut session_id: Uuid,
|
||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
|
auth: Option<CodexAuth>,
|
||||||
rx_sub: Receiver<Submission>,
|
rx_sub: Receiver<Submission>,
|
||||||
tx_event: Sender<Event>,
|
tx_event: Sender<Event>,
|
||||||
ctrl_c: Arc<Notify>,
|
ctrl_c: Arc<Notify>,
|
||||||
@@ -636,6 +642,7 @@ async fn submission_loop(
|
|||||||
|
|
||||||
let client = ModelClient::new(
|
let client = ModelClient::new(
|
||||||
config.clone(),
|
config.clone(),
|
||||||
|
auth.clone(),
|
||||||
provider.clone(),
|
provider.clone(),
|
||||||
model_reasoning_effort,
|
model_reasoning_effort,
|
||||||
model_reasoning_summary,
|
model_reasoning_summary,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ use crate::config::Config;
|
|||||||
use crate::protocol::Event;
|
use crate::protocol::Event;
|
||||||
use crate::protocol::EventMsg;
|
use crate::protocol::EventMsg;
|
||||||
use crate::util::notify_on_sigint;
|
use crate::util::notify_on_sigint;
|
||||||
|
use codex_login::load_auth;
|
||||||
use tokio::sync::Notify;
|
use tokio::sync::Notify;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -25,11 +26,12 @@ pub struct CodexConversation {
|
|||||||
/// that callers can surface the information to the UI.
|
/// that callers can surface the information to the UI.
|
||||||
pub async fn init_codex(config: Config) -> anyhow::Result<CodexConversation> {
|
pub async fn init_codex(config: Config) -> anyhow::Result<CodexConversation> {
|
||||||
let ctrl_c = notify_on_sigint();
|
let ctrl_c = notify_on_sigint();
|
||||||
|
let auth = load_auth(&config.codex_home)?;
|
||||||
let CodexSpawnOk {
|
let CodexSpawnOk {
|
||||||
codex,
|
codex,
|
||||||
init_id,
|
init_id,
|
||||||
session_id,
|
session_id,
|
||||||
} = Codex::spawn(config, ctrl_c.clone()).await?;
|
} = Codex::spawn(config, auth, ctrl_c.clone()).await?;
|
||||||
|
|
||||||
// The first event must be `SessionInitialized`. Validate and forward it to
|
// The first event must be `SessionInitialized`. Validate and forward it to
|
||||||
// the caller so that they can display it in the conversation history.
|
// the caller so that they can display it in the conversation history.
|
||||||
|
|||||||
@@ -526,6 +526,7 @@ impl Config {
|
|||||||
.chatgpt_base_url
|
.chatgpt_base_url
|
||||||
.or(cfg.chatgpt_base_url)
|
.or(cfg.chatgpt_base_url)
|
||||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||||
|
|
||||||
experimental_resume,
|
experimental_resume,
|
||||||
include_plan_tool: include_plan_tool.unwrap_or(false),
|
include_plan_tool: include_plan_tool.unwrap_or(false),
|
||||||
};
|
};
|
||||||
@@ -794,7 +795,7 @@ disable_response_storage = true
|
|||||||
|
|
||||||
let openai_chat_completions_provider = ModelProviderInfo {
|
let openai_chat_completions_provider = ModelProviderInfo {
|
||||||
name: "OpenAI using Chat Completions".to_string(),
|
name: "OpenAI using Chat Completions".to_string(),
|
||||||
base_url: "https://api.openai.com/v1".to_string(),
|
base_url: Some("https://api.openai.com/v1".to_string()),
|
||||||
env_key: Some("OPENAI_API_KEY".to_string()),
|
env_key: Some("OPENAI_API_KEY".to_string()),
|
||||||
wire_api: crate::WireApi::Chat,
|
wire_api: crate::WireApi::Chat,
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
@@ -804,6 +805,7 @@ disable_response_storage = true
|
|||||||
request_max_retries: Some(4),
|
request_max_retries: Some(4),
|
||||||
stream_max_retries: Some(10),
|
stream_max_retries: Some(10),
|
||||||
stream_idle_timeout_ms: Some(300_000),
|
stream_idle_timeout_ms: Some(300_000),
|
||||||
|
requires_auth: false,
|
||||||
};
|
};
|
||||||
let model_provider_map = {
|
let model_provider_map = {
|
||||||
let mut model_provider_map = built_in_model_providers();
|
let mut model_provider_map = built_in_model_providers();
|
||||||
|
|||||||
@@ -30,8 +30,8 @@ mod message_history;
|
|||||||
mod model_provider_info;
|
mod model_provider_info;
|
||||||
pub use model_provider_info::ModelProviderInfo;
|
pub use model_provider_info::ModelProviderInfo;
|
||||||
pub use model_provider_info::WireApi;
|
pub use model_provider_info::WireApi;
|
||||||
|
pub use model_provider_info::built_in_model_providers;
|
||||||
mod models;
|
mod models;
|
||||||
pub mod openai_api_key;
|
|
||||||
mod openai_model_info;
|
mod openai_model_info;
|
||||||
mod openai_tools;
|
mod openai_tools;
|
||||||
pub mod plan_tool;
|
pub mod plan_tool;
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ use std::env::VarError;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::error::EnvVarError;
|
use crate::error::EnvVarError;
|
||||||
use crate::openai_api_key::get_openai_api_key;
|
|
||||||
|
|
||||||
/// Value for the `OpenAI-Originator` header that is sent with requests to
|
/// Value for the `OpenAI-Originator` header that is sent with requests to
|
||||||
/// OpenAI.
|
/// OpenAI.
|
||||||
@@ -30,7 +29,7 @@ const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
|||||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum WireApi {
|
pub enum WireApi {
|
||||||
/// The experimental "Responses" API exposed by OpenAI at `/v1/responses`.
|
/// The Responses API exposed by OpenAI at `/v1/responses`.
|
||||||
Responses,
|
Responses,
|
||||||
|
|
||||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||||
@@ -44,7 +43,7 @@ pub struct ModelProviderInfo {
|
|||||||
/// Friendly display name.
|
/// Friendly display name.
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// Base URL for the provider's OpenAI-compatible API.
|
/// Base URL for the provider's OpenAI-compatible API.
|
||||||
pub base_url: String,
|
pub base_url: Option<String>,
|
||||||
/// Environment variable that stores the user's API key for this provider.
|
/// Environment variable that stores the user's API key for this provider.
|
||||||
pub env_key: Option<String>,
|
pub env_key: Option<String>,
|
||||||
|
|
||||||
@@ -78,6 +77,10 @@ pub struct ModelProviderInfo {
|
|||||||
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
|
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
|
||||||
/// the connection as lost.
|
/// the connection as lost.
|
||||||
pub stream_idle_timeout_ms: Option<u64>,
|
pub stream_idle_timeout_ms: Option<u64>,
|
||||||
|
|
||||||
|
/// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
|
||||||
|
#[serde(default)]
|
||||||
|
pub requires_auth: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelProviderInfo {
|
impl ModelProviderInfo {
|
||||||
@@ -93,11 +96,11 @@ impl ModelProviderInfo {
|
|||||||
&'a self,
|
&'a self,
|
||||||
client: &'a reqwest::Client,
|
client: &'a reqwest::Client,
|
||||||
) -> crate::error::Result<reqwest::RequestBuilder> {
|
) -> crate::error::Result<reqwest::RequestBuilder> {
|
||||||
let api_key = self.api_key()?;
|
|
||||||
|
|
||||||
let url = self.get_full_url();
|
let url = self.get_full_url();
|
||||||
|
|
||||||
let mut builder = client.post(url);
|
let mut builder = client.post(url);
|
||||||
|
|
||||||
|
let api_key = self.api_key()?;
|
||||||
if let Some(key) = api_key {
|
if let Some(key) = api_key {
|
||||||
builder = builder.bearer_auth(key);
|
builder = builder.bearer_auth(key);
|
||||||
}
|
}
|
||||||
@@ -117,9 +120,15 @@ impl ModelProviderInfo {
|
|||||||
.join("&");
|
.join("&");
|
||||||
format!("?{full_params}")
|
format!("?{full_params}")
|
||||||
});
|
});
|
||||||
let base_url = &self.base_url;
|
let base_url = self
|
||||||
|
.base_url
|
||||||
|
.clone()
|
||||||
|
.unwrap_or("https://api.openai.com/v1".to_string());
|
||||||
|
|
||||||
match self.wire_api {
|
match self.wire_api {
|
||||||
WireApi::Responses => format!("{base_url}/responses{query_string}"),
|
WireApi::Responses => {
|
||||||
|
format!("{base_url}/responses{query_string}")
|
||||||
|
}
|
||||||
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
|
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -127,7 +136,10 @@ impl ModelProviderInfo {
|
|||||||
/// Apply provider-specific HTTP headers (both static and environment-based)
|
/// Apply provider-specific HTTP headers (both static and environment-based)
|
||||||
/// onto an existing `reqwest::RequestBuilder` and return the updated
|
/// onto an existing `reqwest::RequestBuilder` and return the updated
|
||||||
/// builder.
|
/// builder.
|
||||||
fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
pub fn apply_http_headers(
|
||||||
|
&self,
|
||||||
|
mut builder: reqwest::RequestBuilder,
|
||||||
|
) -> reqwest::RequestBuilder {
|
||||||
if let Some(extra) = &self.http_headers {
|
if let Some(extra) = &self.http_headers {
|
||||||
for (k, v) in extra {
|
for (k, v) in extra {
|
||||||
builder = builder.header(k, v);
|
builder = builder.header(k, v);
|
||||||
@@ -152,11 +164,7 @@ impl ModelProviderInfo {
|
|||||||
fn api_key(&self) -> crate::error::Result<Option<String>> {
|
fn api_key(&self) -> crate::error::Result<Option<String>> {
|
||||||
match &self.env_key {
|
match &self.env_key {
|
||||||
Some(env_key) => {
|
Some(env_key) => {
|
||||||
let env_value = if env_key == crate::openai_api_key::OPENAI_API_KEY_ENV_VAR {
|
let env_value = std::env::var(env_key);
|
||||||
get_openai_api_key().map_or_else(|| Err(VarError::NotPresent), Ok)
|
|
||||||
} else {
|
|
||||||
std::env::var(env_key)
|
|
||||||
};
|
|
||||||
env_value
|
env_value
|
||||||
.and_then(|v| {
|
.and_then(|v| {
|
||||||
if v.trim().is_empty() {
|
if v.trim().is_empty() {
|
||||||
@@ -204,47 +212,51 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
|||||||
// providers are bundled with Codex CLI, so we only include the OpenAI
|
// providers are bundled with Codex CLI, so we only include the OpenAI
|
||||||
// provider by default. Users are encouraged to add to `model_providers`
|
// provider by default. Users are encouraged to add to `model_providers`
|
||||||
// in config.toml to add their own providers.
|
// in config.toml to add their own providers.
|
||||||
[
|
[(
|
||||||
(
|
"openai",
|
||||||
"openai",
|
P {
|
||||||
P {
|
name: "OpenAI".into(),
|
||||||
name: "OpenAI".into(),
|
// Allow users to override the default OpenAI endpoint by
|
||||||
// Allow users to override the default OpenAI endpoint by
|
// exporting `OPENAI_BASE_URL`. This is useful when pointing
|
||||||
// exporting `OPENAI_BASE_URL`. This is useful when pointing
|
// Codex at a proxy, mock server, or Azure-style deployment
|
||||||
// Codex at a proxy, mock server, or Azure-style deployment
|
// without requiring a full TOML override for the built-in
|
||||||
// without requiring a full TOML override for the built-in
|
// OpenAI provider.
|
||||||
// OpenAI provider.
|
base_url: std::env::var("OPENAI_BASE_URL")
|
||||||
base_url: std::env::var("OPENAI_BASE_URL")
|
.ok()
|
||||||
.ok()
|
.filter(|v| !v.trim().is_empty()),
|
||||||
.filter(|v| !v.trim().is_empty())
|
env_key: None,
|
||||||
.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
|
env_key_instructions: None,
|
||||||
env_key: Some("OPENAI_API_KEY".into()),
|
wire_api: WireApi::Responses,
|
||||||
env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()),
|
query_params: None,
|
||||||
wire_api: WireApi::Responses,
|
http_headers: Some(
|
||||||
query_params: None,
|
[
|
||||||
http_headers: Some(
|
(
|
||||||
[
|
"originator".to_string(),
|
||||||
("originator".to_string(), OPENAI_ORIGINATOR_HEADER.to_string()),
|
OPENAI_ORIGINATOR_HEADER.to_string(),
|
||||||
("version".to_string(), env!("CARGO_PKG_VERSION").to_string()),
|
),
|
||||||
]
|
("version".to_string(), env!("CARGO_PKG_VERSION").to_string()),
|
||||||
.into_iter()
|
]
|
||||||
.collect(),
|
.into_iter()
|
||||||
),
|
.collect(),
|
||||||
env_http_headers: Some(
|
),
|
||||||
[
|
env_http_headers: Some(
|
||||||
("OpenAI-Organization".to_string(), "OPENAI_ORGANIZATION".to_string()),
|
[
|
||||||
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
|
(
|
||||||
]
|
"OpenAI-Organization".to_string(),
|
||||||
.into_iter()
|
"OPENAI_ORGANIZATION".to_string(),
|
||||||
.collect(),
|
),
|
||||||
),
|
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
|
||||||
// Use global defaults for retry/timeout unless overridden in config.toml.
|
]
|
||||||
request_max_retries: None,
|
.into_iter()
|
||||||
stream_max_retries: None,
|
.collect(),
|
||||||
stream_idle_timeout_ms: None,
|
),
|
||||||
},
|
// Use global defaults for retry/timeout unless overridden in config.toml.
|
||||||
),
|
request_max_retries: None,
|
||||||
]
|
stream_max_retries: None,
|
||||||
|
stream_idle_timeout_ms: None,
|
||||||
|
requires_auth: true,
|
||||||
|
},
|
||||||
|
)]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(k, v)| (k.to_string(), v))
|
.map(|(k, v)| (k.to_string(), v))
|
||||||
.collect()
|
.collect()
|
||||||
@@ -264,7 +276,7 @@ base_url = "http://localhost:11434/v1"
|
|||||||
"#;
|
"#;
|
||||||
let expected_provider = ModelProviderInfo {
|
let expected_provider = ModelProviderInfo {
|
||||||
name: "Ollama".into(),
|
name: "Ollama".into(),
|
||||||
base_url: "http://localhost:11434/v1".into(),
|
base_url: Some("http://localhost:11434/v1".into()),
|
||||||
env_key: None,
|
env_key: None,
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
wire_api: WireApi::Chat,
|
wire_api: WireApi::Chat,
|
||||||
@@ -274,6 +286,7 @@ base_url = "http://localhost:11434/v1"
|
|||||||
request_max_retries: None,
|
request_max_retries: None,
|
||||||
stream_max_retries: None,
|
stream_max_retries: None,
|
||||||
stream_idle_timeout_ms: None,
|
stream_idle_timeout_ms: None,
|
||||||
|
requires_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||||
@@ -290,7 +303,7 @@ query_params = { api-version = "2025-04-01-preview" }
|
|||||||
"#;
|
"#;
|
||||||
let expected_provider = ModelProviderInfo {
|
let expected_provider = ModelProviderInfo {
|
||||||
name: "Azure".into(),
|
name: "Azure".into(),
|
||||||
base_url: "https://xxxxx.openai.azure.com/openai".into(),
|
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
|
||||||
env_key: Some("AZURE_OPENAI_API_KEY".into()),
|
env_key: Some("AZURE_OPENAI_API_KEY".into()),
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
wire_api: WireApi::Chat,
|
wire_api: WireApi::Chat,
|
||||||
@@ -302,6 +315,7 @@ query_params = { api-version = "2025-04-01-preview" }
|
|||||||
request_max_retries: None,
|
request_max_retries: None,
|
||||||
stream_max_retries: None,
|
stream_max_retries: None,
|
||||||
stream_idle_timeout_ms: None,
|
stream_idle_timeout_ms: None,
|
||||||
|
requires_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||||
@@ -319,7 +333,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
|||||||
"#;
|
"#;
|
||||||
let expected_provider = ModelProviderInfo {
|
let expected_provider = ModelProviderInfo {
|
||||||
name: "Example".into(),
|
name: "Example".into(),
|
||||||
base_url: "https://example.com".into(),
|
base_url: Some("https://example.com".into()),
|
||||||
env_key: Some("API_KEY".into()),
|
env_key: Some("API_KEY".into()),
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
wire_api: WireApi::Chat,
|
wire_api: WireApi::Chat,
|
||||||
@@ -333,6 +347,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
|||||||
request_max_retries: None,
|
request_max_retries: None,
|
||||||
stream_max_retries: None,
|
stream_max_retries: None,
|
||||||
stream_idle_timeout_ms: None,
|
stream_idle_timeout_ms: None,
|
||||||
|
requires_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
use std::env;
|
|
||||||
use std::sync::LazyLock;
|
|
||||||
use std::sync::RwLock;
|
|
||||||
|
|
||||||
pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
|
|
||||||
|
|
||||||
static OPENAI_API_KEY: LazyLock<RwLock<Option<String>>> = LazyLock::new(|| {
|
|
||||||
let val = env::var(OPENAI_API_KEY_ENV_VAR)
|
|
||||||
.ok()
|
|
||||||
.and_then(|s| if s.is_empty() { None } else { Some(s) });
|
|
||||||
RwLock::new(val)
|
|
||||||
});
|
|
||||||
|
|
||||||
pub fn get_openai_api_key() -> Option<String> {
|
|
||||||
#![allow(clippy::unwrap_used)]
|
|
||||||
OPENAI_API_KEY.read().unwrap().clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_openai_api_key(value: String) {
|
|
||||||
#![allow(clippy::unwrap_used)]
|
|
||||||
if !value.is_empty() {
|
|
||||||
*OPENAI_API_KEY.write().unwrap() = Some(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,11 +1,19 @@
|
|||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use chrono::Utc;
|
||||||
use codex_core::Codex;
|
use codex_core::Codex;
|
||||||
use codex_core::CodexSpawnOk;
|
use codex_core::CodexSpawnOk;
|
||||||
use codex_core::ModelProviderInfo;
|
use codex_core::ModelProviderInfo;
|
||||||
|
use codex_core::built_in_model_providers;
|
||||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||||
use codex_core::protocol::EventMsg;
|
use codex_core::protocol::EventMsg;
|
||||||
use codex_core::protocol::InputItem;
|
use codex_core::protocol::InputItem;
|
||||||
use codex_core::protocol::Op;
|
use codex_core::protocol::Op;
|
||||||
use codex_core::protocol::SessionConfiguredEvent;
|
use codex_core::protocol::SessionConfiguredEvent;
|
||||||
|
use codex_login::AuthDotJson;
|
||||||
|
use codex_login::AuthMode;
|
||||||
|
use codex_login::CodexAuth;
|
||||||
|
use codex_login::TokenData;
|
||||||
use core_test_support::load_default_config_for_test;
|
use core_test_support::load_default_config_for_test;
|
||||||
use core_test_support::load_sse_fixture_with_id;
|
use core_test_support::load_sse_fixture_with_id;
|
||||||
use core_test_support::wait_for_event;
|
use core_test_support::wait_for_event;
|
||||||
@@ -48,32 +56,23 @@ async fn includes_session_id_and_model_headers_in_request() {
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
let model_provider = ModelProviderInfo {
|
let model_provider = ModelProviderInfo {
|
||||||
name: "openai".into(),
|
base_url: Some(format!("{}/v1", server.uri())),
|
||||||
base_url: format!("{}/v1", server.uri()),
|
..built_in_model_providers()["openai"].clone()
|
||||||
// Environment variable that should exist in the test environment.
|
|
||||||
// ModelClient will return an error if the environment variable for the
|
|
||||||
// provider is not set.
|
|
||||||
env_key: Some("PATH".into()),
|
|
||||||
env_key_instructions: None,
|
|
||||||
wire_api: codex_core::WireApi::Responses,
|
|
||||||
query_params: None,
|
|
||||||
http_headers: Some(
|
|
||||||
[("originator".to_string(), "codex_cli_rs".to_string())]
|
|
||||||
.into_iter()
|
|
||||||
.collect(),
|
|
||||||
),
|
|
||||||
env_http_headers: None,
|
|
||||||
request_max_retries: Some(0),
|
|
||||||
stream_max_retries: Some(0),
|
|
||||||
stream_idle_timeout_ms: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Init session
|
// Init session
|
||||||
let codex_home = TempDir::new().unwrap();
|
let codex_home = TempDir::new().unwrap();
|
||||||
let mut config = load_default_config_for_test(&codex_home);
|
let mut config = load_default_config_for_test(&codex_home);
|
||||||
config.model_provider = model_provider;
|
config.model_provider = model_provider;
|
||||||
|
|
||||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||||
|
config,
|
||||||
|
Some(CodexAuth::from_api_key("Test API Key".to_string())),
|
||||||
|
ctrl_c.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
codex
|
codex
|
||||||
.submit(Op::UserInput {
|
.submit(Op::UserInput {
|
||||||
@@ -95,15 +94,20 @@ async fn includes_session_id_and_model_headers_in_request() {
|
|||||||
|
|
||||||
// get request from the server
|
// get request from the server
|
||||||
let request = &server.received_requests().await.unwrap()[0];
|
let request = &server.received_requests().await.unwrap()[0];
|
||||||
let request_body = request.headers.get("session_id").unwrap();
|
let request_session_id = request.headers.get("session_id").unwrap();
|
||||||
let originator = request.headers.get("originator").unwrap();
|
let request_originator = request.headers.get("originator").unwrap();
|
||||||
|
let request_authorization = request.headers.get("authorization").unwrap();
|
||||||
|
|
||||||
assert!(current_session_id.is_some());
|
assert!(current_session_id.is_some());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
request_body.to_str().unwrap(),
|
request_session_id.to_str().unwrap(),
|
||||||
current_session_id.as_ref().unwrap()
|
current_session_id.as_ref().unwrap()
|
||||||
);
|
);
|
||||||
assert_eq!(originator.to_str().unwrap(), "codex_cli_rs");
|
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
|
||||||
|
assert_eq!(
|
||||||
|
request_authorization.to_str().unwrap(),
|
||||||
|
"Bearer Test API Key"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
@@ -126,22 +130,9 @@ async fn includes_base_instructions_override_in_request() {
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
let model_provider = ModelProviderInfo {
|
let model_provider = ModelProviderInfo {
|
||||||
name: "openai".into(),
|
base_url: Some(format!("{}/v1", server.uri())),
|
||||||
base_url: format!("{}/v1", server.uri()),
|
..built_in_model_providers()["openai"].clone()
|
||||||
// Environment variable that should exist in the test environment.
|
|
||||||
// ModelClient will return an error if the environment variable for the
|
|
||||||
// provider is not set.
|
|
||||||
env_key: Some("PATH".into()),
|
|
||||||
env_key_instructions: None,
|
|
||||||
wire_api: codex_core::WireApi::Responses,
|
|
||||||
query_params: None,
|
|
||||||
http_headers: None,
|
|
||||||
env_http_headers: None,
|
|
||||||
request_max_retries: Some(0),
|
|
||||||
stream_max_retries: Some(0),
|
|
||||||
stream_idle_timeout_ms: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let codex_home = TempDir::new().unwrap();
|
let codex_home = TempDir::new().unwrap();
|
||||||
let mut config = load_default_config_for_test(&codex_home);
|
let mut config = load_default_config_for_test(&codex_home);
|
||||||
|
|
||||||
@@ -149,7 +140,13 @@ async fn includes_base_instructions_override_in_request() {
|
|||||||
config.model_provider = model_provider;
|
config.model_provider = model_provider;
|
||||||
|
|
||||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||||
|
config,
|
||||||
|
Some(CodexAuth::from_api_key("Test API Key".to_string())),
|
||||||
|
ctrl_c.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
codex
|
codex
|
||||||
.submit(Op::UserInput {
|
.submit(Op::UserInput {
|
||||||
@@ -172,3 +169,108 @@ async fn includes_base_instructions_override_in_request() {
|
|||||||
.contains("test instructions")
|
.contains("test instructions")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn chatgpt_auth_sends_correct_request() {
|
||||||
|
#![allow(clippy::unwrap_used)]
|
||||||
|
|
||||||
|
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||||
|
println!(
|
||||||
|
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock server
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
|
||||||
|
// First request – must NOT include `previous_response_id`.
|
||||||
|
let first = ResponseTemplate::new(200)
|
||||||
|
.insert_header("content-type", "text/event-stream")
|
||||||
|
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||||
|
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/api/codex/responses"))
|
||||||
|
.respond_with(first)
|
||||||
|
.expect(1)
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let model_provider = ModelProviderInfo {
|
||||||
|
base_url: Some(format!("{}/api/codex", server.uri())),
|
||||||
|
..built_in_model_providers()["openai"].clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Init session
|
||||||
|
let codex_home = TempDir::new().unwrap();
|
||||||
|
let mut config = load_default_config_for_test(&codex_home);
|
||||||
|
config.model_provider = model_provider;
|
||||||
|
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
|
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||||
|
config,
|
||||||
|
Some(auth_from_token("Access Token".to_string())),
|
||||||
|
ctrl_c.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "hello".into(),
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) =
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_))).await
|
||||||
|
else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
let current_session_id = Some(session_id.to_string());
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
|
||||||
|
// get request from the server
|
||||||
|
let request = &server.received_requests().await.unwrap()[0];
|
||||||
|
let request_session_id = request.headers.get("session_id").unwrap();
|
||||||
|
let request_originator = request.headers.get("originator").unwrap();
|
||||||
|
let request_authorization = request.headers.get("authorization").unwrap();
|
||||||
|
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||||
|
|
||||||
|
assert!(current_session_id.is_some());
|
||||||
|
assert_eq!(
|
||||||
|
request_session_id.to_str().unwrap(),
|
||||||
|
current_session_id.as_ref().unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs");
|
||||||
|
assert_eq!(
|
||||||
|
request_authorization.to_str().unwrap(),
|
||||||
|
"Bearer Access Token"
|
||||||
|
);
|
||||||
|
assert!(!request_body["store"].as_bool().unwrap());
|
||||||
|
assert!(request_body["stream"].as_bool().unwrap());
|
||||||
|
assert_eq!(
|
||||||
|
request_body["include"][0].as_str().unwrap(),
|
||||||
|
"reasoning.encrypted_content"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_from_token(id_token: String) -> CodexAuth {
|
||||||
|
CodexAuth::new(
|
||||||
|
None,
|
||||||
|
AuthMode::ChatGPT,
|
||||||
|
PathBuf::new(),
|
||||||
|
Some(AuthDotJson {
|
||||||
|
tokens: TokenData {
|
||||||
|
id_token,
|
||||||
|
access_token: "Access Token".to_string(),
|
||||||
|
refresh_token: "test".to_string(),
|
||||||
|
account_id: None,
|
||||||
|
},
|
||||||
|
last_refresh: Utc::now(),
|
||||||
|
openai_api_key: None,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ async fn spawn_codex() -> Result<Codex, CodexErr> {
|
|||||||
config.model_provider.request_max_retries = Some(2);
|
config.model_provider.request_max_retries = Some(2);
|
||||||
config.model_provider.stream_max_retries = Some(2);
|
config.model_provider.stream_max_retries = Some(2);
|
||||||
let CodexSpawnOk { codex: agent, .. } =
|
let CodexSpawnOk { codex: agent, .. } =
|
||||||
Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?;
|
Codex::spawn(config, None, std::sync::Arc::new(Notify::new())).await?;
|
||||||
|
|
||||||
Ok(agent)
|
Ok(agent)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
|||||||
use codex_core::protocol::EventMsg;
|
use codex_core::protocol::EventMsg;
|
||||||
use codex_core::protocol::InputItem;
|
use codex_core::protocol::InputItem;
|
||||||
use codex_core::protocol::Op;
|
use codex_core::protocol::Op;
|
||||||
|
use codex_login::CodexAuth;
|
||||||
use core_test_support::load_default_config_for_test;
|
use core_test_support::load_default_config_for_test;
|
||||||
use core_test_support::load_sse_fixture;
|
use core_test_support::load_sse_fixture;
|
||||||
use core_test_support::load_sse_fixture_with_id;
|
use core_test_support::load_sse_fixture_with_id;
|
||||||
@@ -75,7 +76,7 @@ async fn retries_on_early_close() {
|
|||||||
|
|
||||||
let model_provider = ModelProviderInfo {
|
let model_provider = ModelProviderInfo {
|
||||||
name: "openai".into(),
|
name: "openai".into(),
|
||||||
base_url: format!("{}/v1", server.uri()),
|
base_url: Some(format!("{}/v1", server.uri())),
|
||||||
// Environment variable that should exist in the test environment.
|
// Environment variable that should exist in the test environment.
|
||||||
// ModelClient will return an error if the environment variable for the
|
// ModelClient will return an error if the environment variable for the
|
||||||
// provider is not set.
|
// provider is not set.
|
||||||
@@ -89,13 +90,20 @@ async fn retries_on_early_close() {
|
|||||||
request_max_retries: Some(0),
|
request_max_retries: Some(0),
|
||||||
stream_max_retries: Some(1),
|
stream_max_retries: Some(1),
|
||||||
stream_idle_timeout_ms: Some(2000),
|
stream_idle_timeout_ms: Some(2000),
|
||||||
|
requires_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
let codex_home = TempDir::new().unwrap();
|
let codex_home = TempDir::new().unwrap();
|
||||||
let mut config = load_default_config_for_test(&codex_home);
|
let mut config = load_default_config_for_test(&codex_home);
|
||||||
config.model_provider = model_provider;
|
config.model_provider = model_provider;
|
||||||
let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c).await.unwrap();
|
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||||
|
config,
|
||||||
|
Some(CodexAuth::from_api_key("Test API Key".to_string())),
|
||||||
|
ctrl_c,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
codex
|
codex
|
||||||
.submit(Op::UserInput {
|
.submit(Op::UserInput {
|
||||||
|
|||||||
@@ -1,20 +1,152 @@
|
|||||||
use chrono::DateTime;
|
use chrono::DateTime;
|
||||||
|
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
use std::env;
|
||||||
use std::fs::OpenOptions;
|
use std::fs::OpenOptions;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use std::os::unix::fs::OpenOptionsExt;
|
use std::os::unix::fs::OpenOptionsExt;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::process::Stdio;
|
use std::process::Stdio;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::Mutex;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
|
|
||||||
const SOURCE_FOR_PYTHON_SERVER: &str = include_str!("./login_with_chatgpt.py");
|
const SOURCE_FOR_PYTHON_SERVER: &str = include_str!("./login_with_chatgpt.py");
|
||||||
|
|
||||||
const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
|
const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||||
|
const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
|
pub enum AuthMode {
|
||||||
|
ApiKey,
|
||||||
|
ChatGPT,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CodexAuth {
|
||||||
|
pub api_key: Option<String>,
|
||||||
|
pub mode: AuthMode,
|
||||||
|
auth_dot_json: Arc<Mutex<Option<AuthDotJson>>>,
|
||||||
|
auth_file: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for CodexAuth {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.mode == other.mode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CodexAuth {
|
||||||
|
pub fn new(
|
||||||
|
api_key: Option<String>,
|
||||||
|
mode: AuthMode,
|
||||||
|
auth_file: PathBuf,
|
||||||
|
auth_dot_json: Option<AuthDotJson>,
|
||||||
|
) -> Self {
|
||||||
|
let auth_dot_json = Arc::new(Mutex::new(auth_dot_json));
|
||||||
|
Self {
|
||||||
|
api_key,
|
||||||
|
mode,
|
||||||
|
auth_file,
|
||||||
|
auth_dot_json,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_api_key(api_key: String) -> Self {
|
||||||
|
Self {
|
||||||
|
api_key: Some(api_key),
|
||||||
|
mode: AuthMode::ApiKey,
|
||||||
|
auth_file: PathBuf::new(),
|
||||||
|
auth_dot_json: Arc::new(Mutex::new(None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_token_data(&self) -> Result<TokenData, std::io::Error> {
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let auth_dot_json = self.auth_dot_json.lock().unwrap().clone();
|
||||||
|
|
||||||
|
match auth_dot_json {
|
||||||
|
Some(auth_dot_json) => {
|
||||||
|
if auth_dot_json.last_refresh < Utc::now() - chrono::Duration::days(28) {
|
||||||
|
let refresh_response = tokio::time::timeout(
|
||||||
|
Duration::from_secs(60),
|
||||||
|
try_refresh_token(auth_dot_json.tokens.refresh_token.clone()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
std::io::Error::other("timed out while refreshing OpenAI API key")
|
||||||
|
})?
|
||||||
|
.map_err(std::io::Error::other)?;
|
||||||
|
|
||||||
|
let updated_auth_dot_json = update_tokens(
|
||||||
|
&self.auth_file,
|
||||||
|
refresh_response.id_token,
|
||||||
|
refresh_response.access_token,
|
||||||
|
refresh_response.refresh_token,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let mut auth_dot_json = self.auth_dot_json.lock().unwrap();
|
||||||
|
*auth_dot_json = Some(updated_auth_dot_json);
|
||||||
|
}
|
||||||
|
Ok(auth_dot_json.tokens.clone())
|
||||||
|
}
|
||||||
|
None => Err(std::io::Error::other("Token data is not available.")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_token(&self) -> Result<String, std::io::Error> {
|
||||||
|
match self.mode {
|
||||||
|
AuthMode::ApiKey => Ok(self.api_key.clone().unwrap_or_default()),
|
||||||
|
AuthMode::ChatGPT => {
|
||||||
|
let id_token = self.get_token_data().await?.access_token;
|
||||||
|
|
||||||
|
Ok(id_token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loads the available auth information from the auth.json or OPENAI_API_KEY environment variable.
|
||||||
|
pub fn load_auth(codex_home: &Path) -> std::io::Result<Option<CodexAuth>> {
|
||||||
|
let auth_file = codex_home.join("auth.json");
|
||||||
|
|
||||||
|
let auth_dot_json = try_read_auth_json(&auth_file).ok();
|
||||||
|
|
||||||
|
let auth_json_api_key = auth_dot_json
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|a| a.openai_api_key.clone())
|
||||||
|
.filter(|s| !s.is_empty());
|
||||||
|
|
||||||
|
let openai_api_key = env::var(OPENAI_API_KEY_ENV_VAR)
|
||||||
|
.ok()
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.or(auth_json_api_key);
|
||||||
|
|
||||||
|
if openai_api_key.is_none() && auth_dot_json.is_none() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mode = if openai_api_key.is_some() {
|
||||||
|
AuthMode::ApiKey
|
||||||
|
} else {
|
||||||
|
AuthMode::ChatGPT
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(CodexAuth {
|
||||||
|
api_key: openai_api_key,
|
||||||
|
mode,
|
||||||
|
auth_file,
|
||||||
|
auth_dot_json: Arc::new(Mutex::new(auth_dot_json)),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
/// Run `python3 -c {{SOURCE_FOR_PYTHON_SERVER}}` with the CODEX_HOME
|
/// Run `python3 -c {{SOURCE_FOR_PYTHON_SERVER}}` with the CODEX_HOME
|
||||||
/// environment variable set to the provided `codex_home` path. If the
|
/// environment variable set to the provided `codex_home` path. If the
|
||||||
@@ -25,14 +157,12 @@ const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
|
|||||||
/// If `capture_output` is true, the subprocess's output will be captured and
|
/// If `capture_output` is true, the subprocess's output will be captured and
|
||||||
/// recorded in memory. Otherwise, the subprocess's output will be sent to the
|
/// recorded in memory. Otherwise, the subprocess's output will be sent to the
|
||||||
/// current process's stdout/stderr.
|
/// current process's stdout/stderr.
|
||||||
pub async fn login_with_chatgpt(
|
pub async fn login_with_chatgpt(codex_home: &Path, capture_output: bool) -> std::io::Result<()> {
|
||||||
codex_home: &Path,
|
|
||||||
capture_output: bool,
|
|
||||||
) -> std::io::Result<String> {
|
|
||||||
let child = Command::new("python3")
|
let child = Command::new("python3")
|
||||||
.arg("-c")
|
.arg("-c")
|
||||||
.arg(SOURCE_FOR_PYTHON_SERVER)
|
.arg(SOURCE_FOR_PYTHON_SERVER)
|
||||||
.env("CODEX_HOME", codex_home)
|
.env("CODEX_HOME", codex_home)
|
||||||
|
.env("CODEX_CLIENT_ID", CLIENT_ID)
|
||||||
.stdin(Stdio::null())
|
.stdin(Stdio::null())
|
||||||
.stdout(if capture_output {
|
.stdout(if capture_output {
|
||||||
Stdio::piped()
|
Stdio::piped()
|
||||||
@@ -48,7 +178,7 @@ pub async fn login_with_chatgpt(
|
|||||||
|
|
||||||
let output = child.wait_with_output().await?;
|
let output = child.wait_with_output().await?;
|
||||||
if output.status.success() {
|
if output.status.success() {
|
||||||
try_read_openai_api_key(codex_home).await
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
Err(std::io::Error::other(format!(
|
Err(std::io::Error::other(format!(
|
||||||
@@ -57,65 +187,54 @@ pub async fn login_with_chatgpt(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Attempt to read the `OPENAI_API_KEY` from the `auth.json` file in the given
|
|
||||||
/// `CODEX_HOME` directory, refreshing it, if necessary.
|
|
||||||
pub async fn try_read_openai_api_key(codex_home: &Path) -> std::io::Result<String> {
|
|
||||||
let auth_dot_json = try_read_auth_json(codex_home).await?;
|
|
||||||
Ok(auth_dot_json.openai_api_key)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Attempt to read and refresh the `auth.json` file in the given `CODEX_HOME` directory.
|
/// Attempt to read and refresh the `auth.json` file in the given `CODEX_HOME` directory.
|
||||||
/// Returns the full AuthDotJson structure after refreshing if necessary.
|
/// Returns the full AuthDotJson structure after refreshing if necessary.
|
||||||
pub async fn try_read_auth_json(codex_home: &Path) -> std::io::Result<AuthDotJson> {
|
pub fn try_read_auth_json(auth_file: &Path) -> std::io::Result<AuthDotJson> {
|
||||||
let auth_path = codex_home.join("auth.json");
|
let mut file = std::fs::File::open(auth_file)?;
|
||||||
let mut file = std::fs::File::open(&auth_path)?;
|
|
||||||
let mut contents = String::new();
|
let mut contents = String::new();
|
||||||
file.read_to_string(&mut contents)?;
|
file.read_to_string(&mut contents)?;
|
||||||
let auth_dot_json: AuthDotJson = serde_json::from_str(&contents)?;
|
let auth_dot_json: AuthDotJson = serde_json::from_str(&contents)?;
|
||||||
|
|
||||||
if is_expired(&auth_dot_json) {
|
Ok(auth_dot_json)
|
||||||
let refresh_response =
|
}
|
||||||
tokio::time::timeout(Duration::from_secs(60), try_refresh_token(&auth_dot_json))
|
|
||||||
.await
|
|
||||||
.map_err(|_| std::io::Error::other("timed out while refreshing OpenAI API key"))?
|
|
||||||
.map_err(std::io::Error::other)?;
|
|
||||||
let mut auth_dot_json = auth_dot_json;
|
|
||||||
auth_dot_json.tokens.id_token = refresh_response.id_token;
|
|
||||||
if let Some(refresh_token) = refresh_response.refresh_token {
|
|
||||||
auth_dot_json.tokens.refresh_token = refresh_token;
|
|
||||||
}
|
|
||||||
auth_dot_json.last_refresh = Utc::now();
|
|
||||||
|
|
||||||
let mut options = OpenOptions::new();
|
async fn update_tokens(
|
||||||
options.truncate(true).write(true).create(true);
|
auth_file: &Path,
|
||||||
#[cfg(unix)]
|
id_token: String,
|
||||||
{
|
access_token: Option<String>,
|
||||||
options.mode(0o600);
|
refresh_token: Option<String>,
|
||||||
}
|
) -> std::io::Result<AuthDotJson> {
|
||||||
|
let mut options = OpenOptions::new();
|
||||||
let json_data = serde_json::to_string(&auth_dot_json)?;
|
options.truncate(true).write(true).create(true);
|
||||||
{
|
#[cfg(unix)]
|
||||||
let mut file = options.open(&auth_path)?;
|
{
|
||||||
file.write_all(json_data.as_bytes())?;
|
options.mode(0o600);
|
||||||
file.flush()?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(auth_dot_json)
|
|
||||||
} else {
|
|
||||||
Ok(auth_dot_json)
|
|
||||||
}
|
}
|
||||||
|
let mut auth_dot_json = try_read_auth_json(auth_file)?;
|
||||||
|
|
||||||
|
auth_dot_json.tokens.id_token = id_token.to_string();
|
||||||
|
if let Some(access_token) = access_token {
|
||||||
|
auth_dot_json.tokens.access_token = access_token.to_string();
|
||||||
|
}
|
||||||
|
if let Some(refresh_token) = refresh_token {
|
||||||
|
auth_dot_json.tokens.refresh_token = refresh_token.to_string();
|
||||||
|
}
|
||||||
|
auth_dot_json.last_refresh = Utc::now();
|
||||||
|
|
||||||
|
let json_data = serde_json::to_string_pretty(&auth_dot_json)?;
|
||||||
|
{
|
||||||
|
let mut file = options.open(auth_file)?;
|
||||||
|
file.write_all(json_data.as_bytes())?;
|
||||||
|
file.flush()?;
|
||||||
|
}
|
||||||
|
Ok(auth_dot_json)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_expired(auth_dot_json: &AuthDotJson) -> bool {
|
async fn try_refresh_token(refresh_token: String) -> std::io::Result<RefreshResponse> {
|
||||||
let last_refresh = auth_dot_json.last_refresh;
|
|
||||||
last_refresh < Utc::now() - chrono::Duration::days(28)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn try_refresh_token(auth_dot_json: &AuthDotJson) -> std::io::Result<RefreshResponse> {
|
|
||||||
let refresh_request = RefreshRequest {
|
let refresh_request = RefreshRequest {
|
||||||
client_id: CLIENT_ID,
|
client_id: CLIENT_ID,
|
||||||
grant_type: "refresh_token",
|
grant_type: "refresh_token",
|
||||||
refresh_token: auth_dot_json.tokens.refresh_token.clone(),
|
refresh_token,
|
||||||
scope: "openid profile email",
|
scope: "openid profile email",
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -150,24 +269,25 @@ struct RefreshRequest {
|
|||||||
scope: &'static str,
|
scope: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize, Clone)]
|
||||||
struct RefreshResponse {
|
struct RefreshResponse {
|
||||||
id_token: String,
|
id_token: String,
|
||||||
|
access_token: Option<String>,
|
||||||
refresh_token: Option<String>,
|
refresh_token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Expected structure for $CODEX_HOME/auth.json.
|
/// Expected structure for $CODEX_HOME/auth.json.
|
||||||
#[derive(Deserialize, Serialize)]
|
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||||
pub struct AuthDotJson {
|
pub struct AuthDotJson {
|
||||||
#[serde(rename = "OPENAI_API_KEY")]
|
#[serde(rename = "OPENAI_API_KEY")]
|
||||||
pub openai_api_key: String,
|
pub openai_api_key: Option<String>,
|
||||||
|
|
||||||
pub tokens: TokenData,
|
pub tokens: TokenData,
|
||||||
|
|
||||||
pub last_refresh: DateTime<Utc>,
|
pub last_refresh: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Clone)]
|
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||||
pub struct TokenData {
|
pub struct TokenData {
|
||||||
/// This is a JWT.
|
/// This is a JWT.
|
||||||
pub id_token: String,
|
pub id_token: String,
|
||||||
@@ -177,5 +297,5 @@ pub struct TokenData {
|
|||||||
|
|
||||||
pub refresh_token: String,
|
pub refresh_token: String,
|
||||||
|
|
||||||
pub account_id: String,
|
pub account_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ from typing import Any, Dict # for type hints
|
|||||||
REQUIRED_PORT = 1455
|
REQUIRED_PORT = 1455
|
||||||
URL_BASE = f"http://localhost:{REQUIRED_PORT}"
|
URL_BASE = f"http://localhost:{REQUIRED_PORT}"
|
||||||
DEFAULT_ISSUER = "https://auth.openai.com"
|
DEFAULT_ISSUER = "https://auth.openai.com"
|
||||||
DEFAULT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
|
||||||
|
|
||||||
EXIT_CODE_WHEN_ADDRESS_ALREADY_IN_USE = 13
|
EXIT_CODE_WHEN_ADDRESS_ALREADY_IN_USE = 13
|
||||||
|
|
||||||
@@ -58,7 +57,7 @@ class TokenData:
|
|||||||
class AuthBundle:
|
class AuthBundle:
|
||||||
"""Aggregates authentication data produced after successful OAuth flow."""
|
"""Aggregates authentication data produced after successful OAuth flow."""
|
||||||
|
|
||||||
api_key: str
|
api_key: str | None
|
||||||
token_data: TokenData
|
token_data: TokenData
|
||||||
last_refresh: str
|
last_refresh: str
|
||||||
|
|
||||||
@@ -78,12 +77,18 @@ def main() -> None:
|
|||||||
eprint("ERROR: CODEX_HOME environment variable is not set")
|
eprint("ERROR: CODEX_HOME environment variable is not set")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
client_id = os.getenv("CODEX_CLIENT_ID")
|
||||||
|
if not client_id:
|
||||||
|
eprint("ERROR: CODEX_CLIENT_ID environment variable is not set")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# Spawn server.
|
# Spawn server.
|
||||||
try:
|
try:
|
||||||
httpd = _ApiKeyHTTPServer(
|
httpd = _ApiKeyHTTPServer(
|
||||||
("127.0.0.1", REQUIRED_PORT),
|
("127.0.0.1", REQUIRED_PORT),
|
||||||
_ApiKeyHTTPHandler,
|
_ApiKeyHTTPHandler,
|
||||||
codex_home=codex_home,
|
codex_home=codex_home,
|
||||||
|
client_id=client_id,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
)
|
)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
@@ -157,7 +162,7 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth_bundle, success_url = self._exchange_code_for_api_key(code)
|
auth_bundle, success_url = self._exchange_code(code)
|
||||||
except Exception as exc: # noqa: BLE001 – propagate to client
|
except Exception as exc: # noqa: BLE001 – propagate to client
|
||||||
self.send_error(500, f"Token exchange failed: {exc}")
|
self.send_error(500, f"Token exchange failed: {exc}")
|
||||||
return
|
return
|
||||||
@@ -211,68 +216,22 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler):
|
|||||||
if getattr(self.server, "verbose", False): # type: ignore[attr-defined]
|
if getattr(self.server, "verbose", False): # type: ignore[attr-defined]
|
||||||
super().log_message(fmt, *args)
|
super().log_message(fmt, *args)
|
||||||
|
|
||||||
def _exchange_code_for_api_key(self, code: str) -> tuple[AuthBundle, str]:
|
def _obtain_api_key(
|
||||||
"""Perform token + token-exchange to obtain an OpenAI API key.
|
self,
|
||||||
|
token_claims: Dict[str, Any],
|
||||||
|
access_claims: Dict[str, Any],
|
||||||
|
token_data: TokenData,
|
||||||
|
) -> tuple[str | None, str | None]:
|
||||||
|
"""Obtain an API key from the auth service.
|
||||||
|
|
||||||
Returns (AuthBundle, success_url).
|
Returns (api_key, success_url) if successful, None otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token_endpoint = f"{self.server.issuer}/oauth/token"
|
|
||||||
|
|
||||||
# 1. Authorization-code -> (id_token, access_token, refresh_token)
|
|
||||||
data = urllib.parse.urlencode(
|
|
||||||
{
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": code,
|
|
||||||
"redirect_uri": self.server.redirect_uri,
|
|
||||||
"client_id": self.server.client_id,
|
|
||||||
"code_verifier": self.server.pkce.code_verifier,
|
|
||||||
}
|
|
||||||
).encode()
|
|
||||||
|
|
||||||
token_data: TokenData
|
|
||||||
|
|
||||||
with urllib.request.urlopen(
|
|
||||||
urllib.request.Request(
|
|
||||||
token_endpoint,
|
|
||||||
data=data,
|
|
||||||
method="POST",
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
) as resp:
|
|
||||||
payload = json.loads(resp.read().decode())
|
|
||||||
|
|
||||||
# Extract chatgpt_account_id from id_token
|
|
||||||
id_token_parts = payload["id_token"].split(".")
|
|
||||||
if len(id_token_parts) != 3:
|
|
||||||
raise ValueError("Invalid ID token")
|
|
||||||
id_token_claims = _decode_jwt_segment(id_token_parts[1])
|
|
||||||
auth_claims = id_token_claims.get("https://api.openai.com/auth", {})
|
|
||||||
chatgpt_account_id = auth_claims.get("chatgpt_account_id", "")
|
|
||||||
|
|
||||||
token_data = TokenData(
|
|
||||||
id_token=payload["id_token"],
|
|
||||||
access_token=payload["access_token"],
|
|
||||||
refresh_token=payload["refresh_token"],
|
|
||||||
account_id=chatgpt_account_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
access_token_parts = token_data.access_token.split(".")
|
|
||||||
if len(access_token_parts) != 3:
|
|
||||||
raise ValueError("Invalid access token")
|
|
||||||
|
|
||||||
access_token_claims = _decode_jwt_segment(access_token_parts[1])
|
|
||||||
|
|
||||||
token_claims = id_token_claims.get("https://api.openai.com/auth", {})
|
|
||||||
access_claims = access_token_claims.get("https://api.openai.com/auth", {})
|
|
||||||
|
|
||||||
org_id = token_claims.get("organization_id")
|
org_id = token_claims.get("organization_id")
|
||||||
if not org_id:
|
|
||||||
raise ValueError("Missing organization in id_token claims")
|
|
||||||
|
|
||||||
project_id = token_claims.get("project_id")
|
project_id = token_claims.get("project_id")
|
||||||
if not project_id:
|
|
||||||
raise ValueError("Missing project in id_token claims")
|
if not org_id or not project_id:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
random_id = secrets.token_hex(6)
|
random_id = secrets.token_hex(6)
|
||||||
|
|
||||||
@@ -292,7 +251,7 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler):
|
|||||||
exchanged_access_token: str
|
exchanged_access_token: str
|
||||||
with urllib.request.urlopen(
|
with urllib.request.urlopen(
|
||||||
urllib.request.Request(
|
urllib.request.Request(
|
||||||
token_endpoint,
|
self.server.token_endpoint,
|
||||||
data=exchange_data,
|
data=exchange_data,
|
||||||
method="POST",
|
method="POST",
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
@@ -340,6 +299,65 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler):
|
|||||||
except Exception as exc: # pragma: no cover – best-effort only
|
except Exception as exc: # pragma: no cover – best-effort only
|
||||||
eprint(f"Unable to redeem ChatGPT subscriber API credits: {exc}")
|
eprint(f"Unable to redeem ChatGPT subscriber API credits: {exc}")
|
||||||
|
|
||||||
|
return (exchanged_access_token, success_url)
|
||||||
|
|
||||||
|
def _exchange_code(self, code: str) -> tuple[AuthBundle, str]:
|
||||||
|
"""Perform token + token-exchange to obtain an OpenAI API key.
|
||||||
|
|
||||||
|
Returns (AuthBundle, success_url).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. Authorization-code -> (id_token, access_token, refresh_token)
|
||||||
|
data = urllib.parse.urlencode(
|
||||||
|
{
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": self.server.redirect_uri,
|
||||||
|
"client_id": self.server.client_id,
|
||||||
|
"code_verifier": self.server.pkce.code_verifier,
|
||||||
|
}
|
||||||
|
).encode()
|
||||||
|
|
||||||
|
token_data: TokenData
|
||||||
|
|
||||||
|
with urllib.request.urlopen(
|
||||||
|
urllib.request.Request(
|
||||||
|
self.server.token_endpoint,
|
||||||
|
data=data,
|
||||||
|
method="POST",
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
) as resp:
|
||||||
|
payload = json.loads(resp.read().decode())
|
||||||
|
|
||||||
|
# Extract chatgpt_account_id from id_token
|
||||||
|
id_token_parts = payload["id_token"].split(".")
|
||||||
|
if len(id_token_parts) != 3:
|
||||||
|
raise ValueError("Invalid ID token")
|
||||||
|
id_token_claims = _decode_jwt_segment(id_token_parts[1])
|
||||||
|
auth_claims = id_token_claims.get("https://api.openai.com/auth", {})
|
||||||
|
chatgpt_account_id = auth_claims.get("chatgpt_account_id", "")
|
||||||
|
|
||||||
|
token_data = TokenData(
|
||||||
|
id_token=payload["id_token"],
|
||||||
|
access_token=payload["access_token"],
|
||||||
|
refresh_token=payload["refresh_token"],
|
||||||
|
account_id=chatgpt_account_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token_parts = token_data.access_token.split(".")
|
||||||
|
if len(access_token_parts) != 3:
|
||||||
|
raise ValueError("Invalid access token")
|
||||||
|
|
||||||
|
access_token_claims = _decode_jwt_segment(access_token_parts[1])
|
||||||
|
|
||||||
|
token_claims = id_token_claims.get("https://api.openai.com/auth", {})
|
||||||
|
access_claims = access_token_claims.get("https://api.openai.com/auth", {})
|
||||||
|
|
||||||
|
exchanged_access_token, success_url = self._obtain_api_key(
|
||||||
|
token_claims, access_claims, token_data
|
||||||
|
)
|
||||||
|
|
||||||
# Persist refresh_token/id_token for future use (redeem credits etc.)
|
# Persist refresh_token/id_token for future use (redeem credits etc.)
|
||||||
last_refresh_str = (
|
last_refresh_str = (
|
||||||
datetime.datetime.now(datetime.timezone.utc)
|
datetime.datetime.now(datetime.timezone.utc)
|
||||||
@@ -353,7 +371,7 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler):
|
|||||||
last_refresh=last_refresh_str,
|
last_refresh=last_refresh_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (auth_bundle, success_url)
|
return (auth_bundle, success_url or f"{URL_BASE}/success")
|
||||||
|
|
||||||
def request_shutdown(self) -> None:
|
def request_shutdown(self) -> None:
|
||||||
# shutdown() must be invoked from another thread to avoid
|
# shutdown() must be invoked from another thread to avoid
|
||||||
@@ -413,6 +431,7 @@ class _ApiKeyHTTPServer(http.server.HTTPServer):
|
|||||||
request_handler_class: type[http.server.BaseHTTPRequestHandler],
|
request_handler_class: type[http.server.BaseHTTPRequestHandler],
|
||||||
*,
|
*,
|
||||||
codex_home: str,
|
codex_home: str,
|
||||||
|
client_id: str,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(server_address, request_handler_class, bind_and_activate=True)
|
super().__init__(server_address, request_handler_class, bind_and_activate=True)
|
||||||
@@ -422,7 +441,8 @@ class _ApiKeyHTTPServer(http.server.HTTPServer):
|
|||||||
self.verbose: bool = verbose
|
self.verbose: bool = verbose
|
||||||
|
|
||||||
self.issuer: str = DEFAULT_ISSUER
|
self.issuer: str = DEFAULT_ISSUER
|
||||||
self.client_id: str = DEFAULT_CLIENT_ID
|
self.token_endpoint: str = f"{self.issuer}/oauth/token"
|
||||||
|
self.client_id: str = client_id
|
||||||
port = server_address[1]
|
port = server_address[1]
|
||||||
self.redirect_uri: str = f"http://localhost:{port}/auth/callback"
|
self.redirect_uri: str = f"http://localhost:{port}/auth/callback"
|
||||||
self.pkce: PkceCodes = _generate_pkce()
|
self.pkce: PkceCodes = _generate_pkce()
|
||||||
@@ -581,8 +601,8 @@ def maybe_redeem_credits(
|
|||||||
granted = redeem_data.get("granted_chatgpt_subscriber_api_credits", 0)
|
granted = redeem_data.get("granted_chatgpt_subscriber_api_credits", 0)
|
||||||
if granted and granted > 0:
|
if granted and granted > 0:
|
||||||
eprint(
|
eprint(
|
||||||
f"""Thanks for being a ChatGPT {'Plus' if plan_type=='plus' else 'Pro'} subscriber!
|
f"""Thanks for being a ChatGPT {"Plus" if plan_type == "plus" else "Pro"} subscriber!
|
||||||
If you haven't already redeemed, you should receive {'$5' if plan_type=='plus' else '$50'} in API credits.
|
If you haven't already redeemed, you should receive {"$5" if plan_type == "plus" else "$50"} in API credits.
|
||||||
|
|
||||||
Credits: https://platform.openai.com/settings/organization/billing/credit-grants
|
Credits: https://platform.openai.com/settings/organization/billing/credit-grants
|
||||||
More info: https://help.openai.com/en/articles/11381614""",
|
More info: https://help.openai.com/en/articles/11381614""",
|
||||||
|
|||||||
@@ -6,16 +6,14 @@ use app::App;
|
|||||||
use codex_core::config::Config;
|
use codex_core::config::Config;
|
||||||
use codex_core::config::ConfigOverrides;
|
use codex_core::config::ConfigOverrides;
|
||||||
use codex_core::config_types::SandboxMode;
|
use codex_core::config_types::SandboxMode;
|
||||||
use codex_core::openai_api_key::OPENAI_API_KEY_ENV_VAR;
|
|
||||||
use codex_core::openai_api_key::get_openai_api_key;
|
|
||||||
use codex_core::openai_api_key::set_openai_api_key;
|
|
||||||
use codex_core::protocol::AskForApproval;
|
use codex_core::protocol::AskForApproval;
|
||||||
use codex_core::util::is_inside_git_repo;
|
use codex_core::util::is_inside_git_repo;
|
||||||
use codex_login::try_read_openai_api_key;
|
use codex_login::load_auth;
|
||||||
use log_layer::TuiLogLayer;
|
use log_layer::TuiLogLayer;
|
||||||
use std::fs::OpenOptions;
|
use std::fs::OpenOptions;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use tracing::error;
|
||||||
use tracing_appender::non_blocking;
|
use tracing_appender::non_blocking;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
@@ -140,7 +138,7 @@ pub async fn run_main(
|
|||||||
.with(tui_layer)
|
.with(tui_layer)
|
||||||
.try_init();
|
.try_init();
|
||||||
|
|
||||||
let show_login_screen = should_show_login_screen(&config).await;
|
let show_login_screen = should_show_login_screen(&config);
|
||||||
if show_login_screen {
|
if show_login_screen {
|
||||||
std::io::stdout()
|
std::io::stdout()
|
||||||
.write_all(b"No API key detected.\nLogin with your ChatGPT account? [Yn] ")?;
|
.write_all(b"No API key detected.\nLogin with your ChatGPT account? [Yn] ")?;
|
||||||
@@ -153,8 +151,8 @@ pub async fn run_main(
|
|||||||
}
|
}
|
||||||
// Spawn a task to run the login command.
|
// Spawn a task to run the login command.
|
||||||
// Block until the login command is finished.
|
// Block until the login command is finished.
|
||||||
let new_key = codex_login::login_with_chatgpt(&config.codex_home, false).await?;
|
codex_login::login_with_chatgpt(&config.codex_home, false).await?;
|
||||||
set_openai_api_key(new_key);
|
|
||||||
std::io::stdout().write_all(b"Login successful.\n")?;
|
std::io::stdout().write_all(b"Login successful.\n")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,28 +215,21 @@ fn restore() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn should_show_login_screen(config: &Config) -> bool {
|
#[allow(clippy::unwrap_used)]
|
||||||
if is_in_need_of_openai_api_key(config) {
|
fn should_show_login_screen(config: &Config) -> bool {
|
||||||
|
if config.model_provider.requires_auth {
|
||||||
// Reading the OpenAI API key is an async operation because it may need
|
// Reading the OpenAI API key is an async operation because it may need
|
||||||
// to refresh the token. Block on it.
|
// to refresh the token. Block on it.
|
||||||
let codex_home = config.codex_home.clone();
|
let codex_home = config.codex_home.clone();
|
||||||
if let Ok(openai_api_key) = try_read_openai_api_key(&codex_home).await {
|
match load_auth(&codex_home) {
|
||||||
set_openai_api_key(openai_api_key);
|
Ok(Some(_)) => false,
|
||||||
false
|
Ok(None) => true,
|
||||||
} else {
|
Err(err) => {
|
||||||
true
|
error!("Failed to read auth.json: {err}");
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_in_need_of_openai_api_key(config: &Config) -> bool {
|
|
||||||
let is_using_openai_key = config
|
|
||||||
.model_provider
|
|
||||||
.env_key
|
|
||||||
.as_ref()
|
|
||||||
.map(|s| s == OPENAI_API_KEY_ENV_VAR)
|
|
||||||
.unwrap_or(false);
|
|
||||||
is_using_openai_key && get_openai_api_key().is_none()
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user