fix: when using --oss, ensure correct configuration is threaded through correctly (#1859)

This PR started as an investigation with the goal of eliminating the use
of `unsafe { std::env::set_var() }` in `ollama/src/client.rs`, as
setting environment variables in a multithreaded context is indeed
unsafe and these tests were observed to be flaky, as a result.

Though as I dug deeper into the issue, I discovered that the logic for
instantiating `OllamaClient` under test scenarios was not quite right.
In this PR, I aimed to:

- share more code between the two creation codepaths,
`try_from_oss_provider()` and `try_from_provider_with_base_url()`
- use the values from `Config` when setting up Ollama, as we have
various mechanisms for overriding config values, so we should be sure
that we are always using the ultimate `Config` for things such as the
`ModelProviderInfo` associated with the `oss` id

Once this was in place,
`OllamaClient::try_from_provider_with_base_url()` could be used in unit
tests for `OllamaClient` so it was possible to create a properly
configured client without having to set environment variables.
This commit is contained in:
Michael Bolin
2025-08-05 13:55:32 -07:00
committed by GitHub
parent 0c5fa271bc
commit d365cae077
6 changed files with 176 additions and 178 deletions

View File

@@ -32,6 +32,7 @@ pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::WireApi;
pub use model_provider_info::built_in_model_providers;
pub use model_provider_info::create_oss_provider_with_base_url;
pub mod model_family;
mod models;
mod openai_model_info;

View File

@@ -234,23 +234,6 @@ pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
use ModelProviderInfo as P;
// These CODEX_OSS_ environment variables are experimental: we may
// switch to reading values from config.toml instead.
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty())
{
Some(url) => url,
None => format!(
"http://localhost:{port}/v1",
port = std::env::var("CODEX_OSS_PORT")
.ok()
.filter(|v| !v.trim().is_empty())
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(DEFAULT_OLLAMA_PORT)
),
};
// We do not want to be in the business of adjucating which third-party
// providers are bundled with Codex CLI, so we only include the OpenAI and
// open source ("oss") providers by default. Users are encouraged to add to
@@ -295,29 +278,51 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
requires_auth: true,
},
),
(
BUILT_IN_OSS_MODEL_PROVIDER_ID,
P {
name: "Open Source".into(),
base_url: Some(codex_oss_base_url),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_auth: false,
},
),
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
]
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect()
}
pub fn create_oss_provider() -> ModelProviderInfo {
// These CODEX_OSS_ environment variables are experimental: we may
// switch to reading values from config.toml instead.
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty())
{
Some(url) => url,
None => format!(
"http://localhost:{port}/v1",
port = std::env::var("CODEX_OSS_PORT")
.ok()
.filter(|v| !v.trim().is_empty())
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(DEFAULT_OLLAMA_PORT)
),
};
create_oss_provider_with_base_url(&codex_oss_base_url)
}
pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
ModelProviderInfo {
name: "gpt-oss".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_auth: false,
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]

View File

@@ -22,6 +22,7 @@ use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use codex_core::protocol::TaskCompleteEvent;
use codex_core::util::is_inside_git_repo;
use codex_ollama::DEFAULT_OSS_MODEL;
use event_processor_with_human_output::EventProcessorWithHumanOutput;
use event_processor_with_json_output::EventProcessorWithJsonOutput;
use tracing::debug;
@@ -35,7 +36,7 @@ use crate::event_processor::EventProcessor;
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
let Cli {
images,
model,
model: model_cli_arg,
oss,
config_profile,
full_auto,
@@ -119,19 +120,18 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
// When using `--oss`, let the bootstrapper pick the model (defaulting to
// gpt-oss:20b) and ensure it is present locally. Also, force the builtin
// `oss` model provider.
let model_provider_override = if oss {
Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned())
let model = if let Some(model) = model_cli_arg {
Some(model)
} else if oss {
Some(DEFAULT_OSS_MODEL.to_owned())
} else {
None
None // No model specified, will use the default.
};
let model = if oss {
Some(
codex_ollama::ensure_oss_ready(model.clone())
.await
.map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?,
)
let model_provider = if oss {
Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_string())
} else {
model
None // No specific model provider override.
};
// Load configuration and determine approval policy
@@ -143,7 +143,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
approval_policy: Some(AskForApproval::Never),
sandbox_mode,
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
model_provider: model_provider_override,
model_provider,
codex_linux_sandbox_exe,
base_instructions: None,
include_plan_tool: None,
@@ -170,6 +170,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
))
};
if oss {
codex_ollama::ensure_oss_ready(&config)
.await
.map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?;
}
// Print the effective configuration and prompt so users can see what Codex
// is using.
event_processor.print_config_summary(&config, &prompt);

