feat: support the chat completions API in the Rust CLI (#862)
This is a substantial PR to add support for the chat completions API, which in turn makes it possible to use non-OpenAI model providers (just like in the TypeScript CLI): * It moves a number of structs from `client.rs` to `client_common.rs` so they can be shared. * It introduces support for the chat completions API in `chat_completions.rs`. * It updates `ModelProviderInfo` so that `env_key` is `Option<String>` instead of `String` (for e.g., ollama) and adds a `wire_api` field * It updates `client.rs` to choose between `stream_responses()` and `stream_chat_completions()` based on the `wire_api` for the `ModelProviderInfo` * It updates the `exec` and TUI CLIs to no longer fail if the `OPENAI_API_KEY` environment variable is not set * It updates the TUI so that `EventMsg::Error` is displayed more prominently when it occurs, particularly now that it is important to alert users to the `CodexErr::EnvVar` variant. * `CodexErr::EnvVar` was updated to include an optional `instructions` field so we can preserve the behavior where we direct users to https://platform.openai.com if `OPENAI_API_KEY` is not set. * Cleaned up the "welcome message" in the TUI to ensure the model provider is displayed. * Updated the docs in `codex-rs/README.md`. To exercise the chat completions API from OpenAI models, I added the following to my `config.toml`: ```toml model = "gpt-4o" model_provider = "openai-chat-completions" [model_providers.openai-chat-completions] name = "OpenAI using Chat Completions" base_url = "https://api.openai.com/v1" env_key = "OPENAI_API_KEY" wire_api = "chat" ``` Though to test a non-OpenAI provider, I installed ollama with mistral locally on my Mac because ChatGPT said that would be a good match for my hardware: ```shell brew install ollama ollama serve ollama pull mistral ``` Then I added the following to my `~/.codex/config.toml`: ```toml model = "mistral" model_provider = "ollama" ``` Note this code could certainly use more test coverage, but I want to get this in so folks can start playing with it. For reference, I believe https://github.com/openai/codex/pull/247 was roughly the comparable PR on the TypeScript side.
This commit is contained in:
310
codex-rs/core/src/chat_completions.rs
Normal file
310
codex-rs/core/src/chat_completions.rs
Normal file
@@ -0,0 +1,310 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::Bytes;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::json;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::util::backoff;
|
||||
|
||||
/// Implementation for the classic Chat Completions API. This is intentionally
|
||||
/// minimal: we only stream back plain assistant text.
|
||||
pub(crate) async fn stream_chat_completions(
|
||||
prompt: &Prompt,
|
||||
model: &str,
|
||||
client: &reqwest::Client,
|
||||
provider: &ModelProviderInfo,
|
||||
) -> Result<ResponseStream> {
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
|
||||
if let Some(instr) = &prompt.instructions {
|
||||
messages.push(json!({"role": "system", "content": instr}));
|
||||
}
|
||||
|
||||
for item in &prompt.input {
|
||||
if let ResponseItem::Message { role, content } = item {
|
||||
let mut text = String::new();
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
messages.push(json!({"role": role, "content": text}));
|
||||
}
|
||||
}
|
||||
|
||||
let payload = json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let base_url = provider.base_url.trim_end_matches('/');
|
||||
let url = format!("{}/chat/completions", base_url);
|
||||
|
||||
debug!(url, "POST (chat)");
|
||||
trace!("request payload: {}", payload);
|
||||
|
||||
let api_key = provider.api_key()?;
|
||||
let mut attempt = 0;
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
let mut req_builder = client.post(&url);
|
||||
if let Some(api_key) = &api_key {
|
||||
req_builder = req_builder.bearer_auth(api_key.clone());
|
||||
}
|
||||
let res = req_builder
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);
|
||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||
tokio::spawn(process_chat_sse(stream, tx_event));
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = res.status();
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
|
||||
let body = (res.text().await).unwrap_or_default();
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(CodexErr::RetryLimit(status));
|
||||
}
|
||||
|
||||
let retry_after_secs = res
|
||||
.headers()
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
|
||||
let delay = retry_after_secs
|
||||
.map(|s| Duration::from_millis(s * 1_000))
|
||||
.unwrap_or_else(|| backoff(attempt));
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
|
||||
return Err(e.into());
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Lightweight SSE processor for the Chat Completions streaming format. The
|
||||
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
|
||||
/// of the pipeline can stay agnostic of the underlying wire format.
|
||||
async fn process_chat_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
|
||||
where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
|
||||
loop {
|
||||
let sse = match timeout(idle_timeout, stream.next()).await {
|
||||
Ok(Some(Ok(ev))) => ev,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event.send(Err(CodexErr::Stream(e.to_string()))).await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// Stream closed gracefully – emit Completed with dummy id.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream("idle timeout waiting for SSE".into())))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
|
||||
if sse.data.trim() == "[DONE]" {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse JSON chunk
|
||||
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let content_opt = chunk
|
||||
.get("choices")
|
||||
.and_then(|c| c.get(0))
|
||||
.and_then(|c| c.get("delta"))
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str());
|
||||
|
||||
if let Some(content) = content_opt {
|
||||
let item = ResponseItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: content.to_string(),
|
||||
}],
|
||||
};
|
||||
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Optional client-side aggregation helper
|
||||
///
|
||||
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
|
||||
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
|
||||
/// per-token deltas**. The stream stays silent while the model is thinking
|
||||
/// and only emits two events per turn:
|
||||
///
|
||||
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
|
||||
/// (fully concatenated).
|
||||
/// 2. The original `ResponseEvent::Completed` right after it.
|
||||
///
|
||||
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
|
||||
///
|
||||
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
|
||||
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
|
||||
/// events.
|
||||
pub(crate) struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
pending_completed: Option<ResponseEvent>,
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// First, flush any buffered Completed event from the previous call.
|
||||
if let Some(ev) = this.pending_completed.take() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
// Accumulate *assistant* text but do not emit yet.
|
||||
if let crate::models::ResponseItem::Message { role, content } = &item {
|
||||
if role == "assistant" {
|
||||
if let Some(text) = content.iter().find_map(|c| match c {
|
||||
crate::models::ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
}) {
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Swallow partial event; keep polling.
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_item = crate::models::ResponseItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![crate::models::ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
|
||||
// Buffer Completed so it is returned *after* the aggregated message.
|
||||
this.pending_completed = Some(ResponseEvent::Completed { response_id });
|
||||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
aggregated_item,
|
||||
))));
|
||||
}
|
||||
|
||||
// Nothing aggregated – forward Completed directly.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id })));
|
||||
} // No other `Ok` variants exist at the moment, continue polling.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
|
||||
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
/// Returns a new stream that emits **only** the final assistant message
|
||||
/// per turn instead of every incremental delta. The produced
|
||||
/// `ResponseEvent` sequence for a typical text turn looks like:
|
||||
///
|
||||
/// ```ignore
|
||||
/// OutputItemDone(<full message>)
|
||||
/// Completed { .. }
|
||||
/// ```
|
||||
///
|
||||
/// No other `OutputItemDone` events will be seen by the caller.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
/// ```ignore
|
||||
/// let agg_stream = client.stream(&prompt).await?.aggregate();
|
||||
/// while let Some(event) = agg_stream.next().await {
|
||||
/// // event now contains cumulative text
|
||||
/// }
|
||||
/// ```
|
||||
fn aggregate(self) -> AggregatedChatStream<Self> {
|
||||
AggregatedChatStream {
|
||||
inner: self,
|
||||
cumulative: String::new(),
|
||||
pending_completed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
@@ -1,11 +1,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::sync::LazyLock;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::Bytes;
|
||||
@@ -23,66 +19,22 @@ use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::chat_completions::stream_chat_completions;
|
||||
use crate::client_common::Payload;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::Reasoning;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::util::backoff;
|
||||
|
||||
/// API request payload for a single model turn.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
/// Optional previous response ID (when storage is enabled).
|
||||
pub prev_id: Option<String>,
|
||||
/// Optional initial instructions (only sent on first turn).
|
||||
pub instructions: Option<String>,
|
||||
/// Whether to store response on server side (disable_response_storage = !store).
|
||||
pub store: bool,
|
||||
|
||||
/// Additional tools sourced from external MCP servers. Note each key is
|
||||
/// the "fully qualified" tool name (i.e., prefixed with the server name),
|
||||
/// which should be reported to the model in place of Tool::name.
|
||||
pub extra_tools: HashMap<String, mcp_types::Tool>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
OutputItemDone(ResponseItem),
|
||||
Completed { response_id: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Payload<'a> {
|
||||
model: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
instructions: Option<&'a String>,
|
||||
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
|
||||
// we code defensively to avoid this case, but perhaps we should use a
|
||||
// separate enum for serialization.
|
||||
input: &'a Vec<ResponseItem>,
|
||||
tools: &'a [serde_json::Value],
|
||||
tool_choice: &'static str,
|
||||
parallel_tool_calls: bool,
|
||||
reasoning: Option<Reasoning>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
previous_response_id: Option<String>,
|
||||
/// true when using the Responses API.
|
||||
store: bool,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Reasoning {
|
||||
effort: &'static str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
generate_summary: Option<bool>,
|
||||
}
|
||||
|
||||
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
|
||||
/// Responses API.
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -152,7 +104,20 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
/// Dispatches to either the Responses or Chat implementation depending on
|
||||
/// the provider config. Public callers always invoke `stream()` – the
|
||||
/// specialised helpers are private to avoid accidental misuse.
|
||||
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
match self.provider.wire_api {
|
||||
WireApi::Responses => self.stream_responses(prompt).await,
|
||||
WireApi::Chat => {
|
||||
stream_chat_completions(prompt, &self.model, &self.client, &self.provider).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation for the OpenAI *Responses* experimental API.
|
||||
async fn stream_responses(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
|
||||
// short circuit for tests
|
||||
warn!(path, "Streaming from fixture");
|
||||
@@ -202,8 +167,8 @@ impl ModelClient {
|
||||
|
||||
let api_key = self
|
||||
.provider
|
||||
.api_key()
|
||||
.ok_or_else(|| crate::error::CodexErr::EnvVar("API_KEY"))?;
|
||||
.api_key()?
|
||||
.expect("Repsones API requires an API key");
|
||||
let res = self
|
||||
.client
|
||||
.post(&url)
|
||||
@@ -396,18 +361,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResponseStream {
|
||||
rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
|
||||
impl Stream for ResponseStream {
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// used in tests to stream from a text SSE file
|
||||
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);
|
||||
|
||||
72
codex-rs/core/src/client_common.rs
Normal file
72
codex-rs/core/src/client_common.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use crate::error::Result;
|
||||
use crate::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// API request payload for a single model turn.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
/// Optional previous response ID (when storage is enabled).
|
||||
pub prev_id: Option<String>,
|
||||
/// Optional initial instructions (only sent on first turn).
|
||||
pub instructions: Option<String>,
|
||||
/// Whether to store response on server side (disable_response_storage = !store).
|
||||
pub store: bool,
|
||||
|
||||
/// Additional tools sourced from external MCP servers. Note each key is
|
||||
/// the "fully qualified" tool name (i.e., prefixed with the server name),
|
||||
/// which should be reported to the model in place of Tool::name.
|
||||
pub extra_tools: HashMap<String, mcp_types::Tool>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
OutputItemDone(ResponseItem),
|
||||
Completed { response_id: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct Reasoning {
|
||||
pub(crate) effort: &'static str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) generate_summary: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct Payload<'a> {
|
||||
pub(crate) model: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) instructions: Option<&'a String>,
|
||||
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
|
||||
// we code defensively to avoid this case, but perhaps we should use a
|
||||
// separate enum for serialization.
|
||||
pub(crate) input: &'a Vec<ResponseItem>,
|
||||
pub(crate) tools: &'a [serde_json::Value],
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) previous_response_id: Option<String>,
|
||||
/// true when using the Responses API.
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct ResponseStream {
|
||||
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
|
||||
impl Stream for ResponseStream {
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
@@ -31,10 +31,13 @@ use tracing::info;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::WireApi;
|
||||
use crate::chat_completions::AggregateStreamExt;
|
||||
use crate::client::ModelClient;
|
||||
use crate::client::Prompt;
|
||||
use crate::client::ResponseEvent;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::exec::ExecParams;
|
||||
@@ -65,7 +68,6 @@ use crate::safety::assess_command_safety;
|
||||
use crate::safety::assess_patch_safety;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
use crate::zdr_transcript::ZdrTranscript;
|
||||
|
||||
/// The high-level interface to the Codex system.
|
||||
/// It operates as a queue pair where you send submissions and receive events.
|
||||
@@ -181,7 +183,7 @@ struct State {
|
||||
previous_response_id: Option<String>,
|
||||
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
zdr_transcript: Option<ZdrTranscript>,
|
||||
zdr_transcript: Option<ConversationHistory>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -416,11 +418,15 @@ impl Drop for Session {
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn partial_clone(&self) -> Self {
|
||||
pub fn partial_clone(&self, retain_zdr_transcript: bool) -> Self {
|
||||
Self {
|
||||
approved_commands: self.approved_commands.clone(),
|
||||
previous_response_id: self.previous_response_id.clone(),
|
||||
zdr_transcript: self.zdr_transcript.clone(),
|
||||
zdr_transcript: if retain_zdr_transcript {
|
||||
self.zdr_transcript.clone()
|
||||
} else {
|
||||
None
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
@@ -534,14 +540,19 @@ async fn submission_loop(
|
||||
let client = ModelClient::new(model.clone(), provider.clone());
|
||||
|
||||
// abort any current running session and clone its state
|
||||
let retain_zdr_transcript =
|
||||
record_conversation_history(disable_response_storage, provider.wire_api);
|
||||
let state = match sess.take() {
|
||||
Some(sess) => {
|
||||
sess.abort();
|
||||
sess.state.lock().unwrap().partial_clone()
|
||||
sess.state
|
||||
.lock()
|
||||
.unwrap()
|
||||
.partial_clone(retain_zdr_transcript)
|
||||
}
|
||||
None => State {
|
||||
zdr_transcript: if disable_response_storage {
|
||||
Some(ZdrTranscript::new())
|
||||
zdr_transcript: if retain_zdr_transcript {
|
||||
Some(ConversationHistory::new())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
@@ -670,21 +681,35 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from);
|
||||
net_new_turn_input.extend(pending_input);
|
||||
|
||||
// Persist only the net-new items of this turn to the rollout.
|
||||
sess.record_rollout_items(&net_new_turn_input).await;
|
||||
|
||||
// Construct the input that we will send to the model. When using the
|
||||
// Chat completions API (or ZDR clients), the model needs the full
|
||||
// conversation history on each turn. The rollout file, however, should
|
||||
// only record the new items that originated in this turn so that it
|
||||
// represents an append-only log without duplicates.
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
// If we are using ZDR, we need to send the transcript with every turn.
|
||||
let mut full_transcript = transcript.contents();
|
||||
full_transcript.extend(net_new_turn_input.clone());
|
||||
// If we are using Chat/ZDR, we need to send the transcript with every turn.
|
||||
|
||||
// 1. Build up the conversation history for the next turn.
|
||||
let full_transcript = [transcript.contents(), net_new_turn_input.clone()].concat();
|
||||
|
||||
// 2. Update the in-memory transcript so that future turns
|
||||
// include these items as part of the history.
|
||||
transcript.record_items(net_new_turn_input);
|
||||
|
||||
// Note that `transcript.record_items()` does some filtering
|
||||
// such that `full_transcript` may include items that were
|
||||
// excluded from `transcript`.
|
||||
full_transcript
|
||||
} else {
|
||||
// Responses API path – we can just send the new items and
|
||||
// record the same.
|
||||
net_new_turn_input
|
||||
};
|
||||
|
||||
// Persist the input part of the turn to the rollout (user messages /
|
||||
// function_call_output from previous step).
|
||||
sess.record_rollout_items(&turn_input).await;
|
||||
|
||||
let turn_input_messages: Vec<String> = turn_input
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
@@ -794,6 +819,7 @@ async fn run_turn(
|
||||
match try_run_turn(sess, &sub_id, &prompt).await {
|
||||
Ok(output) => return Ok(output),
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e) => {
|
||||
if retries < *OPENAI_STREAM_MAX_RETRIES {
|
||||
retries += 1;
|
||||
@@ -838,7 +864,7 @@ async fn try_run_turn(
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
||||
let mut stream = sess.client.clone().stream(prompt).await?.aggregate();
|
||||
|
||||
// Buffer all the incoming messages from the stream first, then execute them.
|
||||
// If we execute a function call in the middle of handling the stream, it can time out.
|
||||
@@ -1612,3 +1638,15 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// See [`ConversationHistory`] for details.
|
||||
fn record_conversation_history(disable_response_storage: bool, wire_api: WireApi) -> bool {
|
||||
if disable_response_storage {
|
||||
return true;
|
||||
}
|
||||
|
||||
match wire_api {
|
||||
WireApi::Responses => false,
|
||||
WireApi::Chat => true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,9 @@ pub struct Config {
|
||||
/// Optional override of model selection.
|
||||
pub model: String,
|
||||
|
||||
/// Key into the model_providers map that specifies which provider to use.
|
||||
pub model_provider_id: String,
|
||||
|
||||
/// Info needed to make an API request to the model.
|
||||
pub model_provider: ModelProviderInfo,
|
||||
|
||||
@@ -219,21 +222,22 @@ impl Config {
|
||||
model_providers.entry(key).or_insert(provider);
|
||||
}
|
||||
|
||||
let model_provider_name = provider
|
||||
let model_provider_id = provider
|
||||
.or(cfg.model_provider)
|
||||
.unwrap_or_else(|| "openai".to_string());
|
||||
let model_provider = model_providers
|
||||
.get(&model_provider_name)
|
||||
.get(&model_provider_id)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
format!("Model provider `{model_provider_name}` not found"),
|
||||
format!("Model provider `{model_provider_id}` not found"),
|
||||
)
|
||||
})?
|
||||
.clone();
|
||||
|
||||
let config = Self {
|
||||
model: model.or(cfg.model).unwrap_or_else(default_model),
|
||||
model_provider_id,
|
||||
model_provider,
|
||||
cwd: cwd.map_or_else(
|
||||
|| {
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
use crate::models::ResponseItem;
|
||||
|
||||
/// Transcript that needs to be maintained for ZDR clients for which
|
||||
/// previous_response_id is not available, so we must include the transcript
|
||||
/// with every API call. This must include each `function_call` and its
|
||||
/// corresponding `function_call_output`.
|
||||
/// Transcript of conversation history that is needed:
|
||||
/// - for ZDR clients for which previous_response_id is not available, so we
|
||||
/// must include the transcript with every API call. This must include each
|
||||
/// `function_call` and its corresponding `function_call_output`.
|
||||
/// - for clients using the "chat completions" API as opposed to the
|
||||
/// "responses" API.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ZdrTranscript {
|
||||
pub(crate) struct ConversationHistory {
|
||||
/// The oldest items are at the beginning of the vector.
|
||||
items: Vec<ResponseItem>,
|
||||
}
|
||||
|
||||
impl ZdrTranscript {
|
||||
impl ConversationHistory {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self { items: Vec::new() }
|
||||
}
|
||||
@@ -55,7 +55,7 @@ pub enum CodexErr {
|
||||
|
||||
/// Returned by run_command_stream when the user pressed Ctrl‑C (SIGINT). Session uses this to
|
||||
/// surface a polite FunctionCallOutput back to the model instead of crashing the CLI.
|
||||
#[error("interrupted (Ctrl‑C)")]
|
||||
#[error("interrupted (Ctrl-C)")]
|
||||
Interrupted,
|
||||
|
||||
/// Unexpected HTTP status code.
|
||||
@@ -97,8 +97,28 @@ pub enum CodexErr {
|
||||
#[error(transparent)]
|
||||
TokioJoin(#[from] JoinError),
|
||||
|
||||
#[error("missing environment variable {0}")]
|
||||
EnvVar(&'static str),
|
||||
#[error("{0}")]
|
||||
EnvVar(EnvVarError),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EnvVarError {
|
||||
/// Name of the environment variable that is missing.
|
||||
pub var: String,
|
||||
|
||||
/// Optional instructions to help the user get a valid value for the
|
||||
/// variable and set it.
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EnvVarError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Missing environment variable: `{}`.", self.var)?;
|
||||
if let Some(instructions) = &self.instructions {
|
||||
write!(f, " {instructions}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexErr {
|
||||
|
||||
@@ -5,11 +5,15 @@
|
||||
// the TUI or the tracing stack).
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
mod chat_completions;
|
||||
|
||||
mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
pub use codex::Codex;
|
||||
pub mod codex_wrapper;
|
||||
pub mod config;
|
||||
mod conversation_history;
|
||||
pub mod error;
|
||||
pub mod exec;
|
||||
mod flags;
|
||||
@@ -21,10 +25,10 @@ pub mod mcp_server_config;
|
||||
mod mcp_tool_call;
|
||||
mod model_provider_info;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
pub use model_provider_info::WireApi;
|
||||
mod models;
|
||||
pub mod protocol;
|
||||
mod rollout;
|
||||
mod safety;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
mod zdr_transcript;
|
||||
|
||||
@@ -8,6 +8,25 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::env::VarError;
|
||||
|
||||
use crate::error::EnvVarError;
|
||||
|
||||
/// Wire protocol that the provider speaks. Most third-party services only
|
||||
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
|
||||
/// itself (and a handful of others) additionally expose the more modern
|
||||
/// *Responses* API. The two protocols use different request/response shapes
|
||||
/// and *cannot* be auto-detected at runtime, therefore each provider entry
|
||||
/// must declare which one it expects.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireApi {
|
||||
/// The experimental “Responses” API exposed by OpenAI at `/v1/responses`.
|
||||
#[default]
|
||||
Responses,
|
||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||
Chat,
|
||||
}
|
||||
|
||||
/// Serializable representation of a provider definition.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -17,13 +36,38 @@ pub struct ModelProviderInfo {
|
||||
/// Base URL for the provider's OpenAI-compatible API.
|
||||
pub base_url: String,
|
||||
/// Environment variable that stores the user's API key for this provider.
|
||||
pub env_key: String,
|
||||
pub env_key: Option<String>,
|
||||
|
||||
/// Optional instructions to help the user get a valid value for the
|
||||
/// variable and set it.
|
||||
pub env_key_instructions: Option<String>,
|
||||
|
||||
/// Which wire protocol this provider expects.
|
||||
pub wire_api: WireApi,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
/// Returns the API key for this provider if present in the environment.
|
||||
pub fn api_key(&self) -> Option<String> {
|
||||
std::env::var(&self.env_key).ok()
|
||||
/// If `env_key` is Some, returns the API key for this provider if present
|
||||
/// (and non-empty) in the environment. If `env_key` is required but
|
||||
/// cannot be found, returns an error.
|
||||
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
|
||||
match &self.env_key {
|
||||
Some(env_key) => std::env::var(env_key)
|
||||
.and_then(|v| {
|
||||
if v.trim().is_empty() {
|
||||
Err(VarError::NotPresent)
|
||||
} else {
|
||||
Ok(Some(v))
|
||||
}
|
||||
})
|
||||
.map_err(|_| {
|
||||
crate::error::CodexErr::EnvVar(EnvVarError {
|
||||
var: env_key.clone(),
|
||||
instructions: self.env_key_instructions.clone(),
|
||||
})
|
||||
}),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +81,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
P {
|
||||
name: "OpenAI".into(),
|
||||
base_url: "https://api.openai.com/v1".into(),
|
||||
env_key: "OPENAI_API_KEY".into(),
|
||||
env_key: Some("OPENAI_API_KEY".into()),
|
||||
env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()),
|
||||
wire_api: WireApi::Responses,
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -45,7 +91,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
P {
|
||||
name: "OpenRouter".into(),
|
||||
base_url: "https://openrouter.ai/api/v1".into(),
|
||||
env_key: "OPENROUTER_API_KEY".into(),
|
||||
env_key: Some("OPENROUTER_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -53,7 +101,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
P {
|
||||
name: "Gemini".into(),
|
||||
base_url: "https://generativelanguage.googleapis.com/v1beta/openai".into(),
|
||||
env_key: "GEMINI_API_KEY".into(),
|
||||
env_key: Some("GEMINI_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -61,7 +111,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
P {
|
||||
name: "Ollama".into(),
|
||||
base_url: "http://localhost:11434/v1".into(),
|
||||
env_key: "OLLAMA_API_KEY".into(),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -69,7 +121,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
P {
|
||||
name: "Mistral".into(),
|
||||
base_url: "https://api.mistral.ai/v1".into(),
|
||||
env_key: "MISTRAL_API_KEY".into(),
|
||||
env_key: Some("MISTRAL_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -77,7 +131,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
P {
|
||||
name: "DeepSeek".into(),
|
||||
base_url: "https://api.deepseek.com".into(),
|
||||
env_key: "DEEPSEEK_API_KEY".into(),
|
||||
env_key: Some("DEEPSEEK_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -85,7 +141,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
P {
|
||||
name: "xAI".into(),
|
||||
base_url: "https://api.x.ai/v1".into(),
|
||||
env_key: "XAI_API_KEY".into(),
|
||||
env_key: Some("XAI_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
},
|
||||
),
|
||||
(
|
||||
@@ -93,7 +151,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
P {
|
||||
name: "Groq".into(),
|
||||
base_url: "https://api.groq.com/openai/v1".into(),
|
||||
env_key: "GROQ_API_KEY".into(),
|
||||
env_key: Some("GROQ_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@@ -116,10 +116,10 @@ pub struct ShellToolCallParams {
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[expect(dead_code)]
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct FunctionCallOutputPayload {
|
||||
pub content: String,
|
||||
#[expect(dead_code)]
|
||||
pub success: Option<bool>,
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ pub struct Submission {
|
||||
/// Submission operation
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[non_exhaustive]
|
||||
pub enum Op {
|
||||
/// Configure the model session.
|
||||
|
||||
@@ -92,7 +92,9 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: "PATH".into(),
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
};
|
||||
|
||||
// Init session
|
||||
|
||||
@@ -82,7 +82,9 @@ async fn retries_on_early_close() {
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: "PATH".into(),
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
};
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
|
||||
Reference in New Issue
Block a user