feat: support the chat completions API in the Rust CLI (#862)

This is a substantial PR to add support for the chat completions API,
which in turn makes it possible to use non-OpenAI model providers (just
like in the TypeScript CLI):

* It moves a number of structs from `client.rs` to `client_common.rs` so
they can be shared.
* It introduces support for the chat completions API in
`chat_completions.rs`.
* It updates `ModelProviderInfo` so that `env_key` is `Option<String>`
instead of `String` (for e.g., ollama) and adds a `wire_api` field
* It updates `client.rs` to choose between `stream_responses()` and
`stream_chat_completions()` based on the `wire_api` for the
`ModelProviderInfo`
* It updates the `exec` and TUI CLIs to no longer fail if the
`OPENAI_API_KEY` environment variable is not set
* It updates the TUI so that `EventMsg::Error` is displayed more
prominently when it occurs, particularly now that it is important to
alert users to the `CodexErr::EnvVar` variant.
* `CodexErr::EnvVar` was updated to include an optional `instructions`
field so we can preserve the behavior where we direct users to
https://platform.openai.com if `OPENAI_API_KEY` is not set.
* Cleaned up the "welcome message" in the TUI to ensure the model
provider is displayed.
* Updated the docs in `codex-rs/README.md`.

To exercise the chat completions API from OpenAI models, I added the
following to my `config.toml`:

```toml
model = "gpt-4o"
model_provider = "openai-chat-completions"

[model_providers.openai-chat-completions]
name = "OpenAI using Chat Completions"
base_url = "https://api.openai.com/v1"
env_key = "OPENAI_API_KEY"
wire_api = "chat"
```

Though to test a non-OpenAI provider, I installed ollama with mistral
locally on my Mac because ChatGPT said that would be a good match for my
hardware:

```shell
brew install ollama
ollama serve
ollama pull mistral
```

Then I added the following to my `~/.codex/config.toml`:

```toml
model = "mistral"
model_provider = "ollama"
```

Note this code could certainly use more test coverage, but I want to get
this in so folks can start playing with it.

For reference, I believe https://github.com/openai/codex/pull/247 was
roughly the comparable PR on the TypeScript side.
This commit is contained in:
Michael Bolin
2025-05-08 21:46:06 -07:00
committed by GitHub
parent a538e6acb2
commit e924070cee
20 changed files with 703 additions and 200 deletions

View File

@@ -0,0 +1,310 @@
use std::time::Duration;
use bytes::Bytes;
use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
use futures::TryStreamExt;
use reqwest::StatusCode;
use serde_json::json;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tracing::debug;
use tracing::trace;
use crate::ModelProviderInfo;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::error::CodexErr;
use crate::error::Result;
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::models::ContentItem;
use crate::models::ResponseItem;
use crate::util::backoff;
/// Implementation for the classic Chat Completions API. This is intentionally
/// minimal: we only stream back plain assistant text.
pub(crate) async fn stream_chat_completions(
prompt: &Prompt,
model: &str,
client: &reqwest::Client,
provider: &ModelProviderInfo,
) -> Result<ResponseStream> {
// Build messages array
let mut messages = Vec::<serde_json::Value>::new();
if let Some(instr) = &prompt.instructions {
messages.push(json!({"role": "system", "content": instr}));
}
for item in &prompt.input {
if let ResponseItem::Message { role, content } = item {
let mut text = String::new();
for c in content {
match c {
ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => {
text.push_str(t);
}
_ => {}
}
}
messages.push(json!({"role": role, "content": text}));
}
}
let payload = json!({
"model": model,
"messages": messages,
"stream": true
});
let base_url = provider.base_url.trim_end_matches('/');
let url = format!("{}/chat/completions", base_url);
debug!(url, "POST (chat)");
trace!("request payload: {}", payload);
let api_key = provider.api_key()?;
let mut attempt = 0;
loop {
attempt += 1;
let mut req_builder = client.post(&url);
if let Some(api_key) = &api_key {
req_builder = req_builder.bearer_auth(api_key.clone());
}
let res = req_builder
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(&payload)
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
tokio::spawn(process_chat_sse(stream, tx_event));
return Ok(ResponseStream { rx_event });
}
Ok(res) => {
let status = res.status();
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
let body = (res.text().await).unwrap_or_default();
return Err(CodexErr::UnexpectedStatus(status, body));
}
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
return Err(CodexErr::RetryLimit(status));
}
let retry_after_secs = res
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let delay = retry_after_secs
.map(|s| Duration::from_millis(s * 1_000))
.unwrap_or_else(|| backoff(attempt));
tokio::time::sleep(delay).await;
}
Err(e) => {
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
return Err(e.into());
}
let delay = backoff(attempt);
tokio::time::sleep(delay).await;
}
}
}
}
/// Lightweight SSE processor for the Chat Completions streaming format. The
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
/// of the pipeline can stay agnostic of the underlying wire format.
async fn process_chat_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
loop {
let sse = match timeout(idle_timeout, stream.next()).await {
Ok(Some(Ok(ev))) => ev,
Ok(Some(Err(e))) => {
let _ = tx_event.send(Err(CodexErr::Stream(e.to_string()))).await;
return;
}
Ok(None) => {
// Stream closed gracefully emit Completed with dummy id.
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
}))
.await;
return;
}
Err(_) => {
let _ = tx_event
.send(Err(CodexErr::Stream("idle timeout waiting for SSE".into())))
.await;
return;
}
};
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
if sse.data.trim() == "[DONE]" {
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
}))
.await;
return;
}
// Parse JSON chunk
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
Ok(v) => v,
Err(_) => continue,
};
let content_opt = chunk
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("delta"))
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str());
if let Some(content) = content_opt {
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: content.to_string(),
}],
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
}
}
/// Optional client-side aggregation helper
///
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
/// per-token deltas**. The stream stays silent while the model is thinking
/// and only emits two events per turn:
///
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
/// (fully concatenated).
/// 2. The original `ResponseEvent::Completed` right after it.
///
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
///
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
/// events.
pub(crate) struct AggregatedChatStream<S> {
inner: S,
cumulative: String,
pending_completed: Option<ResponseEvent>,
}
impl<S> Stream for AggregatedChatStream<S>
where
S: Stream<Item = Result<ResponseEvent>> + Unpin,
{
type Item = Result<ResponseEvent>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
// First, flush any buffered Completed event from the previous call.
if let Some(ev) = this.pending_completed.take() {
return Poll::Ready(Some(Ok(ev)));
}
loop {
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
// Accumulate *assistant* text but do not emit yet.
if let crate::models::ResponseItem::Message { role, content } = &item {
if role == "assistant" {
if let Some(text) = content.iter().find_map(|c| match c {
crate::models::ContentItem::OutputText { text } => Some(text),
_ => None,
}) {
this.cumulative.push_str(text);
}
}
}
// Swallow partial event; keep polling.
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
if !this.cumulative.is_empty() {
let aggregated_item = crate::models::ResponseItem::Message {
role: "assistant".to_string(),
content: vec![crate::models::ContentItem::OutputText {
text: std::mem::take(&mut this.cumulative),
}],
};
// Buffer Completed so it is returned *after* the aggregated message.
this.pending_completed = Some(ResponseEvent::Completed { response_id });
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
aggregated_item,
))));
}
// Nothing aggregated forward Completed directly.
return Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id })));
} // No other `Ok` variants exist at the moment, continue polling.
}
}
}
}
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
/// Returns a new stream that emits **only** the final assistant message
/// per turn instead of every incremental delta. The produced
/// `ResponseEvent` sequence for a typical text turn looks like:
///
/// ```ignore
/// OutputItemDone(<full message>)
/// Completed { .. }
/// ```
///
/// No other `OutputItemDone` events will be seen by the caller.
///
/// Usage:
///
/// ```ignore
/// let agg_stream = client.stream(&prompt).await?.aggregate();
/// while let Some(event) = agg_stream.next().await {
/// // event now contains cumulative text
/// }
/// ```
fn aggregate(self) -> AggregatedChatStream<Self> {
AggregatedChatStream {
inner: self,
cumulative: String::new(),
pending_completed: None,
}
}
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}

