Unverified 提交 759c60f2 authored 作者: Will Chen's avatar Will Chen 提交者: GitHub

Fix auto model for engine (#2170)

<!-- CURSOR_SUMMARY --> > [!NOTE] > Fixes auto/local-agent model routing and provider-specific behavior. > > - Refactors `llm_engine_provider` to require `chatParams { providerId }`; wires it into custom fetch and model constructors (`provider`, `chatModel`, `responses`) > - Updates `get_model_client` to pass `providerId`, tag fallback models (`openai`, `anthropic`, `google`), and use `responses()` for OpenAI in `local-agent` and auto fallback > - Adds Responses API support to fake LLM server (`/v1/responses`) with streaming SSE handler and dump generation > - Extends test helper to parse Responses API dumps (`body.input`) and adds `localAgentUseAutoModel` setup flag > - New e2e test `local_agent_auto.spec.ts` with snapshot verifying request payload via Responses API > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit e5b02d253c9842f14da0099bb11e1b05548e9245. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY --> <!-- This is an auto-generated description by cubic. --> --- ## Summary by cubic Fixes auto model selection in the Dyad engine by passing the correct providerId with each request. Ensures provider-specific options are applied and avoids incorrect defaults. - **Bug Fixes** - Pass providerId to chat/responses models and into the fetch layer for getExtraProviderOptions. - Remove originalProviderId from createDyadEngine; add ChatParams and update provider API signatures. - Tag fallback models with providerId (openai, anthropic, google) and forward model.provider in getModelClient. <sup>Written for commit e5b02d253c9842f14da0099bb11e1b05548e9245. Summary will update on new commits.</sup> <!-- End of auto-generated description by cubic. -->
上级 be7b6977
...@@ -315,7 +315,12 @@ export class PageObject { ...@@ -315,7 +315,12 @@ export class PageObject {
async setUpDyadPro({ async setUpDyadPro({
autoApprove = false, autoApprove = false,
localAgent = false, localAgent = false,
}: { autoApprove?: boolean; localAgent?: boolean } = {}) { localAgentUseAutoModel = false,
}: {
autoApprove?: boolean;
localAgent?: boolean;
localAgentUseAutoModel?: boolean;
} = {}) {
await this.baseSetup(); await this.baseSetup();
await this.goToSettingsTab(); await this.goToSettingsTab();
if (autoApprove) { if (autoApprove) {
...@@ -328,7 +333,7 @@ export class PageObject { ...@@ -328,7 +333,7 @@ export class PageObject {
await this.goToAppsTab(); await this.goToAppsTab();
// Select a non-openAI model for local agent mode, // Select a non-openAI model for local agent mode,
// since openAI models go to the responses API. // since openAI models go to the responses API.
if (localAgent) { if (localAgent && !localAgentUseAutoModel) {
await this.selectModel({ await this.selectModel({
provider: "Anthropic", provider: "Anthropic",
model: "Claude Opus 4.5", model: "Claude Opus 4.5",
...@@ -800,14 +805,26 @@ export class PageObject { ...@@ -800,14 +805,26 @@ export class PageObject {
// Perform snapshot comparison // Perform snapshot comparison
const parsedDump = JSON.parse(dumpContent); const parsedDump = JSON.parse(dumpContent);
if (type === "request") { if (type === "request") {
parsedDump["body"]["messages"] = parsedDump["body"]["messages"].map( if (parsedDump["body"]["input"]) {
(message: any) => { parsedDump["body"]["input"] = parsedDump["body"]["input"].map(
if (message.role === "system") { (input: any) => {
message.content = "[[SYSTEM_MESSAGE]]"; if (input.role === "system") {
} input.content = "[[SYSTEM_MESSAGE]]";
return message; }
}, return input;
); },
);
}
if (parsedDump["body"]["messages"]) {
parsedDump["body"]["messages"] = parsedDump["body"]["messages"].map(
(message: any) => {
if (message.role === "system") {
message.content = "[[SYSTEM_MESSAGE]]";
}
return message;
},
);
}
// Normalize fileIds to be deterministic based on content // Normalize fileIds to be deterministic based on content
normalizeVersionedFiles(parsedDump); normalizeVersionedFiles(parsedDump);
expect( expect(
...@@ -816,9 +833,15 @@ export class PageObject { ...@@ -816,9 +833,15 @@ export class PageObject {
return; return;
} }
expect( expect(
prettifyDump(parsedDump["body"]["messages"], { prettifyDump(
onlyLastMessage: type === "last-message", // responses API
}), parsedDump["body"]["input"] ??
// chat completion API
parsedDump["body"]["messages"],
{
onlyLastMessage: type === "last-message",
},
),
).toMatchSnapshot(name); ).toMatchSnapshot(name);
} }
......
import { testSkipIfWindows } from "./helpers/test_helper";
testSkipIfWindows("local-agent - auto model", async ({ po }) => {
await po.setUpDyadPro({ localAgent: true, localAgentUseAutoModel: true });
await po.importApp("minimal");
await po.selectLocalAgentMode();
await po.sendPrompt("[dump]");
await po.snapshotServerDump("request");
});
...@@ -91,7 +91,6 @@ export async function getModelClient( ...@@ -91,7 +91,6 @@ export async function getModelClient(
const provider = createDyadEngine({ const provider = createDyadEngine({
apiKey: dyadApiKey, apiKey: dyadApiKey,
baseURL: dyadEngineUrl ?? "https://engine.dyad.sh/v1", baseURL: dyadEngineUrl ?? "https://engine.dyad.sh/v1",
originalProviderId: model.provider,
dyadOptions: { dyadOptions: {
enableLazyEdits: enableLazyEdits:
settings.selectedChatMode === "ask" settings.selectedChatMode === "ask"
...@@ -214,12 +213,13 @@ function getProModelClient({ ...@@ -214,12 +213,13 @@ function getProModelClient({
model: createFallback({ model: createFallback({
models: [ models: [
// openai requires no prefix. // openai requires no prefix.
provider.responses(`${GPT_5_2_MODEL_NAME}`), provider.responses(`${GPT_5_2_MODEL_NAME}`, { providerId: "openai" }),
provider(`anthropic/${SONNET_4_5}`), provider(`anthropic/${SONNET_4_5}`, { providerId: "anthropic" }),
provider(`gemini/${GEMINI_3_FLASH}`), provider(`gemini/${GEMINI_3_FLASH}`, { providerId: "google" }),
], ],
}), }),
// Using openAI as the default provider. // Using openAI as the default provider.
// TODO: we should remove this and rely on the provider id passed into the provider().
builtinProviderId: "openai", builtinProviderId: "openai",
}; };
} }
...@@ -228,12 +228,12 @@ function getProModelClient({ ...@@ -228,12 +228,12 @@ function getProModelClient({
model.provider === "openai" model.provider === "openai"
) { ) {
return { return {
model: provider.responses(modelId), model: provider.responses(modelId, { providerId: model.provider }),
builtinProviderId: model.provider, builtinProviderId: model.provider,
}; };
} }
return { return {
model: provider(modelId), model: provider(modelId, { providerId: model.provider }),
builtinProviderId: model.provider, builtinProviderId: model.provider,
}; };
} }
......
...@@ -14,7 +14,9 @@ import type { LanguageModel } from "ai"; ...@@ -14,7 +14,9 @@ import type { LanguageModel } from "ai";
const logger = log.scope("llm_engine_provider"); const logger = log.scope("llm_engine_provider");
export type ExampleChatModelId = string & {}; export type ExampleChatModelId = string & {};
export interface ExampleChatSettings {} export interface ChatParams {
providerId: string;
}
export interface ExampleProviderSettings { export interface ExampleProviderSettings {
/** /**
Example API key. Example API key.
...@@ -38,7 +40,6 @@ or to provide a custom fetch implementation for e.g. testing. ...@@ -38,7 +40,6 @@ or to provide a custom fetch implementation for e.g. testing.
*/ */
fetch?: FetchFunction; fetch?: FetchFunction;
originalProviderId: string;
dyadOptions: { dyadOptions: {
enableLazyEdits?: boolean; enableLazyEdits?: boolean;
enableSmartFilesContext?: boolean; enableSmartFilesContext?: boolean;
...@@ -51,17 +52,14 @@ export interface DyadEngineProvider { ...@@ -51,17 +52,14 @@ export interface DyadEngineProvider {
/** /**
Creates a model for text generation. Creates a model for text generation.
*/ */
(modelId: ExampleChatModelId, settings?: ExampleChatSettings): LanguageModel; (modelId: ExampleChatModelId, chatParams: ChatParams): LanguageModel;
/** /**
Creates a chat model for text generation. Creates a chat model for text generation.
*/ */
chatModel( chatModel(modelId: ExampleChatModelId, chatParams: ChatParams): LanguageModel;
modelId: ExampleChatModelId,
settings?: ExampleChatSettings,
): LanguageModel;
responses(modelId: ExampleChatModelId): LanguageModel; responses(modelId: ExampleChatModelId, chatParams: ChatParams): LanguageModel;
} }
export function createDyadEngine( export function createDyadEngine(
...@@ -103,7 +101,11 @@ export function createDyadEngine( ...@@ -103,7 +101,11 @@ export function createDyadEngine(
}); });
// Custom fetch implementation that adds dyad-specific options to the request // Custom fetch implementation that adds dyad-specific options to the request
const createDyadFetch = (): FetchFunction => { const createDyadFetch = ({
providerId,
}: {
providerId: string;
}): FetchFunction => {
return (input: RequestInfo | URL, init?: RequestInit) => { return (input: RequestInfo | URL, init?: RequestInit) => {
// Use default fetch if no init or body // Use default fetch if no init or body
if (!init || !init.body || typeof init.body !== "string") { if (!init || !init.body || typeof init.body !== "string") {
...@@ -114,10 +116,7 @@ export function createDyadEngine( ...@@ -114,10 +116,7 @@ export function createDyadEngine(
// Parse the request body to manipulate it // Parse the request body to manipulate it
const parsedBody = { const parsedBody = {
...JSON.parse(init.body), ...JSON.parse(init.body),
...getExtraProviderOptions( ...getExtraProviderOptions(providerId, options.settings),
options.originalProviderId,
options.settings,
),
}; };
const dyadVersionedFiles = parsedBody.dyadVersionedFiles; const dyadVersionedFiles = parsedBody.dyadVersionedFiles;
if ("dyadVersionedFiles" in parsedBody) { if ("dyadVersionedFiles" in parsedBody) {
...@@ -195,25 +194,32 @@ export function createDyadEngine( ...@@ -195,25 +194,32 @@ export function createDyadEngine(
}; };
}; };
const createChatModel = (modelId: ExampleChatModelId) => { const createChatModel = (
modelId: ExampleChatModelId,
chatParams: ChatParams,
) => {
const config = { const config = {
...getCommonModelConfig(), ...getCommonModelConfig(),
fetch: createDyadFetch(), fetch: createDyadFetch({ providerId: chatParams.providerId }),
}; };
return new OpenAICompatibleChatLanguageModel(modelId, config); return new OpenAICompatibleChatLanguageModel(modelId, config);
}; };
const createResponsesModel = (modelId: ExampleChatModelId) => { const createResponsesModel = (
modelId: ExampleChatModelId,
chatParams: ChatParams,
) => {
const config = { const config = {
...getCommonModelConfig(), ...getCommonModelConfig(),
fetch: createDyadFetch(), fetch: createDyadFetch({ providerId: chatParams.providerId }),
}; };
return new OpenAIResponsesLanguageModel(modelId, config); return new OpenAIResponsesLanguageModel(modelId, config);
}; };
const provider = (modelId: ExampleChatModelId) => createChatModel(modelId); const provider = (modelId: ExampleChatModelId, chatParams: ChatParams) =>
createChatModel(modelId, chatParams);
provider.chatModel = createChatModel; provider.chatModel = createChatModel;
provider.responses = createResponsesModel; provider.responses = createResponsesModel;
......
...@@ -2,6 +2,7 @@ import express from "express"; ...@@ -2,6 +2,7 @@ import express from "express";
import { createServer } from "http"; import { createServer } from "http";
import cors from "cors"; import cors from "cors";
import { createChatCompletionHandler } from "./chatCompletionHandler"; import { createChatCompletionHandler } from "./chatCompletionHandler";
import { createResponsesHandler } from "./responsesHandler";
import { import {
handleDeviceCode, handleDeviceCode,
handleAccessToken, handleAccessToken,
...@@ -152,6 +153,8 @@ app.get("/lmstudio/api/v0/models", (req, res) => { ...@@ -152,6 +153,8 @@ app.get("/lmstudio/api/v0/models", (req, res) => {
`/${provider}/v1/chat/completions`, `/${provider}/v1/chat/completions`,
createChatCompletionHandler(provider), createChatCompletionHandler(provider),
); );
// Also add responses API endpoints for each provider
app.post(`/${provider}/v1/responses`, createResponsesHandler(provider));
}); });
// Azure-specific endpoints (Azure client uses different URL patterns) // Azure-specific endpoints (Azure client uses different URL patterns)
...@@ -163,6 +166,7 @@ app.post( ...@@ -163,6 +166,7 @@ app.post(
// Default test provider handler: // Default test provider handler:
app.post("/v1/chat/completions", createChatCompletionHandler(".")); app.post("/v1/chat/completions", createChatCompletionHandler("."));
app.post("/v1/responses", createResponsesHandler("."));
// GitHub API Mock Endpoints // GitHub API Mock Endpoints
console.log("Setting up GitHub mock endpoints"); console.log("Setting up GitHub mock endpoints");
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论