350 lines
13 KiB
Rust
350 lines
13 KiB
Rust
use bytes::BytesMut;
|
||
use futures::StreamExt;
|
||
use futures::stream::BoxStream;
|
||
use serde_json::Value as JsonValue;
|
||
use std::collections::VecDeque;
|
||
use std::io;
|
||
|
||
use crate::parser::pull_events_from_value;
|
||
use crate::pull::PullEvent;
|
||
use crate::pull::PullProgressReporter;
|
||
use crate::url::base_url_to_host_root;
|
||
use crate::url::is_openai_compatible_base_url;
|
||
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||
use codex_core::ModelProviderInfo;
|
||
use codex_core::WireApi;
|
||
use codex_core::config::Config;
|
||
|
||
const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama";
|
||
|
||
/// Client for interacting with a local Ollama instance.
|
||
pub struct OllamaClient {
|
||
client: reqwest::Client,
|
||
host_root: String,
|
||
uses_openai_compat: bool,
|
||
}
|
||
|
||
impl OllamaClient {
|
||
/// Construct a client for the built‑in open‑source ("oss") model provider
|
||
/// and verify that a local Ollama server is reachable. If no server is
|
||
/// detected, returns an error with helpful installation/run instructions.
|
||
pub async fn try_from_oss_provider(config: &Config) -> io::Result<Self> {
|
||
// Note that we must look up the provider from the Config to ensure that
|
||
// any overrides the user has in their config.toml are taken into
|
||
// account.
|
||
let provider = config
|
||
.model_providers
|
||
.get(BUILT_IN_OSS_MODEL_PROVIDER_ID)
|
||
.ok_or_else(|| {
|
||
io::Error::new(
|
||
io::ErrorKind::NotFound,
|
||
format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",),
|
||
)
|
||
})?;
|
||
|
||
Self::try_from_provider(provider).await
|
||
}
|
||
|
||
#[cfg(test)]
|
||
async fn try_from_provider_with_base_url(base_url: &str) -> io::Result<Self> {
|
||
let provider = codex_core::create_oss_provider_with_base_url(base_url);
|
||
Self::try_from_provider(&provider).await
|
||
}
|
||
|
||
/// Build a client from a provider definition and verify the server is reachable.
|
||
async fn try_from_provider(provider: &ModelProviderInfo) -> io::Result<Self> {
|
||
#![expect(clippy::expect_used)]
|
||
let base_url = provider
|
||
.base_url
|
||
.as_ref()
|
||
.expect("oss provider must have a base_url");
|
||
let uses_openai_compat = is_openai_compatible_base_url(base_url)
|
||
|| matches!(provider.wire_api, WireApi::Chat)
|
||
&& is_openai_compatible_base_url(base_url);
|
||
let host_root = base_url_to_host_root(base_url);
|
||
let client = reqwest::Client::builder()
|
||
.connect_timeout(std::time::Duration::from_secs(5))
|
||
.build()
|
||
.unwrap_or_else(|_| reqwest::Client::new());
|
||
let client = Self {
|
||
client,
|
||
host_root,
|
||
uses_openai_compat,
|
||
};
|
||
client.probe_server().await?;
|
||
Ok(client)
|
||
}
|
||
|
||
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
|
||
async fn probe_server(&self) -> io::Result<()> {
|
||
let url = if self.uses_openai_compat {
|
||
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
|
||
} else {
|
||
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
|
||
};
|
||
let resp = self.client.get(url).send().await.map_err(|err| {
|
||
tracing::warn!("Failed to connect to Ollama server: {err:?}");
|
||
io::Error::other(OLLAMA_CONNECTION_ERROR)
|
||
})?;
|
||
if resp.status().is_success() {
|
||
Ok(())
|
||
} else {
|
||
tracing::warn!(
|
||
"Failed to probe server at {}: HTTP {}",
|
||
self.host_root,
|
||
resp.status()
|
||
);
|
||
Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
|
||
}
|
||
}
|
||
|
||
/// Return the list of model names known to the local Ollama instance.
|
||
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
|
||
let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
|
||
let resp = self
|
||
.client
|
||
.get(tags_url)
|
||
.send()
|
||
.await
|
||
.map_err(io::Error::other)?;
|
||
if !resp.status().is_success() {
|
||
return Ok(Vec::new());
|
||
}
|
||
let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
|
||
let names = val
|
||
.get("models")
|
||
.and_then(|m| m.as_array())
|
||
.map(|arr| {
|
||
arr.iter()
|
||
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
|
||
.map(str::to_string)
|
||
.collect::<Vec<_>>()
|
||
})
|
||
.unwrap_or_default();
|
||
Ok(names)
|
||
}
|
||
|
||
/// Start a model pull and emit streaming events. The returned stream ends when
|
||
/// a Success event is observed or the server closes the connection.
|
||
pub async fn pull_model_stream(
|
||
&self,
|
||
model: &str,
|
||
) -> io::Result<BoxStream<'static, PullEvent>> {
|
||
let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
|
||
let resp = self
|
||
.client
|
||
.post(url)
|
||
.json(&serde_json::json!({"model": model, "stream": true}))
|
||
.send()
|
||
.await
|
||
.map_err(io::Error::other)?;
|
||
if !resp.status().is_success() {
|
||
return Err(io::Error::other(format!(
|
||
"failed to start pull: HTTP {}",
|
||
resp.status()
|
||
)));
|
||
}
|
||
|
||
let mut stream = resp.bytes_stream();
|
||
let mut buf = BytesMut::new();
|
||
let _pending: VecDeque<PullEvent> = VecDeque::new();
|
||
|
||
// Using an async stream adaptor backed by unfold-like manual loop.
|
||
let s = async_stream::stream! {
|
||
while let Some(chunk) = stream.next().await {
|
||
match chunk {
|
||
Ok(bytes) => {
|
||
buf.extend_from_slice(&bytes);
|
||
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
|
||
let line = buf.split_to(pos + 1);
|
||
if let Ok(text) = std::str::from_utf8(&line) {
|
||
let text = text.trim();
|
||
if text.is_empty() { continue; }
|
||
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
|
||
for ev in pull_events_from_value(&value) { yield ev; }
|
||
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
|
||
yield PullEvent::Error(err_msg.to_string());
|
||
return;
|
||
}
|
||
if let Some(status) = value.get("status").and_then(|s| s.as_str())
|
||
&& status == "success" { yield PullEvent::Success; return; }
|
||
}
|
||
}
|
||
}
|
||
}
|
||
Err(_) => {
|
||
// Connection error: end the stream.
|
||
return;
|
||
}
|
||
}
|
||
}
|
||
};
|
||
|
||
Ok(Box::pin(s))
|
||
}
|
||
|
||
/// High-level helper to pull a model and drive a progress reporter.
|
||
pub async fn pull_with_reporter(
|
||
&self,
|
||
model: &str,
|
||
reporter: &mut dyn PullProgressReporter,
|
||
) -> io::Result<()> {
|
||
reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
|
||
let mut stream = self.pull_model_stream(model).await?;
|
||
while let Some(event) = stream.next().await {
|
||
reporter.on_event(&event)?;
|
||
match event {
|
||
PullEvent::Success => {
|
||
return Ok(());
|
||
}
|
||
PullEvent::Error(err) => {
|
||
// Empirically, ollama returns a 200 OK response even when
|
||
// the output stream includes an error message. Verify with:
|
||
//
|
||
// `curl -i http://localhost:11434/api/pull -d '{ "model": "foobarbaz" }'`
|
||
//
|
||
// As such, we have to check the event stream, not the
|
||
// HTTP response status, to determine whether to return Err.
|
||
return Err(io::Error::other(format!("Pull failed: {err}")));
|
||
}
|
||
PullEvent::ChunkProgress { .. } | PullEvent::Status(_) => {
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
Err(io::Error::other(
|
||
"Pull stream ended unexpectedly without success.",
|
||
))
|
||
}
|
||
|
||
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
|
||
#[cfg(test)]
|
||
fn from_host_root(host_root: impl Into<String>) -> Self {
|
||
let client = reqwest::Client::builder()
|
||
.connect_timeout(std::time::Duration::from_secs(5))
|
||
.build()
|
||
.unwrap_or_else(|_| reqwest::Client::new());
|
||
Self {
|
||
client,
|
||
host_root: host_root.into(),
|
||
uses_openai_compat: false,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
|
||
#[tokio::test]
|
||
async fn test_fetch_models_happy_path() {
|
||
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||
tracing::info!(
|
||
"{} is set; skipping test_fetch_models_happy_path",
|
||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||
);
|
||
return;
|
||
}
|
||
|
||
let server = wiremock::MockServer::start().await;
|
||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||
.and(wiremock::matchers::path("/api/tags"))
|
||
.respond_with(
|
||
wiremock::ResponseTemplate::new(200).set_body_raw(
|
||
serde_json::json!({
|
||
"models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ]
|
||
})
|
||
.to_string(),
|
||
"application/json",
|
||
),
|
||
)
|
||
.mount(&server)
|
||
.await;
|
||
|
||
let client = OllamaClient::from_host_root(server.uri());
|
||
let models = client.fetch_models().await.expect("fetch models");
|
||
assert!(models.contains(&"llama3.2:3b".to_string()));
|
||
assert!(models.contains(&"mistral".to_string()));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_probe_server_happy_path_openai_compat_and_native() {
|
||
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||
tracing::info!(
|
||
"{} set; skipping test_probe_server_happy_path_openai_compat_and_native",
|
||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||
);
|
||
return;
|
||
}
|
||
|
||
let server = wiremock::MockServer::start().await;
|
||
|
||
// Native endpoint
|
||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||
.and(wiremock::matchers::path("/api/tags"))
|
||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||
.mount(&server)
|
||
.await;
|
||
let native = OllamaClient::from_host_root(server.uri());
|
||
native.probe_server().await.expect("probe native");
|
||
|
||
// OpenAI compatibility endpoint
|
||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||
.and(wiremock::matchers::path("/v1/models"))
|
||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||
.mount(&server)
|
||
.await;
|
||
let ollama_client =
|
||
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
||
.await
|
||
.expect("probe OpenAI compat");
|
||
ollama_client
|
||
.probe_server()
|
||
.await
|
||
.expect("probe OpenAI compat");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_try_from_oss_provider_ok_when_server_running() {
|
||
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||
tracing::info!(
|
||
"{} set; skipping test_try_from_oss_provider_ok_when_server_running",
|
||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||
);
|
||
return;
|
||
}
|
||
|
||
let server = wiremock::MockServer::start().await;
|
||
|
||
// OpenAI‑compat models endpoint responds OK.
|
||
wiremock::Mock::given(wiremock::matchers::method("GET"))
|
||
.and(wiremock::matchers::path("/v1/models"))
|
||
.respond_with(wiremock::ResponseTemplate::new(200))
|
||
.mount(&server)
|
||
.await;
|
||
|
||
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
||
.await
|
||
.expect("client should be created when probe succeeds");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_try_from_oss_provider_err_when_server_missing() {
|
||
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||
tracing::info!(
|
||
"{} set; skipping test_try_from_oss_provider_err_when_server_missing",
|
||
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
|
||
);
|
||
return;
|
||
}
|
||
|
||
let server = wiremock::MockServer::start().await;
|
||
let err = OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
|
||
.await
|
||
.err()
|
||
.expect("expected error");
|
||
assert_eq!(OLLAMA_CONNECTION_ERROR, err.to_string());
|
||
}
|
||
}
|