fix: chat completions API now also passes tools along (#1167)
Prior to this PR, there were two big misses in `chat_completions.rs`: 1. The loop in `stream_chat_completions()` was only including items of type `ResponseItem::Message` when building up the `"messages"` JSON for the `POST` request to the `chat/completions` endpoint. This fixes things by ensuring other variants (`FunctionCall`, `LocalShellCall`, and `FunctionCallOutput`) are included, as well. 2. In `process_chat_sse()`, we were not recording tool calls and were only emitting items of type `ResponseEvent::OutputItemDone(ResponseItem::Message)` to the stream. Now we introduce `FunctionCallState`, which is used to accumulate the `delta`s of type `tool_calls`, so we can ultimately emit a `ResponseItem::FunctionCall`, when appropriate. While function calling now appears to work for chat completions with my local testing, I believe that there are still edge cases that are not covered and that this codepath would benefit from a battery of integration tests. (As part of that further cleanup, we should also work to support streaming responses in the UI.) The other important part of this PR is some cleanup in `core/src/codex.rs`. In particular, it was hard to reason about how `run_task()` was building up the list of messages to include in a request across the various cases: - Responses API - Chat Completions API - Responses API used in concert with ZDR I like to think things are a bit cleaner now where: - `zdr_transcript` (if present) contains all messages in the history of the conversation, which includes function call outputs that have not been sent back to the model yet - `pending_input` includes any messages the user has submitted while the turn is in flight that need to be injected as part of the next `POST` to the model - `input_for_next_turn` includes the tool call outputs that have not been sent back to the model yet
This commit is contained in:
@@ -28,8 +28,7 @@ use crate::models::ResponseItem;
|
||||
use crate::openai_tools::create_tools_json_for_chat_completions_api;
|
||||
use crate::util::backoff;
|
||||
|
||||
/// Implementation for the classic Chat Completions API. This is intentionally
|
||||
/// minimal: we only stream back plain assistant text.
|
||||
/// Implementation for the classic Chat Completions API.
|
||||
pub(crate) async fn stream_chat_completions(
|
||||
prompt: &Prompt,
|
||||
model: &str,
|
||||
@@ -43,17 +42,67 @@ pub(crate) async fn stream_chat_completions(
|
||||
messages.push(json!({"role": "system", "content": full_instructions}));
|
||||
|
||||
for item in &prompt.input {
|
||||
if let ResponseItem::Message { role, content } = item {
|
||||
let mut text = String::new();
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
match item {
|
||||
ResponseItem::Message { role, content } => {
|
||||
let mut text = String::new();
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
messages.push(json!({"role": role, "content": text}));
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
}
|
||||
}]
|
||||
}));
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id: _,
|
||||
status,
|
||||
action,
|
||||
} => {
|
||||
// Confirm with API team.
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id.clone().unwrap_or_else(|| "".to_string()),
|
||||
"type": "local_shell_call",
|
||||
"status": status,
|
||||
"action": action,
|
||||
}]
|
||||
}));
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output.content,
|
||||
}));
|
||||
}
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
|
||||
// Omit these items from the conversation history.
|
||||
continue;
|
||||
}
|
||||
messages.push(json!({"role": role, "content": text}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,9 +117,8 @@ pub(crate) async fn stream_chat_completions(
|
||||
let base_url = provider.base_url.trim_end_matches('/');
|
||||
let url = format!("{}/chat/completions", base_url);
|
||||
|
||||
debug!(url, "POST (chat)");
|
||||
trace!(
|
||||
"request payload: {}",
|
||||
debug!(
|
||||
"POST to {url}: {}",
|
||||
serde_json::to_string_pretty(&payload).unwrap_or_default()
|
||||
);
|
||||
|
||||
@@ -140,6 +188,21 @@ where
|
||||
|
||||
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
|
||||
// State to accumulate a function call across streaming chunks.
|
||||
// OpenAI may split the `arguments` string over multiple `delta` events
|
||||
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
|
||||
// keep collecting the pieces here and forward a single
|
||||
// `ResponseItem::FunctionCall` once the call is complete.
|
||||
#[derive(Default)]
|
||||
struct FunctionCallState {
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
call_id: Option<String>,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
let mut fn_call_state = FunctionCallState::default();
|
||||
|
||||
loop {
|
||||
let sse = match timeout(idle_timeout, stream.next()).await {
|
||||
Ok(Some(Ok(ev))) => ev,
|
||||
@@ -179,23 +242,89 @@ where
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
trace!("chat_completions received SSE chunk: {chunk:?}");
|
||||
|
||||
let content_opt = chunk
|
||||
.get("choices")
|
||||
.and_then(|c| c.get(0))
|
||||
.and_then(|c| c.get("delta"))
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str());
|
||||
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
|
||||
|
||||
if let Some(content) = content_opt {
|
||||
let item = ResponseItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: content.to_string(),
|
||||
}],
|
||||
};
|
||||
if let Some(choice) = choice_opt {
|
||||
// Handle assistant content tokens.
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
let item = ResponseItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: content.to_string(),
|
||||
}],
|
||||
};
|
||||
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
// Handle streaming function / tool calls.
|
||||
if let Some(tool_calls) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("tool_calls"))
|
||||
.and_then(|tc| tc.as_array())
|
||||
{
|
||||
if let Some(tool_call) = tool_calls.first() {
|
||||
// Mark that we have an active function call in progress.
|
||||
fn_call_state.active = true;
|
||||
|
||||
// Extract call_id if present.
|
||||
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
|
||||
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
|
||||
}
|
||||
|
||||
// Extract function details if present.
|
||||
if let Some(function) = tool_call.get("function") {
|
||||
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
|
||||
fn_call_state.name.get_or_insert_with(|| name.to_string());
|
||||
}
|
||||
|
||||
if let Some(args_fragment) =
|
||||
function.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
fn_call_state.arguments.push_str(args_fragment);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit end-of-turn when finish_reason signals completion.
|
||||
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
|
||||
match finish_reason {
|
||||
"tool_calls" if fn_call_state.active => {
|
||||
// Build the FunctionCall response item.
|
||||
let item = ResponseItem::FunctionCall {
|
||||
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
||||
arguments: fn_call_state.arguments.clone(),
|
||||
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
||||
};
|
||||
|
||||
// Emit it downstream.
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
"stop" => {
|
||||
// Regular turn without tool-call.
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Emit Completed regardless of reason so the agent can advance.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
}))
|
||||
.await;
|
||||
|
||||
// Prepare for potential next turn (should not happen in same stream).
|
||||
// fn_call_state = FunctionCallState::default();
|
||||
|
||||
return; // End processing for this SSE stream.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -242,9 +371,14 @@ where
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
// Accumulate *assistant* text but do not emit yet.
|
||||
if let crate::models::ResponseItem::Message { role, content } = &item {
|
||||
if role == "assistant" {
|
||||
// If this is an incremental assistant message chunk, accumulate but
|
||||
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
|
||||
// away so downstream consumers see it.
|
||||
|
||||
let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant");
|
||||
|
||||
if is_assistant_delta {
|
||||
if let crate::models::ResponseItem::Message { content, .. } = &item {
|
||||
if let Some(text) = content.iter().find_map(|c| match c {
|
||||
crate::models::ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
@@ -252,10 +386,13 @@ where
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
}
|
||||
|
||||
// Swallow partial assistant chunk; keep polling.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Swallow partial event; keep polling.
|
||||
continue;
|
||||
// Not an assistant message – forward immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
|
||||
if !this.cumulative.is_empty() {
|
||||
|
||||
Reference in New Issue
Block a user