Files
llmx/codex-rs/ollama/src/client.rs
jif-oai be366a31ab chore: clippy on redundant closure (#4058)
Add redundant closure clippy rules and let Codex fix it by minimising
FQP
2025-09-22 19:30:16 +00:00

350 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use bytes::BytesMut;
use futures::StreamExt;
use futures::stream::BoxStream;
use serde_json::Value as JsonValue;
use std::collections::VecDeque;
use std::io;
use crate::parser::pull_events_from_value;
use crate::pull::PullEvent;
use crate::pull::PullProgressReporter;
use crate::url::base_url_to_host_root;
use crate::url::is_openai_compatible_base_url;
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
use codex_core::ModelProviderInfo;
use codex_core::WireApi;
use codex_core::config::Config;
const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama";
/// Client for interacting with a local Ollama instance.
pub struct OllamaClient {
client: reqwest::Client,
host_root: String,
uses_openai_compat: bool,
}
impl OllamaClient {
/// Construct a client for the builtin opensource ("oss") model provider
/// and verify that a local Ollama server is reachable. If no server is
/// detected, returns an error with helpful installation/run instructions.
pub async fn try_from_oss_provider(config: &Config) -> io::Result<Self> {
// Note that we must look up the provider from the Config to ensure that
// any overrides the user has in their config.toml are taken into
// account.
let provider = config
.model_providers
.get(BUILT_IN_OSS_MODEL_PROVIDER_ID)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("Built-in provider {BUILT_IN_OSS_MODEL_PROVIDER_ID} not found",),
)
})?;
Self::try_from_provider(provider).await
}
#[cfg(test)]
async fn try_from_provider_with_base_url(base_url: &str) -> io::Result<Self> {
let provider = codex_core::create_oss_provider_with_base_url(base_url);
Self::try_from_provider(&provider).await
}
/// Build a client from a provider definition and verify the server is reachable.
async fn try_from_provider(provider: &ModelProviderInfo) -> io::Result<Self> {
#![expect(clippy::expect_used)]
let base_url = provider
.base_url
.as_ref()
.expect("oss provider must have a base_url");
let uses_openai_compat = is_openai_compatible_base_url(base_url)
|| matches!(provider.wire_api, WireApi::Chat)
&& is_openai_compatible_base_url(base_url);
let host_root = base_url_to_host_root(base_url);
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
let client = Self {
client,
host_root,
uses_openai_compat,
};
client.probe_server().await?;
Ok(client)
}
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
async fn probe_server(&self) -> io::Result<()> {
let url = if self.uses_openai_compat {
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
} else {
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
};
let resp = self.client.get(url).send().await.map_err(|err| {
tracing::warn!("Failed to connect to Ollama server: {err:?}");
io::Error::other(OLLAMA_CONNECTION_ERROR)
})?;
if resp.status().is_success() {
Ok(())
} else {
tracing::warn!(
"Failed to probe server at {}: HTTP {}",
self.host_root,
resp.status()
);
Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
}
}
/// Return the list of model names known to the local Ollama instance.
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
let resp = self
.client
.get(tags_url)
.send()
.await
.map_err(io::Error::other)?;
if !resp.status().is_success() {
return Ok(Vec::new());
}
let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
let names = val
.get("models")
.and_then(|m| m.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
.map(str::to_string)
.collect::<Vec<_>>()
})
.unwrap_or_default();
Ok(names)
}
/// Start a model pull and emit streaming events. The returned stream ends when
/// a Success event is observed or the server closes the connection.
pub async fn pull_model_stream(
&self,
model: &str,
) -> io::Result<BoxStream<'static, PullEvent>> {
let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
let resp = self
.client
.post(url)
.json(&serde_json::json!({"model": model, "stream": true}))
.send()
.await
.map_err(io::Error::other)?;
if !resp.status().is_success() {
return Err(io::Error::other(format!(
"failed to start pull: HTTP {}",
resp.status()
)));
}
let mut stream = resp.bytes_stream();
let mut buf = BytesMut::new();
let _pending: VecDeque<PullEvent> = VecDeque::new();
// Using an async stream adaptor backed by unfold-like manual loop.
let s = async_stream::stream! {
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
buf.extend_from_slice(&bytes);
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
let line = buf.split_to(pos + 1);
if let Ok(text) = std::str::from_utf8(&line) {
let text = text.trim();
if text.is_empty() { continue; }
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
for ev in pull_events_from_value(&value) { yield ev; }
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
yield PullEvent::Error(err_msg.to_string());
return;
}
if let Some(status) = value.get("status").and_then(|s| s.as_str())
&& status == "success" { yield PullEvent::Success; return; }
}
}
}
}
Err(_) => {
// Connection error: end the stream.
return;
}
}
}
};
Ok(Box::pin(s))
}
/// High-level helper to pull a model and drive a progress reporter.
pub async fn pull_with_reporter(
&self,
model: &str,
reporter: &mut dyn PullProgressReporter,
) -> io::Result<()> {
reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
let mut stream = self.pull_model_stream(model).await?;
while let Some(event) = stream.next().await {
reporter.on_event(&event)?;
match event {
PullEvent::Success => {
return Ok(());
}
PullEvent::Error(err) => {
// Empirically, ollama returns a 200 OK response even when
// the output stream includes an error message. Verify with:
//
// `curl -i http://localhost:11434/api/pull -d '{ "model": "foobarbaz" }'`
//
// As such, we have to check the event stream, not the
// HTTP response status, to determine whether to return Err.
return Err(io::Error::other(format!("Pull failed: {err}")));
}
PullEvent::ChunkProgress { .. } | PullEvent::Status(_) => {
continue;
}
}
}
Err(io::Error::other(
"Pull stream ended unexpectedly without success.",
))
}
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
#[cfg(test)]
fn from_host_root(host_root: impl Into<String>) -> Self {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
Self {
client,
host_root: host_root.into(),
uses_openai_compat: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
// Happy-path tests using a mock HTTP server; skip if sandbox network is disabled.
#[tokio::test]
async fn test_fetch_models_happy_path() {
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} is set; skipping test_fetch_models_happy_path",
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/api/tags"))
.respond_with(
wiremock::ResponseTemplate::new(200).set_body_raw(
serde_json::json!({
"models": [ {"name": "llama3.2:3b"}, {"name":"mistral"} ]
})
.to_string(),
"application/json",
),
)
.mount(&server)
.await;
let client = OllamaClient::from_host_root(server.uri());
let models = client.fetch_models().await.expect("fetch models");
assert!(models.contains(&"llama3.2:3b".to_string()));
assert!(models.contains(&"mistral".to_string()));
}
#[tokio::test]
async fn test_probe_server_happy_path_openai_compat_and_native() {
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} set; skipping test_probe_server_happy_path_openai_compat_and_native",
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
// Native endpoint
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/api/tags"))
.respond_with(wiremock::ResponseTemplate::new(200))
.mount(&server)
.await;
let native = OllamaClient::from_host_root(server.uri());
native.probe_server().await.expect("probe native");
// OpenAI compatibility endpoint
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/v1/models"))
.respond_with(wiremock::ResponseTemplate::new(200))
.mount(&server)
.await;
let ollama_client =
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
.await
.expect("probe OpenAI compat");
ollama_client
.probe_server()
.await
.expect("probe OpenAI compat");
}
#[tokio::test]
async fn test_try_from_oss_provider_ok_when_server_running() {
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} set; skipping test_try_from_oss_provider_ok_when_server_running",
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
// OpenAIcompat models endpoint responds OK.
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/v1/models"))
.respond_with(wiremock::ResponseTemplate::new(200))
.mount(&server)
.await;
OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
.await
.expect("client should be created when probe succeeds");
}
#[tokio::test]
async fn test_try_from_oss_provider_err_when_server_missing() {
if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
tracing::info!(
"{} set; skipping test_try_from_oss_provider_err_when_server_missing",
codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR
);
return;
}
let server = wiremock::MockServer::start().await;
let err = OllamaClient::try_from_provider_with_base_url(&format!("{}/v1", server.uri()))
.await
.err()
.expect("expected error");
assert_eq!(OLLAMA_CONNECTION_ERROR, err.to_string());
}
}