From 9285350842fc87b4bf84f4788151707153270b6f Mon Sep 17 00:00:00 2001 From: easong-openai Date: Tue, 5 Aug 2025 11:31:11 -0700 Subject: [PATCH] Introduce `--oss` flag to use gpt-oss models (#1848) This adds support for easily running Codex backed by a local Ollama instance running our new open source models. See https://github.com/openai/gpt-oss for details. If you pass in `--oss` you'll be prompted to install/launch ollama, and it will automatically download the 20b model and attempt to use it. We'll likely want to expand this with some options later to make the experience smoother for users who can't run the 20b or want to run the 120b. Co-authored-by: Michael Bolin --- README.md | 36 ++ codex-rs/Cargo.lock | 19 + codex-rs/Cargo.toml | 1 + codex-rs/core/src/config.rs | 10 +- codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/model_family.rs | 2 + codex-rs/core/src/model_provider_info.rs | 120 ++++-- codex-rs/exec/Cargo.toml | 1 + codex-rs/exec/src/cli.rs | 3 + codex-rs/exec/src/lib.rs | 24 +- codex-rs/mcp-server/src/codex_tool_config.rs | 2 + .../src/tool_handlers/create_conversation.rs | 2 + codex-rs/ollama/Cargo.toml | 32 ++ codex-rs/ollama/src/client.rs | 366 ++++++++++++++++++ codex-rs/ollama/src/lib.rs | 52 +++ codex-rs/ollama/src/parser.rs | 82 ++++ codex-rs/ollama/src/pull.rs | 147 +++++++ codex-rs/ollama/src/url.rs | 39 ++ codex-rs/tui/Cargo.toml | 1 + codex-rs/tui/src/cli.rs | 6 + codex-rs/tui/src/lib.rs | 22 +- 21 files changed, 924 insertions(+), 44 deletions(-) create mode 100644 codex-rs/ollama/Cargo.toml create mode 100644 codex-rs/ollama/src/client.rs create mode 100644 codex-rs/ollama/src/lib.rs create mode 100644 codex-rs/ollama/src/parser.rs create mode 100644 codex-rs/ollama/src/pull.rs create mode 100644 codex-rs/ollama/src/url.rs diff --git a/README.md b/README.md index c7f6a1d5..dd5e4662 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ This is the home of the **Codex CLI**, which is a coding agent from OpenAI that - [Quickstart](#quickstart) - [OpenAI API Users](#openai-api-users) - [OpenAI Plus/Pro Users](#openai-pluspro-users) + - [Using OpenAI Open Source Models](#using-open-source-models) - [Why Codex?](#why-codex) - [Security model & permissions](#security-model--permissions) - [Platform sandboxing details](#platform-sandboxing-details) @@ -186,6 +187,41 @@ they'll be committed to your working directory. --- +## Using Open Source Models + +Codex can run fully locally against an OpenAI‑compatible OSS host (like Ollama) using the `--oss` flag: + +- Interactive UI: + - codex --oss +- Non‑interactive (programmatic) mode: + - echo "Refactor utils" | codex exec --oss + +Model selection when using `--oss`: + +- If you omit `-m/--model`, Codex defaults to -m gpt-oss:20b and will verify it exists locally (downloading if needed). +- To pick a different size, pass one of: + - -m "gpt-oss:20b" + - -m "gpt-oss:120b" + +Point Codex at your own OSS host: + +- By default, `--oss` talks to http://localhost:11434/v1. +- To use a different host, set one of these environment variables before running Codex: + - CODEX_OSS_BASE_URL, for example: + - CODEX_OSS_BASE_URL="http://my-ollama.example.com:11434/v1" codex --oss -m gpt-oss:20b + - or CODEX_OSS_PORT (when the host is localhost): + - CODEX_OSS_PORT=11434 codex --oss + +Advanced: you can persist this in your config instead of environment variables by overriding the built‑in `oss` provider in `~/.codex/config.toml`: + +```toml +[model_providers.oss] +name = "Open Source" +base_url = "http://my-ollama.example.com:11434/v1" +``` + +--- + ## Why Codex? Codex CLI is built for developers who already **live in the terminal** and want diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 2e20a7d6..4e21baf7 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -729,6 +729,7 @@ dependencies = [ "codex-arg0", "codex-common", "codex-core", + "codex-ollama", "owo-colors", "predicates", "serde_json", @@ -838,6 +839,23 @@ dependencies = [ "wiremock", ] +[[package]] +name = "codex-ollama" +version = "0.0.0" +dependencies = [ + "async-stream", + "bytes", + "codex-core", + "futures", + "reqwest", + "serde_json", + "tempfile", + "tokio", + "toml 0.9.4", + "tracing", + "wiremock", +] + [[package]] name = "codex-tui" version = "0.0.0" @@ -852,6 +870,7 @@ dependencies = [ "codex-core", "codex-file-search", "codex-login", + "codex-ollama", "color-eyre", "crossterm", "image", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 0f8085c7..0ed88522 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -14,6 +14,7 @@ members = [ "mcp-client", "mcp-server", "mcp-types", + "ollama", "tui", ] resolver = "2" diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index d97d5ec1..e62fcc39 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -385,6 +385,8 @@ pub struct ConfigOverrides { pub codex_linux_sandbox_exe: Option, pub base_instructions: Option, pub include_plan_tool: Option, + pub default_disable_response_storage: Option, + pub default_show_raw_agent_reasoning: Option, } impl Config { @@ -408,6 +410,8 @@ impl Config { codex_linux_sandbox_exe, base_instructions, include_plan_tool, + default_disable_response_storage, + default_show_raw_agent_reasoning, } = overrides; let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) { @@ -525,6 +529,7 @@ impl Config { disable_response_storage: config_profile .disable_response_storage .or(cfg.disable_response_storage) + .or(default_disable_response_storage) .unwrap_or(false), notify: cfg.notify, user_instructions, @@ -539,7 +544,10 @@ impl Config { codex_linux_sandbox_exe, hide_agent_reasoning: cfg.hide_agent_reasoning.unwrap_or(false), - show_raw_agent_reasoning: cfg.show_raw_agent_reasoning.unwrap_or(false), + show_raw_agent_reasoning: cfg + .show_raw_agent_reasoning + .or(default_show_raw_agent_reasoning) + .unwrap_or(false), model_reasoning_effort: config_profile .model_reasoning_effort .or(cfg.model_reasoning_effort) diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index f9c608b5..965cb77b 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -28,6 +28,7 @@ mod mcp_connection_manager; mod mcp_tool_call; mod message_history; mod model_provider_info; +pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID; pub use model_provider_info::ModelProviderInfo; pub use model_provider_info::WireApi; pub use model_provider_info::built_in_model_providers; diff --git a/codex-rs/core/src/model_family.rs b/codex-rs/core/src/model_family.rs index 9bc61270..7c4a9de6 100644 --- a/codex-rs/core/src/model_family.rs +++ b/codex-rs/core/src/model_family.rs @@ -85,6 +85,8 @@ pub fn find_family_for_model(slug: &str) -> Option { ) } else if slug.starts_with("gpt-4o") { simple_model_family!(slug, "gpt-4o") + } else if slug.starts_with("gpt-oss") { + simple_model_family!(slug, "gpt-oss") } else if slug.starts_with("gpt-3.5") { simple_model_family!(slug, "gpt-3.5") } else { diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 49478660..595f05ef 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -226,53 +226,93 @@ impl ModelProviderInfo { } } +const DEFAULT_OLLAMA_PORT: u32 = 11434; + +pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss"; + /// Built-in default provider list. pub fn built_in_model_providers() -> HashMap { use ModelProviderInfo as P; - // We do not want to be in the business of adjucating which third-party - // providers are bundled with Codex CLI, so we only include the OpenAI - // provider by default. Users are encouraged to add to `model_providers` - // in config.toml to add their own providers. - [( - "openai", - P { - name: "OpenAI".into(), - // Allow users to override the default OpenAI endpoint by - // exporting `OPENAI_BASE_URL`. This is useful when pointing - // Codex at a proxy, mock server, or Azure-style deployment - // without requiring a full TOML override for the built-in - // OpenAI provider. - base_url: std::env::var("OPENAI_BASE_URL") + // These CODEX_OSS_ environment variables are experimental: we may + // switch to reading values from config.toml instead. + let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()) + { + Some(url) => url, + None => format!( + "http://localhost:{port}/v1", + port = std::env::var("CODEX_OSS_PORT") .ok() - .filter(|v| !v.trim().is_empty()), - env_key: None, - env_key_instructions: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: Some( - [("version".to_string(), env!("CARGO_PKG_VERSION").to_string())] + .filter(|v| !v.trim().is_empty()) + .and_then(|v| v.parse::().ok()) + .unwrap_or(DEFAULT_OLLAMA_PORT) + ), + }; + + // We do not want to be in the business of adjucating which third-party + // providers are bundled with Codex CLI, so we only include the OpenAI and + // open source ("oss") providers by default. Users are encouraged to add to + // `model_providers` in config.toml to add their own providers. + [ + ( + "openai", + P { + name: "OpenAI".into(), + // Allow users to override the default OpenAI endpoint by + // exporting `OPENAI_BASE_URL`. This is useful when pointing + // Codex at a proxy, mock server, or Azure-style deployment + // without requiring a full TOML override for the built-in + // OpenAI provider. + base_url: std::env::var("OPENAI_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: Some( + [("version".to_string(), env!("CARGO_PKG_VERSION").to_string())] + .into_iter() + .collect(), + ), + env_http_headers: Some( + [ + ( + "OpenAI-Organization".to_string(), + "OPENAI_ORGANIZATION".to_string(), + ), + ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), + ] .into_iter() .collect(), - ), - env_http_headers: Some( - [ - ( - "OpenAI-Organization".to_string(), - "OPENAI_ORGANIZATION".to_string(), - ), - ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), - ] - .into_iter() - .collect(), - ), - // Use global defaults for retry/timeout unless overridden in config.toml. - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_auth: true, - }, - )] + ), + // Use global defaults for retry/timeout unless overridden in config.toml. + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: true, + }, + ), + ( + BUILT_IN_OSS_MODEL_PROVIDER_ID, + P { + name: "Open Source".into(), + base_url: Some(codex_oss_base_url), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: false, + }, + ), + ] .into_iter() .map(|(k, v)| (k.to_string(), v)) .collect() diff --git a/codex-rs/exec/Cargo.toml b/codex-rs/exec/Cargo.toml index cd521410..aee480d7 100644 --- a/codex-rs/exec/Cargo.toml +++ b/codex-rs/exec/Cargo.toml @@ -25,6 +25,7 @@ codex-common = { path = "../common", features = [ "sandbox_summary", ] } codex-core = { path = "../core" } +codex-ollama = { path = "../ollama" } owo-colors = "4.2.0" serde_json = "1" shlex = "1.3.0" diff --git a/codex-rs/exec/src/cli.rs b/codex-rs/exec/src/cli.rs index 53af25c7..ea659e32 100644 --- a/codex-rs/exec/src/cli.rs +++ b/codex-rs/exec/src/cli.rs @@ -14,6 +14,9 @@ pub struct Cli { #[arg(long, short = 'm')] pub model: Option, + #[arg(long = "oss", default_value_t = false)] + pub oss: bool, + /// Select the sandbox policy to use when executing model-generated shell /// commands. #[arg(long = "sandbox", short = 's')] diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index ce4d7f65..c1af4f5b 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -9,6 +9,7 @@ use std::path::PathBuf; use std::sync::Arc; pub use cli::Cli; +use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::codex_wrapper::CodexConversation; use codex_core::codex_wrapper::{self}; use codex_core::config::Config; @@ -35,6 +36,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any let Cli { images, model, + oss, config_profile, full_auto, dangerously_bypass_approvals_and_sandbox, @@ -114,6 +116,24 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any sandbox_mode_cli_arg.map(Into::::into) }; + // When using `--oss`, let the bootstrapper pick the model (defaulting to + // gpt-oss:20b) and ensure it is present locally. Also, force the built‑in + // `oss` model provider. + let model_provider_override = if oss { + Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned()) + } else { + None + }; + let model = if oss { + Some( + codex_ollama::ensure_oss_ready(model.clone()) + .await + .map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?, + ) + } else { + model + }; + // Load configuration and determine approval policy let overrides = ConfigOverrides { model, @@ -123,10 +143,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any approval_policy: Some(AskForApproval::Never), sandbox_mode, cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)), - model_provider: None, + model_provider: model_provider_override, codex_linux_sandbox_exe, base_instructions: None, include_plan_tool: None, + default_disable_response_storage: oss.then_some(true), + default_show_raw_agent_reasoning: oss.then_some(true), }; // Parse `-c` overrides. let cli_kv_overrides = match config_overrides.parse_overrides() { diff --git a/codex-rs/mcp-server/src/codex_tool_config.rs b/codex-rs/mcp-server/src/codex_tool_config.rs index 877d0e05..f1a502bb 100644 --- a/codex-rs/mcp-server/src/codex_tool_config.rs +++ b/codex-rs/mcp-server/src/codex_tool_config.rs @@ -158,6 +158,8 @@ impl CodexToolCallParam { codex_linux_sandbox_exe, base_instructions, include_plan_tool, + default_disable_response_storage: None, + default_show_raw_agent_reasoning: None, }; let cli_overrides = cli_overrides diff --git a/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs b/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs index 28a89651..c1f40356 100644 --- a/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs +++ b/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs @@ -59,6 +59,8 @@ pub(crate) async fn handle_create_conversation( codex_linux_sandbox_exe: None, base_instructions, include_plan_tool: None, + default_disable_response_storage: None, + default_show_raw_agent_reasoning: None, }; let cfg: CodexConfig = match CodexConfig::load_with_cli_overrides(cli_overrides, overrides) { diff --git a/codex-rs/ollama/Cargo.toml b/codex-rs/ollama/Cargo.toml new file mode 100644 index 00000000..ead9a064 --- /dev/null +++ b/codex-rs/ollama/Cargo.toml @@ -0,0 +1,32 @@ +[package] +edition = "2024" +name = "codex-ollama" +version = { workspace = true } + +[lib] +name = "codex_ollama" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +async-stream = "0.3" +bytes = "1.10.1" +codex-core = { path = "../core" } +futures = "0.3" +reqwest = { version = "0.12", features = ["json", "stream"] } +serde_json = "1" +tokio = { version = "1", features = [ + "io-std", + "macros", + "process", + "rt-multi-thread", + "signal", +] } +toml = "0.9.2" +tracing = { version = "0.1.41", features = ["log"] } +wiremock = "0.6" + +[dev-dependencies] +tempfile = "3" diff --git a/codex-rs/ollama/src/client.rs b/codex-rs/ollama/src/client.rs new file mode 100644 index 00000000..8a15039f --- /dev/null +++ b/codex-rs/ollama/src/client.rs @@ -0,0 +1,366 @@ +use bytes::BytesMut; +use futures::StreamExt; +use futures::stream::BoxStream; +use serde_json::Value as JsonValue; +use std::collections::VecDeque; +use std::io; + +use codex_core::WireApi; + +use crate::parser::pull_events_from_value; +use crate::pull::PullEvent; +use crate::pull::PullProgressReporter; +use crate::url::base_url_to_host_root; +use crate::url::is_openai_compatible_base_url; + +/// Client for interacting with a local Ollama instance. +pub struct OllamaClient { + client: reqwest::Client, + host_root: String, + uses_openai_compat: bool, +} + +impl OllamaClient { + pub fn from_oss_provider() -> Self { + #![allow(clippy::expect_used)] + // Use the built-in OSS provider's base URL. + let built_in_model_providers = codex_core::built_in_model_providers(); + let provider = built_in_model_providers + .get(codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID) + .expect("oss provider must exist"); + let base_url = provider + .base_url + .as_ref() + .expect("oss provider must have a base_url"); + Self::from_provider(base_url, provider.wire_api) + } + + /// Construct a client for the built‑in open‑source ("oss") model provider + /// and verify that a local Ollama server is reachable. If no server is + /// detected, returns an error with helpful installation/run instructions. + pub async fn try_from_oss_provider() -> io::Result { + let client = Self::from_oss_provider(); + if client.probe_server().await? { + Ok(client) + } else { + Err(io::Error::other( + "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama", + )) + } + } + + /// Build a client from a provider definition. Falls back to the default + /// local URL if no base_url is configured. + fn from_provider(base_url: &str, wire_api: WireApi) -> Self { + let uses_openai_compat = is_openai_compatible_base_url(base_url) + || matches!(wire_api, WireApi::Chat) && is_openai_compatible_base_url(base_url); + let host_root = base_url_to_host_root(base_url); + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { + client, + host_root, + uses_openai_compat, + } + } + + /// Low-level constructor given a raw host root, e.g. "http://localhost:11434". + #[cfg(test)] + fn from_host_root(host_root: impl Into) -> Self { + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { + client, + host_root: host_root.into(), + uses_openai_compat: false, + } + } + + /// Probe whether the server is reachable by hitting the appropriate health endpoint. + pub async fn probe_server(&self) -> io::Result { + let url = if self.uses_openai_compat { + format!("{}/v1/models", self.host_root.trim_end_matches('/')) + } else { + format!("{}/api/tags", self.host_root.trim_end_matches('/')) + }; + let resp = self.client.get(url).send().await; + Ok(matches!(resp, Ok(r) if r.status().is_success())) + } + + /// Return the list of model names known to the local Ollama instance. + pub async fn fetch_models(&self) -> io::Result> { + let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/')); + let resp = self + .client + .get(tags_url) + .send() + .await + .map_err(io::Error::other)?; + if !resp.status().is_success() { + return Ok(Vec::new()); + } + let val = resp.json::().await.map_err(io::Error::other)?; + let names = val + .get("models") + .and_then(|m| m.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.get("name").and_then(|n| n.as_str())) + .map(|s| s.to_string()) + .collect::>() + }) + .unwrap_or_default(); + Ok(names) + } + + /// Start a model pull and emit streaming events. The returned stream ends when + /// a Success event is observed or the server closes the connection. + pub async fn pull_model_stream( + &self, + model: &str, + ) -> io::Result> { + let url = format!("{}/api/pull", self.host_root.trim_end_matches('/')); + let resp = self + .client + .post(url) + .json(&serde_json::json!({"model": model, "stream": true})) + .send() + .await + .map_err(io::Error::other)?; + if !resp.status().is_success() { + return Err(io::Error::other(format!( + "failed to start pull: HTTP {}", + resp.status() + ))); + } + + let mut stream = resp.bytes_stream(); + let mut buf = BytesMut::new(); + let _pending: VecDeque = VecDeque::new(); + + // Using an async stream adaptor backed by unfold-like manual loop. + let s = async_stream::stream! { + while let Some(chunk) = stream.next().await { + match chunk { + Ok(bytes) => { + buf.extend_from_slice(&bytes); + while let Some(pos) = buf.iter().position(|b| *b == b'\n') { + let line = buf.split_to(pos + 1); + if let Ok(text) = std::str::from_utf8(&line) { + let text = text.trim(); + if text.is_empty() { continue; } + if let Ok(value) = serde_json::from_str::(text) { + for ev in pull_events_from_value(&value) { yield ev; } + if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) { + yield PullEvent::Error(err_msg.to_string()); + return; + } + if let Some(status) = value.get("status").and_then(|s| s.as_str()) { + if status == "success" { yield PullEvent::Success; return; } + } + } + } + } + } + Err(_) => { + // Connection error: end the stream. + return; + } + } + } + }; + + Ok(Box::pin(s)) + } + + /// High-level helper to pull a model and drive a progress reporter. + pub async fn pull_with_reporter( + &self, + model: &str, + reporter: &mut dyn PullProgressReporter, + ) -> io::Result<()> { + reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?; + let mut stream = self.pull_model_stream(model).await?; + while let Some(event) = stream.next().await { + reporter.on_event(&event)?; + match event { + PullEvent::Success => { + return Ok(()); + } + PullEvent::Error(err) => { + // Emperically, ollama returns a 200 OK response even when + // the output stream includes an error message. Verify with: + // + // `curl -i http://localhost:11434/api/pull -d '{ "model": "foobarbaz" }'` + // + // As such, we have to check the event stream, not the + // HTTP response status, to determine whether to return Err. + return Err(io::Error::other(format!("Pull failed: {err}"))); + } + PullEvent::ChunkProgress { .. } | PullEvent::Status(_) => { + continue; + } + } + } + Err(io::Error::other( + "Pull stream ended unexpectedly without success.", + )) + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::expect_used, clippy::unwrap_used)] + use super::*; + + /// Simple RAII guard to set an environment variable for the duration of a test + /// and restore the previous value (or remove it) on drop to avoid cross-test + /// interference. + struct EnvVarGuard { + key: String, + prev: Option, + } + impl EnvVarGuard { + fn set(key: &str, value: String) -> Self { + let prev = std::env::var(key).ok(); + // set_var is safe but we mirror existing tests that use an unsafe block + // to silence edition lints around global mutation during tests. + unsafe { std::env::set_var(key, value) }; + Self { + key: key.to_string(), + prev, + } + } + } + impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.prev { + Some(v) => unsafe { std::env::set_var(&self.key, v) }, + None => unsafe { std::env::remove_var(&self.key) }, + } + } + } + + // Happy-path tests using a mock HTTP server; skip if sandbox network is disabled. + #[tokio::test] + async fn test_fetch_models_happy_path() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_fetch_models_happy_path", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/api/tags")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = OllamaClient::from_host_root(server.uri()); + let models = client.fetch_models().await.expect("fetch models"); + assert!(models.contains(&"llama3.2:3b".to_string())); + assert!(models.contains(&"mistral".to_string())); + } + + #[tokio::test] + async fn test_probe_server_happy_path_openai_compat_and_native() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} set; skipping test_probe_server_happy_path_openai_compat_and_native", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + + // Native endpoint + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/api/tags")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + let native = OllamaClient::from_host_root(server.uri()); + assert!(native.probe_server().await.expect("probe native")); + + // OpenAI compatibility endpoint + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/v1/models")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + // Ensure the built-in OSS provider points at our mock server for this test + // to avoid depending on any globally configured environment from other tests. + let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())); + let ollama_client = OllamaClient::from_oss_provider(); + assert!(ollama_client.probe_server().await.expect("probe compat")); + } + + #[tokio::test] + async fn test_try_from_oss_provider_ok_when_server_running() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} set; skipping test_try_from_oss_provider_ok_when_server_running", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + // Configure built‑in `oss` provider to point at this mock server. + // set_var is unsafe on Rust 2024 edition; use unsafe block in tests. + let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())); + + // OpenAI‑compat models endpoint responds OK. + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/v1/models")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + + let _client = OllamaClient::try_from_oss_provider() + .await + .expect("client should be created when probe succeeds"); + } + + #[tokio::test] + async fn test_try_from_oss_provider_err_when_server_missing() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} set; skipping test_try_from_oss_provider_err_when_server_missing", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + // Point oss provider at our mock server but do NOT set up a handler + // for /v1/models so the request returns a non‑success status. + unsafe { std::env::set_var("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())) }; + + let err = OllamaClient::try_from_oss_provider() + .await + .err() + .expect("expected error"); + let msg = err.to_string(); + assert!( + msg.contains("No running Ollama server detected."), + "msg = {msg}" + ); + } +} diff --git a/codex-rs/ollama/src/lib.rs b/codex-rs/ollama/src/lib.rs new file mode 100644 index 00000000..d6f1e04d --- /dev/null +++ b/codex-rs/ollama/src/lib.rs @@ -0,0 +1,52 @@ +mod client; +mod parser; +mod pull; +mod url; + +pub use client::OllamaClient; +pub use pull::CliProgressReporter; +pub use pull::PullEvent; +pub use pull::PullProgressReporter; +pub use pull::TuiProgressReporter; + +/// Default OSS model to use when `--oss` is passed without an explicit `-m`. +pub const DEFAULT_OSS_MODEL: &str = "gpt-oss:20b"; + +/// Prepare the local OSS environment when `--oss` is selected. +/// +/// - Ensures a local Ollama server is reachable. +/// - Selects the final model name (CLI override or default). +/// - Checks if the model exists locally and pulls it if missing. +/// +/// Returns the final model name that should be used by the caller. +pub async fn ensure_oss_ready(cli_model: Option) -> std::io::Result { + // Only download when the requested model is the default OSS model (or when -m is not provided). + let should_download = cli_model + .as_deref() + .map(|name| name == DEFAULT_OSS_MODEL) + .unwrap_or(true); + let model = cli_model.unwrap_or_else(|| DEFAULT_OSS_MODEL.to_string()); + + // Verify local Ollama is reachable. + let ollama_client = crate::OllamaClient::try_from_oss_provider().await?; + + if should_download { + // If the model is not present locally, pull it. + match ollama_client.fetch_models().await { + Ok(models) => { + if !models.iter().any(|m| m == &model) { + let mut reporter = crate::CliProgressReporter::new(); + ollama_client + .pull_with_reporter(&model, &mut reporter) + .await?; + } + } + Err(err) => { + // Not fatal; higher layers may still proceed and surface errors later. + tracing::warn!("Failed to query local models from Ollama: {}.", err); + } + } + } + + Ok(model) +} diff --git a/codex-rs/ollama/src/parser.rs b/codex-rs/ollama/src/parser.rs new file mode 100644 index 00000000..b3ed2ca8 --- /dev/null +++ b/codex-rs/ollama/src/parser.rs @@ -0,0 +1,82 @@ +use serde_json::Value as JsonValue; + +use crate::pull::PullEvent; + +// Convert a single JSON object representing a pull update into one or more events. +pub(crate) fn pull_events_from_value(value: &JsonValue) -> Vec { + let mut events = Vec::new(); + if let Some(status) = value.get("status").and_then(|s| s.as_str()) { + events.push(PullEvent::Status(status.to_string())); + if status == "success" { + events.push(PullEvent::Success); + } + } + let digest = value + .get("digest") + .and_then(|d| d.as_str()) + .unwrap_or("") + .to_string(); + let total = value.get("total").and_then(|t| t.as_u64()); + let completed = value.get("completed").and_then(|t| t.as_u64()); + if total.is_some() || completed.is_some() { + events.push(PullEvent::ChunkProgress { + digest, + total, + completed, + }); + } + events +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pull_events_decoder_status_and_success() { + let v: JsonValue = serde_json::json!({"status":"verifying"}); + let events = pull_events_from_value(&v); + assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying")); + + let v2: JsonValue = serde_json::json!({"status":"success"}); + let events2 = pull_events_from_value(&v2); + assert_eq!(events2.len(), 2); + assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success")); + assert!(matches!(events2[1], PullEvent::Success)); + } + + #[test] + fn test_pull_events_decoder_progress() { + let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100}); + let events = pull_events_from_value(&v); + assert_eq!(events.len(), 1); + match &events[0] { + PullEvent::ChunkProgress { + digest, + total, + completed, + } => { + assert_eq!(digest, "sha256:abc"); + assert_eq!(*total, Some(100)); + assert_eq!(*completed, None); + } + _ => panic!("expected ChunkProgress"), + } + + let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42}); + let events2 = pull_events_from_value(&v2); + assert_eq!(events2.len(), 1); + match &events2[0] { + PullEvent::ChunkProgress { + digest, + total, + completed, + } => { + assert_eq!(digest, "sha256:def"); + assert_eq!(*total, None); + assert_eq!(*completed, Some(42)); + } + _ => panic!("expected ChunkProgress"), + } + } +} diff --git a/codex-rs/ollama/src/pull.rs b/codex-rs/ollama/src/pull.rs new file mode 100644 index 00000000..0dd35cd7 --- /dev/null +++ b/codex-rs/ollama/src/pull.rs @@ -0,0 +1,147 @@ +use std::collections::HashMap; +use std::io; +use std::io::Write; + +/// Events emitted while pulling a model from Ollama. +#[derive(Debug, Clone)] +pub enum PullEvent { + /// A human-readable status message (e.g., "verifying", "writing"). + Status(String), + /// Byte-level progress update for a specific layer digest. + ChunkProgress { + digest: String, + total: Option, + completed: Option, + }, + /// The pull finished successfully. + Success, + + /// Error event with a message. + Error(String), +} + +/// A simple observer for pull progress events. Implementations decide how to +/// render progress (CLI, TUI, logs, ...). +pub trait PullProgressReporter { + fn on_event(&mut self, event: &PullEvent) -> io::Result<()>; +} + +/// A minimal CLI reporter that writes inline progress to stderr. +pub struct CliProgressReporter { + printed_header: bool, + last_line_len: usize, + last_completed_sum: u64, + last_instant: std::time::Instant, + totals_by_digest: HashMap, +} + +impl Default for CliProgressReporter { + fn default() -> Self { + Self::new() + } +} + +impl CliProgressReporter { + pub fn new() -> Self { + Self { + printed_header: false, + last_line_len: 0, + last_completed_sum: 0, + last_instant: std::time::Instant::now(), + totals_by_digest: HashMap::new(), + } + } +} + +impl PullProgressReporter for CliProgressReporter { + fn on_event(&mut self, event: &PullEvent) -> io::Result<()> { + let mut out = std::io::stderr(); + match event { + PullEvent::Status(status) => { + // Avoid noisy manifest messages; otherwise show status inline. + if status.eq_ignore_ascii_case("pulling manifest") { + return Ok(()); + } + let pad = self.last_line_len.saturating_sub(status.len()); + let line = format!("\r{status}{}", " ".repeat(pad)); + self.last_line_len = status.len(); + out.write_all(line.as_bytes())?; + out.flush() + } + PullEvent::ChunkProgress { + digest, + total, + completed, + } => { + if let Some(t) = *total { + self.totals_by_digest + .entry(digest.clone()) + .or_insert((0, 0)) + .0 = t; + } + if let Some(c) = *completed { + self.totals_by_digest + .entry(digest.clone()) + .or_insert((0, 0)) + .1 = c; + } + + let (sum_total, sum_completed) = self + .totals_by_digest + .values() + .fold((0u64, 0u64), |acc, (t, c)| (acc.0 + *t, acc.1 + *c)); + if sum_total > 0 { + if !self.printed_header { + let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0); + let header = format!("Downloading model: total {gb:.2} GB\n"); + out.write_all(b"\r\x1b[2K")?; + out.write_all(header.as_bytes())?; + self.printed_header = true; + } + let now = std::time::Instant::now(); + let dt = now + .duration_since(self.last_instant) + .as_secs_f64() + .max(0.001); + let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64; + let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt; + self.last_completed_sum = sum_completed; + self.last_instant = now; + + let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0); + let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0); + let pct = (sum_completed as f64) * 100.0 / (sum_total as f64); + let text = + format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s"); + let pad = self.last_line_len.saturating_sub(text.len()); + let line = format!("\r{text}{}", " ".repeat(pad)); + self.last_line_len = text.len(); + out.write_all(line.as_bytes())?; + out.flush() + } else { + Ok(()) + } + } + PullEvent::Error(_) => { + // This will be handled by the caller, so we don't do anything + // here or the error will be printed twice. + Ok(()) + } + PullEvent::Success => { + out.write_all(b"\n")?; + out.flush() + } + } + } +} + +/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and +/// CLI behavior aligned until a dedicated TUI integration is implemented. +#[derive(Default)] +pub struct TuiProgressReporter(CliProgressReporter); + +impl PullProgressReporter for TuiProgressReporter { + fn on_event(&mut self, event: &PullEvent) -> io::Result<()> { + self.0.on_event(event) + } +} diff --git a/codex-rs/ollama/src/url.rs b/codex-rs/ollama/src/url.rs new file mode 100644 index 00000000..7c143ce4 --- /dev/null +++ b/codex-rs/ollama/src/url.rs @@ -0,0 +1,39 @@ +/// Identify whether a base_url points at an OpenAI-compatible root (".../v1"). +pub(crate) fn is_openai_compatible_base_url(base_url: &str) -> bool { + base_url.trim_end_matches('/').ends_with("/v1") +} + +/// Convert a provider base_url into the native Ollama host root. +/// For example, "http://localhost:11434/v1" -> "http://localhost:11434". +pub fn base_url_to_host_root(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/v1") { + trimmed + .trim_end_matches("/v1") + .trim_end_matches('/') + .to_string() + } else { + trimmed.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_base_url_to_host_root() { + assert_eq!( + base_url_to_host_root("http://localhost:11434/v1"), + "http://localhost:11434" + ); + assert_eq!( + base_url_to_host_root("http://localhost:11434"), + "http://localhost:11434" + ); + assert_eq!( + base_url_to_host_root("http://localhost:11434/"), + "http://localhost:11434" + ); + } +} diff --git a/codex-rs/tui/Cargo.toml b/codex-rs/tui/Cargo.toml index 60af056a..49d843f0 100644 --- a/codex-rs/tui/Cargo.toml +++ b/codex-rs/tui/Cargo.toml @@ -33,6 +33,7 @@ codex-common = { path = "../common", features = [ codex-core = { path = "../core" } codex-file-search = { path = "../file-search" } codex-login = { path = "../login" } +codex-ollama = { path = "../ollama" } color-eyre = "0.6.3" crossterm = { version = "0.28.1", features = ["bracketed-paste"] } image = { version = "^0.25.6", default-features = false, features = ["jpeg"] } diff --git a/codex-rs/tui/src/cli.rs b/codex-rs/tui/src/cli.rs index cb1b725a..85dffbeb 100644 --- a/codex-rs/tui/src/cli.rs +++ b/codex-rs/tui/src/cli.rs @@ -17,6 +17,12 @@ pub struct Cli { #[arg(long, short = 'm')] pub model: Option, + /// Convenience flag to select the local open source model provider. + /// Equivalent to -c model_provider=oss; verifies a local Ollama server is + /// running. + #[arg(long = "oss", default_value_t = false)] + pub oss: bool, + /// Configuration profile from config.toml to specify default options. #[arg(long = "profile", short = 'p')] pub config_profile: Option, diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index c619ce8f..0b833b13 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -3,6 +3,7 @@ // alternate‑screen mode starts; that file opts‑out locally via `allow`. #![deny(clippy::print_stdout, clippy::print_stderr)] use app::App; +use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config_types::SandboxMode; @@ -70,18 +71,35 @@ pub async fn run_main( ) }; + let model_provider_override = if cli.oss { + Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned()) + } else { + None + }; let config = { // Load configuration and support CLI overrides. let overrides = ConfigOverrides { - model: cli.model.clone(), + // When using `--oss`, let the bootstrapper pick the model + // (defaulting to gpt-oss:20b) and ensure it is present locally. + model: if cli.oss { + Some( + codex_ollama::ensure_oss_ready(cli.model.clone()) + .await + .map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?, + ) + } else { + cli.model.clone() + }, approval_policy, sandbox_mode, cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)), - model_provider: None, + model_provider: model_provider_override, config_profile: cli.config_profile.clone(), codex_linux_sandbox_exe, base_instructions: None, include_plan_tool: Some(true), + default_disable_response_storage: cli.oss.then_some(true), + default_show_raw_agent_reasoning: cli.oss.then_some(true), }; // Parse `-c` overrides from the CLI. let cli_kv_overrides = match cli.config_overrides.parse_overrides() {