View File

@@ -5,13 +5,17 @@ use serde_json::Value as JsonValue;
use std::collections::VecDeque;
use std::io;
use codex_core::WireApi;
use crate::parser::pull_events_from_value;
use crate::pull::PullEvent;
use crate::pull::PullProgressReporter;
use crate::url::base_url_to_host_root;
use crate::url::is_openai_compatible_base_url;
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
use codex_core::ModelProviderInfo;
use codex_core::WireApi;
use codex_core::config::Config;
const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama";
/// Client for interacting with a local Ollama instance.
pub struct OllamaClient {
@@ -21,74 +25,77 @@ pub struct OllamaClient {
}
impl OllamaClient {
pub fn from_oss_provider() -> Self {
/// Construct a client for the builtin opensource ("oss") model provider
/// and verify that a local Ollama server is reachable. If no server is
/// detected, returns an error with helpful installation/run instructions.
pub async fn try_from_oss_provider(config: &Config) -> io::Result<Self> {
// Note that we must look up the provider from the Config to ensure that
// any overrides the user has in their config.toml are taken into
// account.
let provider = config
.model_providers
.get(BUILT_IN_OSS_MODEL_PROVIDER_ID)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",),
)
})?;
Self::try_from_provider(provider).await
}
#[cfg(test)]
async fn try_from_provider_with_base_url(base_url: &str) -> io::Result<Self> {
let provider = codex_core::create_oss_provider_with_base_url(base_url);
Self::try_from_provider(&provider).await
}
/// Build a client from a provider definition and verify the server is reachable.
async fn try_from_provider(provider: &ModelProviderInfo) -> io::Result<Self> {
#![allow(clippy::expect_used)]
// Use the built-in OSS provider's base URL.
let built_in_model_providers = codex_core::built_in_model_providers();
let provider = built_in_model_providers
.get(codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID)
.expect("oss provider must exist");
let base_url = provider
.base_url
.as_ref()
.expect("oss provider must have a base_url");
Self::from_provider(base_url, provider.wire_api)
}
/// Construct a client for the builtin opensource ("oss") model provider
/// and verify that a local Ollama server is reachable. If no server is
/// detected, returns an error with helpful installation/run instructions.
pub async fn try_from_oss_provider() -> io::Result<Self> {
let client = Self::from_oss_provider();
if client.probe_server().await? {
Ok(client)
} else {
Err(io::Error::other(
"No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama",
))
}
}
/// Build a client from a provider definition. Falls back to the default
/// local URL if no base_url is configured.
fn from_provider(base_url: &str, wire_api: WireApi) -> Self {
let uses_openai_compat = is_openai_compatible_base_url(base_url)
|| matches!(wire_api, WireApi::Chat) && is_openai_compatible_base_url(base_url);
|| matches!(provider.wire_api, WireApi::Chat)
&& is_openai_compatible_base_url(base_url);
let host_root = base_url_to_host_root(base_url);
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
Self {
let client = Self {
client,
host_root,
uses_openai_compat,
}
}
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
#[cfg(test)]
fn from_host_root(host_root: impl Into<String>) -> Self {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
Self {
client,
host_root: host_root.into(),
uses_openai_compat: false,
}
};
client.probe_server().await?;
Ok(client)
}
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
pub async fn probe_server(&self) -> io::Result<bool> {
async fn probe_server(&self) -> io::Result<()> {
let url = if self.uses_openai_compat {
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
} else {
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
};
let resp = self.client.get(url).send().await;
Ok(matches!(resp, Ok(r) if r.status().is_success()))
let resp = self.client.get(url).send().await.map_err(|err| {
tracing::warn!("Failed to connect to Ollama server: {err:?}");
io::Error::other(OLLAMA_CONNECTION_ERROR)
})?;
if resp.status().is_success() {
Ok(())
} else {
tracing::warn!(
"Failed to probe server at {}: HTTP {}",
self.host_root,
resp.status()
);
Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
}
}
/// Return the list of model names known to the local Ollama instance.
@@ -210,6 +217,20 @@ impl OllamaClient {
"Pull stream ended unexpectedly without success.",
))
}
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
#[cfg(test)]
fn from_host_root(host_root: impl Into<String>) -> Self {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
Self {
client,
host_root: host_root.into(),
uses_openai_compat: false,
}
}
}
#[cfg(test)]
@@ -217,34 +238,6 @@ mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
/// Simple RAII guard to set an environment variable for the duration of a test
/// and restore the previous value (or remove it) on drop to avoid cross-test
/// interference.
struct EnvVarGuard {
key: String,
prev: Option<String>,
}
impl EnvVarGuard {
fn set(key: &str, value: String) -> Self {
let prev = std::env::var(key).ok();
// set_var is safe but we mirror existing tests that use an unsafe block
// to silence edition lints around global mutation during tests.
unsafe { std::env::set_var(key, value) };
Self {
key: key.to_string(),
prev,
}
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.prev {
Some(v) => unsafe { std::env::set_var(&self.key, v) },
None => unsafe { std::env::remove_var(&self.key) },
}
}
}
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
#[tokio::test]
async fn test_fetch_models_happy_path() {
@@ -296,7 +289,7 @@ mod tests {
.mount(&server)
.await;
let native = OllamaClient::from_host_root(server.uri());
assert!(native.probe_server().await.expect("probe native"));
native.probe_server().await.expect("probe native");
// OpenAI compatibility endpoint
wiremock::Mock::given(wiremock::matchers::method("GET"))
@@ -304,11 +297,14 @@ mod tests {
.respond_with(wiremock::ResponseTemplate::new(200))
.mount(&server)
.await;
// Ensure the built-in OSS provider points at our mock server for this test
// to avoid depending on any globally configured environment from other tests.
let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri()));
let ollama_client = OllamaClient::from_oss_provider();
assert!(ollama_client.probe_server().await.expect("probe compat"));
let ollama_client =
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
.await
.expect("probe OpenAI compat");
ollama_client
.probe_server()
.await
.expect("probe OpenAI compat");
}
#[tokio::test]
@@ -322,9 +318,6 @@ mod tests {
}
let server = wiremock::MockServer::start().await;
// Configure builtin `oss` provider to point at this mock server.
// set_var is unsafe on Rust 2024 edition; use unsafe block in tests.
let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri()));
// OpenAIcompat models endpoint responds OK.
wiremock::Mock::given(wiremock::matchers::method("GET"))
@@ -333,7 +326,7 @@ mod tests {
.mount(&server)
.await;
let _client = OllamaClient::try_from_oss_provider()
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
.await
.expect("client should be created when probe succeeds");
}
@@ -349,18 +342,10 @@ mod tests {
}
let server = wiremock::MockServer::start().await;
// Point oss provider at our mock server but do NOT set up a handler
// for /v1/models so the request returns a nonsuccess status.
unsafe { std::env::set_var("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())) };
let err = OllamaClient::try_from_oss_provider()
let err = OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
.await
.err()
.expect("expected error");
let msg = err.to_string();
assert!(
msg.contains("No running Ollama server detected."),
"msg = {msg}"
);
assert_eq!(OLLAMA_CONNECTION_ERROR, err.to_string());
}
}

