fix: chat completions API now also passes tools along (#1167)

Prior to this PR, there were two big misses in `chat_completions.rs`:

1. The loop in `stream_chat_completions()` was only including items of
type `ResponseItem::Message` when building up the `"messages"` JSON for
the `POST` request to the `chat/completions` endpoint. This fixes things
by ensuring other variants (`FunctionCall`, `LocalShellCall`, and
`FunctionCallOutput`) are included, as well.
2. In `process_chat_sse()`, we were not recording tool calls and were
only emitting items of type
`ResponseEvent::OutputItemDone(ResponseItem::Message)` to the stream.
Now we introduce `FunctionCallState`, which is used to accumulate the
`delta`s of type `tool_calls`, so we can ultimately emit a
`ResponseItem::FunctionCall`, when appropriate.

While function calling now appears to work for chat completions with my
local testing, I believe that there are still edge cases that are not
covered and that this codepath would benefit from a battery of
integration tests. (As part of that further cleanup, we should also work
to support streaming responses in the UI.)

The other important part of this PR is some cleanup in
`core/src/codex.rs`. In particular, it was hard to reason about how
`run_task()` was building up the list of messages to include in a
request across the various cases:

- Responses API
- Chat Completions API
- Responses API used in concert with ZDR

I like to think things are a bit cleaner now where:

- `zdr_transcript` (if present) contains all messages in the history of
the conversation, which includes function call outputs that have not
been sent back to the model yet
- `pending_input` includes any messages the user has submitted while the
turn is in flight that need to be injected as part of the next `POST` to
the model
- `input_for_next_turn` includes the tool call outputs that have not
been sent back to the model yet
This commit is contained in:
Michael Bolin
2025-06-02 13:47:51 -07:00
committed by GitHub
parent e40f86b446
commit d7245cbbc9
2 changed files with 299 additions and 71 deletions

View File

@@ -28,8 +28,7 @@ use crate::models::ResponseItem;
use crate::openai_tools::create_tools_json_for_chat_completions_api;
use crate::util::backoff;
/// Implementation for the classic Chat Completions API. This is intentionally
/// minimal: we only stream back plain assistant text.
/// Implementation for the classic Chat Completions API.
pub(crate) async fn stream_chat_completions(
prompt: &Prompt,
model: &str,
@@ -43,17 +42,67 @@ pub(crate) async fn stream_chat_completions(
messages.push(json!({"role": "system", "content": full_instructions}));
for item in &prompt.input {
if let ResponseItem::Message { role, content } = item {
let mut text = String::new();
for c in content {
match c {
ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => {
text.push_str(t);
match item {
ResponseItem::Message { role, content } => {
let mut text = String::new();
for c in content {
match c {
ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => {
text.push_str(t);
}
_ => {}
}
_ => {}
}
messages.push(json!({"role": role, "content": text}));
}
ResponseItem::FunctionCall {
name,
arguments,
call_id,
} => {
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}]
}));
}
ResponseItem::LocalShellCall {
id,
call_id: _,
status,
action,
} => {
// Confirm with API team.
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": id.clone().unwrap_or_else(|| "".to_string()),
"type": "local_shell_call",
"status": status,
"action": action,
}]
}));
}
ResponseItem::FunctionCallOutput { call_id, output } => {
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": output.content,
}));
}
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
// Omit these items from the conversation history.
continue;
}
messages.push(json!({"role": role, "content": text}));
}
}
@@ -68,9 +117,8 @@ pub(crate) async fn stream_chat_completions(
let base_url = provider.base_url.trim_end_matches('/');
let url = format!("{}/chat/completions", base_url);
debug!(url, "POST (chat)");
trace!(
"request payload: {}",
debug!(
"POST to {url}: {}",
serde_json::to_string_pretty(&payload).unwrap_or_default()
);
@@ -140,6 +188,21 @@ where
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
// State to accumulate a function call across streaming chunks.
// OpenAI may split the `arguments` string over multiple `delta` events
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
// keep collecting the pieces here and forward a single
// `ResponseItem::FunctionCall` once the call is complete.
#[derive(Default)]
struct FunctionCallState {
name: Option<String>,
arguments: String,
call_id: Option<String>,
active: bool,
}
let mut fn_call_state = FunctionCallState::default();
loop {
let sse = match timeout(idle_timeout, stream.next()).await {
Ok(Some(Ok(ev))) => ev,
@@ -179,23 +242,89 @@ where
Ok(v) => v,
Err(_) => continue,
};
trace!("chat_completions received SSE chunk: {chunk:?}");
let content_opt = chunk
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("delta"))
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str());
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
if let Some(content) = content_opt {
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: content.to_string(),
}],
};
if let Some(choice) = choice_opt {
// Handle assistant content tokens.
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: content.to_string(),
}],
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
// Handle streaming function / tool calls.
if let Some(tool_calls) = choice
.get("delta")
.and_then(|d| d.get("tool_calls"))
.and_then(|tc| tc.as_array())
{
if let Some(tool_call) = tool_calls.first() {
// Mark that we have an active function call in progress.
fn_call_state.active = true;
// Extract call_id if present.
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
}
// Extract function details if present.
if let Some(function) = tool_call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
fn_call_state.name.get_or_insert_with(|| name.to_string());
}
if let Some(args_fragment) =
function.get("arguments").and_then(|a| a.as_str())
{
fn_call_state.arguments.push_str(args_fragment);
}
}
}
}
// Emit end-of-turn when finish_reason signals completion.
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
match finish_reason {
"tool_calls" if fn_call_state.active => {
// Build the FunctionCall response item.
let item = ResponseItem::FunctionCall {
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
arguments: fn_call_state.arguments.clone(),
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
};
// Emit it downstream.
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
"stop" => {
// Regular turn without tool-call.
}
_ => {}
}
// Emit Completed regardless of reason so the agent can advance.
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
}))
.await;
// Prepare for potential next turn (should not happen in same stream).
// fn_call_state = FunctionCallState::default();
return; // End processing for this SSE stream.
}
}
}
}
@@ -242,9 +371,14 @@ where
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
// Accumulate *assistant* text but do not emit yet.
if let crate::models::ResponseItem::Message { role, content } = &item {
if role == "assistant" {
// If this is an incremental assistant message chunk, accumulate but
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
// away so downstream consumers see it.
let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant");
if is_assistant_delta {
if let crate::models::ResponseItem::Message { content, .. } = &item {
if let Some(text) = content.iter().find_map(|c| match c {
crate::models::ContentItem::OutputText { text } => Some(text),
_ => None,
@@ -252,10 +386,13 @@ where
this.cumulative.push_str(text);
}
}
// Swallow partial assistant chunk; keep polling.
continue;
}
// Swallow partial event; keep polling.
continue;
// Not an assistant message forward immediately.
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
}
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
if !this.cumulative.is_empty() {