View File

@@ -1,11 +1,7 @@
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::io::BufRead;
use std::path::Path;
use std::pin::Pin;
use std::sync::LazyLock;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use bytes::Bytes;
@@ -23,66 +19,22 @@ use tracing::debug;
use tracing::trace;
use tracing::warn;
use crate::chat_completions::stream_chat_completions;
use crate::client_common::Payload;
use crate::client_common::Prompt;
use crate::client_common::Reasoning;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::error::CodexErr;
use crate::error::Result;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::models::ResponseItem;
use crate::util::backoff;
/// API request payload for a single model turn.
#[derive(Default, Debug, Clone)]
pub struct Prompt {
/// Conversation context input items.
pub input: Vec<ResponseItem>,
/// Optional previous response ID (when storage is enabled).
pub prev_id: Option<String>,
/// Optional initial instructions (only sent on first turn).
pub instructions: Option<String>,
/// Whether to store response on server side (disable_response_storage = !store).
pub store: bool,
/// Additional tools sourced from external MCP servers. Note each key is
/// the "fully qualified" tool name (i.e., prefixed with the server name),
/// which should be reported to the model in place of Tool::name.
pub extra_tools: HashMap<String, mcp_types::Tool>,
}
#[derive(Debug)]
pub enum ResponseEvent {
OutputItemDone(ResponseItem),
Completed { response_id: String },
}
#[derive(Debug, Serialize)]
struct Payload<'a> {
model: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<&'a String>,
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
// we code defensively to avoid this case, but perhaps we should use a
// separate enum for serialization.
input: &'a Vec<ResponseItem>,
tools: &'a [serde_json::Value],
tool_choice: &'static str,
parallel_tool_calls: bool,
reasoning: Option<Reasoning>,
#[serde(skip_serializing_if = "Option::is_none")]
previous_response_id: Option<String>,
/// true when using the Responses API.
store: bool,
stream: bool,
}
#[derive(Debug, Serialize)]
struct Reasoning {
effort: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
generate_summary: Option<bool>,
}
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
/// Responses API.
#[derive(Debug, Serialize)]
@@ -152,7 +104,20 @@ impl ModelClient {
}
}
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
/// Dispatches to either the Responses or Chat implementation depending on
/// the provider config. Public callers always invoke `stream()` the
/// specialised helpers are private to avoid accidental misuse.
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
match self.provider.wire_api {
WireApi::Responses => self.stream_responses(prompt).await,
WireApi::Chat => {
stream_chat_completions(prompt, &self.model, &self.client, &self.provider).await
}
}
}
/// Implementation for the OpenAI *Responses* experimental API.
async fn stream_responses(&self, prompt: &Prompt) -> Result<ResponseStream> {
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
// short circuit for tests
warn!(path, "Streaming from fixture");
@@ -202,8 +167,8 @@ impl ModelClient {
let api_key = self
.provider
.api_key()
.ok_or_else(|| crate::error::CodexErr::EnvVar("API_KEY"))?;
.api_key()?
.expect("Repsones API requires an API key");
let res = self
.client
.post(&url)
@@ -396,18 +361,6 @@ where
}
}
pub struct ResponseStream {
rx_event: mpsc::Receiver<Result<ResponseEvent>>,
}
impl Stream for ResponseStream {
type Item = Result<ResponseEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx_event.poll_recv(cx)
}
}
/// used in tests to stream from a text SSE file
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);

