feat: support the chat completions API in the Rust CLI (#862)

This is a substantial PR to add support for the chat completions API,
which in turn makes it possible to use non-OpenAI model providers (just
like in the TypeScript CLI):

* It moves a number of structs from `client.rs` to `client_common.rs` so
they can be shared.
* It introduces support for the chat completions API in
`chat_completions.rs`.
* It updates `ModelProviderInfo` so that `env_key` is `Option<String>`
instead of `String` (for e.g., ollama) and adds a `wire_api` field
* It updates `client.rs` to choose between `stream_responses()` and
`stream_chat_completions()` based on the `wire_api` for the
`ModelProviderInfo`
* It updates the `exec` and TUI CLIs to no longer fail if the
`OPENAI_API_KEY` environment variable is not set
* It updates the TUI so that `EventMsg::Error` is displayed more
prominently when it occurs, particularly now that it is important to
alert users to the `CodexErr::EnvVar` variant.
* `CodexErr::EnvVar` was updated to include an optional `instructions`
field so we can preserve the behavior where we direct users to
https://platform.openai.com if `OPENAI_API_KEY` is not set.
* Cleaned up the "welcome message" in the TUI to ensure the model
provider is displayed.
* Updated the docs in `codex-rs/README.md`.

To exercise the chat completions API from OpenAI models, I added the
following to my `config.toml`:

```toml
model = "gpt-4o"
model_provider = "openai-chat-completions"

[model_providers.openai-chat-completions]
name = "OpenAI using Chat Completions"
base_url = "https://api.openai.com/v1"
env_key = "OPENAI_API_KEY"
wire_api = "chat"
```

Though to test a non-OpenAI provider, I installed ollama with mistral
locally on my Mac because ChatGPT said that would be a good match for my
hardware:

```shell
brew install ollama
ollama serve
ollama pull mistral
```

Then I added the following to my `~/.codex/config.toml`:

```toml
model = "mistral"
model_provider = "ollama"
```

Note this code could certainly use more test coverage, but I want to get
this in so folks can start playing with it.

For reference, I believe https://github.com/openai/codex/pull/247 was
roughly the comparable PR on the TypeScript side.
This commit is contained in:
Michael Bolin
2025-05-08 21:46:06 -07:00
committed by GitHub
parent a538e6acb2
commit e924070cee
20 changed files with 703 additions and 200 deletions

View File

@@ -1,11 +1,7 @@
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::io::BufRead;
use std::path::Path;
use std::pin::Pin;
use std::sync::LazyLock;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use bytes::Bytes;
@@ -23,66 +19,22 @@ use tracing::debug;
use tracing::trace;
use tracing::warn;
use crate::chat_completions::stream_chat_completions;
use crate::client_common::Payload;
use crate::client_common::Prompt;
use crate::client_common::Reasoning;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::error::CodexErr;
use crate::error::Result;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::models::ResponseItem;
use crate::util::backoff;
/// API request payload for a single model turn.
#[derive(Default, Debug, Clone)]
pub struct Prompt {
/// Conversation context input items.
pub input: Vec<ResponseItem>,
/// Optional previous response ID (when storage is enabled).
pub prev_id: Option<String>,
/// Optional initial instructions (only sent on first turn).
pub instructions: Option<String>,
/// Whether to store response on server side (disable_response_storage = !store).
pub store: bool,
/// Additional tools sourced from external MCP servers. Note each key is
/// the "fully qualified" tool name (i.e., prefixed with the server name),
/// which should be reported to the model in place of Tool::name.
pub extra_tools: HashMap<String, mcp_types::Tool>,
}
#[derive(Debug)]
pub enum ResponseEvent {
OutputItemDone(ResponseItem),
Completed { response_id: String },
}
#[derive(Debug, Serialize)]
struct Payload<'a> {
model: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<&'a String>,
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
// we code defensively to avoid this case, but perhaps we should use a
// separate enum for serialization.
input: &'a Vec<ResponseItem>,
tools: &'a [serde_json::Value],
tool_choice: &'static str,
parallel_tool_calls: bool,
reasoning: Option<Reasoning>,
#[serde(skip_serializing_if = "Option::is_none")]
previous_response_id: Option<String>,
/// true when using the Responses API.
store: bool,
stream: bool,
}
#[derive(Debug, Serialize)]
struct Reasoning {
effort: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
generate_summary: Option<bool>,
}
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
/// Responses API.
#[derive(Debug, Serialize)]
@@ -152,7 +104,20 @@ impl ModelClient {
}
}
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
/// Dispatches to either the Responses or Chat implementation depending on
/// the provider config. Public callers always invoke `stream()` the
/// specialised helpers are private to avoid accidental misuse.
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
match self.provider.wire_api {
WireApi::Responses => self.stream_responses(prompt).await,
WireApi::Chat => {
stream_chat_completions(prompt, &self.model, &self.client, &self.provider).await
}
}
}
/// Implementation for the OpenAI *Responses* experimental API.
async fn stream_responses(&self, prompt: &Prompt) -> Result<ResponseStream> {
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
// short circuit for tests
warn!(path, "Streaming from fixture");
@@ -202,8 +167,8 @@ impl ModelClient {
let api_key = self
.provider
.api_key()
.ok_or_else(|| crate::error::CodexErr::EnvVar("API_KEY"))?;
.api_key()?
.expect("Repsones API requires an API key");
let res = self
.client
.post(&url)
@@ -396,18 +361,6 @@ where
}
}
pub struct ResponseStream {
rx_event: mpsc::Receiver<Result<ResponseEvent>>,
}
impl Stream for ResponseStream {
type Item = Result<ResponseEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx_event.poll_recv(cx)
}
}
/// used in tests to stream from a text SSE file
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);