diff --git a/README.md b/README.md index c7f6a1d5..dd5e4662 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ This is the home of the **Codex CLI**, which is a coding agent from OpenAI that - [Quickstart](#quickstart) - [OpenAI API Users](#openai-api-users) - [OpenAI Plus/Pro Users](#openai-pluspro-users) + - [Using OpenAI Open Source Models](#using-open-source-models) - [Why Codex?](#why-codex) - [Security model & permissions](#security-model--permissions) - [Platform sandboxing details](#platform-sandboxing-details) @@ -186,6 +187,41 @@ they'll be committed to your working directory. --- +## Using Open Source Models + +Codex can run fully locally against an OpenAI‑compatible OSS host (like Ollama) using the `--oss` flag: + +- Interactive UI: + - codex --oss +- Non‑interactive (programmatic) mode: + - echo "Refactor utils" | codex exec --oss + +Model selection when using `--oss`: + +- If you omit `-m/--model`, Codex defaults to -m gpt-oss:20b and will verify it exists locally (downloading if needed). +- To pick a different size, pass one of: + - -m "gpt-oss:20b" + - -m "gpt-oss:120b" + +Point Codex at your own OSS host: + +- By default, `--oss` talks to http://localhost:11434/v1. +- To use a different host, set one of these environment variables before running Codex: + - CODEX_OSS_BASE_URL, for example: + - CODEX_OSS_BASE_URL="http://my-ollama.example.com:11434/v1" codex --oss -m gpt-oss:20b + - or CODEX_OSS_PORT (when the host is localhost): + - CODEX_OSS_PORT=11434 codex --oss + +Advanced: you can persist this in your config instead of environment variables by overriding the built‑in `oss` provider in `~/.codex/config.toml`: + +```toml +[model_providers.oss] +name = "Open Source" +base_url = "http://my-ollama.example.com:11434/v1" +``` + +--- + ## Why Codex? Codex CLI is built for developers who already **live in the terminal** and want diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 2e20a7d6..4e21baf7 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -729,6 +729,7 @@ dependencies = [ "codex-arg0", "codex-common", "codex-core", + "codex-ollama", "owo-colors", "predicates", "serde_json", @@ -838,6 +839,23 @@ dependencies = [ "wiremock", ] +[[package]] +name = "codex-ollama" +version = "0.0.0" +dependencies = [ + "async-stream", + "bytes", + "codex-core", + "futures", + "reqwest", + "serde_json", + "tempfile", + "tokio", + "toml 0.9.4", + "tracing", + "wiremock", +] + [[package]] name = "codex-tui" version = "0.0.0" @@ -852,6 +870,7 @@ dependencies = [ "codex-core", "codex-file-search", "codex-login", + "codex-ollama", "color-eyre", "crossterm", "image", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 0f8085c7..0ed88522 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -14,6 +14,7 @@ members = [ "mcp-client", "mcp-server", "mcp-types", + "ollama", "tui", ] resolver = "2" diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index d97d5ec1..e62fcc39 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -385,6 +385,8 @@ pub struct ConfigOverrides { pub codex_linux_sandbox_exe: Option, pub base_instructions: Option, pub include_plan_tool: Option, + pub default_disable_response_storage: Option, + pub default_show_raw_agent_reasoning: Option, } impl Config { @@ -408,6 +410,8 @@ impl Config { codex_linux_sandbox_exe, base_instructions, include_plan_tool, + default_disable_response_storage, + default_show_raw_agent_reasoning, } = overrides; let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) { @@ -525,6 +529,7 @@ impl Config { disable_response_storage: config_profile .disable_response_storage .or(cfg.disable_response_storage) + .or(default_disable_response_storage) .unwrap_or(false), notify: cfg.notify, user_instructions, @@ -539,7 +544,10 @@ impl Config { codex_linux_sandbox_exe, hide_agent_reasoning: cfg.hide_agent_reasoning.unwrap_or(false), - show_raw_agent_reasoning: cfg.show_raw_agent_reasoning.unwrap_or(false), + show_raw_agent_reasoning: cfg + .show_raw_agent_reasoning + .or(default_show_raw_agent_reasoning) + .unwrap_or(false), model_reasoning_effort: config_profile .model_reasoning_effort .or(cfg.model_reasoning_effort) diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index f9c608b5..965cb77b 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -28,6 +28,7 @@ mod mcp_connection_manager; mod mcp_tool_call; mod message_history; mod model_provider_info; +pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID; pub use model_provider_info::ModelProviderInfo; pub use model_provider_info::WireApi; pub use model_provider_info::built_in_model_providers; diff --git a/codex-rs/core/src/model_family.rs b/codex-rs/core/src/model_family.rs index 9bc61270..7c4a9de6 100644 --- a/codex-rs/core/src/model_family.rs +++ b/codex-rs/core/src/model_family.rs @@ -85,6 +85,8 @@ pub fn find_family_for_model(slug: &str) -> Option { ) } else if slug.starts_with("gpt-4o") { simple_model_family!(slug, "gpt-4o") + } else if slug.starts_with("gpt-oss") { + simple_model_family!(slug, "gpt-oss") } else if slug.starts_with("gpt-3.5") { simple_model_family!(slug, "gpt-3.5") } else { diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 49478660..595f05ef 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -226,53 +226,93 @@ impl ModelProviderInfo { } } +const DEFAULT_OLLAMA_PORT: u32 = 11434; + +pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss"; + /// Built-in default provider list. pub fn built_in_model_providers() -> HashMap { use ModelProviderInfo as P; - // We do not want to be in the business of adjucating which third-party - // providers are bundled with Codex CLI, so we only include the OpenAI - // provider by default. Users are encouraged to add to `model_providers` - // in config.toml to add their own providers. - [( - "openai", - P { - name: "OpenAI".into(), - // Allow users to override the default OpenAI endpoint by - // exporting `OPENAI_BASE_URL`. This is useful when pointing - // Codex at a proxy, mock server, or Azure-style deployment - // without requiring a full TOML override for the built-in - // OpenAI provider. - base_url: std::env::var("OPENAI_BASE_URL") + // These CODEX_OSS_ environment variables are experimental: we may + // switch to reading values from config.toml instead. + let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()) + { + Some(url) => url, + None => format!( + "http://localhost:{port}/v1", + port = std::env::var("CODEX_OSS_PORT") .ok() - .filter(|v| !v.trim().is_empty()), - env_key: None, - env_key_instructions: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: Some( - [("version".to_string(), env!("CARGO_PKG_VERSION").to_string())] + .filter(|v| !v.trim().is_empty()) + .and_then(|v| v.parse::().ok()) + .unwrap_or(DEFAULT_OLLAMA_PORT) + ), + }; + + // We do not want to be in the business of adjucating which third-party + // providers are bundled with Codex CLI, so we only include the OpenAI and + // open source ("oss") providers by default. Users are encouraged to add to + // `model_providers` in config.toml to add their own providers. + [ + ( + "openai", + P { + name: "OpenAI".into(), + // Allow users to override the default OpenAI endpoint by + // exporting `OPENAI_BASE_URL`. This is useful when pointing + // Codex at a proxy, mock server, or Azure-style deployment + // without requiring a full TOML override for the built-in + // OpenAI provider. + base_url: std::env::var("OPENAI_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: Some( + [("version".to_string(), env!("CARGO_PKG_VERSION").to_string())] + .into_iter() + .collect(), + ), + env_http_headers: Some( + [ + ( + "OpenAI-Organization".to_string(), + "OPENAI_ORGANIZATION".to_string(), + ), + ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), + ] .into_iter() .collect(), - ), - env_http_headers: Some( - [ - ( - "OpenAI-Organization".to_string(), - "OPENAI_ORGANIZATION".to_string(), - ), - ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), - ] - .into_iter() - .collect(), - ), - // Use global defaults for retry/timeout unless overridden in config.toml. - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_auth: true, - }, - )] + ), + // Use global defaults for retry/timeout unless overridden in config.toml. + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: true, + }, + ), + ( + BUILT_IN_OSS_MODEL_PROVIDER_ID, + P { + name: "Open Source".into(), + base_url: Some(codex_oss_base_url), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: false, + }, + ), + ] .into_iter() .map(|(k, v)| (k.to_string(), v)) .collect() diff --git a/codex-rs/exec/Cargo.toml b/codex-rs/exec/Cargo.toml index cd521410..aee480d7 100644 --- a/codex-rs/exec/Cargo.toml +++ b/codex-rs/exec/Cargo.toml @@ -25,6 +25,7 @@ codex-common = { path = "../common", features = [ "sandbox_summary", ] } codex-core = { path = "../core" } +codex-ollama = { path = "../ollama" } owo-colors = "4.2.0" serde_json = "1" shlex = "1.3.0" diff --git a/codex-rs/exec/src/cli.rs b/codex-rs/exec/src/cli.rs index 53af25c7..ea659e32 100644 --- a/codex-rs/exec/src/cli.rs +++ b/codex-rs/exec/src/cli.rs @@ -14,6 +14,9 @@ pub struct Cli { #[arg(long, short = 'm')] pub model: Option, + #[arg(long = "oss", default_value_t = false)] + pub oss: bool, + /// Select the sandbox policy to use when executing model-generated shell /// commands. #[arg(long = "sandbox", short = 's')] diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index ce4d7f65..c1af4f5b 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -9,6 +9,7 @@ use std::path::PathBuf; use std::sync::Arc; pub use cli::Cli; +use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::codex_wrapper::CodexConversation; use codex_core::codex_wrapper::{self}; use codex_core::config::Config; @@ -35,6 +36,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any let Cli { images, model, + oss, config_profile, full_auto, dangerously_bypass_approvals_and_sandbox, @@ -114,6 +116,24 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any sandbox_mode_cli_arg.map(Into::::into) }; + // When using `--oss`, let the bootstrapper pick the model (defaulting to + // gpt-oss:20b) and ensure it is present locally. Also, force the built‑in + // `oss` model provider. + let model_provider_override = if oss { + Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned()) + } else { + None + }; + let model = if oss { + Some( + codex_ollama::ensure_oss_ready(model.clone()) + .await + .map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?, + ) + } else { + model + }; + // Load configuration and determine approval policy let overrides = ConfigOverrides { model, @@ -123,10 +143,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any approval_policy: Some(AskForApproval::Never), sandbox_mode, cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)), - model_provider: None, + model_provider: model_provider_override, codex_linux_sandbox_exe, base_instructions: None, include_plan_tool: None, + default_disable_response_storage: oss.then_some(true), + default_show_raw_agent_reasoning: oss.then_some(true), }; // Parse `-c` overrides. let cli_kv_overrides = match config_overrides.parse_overrides() { diff --git a/codex-rs/mcp-server/src/codex_tool_config.rs b/codex-rs/mcp-server/src/codex_tool_config.rs index 877d0e05..f1a502bb 100644 --- a/codex-rs/mcp-server/src/codex_tool_config.rs +++ b/codex-rs/mcp-server/src/codex_tool_config.rs @@ -158,6 +158,8 @@ impl CodexToolCallParam { codex_linux_sandbox_exe, base_instructions, include_plan_tool, + default_disable_response_storage: None, + default_show_raw_agent_reasoning: None, }; let cli_overrides = cli_overrides diff --git a/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs b/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs index 28a89651..c1f40356 100644 --- a/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs +++ b/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs @@ -59,6 +59,8 @@ pub(crate) async fn handle_create_conversation( codex_linux_sandbox_exe: None, base_instructions, include_plan_tool: None, + default_disable_response_storage: None, + default_show_raw_agent_reasoning: None, }; let cfg: CodexConfig = match CodexConfig::load_with_cli_overrides(cli_overrides, overrides) { diff --git a/codex-rs/ollama/Cargo.toml b/codex-rs/ollama/Cargo.toml new file mode 100644 index 00000000..ead9a064 --- /dev/null +++ b/codex-rs/ollama/Cargo.toml @@ -0,0 +1,32 @@ +[package] +edition = "2024" +name = "codex-ollama" +version = { workspace = true } + +[lib] +name = "codex_ollama" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +async-stream = "0.3" +bytes = "1.10.1" +codex-core = { path = "../core" } +futures = "0.3" +reqwest = { version = "0.12", features = ["json", "stream"] } +serde_json = "1" +tokio = { version = "1", features = [ + "io-std", + "macros", + "process", + "rt-multi-thread", + "signal", +] } +toml = "0.9.2" +tracing = { version = "0.1.41", features = ["log"] } +wiremock = "0.6" + +[dev-dependencies] +tempfile = "3" diff --git a/codex-rs/ollama/src/client.rs b/codex-rs/ollama/src/client.rs new file mode 100644 index 00000000..8a15039f --- /dev/null +++ b/codex-rs/ollama/src/client.rs @@ -0,0 +1,366 @@ +use bytes::BytesMut; +use futures::StreamExt; +use futures::stream::BoxStream; +use serde_json::Value as JsonValue; +use std::collections::VecDeque; +use std::io; + +use codex_core::WireApi; + +use crate::parser::pull_events_from_value; +use crate::pull::PullEvent; +use crate::pull::PullProgressReporter; +use crate::url::base_url_to_host_root; +use crate::url::is_openai_compatible_base_url; + +/// Client for interacting with a local Ollama instance. +pub struct OllamaClient { + client: reqwest::Client, + host_root: String, + uses_openai_compat: bool, +} + +impl OllamaClient { + pub fn from_oss_provider() -> Self { + #![allow(clippy::expect_used)] + // Use the built-in OSS provider's base URL. + let built_in_model_providers = codex_core::built_in_model_providers(); + let provider = built_in_model_providers + .get(codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID) + .expect("oss provider must exist"); + let base_url = provider + .base_url + .as_ref() + .expect("oss provider must have a base_url"); + Self::from_provider(base_url, provider.wire_api) + } + + /// Construct a client for the built‑in open‑source ("oss") model provider + /// and verify that a local Ollama server is reachable. If no server is + /// detected, returns an error with helpful installation/run instructions. + pub async fn try_from_oss_provider() -> io::Result { + let client = Self::from_oss_provider(); + if client.probe_server().await? { + Ok(client) + } else { + Err(io::Error::other( + "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama", + )) + } + } + + /// Build a client from a provider definition. Falls back to the default + /// local URL if no base_url is configured. + fn from_provider(base_url: &str, wire_api: WireApi) -> Self { + let uses_openai_compat = is_openai_compatible_base_url(base_url) + || matches!(wire_api, WireApi::Chat) && is_openai_compatible_base_url(base_url); + let host_root = base_url_to_host_root(base_url); + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { + client, + host_root, + uses_openai_compat, + } + } + + /// Low-level constructor given a raw host root, e.g. "http://localhost:11434". + #[cfg(test)] + fn from_host_root(host_root: impl Into) -> Self { + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + Self { + client, + host_root: host_root.into(), + uses_openai_compat: false, + } + } + + /// Probe whether the server is reachable by hitting the appropriate health endpoint. + pub async fn probe_server(&self) -> io::Result { + let url = if self.uses_openai_compat { + format!("{}/v1/models", self.host_root.trim_end_matches('/')) + } else { + format!("{}/api/tags", self.host_root.trim_end_matches('/')) + }; + let resp = self.client.get(url).send().await; + Ok(matches!(resp, Ok(r) if r.status().is_success())) + } + + /// Return the list of model names known to the local Ollama instance. + pub async fn fetch_models(&self) -> io::Result> { + let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/')); + let resp = self + .client + .get(tags_url) + .send() + .await + .map_err(io::Error::other)?; + if !resp.status().is_success() { + return Ok(Vec::new()); + } + let val = resp.json::().await.map_err(io::Error::other)?; + let names = val + .get("models") + .and_then(|m| m.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.get("name").and_then(|n| n.as_str())) + .map(|s| s.to_string()) + .collect::>() + }) + .unwrap_or_default(); + Ok(names) + } + + /// Start a model pull and emit streaming events. The returned stream ends when + /// a Success event is observed or the server closes the connection. + pub async fn pull_model_stream( + &self, + model: &str, + ) -> io::Result> { + let url = format!("{}/api/pull", self.host_root.trim_end_matches('/')); + let resp = self + .client + .post(url) + .json(&serde_json::json!({"model": model, "stream": true})) + .send() + .await + .map_err(io::Error::other)?; + if !resp.status().is_success() { + return Err(io::Error::other(format!( + "failed to start pull: HTTP {}", + resp.status() + ))); + } + + let mut stream = resp.bytes_stream(); + let mut buf = BytesMut::new(); + let _pending: VecDeque = VecDeque::new(); + + // Using an async stream adaptor backed by unfold-like manual loop. + let s = async_stream::stream! { + while let Some(chunk) = stream.next().await { + match chunk { + Ok(bytes) => { + buf.extend_from_slice(&bytes); + while let Some(pos) = buf.iter().position(|b| *b == b'\n') { + let line = buf.split_to(pos + 1); + if let Ok(text) = std::str::from_utf8(&line) { + let text = text.trim(); + if text.is_empty() { continue; } + if let Ok(value) = serde_json::from_str::(text) { + for ev in pull_events_from_value(&value) { yield ev; } + if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) { + yield PullEvent::Error(err_msg.to_string()); + return; + } + if let Some(status) = value.get("status").and_then(|s| s.as_str()) { + if status == "success" { yield PullEvent::Success; return; } + } + } + } + } + } + Err(_) => { + // Connection error: end the stream. + return; + } + } + } + }; + + Ok(Box::pin(s)) + } + + /// High-level helper to pull a model and drive a progress reporter. + pub async fn pull_with_reporter( + &self, + model: &str, + reporter: &mut dyn PullProgressReporter, + ) -> io::Result<()> { + reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?; + let mut stream = self.pull_model_stream(model).await?; + while let Some(event) = stream.next().await { + reporter.on_event(&event)?; + match event { + PullEvent::Success => { + return Ok(()); + } + PullEvent::Error(err) => { + // Emperically, ollama returns a 200 OK response even when + // the output stream includes an error message. Verify with: + // + // `curl -i http://localhost:11434/api/pull -d '{ "model": "foobarbaz" }'` + // + // As such, we have to check the event stream, not the + // HTTP response status, to determine whether to return Err. + return Err(io::Error::other(format!("Pull failed: {err}"))); + } + PullEvent::ChunkProgress { .. } | PullEvent::Status(_) => { + continue; + } + } + } + Err(io::Error::other( + "Pull stream ended unexpectedly without success.", + )) + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::expect_used, clippy::unwrap_used)] + use super::*; + + /// Simple RAII guard to set an environment variable for the duration of a test + /// and restore the previous value (or remove it) on drop to avoid cross-test + /// interference. + struct EnvVarGuard { + key: String, + prev: Option, + } + impl EnvVarGuard { + fn set(key: &str, value: String) -> Self { + let prev = std::env::var(key).ok(); + // set_var is safe but we mirror existing tests that use an unsafe block + // to silence edition lints around global mutation during tests. + unsafe { std::env::set_var(key, value) }; + Self { + key: key.to_string(), + prev, + } + } + } + impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.prev { + Some(v) => unsafe { std::env::set_var(&self.key, v) }, + None => unsafe { std::env::remove_var(&self.key) }, + } + } + } + + // Happy-path tests using a mock HTTP server; skip if sandbox network is disabled. + #[tokio::test] + async fn test_fetch_models_happy_path() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_fetch_models_happy_path", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/api/tags")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = OllamaClient::from_host_root(server.uri()); + let models = client.fetch_models().await.expect("fetch models"); + assert!(models.contains(&"llama3.2:3b".to_string())); + assert!(models.contains(&"mistral".to_string())); + } + + #[tokio::test] + async fn test_probe_server_happy_path_openai_compat_and_native() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} set; skipping test_probe_server_happy_path_openai_compat_and_native", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + + // Native endpoint + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/api/tags")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + let native = OllamaClient::from_host_root(server.uri()); + assert!(native.probe_server().await.expect("probe native")); + + // OpenAI compatibility endpoint + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/v1/models")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + // Ensure the built-in OSS provider points at our mock server for this test + // to avoid depending on any globally configured environment from other tests. + let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())); + let ollama_client = OllamaClient::from_oss_provider(); + assert!(ollama_client.probe_server().await.expect("probe compat")); + } + + #[tokio::test] + async fn test_try_from_oss_provider_ok_when_server_running() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} set; skipping test_try_from_oss_provider_ok_when_server_running", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + // Configure built‑in `oss` provider to point at this mock server. + // set_var is unsafe on Rust 2024 edition; use unsafe block in tests. + let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())); + + // OpenAI‑compat models endpoint responds OK. + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/v1/models")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + + let _client = OllamaClient::try_from_oss_provider() + .await + .expect("client should be created when probe succeeds"); + } + + #[tokio::test] + async fn test_try_from_oss_provider_err_when_server_missing() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} set; skipping test_try_from_oss_provider_err_when_server_missing", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + // Point oss provider at our mock server but do NOT set up a handler + // for /v1/models so the request returns a non‑success status. + unsafe { std::env::set_var("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())) }; + + let err = OllamaClient::try_from_oss_provider() + .await + .err() + .expect("expected error"); + let msg = err.to_string(); + assert!( + msg.contains("No running Ollama server detected."), + "msg = {msg}" + ); + } +} diff --git a/codex-rs/ollama/src/lib.rs b/codex-rs/ollama/src/lib.rs new file mode 100644 index 00000000..d6f1e04d --- /dev/null +++ b/codex-rs/ollama/src/lib.rs @@ -0,0 +1,52 @@ +mod client; +mod parser; +mod pull; +mod url; + +pub use client::OllamaClient; +pub use pull::CliProgressReporter; +pub use pull::PullEvent; +pub use pull::PullProgressReporter; +pub use pull::TuiProgressReporter; + +/// Default OSS model to use when `--oss` is passed without an explicit `-m`. +pub const DEFAULT_OSS_MODEL: &str = "gpt-oss:20b"; + +/// Prepare the local OSS environment when `--oss` is selected. +/// +/// - Ensures a local Ollama server is reachable. +/// - Selects the final model name (CLI override or default). +/// - Checks if the model exists locally and pulls it if missing. +/// +/// Returns the final model name that should be used by the caller. +pub async fn ensure_oss_ready(cli_model: Option) -> std::io::Result { + // Only download when the requested model is the default OSS model (or when -m is not provided). + let should_download = cli_model + .as_deref() + .map(|name| name == DEFAULT_OSS_MODEL) + .unwrap_or(true); + let model = cli_model.unwrap_or_else(|| DEFAULT_OSS_MODEL.to_string()); + + // Verify local Ollama is reachable. + let ollama_client = crate::OllamaClient::try_from_oss_provider().await?; + + if should_download { + // If the model is not present locally, pull it. + match ollama_client.fetch_models().await { + Ok(models) => { + if !models.iter().any(|m| m == &model) { + let mut reporter = crate::CliProgressReporter::new(); + ollama_client + .pull_with_reporter(&model, &mut reporter) + .await?; + } + } + Err(err) => { + // Not fatal; higher layers may still proceed and surface errors later. + tracing::warn!("Failed to query local models from Ollama: {}.", err); + } + } + } + + Ok(model) +} diff --git a/codex-rs/ollama/src/parser.rs b/codex-rs/ollama/src/parser.rs new file mode 100644 index 00000000..b3ed2ca8 --- /dev/null +++ b/codex-rs/ollama/src/parser.rs @@ -0,0 +1,82 @@ +use serde_json::Value as JsonValue; + +use crate::pull::PullEvent; + +// Convert a single JSON object representing a pull update into one or more events. +pub(crate) fn pull_events_from_value(value: &JsonValue) -> Vec { + let mut events = Vec::new(); + if let Some(status) = value.get("status").and_then(|s| s.as_str()) { + events.push(PullEvent::Status(status.to_string())); + if status == "success" { + events.push(PullEvent::Success); + } + } + let digest = value + .get("digest") + .and_then(|d| d.as_str()) + .unwrap_or("") + .to_string(); + let total = value.get("total").and_then(|t| t.as_u64()); + let completed = value.get("completed").and_then(|t| t.as_u64()); + if total.is_some() || completed.is_some() { + events.push(PullEvent::ChunkProgress { + digest, + total, + completed, + }); + } + events +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pull_events_decoder_status_and_success() { + let v: JsonValue = serde_json::json!({"status":"verifying"}); + let events = pull_events_from_value(&v); + assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying")); + + let v2: JsonValue = serde_json::json!({"status":"success"}); + let events2 = pull_events_from_value(&v2); + assert_eq!(events2.len(), 2); + assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success")); + assert!(matches!(events2[1], PullEvent::Success)); + } + + #[test] + fn test_pull_events_decoder_progress() { + let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100}); + let events = pull_events_from_value(&v); + assert_eq!(events.len(), 1); + match &events[0] { + PullEvent::ChunkProgress { + digest, + total, + completed, + } => { + assert_eq!(digest, "sha256:abc"); + assert_eq!(*total, Some(100)); + assert_eq!(*completed, None); + } + _ => panic!("expected ChunkProgress"), + } + + let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42}); + let events2 = pull_events_from_value(&v2); + assert_eq!(events2.len(), 1); + match &events2[0] { + PullEvent::ChunkProgress { + digest, + total, + completed, + } => { + assert_eq!(digest, "sha256:def"); + assert_eq!(*total, None); + assert_eq!(*completed, Some(42)); + } + _ => panic!("expected ChunkProgress"), + } + } +} diff --git a/codex-rs/ollama/src/pull.rs b/codex-rs/ollama/src/pull.rs new file mode 100644 index 00000000..0dd35cd7 --- /dev/null +++ b/codex-rs/ollama/src/pull.rs @@ -0,0 +1,147 @@ +use std::collections::HashMap; +use std::io; +use std::io::Write; + +/// Events emitted while pulling a model from Ollama. +#[derive(Debug, Clone)] +pub enum PullEvent { + /// A human-readable status message (e.g., "verifying", "writing"). + Status(String), + /// Byte-level progress update for a specific layer digest. + ChunkProgress { + digest: String, + total: Option, + completed: Option, + }, + /// The pull finished successfully. + Success, + + /// Error event with a message. + Error(String), +} + +/// A simple observer for pull progress events. Implementations decide how to +/// render progress (CLI, TUI, logs, ...). +pub trait PullProgressReporter { + fn on_event(&mut self, event: &PullEvent) -> io::Result<()>; +} + +/// A minimal CLI reporter that writes inline progress to stderr. +pub struct CliProgressReporter { + printed_header: bool, + last_line_len: usize, + last_completed_sum: u64, + last_instant: std::time::Instant, + totals_by_digest: HashMap, +} + +impl Default for CliProgressReporter { + fn default() -> Self { + Self::new() + } +} + +impl CliProgressReporter { + pub fn new() -> Self { + Self { + printed_header: false, + last_line_len: 0, + last_completed_sum: 0, + last_instant: std::time::Instant::now(), + totals_by_digest: HashMap::new(), + } + } +} + +impl PullProgressReporter for CliProgressReporter { + fn on_event(&mut self, event: &PullEvent) -> io::Result<()> { + let mut out = std::io::stderr(); + match event { + PullEvent::Status(status) => { + // Avoid noisy manifest messages; otherwise show status inline. + if status.eq_ignore_ascii_case("pulling manifest") { + return Ok(()); + } + let pad = self.last_line_len.saturating_sub(status.len()); + let line = format!("\r{status}{}", " ".repeat(pad)); + self.last_line_len = status.len(); + out.write_all(line.as_bytes())?; + out.flush() + } + PullEvent::ChunkProgress { + digest, + total, + completed, + } => { + if let Some(t) = *total { + self.totals_by_digest + .entry(digest.clone()) + .or_insert((0, 0)) + .0 = t; + } + if let Some(c) = *completed { + self.totals_by_digest + .entry(digest.clone()) + .or_insert((0, 0)) + .1 = c; + } + + let (sum_total, sum_completed) = self + .totals_by_digest + .values() + .fold((0u64, 0u64), |acc, (t, c)| (acc.0 + *t, acc.1 + *c)); + if sum_total > 0 { + if !self.printed_header { + let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0); + let header = format!("Downloading model: total {gb:.2} GB\n"); + out.write_all(b"\r\x1b[2K")?; + out.write_all(header.as_bytes())?; + self.printed_header = true; + } + let now = std::time::Instant::now(); + let dt = now + .duration_since(self.last_instant) + .as_secs_f64() + .max(0.001); + let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64; + let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt; + self.last_completed_sum = sum_completed; + self.last_instant = now; + + let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0); + let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0); + let pct = (sum_completed as f64) * 100.0 / (sum_total as f64); + let text = + format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s"); + let pad = self.last_line_len.saturating_sub(text.len()); + let line = format!("\r{text}{}", " ".repeat(pad)); + self.last_line_len = text.len(); + out.write_all(line.as_bytes())?; + out.flush() + } else { + Ok(()) + } + } + PullEvent::Error(_) => { + // This will be handled by the caller, so we don't do anything + // here or the error will be printed twice. + Ok(()) + } + PullEvent::Success => { + out.write_all(b"\n")?; + out.flush() + } + } + } +} + +/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and +/// CLI behavior aligned until a dedicated TUI integration is implemented. +#[derive(Default)] +pub struct TuiProgressReporter(CliProgressReporter); + +impl PullProgressReporter for TuiProgressReporter { + fn on_event(&mut self, event: &PullEvent) -> io::Result<()> { + self.0.on_event(event) + } +} diff --git a/codex-rs/ollama/src/url.rs b/codex-rs/ollama/src/url.rs new file mode 100644 index 00000000..7c143ce4 --- /dev/null +++ b/codex-rs/ollama/src/url.rs @@ -0,0 +1,39 @@ +/// Identify whether a base_url points at an OpenAI-compatible root (".../v1"). +pub(crate) fn is_openai_compatible_base_url(base_url: &str) -> bool { + base_url.trim_end_matches('/').ends_with("/v1") +} + +/// Convert a provider base_url into the native Ollama host root. +/// For example, "http://localhost:11434/v1" -> "http://localhost:11434". +pub fn base_url_to_host_root(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/v1") { + trimmed + .trim_end_matches("/v1") + .trim_end_matches('/') + .to_string() + } else { + trimmed.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_base_url_to_host_root() { + assert_eq!( + base_url_to_host_root("http://localhost:11434/v1"), + "http://localhost:11434" + ); + assert_eq!( + base_url_to_host_root("http://localhost:11434"), + "http://localhost:11434" + ); + assert_eq!( + base_url_to_host_root("http://localhost:11434/"), + "http://localhost:11434" + ); + } +} diff --git a/codex-rs/tui/Cargo.toml b/codex-rs/tui/Cargo.toml index 60af056a..49d843f0 100644 --- a/codex-rs/tui/Cargo.toml +++ b/codex-rs/tui/Cargo.toml @@ -33,6 +33,7 @@ codex-common = { path = "../common", features = [ codex-core = { path = "../core" } codex-file-search = { path = "../file-search" } codex-login = { path = "../login" } +codex-ollama = { path = "../ollama" } color-eyre = "0.6.3" crossterm = { version = "0.28.1", features = ["bracketed-paste"] } image = { version = "^0.25.6", default-features = false, features = ["jpeg"] } diff --git a/codex-rs/tui/src/cli.rs b/codex-rs/tui/src/cli.rs index cb1b725a..85dffbeb 100644 --- a/codex-rs/tui/src/cli.rs +++ b/codex-rs/tui/src/cli.rs @@ -17,6 +17,12 @@ pub struct Cli { #[arg(long, short = 'm')] pub model: Option, + /// Convenience flag to select the local open source model provider. + /// Equivalent to -c model_provider=oss; verifies a local Ollama server is + /// running. + #[arg(long = "oss", default_value_t = false)] + pub oss: bool, + /// Configuration profile from config.toml to specify default options. #[arg(long = "profile", short = 'p')] pub config_profile: Option, diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index c619ce8f..0b833b13 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -3,6 +3,7 @@ // alternate‑screen mode starts; that file opts‑out locally via `allow`. #![deny(clippy::print_stdout, clippy::print_stderr)] use app::App; +use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config_types::SandboxMode; @@ -70,18 +71,35 @@ pub async fn run_main( ) }; + let model_provider_override = if cli.oss { + Some(BUILT_IN_OSS_MODEL_PROVIDER_ID.to_owned()) + } else { + None + }; let config = { // Load configuration and support CLI overrides. let overrides = ConfigOverrides { - model: cli.model.clone(), + // When using `--oss`, let the bootstrapper pick the model + // (defaulting to gpt-oss:20b) and ensure it is present locally. + model: if cli.oss { + Some( + codex_ollama::ensure_oss_ready(cli.model.clone()) + .await + .map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?, + ) + } else { + cli.model.clone() + }, approval_policy, sandbox_mode, cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)), - model_provider: None, + model_provider: model_provider_override, config_profile: cli.config_profile.clone(), codex_linux_sandbox_exe, base_instructions: None, include_plan_tool: Some(true), + default_disable_response_storage: cli.oss.then_some(true), + default_show_raw_agent_reasoning: cli.oss.then_some(true), }; // Parse `-c` overrides from the CLI. let cli_kv_overrides = match cli.config_overrides.parse_overrides() {