feat: read model_provider and model_providers from config.toml (#853)
This is the first step in supporting other model providers in the Rust CLI. Specifically, this PR adds support for the new entries in `Config` and `ConfigOverrides` to specify a `ModelProviderInfo`, which is the basic config needed for an LLM provider. This PR does not get us all the way there yet because `client.rs` still categorically appends `/responses` to the URL and expects the endpoint to support the OpenAI Responses API. Will fix that next!
This commit is contained in:
@@ -26,10 +26,9 @@ use tracing::warn;
|
|||||||
use crate::error::CodexErr;
|
use crate::error::CodexErr;
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||||
use crate::flags::OPENAI_API_BASE;
|
|
||||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||||
use crate::flags::get_api_key;
|
use crate::model_provider_info::ModelProviderInfo;
|
||||||
use crate::models::ResponseItem;
|
use crate::models::ResponseItem;
|
||||||
use crate::util::backoff;
|
use crate::util::backoff;
|
||||||
|
|
||||||
@@ -141,13 +140,16 @@ static DEFAULT_TOOLS: LazyLock<Vec<ResponsesApiTool>> = LazyLock::new(|| {
|
|||||||
pub struct ModelClient {
|
pub struct ModelClient {
|
||||||
model: String,
|
model: String,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
|
provider: ModelProviderInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelClient {
|
impl ModelClient {
|
||||||
pub fn new(model: impl ToString) -> Self {
|
pub fn new(model: impl ToString, provider: ModelProviderInfo) -> Self {
|
||||||
let model = model.to_string();
|
Self {
|
||||||
let client = reqwest::Client::new();
|
model: model.to_string(),
|
||||||
Self { model, client }
|
client: reqwest::Client::new(),
|
||||||
|
provider,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||||
@@ -188,7 +190,9 @@ impl ModelClient {
|
|||||||
stream: true,
|
stream: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
let url = format!("{}/v1/responses", *OPENAI_API_BASE);
|
let base_url = self.provider.base_url.clone();
|
||||||
|
let base_url = base_url.trim_end_matches('/');
|
||||||
|
let url = format!("{}/responses", base_url);
|
||||||
debug!(url, "POST");
|
debug!(url, "POST");
|
||||||
trace!("request payload: {}", serde_json::to_string(&payload)?);
|
trace!("request payload: {}", serde_json::to_string(&payload)?);
|
||||||
|
|
||||||
@@ -196,10 +200,14 @@ impl ModelClient {
|
|||||||
loop {
|
loop {
|
||||||
attempt += 1;
|
attempt += 1;
|
||||||
|
|
||||||
|
let api_key = self
|
||||||
|
.provider
|
||||||
|
.api_key()
|
||||||
|
.ok_or_else(|| crate::error::CodexErr::EnvVar("API_KEY"))?;
|
||||||
let res = self
|
let res = self
|
||||||
.client
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.bearer_auth(get_api_key()?)
|
.bearer_auth(api_key)
|
||||||
.header("OpenAI-Beta", "responses=experimental")
|
.header("OpenAI-Beta", "responses=experimental")
|
||||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||||
.json(&payload)
|
.json(&payload)
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ impl Codex {
|
|||||||
let (tx_sub, rx_sub) = async_channel::bounded(64);
|
let (tx_sub, rx_sub) = async_channel::bounded(64);
|
||||||
let (tx_event, rx_event) = async_channel::bounded(64);
|
let (tx_event, rx_event) = async_channel::bounded(64);
|
||||||
let configure_session = Op::ConfigureSession {
|
let configure_session = Op::ConfigureSession {
|
||||||
|
provider: config.model_provider.clone(),
|
||||||
model: config.model.clone(),
|
model: config.model.clone(),
|
||||||
instructions: config.instructions.clone(),
|
instructions: config.instructions.clone(),
|
||||||
approval_policy: config.approval_policy,
|
approval_policy: config.approval_policy,
|
||||||
@@ -504,6 +505,7 @@ async fn submission_loop(
|
|||||||
sess.abort();
|
sess.abort();
|
||||||
}
|
}
|
||||||
Op::ConfigureSession {
|
Op::ConfigureSession {
|
||||||
|
provider,
|
||||||
model,
|
model,
|
||||||
instructions,
|
instructions,
|
||||||
approval_policy,
|
approval_policy,
|
||||||
@@ -512,7 +514,7 @@ async fn submission_loop(
|
|||||||
notify,
|
notify,
|
||||||
cwd,
|
cwd,
|
||||||
} => {
|
} => {
|
||||||
info!(model, "Configuring session");
|
info!("Configuring session: model={model}; provider={provider:?}");
|
||||||
if !cwd.is_absolute() {
|
if !cwd.is_absolute() {
|
||||||
let message = format!("cwd is not absolute: {cwd:?}");
|
let message = format!("cwd is not absolute: {cwd:?}");
|
||||||
error!(message);
|
error!(message);
|
||||||
@@ -526,7 +528,7 @@ async fn submission_loop(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let client = ModelClient::new(model.clone());
|
let client = ModelClient::new(model.clone(), provider.clone());
|
||||||
|
|
||||||
// abort any current running session and clone its state
|
// abort any current running session and clone its state
|
||||||
let state = match sess.take() {
|
let state = match sess.take() {
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
use crate::flags::OPENAI_DEFAULT_MODEL;
|
use crate::flags::OPENAI_DEFAULT_MODEL;
|
||||||
use crate::mcp_server_config::McpServerConfig;
|
use crate::mcp_server_config::McpServerConfig;
|
||||||
|
use crate::model_provider_info::ModelProviderInfo;
|
||||||
|
use crate::model_provider_info::built_in_model_providers;
|
||||||
use crate::protocol::AskForApproval;
|
use crate::protocol::AskForApproval;
|
||||||
use crate::protocol::SandboxPermission;
|
use crate::protocol::SandboxPermission;
|
||||||
use crate::protocol::SandboxPolicy;
|
use crate::protocol::SandboxPolicy;
|
||||||
@@ -19,6 +21,9 @@ pub struct Config {
|
|||||||
/// Optional override of model selection.
|
/// Optional override of model selection.
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
|
||||||
|
/// Info needed to make an API request to the model.
|
||||||
|
pub model_provider: ModelProviderInfo,
|
||||||
|
|
||||||
/// Approval policy for executing commands.
|
/// Approval policy for executing commands.
|
||||||
pub approval_policy: AskForApproval,
|
pub approval_policy: AskForApproval,
|
||||||
|
|
||||||
@@ -61,6 +66,9 @@ pub struct Config {
|
|||||||
|
|
||||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||||
|
|
||||||
|
/// Combined provider map (defaults merged with user-defined overrides).
|
||||||
|
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Base config deserialized from ~/.codex/config.toml.
|
/// Base config deserialized from ~/.codex/config.toml.
|
||||||
@@ -69,6 +77,9 @@ pub struct ConfigToml {
|
|||||||
/// Optional override of model selection.
|
/// Optional override of model selection.
|
||||||
pub model: Option<String>,
|
pub model: Option<String>,
|
||||||
|
|
||||||
|
/// Provider to use from the model_providers map.
|
||||||
|
pub model_provider: Option<String>,
|
||||||
|
|
||||||
/// Default approval policy for executing commands.
|
/// Default approval policy for executing commands.
|
||||||
pub approval_policy: Option<AskForApproval>,
|
pub approval_policy: Option<AskForApproval>,
|
||||||
|
|
||||||
@@ -93,6 +104,10 @@ pub struct ConfigToml {
|
|||||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||||
|
|
||||||
|
/// User-defined provider entries that extend/override the built-in list.
|
||||||
|
#[serde(default)]
|
||||||
|
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConfigToml {
|
impl ConfigToml {
|
||||||
@@ -152,6 +167,7 @@ pub struct ConfigOverrides {
|
|||||||
pub approval_policy: Option<AskForApproval>,
|
pub approval_policy: Option<AskForApproval>,
|
||||||
pub sandbox_policy: Option<SandboxPolicy>,
|
pub sandbox_policy: Option<SandboxPolicy>,
|
||||||
pub disable_response_storage: Option<bool>,
|
pub disable_response_storage: Option<bool>,
|
||||||
|
pub provider: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@@ -161,10 +177,13 @@ impl Config {
|
|||||||
pub fn load_with_overrides(overrides: ConfigOverrides) -> std::io::Result<Self> {
|
pub fn load_with_overrides(overrides: ConfigOverrides) -> std::io::Result<Self> {
|
||||||
let cfg: ConfigToml = ConfigToml::load_from_toml()?;
|
let cfg: ConfigToml = ConfigToml::load_from_toml()?;
|
||||||
tracing::warn!("Config parsed from config.toml: {cfg:?}");
|
tracing::warn!("Config parsed from config.toml: {cfg:?}");
|
||||||
Ok(Self::load_from_base_config_with_overrides(cfg, overrides))
|
Self::load_from_base_config_with_overrides(cfg, overrides)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_from_base_config_with_overrides(cfg: ConfigToml, overrides: ConfigOverrides) -> Self {
|
fn load_from_base_config_with_overrides(
|
||||||
|
cfg: ConfigToml,
|
||||||
|
overrides: ConfigOverrides,
|
||||||
|
) -> std::io::Result<Self> {
|
||||||
// Instructions: user-provided instructions.md > embedded default.
|
// Instructions: user-provided instructions.md > embedded default.
|
||||||
let instructions =
|
let instructions =
|
||||||
Self::load_instructions().or_else(|| Some(EMBEDDED_INSTRUCTIONS.to_string()));
|
Self::load_instructions().or_else(|| Some(EMBEDDED_INSTRUCTIONS.to_string()));
|
||||||
@@ -176,6 +195,7 @@ impl Config {
|
|||||||
approval_policy,
|
approval_policy,
|
||||||
sandbox_policy,
|
sandbox_policy,
|
||||||
disable_response_storage,
|
disable_response_storage,
|
||||||
|
provider,
|
||||||
} = overrides;
|
} = overrides;
|
||||||
|
|
||||||
let sandbox_policy = match sandbox_policy {
|
let sandbox_policy = match sandbox_policy {
|
||||||
@@ -193,8 +213,28 @@ impl Config {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Self {
|
let mut model_providers = built_in_model_providers();
|
||||||
|
// Merge user-defined providers into the built-in list.
|
||||||
|
for (key, provider) in cfg.model_providers.into_iter() {
|
||||||
|
model_providers.entry(key).or_insert(provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
let model_provider_name = provider
|
||||||
|
.or(cfg.model_provider)
|
||||||
|
.unwrap_or_else(|| "openai".to_string());
|
||||||
|
let model_provider = model_providers
|
||||||
|
.get(&model_provider_name)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::NotFound,
|
||||||
|
format!("Model provider `{model_provider_name}` not found"),
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let config = Self {
|
||||||
model: model.or(cfg.model).unwrap_or_else(default_model),
|
model: model.or(cfg.model).unwrap_or_else(default_model),
|
||||||
|
model_provider,
|
||||||
cwd: cwd.map_or_else(
|
cwd: cwd.map_or_else(
|
||||||
|| {
|
|| {
|
||||||
tracing::info!("cwd not set, using current dir");
|
tracing::info!("cwd not set, using current dir");
|
||||||
@@ -222,7 +262,9 @@ impl Config {
|
|||||||
notify: cfg.notify,
|
notify: cfg.notify,
|
||||||
instructions,
|
instructions,
|
||||||
mcp_servers: cfg.mcp_servers,
|
mcp_servers: cfg.mcp_servers,
|
||||||
}
|
model_providers,
|
||||||
|
};
|
||||||
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_instructions() -> Option<String> {
|
fn load_instructions() -> Option<String> {
|
||||||
@@ -238,6 +280,7 @@ impl Config {
|
|||||||
ConfigToml::default(),
|
ConfigToml::default(),
|
||||||
ConfigOverrides::default(),
|
ConfigOverrides::default(),
|
||||||
)
|
)
|
||||||
|
.expect("defaults for test should always succeed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,11 @@ use std::time::Duration;
|
|||||||
|
|
||||||
use env_flags::env_flags;
|
use env_flags::env_flags;
|
||||||
|
|
||||||
use crate::error::CodexErr;
|
|
||||||
use crate::error::Result;
|
|
||||||
|
|
||||||
env_flags! {
|
env_flags! {
|
||||||
pub OPENAI_DEFAULT_MODEL: &str = "o3";
|
pub OPENAI_DEFAULT_MODEL: &str = "o3";
|
||||||
pub OPENAI_API_BASE: &str = "https://api.openai.com";
|
pub OPENAI_API_BASE: &str = "https://api.openai.com/v1";
|
||||||
|
|
||||||
|
/// Fallback when the provider-specific key is not set.
|
||||||
pub OPENAI_API_KEY: Option<&str> = None;
|
pub OPENAI_API_KEY: Option<&str> = None;
|
||||||
pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
|
pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
|
||||||
value.parse().map(Duration::from_millis)
|
value.parse().map(Duration::from_millis)
|
||||||
@@ -21,9 +20,6 @@ env_flags! {
|
|||||||
value.parse().map(Duration::from_millis)
|
value.parse().map(Duration::from_millis)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Fixture path for offline tests (see client.rs).
|
||||||
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
|
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_api_key() -> Result<&'static str> {
|
|
||||||
OPENAI_API_KEY.ok_or_else(|| CodexErr::EnvVar("OPENAI_API_KEY"))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
mod client;
|
mod client;
|
||||||
pub mod codex;
|
pub mod codex;
|
||||||
|
pub use codex::Codex;
|
||||||
pub mod codex_wrapper;
|
pub mod codex_wrapper;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
@@ -18,6 +19,8 @@ pub mod linux;
|
|||||||
mod mcp_connection_manager;
|
mod mcp_connection_manager;
|
||||||
pub mod mcp_server_config;
|
pub mod mcp_server_config;
|
||||||
mod mcp_tool_call;
|
mod mcp_tool_call;
|
||||||
|
mod model_provider_info;
|
||||||
|
pub use model_provider_info::ModelProviderInfo;
|
||||||
mod models;
|
mod models;
|
||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
mod rollout;
|
mod rollout;
|
||||||
@@ -25,5 +28,3 @@ mod safety;
|
|||||||
mod user_notification;
|
mod user_notification;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
mod zdr_transcript;
|
mod zdr_transcript;
|
||||||
|
|
||||||
pub use codex::Codex;
|
|
||||||
|
|||||||
103
codex-rs/core/src/model_provider_info.rs
Normal file
103
codex-rs/core/src/model_provider_info.rs
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
//! Registry of model providers supported by Codex.
|
||||||
|
//!
|
||||||
|
//! Providers can be defined in two places:
|
||||||
|
//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box.
|
||||||
|
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
|
||||||
|
//! key. These override or extend the defaults at runtime.
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
/// Serializable representation of a provider definition.
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ModelProviderInfo {
|
||||||
|
/// Friendly display name.
|
||||||
|
pub name: String,
|
||||||
|
/// Base URL for the provider's OpenAI-compatible API.
|
||||||
|
pub base_url: String,
|
||||||
|
/// Environment variable that stores the user's API key for this provider.
|
||||||
|
pub env_key: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelProviderInfo {
|
||||||
|
/// Returns the API key for this provider if present in the environment.
|
||||||
|
pub fn api_key(&self) -> Option<String> {
|
||||||
|
std::env::var(&self.env_key).ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Built-in default provider list.
|
||||||
|
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||||
|
use ModelProviderInfo as P;
|
||||||
|
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"openai",
|
||||||
|
P {
|
||||||
|
name: "OpenAI".into(),
|
||||||
|
base_url: "https://api.openai.com/v1".into(),
|
||||||
|
env_key: "OPENAI_API_KEY".into(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openrouter",
|
||||||
|
P {
|
||||||
|
name: "OpenRouter".into(),
|
||||||
|
base_url: "https://openrouter.ai/api/v1".into(),
|
||||||
|
env_key: "OPENROUTER_API_KEY".into(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"gemini",
|
||||||
|
P {
|
||||||
|
name: "Gemini".into(),
|
||||||
|
base_url: "https://generativelanguage.googleapis.com/v1beta/openai".into(),
|
||||||
|
env_key: "GEMINI_API_KEY".into(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ollama",
|
||||||
|
P {
|
||||||
|
name: "Ollama".into(),
|
||||||
|
base_url: "http://localhost:11434/v1".into(),
|
||||||
|
env_key: "OLLAMA_API_KEY".into(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"mistral",
|
||||||
|
P {
|
||||||
|
name: "Mistral".into(),
|
||||||
|
base_url: "https://api.mistral.ai/v1".into(),
|
||||||
|
env_key: "MISTRAL_API_KEY".into(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"deepseek",
|
||||||
|
P {
|
||||||
|
name: "DeepSeek".into(),
|
||||||
|
base_url: "https://api.deepseek.com".into(),
|
||||||
|
env_key: "DEEPSEEK_API_KEY".into(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"xai",
|
||||||
|
P {
|
||||||
|
name: "xAI".into(),
|
||||||
|
base_url: "https://api.x.ai/v1".into(),
|
||||||
|
env_key: "XAI_API_KEY".into(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"groq",
|
||||||
|
P {
|
||||||
|
name: "Groq".into(),
|
||||||
|
base_url: "https://api.groq.com/openai/v1".into(),
|
||||||
|
env_key: "GROQ_API_KEY".into(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| (k.to_string(), v))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
@@ -11,6 +11,8 @@ use mcp_types::CallToolResult;
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::model_provider_info::ModelProviderInfo;
|
||||||
|
|
||||||
/// Submission Queue Entry - requests from user
|
/// Submission Queue Entry - requests from user
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Submission {
|
pub struct Submission {
|
||||||
@@ -27,6 +29,9 @@ pub struct Submission {
|
|||||||
pub enum Op {
|
pub enum Op {
|
||||||
/// Configure the model session.
|
/// Configure the model session.
|
||||||
ConfigureSession {
|
ConfigureSession {
|
||||||
|
/// Provider identifier ("openai", "openrouter", ...).
|
||||||
|
provider: ModelProviderInfo,
|
||||||
|
|
||||||
/// If not specified, server will use its default model.
|
/// If not specified, server will use its default model.
|
||||||
model: String,
|
model: String,
|
||||||
/// Model instructions
|
/// Model instructions
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use codex_core::Codex;
|
use codex_core::Codex;
|
||||||
|
use codex_core::ModelProviderInfo;
|
||||||
use codex_core::config::Config;
|
use codex_core::config::Config;
|
||||||
use codex_core::protocol::InputItem;
|
use codex_core::protocol::InputItem;
|
||||||
use codex_core::protocol::Op;
|
use codex_core::protocol::Op;
|
||||||
@@ -80,14 +81,21 @@ async fn keeps_previous_response_id_between_tasks() {
|
|||||||
// Update environment – `set_var` is `unsafe` starting with the 2024
|
// Update environment – `set_var` is `unsafe` starting with the 2024
|
||||||
// edition so we group the calls into a single `unsafe { … }` block.
|
// edition so we group the calls into a single `unsafe { … }` block.
|
||||||
unsafe {
|
unsafe {
|
||||||
std::env::set_var("OPENAI_API_KEY", "test-key");
|
|
||||||
std::env::set_var("OPENAI_API_BASE", server.uri());
|
|
||||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
|
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
|
||||||
}
|
}
|
||||||
|
let model_provider = ModelProviderInfo {
|
||||||
|
name: "openai".into(),
|
||||||
|
base_url: format!("{}/v1", server.uri()),
|
||||||
|
// Environment variable that should exist in the test environment.
|
||||||
|
// ModelClient will return an error if the environment variable for the
|
||||||
|
// provider is not set.
|
||||||
|
env_key: "PATH".into(),
|
||||||
|
};
|
||||||
|
|
||||||
// Init session
|
// Init session
|
||||||
let config = Config::load_default_config_for_test();
|
let mut config = Config::load_default_config_for_test();
|
||||||
|
config.model_provider = model_provider;
|
||||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use codex_core::Codex;
|
use codex_core::Codex;
|
||||||
|
use codex_core::ModelProviderInfo;
|
||||||
use codex_core::config::Config;
|
use codex_core::config::Config;
|
||||||
use codex_core::protocol::InputItem;
|
use codex_core::protocol::InputItem;
|
||||||
use codex_core::protocol::Op;
|
use codex_core::protocol::Op;
|
||||||
@@ -68,15 +69,23 @@ async fn retries_on_early_close() {
|
|||||||
// scope is very small and clearly delineated.
|
// scope is very small and clearly delineated.
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
std::env::set_var("OPENAI_API_KEY", "test-key");
|
|
||||||
std::env::set_var("OPENAI_API_BASE", server.uri());
|
|
||||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1");
|
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1");
|
||||||
std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000");
|
std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let model_provider = ModelProviderInfo {
|
||||||
|
name: "openai".into(),
|
||||||
|
base_url: format!("{}/v1", server.uri()),
|
||||||
|
// Environment variable that should exist in the test environment.
|
||||||
|
// ModelClient will return an error if the environment variable for the
|
||||||
|
// provider is not set.
|
||||||
|
env_key: "PATH".into(),
|
||||||
|
};
|
||||||
|
|
||||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
let config = Config::load_default_config_for_test();
|
let mut config = Config::load_default_config_for_test();
|
||||||
|
config.model_provider = model_provider;
|
||||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap();
|
let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap();
|
||||||
|
|
||||||
codex
|
codex
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> {
|
|||||||
None
|
None
|
||||||
},
|
},
|
||||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||||
|
provider: None,
|
||||||
};
|
};
|
||||||
let config = Config::load_with_overrides(overrides)?;
|
let config = Config::load_with_overrides(overrides)?;
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ impl CodexToolCallParam {
|
|||||||
approval_policy: approval_policy.map(Into::into),
|
approval_policy: approval_policy.map(Into::into),
|
||||||
sandbox_policy,
|
sandbox_policy,
|
||||||
disable_response_storage,
|
disable_response_storage,
|
||||||
|
provider: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let cfg = codex_core::config::Config::load_with_overrides(overrides)?;
|
let cfg = codex_core::config::Config::load_with_overrides(overrides)?;
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ pub fn run_main(cli: Cli) -> std::io::Result<()> {
|
|||||||
None
|
None
|
||||||
},
|
},
|
||||||
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
|
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
|
||||||
|
provider: None,
|
||||||
};
|
};
|
||||||
#[allow(clippy::print_stderr)]
|
#[allow(clippy::print_stderr)]
|
||||||
match Config::load_with_overrides(overrides) {
|
match Config::load_with_overrides(overrides) {
|
||||||
|
|||||||
Reference in New Issue
Block a user