Files
llmx/codex-rs/core/src/chat_completions.rs
Dylan 3e8bcf0247 [prompts] Add <environment_context> (#1869)
## Summary
Includes a new user message in the api payload which provides useful
environment context for the model, so it knows about things like the
current working directory and the sandbox.

## Testing
Updated unit tests
2025-08-06 01:13:31 -07:00

639 lines
25 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::time::Duration;
use bytes::Bytes;
use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
use futures::TryStreamExt;
use reqwest::StatusCode;
use serde_json::json;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tracing::debug;
use tracing::trace;
use crate::ModelProviderInfo;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::error::CodexErr;
use crate::error::Result;
use crate::model_family::ModelFamily;
use crate::models::ContentItem;
use crate::models::ReasoningItemContent;
use crate::models::ResponseItem;
use crate::openai_tools::create_tools_json_for_chat_completions_api;
use crate::util::backoff;
/// Implementation for the classic Chat Completions API.
pub(crate) async fn stream_chat_completions(
prompt: &Prompt,
model_family: &ModelFamily,
client: &reqwest::Client,
provider: &ModelProviderInfo,
) -> Result<ResponseStream> {
// Build messages array
let mut messages = Vec::<serde_json::Value>::new();
let full_instructions = prompt.get_full_instructions(model_family);
messages.push(json!({"role": "system", "content": full_instructions}));
let input = prompt.get_formatted_input();
for item in &input {
match item {
ResponseItem::Message { role, content, .. } => {
let mut text = String::new();
for c in content {
match c {
ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => {
text.push_str(t);
}
_ => {}
}
}
messages.push(json!({"role": role, "content": text}));
}
ResponseItem::FunctionCall {
name,
arguments,
call_id,
..
} => {
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}]
}));
}
ResponseItem::LocalShellCall {
id,
call_id: _,
status,
action,
} => {
// Confirm with API team.
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": id.clone().unwrap_or_else(|| "".to_string()),
"type": "local_shell_call",
"status": status,
"action": action,
}]
}));
}
ResponseItem::FunctionCallOutput { call_id, output } => {
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": output.content,
}));
}
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
// Omit these items from the conversation history.
continue;
}
}
}
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let payload = json!({
"model": model_family.slug,
"messages": messages,
"stream": true,
"tools": tools_json,
});
debug!(
"POST to {}: {}",
provider.get_full_url(&None),
serde_json::to_string_pretty(&payload).unwrap_or_default()
);
let mut attempt = 0;
let max_retries = provider.request_max_retries();
loop {
attempt += 1;
let req_builder = provider.create_request_builder(client, &None).await?;
let res = req_builder
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(&payload)
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
tokio::spawn(process_chat_sse(
stream,
tx_event,
provider.stream_idle_timeout(),
));
return Ok(ResponseStream { rx_event });
}
Ok(res) => {
let status = res.status();
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
let body = (res.text().await).unwrap_or_default();
return Err(CodexErr::UnexpectedStatus(status, body));
}
if attempt > max_retries {
return Err(CodexErr::RetryLimit(status));
}
let retry_after_secs = res
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let delay = retry_after_secs
.map(|s| Duration::from_millis(s * 1_000))
.unwrap_or_else(|| backoff(attempt));
tokio::time::sleep(delay).await;
}
Err(e) => {
if attempt > max_retries {
return Err(e.into());
}
let delay = backoff(attempt);
tokio::time::sleep(delay).await;
}
}
}
}
/// Lightweight SSE processor for the Chat Completions streaming format. The
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
/// of the pipeline can stay agnostic of the underlying wire format.
async fn process_chat_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
idle_timeout: Duration,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
// State to accumulate a function call across streaming chunks.
// OpenAI may split the `arguments` string over multiple `delta` events
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
// keep collecting the pieces here and forward a single
// `ResponseItem::FunctionCall` once the call is complete.
#[derive(Default)]
struct FunctionCallState {
name: Option<String>,
arguments: String,
call_id: Option<String>,
active: bool,
}
let mut fn_call_state = FunctionCallState::default();
let mut assistant_text = String::new();
let mut reasoning_text = String::new();
loop {
let sse = match timeout(idle_timeout, stream.next()).await {
Ok(Some(Ok(ev))) => ev,
Ok(Some(Err(e))) => {
let _ = tx_event.send(Err(CodexErr::Stream(e.to_string()))).await;
return;
}
Ok(None) => {
// Stream closed gracefully emit Completed with dummy id.
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
return;
}
Err(_) => {
let _ = tx_event
.send(Err(CodexErr::Stream("idle timeout waiting for SSE".into())))
.await;
return;
}
};
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
if sse.data.trim() == "[DONE]" {
// Emit any finalized items before closing so downstream consumers receive
// terminal events for both assistant content and raw reasoning.
if !assistant_text.is_empty() {
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: std::mem::take(&mut assistant_text),
}],
id: None,
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
if !reasoning_text.is_empty() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![ReasoningItemContent::ReasoningText {
text: std::mem::take(&mut reasoning_text),
}]),
encrypted_content: None,
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
return;
}
// Parse JSON chunk
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
Ok(v) => v,
Err(_) => continue,
};
trace!("chat_completions received SSE chunk: {chunk:?}");
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
if let Some(choice) = choice_opt {
// Handle assistant content tokens as streaming deltas.
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
if !content.is_empty() {
assistant_text.push_str(content);
let _ = tx_event
.send(Ok(ResponseEvent::OutputTextDelta(content.to_string())))
.await;
}
}
// Forward any reasoning/thinking deltas if present.
// Some providers stream `reasoning` as a plain string while others
// nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`).
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
let mut maybe_text = reasoning_val.as_str().map(|s| s.to_string());
if maybe_text.is_none() && reasoning_val.is_object() {
if let Some(s) = reasoning_val
.get("text")
.and_then(|t| t.as_str())
.filter(|s| !s.is_empty())
{
maybe_text = Some(s.to_string());
} else if let Some(s) = reasoning_val
.get("content")
.and_then(|t| t.as_str())
.filter(|s| !s.is_empty())
{
maybe_text = Some(s.to_string());
}
}
if let Some(reasoning) = maybe_text {
let _ = tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta(reasoning)))
.await;
}
}
// Handle streaming function / tool calls.
if let Some(tool_calls) = choice
.get("delta")
.and_then(|d| d.get("tool_calls"))
.and_then(|tc| tc.as_array())
{
if let Some(tool_call) = tool_calls.first() {
// Mark that we have an active function call in progress.
fn_call_state.active = true;
// Extract call_id if present.
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
}
// Extract function details if present.
if let Some(function) = tool_call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
fn_call_state.name.get_or_insert_with(|| name.to_string());
}
if let Some(args_fragment) =
function.get("arguments").and_then(|a| a.as_str())
{
fn_call_state.arguments.push_str(args_fragment);
}
}
}
}
// Emit end-of-turn when finish_reason signals completion.
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
match finish_reason {
"tool_calls" if fn_call_state.active => {
// First, flush the terminal raw reasoning so UIs can finalize
// the reasoning stream before any exec/tool events begin.
if !reasoning_text.is_empty() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![ReasoningItemContent::ReasoningText {
text: std::mem::take(&mut reasoning_text),
}]),
encrypted_content: None,
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
// Then emit the FunctionCall response item.
let item = ResponseItem::FunctionCall {
id: None,
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
arguments: fn_call_state.arguments.clone(),
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
"stop" => {
// Regular turn without tool-call. Emit the final assistant message
// as a single OutputItemDone so non-delta consumers see the result.
if !assistant_text.is_empty() {
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: std::mem::take(&mut assistant_text),
}],
id: None,
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
// Also emit a terminal Reasoning item so UIs can finalize raw reasoning.
if !reasoning_text.is_empty() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![ReasoningItemContent::ReasoningText {
text: std::mem::take(&mut reasoning_text),
}]),
encrypted_content: None,
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
}
_ => {}
}
// Emit Completed regardless of reason so the agent can advance.
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
// Prepare for potential next turn (should not happen in same stream).
// fn_call_state = FunctionCallState::default();
return; // End processing for this SSE stream.
}
}
}
}
/// Optional client-side aggregation helper
///
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
/// per-token deltas**. The stream stays silent while the model is thinking
/// and only emits two events per turn:
///
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
/// (fully concatenated).
/// 2. The original `ResponseEvent::Completed` right after it.
///
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
///
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
/// events.
#[derive(Copy, Clone, Eq, PartialEq)]
enum AggregateMode {
AggregatedOnly,
Streaming,
}
pub(crate) struct AggregatedChatStream<S> {
inner: S,
cumulative: String,
cumulative_reasoning: String,
pending: std::collections::VecDeque<ResponseEvent>,
mode: AggregateMode,
}
impl<S> Stream for AggregatedChatStream<S>
where
S: Stream<Item = Result<ResponseEvent>> + Unpin,
{
type Item = Result<ResponseEvent>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
// First, flush any buffered events from the previous call.
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
loop {
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
// If this is an incremental assistant message chunk, accumulate but
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
// away so downstream consumers see it.
let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant");
if is_assistant_delta {
// Only use the final assistant message if we have not
// seen any deltas; otherwise, deltas already built the
// cumulative text and this would duplicate it.
if this.cumulative.is_empty() {
if let crate::models::ResponseItem::Message { content, .. } = &item {
if let Some(text) = content.iter().find_map(|c| match c {
crate::models::ContentItem::OutputText { text } => Some(text),
_ => None,
}) {
this.cumulative.push_str(text);
}
}
}
// Swallow assistant message here; emit on Completed.
continue;
}
// Not an assistant message forward immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
}
Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
}))) => {
// Build any aggregated items in the correct order: Reasoning first, then Message.
let mut emitted_any = false;
if !this.cumulative_reasoning.is_empty()
&& matches!(this.mode, AggregateMode::AggregatedOnly)
{
let aggregated_reasoning = crate::models::ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![
crate::models::ReasoningItemContent::ReasoningText {
text: std::mem::take(&mut this.cumulative_reasoning),
},
]),
encrypted_content: None,
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
emitted_any = true;
}
if !this.cumulative.is_empty() {
let aggregated_message = crate::models::ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![crate::models::ContentItem::OutputText {
text: std::mem::take(&mut this.cumulative),
}],
};
this.pending
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
emitted_any = true;
}
// Always emit Completed last when anything was aggregated.
if emitted_any {
this.pending.push_back(ResponseEvent::Completed {
response_id: response_id.clone(),
token_usage: token_usage.clone(),
});
// Return the first pending event now.
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
}
}
// Nothing aggregated forward Completed directly.
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
})));
}
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
// These events are exclusive to the Responses API and
// will never appear in a Chat Completions stream.
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
this.cumulative.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
// In streaming mode, also forward the delta immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
} else {
continue;
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
this.cumulative_reasoning.push_str(&delta);
if matches!(this.mode, AggregateMode::Streaming) {
// In streaming mode, also forward the delta immediately.
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
} else {
continue;
}
}
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {
continue;
}
}
}
}
}
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
/// Returns a new stream that emits **only** the final assistant message
/// per turn instead of every incremental delta. The produced
/// `ResponseEvent` sequence for a typical text turn looks like:
///
/// ```ignore
/// OutputItemDone(<full message>)
/// Completed
/// ```
///
/// No other `OutputItemDone` events will be seen by the caller.
///
/// Usage:
///
/// ```ignore
/// let agg_stream = client.stream(&prompt).await?.aggregate();
/// while let Some(event) = agg_stream.next().await {
/// // event now contains cumulative text
/// }
/// ```
fn aggregate(self) -> AggregatedChatStream<Self> {
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
}
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
impl<S> AggregatedChatStream<S> {
fn new(inner: S, mode: AggregateMode) -> Self {
AggregatedChatStream {
inner,
cumulative: String::new(),
cumulative_reasoning: String::new(),
pending: std::collections::VecDeque::new(),
mode,
}
}
pub(crate) fn streaming_mode(inner: S) -> Self {
Self::new(inner, AggregateMode::Streaming)
}
}