提交 272302f3 authored 作者: Vittorio's avatar Vittorio

修复connector的提示注入问题

上级 7ac9c84a
import { beforeEach, describe, expect, it, vi } from "vitest";
const {
mockFindConnector,
mockFindChats,
mockDeleteWhere,
mockDelete,
mockUpdateWhere,
mockUpdateSet,
mockUpdate,
mockPragmaAll,
mockLegacyRun,
mockPrepare,
} = vi.hoisted(() => {
const mockFindConnector = vi.fn();
const mockFindChats = vi.fn();
const mockDeleteWhere = vi.fn();
const mockDelete = vi.fn(() => ({
where: mockDeleteWhere,
}));
const mockUpdateWhere = vi.fn();
const mockUpdateSet = vi.fn(() => ({
where: mockUpdateWhere,
}));
const mockUpdate = vi.fn(() => ({
set: mockUpdateSet,
}));
const mockPragmaAll = vi.fn();
const mockLegacyRun = vi.fn();
const mockPrepare = vi.fn((sql: string) => {
if (sql === "PRAGMA table_info(apps)") {
return { all: mockPragmaAll };
}
if (sql === "UPDATE apps SET connector_id = NULL WHERE connector_id = ?") {
return { run: mockLegacyRun };
}
throw new Error(`Unexpected SQL: ${sql}`);
});
return {
mockFindConnector,
mockFindChats,
mockDeleteWhere,
mockDelete,
mockUpdateWhere,
mockUpdateSet,
mockUpdate,
mockPragmaAll,
mockLegacyRun,
mockPrepare,
};
});
vi.mock("../db", () => ({
db: {
query: {
connectors: {
findFirst: mockFindConnector,
},
chats: {
findMany: mockFindChats,
},
},
update: mockUpdate,
delete: mockDelete,
$client: {
prepare: mockPrepare,
},
},
}));
import { deleteConnectorWithCleanup } from "../ipc/handlers/connector_handlers";
import { DyadError } from "@/errors/dyad_error";
describe("deleteConnectorWithCleanup", () => {
beforeEach(() => {
vi.clearAllMocks();
mockDeleteWhere.mockResolvedValue(undefined);
mockUpdateWhere.mockResolvedValue(undefined);
});
it("cleans up legacy app references and chat connector ids before delete", async () => {
mockFindConnector.mockResolvedValue({
id: 3,
name: "Test Connector",
});
mockPragmaAll.mockReturnValue([{ name: "id" }, { name: "connector_id" }]);
mockFindChats.mockResolvedValue([
{ id: 10, connectorIdsJson: [1, 3, 5] },
{ id: 11, connectorIdsJson: [3] },
{ id: 12, connectorIdsJson: [7] },
{ id: 13, connectorIdsJson: null },
]);
await deleteConnectorWithCleanup(3);
expect(mockLegacyRun).toHaveBeenCalledWith(3);
expect(mockUpdate).toHaveBeenCalledTimes(2);
expect(mockUpdateSet).toHaveBeenNthCalledWith(1, {
connectorIdsJson: [1, 5],
});
expect(mockUpdateSet).toHaveBeenNthCalledWith(2, {
connectorIdsJson: null,
});
expect(mockDelete).toHaveBeenCalledTimes(1);
expect(mockDeleteWhere).toHaveBeenCalledTimes(1);
});
it("skips legacy cleanup when the old apps.connector_id column is absent", async () => {
mockFindConnector.mockResolvedValue({
id: 4,
name: "Modern Connector",
});
mockPragmaAll.mockReturnValue([{ name: "id" }, { name: "path" }]);
mockFindChats.mockResolvedValue([]);
await deleteConnectorWithCleanup(4);
expect(mockLegacyRun).not.toHaveBeenCalled();
expect(mockDelete).toHaveBeenCalledTimes(1);
});
it("throws a not found error when the connector does not exist", async () => {
mockFindConnector.mockResolvedValue(undefined);
const deletion = deleteConnectorWithCleanup(999);
await expect(deletion).rejects.toBeInstanceOf(DyadError);
await expect(deletion).rejects.toMatchObject({
message: "Connector not found",
});
});
});
......@@ -118,7 +118,10 @@ import {
} from "./free_agent_quota_handlers";
import { AI_STREAMING_ERROR_MESSAGE_PREFIX } from "@/shared/texts";
import { getCurrentCommitHash } from "../utils/git_utils";
import { formatOpenApiConnectorSystemPrompt } from "../utils/openapi_utils";
import {
formatOpenApiConnectorSystemPrompt,
formatSelectedOpenApiConnectorRawSpecPrompt,
} from "../utils/openapi_utils";
import { resolveEffectiveConnectorIds } from "../utils/connector_selection";
import {
processChatMessagesWithVersionedFiles as getVersionedFiles,
......@@ -944,6 +947,35 @@ ${componentSnippet}
}
}
const selectedRequestConnectorIds = Array.from(
new Set((req.connectorIds ?? []).filter(Number.isFinite)),
);
if (selectedRequestConnectorIds.length > 0) {
const selectedConnectors = await db.query.connectors.findMany({
where: inArray(connectors.id, selectedRequestConnectorIds),
});
const selectedConnectorsById = new Map(
selectedConnectors.map((connector) => [connector.id, connector]),
);
for (const connectorId of selectedRequestConnectorIds) {
const connector = selectedConnectorsById.get(connectorId);
if (
connector?.type === "openapi" &&
connector.rawSpec &&
Object.keys(connector.rawSpec).length > 0
) {
systemPrompt +=
"\n\n" +
formatSelectedOpenApiConnectorRawSpecPrompt({
name: connector.name,
sourceUrl: connector.sourceUrl,
rawSpec: connector.rawSpec,
});
}
}
}
const isSummarizeIntent = req.prompt.startsWith(
"Summarize from chat-id=",
);
......
import { and, desc, eq } from "drizzle-orm";
import { db } from "../../db";
import { connectors } from "../../db/schema";
import { chats, connectors } from "../../db/schema";
import { createTypedHandler } from "./base";
import { connectorContracts } from "../types/connector";
import {
......@@ -35,6 +35,75 @@ function toConnectorDetail(connector: typeof connectors.$inferSelect) {
};
}
function hasLegacyAppConnectorColumn(): boolean {
type TableInfoRow = {
name?: string;
};
const rows = db.$client
.prepare("PRAGMA table_info(apps)")
.all() as TableInfoRow[];
return rows.some((row) => row.name === "connector_id");
}
async function removeConnectorIdFromChats(connectorId: number): Promise<void> {
const existingChats = await db.query.chats.findMany({
columns: {
id: true,
connectorIdsJson: true,
},
});
for (const chat of existingChats) {
const connectorIds = chat.connectorIdsJson ?? [];
if (!connectorIds.includes(connectorId)) {
continue;
}
const filteredConnectorIds = connectorIds.filter(
(id) => id !== connectorId,
);
await db
.update(chats)
.set({
connectorIdsJson:
filteredConnectorIds.length > 0 ? filteredConnectorIds : null,
})
.where(eq(chats.id, chat.id));
}
}
export async function deleteConnectorWithCleanup(
connectorId: number,
): Promise<void> {
const connector = await db.query.connectors.findFirst({
where: eq(connectors.id, connectorId),
});
if (!connector) {
throw new DyadError("Connector not found", DyadErrorKind.NotFound);
}
try {
if (hasLegacyAppConnectorColumn()) {
db.$client
.prepare("UPDATE apps SET connector_id = NULL WHERE connector_id = ?")
.run(connectorId);
}
await removeConnectorIdFromChats(connectorId);
await db.delete(connectors).where(eq(connectors.id, connectorId));
} catch (error) {
throw new DyadError(
error instanceof Error
? `Failed to delete connector: ${error.message}`
: "Failed to delete connector",
DyadErrorKind.External,
);
}
}
export function registerConnectorHandlers() {
createTypedHandler(connectorContracts.listConnectors, async () => {
const allConnectors = await db.query.connectors.findMany({
......@@ -173,15 +242,7 @@ export function registerConnectorHandlers() {
});
createTypedHandler(connectorContracts.deleteConnector, async (_, params) => {
const connector = await db.query.connectors.findFirst({
where: eq(connectors.id, params.connectorId),
});
if (!connector) {
throw new DyadError("Connector not found", DyadErrorKind.NotFound);
}
await db.delete(connectors).where(eq(connectors.id, params.connectorId));
await deleteConnectorWithCleanup(params.connectorId);
});
}
......
......@@ -31,7 +31,10 @@ import { isLocalAgentBackedMode, isTurboEditsV2Enabled } from "@/lib/schemas";
import { DyadError, DyadErrorKind } from "@/errors/dyad_error";
import { resolveChatModeForTurn } from "./chat_mode_resolution";
import { resolveEffectiveConnectorIds } from "../utils/connector_selection";
import { formatOpenApiConnectorSystemPrompt } from "../utils/openapi_utils";
import {
formatOpenApiConnectorSystemPrompt,
formatSelectedOpenApiConnectorRawSpecPrompt,
} from "../utils/openapi_utils";
const logger = log.scope("token_count_handlers");
......@@ -155,6 +158,35 @@ export function registerTokenCountHandlers() {
}
}
const selectedRequestConnectorIds = Array.from(
new Set((req.connectorIds ?? []).filter(Number.isFinite)),
);
if (selectedRequestConnectorIds.length > 0) {
const selectedConnectors = await db.query.connectors.findMany({
where: inArray(connectors.id, selectedRequestConnectorIds),
});
const selectedConnectorsById = new Map(
selectedConnectors.map((connector) => [connector.id, connector]),
);
for (const connectorId of selectedRequestConnectorIds) {
const connector = selectedConnectorsById.get(connectorId);
if (
connector?.type === "openapi" &&
connector.rawSpec &&
Object.keys(connector.rawSpec).length > 0
) {
systemPrompt +=
"\n\n" +
formatSelectedOpenApiConnectorRawSpecPrompt({
name: connector.name,
sourceUrl: connector.sourceUrl,
rawSpec: connector.rawSpec,
});
}
}
}
const systemPromptTokens = estimateTokens(systemPrompt + supabaseContext);
// Extract codebase information if app is associated with the chat
......
......@@ -2,6 +2,7 @@ import { describe, expect, it } from "vitest";
import {
extractOpenApiEndpoints,
formatOpenApiConnectorSystemPrompt,
formatSelectedOpenApiConnectorRawSpecPrompt,
} from "./openapi_utils";
describe("extractOpenApiEndpoints", () => {
......@@ -100,4 +101,86 @@ describe("extractOpenApiEndpoints", () => {
"Prefer these endpoints over inventing new API routes.",
);
});
it("formats the full raw spec for a selected connector", () => {
const prompt = formatSelectedOpenApiConnectorRawSpecPrompt({
name: "Weather API",
sourceUrl: "https://example.com/openapi.json",
rawSpec: {
openapi: "3.1.0",
servers: [{ url: "https://api.example.com" }],
components: {
securitySchemes: {
bearerAuth: {
type: "http",
scheme: "bearer",
},
},
schemas: {
ForecastRequest: {
type: "object",
required: ["city"],
properties: {
city: { type: "string" },
unit: { type: "string", enum: ["metric", "imperial"] },
},
},
ForecastResponse: {
type: "object",
properties: {
temperature: { type: "number" },
condition: { type: "string" },
},
},
},
},
security: [{ bearerAuth: [] }],
paths: {
"/forecast": {
post: {
operationId: "createForecast",
summary: "Create forecast request",
parameters: [
{
name: "x-region",
in: "header",
required: true,
schema: { type: "string" },
},
],
requestBody: {
content: {
"application/json": {
schema: { $ref: "#/components/schemas/ForecastRequest" },
},
},
},
responses: {
"200": {
description: "Forecast response",
content: {
"application/json": {
schema: {
$ref: "#/components/schemas/ForecastResponse",
},
},
},
},
},
},
},
},
},
});
expect(prompt).toContain("# API Connector Raw Spec");
expect(prompt).toContain("OpenAPI JSON:");
expect(prompt).toContain('"openapi": "3.1.0"');
expect(prompt).toContain('"servers": [');
expect(prompt).toContain('"components": {');
expect(prompt).toContain('"ForecastRequest": {');
expect(prompt).toContain('"ForecastResponse": {');
expect(prompt).toContain('"requestBody": {');
expect(prompt).toContain('"responses": {');
});
});
......@@ -22,6 +22,8 @@ type OpenApiLikeDocument = {
paths?: Record<string, Record<string, Record<string, unknown>>>;
};
const MAX_CONNECTOR_SUMMARY_ENDPOINTS = 60;
function parseSpecText(rawText: string): Record<string, unknown> {
try {
return JSON.parse(rawText) as Record<string, unknown>;
......@@ -174,7 +176,7 @@ export function formatOpenApiConnectorSystemPrompt(input: {
endpoints: ConnectorEndpoint[];
}): string {
const endpointLines = input.endpoints
.slice(0, 60)
.slice(0, MAX_CONNECTOR_SUMMARY_ENDPOINTS)
.map((endpoint) => {
const summary = endpoint.summary?.trim()
? ` - ${endpoint.summary.trim()}`
......@@ -186,7 +188,10 @@ export function formatOpenApiConnectorSystemPrompt(input: {
})
.join("\n");
const remainingCount = Math.max(input.endpoints.length - 60, 0);
const remainingCount = Math.max(
input.endpoints.length - MAX_CONNECTOR_SUMMARY_ENDPOINTS,
0,
);
const descriptionBlock =
input.description && input.description.trim().length > 0
? `Description: ${input.description.trim()}\n`
......@@ -204,3 +209,24 @@ ${remainingCount > 0 ? `- ...and ${remainingCount} more endpoints` : ""}
Prefer these endpoints over inventing new API routes. If you need to wire this API into the app, generate code that calls these external endpoints rather than creating fake local endpoints unless the user explicitly asks for a mock or proxy layer.`;
}
export function formatSelectedOpenApiConnectorRawSpecPrompt(input: {
name: string;
sourceUrl: string;
rawSpec: Record<string, unknown>;
}): string {
const rawSpecJson = JSON.stringify(input.rawSpec, null, 2);
return `# API Connector Raw Spec
The user explicitly selected this connector for the current turn. Use the full OpenAPI/Swagger document below as the source of truth for request parameters, request bodies, response bodies, auth, and schema definitions.
Connector: ${input.name}
Spec URL: ${input.sourceUrl}
OpenAPI JSON:
\`\`\`json
${rawSpecJson}
\`\`\`
Use this raw spec directly when generating request helpers, payload types, response handling, and authentication code for this selected connector.`;
}
......@@ -13,6 +13,7 @@ import {
Pencil,
Folder,
Star,
Loader2,
} from "lucide-react";
import {
Popover,
......@@ -33,60 +34,19 @@ import {
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { GitHubConnector } from "@/components/GitHubConnector";
import { SupabaseConnector } from "@/components/SupabaseConnector";
import { NeonConnector } from "@/components/NeonConnector";
import { showError, showSuccess } from "@/lib/toast";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { Label } from "@/components/ui/label";
import { Info, Loader2 } from "lucide-react";
import {
Card,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import { invalidateAppQuery } from "@/hooks/useLoadApp";
import { useDebounce } from "@/hooks/useDebounce";
import { useCheckName } from "@/hooks/useCheckName";
import { AppUpgrades } from "@/components/AppUpgrades";
import { CapacitorControls } from "@/components/CapacitorControls";
import { GithubCollaboratorManager } from "@/components/GithubCollaboratorManager";
import { useAddAppToFavorite } from "@/hooks/useAddAppToFavorite";
import { useTranslation } from "react-i18next";
import { queryKeys } from "@/lib/queryKeys";
function UnavailableIntegrationCard({
provider,
}: {
provider: "supabase" | "neon";
}) {
const { t } = useTranslation("home");
const label = provider === "supabase" ? "Supabase" : "Neon";
const descriptionKey =
provider === "supabase"
? "integrations.mutualExclusion.supabaseUnavailable"
: "integrations.mutualExclusion.neonUnavailable";
return (
<Card className="mt-1">
<CardHeader className="flex flex-row items-center gap-3 py-3">
<Info className="h-5 w-5 text-muted-foreground shrink-0" />
<div>
<CardTitle className="text-sm">{label}</CardTitle>
<CardDescription className="text-xs">
{t(descriptionKey)}
</CardDescription>
</div>
</CardHeader>
</Card>
);
}
export default function AppDetailsPage() {
const navigate = useNavigate();
const router = useRouter();
const search = useSearch({ from: "/app-details" as const });
const { t } = useTranslation("home");
const { apps: appsList, refreshApps } = useLoadApps();
const [isDeleteDialogOpen, setIsDeleteDialogOpen] = useState(false);
const [isDeleting, setIsDeleting] = useState(false);
......@@ -118,8 +78,6 @@ export default function AppDetailsPage() {
// Get the appId and provider filter from search params
const appId = search.appId ? Number(search.appId) : null;
const providerFilter = search.provider;
const { data: screenshotsData } = useQuery({
queryKey: queryKeys.apps.screenshots({ appId }),
queryFn: () => ipc.app.listAppScreenshots({ appId: appId! }),
......@@ -487,58 +445,6 @@ export default function AppDetailsPage() {
Open in Chat
<MessageCircle className="h-4 w-4" />
</Button>
<div className="border border-gray-200 rounded-md p-4">
<GitHubConnector appId={appId} folderName={selectedApp.path} />
{selectedApp.githubOrg && selectedApp.githubRepo && appId && (
<div className="pt-4 border-t border-gray-100 dark:border-gray-800">
<GithubCollaboratorManager appId={appId} />
</div>
)}
</div>
{/* When providerFilter is set, show the selected connector only if the other provider isn't already active */}
{providerFilter === "supabase" &&
appId &&
!selectedApp?.neonProjectId && <SupabaseConnector appId={appId} />}
{providerFilter === "supabase" &&
appId &&
selectedApp?.neonProjectId && (
<UnavailableIntegrationCard provider="supabase" />
)}
{providerFilter === "neon" &&
appId &&
!selectedApp?.supabaseProjectId && <NeonConnector appId={appId} />}
{providerFilter === "neon" &&
appId &&
selectedApp?.supabaseProjectId && (
<UnavailableIntegrationCard provider="neon" />
)}
{/* When no providerFilter, show both with existing mutual exclusion */}
{!providerFilter && (
<>
{appId &&
!selectedApp?.neonProjectId &&
!selectedApp?.supabaseProjectId && (
<div className="flex items-start gap-2 rounded-md border border-muted bg-muted/30 px-3 py-2 text-xs text-muted-foreground">
<Info className="h-4 w-4 shrink-0 mt-0.5" />
<span>{t("integrations.mutualExclusion.chooseOne")}</span>
</div>
)}
{appId && !selectedApp?.neonProjectId && (
<SupabaseConnector appId={appId} />
)}
{appId && selectedApp?.neonProjectId && (
<UnavailableIntegrationCard provider="supabase" />
)}
{appId && !selectedApp?.supabaseProjectId && (
<NeonConnector appId={appId} />
)}
{appId && selectedApp?.supabaseProjectId && (
<UnavailableIntegrationCard provider="neon" />
)}
</>
)}
{appId && <CapacitorControls appId={appId} />}
<AppUpgrades appId={appId} />
</div>
{/* Rename Dialog */}
......
......@@ -209,6 +209,8 @@ export default function HomePage() {
chatId,
appId,
attachments,
connectorIds: selectedConnector ? [selectedConnector.id] : undefined,
connectorSelectionMode: selectedConnector ? "append" : undefined,
requestedChatMode: initialChatMode,
});
await new Promise((resolve) =>
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论