fix: inconsistent usage of base URL and API key (#507)

A recent commit introduced the ability to use third-party model
providers. (Really appreciate it!)

However, the usage is inconsistent: some pieces of code use the custom
providers, whereas others still have the old behavior. Additionally,
`OPENAI_BASE_URL` is now being disregarded when it shouldn't be.

This PR normalizes the usage to `getApiKey` and `getBaseUrl`, and
enables the use of `OPENAI_BASE_URL` if present.

---------

Co-authored-by: Gabriel Bianconi <GabrielBianconi@users.noreply.github.com>
This commit is contained in:
Gabriel Bianconi
2025-04-22 10:51:26 -04:00
committed by GitHub
parent d78f77edb7
commit 98a22273d9
5 changed files with 39 additions and 19 deletions

View File

@@ -13,7 +13,7 @@ import { useTerminalSize } from "../../hooks/use-terminal-size.js";
import { AgentLoop } from "../../utils/agent/agent-loop.js"; import { AgentLoop } from "../../utils/agent/agent-loop.js";
import { ReviewDecision } from "../../utils/agent/review.js"; import { ReviewDecision } from "../../utils/agent/review.js";
import { generateCompactSummary } from "../../utils/compact-summary.js"; import { generateCompactSummary } from "../../utils/compact-summary.js";
import { OPENAI_BASE_URL, saveConfig } from "../../utils/config.js"; import { getBaseUrl, getApiKey, saveConfig } from "../../utils/config.js";
import { extractAppliedPatches as _extractAppliedPatches } from "../../utils/extract-applied-patches.js"; import { extractAppliedPatches as _extractAppliedPatches } from "../../utils/extract-applied-patches.js";
import { getGitDiff } from "../../utils/get-diff.js"; import { getGitDiff } from "../../utils/get-diff.js";
import { createInputItem } from "../../utils/input-utils.js"; import { createInputItem } from "../../utils/input-utils.js";
@@ -65,18 +65,21 @@ const colorsByPolicy: Record<ApprovalPolicy, ColorName | undefined> = {
* *
* @param command The command to explain * @param command The command to explain
* @param model The model to use for generating the explanation * @param model The model to use for generating the explanation
* @param flexMode Whether to use the flex-mode service tier
* @param config The configuration object
* @returns A human-readable explanation of what the command does * @returns A human-readable explanation of what the command does
*/ */
async function generateCommandExplanation( async function generateCommandExplanation(
command: Array<string>, command: Array<string>,
model: string, model: string,
flexMode: boolean, flexMode: boolean,
config: AppConfig,
): Promise<string> { ): Promise<string> {
try { try {
// Create a temporary OpenAI client // Create a temporary OpenAI client
const oai = new OpenAI({ const oai = new OpenAI({
apiKey: process.env["OPENAI_API_KEY"], apiKey: getApiKey(config.provider),
baseURL: OPENAI_BASE_URL, baseURL: getBaseUrl(config.provider),
}); });
// Format the command for display // Format the command for display
@@ -156,6 +159,7 @@ export default function TerminalChat({
items, items,
model, model,
Boolean(config.flexMode), Boolean(config.flexMode),
config,
); );
setItems([ setItems([
{ {
@@ -272,6 +276,7 @@ export default function TerminalChat({
command, command,
model, model,
Boolean(config.flexMode), Boolean(config.flexMode),
config,
); );
log(`Generated explanation: ${explanation}`); log(`Generated explanation: ${explanation}`);

View File

@@ -5,12 +5,7 @@ import type { FileOperation } from "../utils/singlepass/file_ops";
import Spinner from "./vendor/ink-spinner"; // Thirdparty / vendor components import Spinner from "./vendor/ink-spinner"; // Thirdparty / vendor components
import TextInput from "./vendor/ink-text-input"; import TextInput from "./vendor/ink-text-input";
import { import { OPENAI_TIMEOUT_MS, getBaseUrl, getApiKey } from "../utils/config";
OPENAI_TIMEOUT_MS,
OPENAI_BASE_URL as _OPENAI_BASE_URL,
getBaseUrl,
getApiKey,
} from "../utils/config";
import { import {
generateDiffSummary, generateDiffSummary,
generateEditSummary, generateEditSummary,
@@ -399,8 +394,8 @@ export function SinglePassApp({
}); });
const openai = new OpenAI({ const openai = new OpenAI({
apiKey: getApiKey(config.provider ?? "openai"), apiKey: getApiKey(config.provider),
baseURL: getBaseUrl(config.provider ?? "openai"), baseURL: getBaseUrl(config.provider),
timeout: OPENAI_TIMEOUT_MS, timeout: OPENAI_TIMEOUT_MS,
}); });
const chatResp = await openai.beta.chat.completions.parse({ const chatResp = await openai.beta.chat.completions.parse({

View File

@@ -1,8 +1,8 @@
import type { AppConfig } from "./config.js";
import type { ResponseItem } from "openai/resources/responses/responses.mjs"; import type { ResponseItem } from "openai/resources/responses/responses.mjs";
import { OPENAI_BASE_URL } from "./config.js"; import { getBaseUrl, getApiKey } from "./config.js";
import OpenAI from "openai"; import OpenAI from "openai";
/** /**
* Generate a condensed summary of the conversation items. * Generate a condensed summary of the conversation items.
* @param items The list of conversation items to summarize * @param items The list of conversation items to summarize
@@ -14,16 +14,18 @@ import OpenAI from "openai";
* @param items The list of conversation items to summarize * @param items The list of conversation items to summarize
* @param model The model to use for generating the summary * @param model The model to use for generating the summary
* @param flexMode Whether to use the flex-mode service tier * @param flexMode Whether to use the flex-mode service tier
* @param config The configuration object
* @returns A concise structured summary string * @returns A concise structured summary string
*/ */
export async function generateCompactSummary( export async function generateCompactSummary(
items: Array<ResponseItem>, items: Array<ResponseItem>,
model: string, model: string,
flexMode = false, flexMode = false,
config: AppConfig,
): Promise<string> { ): Promise<string> {
const oai = new OpenAI({ const oai = new OpenAI({
apiKey: process.env["OPENAI_API_KEY"], apiKey: getApiKey(config.provider),
baseURL: OPENAI_BASE_URL, baseURL: getBaseUrl(config.provider),
}); });
const conversationText = items const conversationText = items

View File

@@ -41,15 +41,26 @@ export function setApiKey(apiKey: string): void {
OPENAI_API_KEY = apiKey; OPENAI_API_KEY = apiKey;
} }
export function getBaseUrl(provider: string): string | undefined { export function getBaseUrl(provider: string = "openai"): string | undefined {
// If the provider is `openai` and `OPENAI_BASE_URL` is set, use it
if (provider === "openai" && OPENAI_BASE_URL !== "") {
return OPENAI_BASE_URL;
}
const providerInfo = providers[provider.toLowerCase()]; const providerInfo = providers[provider.toLowerCase()];
if (providerInfo) { if (providerInfo) {
return providerInfo.baseURL; return providerInfo.baseURL;
} }
// If the provider not found in the providers list and `OPENAI_BASE_URL` is set, use it
if (OPENAI_BASE_URL !== "") {
return OPENAI_BASE_URL;
}
return undefined; return undefined;
} }
export function getApiKey(provider: string): string | undefined { export function getApiKey(provider: string = "openai"): string | undefined {
const providerInfo = providers[provider.toLowerCase()]; const providerInfo = providers[provider.toLowerCase()];
if (providerInfo) { if (providerInfo) {
if (providerInfo.name === "Ollama") { if (providerInfo.name === "Ollama") {
@@ -58,6 +69,11 @@ export function getApiKey(provider: string): string | undefined {
return process.env[providerInfo.envKey]; return process.env[providerInfo.envKey];
} }
// If the provider not found in the providers list and `OPENAI_API_KEY` is set, use it
if (OPENAI_API_KEY !== "") {
return OPENAI_API_KEY;
}
return undefined; return undefined;
} }

View File

@@ -20,9 +20,11 @@ async function fetchModels(provider: string): Promise<Array<string>> {
throw new Error("No API key configured for provider: " + provider); throw new Error("No API key configured for provider: " + provider);
} }
const baseURL = getBaseUrl(provider);
try { try {
const openai = new OpenAI({ apiKey: getApiKey(provider), baseURL }); const openai = new OpenAI({
apiKey: getApiKey(provider),
baseURL: getBaseUrl(provider),
});
const list = await openai.models.list(); const list = await openai.models.list();
const models: Array<string> = []; const models: Array<string> = [];
for await (const model of list as AsyncIterable<{ id?: string }>) { for await (const model of list as AsyncIterable<{ id?: string }>) {