From e924070cee27eb433cb83c4ec962f942b8776fc0 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Thu, 8 May 2025 21:46:06 -0700 Subject: [PATCH] feat: support the chat completions API in the Rust CLI (#862) This is a substantial PR to add support for the chat completions API, which in turn makes it possible to use non-OpenAI model providers (just like in the TypeScript CLI): * It moves a number of structs from `client.rs` to `client_common.rs` so they can be shared. * It introduces support for the chat completions API in `chat_completions.rs`. * It updates `ModelProviderInfo` so that `env_key` is `Option` instead of `String` (for e.g., ollama) and adds a `wire_api` field * It updates `client.rs` to choose between `stream_responses()` and `stream_chat_completions()` based on the `wire_api` for the `ModelProviderInfo` * It updates the `exec` and TUI CLIs to no longer fail if the `OPENAI_API_KEY` environment variable is not set * It updates the TUI so that `EventMsg::Error` is displayed more prominently when it occurs, particularly now that it is important to alert users to the `CodexErr::EnvVar` variant. * `CodexErr::EnvVar` was updated to include an optional `instructions` field so we can preserve the behavior where we direct users to https://platform.openai.com if `OPENAI_API_KEY` is not set. * Cleaned up the "welcome message" in the TUI to ensure the model provider is displayed. * Updated the docs in `codex-rs/README.md`. To exercise the chat completions API from OpenAI models, I added the following to my `config.toml`: ```toml model = "gpt-4o" model_provider = "openai-chat-completions" [model_providers.openai-chat-completions] name = "OpenAI using Chat Completions" base_url = "https://api.openai.com/v1" env_key = "OPENAI_API_KEY" wire_api = "chat" ``` Though to test a non-OpenAI provider, I installed ollama with mistral locally on my Mac because ChatGPT said that would be a good match for my hardware: ```shell brew install ollama ollama serve ollama pull mistral ``` Then I added the following to my `~/.codex/config.toml`: ```toml model = "mistral" model_provider = "ollama" ``` Note this code could certainly use more test coverage, but I want to get this in so folks can start playing with it. For reference, I believe https://github.com/openai/codex/pull/247 was roughly the comparable PR on the TypeScript side. --- codex-rs/README.md | 55 ++++ codex-rs/core/src/chat_completions.rs | 310 ++++++++++++++++++ codex-rs/core/src/client.rs | 93 ++---- codex-rs/core/src/client_common.rs | 72 ++++ codex-rs/core/src/codex.rs | 72 +++- codex-rs/core/src/config.rs | 10 +- ..._transcript.rs => conversation_history.rs} | 14 +- codex-rs/core/src/error.rs | 26 +- codex-rs/core/src/lib.rs | 6 +- codex-rs/core/src/model_provider_info.rs | 84 ++++- codex-rs/core/src/models.rs | 2 +- codex-rs/core/src/protocol.rs | 1 + codex-rs/core/tests/previous_response_id.rs | 4 +- codex-rs/core/tests/stream_no_completed.rs | 4 +- codex-rs/exec/src/lib.rs | 39 --- codex-rs/tui/src/app_event.rs | 1 + codex-rs/tui/src/chatwidget.rs | 13 +- .../tui/src/conversation_history_widget.rs | 8 + codex-rs/tui/src/history_cell.rs | 73 +++-- codex-rs/tui/src/lib.rs | 16 - 20 files changed, 703 insertions(+), 200 deletions(-) create mode 100644 codex-rs/core/src/chat_completions.rs create mode 100644 codex-rs/core/src/client_common.rs rename codex-rs/core/src/{zdr_transcript.rs => conversation_history.rs} (72%) diff --git a/codex-rs/README.md b/codex-rs/README.md index f5a1e24d..d49a5949 100644 --- a/codex-rs/README.md +++ b/codex-rs/README.md @@ -33,6 +33,61 @@ The model that Codex should use. model = "o3" # overrides the default of "o4-mini" ``` +### model_provider + +Codex comes bundled with a number of "model providers" predefined. This config value is a string that indicates which provider to use. You can also define your own providers via `model_providers`. + +For example, if you are running ollama with Mistral locally, then you would need to add the following to your config: + +```toml +model = "mistral" +model_provider = "ollama" +``` + +because the following definition for `ollama` is included in Codex: + +```toml +[model_providers.ollama] +name = "Ollama" +base_url = "http://localhost:11434/v1" +wire_api = "chat" +``` + +This option defaults to `"openai"` and the corresponding provider is defined as follows: + +```toml +[model_providers.openai] +name = "OpenAI" +base_url = "https://api.openai.com/v1" +env_key = "OPENAI_API_KEY" +wire_api = "responses" +``` + +### model_providers + +This option lets you override and amend the default set of model providers bundled with Codex. This value is a map where the key is the value to use with `model_provider` to select the correspodning provider. + +For example, if you wanted to add a provider that uses the OpenAI 4o model via the chat completions API, then you + +```toml +# Recall that in TOML, root keys must be listed before tables. +model = "gpt-4o" +model_provider = "openai-chat-completions" + +[model_providers.openai-chat-completions] +# Name of the provider that will be displayed in the Codex UI. +name = "OpenAI using Chat Completions" +# The path `/chat/completions` will be amended to this URL to make the POST +# request for the chat completions. +base_url = "https://api.openai.com/v1" +# If `env_key` is set, identifies an environment variable that must be set when +# using Codex with this provider. The value of the environment variable must be +# non-empty and will be used in the `Bearer TOKEN` HTTP header for the POST request. +env_key = "OPENAI_API_KEY" +# valid values for wire_api are "chat" and "responses". +wire_api = "chat" +``` + ### approval_policy Determines when the user should be prompted to approve whether Codex can execute a command: diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs new file mode 100644 index 00000000..8e818c2f --- /dev/null +++ b/codex-rs/core/src/chat_completions.rs @@ -0,0 +1,310 @@ +use std::time::Duration; + +use bytes::Bytes; +use eventsource_stream::Eventsource; +use futures::Stream; +use futures::StreamExt; +use futures::TryStreamExt; +use reqwest::StatusCode; +use serde_json::json; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; +use tokio::sync::mpsc; +use tokio::time::timeout; +use tracing::debug; +use tracing::trace; + +use crate::ModelProviderInfo; +use crate::client_common::Prompt; +use crate::client_common::ResponseEvent; +use crate::client_common::ResponseStream; +use crate::error::CodexErr; +use crate::error::Result; +use crate::flags::OPENAI_REQUEST_MAX_RETRIES; +use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; +use crate::models::ContentItem; +use crate::models::ResponseItem; +use crate::util::backoff; + +/// Implementation for the classic Chat Completions API. This is intentionally +/// minimal: we only stream back plain assistant text. +pub(crate) async fn stream_chat_completions( + prompt: &Prompt, + model: &str, + client: &reqwest::Client, + provider: &ModelProviderInfo, +) -> Result { + // Build messages array + let mut messages = Vec::::new(); + + if let Some(instr) = &prompt.instructions { + messages.push(json!({"role": "system", "content": instr})); + } + + for item in &prompt.input { + if let ResponseItem::Message { role, content } = item { + let mut text = String::new(); + for c in content { + match c { + ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => { + text.push_str(t); + } + _ => {} + } + } + messages.push(json!({"role": role, "content": text})); + } + } + + let payload = json!({ + "model": model, + "messages": messages, + "stream": true + }); + + let base_url = provider.base_url.trim_end_matches('/'); + let url = format!("{}/chat/completions", base_url); + + debug!(url, "POST (chat)"); + trace!("request payload: {}", payload); + + let api_key = provider.api_key()?; + let mut attempt = 0; + loop { + attempt += 1; + + let mut req_builder = client.post(&url); + if let Some(api_key) = &api_key { + req_builder = req_builder.bearer_auth(api_key.clone()); + } + let res = req_builder + .header(reqwest::header::ACCEPT, "text/event-stream") + .json(&payload) + .send() + .await; + + match res { + Ok(resp) if resp.status().is_success() => { + let (tx_event, rx_event) = mpsc::channel::>(16); + let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); + tokio::spawn(process_chat_sse(stream, tx_event)); + return Ok(ResponseStream { rx_event }); + } + Ok(res) => { + let status = res.status(); + if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) { + let body = (res.text().await).unwrap_or_default(); + return Err(CodexErr::UnexpectedStatus(status, body)); + } + + if attempt > *OPENAI_REQUEST_MAX_RETRIES { + return Err(CodexErr::RetryLimit(status)); + } + + let retry_after_secs = res + .headers() + .get(reqwest::header::RETRY_AFTER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + + let delay = retry_after_secs + .map(|s| Duration::from_millis(s * 1_000)) + .unwrap_or_else(|| backoff(attempt)); + tokio::time::sleep(delay).await; + } + Err(e) => { + if attempt > *OPENAI_REQUEST_MAX_RETRIES { + return Err(e.into()); + } + let delay = backoff(attempt); + tokio::time::sleep(delay).await; + } + } + } +} + +/// Lightweight SSE processor for the Chat Completions streaming format. The +/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest +/// of the pipeline can stay agnostic of the underlying wire format. +async fn process_chat_sse(stream: S, tx_event: mpsc::Sender>) +where + S: Stream> + Unpin, +{ + let mut stream = stream.eventsource(); + + let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; + + loop { + let sse = match timeout(idle_timeout, stream.next()).await { + Ok(Some(Ok(ev))) => ev, + Ok(Some(Err(e))) => { + let _ = tx_event.send(Err(CodexErr::Stream(e.to_string()))).await; + return; + } + Ok(None) => { + // Stream closed gracefully – emit Completed with dummy id. + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + })) + .await; + return; + } + Err(_) => { + let _ = tx_event + .send(Err(CodexErr::Stream("idle timeout waiting for SSE".into()))) + .await; + return; + } + }; + + // OpenAI Chat streaming sends a literal string "[DONE]" when finished. + if sse.data.trim() == "[DONE]" { + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + })) + .await; + return; + } + + // Parse JSON chunk + let chunk: serde_json::Value = match serde_json::from_str(&sse.data) { + Ok(v) => v, + Err(_) => continue, + }; + + let content_opt = chunk + .get("choices") + .and_then(|c| c.get(0)) + .and_then(|c| c.get("delta")) + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()); + + if let Some(content) = content_opt { + let item = ResponseItem::Message { + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: content.to_string(), + }], + }; + + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + } +} + +/// Optional client-side aggregation helper +/// +/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from +/// [`process_chat_sse`] into a *running* assistant message, **suppressing the +/// per-token deltas**. The stream stays silent while the model is thinking +/// and only emits two events per turn: +/// +/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message +/// (fully concatenated). +/// 2. The original `ResponseEvent::Completed` right after it. +/// +/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers. +/// +/// The adapter is intentionally *lossless*: callers who do **not** opt in via +/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified +/// events. +pub(crate) struct AggregatedChatStream { + inner: S, + cumulative: String, + pending_completed: Option, +} + +impl Stream for AggregatedChatStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + // First, flush any buffered Completed event from the previous call. + if let Some(ev) = this.pending_completed.take() { + return Poll::Ready(Some(Ok(ev))); + } + + loop { + match Pin::new(&mut this.inner).poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { + // Accumulate *assistant* text but do not emit yet. + if let crate::models::ResponseItem::Message { role, content } = &item { + if role == "assistant" { + if let Some(text) = content.iter().find_map(|c| match c { + crate::models::ContentItem::OutputText { text } => Some(text), + _ => None, + }) { + this.cumulative.push_str(text); + } + } + } + + // Swallow partial event; keep polling. + continue; + } + Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => { + if !this.cumulative.is_empty() { + let aggregated_item = crate::models::ResponseItem::Message { + role: "assistant".to_string(), + content: vec![crate::models::ContentItem::OutputText { + text: std::mem::take(&mut this.cumulative), + }], + }; + + // Buffer Completed so it is returned *after* the aggregated message. + this.pending_completed = Some(ResponseEvent::Completed { response_id }); + + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( + aggregated_item, + )))); + } + + // Nothing aggregated – forward Completed directly. + return Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))); + } // No other `Ok` variants exist at the moment, continue polling. + } + } + } +} + +/// Extension trait that activates aggregation on any stream of [`ResponseEvent`]. +pub(crate) trait AggregateStreamExt: Stream> + Sized { + /// Returns a new stream that emits **only** the final assistant message + /// per turn instead of every incremental delta. The produced + /// `ResponseEvent` sequence for a typical text turn looks like: + /// + /// ```ignore + /// OutputItemDone() + /// Completed { .. } + /// ``` + /// + /// No other `OutputItemDone` events will be seen by the caller. + /// + /// Usage: + /// + /// ```ignore + /// let agg_stream = client.stream(&prompt).await?.aggregate(); + /// while let Some(event) = agg_stream.next().await { + /// // event now contains cumulative text + /// } + /// ``` + fn aggregate(self) -> AggregatedChatStream { + AggregatedChatStream { + inner: self, + cumulative: String::new(), + pending_completed: None, + } + } +} + +impl AggregateStreamExt for T where T: Stream> + Sized {} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 9216e68c..1b21f6e0 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1,11 +1,7 @@ use std::collections::BTreeMap; -use std::collections::HashMap; use std::io::BufRead; use std::path::Path; -use std::pin::Pin; use std::sync::LazyLock; -use std::task::Context; -use std::task::Poll; use std::time::Duration; use bytes::Bytes; @@ -23,66 +19,22 @@ use tracing::debug; use tracing::trace; use tracing::warn; +use crate::chat_completions::stream_chat_completions; +use crate::client_common::Payload; +use crate::client_common::Prompt; +use crate::client_common::Reasoning; +use crate::client_common::ResponseEvent; +use crate::client_common::ResponseStream; use crate::error::CodexErr; use crate::error::Result; use crate::flags::CODEX_RS_SSE_FIXTURE; use crate::flags::OPENAI_REQUEST_MAX_RETRIES; use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; use crate::model_provider_info::ModelProviderInfo; +use crate::model_provider_info::WireApi; use crate::models::ResponseItem; use crate::util::backoff; -/// API request payload for a single model turn. -#[derive(Default, Debug, Clone)] -pub struct Prompt { - /// Conversation context input items. - pub input: Vec, - /// Optional previous response ID (when storage is enabled). - pub prev_id: Option, - /// Optional initial instructions (only sent on first turn). - pub instructions: Option, - /// Whether to store response on server side (disable_response_storage = !store). - pub store: bool, - - /// Additional tools sourced from external MCP servers. Note each key is - /// the "fully qualified" tool name (i.e., prefixed with the server name), - /// which should be reported to the model in place of Tool::name. - pub extra_tools: HashMap, -} - -#[derive(Debug)] -pub enum ResponseEvent { - OutputItemDone(ResponseItem), - Completed { response_id: String }, -} - -#[derive(Debug, Serialize)] -struct Payload<'a> { - model: &'a str, - #[serde(skip_serializing_if = "Option::is_none")] - instructions: Option<&'a String>, - // TODO(mbolin): ResponseItem::Other should not be serialized. Currently, - // we code defensively to avoid this case, but perhaps we should use a - // separate enum for serialization. - input: &'a Vec, - tools: &'a [serde_json::Value], - tool_choice: &'static str, - parallel_tool_calls: bool, - reasoning: Option, - #[serde(skip_serializing_if = "Option::is_none")] - previous_response_id: Option, - /// true when using the Responses API. - store: bool, - stream: bool, -} - -#[derive(Debug, Serialize)] -struct Reasoning { - effort: &'static str, - #[serde(skip_serializing_if = "Option::is_none")] - generate_summary: Option, -} - /// When serialized as JSON, this produces a valid "Tool" in the OpenAI /// Responses API. #[derive(Debug, Serialize)] @@ -152,7 +104,20 @@ impl ModelClient { } } - pub async fn stream(&mut self, prompt: &Prompt) -> Result { + /// Dispatches to either the Responses or Chat implementation depending on + /// the provider config. Public callers always invoke `stream()` – the + /// specialised helpers are private to avoid accidental misuse. + pub async fn stream(&self, prompt: &Prompt) -> Result { + match self.provider.wire_api { + WireApi::Responses => self.stream_responses(prompt).await, + WireApi::Chat => { + stream_chat_completions(prompt, &self.model, &self.client, &self.provider).await + } + } + } + + /// Implementation for the OpenAI *Responses* experimental API. + async fn stream_responses(&self, prompt: &Prompt) -> Result { if let Some(path) = &*CODEX_RS_SSE_FIXTURE { // short circuit for tests warn!(path, "Streaming from fixture"); @@ -202,8 +167,8 @@ impl ModelClient { let api_key = self .provider - .api_key() - .ok_or_else(|| crate::error::CodexErr::EnvVar("API_KEY"))?; + .api_key()? + .expect("Repsones API requires an API key"); let res = self .client .post(&url) @@ -396,18 +361,6 @@ where } } -pub struct ResponseStream { - rx_event: mpsc::Receiver>, -} - -impl Stream for ResponseStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx_event.poll_recv(cx) - } -} - /// used in tests to stream from a text SSE file async fn stream_from_fixture(path: impl AsRef) -> Result { let (tx_event, rx_event) = mpsc::channel::>(16); diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs new file mode 100644 index 00000000..514b6b60 --- /dev/null +++ b/codex-rs/core/src/client_common.rs @@ -0,0 +1,72 @@ +use crate::error::Result; +use crate::models::ResponseItem; +use futures::Stream; +use serde::Serialize; +use std::collections::HashMap; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; +use tokio::sync::mpsc; + +/// API request payload for a single model turn. +#[derive(Default, Debug, Clone)] +pub struct Prompt { + /// Conversation context input items. + pub input: Vec, + /// Optional previous response ID (when storage is enabled). + pub prev_id: Option, + /// Optional initial instructions (only sent on first turn). + pub instructions: Option, + /// Whether to store response on server side (disable_response_storage = !store). + pub store: bool, + + /// Additional tools sourced from external MCP servers. Note each key is + /// the "fully qualified" tool name (i.e., prefixed with the server name), + /// which should be reported to the model in place of Tool::name. + pub extra_tools: HashMap, +} + +#[derive(Debug)] +pub enum ResponseEvent { + OutputItemDone(ResponseItem), + Completed { response_id: String }, +} + +#[derive(Debug, Serialize)] +pub(crate) struct Reasoning { + pub(crate) effort: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) generate_summary: Option, +} + +#[derive(Debug, Serialize)] +pub(crate) struct Payload<'a> { + pub(crate) model: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) instructions: Option<&'a String>, + // TODO(mbolin): ResponseItem::Other should not be serialized. Currently, + // we code defensively to avoid this case, but perhaps we should use a + // separate enum for serialization. + pub(crate) input: &'a Vec, + pub(crate) tools: &'a [serde_json::Value], + pub(crate) tool_choice: &'static str, + pub(crate) parallel_tool_calls: bool, + pub(crate) reasoning: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) previous_response_id: Option, + /// true when using the Responses API. + pub(crate) store: bool, + pub(crate) stream: bool, +} + +pub(crate) struct ResponseStream { + pub(crate) rx_event: mpsc::Receiver>, +} + +impl Stream for ResponseStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx_event.poll_recv(cx) + } +} diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 89bc364b..f68eb73f 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -31,10 +31,13 @@ use tracing::info; use tracing::trace; use tracing::warn; +use crate::WireApi; +use crate::chat_completions::AggregateStreamExt; use crate::client::ModelClient; -use crate::client::Prompt; -use crate::client::ResponseEvent; +use crate::client_common::Prompt; +use crate::client_common::ResponseEvent; use crate::config::Config; +use crate::conversation_history::ConversationHistory; use crate::error::CodexErr; use crate::error::Result as CodexResult; use crate::exec::ExecParams; @@ -65,7 +68,6 @@ use crate::safety::assess_command_safety; use crate::safety::assess_patch_safety; use crate::user_notification::UserNotification; use crate::util::backoff; -use crate::zdr_transcript::ZdrTranscript; /// The high-level interface to the Codex system. /// It operates as a queue pair where you send submissions and receive events. @@ -181,7 +183,7 @@ struct State { previous_response_id: Option, pending_approvals: HashMap>, pending_input: Vec, - zdr_transcript: Option, + zdr_transcript: Option, } impl Session { @@ -416,11 +418,15 @@ impl Drop for Session { } impl State { - pub fn partial_clone(&self) -> Self { + pub fn partial_clone(&self, retain_zdr_transcript: bool) -> Self { Self { approved_commands: self.approved_commands.clone(), previous_response_id: self.previous_response_id.clone(), - zdr_transcript: self.zdr_transcript.clone(), + zdr_transcript: if retain_zdr_transcript { + self.zdr_transcript.clone() + } else { + None + }, ..Default::default() } } @@ -534,14 +540,19 @@ async fn submission_loop( let client = ModelClient::new(model.clone(), provider.clone()); // abort any current running session and clone its state + let retain_zdr_transcript = + record_conversation_history(disable_response_storage, provider.wire_api); let state = match sess.take() { Some(sess) => { sess.abort(); - sess.state.lock().unwrap().partial_clone() + sess.state + .lock() + .unwrap() + .partial_clone(retain_zdr_transcript) } None => State { - zdr_transcript: if disable_response_storage { - Some(ZdrTranscript::new()) + zdr_transcript: if retain_zdr_transcript { + Some(ConversationHistory::new()) } else { None }, @@ -670,21 +681,35 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from); net_new_turn_input.extend(pending_input); + // Persist only the net-new items of this turn to the rollout. + sess.record_rollout_items(&net_new_turn_input).await; + + // Construct the input that we will send to the model. When using the + // Chat completions API (or ZDR clients), the model needs the full + // conversation history on each turn. The rollout file, however, should + // only record the new items that originated in this turn so that it + // represents an append-only log without duplicates. let turn_input: Vec = if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() { - // If we are using ZDR, we need to send the transcript with every turn. - let mut full_transcript = transcript.contents(); - full_transcript.extend(net_new_turn_input.clone()); + // If we are using Chat/ZDR, we need to send the transcript with every turn. + + // 1. Build up the conversation history for the next turn. + let full_transcript = [transcript.contents(), net_new_turn_input.clone()].concat(); + + // 2. Update the in-memory transcript so that future turns + // include these items as part of the history. transcript.record_items(net_new_turn_input); + + // Note that `transcript.record_items()` does some filtering + // such that `full_transcript` may include items that were + // excluded from `transcript`. full_transcript } else { + // Responses API path – we can just send the new items and + // record the same. net_new_turn_input }; - // Persist the input part of the turn to the rollout (user messages / - // function_call_output from previous step). - sess.record_rollout_items(&turn_input).await; - let turn_input_messages: Vec = turn_input .iter() .filter_map(|item| match item { @@ -794,6 +819,7 @@ async fn run_turn( match try_run_turn(sess, &sub_id, &prompt).await { Ok(output) => return Ok(output), Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), + Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e) => { if retries < *OPENAI_STREAM_MAX_RETRIES { retries += 1; @@ -838,7 +864,7 @@ async fn try_run_turn( sub_id: &str, prompt: &Prompt, ) -> CodexResult> { - let mut stream = sess.client.clone().stream(prompt).await?; + let mut stream = sess.client.clone().stream(prompt).await?.aggregate(); // Buffer all the incoming messages from the stream first, then execute them. // If we execute a function call in the middle of handling the stream, it can time out. @@ -1612,3 +1638,15 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option bool { + if disable_response_storage { + return true; + } + + match wire_api { + WireApi::Responses => false, + WireApi::Chat => true, + } +} diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 087d6afb..2264792b 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -21,6 +21,9 @@ pub struct Config { /// Optional override of model selection. pub model: String, + /// Key into the model_providers map that specifies which provider to use. + pub model_provider_id: String, + /// Info needed to make an API request to the model. pub model_provider: ModelProviderInfo, @@ -219,21 +222,22 @@ impl Config { model_providers.entry(key).or_insert(provider); } - let model_provider_name = provider + let model_provider_id = provider .or(cfg.model_provider) .unwrap_or_else(|| "openai".to_string()); let model_provider = model_providers - .get(&model_provider_name) + .get(&model_provider_id) .ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::NotFound, - format!("Model provider `{model_provider_name}` not found"), + format!("Model provider `{model_provider_id}` not found"), ) })? .clone(); let config = Self { model: model.or(cfg.model).unwrap_or_else(default_model), + model_provider_id, model_provider, cwd: cwd.map_or_else( || { diff --git a/codex-rs/core/src/zdr_transcript.rs b/codex-rs/core/src/conversation_history.rs similarity index 72% rename from codex-rs/core/src/zdr_transcript.rs rename to codex-rs/core/src/conversation_history.rs index 25fdc5a6..8d19e0cb 100644 --- a/codex-rs/core/src/zdr_transcript.rs +++ b/codex-rs/core/src/conversation_history.rs @@ -1,16 +1,18 @@ use crate::models::ResponseItem; -/// Transcript that needs to be maintained for ZDR clients for which -/// previous_response_id is not available, so we must include the transcript -/// with every API call. This must include each `function_call` and its -/// corresponding `function_call_output`. +/// Transcript of conversation history that is needed: +/// - for ZDR clients for which previous_response_id is not available, so we +/// must include the transcript with every API call. This must include each +/// `function_call` and its corresponding `function_call_output`. +/// - for clients using the "chat completions" API as opposed to the +/// "responses" API. #[derive(Debug, Clone)] -pub(crate) struct ZdrTranscript { +pub(crate) struct ConversationHistory { /// The oldest items are at the beginning of the vector. items: Vec, } -impl ZdrTranscript { +impl ConversationHistory { pub(crate) fn new() -> Self { Self { items: Vec::new() } } diff --git a/codex-rs/core/src/error.rs b/codex-rs/core/src/error.rs index 0e438700..35b099e6 100644 --- a/codex-rs/core/src/error.rs +++ b/codex-rs/core/src/error.rs @@ -55,7 +55,7 @@ pub enum CodexErr { /// Returned by run_command_stream when the user pressed Ctrl‑C (SIGINT). Session uses this to /// surface a polite FunctionCallOutput back to the model instead of crashing the CLI. - #[error("interrupted (Ctrl‑C)")] + #[error("interrupted (Ctrl-C)")] Interrupted, /// Unexpected HTTP status code. @@ -97,8 +97,28 @@ pub enum CodexErr { #[error(transparent)] TokioJoin(#[from] JoinError), - #[error("missing environment variable {0}")] - EnvVar(&'static str), + #[error("{0}")] + EnvVar(EnvVarError), +} + +#[derive(Debug)] +pub struct EnvVarError { + /// Name of the environment variable that is missing. + pub var: String, + + /// Optional instructions to help the user get a valid value for the + /// variable and set it. + pub instructions: Option, +} + +impl std::fmt::Display for EnvVarError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Missing environment variable: `{}`.", self.var)?; + if let Some(instructions) = &self.instructions { + write!(f, " {instructions}")?; + } + Ok(()) + } } impl CodexErr { diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 1c3a46df..7774e0f5 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -5,11 +5,15 @@ // the TUI or the tracing stack). #![deny(clippy::print_stdout, clippy::print_stderr)] +mod chat_completions; + mod client; +mod client_common; pub mod codex; pub use codex::Codex; pub mod codex_wrapper; pub mod config; +mod conversation_history; pub mod error; pub mod exec; mod flags; @@ -21,10 +25,10 @@ pub mod mcp_server_config; mod mcp_tool_call; mod model_provider_info; pub use model_provider_info::ModelProviderInfo; +pub use model_provider_info::WireApi; mod models; pub mod protocol; mod rollout; mod safety; mod user_notification; pub mod util; -mod zdr_transcript; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index e7069c04..969797cb 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -8,6 +8,25 @@ use serde::Deserialize; use serde::Serialize; use std::collections::HashMap; +use std::env::VarError; + +use crate::error::EnvVarError; + +/// Wire protocol that the provider speaks. Most third-party services only +/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI +/// itself (and a handful of others) additionally expose the more modern +/// *Responses* API. The two protocols use different request/response shapes +/// and *cannot* be auto-detected at runtime, therefore each provider entry +/// must declare which one it expects. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum WireApi { + /// The experimental “Responses” API exposed by OpenAI at `/v1/responses`. + #[default] + Responses, + /// Regular Chat Completions compatible with `/v1/chat/completions`. + Chat, +} /// Serializable representation of a provider definition. #[derive(Debug, Clone, Deserialize, Serialize)] @@ -17,13 +36,38 @@ pub struct ModelProviderInfo { /// Base URL for the provider's OpenAI-compatible API. pub base_url: String, /// Environment variable that stores the user's API key for this provider. - pub env_key: String, + pub env_key: Option, + + /// Optional instructions to help the user get a valid value for the + /// variable and set it. + pub env_key_instructions: Option, + + /// Which wire protocol this provider expects. + pub wire_api: WireApi, } impl ModelProviderInfo { - /// Returns the API key for this provider if present in the environment. - pub fn api_key(&self) -> Option { - std::env::var(&self.env_key).ok() + /// If `env_key` is Some, returns the API key for this provider if present + /// (and non-empty) in the environment. If `env_key` is required but + /// cannot be found, returns an error. + pub fn api_key(&self) -> crate::error::Result> { + match &self.env_key { + Some(env_key) => std::env::var(env_key) + .and_then(|v| { + if v.trim().is_empty() { + Err(VarError::NotPresent) + } else { + Ok(Some(v)) + } + }) + .map_err(|_| { + crate::error::CodexErr::EnvVar(EnvVarError { + var: env_key.clone(), + instructions: self.env_key_instructions.clone(), + }) + }), + None => Ok(None), + } } } @@ -37,7 +81,9 @@ pub fn built_in_model_providers() -> HashMap { P { name: "OpenAI".into(), base_url: "https://api.openai.com/v1".into(), - env_key: "OPENAI_API_KEY".into(), + env_key: Some("OPENAI_API_KEY".into()), + env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()), + wire_api: WireApi::Responses, }, ), ( @@ -45,7 +91,9 @@ pub fn built_in_model_providers() -> HashMap { P { name: "OpenRouter".into(), base_url: "https://openrouter.ai/api/v1".into(), - env_key: "OPENROUTER_API_KEY".into(), + env_key: Some("OPENROUTER_API_KEY".into()), + env_key_instructions: None, + wire_api: WireApi::Chat, }, ), ( @@ -53,7 +101,9 @@ pub fn built_in_model_providers() -> HashMap { P { name: "Gemini".into(), base_url: "https://generativelanguage.googleapis.com/v1beta/openai".into(), - env_key: "GEMINI_API_KEY".into(), + env_key: Some("GEMINI_API_KEY".into()), + env_key_instructions: None, + wire_api: WireApi::Chat, }, ), ( @@ -61,7 +111,9 @@ pub fn built_in_model_providers() -> HashMap { P { name: "Ollama".into(), base_url: "http://localhost:11434/v1".into(), - env_key: "OLLAMA_API_KEY".into(), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, }, ), ( @@ -69,7 +121,9 @@ pub fn built_in_model_providers() -> HashMap { P { name: "Mistral".into(), base_url: "https://api.mistral.ai/v1".into(), - env_key: "MISTRAL_API_KEY".into(), + env_key: Some("MISTRAL_API_KEY".into()), + env_key_instructions: None, + wire_api: WireApi::Chat, }, ), ( @@ -77,7 +131,9 @@ pub fn built_in_model_providers() -> HashMap { P { name: "DeepSeek".into(), base_url: "https://api.deepseek.com".into(), - env_key: "DEEPSEEK_API_KEY".into(), + env_key: Some("DEEPSEEK_API_KEY".into()), + env_key_instructions: None, + wire_api: WireApi::Chat, }, ), ( @@ -85,7 +141,9 @@ pub fn built_in_model_providers() -> HashMap { P { name: "xAI".into(), base_url: "https://api.x.ai/v1".into(), - env_key: "XAI_API_KEY".into(), + env_key: Some("XAI_API_KEY".into()), + env_key_instructions: None, + wire_api: WireApi::Chat, }, ), ( @@ -93,7 +151,9 @@ pub fn built_in_model_providers() -> HashMap { P { name: "Groq".into(), base_url: "https://api.groq.com/openai/v1".into(), - env_key: "GROQ_API_KEY".into(), + env_key: Some("GROQ_API_KEY".into()), + env_key_instructions: None, + wire_api: WireApi::Chat, }, ), ] diff --git a/codex-rs/core/src/models.rs b/codex-rs/core/src/models.rs index 81e19833..fad5a318 100644 --- a/codex-rs/core/src/models.rs +++ b/codex-rs/core/src/models.rs @@ -116,10 +116,10 @@ pub struct ShellToolCallParams { pub timeout_ms: Option, } -#[expect(dead_code)] #[derive(Deserialize, Debug, Clone)] pub struct FunctionCallOutputPayload { pub content: String, + #[expect(dead_code)] pub success: Option, } diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index 613dfe72..131ccb7a 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -25,6 +25,7 @@ pub struct Submission { /// Submission operation #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(tag = "type", rename_all = "snake_case")] +#[allow(clippy::large_enum_variant)] #[non_exhaustive] pub enum Op { /// Configure the model session. diff --git a/codex-rs/core/tests/previous_response_id.rs b/codex-rs/core/tests/previous_response_id.rs index 0c4428b8..c318f38b 100644 --- a/codex-rs/core/tests/previous_response_id.rs +++ b/codex-rs/core/tests/previous_response_id.rs @@ -92,7 +92,9 @@ async fn keeps_previous_response_id_between_tasks() { // Environment variable that should exist in the test environment. // ModelClient will return an error if the environment variable for the // provider is not set. - env_key: "PATH".into(), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: codex_core::WireApi::Responses, }; // Init session diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index abb3d3ca..cfb7d44b 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -82,7 +82,9 @@ async fn retries_on_early_close() { // Environment variable that should exist in the test environment. // ModelClient will return an error if the environment variable for the // provider is not set. - env_key: "PATH".into(), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: codex_core::WireApi::Responses, }; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index d8e4b9f5..d711388f 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -16,8 +16,6 @@ use codex_core::protocol::Op; use codex_core::protocol::SandboxPolicy; use codex_core::util::is_inside_git_repo; use event_processor::EventProcessor; -use owo_colors::OwoColorize; -use owo_colors::Style; use tracing::debug; use tracing::error; use tracing::info; @@ -45,8 +43,6 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { ), }; - assert_api_key(stderr_with_ansi); - let sandbox_policy = if full_auto { Some(SandboxPolicy::new_full_auto_policy()) } else { @@ -163,38 +159,3 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { Ok(()) } - -/// If a valid API key is not present in the environment, print an error to -/// stderr and exits with 1; otherwise, does nothing. -fn assert_api_key(stderr_with_ansi: bool) { - if !has_api_key() { - let (msg_style, var_style, url_style) = if stderr_with_ansi { - ( - Style::new().red(), - Style::new().bold(), - Style::new().bold().underline(), - ) - } else { - (Style::new(), Style::new(), Style::new()) - }; - - eprintln!( - "\n{msg}\n\nSet the environment variable {var} and re-run this command.\nYou can create a key here: {url}\n", - msg = "Missing OpenAI API key.".style(msg_style), - var = "OPENAI_API_KEY".style(var_style), - url = "https://platform.openai.com/account/api-keys".style(url_style), - ); - std::process::exit(1); - } -} - -/// Returns `true` if a recognized API key is present in the environment. -/// -/// At present we only support `OPENAI_API_KEY`, mirroring the behavior of the -/// Node-based `codex-cli`. Additional providers can be added here when the -/// Rust implementation gains first-class support for them. -fn has_api_key() -> bool { - std::env::var("OPENAI_API_KEY") - .map(|s| !s.trim().is_empty()) - .unwrap_or(false) -} diff --git a/codex-rs/tui/src/app_event.rs b/codex-rs/tui/src/app_event.rs index 2b320375..dd5053cf 100644 --- a/codex-rs/tui/src/app_event.rs +++ b/codex-rs/tui/src/app_event.rs @@ -1,6 +1,7 @@ use codex_core::protocol::Event; use crossterm::event::KeyEvent; +#[allow(clippy::large_enum_variant)] pub(crate) enum AppEvent { CodexEvent(Event), diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index cb037e0a..53bb24b8 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -162,12 +162,8 @@ impl ChatWidget<'_> { } fn submit_welcome_message(&mut self) -> std::result::Result<(), SendError> { - self.handle_codex_event(Event { - id: "welcome".to_string(), - msg: EventMsg::AgentMessage { - message: "Welcome to codex!".to_string(), - }, - })?; + self.conversation_history.add_welcome_message(&self.config); + self.request_redraw()?; Ok(()) } @@ -231,8 +227,6 @@ impl ChatWidget<'_> { } EventMsg::TaskStarted => { self.bottom_pane.set_task_running(true)?; - self.conversation_history - .add_background_event(format!("task {id} started")); self.request_redraw()?; } EventMsg::TaskComplete => { @@ -240,8 +234,7 @@ impl ChatWidget<'_> { self.request_redraw()?; } EventMsg::Error { message } => { - self.conversation_history - .add_background_event(format!("Error: {message}")); + self.conversation_history.add_error(message); self.bottom_pane.set_task_running(false)?; } EventMsg::ExecApprovalRequest { diff --git a/codex-rs/tui/src/conversation_history_widget.rs b/codex-rs/tui/src/conversation_history_widget.rs index ca069997..e3bb9121 100644 --- a/codex-rs/tui/src/conversation_history_widget.rs +++ b/codex-rs/tui/src/conversation_history_widget.rs @@ -162,6 +162,10 @@ impl ConversationHistoryWidget { self.scroll_position = usize::MAX; } + pub fn add_welcome_message(&mut self, config: &Config) { + self.add_to_history(HistoryCell::new_welcome_message(config)); + } + pub fn add_user_message(&mut self, message: String) { self.add_to_history(HistoryCell::new_user_prompt(message)); } @@ -174,6 +178,10 @@ impl ConversationHistoryWidget { self.add_to_history(HistoryCell::new_background_event(message)); } + pub fn add_error(&mut self, message: String) { + self.add_to_history(HistoryCell::new_error_event(message)); + } + /// Add a pending patch entry (before user approval). pub fn add_patch_event( &mut self, diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index d8e2b2e2..53035a98 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -32,6 +32,9 @@ pub(crate) enum PatchEventType { /// `Vec>` representation to make it easier to display in a /// scrollable list. pub(crate) enum HistoryCell { + /// Welcome message. + WelcomeMessage { lines: Vec> }, + /// Message from the user. UserPrompt { lines: Vec> }, @@ -69,6 +72,9 @@ pub(crate) enum HistoryCell { /// Background event BackgroundEvent { lines: Vec> }, + /// Error event from the backend. + ErrorEvent { lines: Vec> }, + /// Info describing the newly‑initialized session. SessionInfo { lines: Vec> }, @@ -85,6 +91,31 @@ pub(crate) enum HistoryCell { const TOOL_CALL_MAX_LINES: usize = 5; impl HistoryCell { + pub(crate) fn new_welcome_message(config: &Config) -> Self { + let mut lines: Vec> = vec![ + Line::from(vec![ + "OpenAI ".into(), + "Codex".bold(), + " (research preview)".dim(), + ]), + Line::from(""), + Line::from("codex session:".magenta().bold()), + ]; + + let entries = vec![ + ("workdir", config.cwd.display().to_string()), + ("model", config.model.clone()), + ("provider", config.model_provider_id.clone()), + ("approval", format!("{:?}", config.approval_policy)), + ("sandbox", format!("{:?}", config.sandbox_policy)), + ]; + for (key, value) in entries { + lines.push(Line::from(vec![format!("{key}: ").bold(), value.into()])); + } + lines.push(Line::from("")); + HistoryCell::WelcomeMessage { lines } + } + pub(crate) fn new_user_prompt(message: String) -> Self { let mut lines: Vec> = Vec::new(); lines.push(Line::from("user".cyan().bold())); @@ -245,26 +276,26 @@ impl HistoryCell { HistoryCell::BackgroundEvent { lines } } + pub(crate) fn new_error_event(message: String) -> Self { + let lines: Vec> = vec![ + vec!["ERROR: ".red().bold(), message.into()].into(), + "".into(), + ]; + HistoryCell::ErrorEvent { lines } + } + pub(crate) fn new_session_info(config: &Config, model: String) -> Self { - let mut lines: Vec> = Vec::new(); - - lines.push(Line::from("codex session:".magenta().bold())); - lines.push(Line::from(vec!["↳ model: ".bold(), model.into()])); - lines.push(Line::from(vec![ - "↳ cwd: ".bold(), - config.cwd.display().to_string().into(), - ])); - lines.push(Line::from(vec![ - "↳ approval: ".bold(), - format!("{:?}", config.approval_policy).into(), - ])); - lines.push(Line::from(vec![ - "↳ sandbox: ".bold(), - format!("{:?}", config.sandbox_policy).into(), - ])); - lines.push(Line::from("")); - - HistoryCell::SessionInfo { lines } + if config.model == model { + HistoryCell::SessionInfo { lines: vec![] } + } else { + let lines = vec![ + Line::from("model changed:".magenta().bold()), + Line::from(format!("requested: {}", config.model)), + Line::from(format!("used: {}", model)), + Line::from(""), + ]; + HistoryCell::SessionInfo { lines } + } } /// Create a new `PendingPatch` cell that lists the file‑level summary of @@ -329,9 +360,11 @@ impl HistoryCell { pub(crate) fn lines(&self) -> &Vec> { match self { - HistoryCell::UserPrompt { lines, .. } + HistoryCell::WelcomeMessage { lines, .. } + | HistoryCell::UserPrompt { lines, .. } | HistoryCell::AgentMessage { lines, .. } | HistoryCell::BackgroundEvent { lines, .. } + | HistoryCell::ErrorEvent { lines, .. } | HistoryCell::SessionInfo { lines, .. } | HistoryCell::ActiveExecCommand { lines, .. } | HistoryCell::CompletedExecCommand { lines, .. } diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 42da0f48..fe4f9954 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -33,8 +33,6 @@ mod user_approval_widget; pub use cli::Cli; pub fn run_main(cli: Cli) -> std::io::Result<()> { - assert_env_var_set(); - let (sandbox_policy, approval_policy) = if cli.full_auto { ( Some(SandboxPolicy::new_full_auto_policy()), @@ -172,20 +170,6 @@ fn run_ratatui_app( app_result } -#[expect( - clippy::print_stderr, - reason = "TUI should not have been displayed yet, so we can write to stderr." -)] -fn assert_env_var_set() { - if std::env::var("OPENAI_API_KEY").is_err() { - eprintln!("Welcome to codex! It looks like you're missing: `OPENAI_API_KEY`"); - eprintln!( - "Create an API key (https://platform.openai.com) and export as an environment variable" - ); - std::process::exit(1); - } -} - #[expect( clippy::print_stderr, reason = "TUI should no longer be displayed, so we can write to stderr."