2025-08-05 11:31:11 -07:00
|
|
|
|
use bytes::BytesMut;
|
|
|
|
|
|
use futures::StreamExt;
|
|
|
|
|
|
use futures::stream::BoxStream;
|
|
|
|
|
|
use serde_json::Value as JsonValue;
|
|
|
|
|
|
use std::collections::VecDeque;
|
|
|
|
|
|
use std::io;
|
|
|
|
|
|
|
|
|
|
|
|
use crate::parser::pull_events_from_value;
|
|
|
|
|
|
use crate::pull::PullEvent;
|
|
|
|
|
|
use crate::pull::PullProgressReporter;
|
|
|
|
|
|
use crate::url::base_url_to_host_root;
|
|
|
|
|
|
use crate::url::is_openai_compatible_base_url;
|
2025-08-05 13:55:32 -07:00
|
|
|
|
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
|
|
|
|
|
use codex_core::ModelProviderInfo;
|
|
|
|
|
|
use codex_core::WireApi;
|
|
|
|
|
|
use codex_core::config::Config;
|
|
|
|
|
|
|
|
|
|
|
|
const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama";
|
2025-08-05 11:31:11 -07:00
|
|
|
|
|
|
|
|
|
|
/// Client for interacting with a local Ollama instance.
|
|
|
|
|
|
pub struct OllamaClient {
|
|
|
|
|
|
client: reqwest::Client,
|
|
|
|
|
|
host_root: String,
|
|
|
|
|
|
uses_openai_compat: bool,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl OllamaClient {
|
|
|
|
|
|
/// Construct a client for the built‑in open‑source ("oss") model provider
|
|
|
|
|
|
/// and verify that a local Ollama server is reachable. If no server is
|
|
|
|
|
|
/// detected, returns an error with helpful installation/run instructions.
|
2025-08-05 13:55:32 -07:00
|
|
|
|
pub async fn try_from_oss_provider(config: &Config) -> io::Result<Self> {
|
|
|
|
|
|
// Note that we must look up the provider from the Config to ensure that
|
|
|
|
|
|
// any overrides the user has in their config.toml are taken into
|
|
|
|
|
|
// account.
|
|
|
|
|
|
let provider = config
|
|
|
|
|
|
.model_providers
|
|
|
|
|
|
.get(BUILT_IN_OSS_MODEL_PROVIDER_ID)
|
|
|
|
|
|
.ok_or_else(|| {
|
|
|
|
|
|
io::Error::new(
|
|
|
|
|
|
io::ErrorKind::NotFound,
|
|
|
|
|
|
format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",),
|
|
|
|
|
|
)
|
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
|
|
Self::try_from_provider(provider).await
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
|
async fn try_from_provider_with_base_url(base_url: &str) -> io::Result<Self> {
|
|
|
|
|
|
let provider = codex_core::create_oss_provider_with_base_url(base_url);
|
|
|
|
|
|
Self::try_from_provider(&provider).await
|
2025-08-05 11:31:11 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-05 13:55:32 -07:00
|
|
|
|
/// Build a client from a provider definition and verify the server is reachable.
|
|
|
|
|
|
async fn try_from_provider(provider: &ModelProviderInfo) -> io::Result<Self> {
|
|
|
|
|
|
#![allow(clippy::expect_used)]
|
|
|
|
|
|
let base_url = provider
|
|
|
|
|
|
.base_url
|
|
|
|
|
|
.as_ref()
|
|
|
|
|
|
.expect("oss provider must have a base_url");
|
2025-08-05 11:31:11 -07:00
|
|
|
|
let uses_openai_compat = is_openai_compatible_base_url(base_url)
|
2025-08-05 13:55:32 -07:00
|
|
|
|
|| matches!(provider.wire_api, WireApi::Chat)
|
|
|
|
|
|
&& is_openai_compatible_base_url(base_url);
|
2025-08-05 11:31:11 -07:00
|
|
|
|
let host_root = base_url_to_host_root(base_url);
|
|
|
|
|
|
let client = reqwest::Client::builder()
|
|
|
|
|
|
.connect_timeout(std::time::Duration::from_secs(5))
|
|
|
|
|
|
.build()
|
|
|
|
|
|
.unwrap_or_else(|_| reqwest::Client::new());
|
2025-08-05 13:55:32 -07:00
|
|
|
|
let client = Self {
|
2025-08-05 11:31:11 -07:00
|
|
|
|
client,
|
|
|
|
|
|
host_root,
|
|
|
|
|
|
uses_openai_compat,
|
2025-08-05 13:55:32 -07:00
|
|
|
|
};
|
|
|
|
|
|
client.probe_server().await?;
|
|
|
|
|
|
Ok(client)
|
2025-08-05 11:31:11 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
|
2025-08-05 13:55:32 -07:00
|
|
|
|
async fn probe_server(&self) -> io::Result<()> {
|
2025-08-05 11:31:11 -07:00
|
|
|
|
let url = if self.uses_openai_compat {
|
|
|
|
|
|
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
|
|
|
|
|
|
} else {
|
|
|
|
|
|
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
|
|
|
|
|
|
};
|
2025-08-05 13:55:32 -07:00
|
|
|
|
let resp = self.client.get(url).send().await.map_err(|err| {
|
|
|
|
|
|
tracing::warn!("Failed to connect to Ollama server: {err:?}");
|
|
|
|
|
|
io::Error::other(OLLAMA_CONNECTION_ERROR)
|
|
|
|
|
|
})?;
|
|
|
|
|
|
if resp.status().is_success() {
|
|
|
|
|
|
Ok(())
|
|
|
|
|
|
} else {
|
|
|
|
|
|
tracing::warn!(
|
|
|
|
|
|
"Failed to probe server at {}: HTTP {}",
|
|
|
|
|
|
self.host_root,
|
|
|
|
|
|
resp.status()
|
|
|
|
|
|
);
|
|
|
|
|
|
Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
|
|
|
|
|
|
}
|
2025-08-05 11:31:11 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// Return the list of model names known to the local Ollama instance.
|
|
|
|
|
|
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
|
|
|
|
|
|
let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
|
|
|
|
|
|
let resp = self
|
|
|
|
|
|
.client
|
|
|
|
|
|
.get(tags_url)
|
|
|
|
|
|
.send()
|
|
|
|
|
|
.await
|
|
|
|
|
|
.map_err(io::Error::other)?;
|
|
|
|
|
|
if !resp.status().is_success() {
|
|
|
|
|
|
return Ok(Vec::new());
|
|
|
|
|
|
}
|
|
|
|
|
|
let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
|
|
|
|
|
|
let names = val
|
|
|
|
|
|
.get("models")
|
|
|
|
|
|
.and_then(|m| m.as_array())
|
|
|
|
|
|
.map(|arr| {
|
|
|
|
|
|
arr.iter()
|
|
|
|
|
|
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
|
|
|
|
|
|
.map(|s| s.to_string())
|
|
|
|
|
|
.collect::<Vec<_>>()
|
|
|
|
|
|
})
|
|
|
|
|
|
.unwrap_or_default();
|
|
|
|
|
|
Ok(names)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// Start a model pull and emit streaming events. The returned stream ends when
|
|
|
|
|
|
/// a Success event is observed or the server closes the connection.
|
|
|
|
|
|
pub async fn pull_model_stream(
|
|
|
|
|
|
&self,
|
|
|
|
|
|
model: &str,
|
|
|
|
|
|
) -> io::Result<BoxStream<'static, PullEvent>> {
|
|
|
|
|
|
let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
|
|
|
|
|
|
let resp = self
|
|
|
|
|
|
.client
|
|
|
|
|
|
.post(url)
|
|
|
|
|
|
.json(&serde_json::json!({"model": model, "stream": true}))
|
|
|
|
|
|
.send()
|
|
|
|
|
|
.await
|
|
|
|
|
|
.map_err(io::Error::other)?;
|
|
|
|
|
|
if !resp.status().is_success() {
|
|
|
|
|
|
return Err(io::Error::other(format!(
|
|
|
|
|
|
"failed to start pull: HTTP {}",
|
|
|
|
|
|
resp.status()
|
|
|
|
|
|
)));
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let mut stream = resp.bytes_stream();
|
|
|
|
|
|
let mut buf = BytesMut::new();
|
|
|
|
|
|
let _pending: VecDeque<PullEvent> = VecDeque::new();
|
|
|
|
|
|
|
|
|
|
|
|
// Using an async stream adaptor backed by unfold-like manual loop.
|
|
|
|
|
|
let s = async_stream::stream! {
|
|
|
|
|
|
while let Some(chunk) = stream.next().await {
|
|
|
|
|
|
match chunk {
|
|
|
|
|
|
Ok(bytes) => {
|
|
|
|
|
|
buf.extend_from_slice(&bytes);
|
|
|
|
|
|
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
|
|
|
|
|
|
let line = buf.split_to(pos + 1);
|
|
|
|
|
|
if let Ok(text) = std::str::from_utf8(&line) {
|
|
|
|
|
|
let text = text.trim();
|
|
|
|
|
|
if text.is_empty() { continue; }
|
|
|
|
|
|
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
|
|
|
|
|
|
for ev in pull_events_from_value(&value) { yield ev; }
|
|
|
|
|
|
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
|
|
|
|
|
|
yield PullEvent::Error(err_msg.to_string());
|
|
|
|
|
|
return;
|
|
|
|
|
|
}
|
|
|
|
|
|
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
|
|
|
|
|
|
if status == "success" { yield PullEvent::Success; return; }
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
Err(_) => {
|
|
|
|
|
|
// Connection error: end the stream.
|
|
|
|
|
|
return;
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
Ok(Box::pin(s))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/// High-level helper to pull a model and drive a progress reporter.
|
|
|
|
|
|
pub async fn pull_with_reporter(
|
|
|
|
|
|
&self,
|
|
|
|
|
|
model: &str,
|
|
|
|
|
|
reporter: &mut dyn PullProgressReporter,
|
|
|
|
|
|
) -> io::Result<()> {
|
|
|
|
|
|
reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
|
|
|
|
|
|
let mut stream = self.pull_model_stream(model).await?;
|
|
|
|
|
|
while let Some(event) = stream.next().await {
|
|
|
|
|
|
reporter.on_event(&event)?;
|
|
|
|
|
|
match event {
|
|
|
|
|
|
PullEvent::Success => {
|
|
|
|
|
|
return Ok(());
|
|
|
|
|
|
}
|
|
|
|
|
|
PullEvent::Error(err) => {
|
2025-08-05 11:39:30 -07:00
|
|
|
|
// Empirically, ollama returns a 200 OK response even when
|
2025-08-05 11:31:11 -07:00
|
|
|
|
// the output stream includes an error message. Verify with:
|
|
|
|
|
|
//
|
|
|
|
|
|
// `curl -i http://localhost:11434/api/pull -d '{ "model": "foobarbaz" }'`
|
|
|
|
|
|
//
|
|
|
|
|
|
// As such, we have to check the event stream, not the
|
|
|
|
|
|
// HTTP response status, to determine whether to return Err.
|
|
|
|
|
|
return Err(io::Error::other(format!("Pull failed: {err}")));
|
|
|
|
|
|
}
|
|
|
|
|
|
PullEvent::ChunkProgress { .. } | PullEvent::Status(_) => {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
Err(io::Error::other(
|
|
|
|
|
|
"Pull stream ended unexpectedly without success.",
|
|
|
|
|
|
))
|
|
|
|
|
|
}
|
2025-08-05 13:55:32 -07:00
|
|
|
|
|
|
|
|
|
|
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
|
fn from_host_root(host_root: impl Into<String>) -> Self {
|
|
|
|
|
|
let client = reqwest::Client::builder()
|
|
|
|
|
|
.connect_timeout(std::time::Duration::from_secs(5))
|
|
|
|
|
|
.build()
|
|
|
|
|
|
.unwrap_or_else(|_| reqwest::Client::new());
|
|
|
|
|
|
Self {
|
|
|
|
|
|
client,
|
|
|
|
|
|
host_root: host_root.into(),
|
|
|
|
|
|
uses_openai_compat: false,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-08-05 11:31:11 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
|
mod tests {
|
|
|
|
|
|
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
|
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
|
|
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
|
|
async fn test_fetch_models_happy_path() {
|
|
|
|
|
|
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
|
|
|
|
tracing::info!(
|
|
|
|
|
|
"{} is set; skipping test_fetch_models_happy_path",
|
|
|
|
|
|
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
|
|
|
|
|
);
|
|
|
|
|
|
return;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let server = wiremock::MockServer::start().await;
|
|
|
|
|
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
|
|
|
|
|
.and(wiremock::matchers::path("/api/tags"))
|
|
|
|
|
|
.respond_with(
|
|
|
|
|
|
wiremock::ResponseTemplate::new(200).set_body_raw(
|
|
|
|
|
|
serde_json::json!({
|
|
|
|
|
|
"models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ]
|
|
|
|
|
|
})
|
|
|
|
|
|
.to_string(),
|
|
|
|
|
|
"application/json",
|
|
|
|
|
|
),
|
|
|
|
|
|
)
|
|
|
|
|
|
.mount(&server)
|
|
|
|
|
|
.await;
|
|
|
|
|
|
|
|
|
|
|
|
let client = OllamaClient::from_host_root(server.uri());
|
|
|
|
|
|
let models = client.fetch_models().await.expect("fetch models");
|
|
|
|
|
|
assert!(models.contains(&"llama3.2:3b".to_string()));
|
|
|
|
|
|
assert!(models.contains(&"mistral".to_string()));
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
|
|
async fn test_probe_server_happy_path_openai_compat_and_native() {
|
|
|
|
|
|
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
|
|
|
|
tracing::info!(
|
|
|
|
|
|
"{} set; skipping test_probe_server_happy_path_openai_compat_and_native",
|
|
|
|
|
|
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
|
|
|
|
|
);
|
|
|
|
|
|
return;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let server = wiremock::MockServer::start().await;
|
|
|
|
|
|
|
|
|
|
|
|
// Native endpoint
|
|
|
|
|
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
|
|
|
|
|
.and(wiremock::matchers::path("/api/tags"))
|
|
|
|
|
|
.respond_with(wiremock::ResponseTemplate::new(200))
|
|
|
|
|
|
.mount(&server)
|
|
|
|
|
|
.await;
|
|
|
|
|
|
let native = OllamaClient::from_host_root(server.uri());
|
2025-08-05 13:55:32 -07:00
|
|
|
|
native.probe_server().await.expect("probe native");
|
2025-08-05 11:31:11 -07:00
|
|
|
|
|
|
|
|
|
|
// OpenAI compatibility endpoint
|
|
|
|
|
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
|
|
|
|
|
.and(wiremock::matchers::path("/v1/models"))
|
|
|
|
|
|
.respond_with(wiremock::ResponseTemplate::new(200))
|
|
|
|
|
|
.mount(&server)
|
|
|
|
|
|
.await;
|
2025-08-05 13:55:32 -07:00
|
|
|
|
let ollama_client =
|
|
|
|
|
|
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect("probe OpenAI compat");
|
|
|
|
|
|
ollama_client
|
|
|
|
|
|
.probe_server()
|
|
|
|
|
|
.await
|
|
|
|
|
|
.expect("probe OpenAI compat");
|
2025-08-05 11:31:11 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
|
|
async fn test_try_from_oss_provider_ok_when_server_running() {
|
|
|
|
|
|
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
|
|
|
|
tracing::info!(
|
|
|
|
|
|
"{} set; skipping test_try_from_oss_provider_ok_when_server_running",
|
|
|
|
|
|
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
|
|
|
|
|
);
|
|
|
|
|
|
return;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let server = wiremock::MockServer::start().await;
|
|
|
|
|
|
|
|
|
|
|
|
// OpenAI‑compat models endpoint responds OK.
|
|
|
|
|
|
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
|
|
|
|
|
.and(wiremock::matchers::path("/v1/models"))
|
|
|
|
|
|
.respond_with(wiremock::ResponseTemplate::new(200))
|
|
|
|
|
|
.mount(&server)
|
|
|
|
|
|
.await;
|
|
|
|
|
|
|
2025-08-05 13:55:32 -07:00
|
|
|
|
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
2025-08-05 11:31:11 -07:00
|
|
|
|
.await
|
|
|
|
|
|
.expect("client should be created when probe succeeds");
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
|
|
async fn test_try_from_oss_provider_err_when_server_missing() {
|
|
|
|
|
|
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
|
|
|
|
tracing::info!(
|
|
|
|
|
|
"{} set; skipping test_try_from_oss_provider_err_when_server_missing",
|
|
|
|
|
|
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
|
|
|
|
|
);
|
|
|
|
|
|
return;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let server = wiremock::MockServer::start().await;
|
2025-08-05 13:55:32 -07:00
|
|
|
|
let err = OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
2025-08-05 11:31:11 -07:00
|
|
|
|
.await
|
|
|
|
|
|
.err()
|
|
|
|
|
|
.expect("expected error");
|
2025-08-05 13:55:32 -07:00
|
|
|
|
assert_eq!(OLLAMA_CONNECTION_ERROR, err.to_string());
|
2025-08-05 11:31:11 -07:00
|
|
|
|
}
|
|
|
|
|
|
}
|