diff --git a/bun.lock b/bun.lock index c5b7479..27869f7 100644 --- a/bun.lock +++ b/bun.lock @@ -1,6 +1,5 @@ { "lockfileVersion": 1, - "configVersion": 0, "workspaces": { "": { "name": "db-studio", @@ -102,7 +101,9 @@ "name": "proxy", "dependencies": { "@tanstack/ai": "^0.2.2", + "@tanstack/ai-anthropic": "^0.2.0", "@tanstack/ai-gemini": "^0.3.2", + "@tanstack/ai-openai": "^0.3.0", "@upstash/redis": "^1.36.1", "hono": "^4.11.3", "hono-rate-limiter": "^0.5.3", @@ -115,7 +116,7 @@ }, "packages/server": { "name": "db-studio", - "version": "1.2.20", + "version": "1.2.21", "bin": { "db-studio": "./dist/index.js", }, @@ -217,6 +218,8 @@ "@antfu/ni": ["@antfu/ni@25.0.0", "", { "dependencies": { "ansis": "^4.0.0", "fzf": "^0.5.2", "package-manager-detector": "^1.3.0", "tinyexec": "^1.0.1" }, "bin": { "na": "bin/na.mjs", "ni": "bin/ni.mjs", "nr": "bin/nr.mjs", "nci": "bin/nci.mjs", "nlx": "bin/nlx.mjs", "nun": "bin/nun.mjs", "nup": "bin/nup.mjs" } }, "sha512-9q/yCljni37pkMr4sPrI3G4jqdIk074+iukc5aFJl7kmDCCsiJrbZ6zKxnES1Gwg+i9RcDZwvktl23puGslmvA=="], + "@anthropic-ai/sdk": ["@anthropic-ai/sdk@0.71.2", "", { "dependencies": { "json-schema-to-ts": "^3.1.1" }, "peerDependencies": { "zod": "^3.25.0 || ^4.0.0" }, "optionalPeers": ["zod"], "bin": { "anthropic-ai-sdk": "bin/cli" } }, "sha512-TGNDEUuEstk/DKu0/TflXAEt+p+p/WhTlFzEnoosvbaDU2LTjm42igSdlL0VijrKpWejtOKxX0b8A7uc+XiSAQ=="], + "@asamuzakjp/css-color": ["@asamuzakjp/css-color@4.1.1", "", { "dependencies": { "@csstools/css-calc": "^2.1.4", "@csstools/css-color-parser": "^3.1.0", "@csstools/css-parser-algorithms": "^3.0.5", "@csstools/css-tokenizer": "^3.0.4", "lru-cache": "^11.2.4" } }, "sha512-B0Hv6G3gWGMn0xKJ0txEi/jM5iFpT3MfDxmhZFb4W047GvytCf1DHQ1D69W3zHI4yWe2aTZAA0JnbMZ7Xc8DuQ=="], "@asamuzakjp/dom-selector": ["@asamuzakjp/dom-selector@6.7.6", "", { "dependencies": { "@asamuzakjp/nwsapi": "^2.3.9", "bidi-js": "^1.0.3", "css-tree": "^3.1.0", "is-potential-custom-element-name": "^1.0.1", "lru-cache": "^11.2.4" } }, "sha512-hBaJER6A9MpdG3WgdlOolHmbOYvSk46y7IQN/1+iqiCuUu6iWdQrs9DGKF8ocqsEqWujWf/V7b7vaDgiUmIvUg=="], @@ -945,12 +948,16 @@ "@tanstack/ai": ["@tanstack/ai@0.2.2", "", { "dependencies": { "@tanstack/devtools-event-client": "^0.4.0", "partial-json": "^0.1.7" } }, "sha512-qqnUSKYMuJnGhiL6t8BAu3Joc9QhQTJIxUIWgQlObDhdY+dCJMLyv+Z7Zw+WqzCCjDfvWmHgLNWDI8+f3KkOPw=="], + "@tanstack/ai-anthropic": ["@tanstack/ai-anthropic@0.2.0", "", { "dependencies": { "@anthropic-ai/sdk": "^0.71.0" }, "peerDependencies": { "@tanstack/ai": "^0.2.0", "zod": "^4.0.0" } }, "sha512-52uwfHGmclhFQx+xOlaNnqLUfh1O1arcAIPlenzFp8infqgG9wVPwtDfvGRlr8zUkYlEKlXsZSpUxRbT21RYuA=="], + "@tanstack/ai-client": ["@tanstack/ai-client@0.2.2", "", { "dependencies": { "@tanstack/ai": "0.2.2" } }, "sha512-7WVYzMas6ACtt4NMGqnduWbKwDr1syYrmgQoy+hw3Iu5lEEIqdI/tG4koJJROtOM7ogUjKQBpIgsX4Tej+fPWA=="], "@tanstack/ai-devtools-core": ["@tanstack/ai-devtools-core@0.2.0", "", { "dependencies": { "@tanstack/ai": "0.2.2", "@tanstack/devtools-ui": "^0.4.4", "@tanstack/devtools-utils": "^0.2.3", "goober": "^2.1.18", "solid-js": "^1.9.10" } }, "sha512-7QP47lu7IOxe32QHmx8ThiEqsnn0Ye6NOirV3s9WjqyuHlMowzt5eZjbm8w86c6eOn4NGdkolba+jZPVCKrfcQ=="], "@tanstack/ai-gemini": ["@tanstack/ai-gemini@0.3.2", "", { "dependencies": { "@google/genai": "^1.30.0" }, "peerDependencies": { "@tanstack/ai": "^0.2.2" } }, "sha512-5s4lcmGJAb8G7lKGxkn/8soDLn/+Ao5IEh+2ADmd9PTuyOrwgdAus5n3I672GYW4EI/i2eXcMuTxe50DGAsVnw=="], + "@tanstack/ai-openai": ["@tanstack/ai-openai@0.3.0", "", { "dependencies": { "openai": "^6.9.1" }, "peerDependencies": { "@tanstack/ai": "^0.2.2", "zod": "^4.0.0" } }, "sha512-ZaMYUiU97LLDhJFbCQgtTVzCitMQSXFe4j1WGsXrz5fPe9ZUqYzMFSJRozONluM+vuI6z02XF2PJPzQldGKbMQ=="], + "@tanstack/ai-react": ["@tanstack/ai-react@0.2.2", "", { "dependencies": { "@tanstack/ai-client": "0.2.2" }, "peerDependencies": { "@tanstack/ai": "^0.2.2", "@types/react": ">=18.0.0", "react": ">=18.0.0" } }, "sha512-CNSOOoAUjre5lQxbQVqsXIJEJsTEdyPfyuvAFtgBGvbAJAVn5+6AZyOJzZZz1YKO3F05yD4hMkygelvoSPTUJA=="], "@tanstack/devtools": ["@tanstack/devtools@0.10.3", "", { "dependencies": { "@solid-primitives/event-listener": "^2.4.3", "@solid-primitives/keyboard": "^1.3.3", "@solid-primitives/resize-observer": "^2.1.3", "@tanstack/devtools-client": "0.0.5", "@tanstack/devtools-event-bus": "0.4.0", "@tanstack/devtools-ui": "0.4.4", "clsx": "^2.1.1", "goober": "^2.1.16", "solid-js": "^1.9.9" } }, "sha512-M2HnKtaNf3Z8JDTNDq+X7/1gwOqSwTnCyC0GR+TYiRZM9mkY9GpvTqp6p6bx3DT8onu2URJiVxgHD9WK2e3MNQ=="], @@ -1041,7 +1048,7 @@ "@types/babel__traverse": ["@types/babel__traverse@7.28.0", "", { "dependencies": { "@babel/types": "^7.28.2" } }, "sha512-8PvcXf70gTDZBgt9ptxJ8elBeBjcLOAcOtoO/mPJjtji1+CdGbHgm77om1GrsPxsiE+uXIpNSK64UYaIwQXd4Q=="], - "@types/bun": ["@types/bun@1.3.6", "", { "dependencies": { "bun-types": "1.3.6" } }, "sha512-uWCv6FO/8LcpREhenN1d1b6fcspAB+cefwD7uti8C8VffIv0Um08TKMn98FynpTiU38+y2dUO55T11NgDt8VAA=="], + "@types/bun": ["@types/bun@1.3.8", "", { "dependencies": { "bun-types": "1.3.8" } }, "sha512-3LvWJ2q5GerAXYxO2mffLTqOzEu5qnhEAlh48Vnu8WQfnmSwbgagjGZV6BoHKJztENYEDn6QmVd949W4uESRJA=="], "@types/chai": ["@types/chai@5.2.3", "", { "dependencies": { "@types/deep-eql": "*", "assertion-error": "^2.0.1" } }, "sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA=="], @@ -1379,7 +1386,7 @@ "builtins": ["builtins@1.0.3", "", {}, "sha512-uYBjakWipfaO/bXI7E8rq6kpwHRZK5cNYrUv2OzZSI/FvmdMyXJ2tG9dKcjEC5YHmHpUAwsargWIZNWdxb/bnQ=="], - "bun-types": ["bun-types@1.3.6", "", { "dependencies": { "@types/node": "*" } }, "sha512-OlFwHcnNV99r//9v5IIOgQ9Uk37gZqrNMCcqEaExdkVq3Avwqok1bJFmvGMCkCE0FqzdY8VMOZpfpR3lwI+CsQ=="], + "bun-types": ["bun-types@1.3.8", "", { "dependencies": { "@types/node": "*" } }, "sha512-fL99nxdOWvV4LqjmC+8Q9kW3M4QTtTR1eePs94v5ctGqU8OeceWrSUaRw3JYb7tU3FkMIAjkueehrHPPPGKi5Q=="], "bundle-name": ["bundle-name@4.1.0", "", { "dependencies": { "run-applescript": "^7.0.0" } }, "sha512-tjwM5exMg6BGRI+kNmTntNsvdZS1X8BFYS6tnJ2hdH0kVxM6/eVZ2xy+FqStSWvYmtfFMDLIxurorHwDKfDz5Q=="], @@ -2259,6 +2266,8 @@ "json-parse-even-better-errors": ["json-parse-even-better-errors@2.3.1", "", {}, "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w=="], + "json-schema-to-ts": ["json-schema-to-ts@3.1.1", "", { "dependencies": { "@babel/runtime": "^7.18.3", "ts-algebra": "^2.0.0" } }, "sha512-+DWg8jCJG2TEnpy7kOm/7/AxaYoaRbjVB4LFZLySZlWn8exGs3A4OLJR966cVvU26N7X9TWxl+Jsw7dzAqKT6g=="], + "json-schema-traverse": ["json-schema-traverse@1.0.0", "", {}, "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug=="], "json-schema-typed": ["json-schema-typed@8.0.2", "", {}, "sha512-fQhoXdcvc3V28x7C7BMs4P5+kNlgUURe2jmUT1T//oBRMDrqy1QPelJimwZGo7Hg9VPV3EQV5Bnq4hbFy2vetA=="], @@ -2623,6 +2632,8 @@ "open": ["open@10.2.0", "", { "dependencies": { "default-browser": "^5.2.1", "define-lazy-prop": "^3.0.0", "is-inside-container": "^1.0.0", "wsl-utils": "^0.1.0" } }, "sha512-YgBpdJHPyQ2UE5x+hlSXcnejzAvD0b22U2OuAP+8OnlJT+PjWPxtgmGqKKc+RgTM63U9gN0YzrYc71R2WT/hTA=="], + "openai": ["openai@6.17.0", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.25 || ^4.0" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-NHRpPEUPzAvFOAFs9+9pC6+HCw/iWsYsKCMPXH5Kw7BpMxqd8g/A07/1o7Gx2TWtCnzevVRyKMRFqyiHyAlqcA=="], + "optionator": ["optionator@0.9.4", "", { "dependencies": { "deep-is": "^0.1.3", "fast-levenshtein": "^2.0.6", "levn": "^0.4.1", "prelude-ls": "^1.2.1", "type-check": "^0.4.0", "word-wrap": "^1.2.5" } }, "sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g=="], "ora": ["ora@8.2.0", "", { "dependencies": { "chalk": "^5.3.0", "cli-cursor": "^5.0.0", "cli-spinners": "^2.9.2", "is-interactive": "^2.0.0", "is-unicode-supported": "^2.0.0", "log-symbols": "^6.0.0", "stdin-discarder": "^0.2.2", "string-width": "^7.2.0", "strip-ansi": "^7.1.0" } }, "sha512-weP+BZ8MVNnlCm8c0Qdc1WSWq4Qn7I+9CJGm7Qali6g44e/PUzbjNqJX5NJ9ljlNMosfJvg1fKEGILklK9cwnw=="], @@ -3163,6 +3174,8 @@ "trough": ["trough@2.2.0", "", {}, "sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw=="], + "ts-algebra": ["ts-algebra@2.0.0", "", {}, "sha512-FPAhNPFMrkwz76P7cdjdmiShwMynZYN6SgOujD1urY4oNm80Ou9oMdmbR45LotcKOXoy7wSmHkRFE6Mxbrhefw=="], + "ts-api-utils": ["ts-api-utils@2.4.0", "", { "peerDependencies": { "typescript": ">=4.8.4" } }, "sha512-3TaVTaAv2gTiMB35i3FiGJaRfwb3Pyn/j3m/bfAvGe8FB7CF6u+LMYqYlDh7reQf7UNvoTvdfAqHGmPGOSsPmA=="], "ts-dedent": ["ts-dedent@2.2.0", "", {}, "sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ=="], diff --git a/packages/core/src/components/chat/chat-sidebar.tsx b/packages/core/src/components/chat/chat-sidebar.tsx index 22329f2..44f591b 100644 --- a/packages/core/src/components/chat/chat-sidebar.tsx +++ b/packages/core/src/components/chat/chat-sidebar.tsx @@ -2,8 +2,9 @@ import { fetchServerSentEvents, useChat } from "@tanstack/ai-react"; import { Plus } from "lucide-react"; -import { useState } from "react"; -import { CHAT_SUGGESTIONS, DEFAULTS } from "shared/constants"; +import { useEffect, useMemo, useState } from "react"; +import { CHAT_SUGGESTIONS, DEFAULTS, MODEL_LIST } from "shared/constants"; +import type { ExecuteQueryResult } from "shared/types"; import { Conversation, ConversationContent, @@ -33,21 +34,120 @@ import { import { Suggestion, Suggestions } from "@/components/ai-elements/suggestion"; import { SheetSidebar } from "@/components/sheet-sidebar"; import { Button } from "@/components/ui/button"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { useExecuteQuery } from "@/hooks/use-execute-query"; import { useRateLimit } from "@/hooks/use-rate-limit"; +import { getDbType } from "@/lib/api"; +import { useAiPrefillStore } from "@/stores/ai-prefill.store"; +import { useAiSettingsStore } from "@/stores/ai-settings.store"; +import { useDatabaseStore } from "@/stores/database.store"; +import { useInsertSqlStore } from "@/stores/insert-sql.store"; import { useSheetStore } from "@/stores/sheet.store"; export const ChatSidebar = () => { const { rateLimit, refetchRateLimit } = useRateLimit(); const [text, setText] = useState(""); + const [chatResults, setChatResults] = useState< + Record + >({}); const { isSheetOpen, closeSheet } = useSheetStore(); + const { selectedDatabase } = useDatabaseStore(); + const { + includeSchemaInAiContext, + useByocProxy, + byocProxyUrl, + provider, + model, + apiKeys, + setModel, + } = useAiSettingsStore(); + const { executeQuery: executeSandboxQuery } = useExecuteQuery("sandbox"); + + const chatUrl = useMemo(() => { + const dbType = getDbType() ?? "pg"; + return `${DEFAULTS.BASE_URL}/${dbType}/chat`; + }, []); + + const chatBody = useMemo(() => { + const body: Record = { + db: selectedDatabase ?? undefined, + includeSchemaInAiContext, + provider, + model, + apiKey: apiKeys[provider] || undefined, + }; + if (useByocProxy && byocProxyUrl?.trim()) { + body.proxyUrl = byocProxyUrl.trim(); + } + return body; + }, [ + selectedDatabase, + includeSchemaInAiContext, + useByocProxy, + byocProxyUrl, + provider, + model, + apiKeys, + ]); + + const { prefillMessage, setPrefillMessage } = useAiPrefillStore(); + const { setPendingSql } = useInsertSqlStore(); + + function extractSqlBlock(text: string): string | null { + const match = text.match(/```sql\s*([\s\S]*?)```/); + return match ? match[1].trim() : null; + } + + useEffect(() => { + if (isSheetOpen("ai-assistant") && prefillMessage) { + setText(prefillMessage); + setPrefillMessage(null); + } + }, [isSheetOpen("ai-assistant"), prefillMessage, setPrefillMessage]); const { messages, sendMessage, isLoading, clear, stop } = useChat({ - connection: fetchServerSentEvents(`${DEFAULTS.BASE_URL}/chat`), + connection: fetchServerSentEvents(chatUrl, { body: chatBody }), onError: (error) => console.error("Error:", error.message), onResponse: (response) => console.log("Response:", response), onFinish: (message) => { console.log("Finish:", message); + const messageAny = message as { + parts?: { type: string; content?: string }[]; + content?: string; + }; + const textParts = Array.isArray(messageAny.parts) + ? messageAny.parts + .filter((part) => part.type === "text") + .map((part) => part.content ?? "") + .join("") + : (messageAny.content ?? ""); + const sqlBlock = extractSqlBlock(textParts); + if (sqlBlock && message.id && selectedDatabase) { + setChatResults((prev) => ({ + ...prev, + [message.id]: { isLoading: true }, + })); + executeSandboxQuery({ query: sqlBlock }) + .then((result) => { + setChatResults((prev) => ({ + ...prev, + [message.id]: { isLoading: false, result }, + })); + }) + .catch((error: Error) => { + setChatResults((prev) => ({ + ...prev, + [message.id]: { isLoading: false, error: error.message }, + })); + }); + } refetchRateLimit(); }, }); @@ -67,6 +167,7 @@ export const ChatSidebar = () => { const handleNewChat = () => { clear(); setText(""); + setChatResults({}); }; const handleSuggestionClick = (suggestion: string) => { @@ -86,6 +187,24 @@ export const ChatSidebar = () => { title="AI Assistant" cta={
+ {rateLimit && ( @@ -139,10 +258,13 @@ export const ChatSidebar = () => { ); const textContent = message.parts .filter((part) => part.type === "text") - .map((part) => part.content) + .map((part) => ("content" in part ? part.content : "")) .join(""); const hasThinking = thinkingParts.length > 0; + const sqlBlock = + message.role === "assistant" ? extractSqlBlock(textContent) : null; + const chatResult = message.id ? chatResults[message.id] : undefined; return ( { > -
+
{hasThinking && message.role === "assistant" && ( - {thinkingParts.map((part) => part.content).join("\n")} + {thinkingParts + .map((part) => ("content" in part ? part.content : "")) + .join("\n")} )} {textContent} + {sqlBlock && ( +
+ + {chatResult?.isLoading && ( +
+ Running in sandbox... +
+ )} + {chatResult?.error && ( +
+ Error: {chatResult.error} +
+ )} + {chatResult?.result && ( +
+
+ Sandbox result — no changes saved +
+
+ {chatResult.result.rowCount} rows •{" "} + {chatResult.result.duration.toFixed(2)}ms +
+ {chatResult.result.rows.length > 0 && ( +
+																				{JSON.stringify(
+																					chatResult.result.rows.slice(0, 5),
+																					null,
+																					2,
+																				)}
+																			
+ )} +
+ )} +
+ )}
@@ -222,7 +389,7 @@ export const ChatSidebar = () => { className="h-8!" status={status} onClick={isLoading ? handleStop : undefined} - disabled={rateLimit && rateLimit.remaining === 0} + disabled={rateLimit?.remaining === 0 || !selectedDatabase} /> diff --git a/packages/core/src/components/components/command-palette.tsx b/packages/core/src/components/components/command-palette.tsx index 64b293f..7a50420 100644 --- a/packages/core/src/components/components/command-palette.tsx +++ b/packages/core/src/components/components/command-palette.tsx @@ -682,13 +682,7 @@ export function CommandPalette() {
- handleAction(() => { - toast.info("Settings - Coming Soon!", { - description: "Customize your DB Studio experience", - }); - }) - } + onSelect={() => handleAction(() => openSheet("settings"), "Opening settings")} >
diff --git a/packages/core/src/components/runnr-tab/cdoe-editor.tsx b/packages/core/src/components/runnr-tab/cdoe-editor.tsx index 255c8ed..f5bc119 100644 --- a/packages/core/src/components/runnr-tab/cdoe-editor.tsx +++ b/packages/core/src/components/runnr-tab/cdoe-editor.tsx @@ -1,6 +1,7 @@ import * as monaco from "monaco-editor"; import { useEffect, useRef } from "react"; import { toast } from "sonner"; +import { useInsertSqlStore } from "@/stores/insert-sql.store"; import { BUILTIN_FUNCTIONS, BUILTIN_TYPES, @@ -33,6 +34,8 @@ export const CodeEditor = ({ }: CodeEditorProps) => { const monacoEl = useRef(null); const editorRef = useRef(null); + const pendingInsertSql = useInsertSqlStore((s) => s.pendingSql); + const setPendingSql = useInsertSqlStore((s) => s.setPendingSql); useEffect(() => { if (!monacoEl.current) return; @@ -292,6 +295,18 @@ export const CodeEditor = ({ // Set initial query change onQueryChange(editorInstance.getValue()); + // Insert any pending SQL from chat (e.g. user clicked "Insert into editor") + const pending = useInsertSqlStore.getState().pendingSql; + if (pending) { + const sel = editorInstance.getSelection(); + const model = editorInstance.getModel(); + const range = sel ?? model?.getFullModelRange(); + if (range) { + editorInstance.executeEdits("insert-sql", [{ range, text: pending }]); + useInsertSqlStore.getState().setPendingSql(null); + } + } + // Reset unsaved changes state when editor is initialized if (queryId) { const initialValue = editorInstance.getValue(); @@ -319,6 +334,18 @@ export const CodeEditor = ({ }; }, [initialQuery, queryId, savedQuery, onQueryChange, onUnsavedChanges, onExecuteQuery]); + // When pending SQL is set from chat (e.g. "Insert into editor"), insert at cursor + useEffect(() => { + if (!pendingInsertSql || !editorRef.current) return; + const editor = editorRef.current; + const sel = editor.getSelection(); + const model = editor.getModel(); + const range = sel ?? model?.getFullModelRange(); + if (!range) return; + editor.executeEdits("insert-sql", [{ range, text: pendingInsertSql }]); + setPendingSql(null); + }, [pendingInsertSql, setPendingSql]); + return (
{ const [showAs] = useQueryState(CONSTANTS.RUNNER_STATE_KEYS.SHOW_AS); + const { suggestFix, isSuggestingFix, suggestFixError } = useSuggestFix(); + const { setPendingSql } = useInsertSqlStore(); + const [fixResult, setFixResult] = useState(null); + + useEffect(() => { + setFixResult(null); + }, [error]); const renderResults = useMemo(() => { if (error) { + const errorDetails = (error as Error & { details?: unknown }).details; return (
-
Error: {error.message}
+
+
Error: {error.message}
+
+ + {suggestFixError && ( + {suggestFixError.message} + )} +
+ {fixResult && ( +
+
{fixResult.explanation}
+
{fixResult.suggestedQuery}
+ +
+ )} +
); } @@ -71,6 +123,11 @@ export const QueryResultContainer = ({ className="absolute bottom-0 left-0 right-0 border-t-2 border-zinc-80 flex flex-col bg-[#1E1E1E] w-full" style={{ height: "calc(100vh - 400px)" }} > + {runMode === "sandbox" && ( +
+ Sandbox run — no changes saved +
+ )}
{isLoading ? (
diff --git a/packages/core/src/components/runnr-tab/runner-header.tsx b/packages/core/src/components/runnr-tab/runner-header.tsx index 67a5a25..07d74bc 100644 --- a/packages/core/src/components/runnr-tab/runner-header.tsx +++ b/packages/core/src/components/runnr-tab/runner-header.tsx @@ -1,14 +1,30 @@ -import { AlignLeft, Braces, Command, CornerDownLeft, Heart, Save, Table } from "lucide-react"; +import { + AlignLeft, + Braces, + Command, + CornerDownLeft, + Heart, + Save, + Shield, + Sparkles, + Table, + Zap, +} from "lucide-react"; import { useQueryState } from "nuqs"; import { Button } from "@/components/ui/button"; import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { useAiPrefillStore } from "@/stores/ai-prefill.store"; +import { useSheetStore } from "@/stores/sheet.store"; import { CONSTANTS } from "@/utils/constants"; import type { QueryResult } from "./runner-tab"; export const RunnerHeader = ({ isExecutingQuery, handleButtonClick, + handleSandboxRun, + handleOptimizeQuery, + isOptimizing, handleFormatQuery, handleSaveQuery, handleFavorite, @@ -16,9 +32,13 @@ export const RunnerHeader = ({ queryId, hasUnsavedChanges, queryResult, + lastRunMode, }: { isExecutingQuery: boolean; handleButtonClick: () => void; + handleSandboxRun: () => void; + handleOptimizeQuery: () => void; + isOptimizing: boolean; handleFormatQuery: () => void; handleSaveQuery: () => void; handleFavorite: () => void; @@ -26,8 +46,16 @@ export const RunnerHeader = ({ queryId: string; hasUnsavedChanges: boolean; queryResult: QueryResult | null; + lastRunMode: "normal" | "sandbox"; }) => { const [showAs, setShowAs] = useQueryState(CONSTANTS.RUNNER_STATE_KEYS.SHOW_AS); + const { openSheet } = useSheetStore(); + const { setPrefillMessage } = useAiPrefillStore(); + + const handleGenerateWithAi = () => { + setPrefillMessage("Generate a SQL query for this database."); + openSheet("ai-assistant"); + }; return (
@@ -45,6 +73,42 @@ export const RunnerHeader = ({ + + + + + +

Run in sandbox (no changes saved)

+
+
+ + + + + + +

Generate with AI

+
+
+ + + +
+
+ )} + {analyzeError && ( +
Analyze error: {analyzeError.message}
+ )} + {suggestOptimizationError && ( +
+ Optimization error: {suggestOptimizationError.message} +
+ )} +
+ )} + }> {
); diff --git a/packages/core/src/components/runnr-tab/table-view.tsx b/packages/core/src/components/runnr-tab/table-view.tsx index 36b5094..415c0cd 100644 --- a/packages/core/src/components/runnr-tab/table-view.tsx +++ b/packages/core/src/components/runnr-tab/table-view.tsx @@ -1,10 +1,10 @@ import { flexRender, getCoreRowModel, type Row, useReactTable } from "@tanstack/react-table"; import { useVirtualizer } from "@tanstack/react-virtual"; import { useMemo, useRef } from "react"; -import type { ExecuteQueryResponse } from "shared/types"; +import type { ExecuteQueryResult } from "shared/types"; import { formatCellValue } from "@/utils/format-cell-value"; -export const TableView = ({ results }: { results: ExecuteQueryResponse | null }) => { +export const TableView = ({ results }: { results: ExecuteQueryResult | null }) => { const columns = useMemo(() => { if (!results?.columns) return []; diff --git a/packages/core/src/components/settings/settings-sheet.tsx b/packages/core/src/components/settings/settings-sheet.tsx new file mode 100644 index 0000000..448323d --- /dev/null +++ b/packages/core/src/components/settings/settings-sheet.tsx @@ -0,0 +1,211 @@ +"use client"; + +import { useMemo } from "react"; +import { AI_PROVIDERS, type AiProvider, MODEL_LIST } from "shared/constants"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; +import { Switch } from "@/components/ui/switch"; +import { cn } from "@/lib/utils"; +import { useAiSettingsStore } from "@/stores/ai-settings.store"; +import { useSheetStore } from "@/stores/sheet.store"; + +function isValidByocUrl(value: string): boolean { + if (!value.trim()) return true; + try { + const url = new URL(value.trim()); + return url.protocol === "https:"; + } catch { + return false; + } +} + +export const SettingsSheet = () => { + const { closeSheet, isSheetOpen } = useSheetStore(); + const { + includeSchemaInAiContext, + useByocProxy, + byocProxyUrl, + provider, + model, + apiKeys, + setIncludeSchemaInAiContext, + setUseByocProxy, + setByocProxyUrl, + setProvider, + setModel, + setApiKeyForProvider, + } = useAiSettingsStore(); + + const byocUrlError = useMemo(() => { + if (!useByocProxy || !byocProxyUrl.trim()) return null; + return isValidByocUrl(byocProxyUrl) ? null : "Enter a valid HTTPS URL or leave empty"; + }, [useByocProxy, byocProxyUrl]); + + return ( + { + if (!open) closeSheet("settings"); + }} + > + + + Settings + + +
+
+

AI & Database Context

+ +
+
+ +

+ Introspect tables and columns and add them to the AI prompt for context-aware + SQL and answers. +

+
+ +
+ +
+
+ +

+ Use your own proxy/API for AI instead of the default. +

+
+ +
+ + {useByocProxy && ( +
+ + setByocProxyUrl(e.target.value)} + className={cn("h-8", byocUrlError && "border-destructive")} + aria-invalid={Boolean(byocUrlError)} + aria-describedby={byocUrlError ? "byoc-url-error" : undefined} + /> + {byocUrlError ? ( +

+ {byocUrlError} +

+ ) : ( +

+ Must be HTTPS. Same API as default: POST /chat with messages, systemPrompt, + conversationId; respond with SSE. +

+ )} +
+ )} +
+ +
+

AI Provider

+ +
+ + +
+ +
+ + +
+ +
+ + setApiKeyForProvider(provider, e.target.value)} + className="h-8" + /> +

+ Leave empty to use the default proxy key. Keys are stored locally. +

+
+
+
+
+
+ ); +}; diff --git a/packages/core/src/components/table-tab/header/settings-menu.tsx b/packages/core/src/components/table-tab/header/settings-menu.tsx index 561b160..1a9f45b 100644 --- a/packages/core/src/components/table-tab/header/settings-menu.tsx +++ b/packages/core/src/components/table-tab/header/settings-menu.tsx @@ -1,16 +1,17 @@ import { Settings } from "lucide-react"; import { Button } from "@/components/ui/button"; - -// todo: light & dark mode -// todo: json cell tab size +import { useSheetStore } from "@/stores/sheet.store"; export const SettingsBtn = () => { + const { openSheet } = useSheetStore(); + return ( diff --git a/packages/core/src/hooks/use-analyze-query.ts b/packages/core/src/hooks/use-analyze-query.ts new file mode 100644 index 0000000..4c7eda1 --- /dev/null +++ b/packages/core/src/hooks/use-analyze-query.ts @@ -0,0 +1,30 @@ +import { useMutation } from "@tanstack/react-query"; +import type { AnalyzeQueryResult, BaseResponse } from "shared/types"; +import { api } from "@/lib/api"; +import { useDatabaseStore } from "@/stores/database.store"; + +export const useAnalyzeQuery = () => { + const { selectedDatabase } = useDatabaseStore(); + + const { + mutateAsync: analyzeQuery, + isPending: isAnalyzing, + error: analyzeError, + } = useMutation({ + mutationFn: async ({ query }) => { + const params = new URLSearchParams({ db: selectedDatabase ?? "" }); + const res = await api.post>( + "/query/analyze", + { query }, + { params }, + ); + return res.data.data; + }, + }); + + return { + analyzeQuery, + isAnalyzing, + analyzeError, + }; +}; diff --git a/packages/core/src/hooks/use-execute-query.ts b/packages/core/src/hooks/use-execute-query.ts index ecadd7a..e3ab2f2 100644 --- a/packages/core/src/hooks/use-execute-query.ts +++ b/packages/core/src/hooks/use-execute-query.ts @@ -3,7 +3,9 @@ import type { BaseResponse, ExecuteQueryResult } from "shared/types"; import { api } from "@/lib/api"; import { useDatabaseStore } from "@/stores/database.store"; -export const useExecuteQuery = () => { +type ExecuteQueryMode = "normal" | "sandbox"; + +export const useExecuteQuery = (mode: ExecuteQueryMode = "normal") => { const { selectedDatabase } = useDatabaseStore(); const { @@ -13,8 +15,9 @@ export const useExecuteQuery = () => { } = useMutation({ mutationFn: async ({ query }) => { const params = new URLSearchParams({ db: selectedDatabase ?? "" }); + const endpoint = mode === "sandbox" ? "/query/sandbox" : "/query"; const res = await api.post>( - "/query", + endpoint, { query }, { params }, ); diff --git a/packages/core/src/hooks/use-suggest-fix.ts b/packages/core/src/hooks/use-suggest-fix.ts new file mode 100644 index 0000000..cf6704a --- /dev/null +++ b/packages/core/src/hooks/use-suggest-fix.ts @@ -0,0 +1,43 @@ +import { useMutation } from "@tanstack/react-query"; +import type { BaseResponse, SuggestFixResult } from "shared/types"; +import { api } from "@/lib/api"; +import { useAiSettingsStore } from "@/stores/ai-settings.store"; +import { useDatabaseStore } from "@/stores/database.store"; + +export const useSuggestFix = () => { + const { selectedDatabase } = useDatabaseStore(); + const { useByocProxy, byocProxyUrl, provider, model, apiKeys } = useAiSettingsStore(); + + const { + mutateAsync: suggestFix, + isPending: isSuggestingFix, + error: suggestFixError, + } = useMutation< + SuggestFixResult, + Error, + { query: string; errorMessage: string; errorDetails?: unknown } + >({ + mutationFn: async ({ query, errorMessage, errorDetails }) => { + const params = new URLSearchParams({ db: selectedDatabase ?? "" }); + const body = { + query, + errorMessage, + errorDetails, + proxyUrl: useByocProxy ? byocProxyUrl.trim() : undefined, + provider, + model, + apiKey: apiKeys[provider] || undefined, + }; + const res = await api.post>("/chat/suggest-fix", body, { + params, + }); + return res.data.data; + }, + }); + + return { + suggestFix, + isSuggestingFix, + suggestFixError, + }; +}; diff --git a/packages/core/src/hooks/use-suggest-optimization.ts b/packages/core/src/hooks/use-suggest-optimization.ts new file mode 100644 index 0000000..0778067 --- /dev/null +++ b/packages/core/src/hooks/use-suggest-optimization.ts @@ -0,0 +1,39 @@ +import { useMutation } from "@tanstack/react-query"; +import type { BaseResponse, SuggestOptimizationResult } from "shared/types"; +import { api } from "@/lib/api"; +import { useAiSettingsStore } from "@/stores/ai-settings.store"; +import { useDatabaseStore } from "@/stores/database.store"; + +export const useSuggestOptimization = () => { + const { selectedDatabase } = useDatabaseStore(); + const { useByocProxy, byocProxyUrl, provider, model, apiKeys } = useAiSettingsStore(); + + const { + mutateAsync: suggestOptimization, + isPending: isSuggestingOptimization, + error: suggestOptimizationError, + } = useMutation({ + mutationFn: async ({ query }) => { + const params = new URLSearchParams({ db: selectedDatabase ?? "" }); + const body = { + query, + proxyUrl: useByocProxy ? byocProxyUrl.trim() : undefined, + provider, + model, + apiKey: apiKeys[provider] || undefined, + }; + const res = await api.post>( + "/query/suggest-optimization", + body, + { params }, + ); + return res.data.data; + }, + }); + + return { + suggestOptimization, + isSuggestingOptimization, + suggestOptimizationError, + }; +}; diff --git a/packages/core/src/routes/__root.tsx b/packages/core/src/routes/__root.tsx index 98867db..647bab2 100644 --- a/packages/core/src/routes/__root.tsx +++ b/packages/core/src/routes/__root.tsx @@ -5,6 +5,7 @@ import { createRootRoute, Outlet } from "@tanstack/react-router"; import { TanStackRouterDevtoolsPanel } from "@tanstack/react-router-devtools"; import { NuqsAdapter } from "nuqs/adapters/react"; import { AddTableForm } from "@/components/add-table/add-table-form"; +import { SettingsSheet } from "@/components/settings/settings-sheet"; import { Toaster } from "@/components/ui/sonner"; import { Spinner } from "@/components/ui/spinner"; import { useCurrentDatabase } from "@/hooks/use-databases-list"; @@ -53,6 +54,7 @@ export const Route = createRootRoute({ {/* Global sheets */} + {/* Devtools */} void; +}; + +export const useAiPrefillStore = create()((set) => ({ + prefillMessage: null, + setPrefillMessage: (message) => set({ prefillMessage: message }), +})); diff --git a/packages/core/src/stores/ai-settings.store.ts b/packages/core/src/stores/ai-settings.store.ts new file mode 100644 index 0000000..07bc2e2 --- /dev/null +++ b/packages/core/src/stores/ai-settings.store.ts @@ -0,0 +1,52 @@ +import { type AiProvider, MODEL_LIST } from "shared/constants"; +import { create } from "zustand"; +import { persist } from "zustand/middleware"; + +type AiSettingsState = { + includeSchemaInAiContext: boolean; + useByocProxy: boolean; + byocProxyUrl: string; + provider: AiProvider; + model: string; + apiKeys: Partial>; + setIncludeSchemaInAiContext: (value: boolean) => void; + setUseByocProxy: (value: boolean) => void; + setByocProxyUrl: (value: string) => void; + setProvider: (value: AiProvider) => void; + setModel: (value: string) => void; + setApiKeyForProvider: (provider: AiProvider, apiKey: string) => void; +}; + +export const useAiSettingsStore = create()( + persist( + (set) => ({ + includeSchemaInAiContext: true, + useByocProxy: false, + byocProxyUrl: "", + provider: "gemini", + model: + MODEL_LIST.find((item) => item.provider === "gemini")?.id ?? "gemini-3-flash-preview", + apiKeys: {}, + setIncludeSchemaInAiContext: (value) => set({ includeSchemaInAiContext: value }), + setUseByocProxy: (value) => set({ useByocProxy: value }), + setByocProxyUrl: (value) => set({ byocProxyUrl: value }), + setProvider: (value) => + set((state) => { + const providerModels = MODEL_LIST.filter((item) => item.provider === value); + const nextModel = providerModels[0]?.id ?? state.model; + return { provider: value, model: nextModel }; + }), + setModel: (value) => set({ model: value }), + setApiKeyForProvider: (provider, apiKey) => + set((state) => ({ + apiKeys: { + ...state.apiKeys, + [provider]: apiKey, + }, + })), + }), + { + name: "db-studio-ai-settings", + }, + ), +); diff --git a/packages/core/src/stores/insert-sql.store.ts b/packages/core/src/stores/insert-sql.store.ts new file mode 100644 index 0000000..80cf003 --- /dev/null +++ b/packages/core/src/stores/insert-sql.store.ts @@ -0,0 +1,11 @@ +import { create } from "zustand"; + +type InsertSqlState = { + pendingSql: string | null; + setPendingSql: (sql: string | null) => void; +}; + +export const useInsertSqlStore = create()((set) => ({ + pendingSql: null, + setPendingSql: (sql) => set({ pendingSql: sql }), +})); diff --git a/packages/core/src/stores/sheet.store.ts b/packages/core/src/stores/sheet.store.ts index 1a2f17f..39b454e 100644 --- a/packages/core/src/stores/sheet.store.ts +++ b/packages/core/src/stores/sheet.store.ts @@ -6,7 +6,8 @@ type SheetName = | "add-record" | `add-foreign-key-${number}` | "record-reference" - | "ai-assistant"; + | "ai-assistant" + | "settings"; type SheetState = { openSheets: SheetName[]; diff --git a/packages/proxy/package.json b/packages/proxy/package.json index 8843c1c..20eb25d 100644 --- a/packages/proxy/package.json +++ b/packages/proxy/package.json @@ -14,12 +14,14 @@ "secrets:clear": "wrangler secret delete --all" }, "dependencies": { - "shared": "workspace:*", "@tanstack/ai": "^0.2.2", + "@tanstack/ai-anthropic": "^0.2.0", "@tanstack/ai-gemini": "^0.3.2", + "@tanstack/ai-openai": "^0.3.0", "@upstash/redis": "^1.36.1", "hono": "^4.11.3", - "hono-rate-limiter": "^0.5.3" + "hono-rate-limiter": "^0.5.3", + "shared": "workspace:*" }, "devDependencies": { "@biomejs/biome": "^2.2.6", diff --git a/packages/proxy/src/index.ts b/packages/proxy/src/index.ts index 1c31383..61ee80d 100644 --- a/packages/proxy/src/index.ts +++ b/packages/proxy/src/index.ts @@ -1,14 +1,60 @@ -import { LIMIT } from "shared/constants"; import { env } from "cloudflare:workers"; import { chat, toServerSentEventsResponse } from "@tanstack/ai"; +import { anthropicText, createAnthropicChat } from "@tanstack/ai-anthropic"; import { createGeminiChat } from "@tanstack/ai-gemini"; +import { createOpenaiChat, openaiText } from "@tanstack/ai-openai"; import { Hono } from "hono"; import { cors } from "hono/cors"; +import { LIMIT } from "shared/constants"; import { createProxyLimiter, keyGenerator } from "./limit"; import { getRedis } from "./redis"; const app = new Hono(); +const DEFAULT_MODELS = { + gemini: "gemini-3-flash-preview", + openai: "gpt-4o", + anthropic: "claude-3-5-sonnet-20241022", +} as const; + +type Provider = keyof typeof DEFAULT_MODELS; + +const buildAdapter = ({ + provider, + model, + apiKey, +}: { + provider?: Provider; + model?: string; + apiKey?: string; +}) => { + const resolvedProvider = provider && provider in DEFAULT_MODELS ? provider : "gemini"; + const resolvedModel = model ?? DEFAULT_MODELS[resolvedProvider]; + + if (resolvedProvider === "openai") { + const model = resolvedModel as Parameters[0]; + if (apiKey) { + return createOpenaiChat(model, apiKey); + } + return openaiText(model); + } + + if (resolvedProvider === "anthropic") { + const model = resolvedModel as Parameters[0]; + if (apiKey) { + return createAnthropicChat(model, apiKey); + } + return anthropicText(model); + } + + const geminiModel = resolvedModel as Parameters[0]; + return createGeminiChat(geminiModel, apiKey ?? env.GEMINI_API_KEY, { + temperature: 0.1, + topP: 0.9, + maxOutputTokens: 1024, + }); +}; + app.use( "/*", cors({ @@ -32,7 +78,8 @@ app.use("/chat", createProxyLimiter()); */ app.post("/chat", async (c) => { try { - const { messages, systemPrompt, conversationId } = await c.req.json(); + const { messages, systemPrompt, conversationId, provider, model, apiKey } = + await c.req.json(); if (!messages || !Array.isArray(messages)) { return c.json({ error: "Invalid request: messages array required" }, 400); } @@ -41,18 +88,16 @@ app.post("/chat", async (c) => { console.log("systemPrompt", systemPrompt); console.log("conversationId", conversationId); - const stream = chat({ - adapter: createGeminiChat("gemini-3-flash-preview", env.GEMINI_API_KEY, { - temperature: 0.1, // Very low - we want deterministic, accurate SQL - topP: 0.9, // Very low - we want deterministic, accurate SQL - maxOutputTokens: 1024, // Short responses - SQL + brief explanation - }), + const stream = await chat({ + adapter: buildAdapter({ provider, model, apiKey }), messages, conversationId, systemPrompts: [systemPrompt], }); - return toServerSentEventsResponse(stream); + return toServerSentEventsResponse( + stream as Parameters[0], + ); } catch (error) { console.error("Proxy error:", error); const errorMessage = error instanceof Error ? error.message : "An error occurred"; diff --git a/packages/proxy/src/limit.ts b/packages/proxy/src/limit.ts index 21e192d..2c8f09c 100644 --- a/packages/proxy/src/limit.ts +++ b/packages/proxy/src/limit.ts @@ -1,6 +1,6 @@ -import { LIMIT, ONE_DAY } from "shared/constants"; import type { Context, MiddlewareHandler } from "hono"; import { rateLimiter } from "hono-rate-limiter"; +import { LIMIT, ONE_DAY } from "shared/constants"; import { getRedisStore } from "./redis"; export const keyGenerator = (c: Context) => { diff --git a/packages/server/src/app.types.ts b/packages/server/src/app.types.ts index 7bc92d4..0246a70 100644 --- a/packages/server/src/app.types.ts +++ b/packages/server/src/app.types.ts @@ -9,7 +9,7 @@ import type { TablesRoutes } from "@/routes/tables.routes.js"; export type BaseResponseType = TypedResponse, 200>; -export type ApiErrorType = TypedResponse; +export type ApiErrorType = TypedResponse; /** * ApiHandler is a type that represents a response or error from an API endpoint. diff --git a/packages/server/src/dao/query.dao.ts b/packages/server/src/dao/query.dao.ts index 8df9479..4c9ae2a 100644 --- a/packages/server/src/dao/query.dao.ts +++ b/packages/server/src/dao/query.dao.ts @@ -1,7 +1,15 @@ import { HTTPException } from "hono/http-exception"; -import type { DatabaseSchemaType, ExecuteQueryParams, ExecuteQueryResult } from "shared/types"; +import { DatabaseError, type QueryResult } from "pg"; +import type { + AnalyzeQueryResult, + DatabaseSchemaType, + ExecuteQueryParams, + ExecuteQueryResult, +} from "shared/types"; import { getDbPool } from "@/db-manager.js"; +const SANDBOX_STATEMENT_TIMEOUT_MS = 8000; + export const executeQuery = async ({ query, db, @@ -20,7 +28,31 @@ export const executeQuery = async ({ const cleanedQuery = query.trim().replace(/;+$/, ""); const startTime = performance.now(); - const result = await pool.query(cleanedQuery); + let result: QueryResult; + try { + result = await pool.query(cleanedQuery); + } catch (error) { + if (error instanceof DatabaseError) { + const headers = new Headers(); + headers.set("Content-Type", "application/json"); + throw new HTTPException(400, { + message: "Query failed", + res: new Response( + JSON.stringify({ + error: error.message, + details: { + code: error.code, + position: error.position, + detail: error.detail, + hint: error.hint, + }, + }), + { status: 400, statusText: "Bad Request", headers }, + ), + }); + } + throw error; + } const duration = performance.now() - startTime; const columns = result.fields.map((field) => field.name); @@ -33,3 +65,125 @@ export const executeQuery = async ({ message: result.rows.length === 0 ? "OK" : undefined, }; }; + +export const executeQuerySandbox = async ({ + query, + db, +}: { + query: ExecuteQueryParams["query"]; + db: DatabaseSchemaType["db"]; +}): Promise => { + const pool = getDbPool(db); + if (!query || !query.trim()) { + throw new HTTPException(400, { + message: "Query is required", + }); + } + + const cleanedQuery = query.trim().replace(/;+$/, ""); + const client = await pool.connect(); + + try { + await client.query("BEGIN"); + await client.query("SET LOCAL statement_timeout = $1", [ + `${SANDBOX_STATEMENT_TIMEOUT_MS}`, + ]); + + const startTime = performance.now(); + let result: QueryResult; + try { + result = await client.query(cleanedQuery); + } catch (error) { + if (error instanceof DatabaseError) { + const headers = new Headers(); + headers.set("Content-Type", "application/json"); + throw new HTTPException(400, { + message: "Query failed", + res: new Response( + JSON.stringify({ + error: error.message, + details: { + code: error.code, + position: error.position, + detail: error.detail, + hint: error.hint, + }, + }), + { status: 400, statusText: "Bad Request", headers }, + ), + }); + } + throw error; + } + const duration = performance.now() - startTime; + + const columns = result.fields.map((field) => field.name); + + await client.query("ROLLBACK"); + + return { + columns, + rows: result.rows, + rowCount: result.rows.length, + duration, + message: result.rows.length === 0 ? "OK" : undefined, + }; + } catch (error) { + try { + await client.query("ROLLBACK"); + } catch { + // Ignore rollback errors, original error is more important + } + throw error; + } finally { + client.release(); + } +}; + +export const analyzeQuery = async ({ + query, + db, +}: { + query: ExecuteQueryParams["query"]; + db: DatabaseSchemaType["db"]; +}): Promise => { + const pool = getDbPool(db); + if (!query || !query.trim()) { + throw new HTTPException(400, { + message: "Query is required", + }); + } + + const cleanedQuery = query.trim().replace(/;+$/, ""); + const client = await pool.connect(); + + try { + await client.query("BEGIN"); + await client.query("SET LOCAL statement_timeout = $1", [ + `${SANDBOX_STATEMENT_TIMEOUT_MS}`, + ]); + + const result = await client.query(`EXPLAIN (ANALYZE, FORMAT JSON) ${cleanedQuery}`); + + await client.query("ROLLBACK"); + + const planJson = result.rows[0]?.["QUERY PLAN"]; + const planRoot = Array.isArray(planJson) ? planJson[0] : planJson; + const executionTimeMs = + typeof planRoot?.["Execution Time"] === "number" ? planRoot["Execution Time"] : 0; + + return { + plan: planJson, + executionTimeMs, + }; + } catch (error) { + try { + await client.query("ROLLBACK"); + } catch { + // Ignore rollback errors + } + throw error; + } finally { + client.release(); + } +}; diff --git a/packages/server/src/dao/table-columns.dao.ts b/packages/server/src/dao/table-columns.dao.ts index 7c4815e..3a3b856 100644 --- a/packages/server/src/dao/table-columns.dao.ts +++ b/packages/server/src/dao/table-columns.dao.ts @@ -74,7 +74,18 @@ export async function getTableColumns({ }); } - return rows.map((r: any) => { + type ColumnRow = { + columnName: string; + dataType: string; + isNullable: boolean; + columnDefault: string | null; + isPrimaryKey: boolean; + isForeignKey: boolean; + referencedTable: string | null; + referencedColumn: string | null; + enumValues?: string | string[] | null; + }; + return (rows as ColumnRow[]).map((r) => { // Parse enumValues to always return string[] | null let parsedEnumValues: string[] | null = null; if (r.enumValues) { diff --git a/packages/server/src/routes/chat.routes.ts b/packages/server/src/routes/chat.routes.ts index 236a925..0d1a705 100644 --- a/packages/server/src/routes/chat.routes.ts +++ b/packages/server/src/routes/chat.routes.ts @@ -1,9 +1,30 @@ import { zValidator } from "@hono/zod-validator"; import { Hono } from "hono"; import { DEFAULTS } from "shared/constants"; -import { chatSchema } from "shared/types"; +import { + chatSchema, + databaseSchema, + type SuggestFixResult, + suggestFixSchema, +} from "shared/types"; import { getDetailedSchema } from "@/dao/table-details-schema.js"; -import { generateSystemPrompt } from "@/utils/system-prompt-generator.js"; +import { readSseText } from "@/utils/read-sse-text.js"; +import { + generateSystemPrompt, + getMinimalSystemPrompt, +} from "@/utils/system-prompt-generator.js"; + +/** Validate BYOC proxy URL: https only to avoid SSRF */ +function getProxyUrl(proxyUrl?: string): string { + if (!proxyUrl || !proxyUrl.trim()) return DEFAULTS.PROXY_URL; + try { + const url = new URL(proxyUrl.trim()); + if (url.protocol !== "https:") return DEFAULTS.PROXY_URL; + return url.origin; + } catch { + return DEFAULTS.PROXY_URL; + } +} export const chatRoutes = new Hono() /** @@ -13,24 +34,43 @@ export const chatRoutes = new Hono() /** * POST /chat - Handle AI chat requests with streaming - * Proxies to the Cloudflare Worker which has the Gemini API key + * Proxies to the Cloudflare Worker which has the Gemini API key (or BYOC URL) */ .post("/", zValidator("json", chatSchema), async (c) => { - const { messages, conversationId, db } = c.req.valid("json"); + const { + messages, + conversationId, + db, + includeSchemaInAiContext, + proxyUrl: clientProxyUrl, + provider, + model, + apiKey, + } = c.req.valid("json"); console.log("POST /chat messages", messages); - // Get the database schema and generate system prompt - const schema = await getDetailedSchema(db); - const systemPrompt = generateSystemPrompt(schema); + const useSchemaContext = includeSchemaInAiContext !== false; + let systemPrompt: string; + if (useSchemaContext) { + const schema = await getDetailedSchema(db); + systemPrompt = generateSystemPrompt(schema); + } else { + systemPrompt = getMinimalSystemPrompt(); + } const payload = { messages, conversationId, systemPrompt, + provider, + model, + apiKey, }; + const proxyBaseUrl = getProxyUrl(clientProxyUrl); + // Forward request to the proxy with the system prompt - const proxyResponse = await fetch(`${DEFAULTS.PROXY_URL}/chat`, { + const proxyResponse = await fetch(`${proxyBaseUrl}/chat`, { method: "POST", headers: { "Content-Type": "application/json", @@ -57,6 +97,76 @@ export const chatRoutes = new Hono() Connection: "keep-alive", }, }); - }); + }) + + /** + * POST /chat/suggest-fix - Suggest a fix for a failed SQL query + */ + .post( + "/suggest-fix", + zValidator("query", databaseSchema), + zValidator("json", suggestFixSchema), + async (c) => { + const { db } = c.req.valid("query"); + const { + query, + errorMessage, + errorDetails, + proxyUrl: clientProxyUrl, + provider, + model, + apiKey, + } = c.req.valid("json"); + + const schema = await getDetailedSchema(db); + const systemPrompt = generateSystemPrompt(schema); + const errorDetailsText = errorDetails + ? `\nError details: ${JSON.stringify(errorDetails)}` + : ""; + + const prompt = `The following SQL query failed. Fix it and explain briefly.\n\nFailed query:\n${query}\n\nError:\n${errorMessage}${errorDetailsText}\n\nReturn ONLY valid JSON with keys: suggestedQuery, explanation.`; + + const payload = { + messages: [{ role: "user", content: prompt }], + systemPrompt, + conversationId: "suggest-fix", + provider, + model, + apiKey, + }; + + const proxyBaseUrl = getProxyUrl(clientProxyUrl); + const proxyResponse = await fetch(`${proxyBaseUrl}/chat`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(payload), + }); + + if (!proxyResponse.ok) { + const errorData = await proxyResponse.json(); + return c.json( + { error: errorData.error || "Proxy request failed" }, + proxyResponse.status as 400 | 500, + ); + } + + const rawText = await readSseText(proxyResponse); + let result: SuggestFixResult | null = null; + + try { + result = JSON.parse(rawText) as SuggestFixResult; + } catch { + result = null; + } + + if (!result?.suggestedQuery) { + return c.json({ error: "Failed to parse AI response" }, 500); + } + + return c.json({ data: result }, 200); + }, + ); export type ChatRoutes = typeof chatRoutes; diff --git a/packages/server/src/routes/query.routes.ts b/packages/server/src/routes/query.routes.ts index a0d89e0..44a6df0 100644 --- a/packages/server/src/routes/query.routes.ts +++ b/packages/server/src/routes/query.routes.ts @@ -1,8 +1,33 @@ import { zValidator } from "@hono/zod-validator"; import { Hono } from "hono"; -import { databaseSchema, type ExecuteQueryResult, executeQuerySchema } from "shared/types"; -import type { ApiHandler } from "@/app.types.js"; -import { executeQuery } from "@/dao/query.dao.js"; +import { DEFAULTS } from "shared/constants"; +import { + type AnalyzeQueryResult, + type ApiError, + analyzeQuerySchema, + databaseSchema, + type ExecuteQueryResult, + executeQuerySchema, + type SuggestOptimizationResult, + suggestOptimizationSchema, +} from "shared/types"; +import type { ApiErrorType, ApiHandler } from "@/app.types.js"; +import { analyzeQuery, executeQuery, executeQuerySandbox } from "@/dao/query.dao.js"; +import { getDetailedSchema } from "@/dao/table-details-schema.js"; +import { readSseText } from "@/utils/read-sse-text.js"; +import { generateSystemPrompt } from "@/utils/system-prompt-generator.js"; + +/** Validate BYOC proxy URL: https only to avoid SSRF */ +function getProxyUrl(proxyUrl?: string): string { + if (!proxyUrl || !proxyUrl.trim()) return DEFAULTS.PROXY_URL; + try { + const url = new URL(proxyUrl.trim()); + if (url.protocol !== "https:") return DEFAULTS.PROXY_URL; + return url.origin; + } catch { + return DEFAULTS.PROXY_URL; + } +} export const queryRoutes = new Hono() /** @@ -27,6 +52,99 @@ export const queryRoutes = new Hono() const data = await executeQuery({ query, db }); return c.json({ data }, 200); }, + ) + + /** + * POST /query/sandbox + * Executes a SQL query in a sandbox (transaction + rollback) + * @param {DatabaseSchemaType} query - The database to use + * @param {ExecuteQuerySchemaType} json - The query to execute + * @returns {ApiHandler} The result of the query + */ + .post( + "/sandbox", + zValidator("query", databaseSchema), + zValidator("json", executeQuerySchema), + async (c): ApiHandler => { + const { query } = c.req.valid("json"); + const { db } = c.req.valid("query"); + const data = await executeQuerySandbox({ query, db }); + return c.json({ data }, 200); + }, + ) + + /** + * POST /query/analyze + * Runs EXPLAIN ANALYZE for timing and plan insights + */ + .post( + "/analyze", + zValidator("query", databaseSchema), + zValidator("json", analyzeQuerySchema), + async (c): ApiHandler => { + const { query } = c.req.valid("json"); + const { db } = c.req.valid("query"); + const data = await analyzeQuery({ query, db }); + return c.json({ data }, 200); + }, + ) + + /** + * POST /query/suggest-optimization + * Suggest a faster version of a query + */ + .post( + "/suggest-optimization", + zValidator("query", databaseSchema), + zValidator("json", suggestOptimizationSchema), + async (c): ApiHandler => { + const { db } = c.req.valid("query"); + const { query, proxyUrl, provider, model, apiKey } = c.req.valid("json"); + + const schema = await getDetailedSchema(db); + const systemPrompt = generateSystemPrompt(schema); + const prompt = `Optimize the following SQL query for performance. Keep semantics identical.\n\nQuery:\n${query}\n\nReturn ONLY valid JSON with keys: suggestedQuery, explanation.`; + + const proxyUrlToUse = getProxyUrl(proxyUrl); + + const proxyResponse = await fetch(`${proxyUrlToUse}/chat`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + messages: [{ role: "user", content: prompt }], + systemPrompt, + conversationId: "suggest-optimization", + provider, + model, + apiKey, + }), + }); + + if (!proxyResponse.ok) { + const errorData = (await proxyResponse.json()) as { error?: string }; + const status = proxyResponse.status === 400 ? 400 : 500; + return c.json( + { error: errorData.error ?? "Proxy request failed" } satisfies ApiError, + status, + ) as unknown as ApiErrorType; + } + + const rawText = await readSseText(proxyResponse); + let parsed: SuggestOptimizationResult | null = null; + try { + parsed = JSON.parse(rawText) as SuggestOptimizationResult; + } catch { + parsed = null; + } + + if (!parsed?.suggestedQuery) { + return c.json({ error: "Failed to parse AI response" }, 500); + } + + return c.json({ data: parsed }, 200); + }, ); export type QueryRoutes = typeof queryRoutes; diff --git a/packages/server/src/utils/read-sse-text.ts b/packages/server/src/utils/read-sse-text.ts new file mode 100644 index 0000000..fa5e5a5 --- /dev/null +++ b/packages/server/src/utils/read-sse-text.ts @@ -0,0 +1,37 @@ +export async function readSseText(response: Response): Promise { + if (!response.body) return ""; + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + let output = ""; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + + for (const line of lines) { + if (!line.startsWith("data: ")) continue; + const data = line.slice(6).trim(); + if (!data || data === "[DONE]") continue; + const chunk = JSON.parse(data) as { + type?: string; + delta?: string; + content?: string; + error?: { message?: string }; + }; + if (chunk.type === "error") { + throw new Error(chunk.error?.message ?? "AI stream error"); + } + if (typeof chunk.delta === "string") { + output += chunk.delta; + } else if (typeof chunk.content === "string") { + output += chunk.content; + } + } + } + + return output.trim(); +} diff --git a/packages/server/src/utils/system-prompt-generator.ts b/packages/server/src/utils/system-prompt-generator.ts index 2879cc8..792892b 100644 --- a/packages/server/src/utils/system-prompt-generator.ts +++ b/packages/server/src/utils/system-prompt-generator.ts @@ -1,5 +1,16 @@ import type { DatabaseSchema } from "shared/types"; +const MINIMAL_SYSTEM_PROMPT = `You are a database assistant for db-studio. Your responses must be CONCISE and FOCUSED. +You help with general database and SQL questions. Without schema context, use generic table/column names in examples. +When generating SQL, wrap queries in \`\`\`sql code blocks. Keep responses SHORT.`; + +/** + * Generate minimal system prompt (no schema context) + */ +export function getMinimalSystemPrompt(): string { + return MINIMAL_SYSTEM_PROMPT; +} + /** * Generate system prompt with database context */ diff --git a/packages/shared/src/constants/chat.ts b/packages/shared/src/constants/chat.ts index 46d2933..61a0ff6 100644 --- a/packages/shared/src/constants/chat.ts +++ b/packages/shared/src/constants/chat.ts @@ -13,35 +13,48 @@ export const CHAT_SUGGESTIONS = [ "Help me write a safe UPDATE query", ]; -// const MODEL_LIST = [ -// { -// id: "gpt-4o", -// name: "GPT-4o", -// chef: "OpenAI", -// chefSlug: "openai", -// }, -// { -// id: "gpt-4o-mini", -// name: "GPT-4o Mini", -// chef: "OpenAI", -// chefSlug: "openai", -// }, -// { -// id: "claude-opus-4-20250514", -// name: "Claude 4 Opus", -// chef: "Anthropic", -// chefSlug: "anthropic", -// }, -// { -// id: "claude-sonnet-4-20250514", -// name: "Claude 4 Sonnet", -// chef: "Anthropic", -// chefSlug: "anthropic", -// }, -// { -// id: "gemini-2.0-flash-exp", -// name: "Gemini 2.0 Flash", -// chef: "Google", -// chefSlug: "google", -// }, -// ]; +export const AI_PROVIDERS = ["gemini", "openai", "anthropic"] as const; + +export type AiProvider = (typeof AI_PROVIDERS)[number]; + +export const MODEL_LIST = [ + { + id: "gemini-3-flash-preview", + name: "Gemini 3 Flash Preview", + provider: "gemini", + }, + { + id: "gemini-2.0-flash-exp", + name: "Gemini 2.0 Flash", + provider: "gemini", + }, + { + id: "gpt-4o", + name: "GPT-4o", + provider: "openai", + }, + { + id: "gpt-4o-mini", + name: "GPT-4o Mini", + provider: "openai", + }, + { + id: "claude-3-5-sonnet-20241022", + name: "Claude 3.5 Sonnet", + provider: "anthropic", + }, + { + id: "claude-3-opus-20240229", + name: "Claude 3 Opus", + provider: "anthropic", + }, +] as const; + +export const MODELS_BY_PROVIDER = MODEL_LIST.reduce( + (acc, model) => { + acc[model.provider] ??= []; + acc[model.provider]?.push(model); + return acc; + }, + {} as Record, +); diff --git a/packages/shared/src/types/api-response.types.ts b/packages/shared/src/types/api-response.types.ts index a3e25e8..a1571a4 100644 --- a/packages/shared/src/types/api-response.types.ts +++ b/packages/shared/src/types/api-response.types.ts @@ -12,5 +12,5 @@ export type BaseResponse = { */ export type ApiError = { error: string; - details?: string; + details?: unknown; }; diff --git a/packages/shared/src/types/chat.types.ts b/packages/shared/src/types/chat.types.ts index a57712c..7628884 100644 --- a/packages/shared/src/types/chat.types.ts +++ b/packages/shared/src/types/chat.types.ts @@ -1,13 +1,46 @@ import { z } from "zod"; import { databaseSchema } from "./database.types.js"; -export const chatSchema = z.object({ - messages: z.array( - z.object({ - role: z.enum(["user", "assistant"]), - content: z.string("Content is required"), - }), - ), +const messageSchema = z.object({ + role: z.enum(["user", "assistant"]), + content: z.string("Content is required"), +}); + +/** Optional nested payload from TanStack AI client (body.data) */ +const chatDataSchema = z.object({ conversationId: z.string().optional(), - db: databaseSchema.shape.db, + db: databaseSchema.shape.db.optional(), + includeSchemaInAiContext: z.boolean().optional(), + proxyUrl: z.string().optional(), + provider: z.string().optional(), + model: z.string().optional(), + apiKey: z.string().optional(), }); + +/** Accepts flat body or TanStack shape { messages, data: { conversationId, db, ... } } */ +export const chatSchema = z + .object({ + messages: z.array(messageSchema), + conversationId: z.string().optional(), + db: databaseSchema.shape.db.optional(), + includeSchemaInAiContext: z.boolean().optional(), + proxyUrl: z.string().optional(), + provider: z.string().optional(), + model: z.string().optional(), + apiKey: z.string().optional(), + data: chatDataSchema.optional(), + }) + .refine( + (v) => v.db !== undefined || v.data?.db !== undefined, + "db is required (top-level or in data)", + ) + .transform((v) => ({ + messages: v.messages, + conversationId: v.conversationId ?? v.data?.conversationId, + db: (v.db ?? v.data?.db) as string, + includeSchemaInAiContext: v.includeSchemaInAiContext ?? v.data?.includeSchemaInAiContext, + proxyUrl: v.proxyUrl ?? v.data?.proxyUrl, + provider: v.provider ?? v.data?.provider, + model: v.model ?? v.data?.model, + apiKey: v.apiKey ?? v.data?.apiKey, + })); diff --git a/packages/shared/src/types/index.ts b/packages/shared/src/types/index.ts index 5556126..3b6847d 100644 --- a/packages/shared/src/types/index.ts +++ b/packages/shared/src/types/index.ts @@ -12,7 +12,9 @@ export * from "./delete-column.types.js"; // done export * from "./delete-record.types.js"; // done export * from "./execute-query.types.js"; // done export * from "./export-table.types.js"; +export * from "./query-optimization.types.js"; export * from "./rate-limit-response.type.js"; +export * from "./suggest-fix.types.js"; export * from "./table-data.types.js"; // done export * from "./table-info.type.js"; // done export * from "./update-recors.types.js"; // done diff --git a/packages/shared/src/types/query-optimization.types.ts b/packages/shared/src/types/query-optimization.types.ts new file mode 100644 index 0000000..2855292 --- /dev/null +++ b/packages/shared/src/types/query-optimization.types.ts @@ -0,0 +1,27 @@ +import { z } from "zod"; + +export const analyzeQuerySchema = z.object({ + query: z.string("Query is required"), +}); + +export type AnalyzeQueryParams = z.infer; + +export type AnalyzeQueryResult = { + plan: unknown; + executionTimeMs: number; +}; + +export const suggestOptimizationSchema = z.object({ + query: z.string("Query is required"), + proxyUrl: z.string().optional(), + provider: z.string().optional(), + model: z.string().optional(), + apiKey: z.string().optional(), +}); + +export type SuggestOptimizationParams = z.infer; + +export type SuggestOptimizationResult = { + suggestedQuery: string; + explanation: string; +}; diff --git a/packages/shared/src/types/suggest-fix.types.ts b/packages/shared/src/types/suggest-fix.types.ts new file mode 100644 index 0000000..20c1b30 --- /dev/null +++ b/packages/shared/src/types/suggest-fix.types.ts @@ -0,0 +1,25 @@ +import { z } from "zod"; + +export const suggestFixSchema = z.object({ + query: z.string("Query is required"), + errorMessage: z.string("Error message is required"), + proxyUrl: z.string().optional(), + provider: z.string().optional(), + model: z.string().optional(), + apiKey: z.string().optional(), + errorDetails: z + .object({ + code: z.string().optional(), + position: z.string().optional(), + detail: z.string().optional(), + hint: z.string().optional(), + }) + .optional(), +}); + +export type SuggestFixParams = z.infer; + +export type SuggestFixResult = { + suggestedQuery: string; + explanation: string; +};