View File

@@ -0,0 +1,72 @@
use crate::error::Result;
use crate::models::ResponseItem;
use futures::Stream;
use serde::Serialize;
use std::collections::HashMap;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
/// API request payload for a single model turn.
#[derive(Default, Debug, Clone)]
pub struct Prompt {
/// Conversation context input items.
pub input: Vec<ResponseItem>,
/// Optional previous response ID (when storage is enabled).
pub prev_id: Option<String>,
/// Optional initial instructions (only sent on first turn).
pub instructions: Option<String>,
/// Whether to store response on server side (disable_response_storage = !store).
pub store: bool,
/// Additional tools sourced from external MCP servers. Note each key is
/// the "fully qualified" tool name (i.e., prefixed with the server name),
/// which should be reported to the model in place of Tool::name.
pub extra_tools: HashMap<String, mcp_types::Tool>,
}
#[derive(Debug)]
pub enum ResponseEvent {
OutputItemDone(ResponseItem),
Completed { response_id: String },
}
#[derive(Debug, Serialize)]
pub(crate) struct Reasoning {
pub(crate) effort: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) generate_summary: Option<bool>,
}
#[derive(Debug, Serialize)]
pub(crate) struct Payload<'a> {
pub(crate) model: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) instructions: Option<&'a String>,
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
// we code defensively to avoid this case, but perhaps we should use a
// separate enum for serialization.
pub(crate) input: &'a Vec<ResponseItem>,
pub(crate) tools: &'a [serde_json::Value],
pub(crate) tool_choice: &'static str,
pub(crate) parallel_tool_calls: bool,
pub(crate) reasoning: Option<Reasoning>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) previous_response_id: Option<String>,
/// true when using the Responses API.
pub(crate) store: bool,
pub(crate) stream: bool,
}
pub(crate) struct ResponseStream {
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
}
impl Stream for ResponseStream {
type Item = Result<ResponseEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx_event.poll_recv(cx)
}
}