View File

@@ -4,6 +4,7 @@ mod pull;
mod url;
pub use client::OllamaClient;
use codex_core::config::Config;
pub use pull::CliProgressReporter;
pub use pull::PullEvent;
pub use pull::PullProgressReporter;
@@ -15,38 +16,29 @@ pub const DEFAULT_OSS_MODEL: &str = "gpt-oss:20b";
/// Prepare the local OSS environment when `--oss` is selected.
///
/// - Ensures a local Ollama server is reachable.
/// - Selects the final model name (CLI override or default).
/// - Checks if the model exists locally and pulls it if missing.
///
/// Returns the final model name that should be used by the caller.
pub async fn ensure_oss_ready(cli_model: Option<String>) -> std::io::Result<String> {
pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> {
// Only download when the requested model is the default OSS model (or when -m is not provided).
let should_download = cli_model
.as_deref()
.map(|name| name == DEFAULT_OSS_MODEL)
.unwrap_or(true);
let model = cli_model.unwrap_or_else(|| DEFAULT_OSS_MODEL.to_string());
let model = config.model.as_ref();
// Verify local Ollama is reachable.
let ollama_client = crate::OllamaClient::try_from_oss_provider().await?;
let ollama_client = crate::OllamaClient::try_from_oss_provider(config).await?;
if should_download {
// If the model is not present locally, pull it.
match ollama_client.fetch_models().await {
Ok(models) => {
if !models.iter().any(|m| m == &model) {
let mut reporter = crate::CliProgressReporter::new();
ollama_client
.pull_with_reporter(&model, &mut reporter)
.await?;
}
}
Err(err) => {
// Not fatal; higher layers may still proceed and surface errors later.
tracing::warn!("Failed to query local models from Ollama: {}.", err);
// If the model is not present locally, pull it.
match ollama_client.fetch_models().await {
Ok(models) => {
if !models.iter().any(|m| m == model) {
let mut reporter = crate::CliProgressReporter::new();
ollama_client
.pull_with_reporter(model, &mut reporter)
.await?;
}
}
Err(err) => {
// Not fatal; higher layers may still proceed and surface errors later.
tracing::warn!("Failed to query local models from Ollama: {}.", err);
}
}
Ok(model)
Ok(())
}

View File

@@ -10,6 +10,7 @@ use codex_core::config_types::SandboxMode;
use codex_core::protocol::AskForApproval;
use codex_core::util::is_inside_git_repo;
use codex_login::load_auth;
use codex_ollama::DEFAULT_OSS_MODEL;
use log_layer::TuiLogLayer;
use std::fs::OpenOptions;
use std::io::Write;
@@ -71,25 +72,27 @@ pub async fn run_main(
)
};
// When using `--oss`, let the bootstrapper pick the model (defaulting to
// gpt-oss:20b) and ensure it is present locally. Also, force the builtin
// `oss` model provider.
let model = if let Some(model) = &cli.model {
Some(model.clone())
} else if cli.oss {
Some(DEFAULT_OSS_MODEL.to_owned())
} else {
None // No model specified, will use the default.
};
let model_provider_override = if cli.oss {
Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned())
} else {
None
};
let config = {
// Load configuration and support CLI overrides.
let overrides = ConfigOverrides {
// When using `--oss`, let the bootstrapper pick the model
// (defaulting to gpt-oss:20b) and ensure it is present locally.
model: if cli.oss {
Some(
codex_ollama::ensure_oss_ready(cli.model.clone())
.await
.map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?,
)
} else {
cli.model.clone()
},
model,
approval_policy,
sandbox_mode,
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
@@ -154,6 +157,12 @@ pub async fn run_main(
.with_target(false)
.with_filter(env_filter());
if cli.oss {
codex_ollama::ensure_oss_ready(&config)
.await
.map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?;
}
// Channel that carries formatted log lines to the UI.
let (log_tx, log_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let tui_layer = TuiLogLayer::new(log_tx.clone(), 120).with_filter(env_filter());