fix: when using --oss, ensure correct configuration is threaded through correctly (#1859)
This PR started as an investigation with the goal of eliminating the use
of `unsafe { std::env::set_var() }` in `ollama/src/client.rs`, as
setting environment variables in a multithreaded context is indeed
unsafe and these tests were observed to be flaky, as a result.
Though as I dug deeper into the issue, I discovered that the logic for
instantiating `OllamaClient` under test scenarios was not quite right.
In this PR, I aimed to:
- share more code between the two creation codepaths,
`try_from_oss_provider()` and `try_from_provider_with_base_url()`
- use the values from `Config` when setting up Ollama, as we have
various mechanisms for overriding config values, so we should be sure
that we are always using the ultimate `Config` for things such as the
`ModelProviderInfo` associated with the `oss` id
Once this was in place,
`OllamaClient::try_from_provider_with_base_url()` could be used in unit
tests for `OllamaClient` so it was possible to create a properly
configured client without having to set environment variables.
This commit is contained in:
@@ -32,6 +32,7 @@ pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
|||||||
pub use model_provider_info::ModelProviderInfo;
|
pub use model_provider_info::ModelProviderInfo;
|
||||||
pub use model_provider_info::WireApi;
|
pub use model_provider_info::WireApi;
|
||||||
pub use model_provider_info::built_in_model_providers;
|
pub use model_provider_info::built_in_model_providers;
|
||||||
|
pub use model_provider_info::create_oss_provider_with_base_url;
|
||||||
pub mod model_family;
|
pub mod model_family;
|
||||||
mod models;
|
mod models;
|
||||||
mod openai_model_info;
|
mod openai_model_info;
|
||||||
|
|||||||
@@ -234,23 +234,6 @@ pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
|
|||||||
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||||
use ModelProviderInfo as P;
|
use ModelProviderInfo as P;
|
||||||
|
|
||||||
// These CODEX_OSS_ environment variables are experimental: we may
|
|
||||||
// switch to reading values from config.toml instead.
|
|
||||||
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
|
|
||||||
.ok()
|
|
||||||
.filter(|v| !v.trim().is_empty())
|
|
||||||
{
|
|
||||||
Some(url) => url,
|
|
||||||
None => format!(
|
|
||||||
"http://localhost:{port}/v1",
|
|
||||||
port = std::env::var("CODEX_OSS_PORT")
|
|
||||||
.ok()
|
|
||||||
.filter(|v| !v.trim().is_empty())
|
|
||||||
.and_then(|v| v.parse::<u32>().ok())
|
|
||||||
.unwrap_or(DEFAULT_OLLAMA_PORT)
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
// We do not want to be in the business of adjucating which third-party
|
// We do not want to be in the business of adjucating which third-party
|
||||||
// providers are bundled with Codex CLI, so we only include the OpenAI and
|
// providers are bundled with Codex CLI, so we only include the OpenAI and
|
||||||
// open source ("oss") providers by default. Users are encouraged to add to
|
// open source ("oss") providers by default. Users are encouraged to add to
|
||||||
@@ -295,29 +278,51 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
|||||||
requires_auth: true,
|
requires_auth: true,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
|
||||||
BUILT_IN_OSS_MODEL_PROVIDER_ID,
|
|
||||||
P {
|
|
||||||
name: "Open Source".into(),
|
|
||||||
base_url: Some(codex_oss_base_url),
|
|
||||||
env_key: None,
|
|
||||||
env_key_instructions: None,
|
|
||||||
wire_api: WireApi::Chat,
|
|
||||||
query_params: None,
|
|
||||||
http_headers: None,
|
|
||||||
env_http_headers: None,
|
|
||||||
request_max_retries: None,
|
|
||||||
stream_max_retries: None,
|
|
||||||
stream_idle_timeout_ms: None,
|
|
||||||
requires_auth: false,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(k, v)| (k.to_string(), v))
|
.map(|(k, v)| (k.to_string(), v))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn create_oss_provider() -> ModelProviderInfo {
|
||||||
|
// These CODEX_OSS_ environment variables are experimental: we may
|
||||||
|
// switch to reading values from config.toml instead.
|
||||||
|
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
|
||||||
|
.ok()
|
||||||
|
.filter(|v| !v.trim().is_empty())
|
||||||
|
{
|
||||||
|
Some(url) => url,
|
||||||
|
None => format!(
|
||||||
|
"http://localhost:{port}/v1",
|
||||||
|
port = std::env::var("CODEX_OSS_PORT")
|
||||||
|
.ok()
|
||||||
|
.filter(|v| !v.trim().is_empty())
|
||||||
|
.and_then(|v| v.parse::<u32>().ok())
|
||||||
|
.unwrap_or(DEFAULT_OLLAMA_PORT)
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
create_oss_provider_with_base_url(&codex_oss_base_url)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
|
||||||
|
ModelProviderInfo {
|
||||||
|
name: "gpt-oss".into(),
|
||||||
|
base_url: Some(base_url.into()),
|
||||||
|
env_key: None,
|
||||||
|
env_key_instructions: None,
|
||||||
|
wire_api: WireApi::Chat,
|
||||||
|
query_params: None,
|
||||||
|
http_headers: None,
|
||||||
|
env_http_headers: None,
|
||||||
|
request_max_retries: None,
|
||||||
|
stream_max_retries: None,
|
||||||
|
stream_idle_timeout_ms: None,
|
||||||
|
requires_auth: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
#![allow(clippy::unwrap_used)]
|
#![allow(clippy::unwrap_used)]
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ use codex_core::protocol::InputItem;
|
|||||||
use codex_core::protocol::Op;
|
use codex_core::protocol::Op;
|
||||||
use codex_core::protocol::TaskCompleteEvent;
|
use codex_core::protocol::TaskCompleteEvent;
|
||||||
use codex_core::util::is_inside_git_repo;
|
use codex_core::util::is_inside_git_repo;
|
||||||
|
use codex_ollama::DEFAULT_OSS_MODEL;
|
||||||
use event_processor_with_human_output::EventProcessorWithHumanOutput;
|
use event_processor_with_human_output::EventProcessorWithHumanOutput;
|
||||||
use event_processor_with_json_output::EventProcessorWithJsonOutput;
|
use event_processor_with_json_output::EventProcessorWithJsonOutput;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
@@ -35,7 +36,7 @@ use crate::event_processor::EventProcessor;
|
|||||||
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
|
||||||
let Cli {
|
let Cli {
|
||||||
images,
|
images,
|
||||||
model,
|
model: model_cli_arg,
|
||||||
oss,
|
oss,
|
||||||
config_profile,
|
config_profile,
|
||||||
full_auto,
|
full_auto,
|
||||||
@@ -119,19 +120,18 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
|||||||
// When using `--oss`, let the bootstrapper pick the model (defaulting to
|
// When using `--oss`, let the bootstrapper pick the model (defaulting to
|
||||||
// gpt-oss:20b) and ensure it is present locally. Also, force the built‑in
|
// gpt-oss:20b) and ensure it is present locally. Also, force the built‑in
|
||||||
// `oss` model provider.
|
// `oss` model provider.
|
||||||
let model_provider_override = if oss {
|
let model = if let Some(model) = model_cli_arg {
|
||||||
Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned())
|
Some(model)
|
||||||
|
} else if oss {
|
||||||
|
Some(DEFAULT_OSS_MODEL.to_owned())
|
||||||
} else {
|
} else {
|
||||||
None
|
None // No model specified, will use the default.
|
||||||
};
|
};
|
||||||
let model = if oss {
|
|
||||||
Some(
|
let model_provider = if oss {
|
||||||
codex_ollama::ensure_oss_ready(model.clone())
|
Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_string())
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?,
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
model
|
None // No specific model provider override.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Load configuration and determine approval policy
|
// Load configuration and determine approval policy
|
||||||
@@ -143,7 +143,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
|||||||
approval_policy: Some(AskForApproval::Never),
|
approval_policy: Some(AskForApproval::Never),
|
||||||
sandbox_mode,
|
sandbox_mode,
|
||||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||||
model_provider: model_provider_override,
|
model_provider,
|
||||||
codex_linux_sandbox_exe,
|
codex_linux_sandbox_exe,
|
||||||
base_instructions: None,
|
base_instructions: None,
|
||||||
include_plan_tool: None,
|
include_plan_tool: None,
|
||||||
@@ -170,6 +170,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
|||||||
))
|
))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if oss {
|
||||||
|
codex_ollama::ensure_oss_ready(&config)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?;
|
||||||
|
}
|
||||||
|
|
||||||
// Print the effective configuration and prompt so users can see what Codex
|
// Print the effective configuration and prompt so users can see what Codex
|
||||||
// is using.
|
// is using.
|
||||||
event_processor.print_config_summary(&config, &prompt);
|
event_processor.print_config_summary(&config, &prompt);
|
||||||
|
|||||||
@@ -5,13 +5,17 @@ use serde_json::Value as JsonValue;
|
|||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::io;
|
use std::io;
|
||||||
|
|
||||||
use codex_core::WireApi;
|
|
||||||
|
|
||||||
use crate::parser::pull_events_from_value;
|
use crate::parser::pull_events_from_value;
|
||||||
use crate::pull::PullEvent;
|
use crate::pull::PullEvent;
|
||||||
use crate::pull::PullProgressReporter;
|
use crate::pull::PullProgressReporter;
|
||||||
use crate::url::base_url_to_host_root;
|
use crate::url::base_url_to_host_root;
|
||||||
use crate::url::is_openai_compatible_base_url;
|
use crate::url::is_openai_compatible_base_url;
|
||||||
|
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||||
|
use codex_core::ModelProviderInfo;
|
||||||
|
use codex_core::WireApi;
|
||||||
|
use codex_core::config::Config;
|
||||||
|
|
||||||
|
const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama";
|
||||||
|
|
||||||
/// Client for interacting with a local Ollama instance.
|
/// Client for interacting with a local Ollama instance.
|
||||||
pub struct OllamaClient {
|
pub struct OllamaClient {
|
||||||
@@ -21,74 +25,77 @@ pub struct OllamaClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl OllamaClient {
|
impl OllamaClient {
|
||||||
pub fn from_oss_provider() -> Self {
|
/// Construct a client for the built‑in open‑source ("oss") model provider
|
||||||
|
/// and verify that a local Ollama server is reachable. If no server is
|
||||||
|
/// detected, returns an error with helpful installation/run instructions.
|
||||||
|
pub async fn try_from_oss_provider(config: &Config) -> io::Result<Self> {
|
||||||
|
// Note that we must look up the provider from the Config to ensure that
|
||||||
|
// any overrides the user has in their config.toml are taken into
|
||||||
|
// account.
|
||||||
|
let provider = config
|
||||||
|
.model_providers
|
||||||
|
.get(BUILT_IN_OSS_MODEL_PROVIDER_ID)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::NotFound,
|
||||||
|
format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Self::try_from_provider(provider).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
async fn try_from_provider_with_base_url(base_url: &str) -> io::Result<Self> {
|
||||||
|
let provider = codex_core::create_oss_provider_with_base_url(base_url);
|
||||||
|
Self::try_from_provider(&provider).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a client from a provider definition and verify the server is reachable.
|
||||||
|
async fn try_from_provider(provider: &ModelProviderInfo) -> io::Result<Self> {
|
||||||
#![allow(clippy::expect_used)]
|
#![allow(clippy::expect_used)]
|
||||||
// Use the built-in OSS provider's base URL.
|
|
||||||
let built_in_model_providers = codex_core::built_in_model_providers();
|
|
||||||
let provider = built_in_model_providers
|
|
||||||
.get(codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID)
|
|
||||||
.expect("oss provider must exist");
|
|
||||||
let base_url = provider
|
let base_url = provider
|
||||||
.base_url
|
.base_url
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.expect("oss provider must have a base_url");
|
.expect("oss provider must have a base_url");
|
||||||
Self::from_provider(base_url, provider.wire_api)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct a client for the built‑in open‑source ("oss") model provider
|
|
||||||
/// and verify that a local Ollama server is reachable. If no server is
|
|
||||||
/// detected, returns an error with helpful installation/run instructions.
|
|
||||||
pub async fn try_from_oss_provider() -> io::Result<Self> {
|
|
||||||
let client = Self::from_oss_provider();
|
|
||||||
if client.probe_server().await? {
|
|
||||||
Ok(client)
|
|
||||||
} else {
|
|
||||||
Err(io::Error::other(
|
|
||||||
"No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build a client from a provider definition. Falls back to the default
|
|
||||||
/// local URL if no base_url is configured.
|
|
||||||
fn from_provider(base_url: &str, wire_api: WireApi) -> Self {
|
|
||||||
let uses_openai_compat = is_openai_compatible_base_url(base_url)
|
let uses_openai_compat = is_openai_compatible_base_url(base_url)
|
||||||
|| matches!(wire_api, WireApi::Chat) && is_openai_compatible_base_url(base_url);
|
|| matches!(provider.wire_api, WireApi::Chat)
|
||||||
|
&& is_openai_compatible_base_url(base_url);
|
||||||
let host_root = base_url_to_host_root(base_url);
|
let host_root = base_url_to_host_root(base_url);
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.connect_timeout(std::time::Duration::from_secs(5))
|
.connect_timeout(std::time::Duration::from_secs(5))
|
||||||
.build()
|
.build()
|
||||||
.unwrap_or_else(|_| reqwest::Client::new());
|
.unwrap_or_else(|_| reqwest::Client::new());
|
||||||
Self {
|
let client = Self {
|
||||||
client,
|
client,
|
||||||
host_root,
|
host_root,
|
||||||
uses_openai_compat,
|
uses_openai_compat,
|
||||||
}
|
};
|
||||||
}
|
client.probe_server().await?;
|
||||||
|
Ok(client)
|
||||||
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
|
|
||||||
#[cfg(test)]
|
|
||||||
fn from_host_root(host_root: impl Into<String>) -> Self {
|
|
||||||
let client = reqwest::Client::builder()
|
|
||||||
.connect_timeout(std::time::Duration::from_secs(5))
|
|
||||||
.build()
|
|
||||||
.unwrap_or_else(|_| reqwest::Client::new());
|
|
||||||
Self {
|
|
||||||
client,
|
|
||||||
host_root: host_root.into(),
|
|
||||||
uses_openai_compat: false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
|
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
|
||||||
pub async fn probe_server(&self) -> io::Result<bool> {
|
async fn probe_server(&self) -> io::Result<()> {
|
||||||
let url = if self.uses_openai_compat {
|
let url = if self.uses_openai_compat {
|
||||||
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
|
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
|
||||||
} else {
|
} else {
|
||||||
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
|
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
|
||||||
};
|
};
|
||||||
let resp = self.client.get(url).send().await;
|
let resp = self.client.get(url).send().await.map_err(|err| {
|
||||||
Ok(matches!(resp, Ok(r) if r.status().is_success()))
|
tracing::warn!("Failed to connect to Ollama server: {err:?}");
|
||||||
|
io::Error::other(OLLAMA_CONNECTION_ERROR)
|
||||||
|
})?;
|
||||||
|
if resp.status().is_success() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to probe server at {}: HTTP {}",
|
||||||
|
self.host_root,
|
||||||
|
resp.status()
|
||||||
|
);
|
||||||
|
Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the list of model names known to the local Ollama instance.
|
/// Return the list of model names known to the local Ollama instance.
|
||||||
@@ -210,6 +217,20 @@ impl OllamaClient {
|
|||||||
"Pull stream ended unexpectedly without success.",
|
"Pull stream ended unexpectedly without success.",
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
|
||||||
|
#[cfg(test)]
|
||||||
|
fn from_host_root(host_root: impl Into<String>) -> Self {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.connect_timeout(std::time::Duration::from_secs(5))
|
||||||
|
.build()
|
||||||
|
.unwrap_or_else(|_| reqwest::Client::new());
|
||||||
|
Self {
|
||||||
|
client,
|
||||||
|
host_root: host_root.into(),
|
||||||
|
uses_openai_compat: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -217,34 +238,6 @@ mod tests {
|
|||||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
/// Simple RAII guard to set an environment variable for the duration of a test
|
|
||||||
/// and restore the previous value (or remove it) on drop to avoid cross-test
|
|
||||||
/// interference.
|
|
||||||
struct EnvVarGuard {
|
|
||||||
key: String,
|
|
||||||
prev: Option<String>,
|
|
||||||
}
|
|
||||||
impl EnvVarGuard {
|
|
||||||
fn set(key: &str, value: String) -> Self {
|
|
||||||
let prev = std::env::var(key).ok();
|
|
||||||
// set_var is safe but we mirror existing tests that use an unsafe block
|
|
||||||
// to silence edition lints around global mutation during tests.
|
|
||||||
unsafe { std::env::set_var(key, value) };
|
|
||||||
Self {
|
|
||||||
key: key.to_string(),
|
|
||||||
prev,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl Drop for EnvVarGuard {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
match &self.prev {
|
|
||||||
Some(v) => unsafe { std::env::set_var(&self.key, v) },
|
|
||||||
None => unsafe { std::env::remove_var(&self.key) },
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
|
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fetch_models_happy_path() {
|
async fn test_fetch_models_happy_path() {
|
||||||
@@ -296,7 +289,7 @@ mod tests {
|
|||||||
.mount(&server)
|
.mount(&server)
|
||||||
.await;
|
.await;
|
||||||
let native = OllamaClient::from_host_root(server.uri());
|
let native = OllamaClient::from_host_root(server.uri());
|
||||||
assert!(native.probe_server().await.expect("probe native"));
|
native.probe_server().await.expect("probe native");
|
||||||
|
|
||||||
// OpenAI compatibility endpoint
|
// OpenAI compatibility endpoint
|
||||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||||
@@ -304,11 +297,14 @@ mod tests {
|
|||||||
.respond_with(wiremock::ResponseTemplate::new(200))
|
.respond_with(wiremock::ResponseTemplate::new(200))
|
||||||
.mount(&server)
|
.mount(&server)
|
||||||
.await;
|
.await;
|
||||||
// Ensure the built-in OSS provider points at our mock server for this test
|
let ollama_client =
|
||||||
// to avoid depending on any globally configured environment from other tests.
|
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
||||||
let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri()));
|
.await
|
||||||
let ollama_client = OllamaClient::from_oss_provider();
|
.expect("probe OpenAI compat");
|
||||||
assert!(ollama_client.probe_server().await.expect("probe compat"));
|
ollama_client
|
||||||
|
.probe_server()
|
||||||
|
.await
|
||||||
|
.expect("probe OpenAI compat");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -322,9 +318,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let server = wiremock::MockServer::start().await;
|
let server = wiremock::MockServer::start().await;
|
||||||
// Configure built‑in `oss` provider to point at this mock server.
|
|
||||||
// set_var is unsafe on Rust 2024 edition; use unsafe block in tests.
|
|
||||||
let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri()));
|
|
||||||
|
|
||||||
// OpenAI‑compat models endpoint responds OK.
|
// OpenAI‑compat models endpoint responds OK.
|
||||||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||||||
@@ -333,7 +326,7 @@ mod tests {
|
|||||||
.mount(&server)
|
.mount(&server)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let _client = OllamaClient::try_from_oss_provider()
|
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
||||||
.await
|
.await
|
||||||
.expect("client should be created when probe succeeds");
|
.expect("client should be created when probe succeeds");
|
||||||
}
|
}
|
||||||
@@ -349,18 +342,10 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let server = wiremock::MockServer::start().await;
|
let server = wiremock::MockServer::start().await;
|
||||||
// Point oss provider at our mock server but do NOT set up a handler
|
let err = OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
||||||
// for /v1/models so the request returns a non‑success status.
|
|
||||||
unsafe { std::env::set_var("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())) };
|
|
||||||
|
|
||||||
let err = OllamaClient::try_from_oss_provider()
|
|
||||||
.await
|
.await
|
||||||
.err()
|
.err()
|
||||||
.expect("expected error");
|
.expect("expected error");
|
||||||
let msg = err.to_string();
|
assert_eq!(OLLAMA_CONNECTION_ERROR, err.to_string());
|
||||||
assert!(
|
|
||||||
msg.contains("No running Ollama server detected."),
|
|
||||||
"msg = {msg}"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ mod pull;
|
|||||||
mod url;
|
mod url;
|
||||||
|
|
||||||
pub use client::OllamaClient;
|
pub use client::OllamaClient;
|
||||||
|
use codex_core::config::Config;
|
||||||
pub use pull::CliProgressReporter;
|
pub use pull::CliProgressReporter;
|
||||||
pub use pull::PullEvent;
|
pub use pull::PullEvent;
|
||||||
pub use pull::PullProgressReporter;
|
pub use pull::PullProgressReporter;
|
||||||
@@ -15,38 +16,29 @@ pub const DEFAULT_OSS_MODEL: &str = "gpt-oss:20b";
|
|||||||
/// Prepare the local OSS environment when `--oss` is selected.
|
/// Prepare the local OSS environment when `--oss` is selected.
|
||||||
///
|
///
|
||||||
/// - Ensures a local Ollama server is reachable.
|
/// - Ensures a local Ollama server is reachable.
|
||||||
/// - Selects the final model name (CLI override or default).
|
|
||||||
/// - Checks if the model exists locally and pulls it if missing.
|
/// - Checks if the model exists locally and pulls it if missing.
|
||||||
///
|
pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> {
|
||||||
/// Returns the final model name that should be used by the caller.
|
|
||||||
pub async fn ensure_oss_ready(cli_model: Option<String>) -> std::io::Result<String> {
|
|
||||||
// Only download when the requested model is the default OSS model (or when -m is not provided).
|
// Only download when the requested model is the default OSS model (or when -m is not provided).
|
||||||
let should_download = cli_model
|
let model = config.model.as_ref();
|
||||||
.as_deref()
|
|
||||||
.map(|name| name == DEFAULT_OSS_MODEL)
|
|
||||||
.unwrap_or(true);
|
|
||||||
let model = cli_model.unwrap_or_else(|| DEFAULT_OSS_MODEL.to_string());
|
|
||||||
|
|
||||||
// Verify local Ollama is reachable.
|
// Verify local Ollama is reachable.
|
||||||
let ollama_client = crate::OllamaClient::try_from_oss_provider().await?;
|
let ollama_client = crate::OllamaClient::try_from_oss_provider(config).await?;
|
||||||
|
|
||||||
if should_download {
|
// If the model is not present locally, pull it.
|
||||||
// If the model is not present locally, pull it.
|
match ollama_client.fetch_models().await {
|
||||||
match ollama_client.fetch_models().await {
|
Ok(models) => {
|
||||||
Ok(models) => {
|
if !models.iter().any(|m| m == model) {
|
||||||
if !models.iter().any(|m| m == &model) {
|
let mut reporter = crate::CliProgressReporter::new();
|
||||||
let mut reporter = crate::CliProgressReporter::new();
|
ollama_client
|
||||||
ollama_client
|
.pull_with_reporter(model, &mut reporter)
|
||||||
.pull_with_reporter(&model, &mut reporter)
|
.await?;
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
// Not fatal; higher layers may still proceed and surface errors later.
|
|
||||||
tracing::warn!("Failed to query local models from Ollama: {}.", err);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Err(err) => {
|
||||||
|
// Not fatal; higher layers may still proceed and surface errors later.
|
||||||
|
tracing::warn!("Failed to query local models from Ollama: {}.", err);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(model)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use codex_core::config_types::SandboxMode;
|
|||||||
use codex_core::protocol::AskForApproval;
|
use codex_core::protocol::AskForApproval;
|
||||||
use codex_core::util::is_inside_git_repo;
|
use codex_core::util::is_inside_git_repo;
|
||||||
use codex_login::load_auth;
|
use codex_login::load_auth;
|
||||||
|
use codex_ollama::DEFAULT_OSS_MODEL;
|
||||||
use log_layer::TuiLogLayer;
|
use log_layer::TuiLogLayer;
|
||||||
use std::fs::OpenOptions;
|
use std::fs::OpenOptions;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
@@ -71,25 +72,27 @@ pub async fn run_main(
|
|||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// When using `--oss`, let the bootstrapper pick the model (defaulting to
|
||||||
|
// gpt-oss:20b) and ensure it is present locally. Also, force the built‑in
|
||||||
|
// `oss` model provider.
|
||||||
|
let model = if let Some(model) = &cli.model {
|
||||||
|
Some(model.clone())
|
||||||
|
} else if cli.oss {
|
||||||
|
Some(DEFAULT_OSS_MODEL.to_owned())
|
||||||
|
} else {
|
||||||
|
None // No model specified, will use the default.
|
||||||
|
};
|
||||||
|
|
||||||
let model_provider_override = if cli.oss {
|
let model_provider_override = if cli.oss {
|
||||||
Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned())
|
Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = {
|
let config = {
|
||||||
// Load configuration and support CLI overrides.
|
// Load configuration and support CLI overrides.
|
||||||
let overrides = ConfigOverrides {
|
let overrides = ConfigOverrides {
|
||||||
// When using `--oss`, let the bootstrapper pick the model
|
model,
|
||||||
// (defaulting to gpt-oss:20b) and ensure it is present locally.
|
|
||||||
model: if cli.oss {
|
|
||||||
Some(
|
|
||||||
codex_ollama::ensure_oss_ready(cli.model.clone())
|
|
||||||
.await
|
|
||||||
.map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
cli.model.clone()
|
|
||||||
},
|
|
||||||
approval_policy,
|
approval_policy,
|
||||||
sandbox_mode,
|
sandbox_mode,
|
||||||
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
|
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
|
||||||
@@ -154,6 +157,12 @@ pub async fn run_main(
|
|||||||
.with_target(false)
|
.with_target(false)
|
||||||
.with_filter(env_filter());
|
.with_filter(env_filter());
|
||||||
|
|
||||||
|
if cli.oss {
|
||||||
|
codex_ollama::ensure_oss_ready(&config)
|
||||||
|
.await
|
||||||
|
.map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?;
|
||||||
|
}
|
||||||
|
|
||||||
// Channel that carries formatted log lines to the UI.
|
// Channel that carries formatted log lines to the UI.
|
||||||
let (log_tx, log_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
let (log_tx, log_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||||
let tui_layer = TuiLogLayer::new(log_tx.clone(), 120).with_filter(env_filter());
|
let tui_layer = TuiLogLayer::new(log_tx.clone(), 120).with_filter(env_filter());
|
||||||
|
|||||||
Reference in New Issue
Block a user