diff --git a/codex-cli/src/cli.tsx b/codex-cli/src/cli.tsx index 08b95cb8..722b9319 100644 --- a/codex-cli/src/cli.tsx +++ b/codex-cli/src/cli.tsx @@ -19,15 +19,13 @@ import { ReviewDecision } from "./utils/agent/review"; import { AutoApprovalMode } from "./utils/auto-approval-mode"; import { checkForUpdates } from "./utils/check-updates"; import { + getApiKey, loadConfig, PRETTY_PRINT, INSTRUCTIONS_FILEPATH, } from "./utils/config"; import { createInputItem } from "./utils/input-utils"; -import { - isModelSupportedForResponses, - preloadModels, -} from "./utils/model-utils.js"; +import { isModelSupportedForResponses } from "./utils/model-utils.js"; import { parseToolCall } from "./utils/parsers"; import { onExit, setInkRenderer } from "./utils/terminal"; import chalk from "chalk"; @@ -97,6 +95,7 @@ const cli = meow( help: { type: "boolean", aliases: ["h"] }, view: { type: "string" }, model: { type: "string", aliases: ["m"] }, + provider: { type: "string", aliases: ["p"] }, image: { type: "string", isMultiple: true, aliases: ["i"] }, quiet: { type: "boolean", @@ -227,7 +226,19 @@ if (cli.flags.config) { // API key handling // --------------------------------------------------------------------------- -const apiKey = process.env["OPENAI_API_KEY"]; +const fullContextMode = Boolean(cli.flags.fullContext); +let config = loadConfig(undefined, undefined, { + cwd: process.cwd(), + disableProjectDoc: Boolean(cli.flags.noProjectDoc), + projectDocPath: cli.flags.projectDoc as string | undefined, + isFullContext: fullContextMode, +}); + +const prompt = cli.input[0]; +const model = cli.flags.model ?? config.model; +const imagePaths = cli.flags.image as Array | undefined; +const provider = cli.flags.provider ?? config.provider; +const apiKey = getApiKey(provider); if (!apiKey) { // eslint-disable-next-line no-console @@ -242,24 +253,13 @@ if (!apiKey) { process.exit(1); } -const fullContextMode = Boolean(cli.flags.fullContext); -let config = loadConfig(undefined, undefined, { - cwd: process.cwd(), - disableProjectDoc: Boolean(cli.flags.noProjectDoc), - projectDocPath: cli.flags.projectDoc as string | undefined, - isFullContext: fullContextMode, -}); - -const prompt = cli.input[0]; -const model = cli.flags.model; -const imagePaths = cli.flags.image as Array | undefined; - config = { apiKey, ...config, model: model ?? config.model, - flexMode: Boolean(cli.flags.flexMode), notify: Boolean(cli.flags.notify), + flexMode: Boolean(cli.flags.flexMode), + provider, }; // Check for updates after loading config @@ -281,7 +281,10 @@ if (cli.flags.flexMode) { } } -if (!(await isModelSupportedForResponses(config.model))) { +if ( + !(await isModelSupportedForResponses(config.model)) && + (!provider || provider.toLowerCase() === "openai") +) { // eslint-disable-next-line no-console console.error( `The model "${config.model}" does not appear in the list of models ` + @@ -378,8 +381,6 @@ const approvalPolicy: ApprovalPolicy = ? AutoApprovalMode.AUTO_EDIT : config.approvalMode || AutoApprovalMode.SUGGEST; -preloadModels(); - const instance = render( (config.model); + const [provider, setProvider] = useState(config.provider || "openai"); const [lastResponseId, setLastResponseId] = useState(null); const [items, setItems] = useState>([]); const [loading, setLoading] = useState(false); @@ -228,7 +229,7 @@ export default function TerminalChat({ log("creating NEW AgentLoop"); log( - `model=${model} instructions=${Boolean( + `model=${model} provider=${provider} instructions=${Boolean( config.instructions, )} approvalPolicy=${approvalPolicy}`, ); @@ -238,6 +239,7 @@ export default function TerminalChat({ agentRef.current = new AgentLoop({ model, + provider, config, instructions: config.instructions, approvalPolicy, @@ -307,10 +309,15 @@ export default function TerminalChat({ agentRef.current = undefined; forceUpdate(); // re‑render after teardown too }; - // We intentionally omit 'approvalPolicy' and 'confirmationPrompt' from the deps - // so switching modes or showing confirmation dialogs doesn’t tear down the loop. - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [model, config, requestConfirmation, additionalWritableRoots]); + }, [ + model, + provider, + config, + approvalPolicy, + confirmationPrompt, + requestConfirmation, + additionalWritableRoots, + ]); // whenever loading starts/stops, reset or start a timer — but pause the // timer while a confirmation overlay is displayed so we don't trigger a @@ -417,7 +424,7 @@ export default function TerminalChat({ // ──────────────────────────────────────────────────────────────── useEffect(() => { (async () => { - const available = await getAvailableModels(); + const available = await getAvailableModels(provider); if (model && available.length > 0 && !available.includes(model)) { setItems((prev) => [ ...prev, @@ -428,7 +435,7 @@ export default function TerminalChat({ content: [ { type: "input_text", - text: `Warning: model "${model}" is not in the list of available models returned by OpenAI.`, + text: `Warning: model "${model}" is not in the list of available models for provider "${provider}".`, }, ], }, @@ -470,6 +477,7 @@ export default function TerminalChat({ version: CLI_VERSION, PWD, model, + provider, approvalPolicy, colorsByPolicy, agent, @@ -566,6 +574,7 @@ export default function TerminalChat({ {overlayMode === "model" && ( { log( @@ -582,6 +591,13 @@ export default function TerminalChat({ prev && newModel !== model ? null : prev, ); + // Save model to config + saveConfig({ + ...config, + model: newModel, + provider: provider, + }); + setItems((prev) => [ ...prev, { @@ -599,6 +615,51 @@ export default function TerminalChat({ setOverlayMode("none"); }} + onSelectProvider={(newProvider) => { + log( + "TerminalChat: interruptAgent invoked – calling agent.cancel()", + ); + if (!agent) { + log("TerminalChat: agent is not ready yet"); + } + agent?.cancel(); + setLoading(false); + + // Select default model for the new provider + const defaultModel = model; + + // Save provider to config + const updatedConfig = { + ...config, + provider: newProvider, + model: defaultModel, + }; + saveConfig(updatedConfig); + + setProvider(newProvider); + setModel(defaultModel); + setLastResponseId((prev) => + prev && newProvider !== provider ? null : prev, + ); + + setItems((prev) => [ + ...prev, + { + id: `switch-provider-${Date.now()}`, + type: "message", + role: "system", + content: [ + { + type: "input_text", + text: `Switched provider to ${newProvider} with model ${defaultModel}`, + }, + ], + }, + ]); + + // Don't close the overlay so user can select a model for the new provider + // setOverlayMode("none"); + }} onExit={() => setOverlayMode("none")} /> )} diff --git a/codex-cli/src/components/chat/terminal-header.tsx b/codex-cli/src/components/chat/terminal-header.tsx index 4c0ed2e1..bdc49946 100644 --- a/codex-cli/src/components/chat/terminal-header.tsx +++ b/codex-cli/src/components/chat/terminal-header.tsx @@ -9,6 +9,7 @@ export interface TerminalHeaderProps { version: string; PWD: string; model: string; + provider?: string; approvalPolicy: string; colorsByPolicy: Record; agent?: AgentLoop; @@ -21,6 +22,7 @@ const TerminalHeader: React.FC = ({ version, PWD, model, + provider = "openai", approvalPolicy, colorsByPolicy, agent, @@ -32,7 +34,7 @@ const TerminalHeader: React.FC = ({ {terminalRows < 10 ? ( // Compact header for small terminal windows - ● Codex v{version} – {PWD} – {model} –{" "} + ● Codex v{version} – {PWD} – {model} ({provider}) –{" "} {approvalPolicy} {flexModeEnabled ? " – flex-mode" : ""} @@ -65,6 +67,10 @@ const TerminalHeader: React.FC = ({ model: {model} + + provider:{" "} + {provider} + approval:{" "} diff --git a/codex-cli/src/components/model-overlay.tsx b/codex-cli/src/components/model-overlay.tsx index c006eb31..ec5e40d5 100644 --- a/codex-cli/src/components/model-overlay.tsx +++ b/codex-cli/src/components/model-overlay.tsx @@ -1,8 +1,9 @@ import TypeaheadOverlay from "./typeahead-overlay.js"; import { getAvailableModels, - RECOMMENDED_MODELS, + RECOMMENDED_MODELS as _RECOMMENDED_MODELS, } from "../utils/model-utils.js"; +import { providers } from "../utils/providers.js"; import { Box, Text, useInput } from "ink"; import React, { useEffect, useState } from "react"; @@ -16,39 +17,51 @@ import React, { useEffect, useState } from "react"; */ type Props = { currentModel: string; + currentProvider?: string; hasLastResponse: boolean; onSelect: (model: string) => void; + onSelectProvider?: (provider: string) => void; onExit: () => void; }; export default function ModelOverlay({ currentModel, + currentProvider = "openai", hasLastResponse, onSelect, + onSelectProvider, onExit, }: Props): JSX.Element { const [items, setItems] = useState>( [], ); + const [providerItems, _setProviderItems] = useState< + Array<{ label: string; value: string }> + >(Object.values(providers).map((p) => ({ label: p.name, value: p.name }))); + const [mode, setMode] = useState<"model" | "provider">("model"); + const [isLoading, setIsLoading] = useState(true); + // This effect will run when the provider changes to update the model list useEffect(() => { + setIsLoading(true); (async () => { - const models = await getAvailableModels(); - - // Split the list into recommended and “other” models. - const recommended = RECOMMENDED_MODELS.filter((m) => models.includes(m)); - const others = models.filter((m) => !recommended.includes(m)); - - const ordered = [...recommended, ...others.sort()]; - - setItems( - ordered.map((m) => ({ - label: recommended.includes(m) ? `⭐ ${m}` : m, - value: m, - })), - ); + try { + const models = await getAvailableModels(currentProvider); + // Convert the models to the format needed by TypeaheadOverlay + setItems( + models.map((m) => ({ + label: m, + value: m, + })), + ); + } catch (error) { + // Silently handle errors - remove console.error + // console.error("Error loading models:", error); + } finally { + setIsLoading(false); + } })(); - }, []); + }, [currentProvider]); // --------------------------------------------------------------------------- // If the conversation already contains a response we cannot change the model @@ -58,10 +71,14 @@ export default function ModelOverlay({ // available action is to dismiss the overlay (Esc or Enter). // --------------------------------------------------------------------------- - // Always register input handling so hooks are called consistently. + // Register input handling for switching between model and provider selection useInput((_input, key) => { if (hasLastResponse && (key.escape || key.return)) { onExit(); + } else if (!hasLastResponse) { + if (key.tab) { + setMode(mode === "model" ? "provider" : "model"); + } } }); @@ -91,13 +108,47 @@ export default function ModelOverlay({ ); } + if (mode === "provider") { + return ( + + + Current provider:{" "} + {currentProvider} + + press tab to switch to model selection + + } + initialItems={providerItems} + currentValue={currentProvider} + onSelect={(provider) => { + if (onSelectProvider) { + onSelectProvider(provider); + // Immediately switch to model selection so user can pick a model for the new provider + setMode("model"); + } + }} + onExit={onExit} + /> + ); + } + return ( - Current model: {currentModel} - + + + Current model: {currentModel} + + + Current provider: {currentProvider} + + {isLoading && Loading models...} + press tab to switch to provider selection + } initialItems={items} currentValue={currentModel} diff --git a/codex-cli/src/components/singlepass-cli-app.tsx b/codex-cli/src/components/singlepass-cli-app.tsx index 5d649424..c52ae1be 100644 --- a/codex-cli/src/components/singlepass-cli-app.tsx +++ b/codex-cli/src/components/singlepass-cli-app.tsx @@ -5,7 +5,12 @@ import type { FileOperation } from "../utils/singlepass/file_ops"; import Spinner from "./vendor/ink-spinner"; // Third‑party / vendor components import TextInput from "./vendor/ink-text-input"; -import { OPENAI_TIMEOUT_MS, OPENAI_BASE_URL } from "../utils/config"; +import { + OPENAI_TIMEOUT_MS, + OPENAI_BASE_URL as _OPENAI_BASE_URL, + getBaseUrl, + getApiKey, +} from "../utils/config"; import { generateDiffSummary, generateEditSummary, @@ -394,8 +399,8 @@ export function SinglePassApp({ }); const openai = new OpenAI({ - apiKey: config.apiKey ?? "", - baseURL: OPENAI_BASE_URL || undefined, + apiKey: getApiKey(config.provider), + baseURL: getBaseUrl(config.provider), timeout: OPENAI_TIMEOUT_MS, }); const chatResp = await openai.beta.chat.completions.parse({ diff --git a/codex-cli/src/utils/agent/agent-loop.ts b/codex-cli/src/utils/agent/agent-loop.ts index 172ba672..311cad96 100644 --- a/codex-cli/src/utils/agent/agent-loop.ts +++ b/codex-cli/src/utils/agent/agent-loop.ts @@ -1,16 +1,19 @@ import type { ReviewDecision } from "./review.js"; import type { ApplyPatchCommand, ApprovalPolicy } from "../../approvals.js"; import type { AppConfig } from "../config.js"; +import type { ResponseEvent } from "../responses.js"; import type { ResponseFunctionToolCall, ResponseInputItem, ResponseItem, + ResponseCreateParams, } from "openai/resources/responses/responses.mjs"; import type { Reasoning } from "openai/resources.mjs"; import { log } from "./log.js"; -import { OPENAI_BASE_URL, OPENAI_TIMEOUT_MS } from "../config.js"; +import { OPENAI_TIMEOUT_MS, getApiKey, getBaseUrl } from "../config.js"; import { parseToolCallArguments } from "../parsers.js"; +import { responsesCreateViaChatCompletions } from "../responses.js"; import { ORIGIN, CLI_VERSION, @@ -39,6 +42,7 @@ const alreadyProcessedResponses = new Set(); type AgentLoopParams = { model: string; + provider?: string; config?: AppConfig; instructions?: string; approvalPolicy: ApprovalPolicy; @@ -58,6 +62,7 @@ type AgentLoopParams = { export class AgentLoop { private model: string; + private provider: string; private instructions?: string; private approvalPolicy: ApprovalPolicy; private config: AppConfig; @@ -198,6 +203,7 @@ export class AgentLoop { // private cumulativeThinkingMs = 0; constructor({ model, + provider = "openai", instructions, approvalPolicy, // `config` used to be required. Some unit‑tests (and potentially other @@ -214,6 +220,7 @@ export class AgentLoop { additionalWritableRoots, }: AgentLoopParams & { config?: AppConfig }) { this.model = model; + this.provider = provider; this.instructions = instructions; this.approvalPolicy = approvalPolicy; @@ -236,7 +243,9 @@ export class AgentLoop { this.sessionId = getSessionId() || randomUUID().replaceAll("-", ""); // Configure OpenAI client with optional timeout (ms) from environment const timeoutMs = OPENAI_TIMEOUT_MS; - const apiKey = this.config.apiKey ?? process.env["OPENAI_API_KEY"] ?? ""; + const apiKey = getApiKey(this.provider); + const baseURL = getBaseUrl(this.provider); + this.oai = new OpenAI({ // The OpenAI JS SDK only requires `apiKey` when making requests against // the official API. When running unit‑tests we stub out all network @@ -245,7 +254,7 @@ export class AgentLoop { // errors inside the SDK (it validates that `apiKey` is a non‑empty // string when the field is present). ...(apiKey ? { apiKey } : {}), - baseURL: OPENAI_BASE_URL, + baseURL, defaultHeaders: { originator: ORIGIN, version: CLI_VERSION, @@ -492,11 +501,23 @@ export class AgentLoop { const mergedInstructions = [prefix, this.instructions] .filter(Boolean) .join("\n"); + + const responseCall = + !this.config.provider || + this.config.provider?.toLowerCase() === "openai" + ? (params: ResponseCreateParams) => + this.oai.responses.create(params) + : (params: ResponseCreateParams) => + responsesCreateViaChatCompletions( + this.oai, + params as ResponseCreateParams & { stream: true }, + ); log( `instructions (length ${mergedInstructions.length}): ${mergedInstructions}`, ); + // eslint-disable-next-line no-await-in-loop - stream = await this.oai.responses.create({ + stream = await responseCall({ model: this.model, instructions: mergedInstructions, previous_response_id: lastResponseId || undefined, @@ -720,7 +741,7 @@ export class AgentLoop { try { // eslint-disable-next-line no-await-in-loop - for await (const event of stream) { + for await (const event of stream as AsyncIterable) { log(`AgentLoop.run(): response event ${event.type}`); // process and surface each item (no‑op until we can depend on streaming events) diff --git a/codex-cli/src/utils/config.ts b/codex-cli/src/utils/config.ts index 3be77059..ada95025 100644 --- a/codex-cli/src/utils/config.ts +++ b/codex-cli/src/utils/config.ts @@ -10,6 +10,7 @@ import type { FullAutoErrorMode } from "./auto-approval-mode.js"; import { log } from "./agent/log.js"; import { AutoApprovalMode } from "./auto-approval-mode.js"; +import { providers } from "./providers.js"; import { existsSync, mkdirSync, readFileSync, writeFileSync } from "fs"; import { load as loadYaml, dump as dumpYaml } from "js-yaml"; import { homedir } from "os"; @@ -40,12 +41,33 @@ export function setApiKey(apiKey: string): void { OPENAI_API_KEY = apiKey; } +export function getBaseUrl(provider: string = "openai"): string | undefined { + const providerInfo = providers[provider.toLowerCase()]; + if (providerInfo) { + return providerInfo.baseURL; + } + return undefined; +} + +export function getApiKey(provider: string = "openai"): string | undefined { + const providerInfo = providers[provider.toLowerCase()]; + if (providerInfo) { + if (providerInfo.name === "Ollama") { + return process.env[providerInfo.envKey] ?? "dummy"; + } + return process.env[providerInfo.envKey]; + } + + return undefined; +} + // Formatting (quiet mode-only). export const PRETTY_PRINT = Boolean(process.env["PRETTY_PRINT"] || ""); // Represents config as persisted in config.json. export type StoredConfig = { model?: string; + provider?: string; approvalMode?: AutoApprovalMode; fullAutoErrorMode?: FullAutoErrorMode; memory?: MemoryConfig; @@ -76,6 +98,7 @@ export type MemoryConfig = { export type AppConfig = { apiKey?: string; model: string; + provider?: string; instructions: string; approvalMode?: AutoApprovalMode; fullAutoErrorMode?: FullAutoErrorMode; @@ -270,6 +293,7 @@ export const loadConfig = ( (options.isFullContext ? DEFAULT_FULL_CONTEXT_MODEL : DEFAULT_AGENTIC_MODEL), + provider: storedConfig.provider, instructions: combinedInstructions, notify: storedConfig.notify === true, approvalMode: storedConfig.approvalMode, @@ -389,6 +413,7 @@ export const saveConfig = ( // Create the config object to save const configToSave: StoredConfig = { model: config.model, + provider: config.provider, approvalMode: config.approvalMode, }; diff --git a/codex-cli/src/utils/model-utils.ts b/codex-cli/src/utils/model-utils.ts index 07d924c0..946756cb 100644 --- a/codex-cli/src/utils/model-utils.ts +++ b/codex-cli/src/utils/model-utils.ts @@ -1,4 +1,4 @@ -import { OPENAI_API_KEY } from "./config"; +import { getBaseUrl, getApiKey } from "./config"; import OpenAI from "openai"; const MODEL_LIST_TIMEOUT_MS = 2_000; // 2 seconds @@ -12,44 +12,38 @@ export const RECOMMENDED_MODELS: Array = ["o4-mini", "o3"]; * lifetime of the process and the results are cached for subsequent calls. */ -let modelsPromise: Promise> | null = null; - -async function fetchModels(): Promise> { +async function fetchModels(provider: string): Promise> { // If the user has not configured an API key we cannot hit the network. - if (!OPENAI_API_KEY) { - return RECOMMENDED_MODELS; + if (!getApiKey(provider)) { + throw new Error("No API key configured for provider: " + provider); } + const baseURL = getBaseUrl(provider); try { - const openai = new OpenAI({ apiKey: OPENAI_API_KEY }); + const openai = new OpenAI({ apiKey: getApiKey(provider), baseURL }); const list = await openai.models.list(); - const models: Array = []; for await (const model of list as AsyncIterable<{ id?: string }>) { if (model && typeof model.id === "string") { - models.push(model.id); + let modelStr = model.id; + // fix for gemini + if (modelStr.startsWith("models/")) { + modelStr = modelStr.replace("models/", ""); + } + models.push(modelStr); } } return models.sort(); - } catch { + } catch (error) { return []; } } -export function preloadModels(): void { - if (!modelsPromise) { - // Fire‑and‑forget – callers that truly need the list should `await` - // `getAvailableModels()` instead. - void getAvailableModels(); - } -} - -export async function getAvailableModels(): Promise> { - if (!modelsPromise) { - modelsPromise = fetchModels(); - } - return modelsPromise; +export async function getAvailableModels( + provider: string, +): Promise> { + return fetchModels(provider.toLowerCase()); } /** @@ -70,7 +64,7 @@ export async function isModelSupportedForResponses( try { const models = await Promise.race>([ - getAvailableModels(), + getAvailableModels("openai"), new Promise>((resolve) => setTimeout(() => resolve([]), MODEL_LIST_TIMEOUT_MS), ), diff --git a/codex-cli/src/utils/providers.ts b/codex-cli/src/utils/providers.ts new file mode 100644 index 00000000..adf628bb --- /dev/null +++ b/codex-cli/src/utils/providers.ts @@ -0,0 +1,45 @@ +export const providers: Record< + string, + { name: string; baseURL: string; envKey: string } +> = { + openai: { + name: "OpenAI", + baseURL: "https://api.openai.com/v1", + envKey: "OPENAI_API_KEY", + }, + openrouter: { + name: "OpenRouter", + baseURL: "https://openrouter.ai/api/v1", + envKey: "OPENROUTER_API_KEY", + }, + gemini: { + name: "Gemini", + baseURL: "https://generativelanguage.googleapis.com/v1beta/openai", + envKey: "GEMINI_API_KEY", + }, + ollama: { + name: "Ollama", + baseURL: "http://localhost:11434/v1", + envKey: "OLLAMA_API_KEY", + }, + mistral: { + name: "Mistral", + baseURL: "https://api.mistral.ai/v1", + envKey: "MISTRAL_API_KEY", + }, + deepseek: { + name: "DeepSeek", + baseURL: "https://api.deepseek.com", + envKey: "DEEPSEEK_API_KEY", + }, + xai: { + name: "xAI", + baseURL: "https://api.x.ai/v1", + envKey: "XAI_API_KEY", + }, + groq: { + name: "Groq", + baseURL: "https://api.groq.com/openai/v1", + envKey: "GROQ_API_KEY", + }, +}; diff --git a/codex-cli/src/utils/responses.ts b/codex-cli/src/utils/responses.ts new file mode 100644 index 00000000..bc972913 --- /dev/null +++ b/codex-cli/src/utils/responses.ts @@ -0,0 +1,736 @@ +import type { OpenAI } from "openai"; +import type { + ResponseCreateParams, + Response, +} from "openai/resources/responses/responses"; +// Define interfaces based on OpenAI API documentation +type ResponseCreateInput = ResponseCreateParams; +type ResponseOutput = Response; +// interface ResponseOutput { +// id: string; +// object: 'response'; +// created_at: number; +// status: 'completed' | 'failed' | 'in_progress' | 'incomplete'; +// error: { code: string; message: string } | null; +// incomplete_details: { reason: string } | null; +// instructions: string | null; +// max_output_tokens: number | null; +// model: string; +// output: Array<{ +// type: 'message'; +// id: string; +// status: 'completed' | 'in_progress'; +// role: 'assistant'; +// content: Array<{ +// type: 'output_text' | 'function_call'; +// text?: string; +// annotations?: Array; +// tool_call?: { +// id: string; +// type: 'function'; +// function: { name: string; arguments: string }; +// }; +// }>; +// }>; +// parallel_tool_calls: boolean; +// previous_response_id: string | null; +// reasoning: { effort: string | null; summary: string | null }; +// store: boolean; +// temperature: number; +// text: { format: { type: 'text' } }; +// tool_choice: string | object; +// tools: Array; +// top_p: number; +// truncation: string; +// usage: { +// input_tokens: number; +// input_tokens_details: { cached_tokens: number }; +// output_tokens: number; +// output_tokens_details: { reasoning_tokens: number }; +// total_tokens: number; +// } | null; +// user: string | null; +// metadata: Record; +// } + +// Define types for the ResponseItem content and parts +type ResponseContentPart = { + type: string; + [key: string]: unknown; +}; + +type ResponseItemType = { + type: string; + id?: string; + status?: string; + role?: string; + content?: Array; + [key: string]: unknown; +}; + +type ResponseEvent = + | { type: "response.created"; response: Partial } + | { type: "response.in_progress"; response: Partial } + | { + type: "response.output_item.added"; + output_index: number; + item: ResponseItemType; + } + | { + type: "response.content_part.added"; + item_id: string; + output_index: number; + content_index: number; + part: ResponseContentPart; + } + | { + type: "response.output_text.delta"; + item_id: string; + output_index: number; + content_index: number; + delta: string; + } + | { + type: "response.output_text.done"; + item_id: string; + output_index: number; + content_index: number; + text: string; + } + | { + type: "response.function_call_arguments.delta"; + item_id: string; + output_index: number; + content_index: number; + delta: string; + } + | { + type: "response.function_call_arguments.done"; + item_id: string; + output_index: number; + content_index: number; + arguments: string; + } + | { + type: "response.content_part.done"; + item_id: string; + output_index: number; + content_index: number; + part: ResponseContentPart; + } + | { + type: "response.output_item.done"; + output_index: number; + item: ResponseItemType; + } + | { type: "response.completed"; response: ResponseOutput } + | { type: "error"; code: string; message: string; param: string | null }; + +// Define a type for tool call data +type ToolCallData = { + id: string; + name: string; + arguments: string; +}; + +// Define a type for usage data +type UsageData = { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + input_tokens?: number; + input_tokens_details?: { cached_tokens: number }; + output_tokens?: number; + output_tokens_details?: { reasoning_tokens: number }; + [key: string]: unknown; +}; + +// Define a type for content output +type ResponseContentOutput = + | { + type: "function_call"; + call_id: string; + name: string; + arguments: string; + [key: string]: unknown; + } + | { + type: "output_text"; + text: string; + annotations: Array; + [key: string]: unknown; + }; + +// Global map to store conversation histories +const conversationHistories = new Map< + string, + { + previous_response_id: string | null; + messages: Array; + } +>(); + +// Utility function to generate unique IDs +function generateId(prefix: string = "msg"): string { + return `${prefix}_${Math.random().toString(36).substr(2, 9)}`; +} + +// Function to convert ResponseInputItem to ChatCompletionMessageParam +type ResponseInputItem = ResponseCreateInput["input"][number]; + +function convertInputItemToMessage( + item: string | ResponseInputItem, +): OpenAI.Chat.Completions.ChatCompletionMessageParam { + // Handle string inputs as content for a user message + if (typeof item === "string") { + return { role: "user", content: item }; + } + + // At this point we know it's a ResponseInputItem + const responseItem = item; + + if (responseItem.type === "message") { + // Use a more specific type assertion for the message content + const content = Array.isArray(responseItem.content) + ? responseItem.content + .filter((c) => typeof c === "object" && c.type === "input_text") + .map((c) => + typeof c === "object" && "text" in c + ? (c["text"] as string) || "" + : "", + ) + .join("") + : ""; + return { role: responseItem.role, content }; + } else if (responseItem.type === "function_call_output") { + return { + role: "tool", + tool_call_id: responseItem.call_id, + content: responseItem.output, + }; + } + throw new Error(`Unsupported input item type: ${responseItem.type}`); +} + +// Function to get full messages including history +function getFullMessages( + input: ResponseCreateInput, +): Array { + let baseHistory: Array = + []; + if (input.previous_response_id) { + const prev = conversationHistories.get(input.previous_response_id); + if (!prev) { + throw new Error( + `Previous response not found: ${input.previous_response_id}`, + ); + } + baseHistory = prev.messages; + } + + // Handle both string and ResponseInputItem in input.input + const newInputMessages = Array.isArray(input.input) + ? input.input.map(convertInputItemToMessage) + : [convertInputItemToMessage(input.input)]; + + const messages = [...baseHistory, ...newInputMessages]; + if ( + input.instructions && + messages[0]?.role !== "system" && + messages[0]?.role !== "developer" + ) { + return [{ role: "system", content: input.instructions }, ...messages]; + } + return messages; +} + +// Function to convert tools +function convertTools( + tools?: ResponseCreateInput["tools"], +): Array | undefined { + return tools + ?.filter((tool) => tool.type === "function") + .map((tool) => ({ + type: "function" as const, + function: { + name: tool.name, + description: tool.description || undefined, + parameters: tool.parameters, + }, + })); +} + +// Main function with overloading +async function responsesCreateViaChatCompletions( + openai: OpenAI, + input: ResponseCreateInput & { stream: true }, +): Promise>; +async function responsesCreateViaChatCompletions( + openai: OpenAI, + input: ResponseCreateInput & { stream?: false }, +): Promise; +async function responsesCreateViaChatCompletions( + openai: OpenAI, + input: ResponseCreateInput, +): Promise> { + if (input.stream) { + return streamResponses(openai, input); + } else { + return nonStreamResponses(openai, input); + } +} + +// Non-streaming implementation +async function nonStreamResponses( + openai: OpenAI, + input: ResponseCreateInput, +): Promise { + const fullMessages = getFullMessages(input); + const chatTools = convertTools(input.tools); + const webSearchOptions = input.tools?.some( + (tool) => tool.type === "function" && tool.name === "web_search", + ) + ? {} + : undefined; + + const chatInput: OpenAI.Chat.Completions.ChatCompletionCreateParams = { + model: input.model, + messages: fullMessages, + tools: chatTools, + web_search_options: webSearchOptions, + temperature: input.temperature, + top_p: input.top_p, + tool_choice: (input.tool_choice === "auto" + ? "auto" + : input.tool_choice) as OpenAI.Chat.Completions.ChatCompletionCreateParams["tool_choice"], + user: input.user, + metadata: input.metadata, + }; + + try { + const chatResponse = await openai.chat.completions.create(chatInput); + if (!("choices" in chatResponse) || chatResponse.choices.length === 0) { + throw new Error("No choices in chat completion response"); + } + const assistantMessage = chatResponse.choices?.[0]?.message; + if (!assistantMessage) { + throw new Error("No assistant message in chat completion response"); + } + + // Construct ResponseOutput + const responseId = generateId("resp"); + const outputItemId = generateId("msg"); + const outputContent: Array = []; + + // Check if the response contains tool calls + const hasFunctionCalls = + assistantMessage.tool_calls && assistantMessage.tool_calls.length > 0; + + if (hasFunctionCalls && assistantMessage.tool_calls) { + for (const toolCall of assistantMessage.tool_calls) { + if (toolCall.type === "function") { + outputContent.push({ + type: "function_call", + call_id: toolCall.id, + name: toolCall.function.name, + arguments: toolCall.function.arguments, + }); + } + } + } + + if (assistantMessage.content) { + outputContent.push({ + type: "output_text", + text: assistantMessage.content, + annotations: [], + }); + } + + // Create response with appropriate status and properties + const responseOutput = { + id: responseId, + object: "response", + created_at: Math.floor(Date.now() / 1000), + status: hasFunctionCalls ? "requires_action" : "completed", + error: null, + incomplete_details: null, + instructions: null, + max_output_tokens: null, + model: chatResponse.model, + output: [ + { + type: "message", + id: outputItemId, + status: "completed", + role: "assistant", + content: outputContent, + }, + ], + parallel_tool_calls: input.parallel_tool_calls ?? false, + previous_response_id: input.previous_response_id ?? null, + reasoning: null, + temperature: input.temperature ?? 1.0, + text: { format: { type: "text" } }, + tool_choice: input.tool_choice ?? "auto", + tools: input.tools ?? [], + top_p: input.top_p ?? 1.0, + truncation: input.truncation ?? "disabled", + usage: chatResponse.usage + ? { + input_tokens: chatResponse.usage.prompt_tokens, + input_tokens_details: { cached_tokens: 0 }, + output_tokens: chatResponse.usage.completion_tokens, + output_tokens_details: { reasoning_tokens: 0 }, + total_tokens: chatResponse.usage.total_tokens, + } + : undefined, + user: input.user ?? undefined, + metadata: input.metadata ?? {}, + output_text: "", + } as ResponseOutput; + + // Add required_action property for tool calls + if (hasFunctionCalls && assistantMessage.tool_calls) { + // Define type with required action + type ResponseWithAction = Partial & { + required_action: unknown; + }; + + // Use the defined type for the assertion + (responseOutput as ResponseWithAction).required_action = { + type: "submit_tool_outputs", + submit_tool_outputs: { + tool_calls: assistantMessage.tool_calls.map((toolCall) => ({ + id: toolCall.id, + type: toolCall.type, + function: { + name: toolCall.function.name, + arguments: toolCall.function.arguments, + }, + })), + }, + }; + } + + // Store history + const newHistory = [...fullMessages, assistantMessage]; + conversationHistories.set(responseId, { + previous_response_id: input.previous_response_id ?? null, + messages: newHistory, + }); + + return responseOutput; + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + throw new Error(`Failed to process chat completion: ${errorMessage}`); + } +} + +// Streaming implementation +async function* streamResponses( + openai: OpenAI, + input: ResponseCreateInput, +): AsyncGenerator { + const fullMessages = getFullMessages(input); + const chatTools = convertTools(input.tools); + const webSearchOptions = input.tools?.some( + (tool) => tool.type === "function" && tool.name === "web_search", + ) + ? {} + : undefined; + + const chatInput: OpenAI.Chat.Completions.ChatCompletionCreateParams = { + model: input.model, + messages: fullMessages, + tools: chatTools, + web_search_options: webSearchOptions, + temperature: input.temperature ?? 1.0, + top_p: input.top_p ?? 1.0, + tool_choice: (input.tool_choice === "auto" + ? "auto" + : input.tool_choice) as OpenAI.Chat.Completions.ChatCompletionCreateParams["tool_choice"], + stream: true, + user: input.user, + metadata: input.metadata, + }; + + try { + // console.error("chatInput", JSON.stringify(chatInput)); + const stream = await openai.chat.completions.create(chatInput); + + // Initialize state + const responseId = generateId("resp"); + const outputItemId = generateId("msg"); + let textContentAdded = false; + let textContent = ""; + const toolCalls = new Map(); + let usage: UsageData | null = null; + const finalOutputItem: Array = []; + // Initial response + const initialResponse: Partial = { + id: responseId, + object: "response" as const, + created_at: Math.floor(Date.now() / 1000), + status: "in_progress" as const, + model: input.model, + output: [], + error: null, + incomplete_details: null, + instructions: null, + max_output_tokens: null, + parallel_tool_calls: true, + previous_response_id: input.previous_response_id ?? null, + reasoning: null, + temperature: input.temperature ?? 1.0, + text: { format: { type: "text" } }, + tool_choice: input.tool_choice ?? "auto", + tools: input.tools ?? [], + top_p: input.top_p ?? 1.0, + truncation: input.truncation ?? "disabled", + usage: undefined, + user: input.user ?? undefined, + metadata: input.metadata ?? {}, + output_text: "", + }; + yield { type: "response.created", response: initialResponse }; + yield { type: "response.in_progress", response: initialResponse }; + let isToolCall = false; + for await (const chunk of stream as AsyncIterable) { + // console.error('\nCHUNK: ', JSON.stringify(chunk)); + const choice = chunk.choices[0]; + if (!choice) { + continue; + } + if ( + !isToolCall && + (("tool_calls" in choice.delta && choice.delta.tool_calls) || + choice.finish_reason === "tool_calls") + ) { + isToolCall = true; + } + + if (chunk.usage) { + usage = { + prompt_tokens: chunk.usage.prompt_tokens, + completion_tokens: chunk.usage.completion_tokens, + total_tokens: chunk.usage.total_tokens, + input_tokens: chunk.usage.prompt_tokens, + input_tokens_details: { cached_tokens: 0 }, + output_tokens: chunk.usage.completion_tokens, + output_tokens_details: { reasoning_tokens: 0 }, + }; + } + if (isToolCall) { + for (const tcDelta of choice.delta.tool_calls || []) { + const tcIndex = tcDelta.index; + const content_index = textContentAdded ? tcIndex + 1 : tcIndex; + + if (!toolCalls.has(tcIndex)) { + // New tool call + const toolCallId = tcDelta.id || generateId("call"); + const functionName = tcDelta.function?.name || ""; + + yield { + type: "response.output_item.added", + item: { + type: "function_call", + id: outputItemId, + status: "in_progress", + call_id: toolCallId, + name: functionName, + arguments: "", + }, + output_index: 0, + }; + toolCalls.set(tcIndex, { + id: toolCallId, + name: functionName, + arguments: "", + }); + } + + if (tcDelta.function?.arguments) { + const current = toolCalls.get(tcIndex); + if (current) { + current.arguments += tcDelta.function.arguments; + yield { + type: "response.function_call_arguments.delta", + item_id: outputItemId, + output_index: 0, + content_index, + delta: tcDelta.function.arguments, + }; + } + } + } + + if (choice.finish_reason === "tool_calls") { + for (const [tcIndex, tc] of toolCalls) { + const item = { + type: "function_call", + id: outputItemId, + status: "completed", + call_id: tc.id, + name: tc.name, + arguments: tc.arguments, + }; + yield { + type: "response.function_call_arguments.done", + item_id: outputItemId, + output_index: tcIndex, + content_index: textContentAdded ? tcIndex + 1 : tcIndex, + arguments: tc.arguments, + }; + yield { + type: "response.output_item.done", + output_index: tcIndex, + item, + }; + finalOutputItem.push(item as unknown as ResponseContentOutput); + } + } else { + continue; + } + } else { + if (!textContentAdded) { + yield { + type: "response.content_part.added", + item_id: outputItemId, + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "", annotations: [] }, + }; + textContentAdded = true; + } + if (choice.delta.content?.length) { + yield { + type: "response.output_text.delta", + item_id: outputItemId, + output_index: 0, + content_index: 0, + delta: choice.delta.content, + }; + textContent += choice.delta.content; + } + if (choice.finish_reason) { + yield { + type: "response.output_text.done", + item_id: outputItemId, + output_index: 0, + content_index: 0, + text: textContent, + }; + yield { + type: "response.content_part.done", + item_id: outputItemId, + output_index: 0, + content_index: 0, + part: { type: "output_text", text: textContent, annotations: [] }, + }; + const item = { + type: "message", + id: outputItemId, + status: "completed", + role: "assistant", + content: [ + { type: "output_text", text: textContent, annotations: [] }, + ], + }; + yield { + type: "response.output_item.done", + output_index: 0, + item, + }; + finalOutputItem.push(item as unknown as ResponseContentOutput); + } else { + continue; + } + } + + // Construct final response + const finalResponse: ResponseOutput = { + id: responseId, + object: "response" as const, + created_at: initialResponse.created_at || Math.floor(Date.now() / 1000), + status: "completed" as const, + error: null, + incomplete_details: null, + instructions: null, + max_output_tokens: null, + model: chunk.model || input.model, + output: finalOutputItem as unknown as ResponseOutput["output"], + parallel_tool_calls: true, + previous_response_id: input.previous_response_id ?? null, + reasoning: null, + temperature: input.temperature ?? 1.0, + text: { format: { type: "text" } }, + tool_choice: input.tool_choice ?? "auto", + tools: input.tools ?? [], + top_p: input.top_p ?? 1.0, + truncation: input.truncation ?? "disabled", + usage: usage as ResponseOutput["usage"], + user: input.user ?? undefined, + metadata: input.metadata ?? {}, + output_text: "", + } as ResponseOutput; + + // Store history + const assistantMessage = { + role: "assistant" as const, + content: textContent || null, + }; + + // Add tool_calls property if needed + if (toolCalls.size > 0) { + const toolCallsArray = Array.from(toolCalls.values()).map((tc) => ({ + id: tc.id, + type: "function" as const, + function: { name: tc.name, arguments: tc.arguments }, + })); + + // Define a more specific type for the assistant message with tool calls + type AssistantMessageWithToolCalls = + OpenAI.Chat.Completions.ChatCompletionMessageParam & { + tool_calls: Array<{ + id: string; + type: "function"; + function: { + name: string; + arguments: string; + }; + }>; + }; + + // Use type assertion with the defined type + (assistantMessage as AssistantMessageWithToolCalls).tool_calls = + toolCallsArray; + } + const newHistory = [...fullMessages, assistantMessage]; + conversationHistories.set(responseId, { + previous_response_id: input.previous_response_id ?? null, + messages: newHistory, + }); + + yield { type: "response.completed", response: finalResponse }; + } + } catch (error) { + // console.error('\nERROR: ', JSON.stringify(error)); + yield { + type: "error", + code: + error instanceof Error && "code" in error + ? (error as { code: string }).code + : "unknown", + message: error instanceof Error ? error.message : String(error), + param: null, + }; + } +} + +export { + responsesCreateViaChatCompletions, + ResponseCreateInput, + ResponseOutput, + ResponseEvent, +}; diff --git a/codex-cli/tests/responses-chat-completions.test.ts b/codex-cli/tests/responses-chat-completions.test.ts new file mode 100644 index 00000000..85ab7d7d --- /dev/null +++ b/codex-cli/tests/responses-chat-completions.test.ts @@ -0,0 +1,842 @@ +import { describe, it, expect, vi, afterEach, beforeEach } from "vitest"; +import type { OpenAI } from "openai"; +import type { + ResponseCreateInput, + ResponseEvent, +} from "../src/utils/responses"; +import type { + ResponseInputItem, + Tool, + ResponseCreateParams, + ResponseFunctionToolCallItem, + ResponseFunctionToolCall, +} from "openai/resources/responses/responses"; + +// Define specific types for streaming and non-streaming params +type ResponseCreateParamsStreaming = ResponseCreateParams & { stream: true }; +type ResponseCreateParamsNonStreaming = ResponseCreateParams & { + stream?: false; +}; + +// Define additional type guard for tool calls done event +type ToolCallsDoneEvent = Extract< + ResponseEvent, + { type: "response.function_call_arguments.done" } +>; +type OutputTextDeltaEvent = Extract< + ResponseEvent, + { type: "response.output_text.delta" } +>; +type OutputTextDoneEvent = Extract< + ResponseEvent, + { type: "response.output_text.done" } +>; +type ResponseCompletedEvent = Extract< + ResponseEvent, + { type: "response.completed" } +>; + +// Mock state to control the OpenAI client behavior +const openAiState: { + createSpy?: ReturnType; + createStreamSpy?: ReturnType; +} = {}; + +// Mock the OpenAI client +vi.mock("openai", () => { + class FakeOpenAI { + public chat = { + completions: { + create: (...args: Array) => { + if (args[0]?.stream) { + return openAiState.createStreamSpy!(...args); + } + return openAiState.createSpy!(...args); + }, + }, + }; + } + + return { + __esModule: true, + default: FakeOpenAI, + }; +}); + +// Helper function to create properly typed test inputs +function createTestInput(options: { + model: string; + userMessage: string; + stream?: boolean; + tools?: Array; + previousResponseId?: string; +}): ResponseCreateInput { + const message: ResponseInputItem.Message = { + type: "message", + role: "user", + content: [ + { + type: "input_text" as const, + text: options.userMessage, + }, + ], + }; + + const input: ResponseCreateInput = { + model: options.model, + input: [message], + }; + + if (options.stream !== undefined) { + // @ts-expect-error TypeScript doesn't recognize this is valid + input.stream = options.stream; + } + + if (options.tools) { + input.tools = options.tools; + } + + if (options.previousResponseId) { + input.previous_response_id = options.previousResponseId; + } + + return input; +} + +// Type guard for function call content +function isFunctionCall(content: any): content is ResponseFunctionToolCall { + return ( + content && typeof content === "object" && content.type === "function_call" + ); +} + +// Additional type guard for tool call +function isToolCall(item: any): item is ResponseFunctionToolCallItem { + return item && typeof item === "object" && item.type === "function"; +} + +// Type guards for various event types +export function _isToolCallsDoneEvent( + event: ResponseEvent, +): event is ToolCallsDoneEvent { + return event.type === "response.function_call_arguments.done"; +} + +function isOutputTextDeltaEvent( + event: ResponseEvent, +): event is OutputTextDeltaEvent { + return event.type === "response.output_text.delta"; +} + +function isOutputTextDoneEvent( + event: ResponseEvent, +): event is OutputTextDoneEvent { + return event.type === "response.output_text.done"; +} + +function isResponseCompletedEvent( + event: ResponseEvent, +): event is ResponseCompletedEvent { + return event.type === "response.completed"; +} + +// Helper function to create a mock stream for tool calls testing +function createToolCallsStream() { + async function* fakeToolStream() { + yield { + id: "chatcmpl-123", + model: "gpt-4o", + choices: [ + { + delta: { role: "assistant" }, + finish_reason: null, + index: 0, + }, + ], + }; + yield { + id: "chatcmpl-123", + model: "gpt-4o", + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_123", + type: "function", + function: { name: "get_weather" }, + }, + ], + }, + finish_reason: null, + index: 0, + }, + ], + }; + yield { + id: "chatcmpl-123", + model: "gpt-4o", + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { + arguments: '{"location":"San Franci', + }, + }, + ], + }, + finish_reason: null, + index: 0, + }, + ], + }; + yield { + id: "chatcmpl-123", + model: "gpt-4o", + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { + arguments: 'sco"}', + }, + }, + ], + }, + finish_reason: null, + index: 0, + }, + ], + }; + yield { + id: "chatcmpl-123", + model: "gpt-4o", + choices: [ + { + delta: {}, + finish_reason: "tool_calls", + index: 0, + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + }; + } + + return fakeToolStream(); +} + +describe("responsesCreateViaChatCompletions", () => { + // Using any type here to avoid import issues + let responsesModule: any; + + beforeEach(async () => { + vi.resetModules(); + responsesModule = await import("../src/utils/responses"); + }); + + afterEach(() => { + vi.resetAllMocks(); + openAiState.createSpy = undefined; + openAiState.createStreamSpy = undefined; + }); + + describe("non-streaming mode", () => { + it("should convert basic user message to chat completions format", async () => { + // Setup mock response + openAiState.createSpy = vi.fn().mockResolvedValue({ + id: "chat-123", + model: "gpt-4o", + choices: [ + { + message: { + role: "assistant", + content: "This is a test response", + }, + finish_reason: "stop", + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + }); + + const openaiClient = new (await import("openai")).default({ + apiKey: "test-key", + }) as unknown as OpenAI; + + const inputMessage = createTestInput({ + model: "gpt-4o", + userMessage: "Hello world", + stream: false, + }); + + const result = await responsesModule.responsesCreateViaChatCompletions( + openaiClient, + inputMessage as ResponseCreateParams & { stream?: false | undefined }, + ); + + // Verify OpenAI was called with correct parameters + expect(openAiState.createSpy).toHaveBeenCalledTimes(1); + + // Skip type checking for mock objects in tests - this is acceptable for test code + // @ts-ignore + const callArgs = openAiState.createSpy?.mock?.calls?.[0]?.[0]; + if (callArgs) { + expect(callArgs.model).toBe("gpt-4o"); + expect(callArgs.messages).toEqual([ + { role: "user", content: "Hello world" }, + ]); + expect(callArgs.stream).toBeUndefined(); + } + + // Verify result format + expect(result.id).toBeDefined(); + expect(result.object).toBe("response"); + expect(result.model).toBe("gpt-4o"); + expect(result.status).toBe("completed"); + expect(result.output).toHaveLength(1); + + // Use type guard to check the output item type + const outputItem = result.output[0]; + expect(outputItem).toBeDefined(); + + if (outputItem && outputItem.type === "message") { + expect(outputItem.role).toBe("assistant"); + expect(outputItem.content).toHaveLength(1); + + const content = outputItem.content[0]; + if (content && content.type === "output_text") { + expect(content.text).toBe("This is a test response"); + } + } + + expect(result.usage?.total_tokens).toBe(15); + }); + + it("should handle function calling correctly", async () => { + // Setup mock response with tool calls + openAiState.createSpy = vi.fn().mockResolvedValue({ + id: "chat-456", + model: "gpt-4o", + choices: [ + { + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_abc123", + type: "function", + function: { + name: "get_weather", + arguments: JSON.stringify({ location: "New York" }), + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { + prompt_tokens: 15, + completion_tokens: 8, + total_tokens: 23, + }, + }); + + const openaiClient = new (await import("openai")).default({ + apiKey: "test-key", + }) as unknown as OpenAI; + + // Define function tool correctly + const weatherTool = { + type: "function" as const, + name: "get_weather", + description: "Get the current weather", + strict: true, + parameters: { + type: "object", + properties: { + location: { type: "string" }, + }, + required: ["location"], + }, + }; + + const inputMessage = createTestInput({ + model: "gpt-4o", + userMessage: "What's the weather in New York?", + tools: [weatherTool as any], + stream: false, + }); + + const result = await responsesModule.responsesCreateViaChatCompletions( + openaiClient, + inputMessage as ResponseCreateParams & { stream: false }, + ); + + // Verify OpenAI was called with correct parameters + expect(openAiState.createSpy).toHaveBeenCalledTimes(1); + + // Skip type checking for mock objects in tests + // @ts-ignore + const callArgs = openAiState.createSpy?.mock?.calls?.[0]?.[0]; + if (callArgs) { + expect(callArgs.model).toBe("gpt-4o"); + expect(callArgs.tools).toHaveLength(1); + expect(callArgs.tools[0].function.name).toBe("get_weather"); + } + + // Verify function call output directly instead of trying to check type + expect(result.output).toHaveLength(1); + + const outputItem = result.output[0]; + if (outputItem && outputItem.type === "message") { + const content = outputItem.content[0]; + + // Use the type guard function + expect(isFunctionCall(content)).toBe(true); + + // Using type assertion after type guard check + if (isFunctionCall(content)) { + // These properties should exist on ResponseFunctionToolCall + expect((content as any).name).toBe("get_weather"); + expect(JSON.parse((content as any).arguments).location).toBe( + "New York", + ); + } + } + }); + + it("should preserve conversation history", async () => { + // First interaction + openAiState.createSpy = vi.fn().mockResolvedValue({ + id: "chat-789", + model: "gpt-4o", + choices: [ + { + message: { + role: "assistant", + content: "Hello! How can I help you?", + }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 5, completion_tokens: 6, total_tokens: 11 }, + }); + + const openaiClient = new (await import("openai")).default({ + apiKey: "test-key", + }) as unknown as OpenAI; + + const firstInput = createTestInput({ + model: "gpt-4o", + userMessage: "Hi there", + stream: false, + }); + + const firstResponse = + await responsesModule.responsesCreateViaChatCompletions( + openaiClient, + firstInput as unknown as ResponseCreateParamsNonStreaming & { + stream?: false | undefined; + }, + ); + + // Reset the mock for second interaction + openAiState.createSpy.mockReset(); + openAiState.createSpy = vi.fn().mockResolvedValue({ + id: "chat-790", + model: "gpt-4o", + choices: [ + { + message: { + role: "assistant", + content: "I'm an AI assistant created by Anthropic.", + }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 15, completion_tokens: 10, total_tokens: 25 }, + }); + + // Second interaction with previous_response_id + const secondInput = createTestInput({ + model: "gpt-4o", + userMessage: "Who are you?", + previousResponseId: firstResponse.id, + stream: false, + }); + + await responsesModule.responsesCreateViaChatCompletions( + openaiClient, + secondInput as unknown as ResponseCreateParamsNonStreaming & { + stream?: false | undefined; + }, + ); + + // Verify history was included in second call + expect(openAiState.createSpy).toHaveBeenCalledTimes(1); + + // Skip type checking for mock objects in tests + // @ts-ignore + const secondCallArgs = openAiState.createSpy?.mock?.calls?.[0]?.[0]; + if (secondCallArgs) { + // Should have 3 messages: original user, assistant response, and new user message + expect(secondCallArgs.messages).toHaveLength(3); + expect(secondCallArgs.messages[0].role).toBe("user"); + expect(secondCallArgs.messages[0].content).toBe("Hi there"); + expect(secondCallArgs.messages[1].role).toBe("assistant"); + expect(secondCallArgs.messages[1].content).toBe( + "Hello! How can I help you?", + ); + expect(secondCallArgs.messages[2].role).toBe("user"); + expect(secondCallArgs.messages[2].content).toBe("Who are you?"); + } + }); + + it("handles tools correctly", async () => { + const testFunction = { + type: "function" as const, + name: "get_weather", + description: "Get the weather", + strict: true, + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The location to get the weather for", + }, + }, + required: ["location"], + }, + }; + + // Mock response with a tool call + openAiState.createSpy = vi.fn().mockResolvedValue({ + id: "chatcmpl-123", + created: Date.now(), + model: "gpt-4o", + object: "chat.completion", + choices: [ + { + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_123", + type: "function", + function: { + name: "get_weather", + arguments: JSON.stringify({ location: "San Francisco" }), + }, + }, + ], + }, + finish_reason: "tool_calls", + index: 0, + }, + ], + }); + + const openaiClient = new (await import("openai")).default({ + apiKey: "test-key", + }) as unknown as OpenAI; + + const inputMessage = createTestInput({ + model: "gpt-4o", + userMessage: "What's the weather in San Francisco?", + tools: [testFunction], + }); + + const result = await responsesModule.responsesCreateViaChatCompletions( + openaiClient, + inputMessage as unknown as ResponseCreateParamsNonStreaming, + ); + + expect(result.status).toBe("requires_action"); + + // Cast result to include required_action to address TypeScript issues + const resultWithAction = result as any; + + // Add null checks for required_action + expect(resultWithAction.required_action).not.toBeNull(); + expect(resultWithAction.required_action?.type).toBe( + "submit_tool_outputs", + ); + + // Safely access the tool calls with proper null checks + const toolCalls = + resultWithAction.required_action?.submit_tool_outputs?.tool_calls || []; + expect(toolCalls.length).toBe(1); + + if (toolCalls.length > 0) { + const toolCall = toolCalls[0]; + expect(toolCall.type).toBe("function"); + + if (isToolCall(toolCall)) { + // Access with type assertion after type guard + expect((toolCall as any).function.name).toBe("get_weather"); + expect(JSON.parse((toolCall as any).function.arguments)).toEqual({ + location: "San Francisco", + }); + } + } + + // Only check model, messages, and tools in exact match + expect(openAiState.createSpy).toHaveBeenCalledWith( + expect.objectContaining({ + model: "gpt-4o", + messages: [ + { + role: "user", + content: "What's the weather in San Francisco?", + }, + ], + tools: [ + expect.objectContaining({ + type: "function", + function: { + name: "get_weather", + description: "Get the weather", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The location to get the weather for", + }, + }, + required: ["location"], + }, + }, + }), + ], + }), + ); + }); + }); + + describe("streaming mode", () => { + it("should handle streaming responses correctly", async () => { + // Mock an async generator for streaming + async function* fakeStream() { + yield { + id: "chatcmpl-123", + model: "gpt-4o", + choices: [ + { + delta: { role: "assistant" }, + finish_reason: null, + index: 0, + }, + ], + }; + yield { + id: "chatcmpl-123", + model: "gpt-4o", + choices: [ + { + delta: { content: "Hello" }, + finish_reason: null, + index: 0, + }, + ], + }; + yield { + id: "chatcmpl-123", + model: "gpt-4o", + choices: [ + { + delta: { content: " world" }, + finish_reason: null, + index: 0, + }, + ], + }; + yield { + id: "chatcmpl-123", + model: "gpt-4o", + choices: [ + { + delta: {}, + finish_reason: "stop", + index: 0, + }, + ], + usage: { prompt_tokens: 5, completion_tokens: 2, total_tokens: 7 }, + }; + } + + openAiState.createStreamSpy = vi.fn().mockResolvedValue(fakeStream()); + + const openaiClient = new (await import("openai")).default({ + apiKey: "test-key", + }) as unknown as OpenAI; + + const inputMessage = createTestInput({ + model: "gpt-4o", + userMessage: "Say hello", + stream: true, + }); + + const streamGenerator = + await responsesModule.responsesCreateViaChatCompletions( + openaiClient, + inputMessage as unknown as ResponseCreateParamsStreaming & { + stream: true; + }, + ); + + // Collect all events from the stream + const events: Array = []; + for await (const event of streamGenerator) { + events.push(event); + } + + // Verify stream generation + expect(events.length).toBeGreaterThan(0); + + // Check initial events + const firstEvent = events[0]; + const secondEvent = events[1]; + expect(firstEvent?.type).toBe("response.created"); + expect(secondEvent?.type).toBe("response.in_progress"); + + // Find content delta events using proper type guard + const deltaEvents = events.filter(isOutputTextDeltaEvent); + + // Should have two delta events for "Hello" and " world" + expect(deltaEvents).toHaveLength(2); + expect(deltaEvents[0]?.delta).toBe("Hello"); + expect(deltaEvents[1]?.delta).toBe(" world"); + + // Check final completion event with type guard + const completionEvent = events.find(isResponseCompletedEvent); + expect(completionEvent).toBeDefined(); + if (completionEvent) { + expect(completionEvent.response.status).toBe("completed"); + } + + // Text should be concatenated + const textDoneEvent = events.find(isOutputTextDoneEvent); + expect(textDoneEvent).toBeDefined(); + if (textDoneEvent) { + expect(textDoneEvent.text).toBe("Hello world"); + } + }); + + it("should handle errors gracefully", async () => { + // Setup mock to throw an error + openAiState.createSpy = vi + .fn() + .mockRejectedValue(new Error("API connection error")); + + const openaiClient = new (await import("openai")).default({ + apiKey: "test-key", + }) as unknown as OpenAI; + + const inputMessage = createTestInput({ + model: "gpt-4o", + userMessage: "Test message", + stream: false, + }); + + // Expect the function to throw an error + await expect( + responsesModule.responsesCreateViaChatCompletions( + openaiClient, + inputMessage as unknown as ResponseCreateParamsNonStreaming & { + stream?: false | undefined; + }, + ), + ).rejects.toThrow("Failed to process chat completion"); + }); + + it("handles streaming with tool calls", async () => { + // Mock a streaming response with tool calls + const mockStream = createToolCallsStream(); + openAiState.createStreamSpy = vi.fn().mockReturnValue(mockStream); + + const openaiClient = new (await import("openai")).default({ + apiKey: "test-key", + }) as unknown as OpenAI; + + const testFunction = { + type: "function" as const, + name: "get_weather", + description: "Get the current weather", + strict: true, + parameters: { + type: "object", + properties: { + location: { type: "string" }, + }, + required: ["location"], + }, + }; + + const inputMessage = createTestInput({ + model: "gpt-4o", + userMessage: "What's the weather in San Francisco?", + tools: [testFunction], + stream: true, + }); + + const streamGenerator = + await responsesModule.responsesCreateViaChatCompletions( + openaiClient, + inputMessage as unknown as ResponseCreateParamsStreaming, + ); + + // Collect all events from the stream + const events: Array = []; + for await (const event of streamGenerator) { + events.push(event); + } + + // Verify stream generation + expect(events.length).toBeGreaterThan(0); + + // Look for function call related events of any type related to tool calls + const toolCallEvents = events.filter( + (event) => + event.type.includes("function_call") || + event.type.includes("tool") || + (event.type === "response.output_item.added" && + "item" in event && + event.item?.type === "function_call"), + ); + + expect(toolCallEvents.length).toBeGreaterThan(0); + + // Check if we have the completed event which should contain the final result + const completedEvent = events.find(isResponseCompletedEvent); + expect(completedEvent).toBeDefined(); + + if (completedEvent) { + // Get the function call from the output array + const functionCallItem = completedEvent.response.output.find( + (item) => item.type === "function_call", + ); + expect(functionCallItem).toBeDefined(); + + if (functionCallItem && functionCallItem.type === "function_call") { + expect(functionCallItem.name).toBe("get_weather"); + // The arguments is a JSON string, but we can check if it includes San Francisco + expect(functionCallItem.arguments).toContain("San Francisco"); + } + } + }); + }); +});