View File

@@ -31,10 +31,13 @@ use tracing::info;
use tracing::trace;
use tracing::warn;
use crate::WireApi;
use crate::chat_completions::AggregateStreamExt;
use crate::client::ModelClient;
use crate::client::Prompt;
use crate::client::ResponseEvent;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::config::Config;
use crate::conversation_history::ConversationHistory;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::exec::ExecParams;
@@ -65,7 +68,6 @@ use crate::safety::assess_command_safety;
use crate::safety::assess_patch_safety;
use crate::user_notification::UserNotification;
use crate::util::backoff;
use crate::zdr_transcript::ZdrTranscript;
/// The high-level interface to the Codex system.
/// It operates as a queue pair where you send submissions and receive events.
@@ -181,7 +183,7 @@ struct State {
previous_response_id: Option<String>,
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
pending_input: Vec<ResponseInputItem>,
zdr_transcript: Option<ZdrTranscript>,
zdr_transcript: Option<ConversationHistory>,
}
impl Session {
@@ -416,11 +418,15 @@ impl Drop for Session {
}
impl State {
pub fn partial_clone(&self) -> Self {
pub fn partial_clone(&self, retain_zdr_transcript: bool) -> Self {
Self {
approved_commands: self.approved_commands.clone(),
previous_response_id: self.previous_response_id.clone(),
zdr_transcript: self.zdr_transcript.clone(),
zdr_transcript: if retain_zdr_transcript {
self.zdr_transcript.clone()
} else {
None
},
..Default::default()
}
}
@@ -534,14 +540,19 @@ async fn submission_loop(
let client = ModelClient::new(model.clone(), provider.clone());
// abort any current running session and clone its state
let retain_zdr_transcript =
record_conversation_history(disable_response_storage, provider.wire_api);
let state = match sess.take() {
Some(sess) => {
sess.abort();
sess.state.lock().unwrap().partial_clone()
sess.state
.lock()
.unwrap()
.partial_clone(retain_zdr_transcript)
}
None => State {
zdr_transcript: if disable_response_storage {
Some(ZdrTranscript::new())
zdr_transcript: if retain_zdr_transcript {
Some(ConversationHistory::new())
} else {
None
},
@@ -670,21 +681,35 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from);
net_new_turn_input.extend(pending_input);
// Persist only the net-new items of this turn to the rollout.
sess.record_rollout_items(&net_new_turn_input).await;
// Construct the input that we will send to the model. When using the
// Chat completions API (or ZDR clients), the model needs the full
// conversation history on each turn. The rollout file, however, should
// only record the new items that originated in this turn so that it
// represents an append-only log without duplicates.
let turn_input: Vec<ResponseItem> =
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
// If we are using ZDR, we need to send the transcript with every turn.
let mut full_transcript = transcript.contents();
full_transcript.extend(net_new_turn_input.clone());
// If we are using Chat/ZDR, we need to send the transcript with every turn.
// 1. Build up the conversation history for the next turn.
let full_transcript = [transcript.contents(), net_new_turn_input.clone()].concat();
// 2. Update the in-memory transcript so that future turns
// include these items as part of the history.
transcript.record_items(net_new_turn_input);
// Note that `transcript.record_items()` does some filtering
// such that `full_transcript` may include items that were
// excluded from `transcript`.
full_transcript
} else {
// Responses API path we can just send the new items and
// record the same.
net_new_turn_input
};
// Persist the input part of the turn to the rollout (user messages /
// function_call_output from previous step).
sess.record_rollout_items(&turn_input).await;
let turn_input_messages: Vec<String> = turn_input
.iter()
.filter_map(|item| match item {
@@ -794,6 +819,7 @@ async fn run_turn(
match try_run_turn(sess, &sub_id, &prompt).await {
Ok(output) => return Ok(output),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e) => {
if retries < *OPENAI_STREAM_MAX_RETRIES {
retries += 1;
@@ -838,7 +864,7 @@ async fn try_run_turn(
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<Vec<ProcessedResponseItem>> {
let mut stream = sess.client.clone().stream(prompt).await?;
let mut stream = sess.client.clone().stream(prompt).await?.aggregate();
// Buffer all the incoming messages from the stream first, then execute them.
// If we execute a function call in the middle of handling the stream, it can time out.
@@ -1612,3 +1638,15 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
}
})
}
/// See [`ConversationHistory`] for details.
fn record_conversation_history(disable_response_storage: bool, wire_api: WireApi) -> bool {
if disable_response_storage {
return true;
}
match wire_api {
WireApi::Responses => false,
WireApi::Chat => true,
}
}

View File

@@ -21,6 +21,9 @@ pub struct Config {
/// Optional override of model selection.
pub model: String,
/// Key into the model_providers map that specifies which provider to use.
pub model_provider_id: String,
/// Info needed to make an API request to the model.
pub model_provider: ModelProviderInfo,
@@ -219,21 +222,22 @@ impl Config {
model_providers.entry(key).or_insert(provider);
}
let model_provider_name = provider
let model_provider_id = provider
.or(cfg.model_provider)
.unwrap_or_else(|| "openai".to_string());
let model_provider = model_providers
.get(&model_provider_name)
.get(&model_provider_id)
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Model provider `{model_provider_name}` not found"),
format!("Model provider `{model_provider_id}` not found"),
)
})?
.clone();
let config = Self {
model: model.or(cfg.model).unwrap_or_else(default_model),
model_provider_id,
model_provider,
cwd: cwd.map_or_else(
|| {

View File

@@ -1,16 +1,18 @@
use crate::models::ResponseItem;
/// Transcript that needs to be maintained for ZDR clients for which
/// previous_response_id is not available, so we must include the transcript
/// with every API call. This must include each `function_call` and its
/// corresponding `function_call_output`.
/// Transcript of conversation history that is needed:
/// - for ZDR clients for which previous_response_id is not available, so we
/// must include the transcript with every API call. This must include each
/// `function_call` and its corresponding `function_call_output`.
/// - for clients using the "chat completions" API as opposed to the
/// "responses" API.
#[derive(Debug, Clone)]
pub(crate) struct ZdrTranscript {
pub(crate) struct ConversationHistory {
/// The oldest items are at the beginning of the vector.
items: Vec<ResponseItem>,
}
impl ZdrTranscript {
impl ConversationHistory {
pub(crate) fn new() -> Self {
Self { items: Vec::new() }
}

View File

@@ -55,7 +55,7 @@ pub enum CodexErr {
/// Returned by run_command_stream when the user pressed CtrlC (SIGINT). Session uses this to
/// surface a polite FunctionCallOutput back to the model instead of crashing the CLI.
#[error("interrupted (CtrlC)")]
#[error("interrupted (Ctrl-C)")]
Interrupted,
/// Unexpected HTTP status code.
@@ -97,8 +97,28 @@ pub enum CodexErr {
#[error(transparent)]
TokioJoin(#[from] JoinError),
#[error("missing environment variable {0}")]
EnvVar(&'static str),
#[error("{0}")]
EnvVar(EnvVarError),
}
#[derive(Debug)]
pub struct EnvVarError {
/// Name of the environment variable that is missing.
pub var: String,
/// Optional instructions to help the user get a valid value for the
/// variable and set it.
pub instructions: Option<String>,
}
impl std::fmt::Display for EnvVarError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Missing environment variable: `{}`.", self.var)?;
if let Some(instructions) = &self.instructions {
write!(f, " {instructions}")?;
}
Ok(())
}
}
impl CodexErr {

View File

@@ -5,11 +5,15 @@
// the TUI or the tracing stack).
#![deny(clippy::print_stdout, clippy::print_stderr)]
mod chat_completions;
mod client;
mod client_common;
pub mod codex;
pub use codex::Codex;
pub mod codex_wrapper;
pub mod config;
mod conversation_history;
pub mod error;
pub mod exec;
mod flags;
@@ -21,10 +25,10 @@ pub mod mcp_server_config;
mod mcp_tool_call;
mod model_provider_info;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::WireApi;
mod models;
pub mod protocol;
mod rollout;
mod safety;
mod user_notification;
pub mod util;
mod zdr_transcript;

View File

@@ -8,6 +8,25 @@
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::env::VarError;
use crate::error::EnvVarError;
/// Wire protocol that the provider speaks. Most third-party services only
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
/// itself (and a handful of others) additionally expose the more modern
/// *Responses* API. The two protocols use different request/response shapes
/// and *cannot* be auto-detected at runtime, therefore each provider entry
/// must declare which one it expects.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
/// The experimental “Responses” API exposed by OpenAI at `/v1/responses`.
#[default]
Responses,
/// Regular Chat Completions compatible with `/v1/chat/completions`.
Chat,
}
/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize)]
@@ -17,13 +36,38 @@ pub struct ModelProviderInfo {
/// Base URL for the provider's OpenAI-compatible API.
pub base_url: String,
/// Environment variable that stores the user's API key for this provider.
pub env_key: String,
pub env_key: Option<String>,
/// Optional instructions to help the user get a valid value for the
/// variable and set it.
pub env_key_instructions: Option<String>,
/// Which wire protocol this provider expects.
pub wire_api: WireApi,
}
impl ModelProviderInfo {
/// Returns the API key for this provider if present in the environment.
pub fn api_key(&self) -> Option<String> {
std::env::var(&self.env_key).ok()
/// If `env_key` is Some, returns the API key for this provider if present
/// (and non-empty) in the environment. If `env_key` is required but
/// cannot be found, returns an error.
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
match &self.env_key {
Some(env_key) => std::env::var(env_key)
.and_then(|v| {
if v.trim().is_empty() {
Err(VarError::NotPresent)
} else {
Ok(Some(v))
}
})
.map_err(|_| {
crate::error::CodexErr::EnvVar(EnvVarError {
var: env_key.clone(),
instructions: self.env_key_instructions.clone(),
})
}),
None => Ok(None),
}
}
}
@@ -37,7 +81,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
P {
name: "OpenAI".into(),
base_url: "https://api.openai.com/v1".into(),
env_key: "OPENAI_API_KEY".into(),
env_key: Some("OPENAI_API_KEY".into()),
env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()),
wire_api: WireApi::Responses,
},
),
(
@@ -45,7 +91,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
P {
name: "OpenRouter".into(),
base_url: "https://openrouter.ai/api/v1".into(),
env_key: "OPENROUTER_API_KEY".into(),
env_key: Some("OPENROUTER_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
},
),
(
@@ -53,7 +101,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
P {
name: "Gemini".into(),
base_url: "https://generativelanguage.googleapis.com/v1beta/openai".into(),
env_key: "GEMINI_API_KEY".into(),
env_key: Some("GEMINI_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
},
),
(
@@ -61,7 +111,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
P {
name: "Ollama".into(),
base_url: "http://localhost:11434/v1".into(),
env_key: "OLLAMA_API_KEY".into(),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Chat,
},
),
(
@@ -69,7 +121,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
P {
name: "Mistral".into(),
base_url: "https://api.mistral.ai/v1".into(),
env_key: "MISTRAL_API_KEY".into(),
env_key: Some("MISTRAL_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
},
),
(
@@ -77,7 +131,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
P {
name: "DeepSeek".into(),
base_url: "https://api.deepseek.com".into(),
env_key: "DEEPSEEK_API_KEY".into(),
env_key: Some("DEEPSEEK_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
},
),
(
@@ -85,7 +141,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
P {
name: "xAI".into(),
base_url: "https://api.x.ai/v1".into(),
env_key: "XAI_API_KEY".into(),
env_key: Some("XAI_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
},
),
(
@@ -93,7 +151,9 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
P {
name: "Groq".into(),
base_url: "https://api.groq.com/openai/v1".into(),
env_key: "GROQ_API_KEY".into(),
env_key: Some("GROQ_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
},
),
]

View File

@@ -116,10 +116,10 @@ pub struct ShellToolCallParams {
pub timeout_ms: Option<u64>,
}
#[expect(dead_code)]
#[derive(Deserialize, Debug, Clone)]
pub struct FunctionCallOutputPayload {
pub content: String,
#[expect(dead_code)]
pub success: Option<bool>,
}

View File

@@ -25,6 +25,7 @@ pub struct Submission {
/// Submission operation
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[allow(clippy::large_enum_variant)]
#[non_exhaustive]
pub enum Op {
/// Configure the model session.

View File

@@ -92,7 +92,9 @@ async fn keeps_previous_response_id_between_tasks() {
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// provider is not set.
env_key: "PATH".into(),
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: codex_core::WireApi::Responses,
};
// Init session

View File

@@ -82,7 +82,9 @@ async fn retries_on_early_close() {
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// provider is not set.
env_key: "PATH".into(),
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: codex_core::WireApi::Responses,
};
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());