Unverified 提交 d535db62 authored 作者: Will Chen's avatar Will Chen 提交者: GitHub

Upgrade to AI sdk with codemod (#1000)

上级 573642ae
- paragraph: hi - paragraph: hi
- paragraph: ollamachunkollamachunk - img
- text: file1.txt
- img
- text: file1.txt
- paragraph: More EOM
- button "Retry": - button "Retry":
- img - img
\ No newline at end of file
差异被折叠。
...@@ -84,17 +84,18 @@ ...@@ -84,17 +84,18 @@
"vitest": "^3.1.1" "vitest": "^3.1.1"
}, },
"dependencies": { "dependencies": {
"@ai-sdk/anthropic": "^1.2.8", "@ai-sdk/anthropic": "^2.0.4",
"@ai-sdk/google": "^1.2.19", "@ai-sdk/google": "^2.0.6",
"@ai-sdk/openai": "^1.3.24", "@ai-sdk/openai": "^2.0.15",
"@ai-sdk/openai-compatible": "^0.2.13", "@ai-sdk/openai-compatible": "^1.0.8",
"@ai-sdk/provider-utils": "^3.0.3",
"@biomejs/biome": "^1.9.4", "@biomejs/biome": "^1.9.4",
"@dyad-sh/supabase-management-js": "v1.0.0", "@dyad-sh/supabase-management-js": "v1.0.0",
"@lexical/react": "^0.33.1", "@lexical/react": "^0.33.1",
"@monaco-editor/react": "^4.7.0-rc.0", "@monaco-editor/react": "^4.7.0-rc.0",
"@neondatabase/api-client": "^2.1.0", "@neondatabase/api-client": "^2.1.0",
"@neondatabase/serverless": "^1.0.1", "@neondatabase/serverless": "^1.0.1",
"@openrouter/ai-sdk-provider": "^0.4.5", "@openrouter/ai-sdk-provider": "^1.1.2",
"@radix-ui/react-accordion": "^1.2.4", "@radix-ui/react-accordion": "^1.2.4",
"@radix-ui/react-alert-dialog": "^1.1.13", "@radix-ui/react-alert-dialog": "^1.1.13",
"@radix-ui/react-checkbox": "^1.3.2", "@radix-ui/react-checkbox": "^1.3.2",
...@@ -118,7 +119,7 @@ ...@@ -118,7 +119,7 @@
"@types/uuid": "^10.0.0", "@types/uuid": "^10.0.0",
"@vercel/sdk": "^1.10.0", "@vercel/sdk": "^1.10.0",
"@vitejs/plugin-react": "^4.3.4", "@vitejs/plugin-react": "^4.3.4",
"ai": "^4.3.4", "ai": "^5.0.15",
"better-sqlite3": "^11.9.1", "better-sqlite3": "^11.9.1",
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
...@@ -140,7 +141,6 @@ ...@@ -140,7 +141,6 @@
"lexical-beautiful-mentions": "^0.1.47", "lexical-beautiful-mentions": "^0.1.47",
"lucide-react": "^0.487.0", "lucide-react": "^0.487.0",
"monaco-editor": "^0.52.2", "monaco-editor": "^0.52.2",
"ollama-ai-provider": "^1.2.0",
"openai": "^4.91.1", "openai": "^4.91.1",
"posthog-js": "^1.236.3", "posthog-js": "^1.236.3",
"react": "^19.0.0", "react": "^19.0.0",
...@@ -158,7 +158,8 @@ ...@@ -158,7 +158,8 @@
"tree-kill": "^1.2.2", "tree-kill": "^1.2.2",
"tw-animate-css": "^1.2.5", "tw-animate-css": "^1.2.5",
"update-electron-app": "^3.1.1", "update-electron-app": "^3.1.1",
"uuid": "^11.1.0" "uuid": "^11.1.0",
"zod": "^3.25.76"
}, },
"lint-staged": { "lint-staged": {
"**/*.{js,mjs,cjs,jsx,ts,mts,cts,tsx,vue,astro,svelte}": "oxlint", "**/*.{js,mjs,cjs,jsx,ts,mts,cts,tsx,vue,astro,svelte}": "oxlint",
......
import { v4 as uuidv4 } from "uuid"; import { v4 as uuidv4 } from "uuid";
import { ipcMain } from "electron"; import { ipcMain } from "electron";
import { import {
CoreMessage, ModelMessage,
TextPart, TextPart,
ImagePart, ImagePart,
streamText, streamText,
...@@ -134,14 +134,14 @@ async function processStreamChunks({ ...@@ -134,14 +134,14 @@ async function processStreamChunks({
chunk = "</think>"; chunk = "</think>";
inThinkingBlock = false; inThinkingBlock = false;
} }
chunk += part.textDelta; chunk += part.text;
} else if (part.type === "reasoning") { } else if (part.type === "reasoning-delta") {
if (!inThinkingBlock) { if (!inThinkingBlock) {
chunk = "<think>"; chunk = "<think>";
inThinkingBlock = true; inThinkingBlock = true;
} }
chunk += escapeDyadTags(part.textDelta); chunk += escapeDyadTags(part.text);
} }
if (!chunk) { if (!chunk) {
...@@ -603,7 +603,7 @@ This conversation includes one or more image attachments. When the user uploads ...@@ -603,7 +603,7 @@ This conversation includes one or more image attachments. When the user uploads
] as const) ] as const)
: []; : [];
let chatMessages: CoreMessage[] = [ let chatMessages: ModelMessage[] = [
...codebasePrefix, ...codebasePrefix,
...otherCodebasePrefix, ...otherCodebasePrefix,
...limitedMessageHistory.map((msg) => ({ ...limitedMessageHistory.map((msg) => ({
...@@ -647,7 +647,7 @@ This conversation includes one or more image attachments. When the user uploads ...@@ -647,7 +647,7 @@ This conversation includes one or more image attachments. When the user uploads
content: content:
"Summarize the following chat: " + "Summarize the following chat: " +
formatMessagesForSummary(previousChat?.messages ?? []), formatMessagesForSummary(previousChat?.messages ?? []),
} satisfies CoreMessage, } satisfies ModelMessage,
]; ];
} }
...@@ -655,7 +655,7 @@ This conversation includes one or more image attachments. When the user uploads ...@@ -655,7 +655,7 @@ This conversation includes one or more image attachments. When the user uploads
chatMessages, chatMessages,
modelClient, modelClient,
}: { }: {
chatMessages: CoreMessage[]; chatMessages: ModelMessage[];
modelClient: ModelClient; modelClient: ModelClient;
}) => { }) => {
const dyadRequestId = uuidv4(); const dyadRequestId = uuidv4();
...@@ -668,7 +668,7 @@ This conversation includes one or more image attachments. When the user uploads ...@@ -668,7 +668,7 @@ This conversation includes one or more image attachments. When the user uploads
logger.log("sending AI request"); logger.log("sending AI request");
} }
return streamText({ return streamText({
maxTokens: await getMaxTokens(settings.selectedModel), maxOutputTokens: await getMaxTokens(settings.selectedModel),
temperature: await getTemperature(settings.selectedModel), temperature: await getTemperature(settings.selectedModel),
maxRetries: 2, maxRetries: 2,
model: modelClient.model, model: modelClient.model,
...@@ -798,7 +798,7 @@ This conversation includes one or more image attachments. When the user uploads ...@@ -798,7 +798,7 @@ This conversation includes one or more image attachments. When the user uploads
break; break;
} }
if (part.type !== "text-delta") continue; // ignore reasoning for continuation if (part.type !== "text-delta") continue; // ignore reasoning for continuation
fullResponse += part.textDelta; fullResponse += part.text;
fullResponse = cleanFullResponse(fullResponse); fullResponse = cleanFullResponse(fullResponse);
fullResponse = await processResponseChunkUpdate({ fullResponse = await processResponseChunkUpdate({
fullResponse, fullResponse,
...@@ -825,7 +825,7 @@ This conversation includes one or more image attachments. When the user uploads ...@@ -825,7 +825,7 @@ This conversation includes one or more image attachments. When the user uploads
let autoFixAttempts = 0; let autoFixAttempts = 0;
const originalFullResponse = fullResponse; const originalFullResponse = fullResponse;
const previousAttempts: CoreMessage[] = []; const previousAttempts: ModelMessage[] = [];
while ( while (
problemReport.problems.length > 0 && problemReport.problems.length > 0 &&
autoFixAttempts < 2 && autoFixAttempts < 2 &&
...@@ -1161,9 +1161,9 @@ async function replaceTextAttachmentWithContent( ...@@ -1161,9 +1161,9 @@ async function replaceTextAttachmentWithContent(
// Helper function to convert traditional message to one with proper image attachments // Helper function to convert traditional message to one with proper image attachments
async function prepareMessageWithAttachments( async function prepareMessageWithAttachments(
message: CoreMessage, message: ModelMessage,
attachmentPaths: string[], attachmentPaths: string[],
): Promise<CoreMessage> { ): Promise<ModelMessage> {
let textContent = message.content; let textContent = message.content;
// Get the original text content // Get the original text content
if (typeof textContent !== "string") { if (typeof textContent !== "string") {
......
...@@ -37,7 +37,9 @@ export function parseOllamaHost(host?: string): string { ...@@ -37,7 +37,9 @@ export function parseOllamaHost(host?: string): string {
return `http://${host}:11434`; return `http://${host}:11434`;
} }
const OLLAMA_API_URL = parseOllamaHost(process.env.OLLAMA_HOST); export function getOllamaApiUrl(): string {
return parseOllamaHost(process.env.OLLAMA_HOST);
}
interface OllamaModel { interface OllamaModel {
name: string; name: string;
...@@ -55,7 +57,7 @@ interface OllamaModel { ...@@ -55,7 +57,7 @@ interface OllamaModel {
export async function fetchOllamaModels(): Promise<LocalModelListResponse> { export async function fetchOllamaModels(): Promise<LocalModelListResponse> {
try { try {
const response = await fetch(`${OLLAMA_API_URL}/api/tags`); const response = await fetch(`${getOllamaApiUrl()}/api/tags`);
if (!response.ok) { if (!response.ok) {
throw new Error(`Failed to fetch model: ${response.statusText}`); throw new Error(`Failed to fetch model: ${response.statusText}`);
} }
......
import { LanguageModelV1 } from "ai";
import { createOpenAI } from "@ai-sdk/openai"; import { createOpenAI } from "@ai-sdk/openai";
import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google"; import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google";
import { createAnthropic } from "@ai-sdk/anthropic"; import { createAnthropic } from "@ai-sdk/anthropic";
import { createOpenRouter } from "@openrouter/ai-sdk-provider"; import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { createOllama } from "ollama-ai-provider";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas"; import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
import { getEnvVar } from "./read_env"; import { getEnvVar } from "./read_env";
...@@ -13,6 +11,9 @@ import { LanguageModelProvider } from "../ipc_types"; ...@@ -13,6 +11,9 @@ import { LanguageModelProvider } from "../ipc_types";
import { createDyadEngine } from "./llm_engine_provider"; import { createDyadEngine } from "./llm_engine_provider";
import { LM_STUDIO_BASE_URL } from "./lm_studio_utils"; import { LM_STUDIO_BASE_URL } from "./lm_studio_utils";
import { LanguageModel } from "ai";
import { createOllamaProvider } from "./ollama_provider";
import { getOllamaApiUrl } from "../handlers/local_model_ollama_handler";
const dyadEngineUrl = process.env.DYAD_ENGINE_URL; const dyadEngineUrl = process.env.DYAD_ENGINE_URL;
const dyadGatewayUrl = process.env.DYAD_GATEWAY_URL; const dyadGatewayUrl = process.env.DYAD_GATEWAY_URL;
...@@ -33,7 +34,7 @@ const AUTO_MODELS = [ ...@@ -33,7 +34,7 @@ const AUTO_MODELS = [
]; ];
export interface ModelClient { export interface ModelClient {
model: LanguageModelV1; model: LanguageModel;
builtinProviderId?: string; builtinProviderId?: string;
} }
...@@ -168,7 +169,10 @@ function getRegularModelClient( ...@@ -168,7 +169,10 @@ function getRegularModelClient(
model: LargeLanguageModel, model: LargeLanguageModel,
settings: UserSettings, settings: UserSettings,
providerConfig: LanguageModelProvider, providerConfig: LanguageModelProvider,
) { ): {
modelClient: ModelClient;
backupModelClients: ModelClient[];
} {
// Get API key for the specific provider // Get API key for the specific provider
const apiKey = const apiKey =
settings.providerSettings?.[model.provider]?.apiKey?.value || settings.providerSettings?.[model.provider]?.apiKey?.value ||
...@@ -220,13 +224,11 @@ function getRegularModelClient( ...@@ -220,13 +224,11 @@ function getRegularModelClient(
}; };
} }
case "ollama": { case "ollama": {
// Ollama typically runs locally and doesn't require an API key in the same way const provider = createOllamaProvider({ baseURL: getOllamaApiUrl() });
const provider = createOllama({
baseURL: process.env.OLLAMA_HOST,
});
return { return {
modelClient: { modelClient: {
model: provider(model.name), model: provider(model.name),
builtinProviderId: providerId,
}, },
backupModelClients: [], backupModelClients: [],
}; };
......
import { import { LanguageModel } from "ai";
LanguageModelV1,
LanguageModelV1ObjectGenerationMode,
} from "@ai-sdk/provider";
import { OpenAICompatibleChatLanguageModel } from "@ai-sdk/openai-compatible"; import { OpenAICompatibleChatLanguageModel } from "@ai-sdk/openai-compatible";
import { import {
FetchFunction, FetchFunction,
...@@ -9,7 +6,6 @@ import { ...@@ -9,7 +6,6 @@ import {
withoutTrailingSlash, withoutTrailingSlash,
} from "@ai-sdk/provider-utils"; } from "@ai-sdk/provider-utils";
import { OpenAICompatibleChatSettings } from "@ai-sdk/openai-compatible";
import log from "electron-log"; import log from "electron-log";
import { getExtraProviderOptions } from "./thinking_utils"; import { getExtraProviderOptions } from "./thinking_utils";
import type { UserSettings } from "../../lib/schemas"; import type { UserSettings } from "../../lib/schemas";
...@@ -18,7 +14,7 @@ const logger = log.scope("llm_engine_provider"); ...@@ -18,7 +14,7 @@ const logger = log.scope("llm_engine_provider");
export type ExampleChatModelId = string & {}; export type ExampleChatModelId = string & {};
export interface ExampleChatSettings extends OpenAICompatibleChatSettings { export interface ExampleChatSettings {
files?: { path: string; content: string }[]; files?: { path: string; content: string }[];
} }
export interface ExampleProviderSettings { export interface ExampleProviderSettings {
...@@ -56,10 +52,7 @@ export interface DyadEngineProvider { ...@@ -56,10 +52,7 @@ export interface DyadEngineProvider {
/** /**
Creates a model for text generation. Creates a model for text generation.
*/ */
( (modelId: ExampleChatModelId, settings?: ExampleChatSettings): LanguageModel;
modelId: ExampleChatModelId,
settings?: ExampleChatSettings,
): LanguageModelV1;
/** /**
Creates a chat model for text generation. Creates a chat model for text generation.
...@@ -67,7 +60,7 @@ Creates a chat model for text generation. ...@@ -67,7 +60,7 @@ Creates a chat model for text generation.
chatModel( chatModel(
modelId: ExampleChatModelId, modelId: ExampleChatModelId,
settings?: ExampleChatSettings, settings?: ExampleChatSettings,
): LanguageModelV1; ): LanguageModel;
} }
export function createDyadEngine( export function createDyadEngine(
...@@ -113,13 +106,13 @@ export function createDyadEngine( ...@@ -113,13 +106,13 @@ export function createDyadEngine(
settings: ExampleChatSettings = {}, settings: ExampleChatSettings = {},
) => { ) => {
// Extract files from settings to process them appropriately // Extract files from settings to process them appropriately
const { files, ...restSettings } = settings; const { files } = settings;
// Create configuration with file handling // Create configuration with file handling
const config = { const config = {
...getCommonModelConfig(), ...getCommonModelConfig(),
defaultObjectGenerationMode: // defaultObjectGenerationMode:
"tool" as LanguageModelV1ObjectGenerationMode, // "tool" as LanguageModelV1ObjectGenerationMode,
// Custom fetch implementation that adds files to the request // Custom fetch implementation that adds files to the request
fetch: (input: RequestInfo | URL, init?: RequestInit) => { fetch: (input: RequestInfo | URL, init?: RequestInit) => {
// Use default fetch if no init or body // Use default fetch if no init or body
...@@ -181,7 +174,7 @@ export function createDyadEngine( ...@@ -181,7 +174,7 @@ export function createDyadEngine(
}, },
}; };
return new OpenAICompatibleChatLanguageModel(modelId, restSettings, config); return new OpenAICompatibleChatLanguageModel(modelId, config);
}; };
const provider = ( const provider = (
......
import { LanguageModel } from "ai";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import type { FetchFunction } from "@ai-sdk/provider-utils";
import { withoutTrailingSlash } from "@ai-sdk/provider-utils";
import type {} from "@ai-sdk/provider";
type OllamaChatModelId = string;
export interface OllamaProviderOptions {
/**
* Base URL for the Ollama API. For real Ollama, use e.g. http://localhost:11434/api
* The provider will POST to `${baseURL}/chat`.
* If undefined, defaults to http://localhost:11434/api
*/
baseURL?: string;
headers?: Record<string, string>;
fetch?: FetchFunction;
}
export interface OllamaChatSettings {}
export interface OllamaProvider {
(modelId: OllamaChatModelId, settings?: OllamaChatSettings): LanguageModel;
}
export function createOllamaProvider(
options?: OllamaProviderOptions,
): OllamaProvider {
const base = withoutTrailingSlash(
options?.baseURL ?? "http://localhost:11434",
)!;
const v1Base = (base.endsWith("/v1") ? base : `${base}/v1`) as string;
const provider = createOpenAICompatible({
name: "ollama",
baseURL: v1Base,
headers: options?.headers,
});
return (modelId: OllamaChatModelId) => provider(modelId);
}
...@@ -94,51 +94,6 @@ app.get("/ollama/api/tags", (req, res) => { ...@@ -94,51 +94,6 @@ app.get("/ollama/api/tags", (req, res) => {
res.json(ollamaModels); res.json(ollamaModels);
}); });
app.post("/ollama/chat", (req, res) => {
// Tell the client we're going to stream NDJSON
res.setHeader("Content-Type", "application/x-ndjson");
res.setHeader("Cache-Control", "no-cache");
// Chunk #1 – partial answer
const firstChunk = {
model: "llama3.2",
created_at: "2023-08-04T08:52:19.385406455-07:00",
message: {
role: "assistant",
content: "ollamachunk",
images: null,
},
done: false,
};
// Chunk #2 – final answer + metrics
const secondChunk = {
model: "llama3.2",
created_at: "2023-08-04T19:22:45.499127Z",
message: {
role: "assistant",
content: "",
},
done: true,
total_duration: 4883583458,
load_duration: 1334875,
prompt_eval_count: 26,
prompt_eval_duration: 342546000,
eval_count: 282,
eval_duration: 4535599000,
};
// Send the first object right away
res.write(JSON.stringify(firstChunk) + "\n");
res.write(JSON.stringify(firstChunk) + "\n");
// …and the second one a moment later to mimic streaming
setTimeout(() => {
res.write(JSON.stringify(secondChunk) + "\n");
res.end(); // Close the HTTP stream
}, 300); // 300 ms delay – tweak as you like
});
// LM Studio specific endpoints // LM Studio specific endpoints
app.get("/lmstudio/api/v0/models", (req, res) => { app.get("/lmstudio/api/v0/models", (req, res) => {
const lmStudioModels = { const lmStudioModels = {
...@@ -182,7 +137,7 @@ app.get("/lmstudio/api/v0/models", (req, res) => { ...@@ -182,7 +137,7 @@ app.get("/lmstudio/api/v0/models", (req, res) => {
res.json(lmStudioModels); res.json(lmStudioModels);
}); });
["lmstudio", "gateway", "engine"].forEach((provider) => { ["lmstudio", "gateway", "engine", "ollama"].forEach((provider) => {
app.post( app.post(
`/${provider}/v1/chat/completions`, `/${provider}/v1/chat/completions`,
createChatCompletionHandler(provider), createChatCompletionHandler(provider),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论