use bytes::BytesMut; use futures::StreamExt; use futures::stream::BoxStream; use serde_json::Value as JsonValue; use std::collections::VecDeque; use std::io; use codex_core::WireApi; use crate::parser::pull_events_from_value; use crate::pull::PullEvent; use crate::pull::PullProgressReporter; use crate::url::base_url_to_host_root; use crate::url::is_openai_compatible_base_url; /// Client for interacting with a local Ollama instance. pub struct OllamaClient { client: reqwest::Client, host_root: String, uses_openai_compat: bool, } impl OllamaClient { pub fn from_oss_provider() -> Self { #![allow(clippy::expect_used)] // Use the built-in OSS provider's base URL. let built_in_model_providers = codex_core::built_in_model_providers(); let provider = built_in_model_providers .get(codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID) .expect("oss provider must exist"); let base_url = provider .base_url .as_ref() .expect("oss provider must have a base_url"); Self::from_provider(base_url, provider.wire_api) } /// Construct a client for the built‑in open‑source ("oss") model provider /// and verify that a local Ollama server is reachable. If no server is /// detected, returns an error with helpful installation/run instructions. pub async fn try_from_oss_provider() -> io::Result { let client = Self::from_oss_provider(); if client.probe_server().await? { Ok(client) } else { Err(io::Error::other( "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama", )) } } /// Build a client from a provider definition. Falls back to the default /// local URL if no base_url is configured. fn from_provider(base_url: &str, wire_api: WireApi) -> Self { let uses_openai_compat = is_openai_compatible_base_url(base_url) || matches!(wire_api, WireApi::Chat) && is_openai_compatible_base_url(base_url); let host_root = base_url_to_host_root(base_url); let client = reqwest::Client::builder() .connect_timeout(std::time::Duration::from_secs(5)) .build() .unwrap_or_else(|_| reqwest::Client::new()); Self { client, host_root, uses_openai_compat, } } /// Low-level constructor given a raw host root, e.g. "http://localhost:11434". #[cfg(test)] fn from_host_root(host_root: impl Into) -> Self { let client = reqwest::Client::builder() .connect_timeout(std::time::Duration::from_secs(5)) .build() .unwrap_or_else(|_| reqwest::Client::new()); Self { client, host_root: host_root.into(), uses_openai_compat: false, } } /// Probe whether the server is reachable by hitting the appropriate health endpoint. pub async fn probe_server(&self) -> io::Result { let url = if self.uses_openai_compat { format!("{}/v1/models", self.host_root.trim_end_matches('/')) } else { format!("{}/api/tags", self.host_root.trim_end_matches('/')) }; let resp = self.client.get(url).send().await; Ok(matches!(resp, Ok(r) if r.status().is_success())) } /// Return the list of model names known to the local Ollama instance. pub async fn fetch_models(&self) -> io::Result> { let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/')); let resp = self .client .get(tags_url) .send() .await .map_err(io::Error::other)?; if !resp.status().is_success() { return Ok(Vec::new()); } let val = resp.json::().await.map_err(io::Error::other)?; let names = val .get("models") .and_then(|m| m.as_array()) .map(|arr| { arr.iter() .filter_map(|v| v.get("name").and_then(|n| n.as_str())) .map(|s| s.to_string()) .collect::>() }) .unwrap_or_default(); Ok(names) } /// Start a model pull and emit streaming events. The returned stream ends when /// a Success event is observed or the server closes the connection. pub async fn pull_model_stream( &self, model: &str, ) -> io::Result> { let url = format!("{}/api/pull", self.host_root.trim_end_matches('/')); let resp = self .client .post(url) .json(&serde_json::json!({"model": model, "stream": true})) .send() .await .map_err(io::Error::other)?; if !resp.status().is_success() { return Err(io::Error::other(format!( "failed to start pull: HTTP {}", resp.status() ))); } let mut stream = resp.bytes_stream(); let mut buf = BytesMut::new(); let _pending: VecDeque = VecDeque::new(); // Using an async stream adaptor backed by unfold-like manual loop. let s = async_stream::stream! { while let Some(chunk) = stream.next().await { match chunk { Ok(bytes) => { buf.extend_from_slice(&bytes); while let Some(pos) = buf.iter().position(|b| *b == b'\n') { let line = buf.split_to(pos + 1); if let Ok(text) = std::str::from_utf8(&line) { let text = text.trim(); if text.is_empty() { continue; } if let Ok(value) = serde_json::from_str::(text) { for ev in pull_events_from_value(&value) { yield ev; } if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) { yield PullEvent::Error(err_msg.to_string()); return; } if let Some(status) = value.get("status").and_then(|s| s.as_str()) { if status == "success" { yield PullEvent::Success; return; } } } } } } Err(_) => { // Connection error: end the stream. return; } } } }; Ok(Box::pin(s)) } /// High-level helper to pull a model and drive a progress reporter. pub async fn pull_with_reporter( &self, model: &str, reporter: &mut dyn PullProgressReporter, ) -> io::Result<()> { reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?; let mut stream = self.pull_model_stream(model).await?; while let Some(event) = stream.next().await { reporter.on_event(&event)?; match event { PullEvent::Success => { return Ok(()); } PullEvent::Error(err) => { // Empirically, ollama returns a 200 OK response even when // the output stream includes an error message. Verify with: // // `curl -i http://localhost:11434/api/pull -d '{ "model": "foobarbaz" }'` // // As such, we have to check the event stream, not the // HTTP response status, to determine whether to return Err. return Err(io::Error::other(format!("Pull failed: {err}"))); } PullEvent::ChunkProgress { .. } | PullEvent::Status(_) => { continue; } } } Err(io::Error::other( "Pull stream ended unexpectedly without success.", )) } } #[cfg(test)] mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] use super::*; /// Simple RAII guard to set an environment variable for the duration of a test /// and restore the previous value (or remove it) on drop to avoid cross-test /// interference. struct EnvVarGuard { key: String, prev: Option, } impl EnvVarGuard { fn set(key: &str, value: String) -> Self { let prev = std::env::var(key).ok(); // set_var is safe but we mirror existing tests that use an unsafe block // to silence edition lints around global mutation during tests. unsafe { std::env::set_var(key, value) }; Self { key: key.to_string(), prev, } } } impl Drop for EnvVarGuard { fn drop(&mut self) { match &self.prev { Some(v) => unsafe { std::env::set_var(&self.key, v) }, None => unsafe { std::env::remove_var(&self.key) }, } } } // Happy-path tests using a mock HTTP server; skip if sandbox network is disabled. #[tokio::test] async fn test_fetch_models_happy_path() { if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_fetch_models_happy_path", codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } let server = wiremock::MockServer::start().await; wiremock::Mock::given(wiremock::matchers::method("GET")) .and(wiremock::matchers::path("/api/tags")) .respond_with( wiremock::ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ] }) .to_string(), "application/json", ), ) .mount(&server) .await; let client = OllamaClient::from_host_root(server.uri()); let models = client.fetch_models().await.expect("fetch models"); assert!(models.contains(&"llama3.2:3b".to_string())); assert!(models.contains(&"mistral".to_string())); } #[tokio::test] async fn test_probe_server_happy_path_openai_compat_and_native() { if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} set; skipping test_probe_server_happy_path_openai_compat_and_native", codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } let server = wiremock::MockServer::start().await; // Native endpoint wiremock::Mock::given(wiremock::matchers::method("GET")) .and(wiremock::matchers::path("/api/tags")) .respond_with(wiremock::ResponseTemplate::new(200)) .mount(&server) .await; let native = OllamaClient::from_host_root(server.uri()); assert!(native.probe_server().await.expect("probe native")); // OpenAI compatibility endpoint wiremock::Mock::given(wiremock::matchers::method("GET")) .and(wiremock::matchers::path("/v1/models")) .respond_with(wiremock::ResponseTemplate::new(200)) .mount(&server) .await; // Ensure the built-in OSS provider points at our mock server for this test // to avoid depending on any globally configured environment from other tests. let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())); let ollama_client = OllamaClient::from_oss_provider(); assert!(ollama_client.probe_server().await.expect("probe compat")); } #[tokio::test] async fn test_try_from_oss_provider_ok_when_server_running() { if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} set; skipping test_try_from_oss_provider_ok_when_server_running", codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } let server = wiremock::MockServer::start().await; // Configure built‑in `oss` provider to point at this mock server. // set_var is unsafe on Rust 2024 edition; use unsafe block in tests. let _guard = EnvVarGuard::set("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())); // OpenAI‑compat models endpoint responds OK. wiremock::Mock::given(wiremock::matchers::method("GET")) .and(wiremock::matchers::path("/v1/models")) .respond_with(wiremock::ResponseTemplate::new(200)) .mount(&server) .await; let _client = OllamaClient::try_from_oss_provider() .await .expect("client should be created when probe succeeds"); } #[tokio::test] async fn test_try_from_oss_provider_err_when_server_missing() { if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} set; skipping test_try_from_oss_provider_err_when_server_missing", codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } let server = wiremock::MockServer::start().await; // Point oss provider at our mock server but do NOT set up a handler // for /v1/models so the request returns a non‑success status. unsafe { std::env::set_var("CODEX_OSS_BASE_URL", format!("{}/v1", server.uri())) }; let err = OllamaClient::try_from_oss_provider() .await .err() .expect("expected error"); let msg = err.to_string(); assert!( msg.contains("No running Ollama server detected."), "msg = {msg}" ); } }