Files
llmx/codex-rs/core/src/chat_completions.rs
Michael Bolin 61b881d4e5 fix: agent instructions were not being included when ~/.codex/instructions.md was empty (#908)
I had seen issues where `codex-rs` would not always write files without
me pressuring it to do so, and between that and the report of
https://github.com/openai/codex/issues/900, I decided to look into this
further. I found two serious issues with agent instructions:

(1) We were only sending agent instructions on the first turn, but
looking at the TypeScript code, we should be sending them on every turn.

(2) There was a serious issue where the agent instructions were
frequently lost:

* The TypeScript CLI appears to keep writing `~/.codex/instructions.md`:
55142e3e6c/codex-cli/src/utils/config.ts (L586)
* If `instructions.md` is present, the Rust CLI uses the contents of it
INSTEAD OF the default prompt, even if `instructions.md` is empty:
55142e3e6c/codex-rs/core/src/config.rs (L202-L203)

The combination of these two things means that I have been using
`codex-rs` without these key instructions:
https://github.com/openai/codex/blob/main/codex-rs/core/prompt.md

Looking at the TypeScript code, it appears we should be concatenating
these three items every time (if they exist):

* `prompt.md`
* `~/.codex/instructions.md`
* nearest `AGENTS.md`

This PR fixes things so that:

* `Config.instructions` is `None` if `instructions.md` is empty
* `Payload.instructions` is now `&'a str` instead of `Option<&'a
String>` because we should always have _something_ to send
* `Prompt` now has a `get_full_instructions()` helper that returns a
`Cow<str>` that will always include the agent instructions first.
2025-05-12 17:24:44 -07:00

310 lines
11 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::time::Duration;
use bytes::Bytes;
use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
use futures::TryStreamExt;
use reqwest::StatusCode;
use serde_json::json;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tracing::debug;
use tracing::trace;
use crate::ModelProviderInfo;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::error::CodexErr;
use crate::error::Result;
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::models::ContentItem;
use crate::models::ResponseItem;
use crate::util::backoff;
/// Implementation for the classic Chat Completions API. This is intentionally
/// minimal: we only stream back plain assistant text.
pub(crate) async fn stream_chat_completions(
prompt: &Prompt,
model: &str,
client: &reqwest::Client,
provider: &ModelProviderInfo,
) -> Result<ResponseStream> {
// Build messages array
let mut messages = Vec::<serde_json::Value>::new();
let full_instructions = prompt.get_full_instructions();
messages.push(json!({"role": "system", "content": full_instructions}));
for item in &prompt.input {
if let ResponseItem::Message { role, content } = item {
let mut text = String::new();
for c in content {
match c {
ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => {
text.push_str(t);
}
_ => {}
}
}
messages.push(json!({"role": role, "content": text}));
}
}
let payload = json!({
"model": model,
"messages": messages,
"stream": true
});
let base_url = provider.base_url.trim_end_matches('/');
let url = format!("{}/chat/completions", base_url);
debug!(url, "POST (chat)");
trace!("request payload: {}", payload);
let api_key = provider.api_key()?;
let mut attempt = 0;
loop {
attempt += 1;
let mut req_builder = client.post(&url);
if let Some(api_key) = &api_key {
req_builder = req_builder.bearer_auth(api_key.clone());
}
let res = req_builder
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(&payload)
.send()
.await;
match res {
Ok(resp) if resp.status().is_success() => {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
tokio::spawn(process_chat_sse(stream, tx_event));
return Ok(ResponseStream { rx_event });
}
Ok(res) => {
let status = res.status();
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
let body = (res.text().await).unwrap_or_default();
return Err(CodexErr::UnexpectedStatus(status, body));
}
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
return Err(CodexErr::RetryLimit(status));
}
let retry_after_secs = res
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let delay = retry_after_secs
.map(|s| Duration::from_millis(s * 1_000))
.unwrap_or_else(|| backoff(attempt));
tokio::time::sleep(delay).await;
}
Err(e) => {
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
return Err(e.into());
}
let delay = backoff(attempt);
tokio::time::sleep(delay).await;
}
}
}
}
/// Lightweight SSE processor for the Chat Completions streaming format. The
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
/// of the pipeline can stay agnostic of the underlying wire format.
async fn process_chat_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
loop {
let sse = match timeout(idle_timeout, stream.next()).await {
Ok(Some(Ok(ev))) => ev,
Ok(Some(Err(e))) => {
let _ = tx_event.send(Err(CodexErr::Stream(e.to_string()))).await;
return;
}
Ok(None) => {
// Stream closed gracefully emit Completed with dummy id.
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
}))
.await;
return;
}
Err(_) => {
let _ = tx_event
.send(Err(CodexErr::Stream("idle timeout waiting for SSE".into())))
.await;
return;
}
};
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
if sse.data.trim() == "[DONE]" {
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
}))
.await;
return;
}
// Parse JSON chunk
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
Ok(v) => v,
Err(_) => continue,
};
let content_opt = chunk
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("delta"))
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str());
if let Some(content) = content_opt {
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: content.to_string(),
}],
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
}
}
/// Optional client-side aggregation helper
///
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
/// per-token deltas**. The stream stays silent while the model is thinking
/// and only emits two events per turn:
///
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
/// (fully concatenated).
/// 2. The original `ResponseEvent::Completed` right after it.
///
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
///
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
/// events.
pub(crate) struct AggregatedChatStream<S> {
inner: S,
cumulative: String,
pending_completed: Option<ResponseEvent>,
}
impl<S> Stream for AggregatedChatStream<S>
where
S: Stream<Item = Result<ResponseEvent>> + Unpin,
{
type Item = Result<ResponseEvent>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
// First, flush any buffered Completed event from the previous call.
if let Some(ev) = this.pending_completed.take() {
return Poll::Ready(Some(Ok(ev)));
}
loop {
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
// Accumulate *assistant* text but do not emit yet.
if let crate::models::ResponseItem::Message { role, content } = &item {
if role == "assistant" {
if let Some(text) = content.iter().find_map(|c| match c {
crate::models::ContentItem::OutputText { text } => Some(text),
_ => None,
}) {
this.cumulative.push_str(text);
}
}
}
// Swallow partial event; keep polling.
continue;
}
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
if !this.cumulative.is_empty() {
let aggregated_item = crate::models::ResponseItem::Message {
role: "assistant".to_string(),
content: vec![crate::models::ContentItem::OutputText {
text: std::mem::take(&mut this.cumulative),
}],
};
// Buffer Completed so it is returned *after* the aggregated message.
this.pending_completed = Some(ResponseEvent::Completed { response_id });
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
aggregated_item,
))));
}
// Nothing aggregated forward Completed directly.
return Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id })));
} // No other `Ok` variants exist at the moment, continue polling.
}
}
}
}
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
/// Returns a new stream that emits **only** the final assistant message
/// per turn instead of every incremental delta. The produced
/// `ResponseEvent` sequence for a typical text turn looks like:
///
/// ```ignore
/// OutputItemDone(<full message>)
/// Completed { .. }
/// ```
///
/// No other `OutputItemDone` events will be seen by the caller.
///
/// Usage:
///
/// ```ignore
/// let agg_stream = client.stream(&prompt).await?.aggregate();
/// while let Some(event) = agg_stream.next().await {
/// // event now contains cumulative text
/// }
/// ```
fn aggregate(self) -> AggregatedChatStream<Self> {
AggregatedChatStream {
inner: self,
cumulative: String::new(),
pending_completed: None,
}
}
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}