From d365cae0771855d2ac2bddc90ef04b60ca872e7e Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Tue, 5 Aug 2025 13:55:32 -0700 Subject: [PATCH] fix: when using `--oss`, ensure correct configuration is threaded through correctly (#1859) This PR started as an investigation with the goal of eliminating the use of `unsafe { std::env::set_var() }` in `ollama/src/client.rs`, as setting environment variables in a multithreaded context is indeed unsafe and these tests were observed to be flaky, as a result. Though as I dug deeper into the issue, I discovered that the logic for instantiating `OllamaClient` under test scenarios was not quite right. In this PR, I aimed to: - share more code between the two creation codepaths, `try_from_oss_provider()` and `try_from_provider_with_base_url()` - use the values from `Config` when setting up Ollama, as we have various mechanisms for overriding config values, so we should be sure that we are always using the ultimate `Config` for things such as the `ModelProviderInfo` associated with the `oss` id Once this was in place, `OllamaClient::try_from_provider_with_base_url()` could be used in unit tests for `OllamaClient` so it was possible to create a properly configured client without having to set environment variables. --- codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/model_provider_info.rs | 73 +++++----- codex-rs/exec/src/lib.rs | 30 ++-- codex-rs/ollama/src/client.rs | 177 +++++++++++------------ codex-rs/ollama/src/lib.rs | 42 +++--- codex-rs/tui/src/lib.rs | 31 ++-- 6 files changed, 176 insertions(+), 178 deletions(-) diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 965cb77b..d072613e 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -32,6 +32,7 @@ pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID; pub use model_provider_info::ModelProviderInfo; pub use model_provider_info::WireApi; pub use model_provider_info::built_in_model_providers; +pub use model_provider_info::create_oss_provider_with_base_url; pub mod model_family; mod models; mod openai_model_info; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 595f05ef..db369df3 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -234,23 +234,6 @@ pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss"; pub fn built_in_model_providers() -> HashMap { use ModelProviderInfo as P; - // These CODEX_OSS_ environment variables are experimental: we may - // switch to reading values from config.toml instead. - let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL") - .ok() - .filter(|v| !v.trim().is_empty()) - { - Some(url) => url, - None => format!( - "http://localhost:{port}/v1", - port = std::env::var("CODEX_OSS_PORT") - .ok() - .filter(|v| !v.trim().is_empty()) - .and_then(|v| v.parse::().ok()) - .unwrap_or(DEFAULT_OLLAMA_PORT) - ), - }; - // We do not want to be in the business of adjucating which third-party // providers are bundled with Codex CLI, so we only include the OpenAI and // open source ("oss") providers by default. Users are encouraged to add to @@ -295,29 +278,51 @@ pub fn built_in_model_providers() -> HashMap { requires_auth: true, }, ), - ( - BUILT_IN_OSS_MODEL_PROVIDER_ID, - P { - name: "Open Source".into(), - base_url: Some(codex_oss_base_url), - env_key: None, - env_key_instructions: None, - wire_api: WireApi::Chat, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_auth: false, - }, - ), + (BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()), ] .into_iter() .map(|(k, v)| (k.to_string(), v)) .collect() } +pub fn create_oss_provider() -> ModelProviderInfo { + // These CODEX_OSS_ environment variables are experimental: we may + // switch to reading values from config.toml instead. + let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()) + { + Some(url) => url, + None => format!( + "http://localhost:{port}/v1", + port = std::env::var("CODEX_OSS_PORT") + .ok() + .filter(|v| !v.trim().is_empty()) + .and_then(|v| v.parse::().ok()) + .unwrap_or(DEFAULT_OLLAMA_PORT) + ), + }; + + create_oss_provider_with_base_url(&codex_oss_base_url) +} + +pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo { + ModelProviderInfo { + name: "gpt-oss".into(), + base_url: Some(base_url.into()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: false, + } +} + #[cfg(test)] mod tests { #![allow(clippy::unwrap_used)] diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index c1af4f5b..a0360182 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -22,6 +22,7 @@ use codex_core::protocol::InputItem; use codex_core::protocol::Op; use codex_core::protocol::TaskCompleteEvent; use codex_core::util::is_inside_git_repo; +use codex_ollama::DEFAULT_OSS_MODEL; use event_processor_with_human_output::EventProcessorWithHumanOutput; use event_processor_with_json_output::EventProcessorWithJsonOutput; use tracing::debug; @@ -35,7 +36,7 @@ use crate::event_processor::EventProcessor; pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> anyhow::Result<()> { let Cli { images, - model, + model: model_cli_arg, oss, config_profile, full_auto, @@ -119,19 +120,18 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any // When using `--oss`, let the bootstrapper pick the model (defaulting to // gpt-oss:20b) and ensure it is present locally. Also, force the built‑in // `oss` model provider. - let model_provider_override = if oss { - Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned()) + let model = if let Some(model) = model_cli_arg { + Some(model) + } else if oss { + Some(DEFAULT_OSS_MODEL.to_owned()) } else { - None + None // No model specified, will use the default. }; - let model = if oss { - Some( - codex_ollama::ensure_oss_ready(model.clone()) - .await - .map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?, - ) + + let model_provider = if oss { + Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_string()) } else { - model + None // No specific model provider override. }; // Load configuration and determine approval policy @@ -143,7 +143,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any approval_policy: Some(AskForApproval::Never), sandbox_mode, cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)), - model_provider: model_provider_override, + model_provider, codex_linux_sandbox_exe, base_instructions: None, include_plan_tool: None, @@ -170,6 +170,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any )) }; + if oss { + codex_ollama::ensure_oss_ready(&config) + .await + .map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?; + } + // Print the effective configuration and prompt so users can see what Codex // is using. event_processor.print_config_summary(&config, &prompt); diff --git a/codex-rs/ollama/src/client.rs b/codex-rs/ollama/src/client.rs index f86271dc..6f462113 100644 --- a/codex-rs/ollama/src/client.rs +++ b/codex-rs/ollama/src/client.rs @@ -5,13 +5,17 @@ use serde_json::Value as JsonValue; use std::collections::VecDeque; use std::io; -use codex_core::WireApi; - use crate::parser::pull_events_from_value; use crate::pull::PullEvent; use crate::pull::PullProgressReporter; use crate::url::base_url_to_host_root; use crate::url::is_openai_compatible_base_url; +use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; +use codex_core::ModelProviderInfo; +use codex_core::WireApi; +use codex_core::config::Config; + +const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama"; /// Client for interacting with a local Ollama instance. pub struct OllamaClient { @@ -21,74 +25,77 @@ pub struct OllamaClient { } impl OllamaClient { - pub fn from_oss_provider() -> Self { + /// Construct a client for the built‑in open‑source ("oss") model provider + /// and verify that a local Ollama server is reachable. If no server is + /// detected, returns an error with helpful installation/run instructions. + pub async fn try_from_oss_provider(config: &Config) -> io::Result { + // Note that we must look up the provider from the Config to ensure that + // any overrides the user has in their config.toml are taken into + // account. + let provider = config + .model_providers + .get(BUILT_IN_OSS_MODEL_PROVIDER_ID) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",), + ) + })?; + + Self::try_from_provider(provider).await + } + + #[cfg(test)] + async fn try_from_provider_with_base_url(base_url: &str) -> io::Result { + let provider = codex_core::create_oss_provider_with_base_url(base_url); + Self::try_from_provider(&provider).await + } + + /// Build a client from a provider definition and verify the server is reachable. + async fn try_from_provider(provider: &ModelProviderInfo) -> io::Result { #![allow(clippy::expect_used)] - // Use the built-in OSS provider's base URL. - let built_in_model_providers = codex_core::built_in_model_providers(); - let provider = built_in_model_providers - .get(codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID) - .expect("oss provider must exist"); let base_url = provider .base_url .as_ref() .expect("oss provider must have a base_url"); - Self::from_provider(base_url, provider.wire_api) - } - - /// Construct a client for the built‑in open‑source ("oss") model provider - /// and verify that a local Ollama server is reachable. If no server is - /// detected, returns an error with helpful installation/run instructions. - pub async fn try_from_oss_provider() -> io::Result { - let client = Self::from_oss_provider(); - if client.probe_server().await? { - Ok(client) - } else { - Err(io::Error::other( - "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama", - )) - } - } - - /// Build a client from a provider definition. Falls back to the default - /// local URL if no base_url is configured. - fn from_provider(base_url: &str, wire_api: WireApi) -> Self { let uses_openai_compat = is_openai_compatible_base_url(base_url) - || matches!(wire_api, WireApi::Chat) && is_openai_compatible_base_url(base_url); + || matches!(provider.wire_api, WireApi::Chat) + && is_openai_compatible_base_url(base_url); let host_root = base_url_to_host_root(base_url); let client = reqwest::Client::builder() .connect_timeout(std::time::Duration::from_secs(5)) .build() .unwrap_or_else(|_| reqwest::Client::new()); - Self { + let client = Self { client, host_root, uses_openai_compat, - } - } - - /// Low-level constructor given a raw host root, e.g. "http://localhost:11434". - #[cfg(test)] - fn from_host_root(host_root: impl Into) -> Self { - let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) - .build() - .unwrap_or_else(|_| reqwest::Client::new()); - Self { - client, - host_root: host_root.into(), - uses_openai_compat: false, - } + }; + client.probe_server().await?; + Ok(client) } /// Probe whether the server is reachable by hitting the appropriate health endpoint. - pub async fn probe_server(&self) -> io::Result { + async fn probe_server(&self) -> io::Result<()> { let url = if self.uses_openai_compat { format!("{}/v1/models", self.host_root.trim_end_matches('/')) } else { format!("{}/api/tags", self.host_root.trim_end_matches('/')) }; - let resp = self.client.get(url).send().await; - Ok(matches!(resp, Ok(r) if r.status().is_success())) + let resp = self.client.get(url).send().await.map_err(|err| { + tracing::warn!("Failed to connect to Ollama server: {err:?}"); + io::Error::other(OLLAMA_CONNECTION_ERROR) + })?; + if resp.status().is_success() { + Ok(()) + } else { + tracing::warn!( + "Failed to probe server at {}: HTTP {}", + self.host_root, + resp.status() + ); + Err(io::Error::other(OLLAMA_CONNECTION_ERROR)) + } } /// Return the list of model names known to the local Ollama instance. @@ -210,6 +217,20 @@ impl OllamaClient { "Pull stream ended unexpectedly without success.", )) } + + /// Low-level constructor given a raw host root, e.g. "http://localhost:11434". + #[cfg(test)] + fn from_host_root(host_root: impl Into) -> Self { + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { + client, + host_root: host_root.into(), + uses_openai_compat: false, + } + } } #[cfg(test)] @@ -217,34 +238,6 @@ mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] use super::*; - /// Simple RAII guard to set an environment variable for the duration of a test - /// and restore the previous value (or remove it) on drop to avoid cross-test - /// interference. - struct EnvVarGuard { - key: String, - prev: Option, - } - impl EnvVarGuard { - fn set(key: &str, value: String) -> Self { - let prev = std::env::var(key).ok(); - // set_var is safe but we mirror existing tests that use an unsafe block - // to silence edition lints around global mutation during tests. - unsafe { std::env::set_var(key, value) }; - Self { - key: key.to_string(), - prev, - } - } - } - impl Drop for EnvVarGuard { - fn drop(&mut self) { - match &self.prev { - Some(v) => unsafe { std::env::set_var(&self.key, v) }, - None => unsafe { std::env::remove_var(&self.key) }, - } - } - } - // Happy-path tests using a mock HTTP server; skip if sandbox network is disabled. #[tokio::test] async fn test_fetch_models_happy_path() { @@ -296,7 +289,7 @@ mod tests { .mount(&server) .await; let native = OllamaClient::from_host_root(server.uri()); - assert!(native.probe_server().await.expect("probe native")); + native.probe_server().await.expect("probe native"); // OpenAI compatibility endpoint wiremock::Mock::given(wiremock::matchers::method("GET")) @@ -304,11 +297,14 @@ mod tests { .respond_with(wiremock::ResponseTemplate::new(200)) .mount(&server) .await; - // Ensure the built-in OSS provider points at our mock server for this test - // to avoid depending on any globally configured environment from other tests. - let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())); - let ollama_client = OllamaClient::from_oss_provider(); - assert!(ollama_client.probe_server().await.expect("probe compat")); + let ollama_client = + OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri())) + .await + .expect("probe OpenAI compat"); + ollama_client + .probe_server() + .await + .expect("probe OpenAI compat"); } #[tokio::test] @@ -322,9 +318,6 @@ mod tests { } let server = wiremock::MockServer::start().await; - // Configure built‑in `oss` provider to point at this mock server. - // set_var is unsafe on Rust 2024 edition; use unsafe block in tests. - let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())); // OpenAI‑compat models endpoint responds OK. wiremock::Mock::given(wiremock::matchers::method("GET")) @@ -333,7 +326,7 @@ mod tests { .mount(&server) .await; - let _client = OllamaClient::try_from_oss_provider() + OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri())) .await .expect("client should be created when probe succeeds"); } @@ -349,18 +342,10 @@ mod tests { } let server = wiremock::MockServer::start().await; - // Point oss provider at our mock server but do NOT set up a handler - // for /v1/models so the request returns a non‑success status. - unsafe { std::env::set_var("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())) }; - - let err = OllamaClient::try_from_oss_provider() + let err = OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri())) .await .err() .expect("expected error"); - let msg = err.to_string(); - assert!( - msg.contains("No running Ollama server detected."), - "msg = {msg}" - ); + assert_eq!(OLLAMA_CONNECTION_ERROR, err.to_string()); } } diff --git a/codex-rs/ollama/src/lib.rs b/codex-rs/ollama/src/lib.rs index d6f1e04d..0ebf1662 100644 --- a/codex-rs/ollama/src/lib.rs +++ b/codex-rs/ollama/src/lib.rs @@ -4,6 +4,7 @@ mod pull; mod url; pub use client::OllamaClient; +use codex_core::config::Config; pub use pull::CliProgressReporter; pub use pull::PullEvent; pub use pull::PullProgressReporter; @@ -15,38 +16,29 @@ pub const DEFAULT_OSS_MODEL: &str = "gpt-oss:20b"; /// Prepare the local OSS environment when `--oss` is selected. /// /// - Ensures a local Ollama server is reachable. -/// - Selects the final model name (CLI override or default). /// - Checks if the model exists locally and pulls it if missing. -/// -/// Returns the final model name that should be used by the caller. -pub async fn ensure_oss_ready(cli_model: Option) -> std::io::Result { +pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> { // Only download when the requested model is the default OSS model (or when -m is not provided). - let should_download = cli_model - .as_deref() - .map(|name| name == DEFAULT_OSS_MODEL) - .unwrap_or(true); - let model = cli_model.unwrap_or_else(|| DEFAULT_OSS_MODEL.to_string()); + let model = config.model.as_ref(); // Verify local Ollama is reachable. - let ollama_client = crate::OllamaClient::try_from_oss_provider().await?; + let ollama_client = crate::OllamaClient::try_from_oss_provider(config).await?; - if should_download { - // If the model is not present locally, pull it. - match ollama_client.fetch_models().await { - Ok(models) => { - if !models.iter().any(|m| m == &model) { - let mut reporter = crate::CliProgressReporter::new(); - ollama_client - .pull_with_reporter(&model, &mut reporter) - .await?; - } - } - Err(err) => { - // Not fatal; higher layers may still proceed and surface errors later. - tracing::warn!("Failed to query local models from Ollama: {}.", err); + // If the model is not present locally, pull it. + match ollama_client.fetch_models().await { + Ok(models) => { + if !models.iter().any(|m| m == model) { + let mut reporter = crate::CliProgressReporter::new(); + ollama_client + .pull_with_reporter(model, &mut reporter) + .await?; } } + Err(err) => { + // Not fatal; higher layers may still proceed and surface errors later. + tracing::warn!("Failed to query local models from Ollama: {}.", err); + } } - Ok(model) + Ok(()) } diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 0b833b13..bab728e1 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -10,6 +10,7 @@ use codex_core::config_types::SandboxMode; use codex_core::protocol::AskForApproval; use codex_core::util::is_inside_git_repo; use codex_login::load_auth; +use codex_ollama::DEFAULT_OSS_MODEL; use log_layer::TuiLogLayer; use std::fs::OpenOptions; use std::io::Write; @@ -71,25 +72,27 @@ pub async fn run_main( ) }; + // When using `--oss`, let the bootstrapper pick the model (defaulting to + // gpt-oss:20b) and ensure it is present locally. Also, force the built‑in + // `oss` model provider. + let model = if let Some(model) = &cli.model { + Some(model.clone()) + } else if cli.oss { + Some(DEFAULT_OSS_MODEL.to_owned()) + } else { + None // No model specified, will use the default. + }; + let model_provider_override = if cli.oss { Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned()) } else { None }; + let config = { // Load configuration and support CLI overrides. let overrides = ConfigOverrides { - // When using `--oss`, let the bootstrapper pick the model - // (defaulting to gpt-oss:20b) and ensure it is present locally. - model: if cli.oss { - Some( - codex_ollama::ensure_oss_ready(cli.model.clone()) - .await - .map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?, - ) - } else { - cli.model.clone() - }, + model, approval_policy, sandbox_mode, cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)), @@ -154,6 +157,12 @@ pub async fn run_main( .with_target(false) .with_filter(env_filter()); + if cli.oss { + codex_ollama::ensure_oss_ready(&config) + .await + .map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?; + } + // Channel that carries formatted log lines to the UI. let (log_tx, log_rx) = tokio::sync::mpsc::unbounded_channel::(); let tui_layer = TuiLogLayer::new(log_tx.clone(), 120).with_filter(env_filter());