feat: support the chat completions API in the Rust CLI (#862)

This is a substantial PR to add support for the chat completions API,
which in turn makes it possible to use non-OpenAI model providers (just
like in the TypeScript CLI):

* It moves a number of structs from `client.rs` to `client_common.rs` so
they can be shared.
* It introduces support for the chat completions API in
`chat_completions.rs`.
* It updates `ModelProviderInfo` so that `env_key` is `Option<String>`
instead of `String` (for e.g., ollama) and adds a `wire_api` field
* It updates `client.rs` to choose between `stream_responses()` and
`stream_chat_completions()` based on the `wire_api` for the
`ModelProviderInfo`
* It updates the `exec` and TUI CLIs to no longer fail if the
`OPENAI_API_KEY` environment variable is not set
* It updates the TUI so that `EventMsg::Error` is displayed more
prominently when it occurs, particularly now that it is important to
alert users to the `CodexErr::EnvVar` variant.
* `CodexErr::EnvVar` was updated to include an optional `instructions`
field so we can preserve the behavior where we direct users to
https://platform.openai.com if `OPENAI_API_KEY` is not set.
* Cleaned up the "welcome message" in the TUI to ensure the model
provider is displayed.
* Updated the docs in `codex-rs/README.md`.

To exercise the chat completions API from OpenAI models, I added the
following to my `config.toml`:

```toml
model = "gpt-4o"
model_provider = "openai-chat-completions"

[model_providers.openai-chat-completions]
name = "OpenAI using Chat Completions"
base_url = "https://api.openai.com/v1"
env_key = "OPENAI_API_KEY"
wire_api = "chat"
```

Though to test a non-OpenAI provider, I installed ollama with mistral
locally on my Mac because ChatGPT said that would be a good match for my
hardware:

```shell
brew install ollama
ollama serve
ollama pull mistral
```

Then I added the following to my `~/.codex/config.toml`:

```toml
model = "mistral"
model_provider = "ollama"
```

Note this code could certainly use more test coverage, but I want to get
this in so folks can start playing with it.

For reference, I believe https://github.com/openai/codex/pull/247 was
roughly the comparable PR on the TypeScript side.
This commit is contained in:
Michael Bolin
2025-05-08 21:46:06 -07:00
committed by GitHub
parent a538e6acb2
commit e924070cee
20 changed files with 703 additions and 200 deletions

View File

@@ -31,10 +31,13 @@ use tracing::info;
use tracing::trace;
use tracing::warn;
use crate::WireApi;
use crate::chat_completions::AggregateStreamExt;
use crate::client::ModelClient;
use crate::client::Prompt;
use crate::client::ResponseEvent;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::config::Config;
use crate::conversation_history::ConversationHistory;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::exec::ExecParams;
@@ -65,7 +68,6 @@ use crate::safety::assess_command_safety;
use crate::safety::assess_patch_safety;
use crate::user_notification::UserNotification;
use crate::util::backoff;
use crate::zdr_transcript::ZdrTranscript;
/// The high-level interface to the Codex system.
/// It operates as a queue pair where you send submissions and receive events.
@@ -181,7 +183,7 @@ struct State {
previous_response_id: Option<String>,
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
pending_input: Vec<ResponseInputItem>,
zdr_transcript: Option<ZdrTranscript>,
zdr_transcript: Option<ConversationHistory>,
}
impl Session {
@@ -416,11 +418,15 @@ impl Drop for Session {
}
impl State {
pub fn partial_clone(&self) -> Self {
pub fn partial_clone(&self, retain_zdr_transcript: bool) -> Self {
Self {
approved_commands: self.approved_commands.clone(),
previous_response_id: self.previous_response_id.clone(),
zdr_transcript: self.zdr_transcript.clone(),
zdr_transcript: if retain_zdr_transcript {
self.zdr_transcript.clone()
} else {
None
},
..Default::default()
}
}
@@ -534,14 +540,19 @@ async fn submission_loop(
let client = ModelClient::new(model.clone(), provider.clone());
// abort any current running session and clone its state
let retain_zdr_transcript =
record_conversation_history(disable_response_storage, provider.wire_api);
let state = match sess.take() {
Some(sess) => {
sess.abort();
sess.state.lock().unwrap().partial_clone()
sess.state
.lock()
.unwrap()
.partial_clone(retain_zdr_transcript)
}
None => State {
zdr_transcript: if disable_response_storage {
Some(ZdrTranscript::new())
zdr_transcript: if retain_zdr_transcript {
Some(ConversationHistory::new())
} else {
None
},
@@ -670,21 +681,35 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from);
net_new_turn_input.extend(pending_input);
// Persist only the net-new items of this turn to the rollout.
sess.record_rollout_items(&net_new_turn_input).await;
// Construct the input that we will send to the model. When using the
// Chat completions API (or ZDR clients), the model needs the full
// conversation history on each turn. The rollout file, however, should
// only record the new items that originated in this turn so that it
// represents an append-only log without duplicates.
let turn_input: Vec<ResponseItem> =
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
// If we are using ZDR, we need to send the transcript with every turn.
let mut full_transcript = transcript.contents();
full_transcript.extend(net_new_turn_input.clone());
// If we are using Chat/ZDR, we need to send the transcript with every turn.
// 1. Build up the conversation history for the next turn.
let full_transcript = [transcript.contents(), net_new_turn_input.clone()].concat();
// 2. Update the in-memory transcript so that future turns
// include these items as part of the history.
transcript.record_items(net_new_turn_input);
// Note that `transcript.record_items()` does some filtering
// such that `full_transcript` may include items that were
// excluded from `transcript`.
full_transcript
} else {
// Responses API path we can just send the new items and
// record the same.
net_new_turn_input
};
// Persist the input part of the turn to the rollout (user messages /
// function_call_output from previous step).
sess.record_rollout_items(&turn_input).await;
let turn_input_messages: Vec<String> = turn_input
.iter()
.filter_map(|item| match item {
@@ -794,6 +819,7 @@ async fn run_turn(
match try_run_turn(sess, &sub_id, &prompt).await {
Ok(output) => return Ok(output),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e) => {
if retries < *OPENAI_STREAM_MAX_RETRIES {
retries += 1;
@@ -838,7 +864,7 @@ async fn try_run_turn(
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<Vec<ProcessedResponseItem>> {
let mut stream = sess.client.clone().stream(prompt).await?;
let mut stream = sess.client.clone().stream(prompt).await?.aggregate();
// Buffer all the incoming messages from the stream first, then execute them.
// If we execute a function call in the middle of handling the stream, it can time out.
@@ -1612,3 +1638,15 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
}
})
}
/// See [`ConversationHistory`] for details.
fn record_conversation_history(disable_response_storage: bool, wire_api: WireApi) -> bool {
if disable_response_storage {
return true;
}
match wire_api {
WireApi::Responses => false,
WireApi::Chat => true,
}
}