Unverified 提交 069c2212 authored 作者: Will Chen's avatar Will Chen 提交者: GitHub

Implement saver mode (#154)

上级 3763423d
import { ipcMain } from "electron"; import { ipcMain } from "electron";
import { CoreMessage, TextPart, ImagePart, streamText } from "ai"; import { CoreMessage, TextPart, ImagePart } from "ai";
import { db } from "../../db"; import { db } from "../../db";
import { chats, messages } from "../../db/schema"; import { chats, messages } from "../../db/schema";
import { and, eq, isNull } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
...@@ -29,6 +29,7 @@ import * as crypto from "crypto"; ...@@ -29,6 +29,7 @@ import * as crypto from "crypto";
import { readFile, writeFile, unlink } from "fs/promises"; import { readFile, writeFile, unlink } from "fs/promises";
import { getMaxTokens } from "../utils/token_utils"; import { getMaxTokens } from "../utils/token_utils";
import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants"; import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants";
import { streamTextWithBackup } from "../utils/stream_utils";
const logger = log.scope("chat_stream_handlers"); const logger = log.scope("chat_stream_handlers");
...@@ -214,7 +215,7 @@ export function registerChatStreamHandlers() { ...@@ -214,7 +215,7 @@ export function registerChatStreamHandlers() {
} else { } else {
// Normal AI processing for non-test prompts // Normal AI processing for non-test prompts
const settings = readSettings(); const settings = readSettings();
const modelClient = await getModelClient( const { modelClient, backupModelClients } = await getModelClient(
settings.selectedModel, settings.selectedModel,
settings, settings,
); );
...@@ -372,13 +373,14 @@ This conversation includes one or more image attachments. When the user uploads ...@@ -372,13 +373,14 @@ This conversation includes one or more image attachments. When the user uploads
} }
// When calling streamText, the messages need to be properly formatted for mixed content // When calling streamText, the messages need to be properly formatted for mixed content
const { textStream } = streamText({ const { textStream } = streamTextWithBackup({
maxTokens: await getMaxTokens(settings.selectedModel), maxTokens: await getMaxTokens(settings.selectedModel),
temperature: 0, temperature: 0,
model: modelClient, model: modelClient,
backupModelClients: backupModelClients,
system: systemPrompt, system: systemPrompt,
messages: chatMessages.filter((m) => m.content), messages: chatMessages.filter((m) => m.content),
onError: (error) => { onError: (error: any) => {
logger.error("Error streaming text:", error); logger.error("Error streaming text:", error);
const message = const message =
(error as any)?.error?.message || JSON.stringify(error); (error as any)?.error?.message || JSON.stringify(error);
......
...@@ -12,7 +12,9 @@ export function createLoggedHandler(logger: log.LogFunctions) { ...@@ -12,7 +12,9 @@ export function createLoggedHandler(logger: log.LogFunctions) {
logger.log(`IPC: ${channel} called with args: ${JSON.stringify(args)}`); logger.log(`IPC: ${channel} called with args: ${JSON.stringify(args)}`);
try { try {
const result = await fn(event, ...args); const result = await fn(event, ...args);
logger.log(`IPC: ${channel} returned: ${JSON.stringify(result)}`); logger.log(
`IPC: ${channel} returned: ${JSON.stringify(result).slice(0, 100)}...`,
);
return result; return result;
} catch (error) { } catch (error) {
logger.error( logger.error(
......
import { LanguageModelV1 } from "ai";
import { createOpenAI } from "@ai-sdk/openai"; import { createOpenAI } from "@ai-sdk/openai";
import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google"; import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google";
import { createAnthropic } from "@ai-sdk/anthropic"; import { createAnthropic } from "@ai-sdk/anthropic";
...@@ -8,6 +9,8 @@ import type { LargeLanguageModel, UserSettings } from "../../lib/schemas"; ...@@ -8,6 +9,8 @@ import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
import { getEnvVar } from "./read_env"; import { getEnvVar } from "./read_env";
import log from "electron-log"; import log from "electron-log";
import { getLanguageModelProviders } from "../shared/language_model_helpers"; import { getLanguageModelProviders } from "../shared/language_model_helpers";
import { LanguageModelProvider } from "../ipc_types";
import { llmErrorStore } from "@/main/llm_error_store";
const AUTO_MODELS = [ const AUTO_MODELS = [
{ {
...@@ -24,11 +27,19 @@ const AUTO_MODELS = [ ...@@ -24,11 +27,19 @@ const AUTO_MODELS = [
}, },
]; ];
export interface ModelClient {
model: LanguageModelV1;
builtinProviderId?: string;
}
const logger = log.scope("getModelClient"); const logger = log.scope("getModelClient");
export async function getModelClient( export async function getModelClient(
model: LargeLanguageModel, model: LargeLanguageModel,
settings: UserSettings, settings: UserSettings,
) { ): Promise<{
modelClient: ModelClient;
backupModelClients: ModelClient[];
}> {
const allProviders = await getLanguageModelProviders(); const allProviders = await getLanguageModelProviders();
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value; const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
...@@ -83,7 +94,44 @@ export async function getModelClient( ...@@ -83,7 +94,44 @@ export async function getModelClient(
logger.info("Using Dyad Pro API key via Gateway"); logger.info("Using Dyad Pro API key via Gateway");
// Do not use free variant (for openrouter). // Do not use free variant (for openrouter).
const modelName = model.name.split(":free")[0]; const modelName = model.name.split(":free")[0];
return provider(`${providerConfig.gatewayPrefix}${modelName}`); const autoModelClient = {
model: provider(`${providerConfig.gatewayPrefix}${modelName}`),
builtinProviderId: "auto",
};
const googleSettings = settings.providerSettings?.google;
// Budget saver mode logic (all must be true):
// 1. Pro Saver Mode is enabled
// 2. Provider is Google
// 3. API Key is set
// 4. Has no recent errors
if (
settings.enableProSaverMode &&
providerConfig.id === "google" &&
googleSettings &&
googleSettings.apiKey?.value &&
llmErrorStore.modelHasNoRecentError({
model: model.name,
provider: providerConfig.id,
})
) {
return {
modelClient: getRegularModelClient(
{
provider: providerConfig.id,
name: model.name,
},
settings,
providerConfig,
).modelClient,
backupModelClients: [autoModelClient],
};
} else {
return {
modelClient: autoModelClient,
backupModelClients: [],
};
}
} else { } else {
logger.warn( logger.warn(
`Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`, `Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`,
...@@ -91,7 +139,14 @@ export async function getModelClient( ...@@ -91,7 +139,14 @@ export async function getModelClient(
// Fall through to regular provider logic if gateway prefix is missing // Fall through to regular provider logic if gateway prefix is missing
} }
} }
return getRegularModelClient(model, settings, providerConfig);
}
function getRegularModelClient(
model: LargeLanguageModel,
settings: UserSettings,
providerConfig: LanguageModelProvider,
) {
// Get API key for the specific provider // Get API key for the specific provider
const apiKey = const apiKey =
settings.providerSettings?.[model.provider]?.apiKey?.value || settings.providerSettings?.[model.provider]?.apiKey?.value ||
...@@ -99,30 +154,60 @@ export async function getModelClient( ...@@ -99,30 +154,60 @@ export async function getModelClient(
? getEnvVar(providerConfig.envVarName) ? getEnvVar(providerConfig.envVarName)
: undefined); : undefined);
const providerId = providerConfig.id;
// Create client based on provider ID or type // Create client based on provider ID or type
switch (providerConfig.id) { switch (providerId) {
case "openai": { case "openai": {
const provider = createOpenAI({ apiKey }); const provider = createOpenAI({ apiKey });
return provider(model.name); return {
modelClient: {
model: provider(model.name),
builtinProviderId: providerId,
},
backupModelClients: [],
};
} }
case "anthropic": { case "anthropic": {
const provider = createAnthropic({ apiKey }); const provider = createAnthropic({ apiKey });
return provider(model.name); return {
modelClient: {
model: provider(model.name),
builtinProviderId: providerId,
},
backupModelClients: [],
};
} }
case "google": { case "google": {
const provider = createGoogle({ apiKey }); const provider = createGoogle({ apiKey });
return provider(model.name); return {
modelClient: {
model: provider(model.name),
builtinProviderId: providerId,
},
backupModelClients: [],
};
} }
case "openrouter": { case "openrouter": {
const provider = createOpenRouter({ apiKey }); const provider = createOpenRouter({ apiKey });
return provider(model.name); return {
modelClient: {
model: provider(model.name),
builtinProviderId: providerId,
},
backupModelClients: [],
};
} }
case "ollama": { case "ollama": {
// Ollama typically runs locally and doesn't require an API key in the same way // Ollama typically runs locally and doesn't require an API key in the same way
const provider = createOllama({ const provider = createOllama({
baseURL: providerConfig.apiBaseUrl, baseURL: providerConfig.apiBaseUrl,
}); });
return provider(model.name); return {
modelClient: {
model: provider(model.name),
},
backupModelClients: [],
};
} }
case "lmstudio": { case "lmstudio": {
// LM Studio uses OpenAI compatible API // LM Studio uses OpenAI compatible API
...@@ -131,7 +216,12 @@ export async function getModelClient( ...@@ -131,7 +216,12 @@ export async function getModelClient(
name: "lmstudio", name: "lmstudio",
baseURL, baseURL,
}); });
return provider(model.name); return {
modelClient: {
model: provider(model.name),
},
backupModelClients: [],
};
} }
default: { default: {
// Handle custom providers // Handle custom providers
...@@ -147,7 +237,12 @@ export async function getModelClient( ...@@ -147,7 +237,12 @@ export async function getModelClient(
baseURL: providerConfig.apiBaseUrl, baseURL: providerConfig.apiBaseUrl,
apiKey: apiKey, apiKey: apiKey,
}); });
return provider(model.name); return {
modelClient: {
model: provider(model.name),
},
backupModelClients: [],
};
} }
// If it's not a known ID and not type 'custom', it's unsupported // If it's not a known ID and not type 'custom', it's unsupported
throw new Error(`Unsupported model provider: ${model.provider}`); throw new Error(`Unsupported model provider: ${model.provider}`);
......
import { streamText } from "ai";
import log from "electron-log";
import { ModelClient } from "./get_model_client";
import { llmErrorStore } from "@/main/llm_error_store";
const logger = log.scope("stream_utils");
export interface StreamTextWithBackupParams
extends Omit<Parameters<typeof streamText>[0], "model"> {
model: ModelClient; // primary client
backupModelClients?: ModelClient[]; // ordered fall-backs
}
export function streamTextWithBackup(params: StreamTextWithBackupParams): {
textStream: AsyncIterable<string>;
} {
const {
model: primaryModel,
backupModelClients = [],
onError: callerOnError,
abortSignal: callerAbort,
...rest
} = params;
const modelClients: ModelClient[] = [primaryModel, ...backupModelClients];
async function* combinedGenerator(): AsyncIterable<string> {
let lastErr: { error: unknown } | undefined = undefined;
for (let i = 0; i < modelClients.length; i++) {
const currentModelClient = modelClients[i];
/* Local abort controller for this single attempt */
const attemptAbort = new AbortController();
if (callerAbort) {
if (callerAbort.aborted) {
// Already aborted, trigger immediately
attemptAbort.abort();
} else {
callerAbort.addEventListener("abort", () => attemptAbort.abort(), {
once: true,
});
}
}
let errorFromCurrent: { error: unknown } | undefined = undefined; // set when onError fires
const providerId = currentModelClient.builtinProviderId;
if (providerId) {
llmErrorStore.clearModelError({
model: currentModelClient.model.modelId,
provider: providerId,
});
}
logger.info(
"Streaming text with model",
currentModelClient.model.modelId,
"provider",
currentModelClient.model.provider,
"builtinProviderId",
currentModelClient.builtinProviderId,
);
const { textStream } = streamText({
...rest,
maxRetries: 0,
model: currentModelClient.model,
abortSignal: attemptAbort.signal,
onError: (error) => {
const providerId = currentModelClient.builtinProviderId;
if (providerId) {
llmErrorStore.recordModelError({
model: currentModelClient.model.modelId,
provider: providerId,
});
}
logger.error(
`Error streaming text with ${providerId} and model ${currentModelClient.model.modelId}: ${error}`,
error,
);
errorFromCurrent = error;
attemptAbort.abort(); // kill fetch / SSE
},
});
try {
for await (const chunk of textStream) {
/* If onError fired during streaming, bail out immediately. */
if (errorFromCurrent) throw errorFromCurrent;
yield chunk;
}
/* Stream ended – check if it actually failed */
if (errorFromCurrent) throw errorFromCurrent;
/* Completed successfully – stop trying more models. */
return;
} catch (err) {
if (typeof err === "object" && err !== null && "error" in err) {
lastErr = err as { error: unknown };
} else {
lastErr = { error: err };
}
logger.warn(
`[streamTextWithBackup] model #${i} failed – ${
i < modelClients.length - 1
? "switching to backup"
: "no backups left"
}`,
err,
);
/* loop continues to next model (if any) */
}
}
/* Every model failed */
if (!lastErr) {
throw new Error("Invariant in StreamTextWithbackup failed!");
}
callerOnError?.(lastErr);
logger.error("All model invocations failed", lastErr);
// throw lastErr ?? new Error("All model invocations failed");
}
return { textStream: combinedGenerator() };
}
class LlmErrorStore {
private modelErrorToTimestamp: Record<string, number> = {};
constructor() {}
recordModelError({ model, provider }: { model: string; provider: string }) {
this.modelErrorToTimestamp[this.getKey({ model, provider })] = Date.now();
}
clearModelError({ model, provider }: { model: string; provider: string }) {
delete this.modelErrorToTimestamp[this.getKey({ model, provider })];
}
modelHasNoRecentError({
model,
provider,
}: {
model: string;
provider: string;
}): boolean {
const key = this.getKey({ model, provider });
const timestamp = this.modelErrorToTimestamp[key];
if (!timestamp) {
return true;
}
const oneHourAgo = Date.now() - 1000 * 60 * 60;
return timestamp < oneHourAgo;
}
private getKey({ model, provider }: { model: string; provider: string }) {
return `${provider}::${model}`;
}
}
export const llmErrorStore = new LlmErrorStore();
# Fake LLM Server
A simple server that mimics the OpenAI streaming chat completions API for testing purposes.
## Features
- Implements a basic version of the OpenAI chat completions API
- Supports both streaming and non-streaming responses
- Always responds with "hello world" message
- Simulates a 429 rate limit error when the last message is "[429]"
- Configurable through environment variables
## Installation
```bash
npm install
```
## Usage
Start the server:
```bash
# Development mode
npm run dev
# Production mode
npm run build
npm start
```
### Example usage
```
curl -X POST http://localhost:3500/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"Say something"}],"model":"any-model","stream":true}'
```
The server will be available at http://localhost:3500 by default.
## API Endpoints
### POST /v1/chat/completions
This endpoint mimics OpenAI's chat completions API.
#### Request Format
```json
{
"messages": [{ "role": "user", "content": "Your prompt here" }],
"model": "any-model",
"stream": true
}
```
- Set `stream: true` to receive a streaming response
- Set `stream: false` or omit it for a regular JSON response
#### Response
For non-streaming requests, you'll get a standard JSON response:
```json
{
"id": "chatcmpl-123456789",
"object": "chat.completion",
"created": 1699000000,
"model": "fake-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "hello world"
},
"finish_reason": "stop"
}
]
}
```
For streaming requests, you'll receive a series of server-sent events (SSE), each containing a chunk of the response.
### Simulating Rate Limit Errors
To test how your application handles rate limiting, send a message with content exactly equal to `[429]`:
```json
{
"messages": [{ "role": "user", "content": "[429]" }],
"model": "any-model"
}
```
This will return a 429 status code with the following response:
```json
{
"error": {
"message": "Too many requests. Please try again later.",
"type": "rate_limit_error",
"param": null,
"code": "rate_limit_exceeded"
}
}
```
## Configuration
You can configure the server by modifying the `PORT` variable in the code.
## Use Case
This server is primarily intended for testing applications that integrate with OpenAI's API, allowing you to develop and test without making actual API calls to OpenAI.
"use strict";
var __importDefault =
(this && this.__importDefault) ||
function (mod) {
return mod && mod.__esModule ? mod : { default: mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
const express_1 = __importDefault(require("express"));
const http_1 = require("http");
const cors_1 = __importDefault(require("cors"));
// Create Express app
const app = (0, express_1.default)();
app.use((0, cors_1.default)());
app.use(express_1.default.json());
const PORT = 3500;
// Helper function to create OpenAI-like streaming response chunks
function createStreamChunk(content, role = "assistant", isLast = false) {
const chunk = {
id: `chatcmpl-${Date.now()}`,
object: "chat.completion.chunk",
created: Math.floor(Date.now() / 1000),
model: "fake-model",
choices: [
{
index: 0,
delta: isLast ? {} : { content, role },
finish_reason: isLast ? "stop" : null,
},
],
};
return `data: ${JSON.stringify(chunk)}\n\n${isLast ? "data: [DONE]\n\n" : ""}`;
}
// Handle POST requests to /v1/chat/completions
app.post("/v1/chat/completions", (req, res) => {
const { stream = false } = req.body;
// Non-streaming response
if (!stream) {
return res.json({
id: `chatcmpl-${Date.now()}`,
object: "chat.completion",
created: Math.floor(Date.now() / 1000),
model: "fake-model",
choices: [
{
index: 0,
message: {
role: "assistant",
content: "hello world",
},
finish_reason: "stop",
},
],
});
}
// Streaming response
res.setHeader("Content-Type", "text/event-stream");
res.setHeader("Cache-Control", "no-cache");
res.setHeader("Connection", "keep-alive");
// Split the "hello world" message into characters to simulate streaming
const message = "hello world";
const messageChars = message.split("");
// Stream each character with a delay
let index = 0;
// Send role first
res.write(createStreamChunk("", "assistant"));
const interval = setInterval(() => {
if (index < messageChars.length) {
res.write(createStreamChunk(messageChars[index]));
index++;
} else {
// Send the final chunk
res.write(createStreamChunk("", "assistant", true));
clearInterval(interval);
res.end();
}
}, 100);
});
// Start the server
const server = (0, http_1.createServer)(app);
server.listen(PORT, () => {
console.log(`Fake LLM server running on http://localhost:${PORT}`);
});
// Handle SIGINT (Ctrl+C)
process.on("SIGINT", () => {
console.log("Shutting down fake LLM server");
server.close(() => {
console.log("Server closed");
process.exit(0);
});
});
import express from "express";
import { createServer } from "http";
import cors from "cors";
// Create Express app
const app = express();
app.use(cors());
app.use(express.json());
const PORT = 3500;
// Helper function to create OpenAI-like streaming response chunks
function createStreamChunk(
content: string,
role: string = "assistant",
isLast: boolean = false,
) {
const chunk = {
id: `chatcmpl-${Date.now()}`,
object: "chat.completion.chunk",
created: Math.floor(Date.now() / 1000),
model: "fake-model",
choices: [
{
index: 0,
delta: isLast ? {} : { content, role },
finish_reason: isLast ? "stop" : null,
},
],
};
return `data: ${JSON.stringify(chunk)}\n\n${isLast ? "data: [DONE]\n\n" : ""}`;
}
// Handle POST requests to /v1/chat/completions
app.post("/v1/chat/completions", (req, res) => {
const { stream = false, messages = [] } = req.body;
// Check if the last message contains "[429]" to simulate rate limiting
const lastMessage = messages[messages.length - 1];
if (lastMessage && lastMessage.content === "[429]") {
return res.status(429).json({
error: {
message: "Too many requests. Please try again later.",
type: "rate_limit_error",
param: null,
code: "rate_limit_exceeded",
},
});
}
// Non-streaming response
if (!stream) {
return res.json({
id: `chatcmpl-${Date.now()}`,
object: "chat.completion",
created: Math.floor(Date.now() / 1000),
model: "fake-model",
choices: [
{
index: 0,
message: {
role: "assistant",
content: "hello world",
},
finish_reason: "stop",
},
],
});
}
// Streaming response
res.setHeader("Content-Type", "text/event-stream");
res.setHeader("Cache-Control", "no-cache");
res.setHeader("Connection", "keep-alive");
// Split the "hello world" message into characters to simulate streaming
const message = "hello world";
const messageChars = message.split("");
// Stream each character with a delay
let index = 0;
// Send role first
res.write(createStreamChunk("", "assistant"));
const interval = setInterval(() => {
if (index < messageChars.length) {
res.write(createStreamChunk(messageChars[index]));
index++;
} else {
// Send the final chunk
res.write(createStreamChunk("", "assistant", true));
clearInterval(interval);
res.end();
}
}, 100);
});
// Start the server
const server = createServer(app);
server.listen(PORT, () => {
console.log(`Fake LLM server running on http://localhost:${PORT}`);
});
// Handle SIGINT (Ctrl+C)
process.on("SIGINT", () => {
console.log("Shutting down fake LLM server");
server.close(() => {
console.log("Server closed");
process.exit(0);
});
});
差异被折叠。
{
"name": "fake-llm-server",
"version": "1.0.0",
"main": "dist/index.js",
"scripts": {
"build": "tsc",
"start": "node dist/index.js",
"dev": "ts-node index.ts",
"test": "echo \"Error: no test specified\" && exit 1"
},
"keywords": [],
"author": "",
"license": "ISC",
"description": "Fake OpenAI API server for testing",
"dependencies": {
"cors": "^2.8.5",
"express": "^4.18.2",
"stream": "0.0.2"
},
"devDependencies": {
"@types/cors": "^2.8.18",
"@types/express": "^4.17.21",
"@types/node": "^20.17.46",
"ts-node": "^10.9.2",
"typescript": "^5.8.3"
}
}
{
"compilerOptions": {
"target": "ES2020",
"module": "commonjs",
"outDir": "dist",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true
},
"include": ["*.ts"],
"exclude": ["node_modules"]
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论