diff --git a/.gitignore b/.gitignore index 2ab8cb2a0e..8ee1d607dd 100644 --- a/.gitignore +++ b/.gitignore @@ -109,6 +109,8 @@ runs/ __pycache__ tmpfork +.cmux/*.tmp.* +.mux/*.tmp.* .mux-agent-cli storybook-static/ *.tgz diff --git a/src/common/constants/paths.ts b/src/common/constants/paths.ts index 389d8c6a60..192907f1ca 100644 --- a/src/common/constants/paths.ts +++ b/src/common/constants/paths.ts @@ -77,6 +77,17 @@ export function getMuxSessionsDir(rootDir?: string): string { return join(root, "sessions"); } +/** + * Get the directory where user-installed extensions live. + * Example: ~/.mux/ext + * + * @param rootDir - Optional root directory (defaults to getMuxHome()) + */ +export function getMuxExtDir(rootDir?: string): string { + const root = rootDir ?? getMuxHome(); + return join(root, "ext"); +} + /** * Get the directory where plan files are stored. * Example: ~/.mux/plans/workspace-id.md diff --git a/src/node/services/aiService.ts b/src/node/services/aiService.ts index 373346c32c..8794503e41 100644 --- a/src/node/services/aiService.ts +++ b/src/node/services/aiService.ts @@ -18,6 +18,7 @@ import { import type { MuxMessage, MuxTextPart } from "@/common/types/message"; import { createMuxMessage } from "@/common/types/message"; import type { Config, ProviderConfig } from "@/node/config"; +import { ExtensionManager } from "./extensions/extensionManager"; import { StreamManager } from "./streamManager"; import type { InitStateManager } from "./initStateManager"; import type { SendMessageError } from "@/common/types/errors"; @@ -278,6 +279,7 @@ export class AIService extends EventEmitter { private readonly initStateManager: InitStateManager; private readonly mockModeEnabled: boolean; private readonly mockScenarioPlayer?: MockScenarioPlayer; + private readonly extensionManager: ExtensionManager; private readonly backgroundProcessManager?: BackgroundProcessManager; constructor( @@ -297,6 +299,7 @@ export class AIService extends EventEmitter { this.partialService = partialService; this.initStateManager = initStateManager; this.backgroundProcessManager = backgroundProcessManager; + this.extensionManager = new ExtensionManager(); this.streamManager = new StreamManager(historyService, partialService, sessionUsageService); void this.ensureSessionsDir(); this.setupStreamEventForwarding(); @@ -1172,10 +1175,19 @@ export class AIService extends EventEmitter { // Apply tool policy to filter tools (if policy provided) const tools = applyToolPolicy(allTools, toolPolicy); + const toolsWithExtensions = this.extensionManager.wrapToolsWithPostToolUse(tools, { + workspaceId, + projectPath: metadata.projectPath, + workspacePath, + runtimeConfig: metadata.runtimeConfig, + runtimeTempDir, + runtime, + }); + log.info("AIService.streamMessage: tool configuration", { workspaceId, model: modelString, - toolNames: Object.keys(tools), + toolNames: Object.keys(toolsWithExtensions), hasToolPolicy: Boolean(toolPolicy), }); @@ -1336,7 +1348,7 @@ export class AIService extends EventEmitter { systemMessage, messages: finalMessages, tools: Object.fromEntries( - Object.entries(tools).map(([name, tool]) => [ + Object.entries(toolsWithExtensions).map(([name, tool]) => [ name, { description: tool.description, @@ -1365,7 +1377,7 @@ export class AIService extends EventEmitter { systemMessage, runtime, abortSignal, - tools, + toolsWithExtensions, { systemMessageTokens, timestamp: Date.now(), diff --git a/src/node/services/extensions/extensionManager.test.ts b/src/node/services/extensions/extensionManager.test.ts new file mode 100644 index 0000000000..06b48ec4a0 --- /dev/null +++ b/src/node/services/extensions/extensionManager.test.ts @@ -0,0 +1,114 @@ +import * as os from "os"; +import * as path from "path"; +import { mkdtemp, mkdir, rm, writeFile } from "fs/promises"; +import { z } from "zod"; +import { tool } from "ai"; +import { ExtensionManager } from "./extensionManager"; + +describe("ExtensionManager", () => { + it("wrapToolsWithPostToolUse lets extensions transform tool results", async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), "mux-ext-mgr-")); + try { + const extDir = path.join(tempRoot, "ext"); + await mkdir(extDir, { recursive: true }); + + await writeFile( + path.join(extDir, "transform.js"), + [ + "module.exports = {", + " onPostToolUse: (payload) => ({ result: { seen: payload.toolCallId, original: payload.result } }),", + "};", + "", + ].join("\n"), + "utf-8" + ); + + const manager = new ExtensionManager({ extDir, hookTimeoutMs: 1000 }); + + const base = tool({ + description: "test", + inputSchema: z.object({ x: z.number() }), + execute: async ({ x }: { x: number }) => ({ x }), + }); + + const wrapped = manager.wrapToolsWithPostToolUse( + { test: base }, + { + workspaceId: "w1", + projectPath: "/tmp/project", + workspacePath: "/tmp/project", + runtimeConfig: { type: "local" }, + runtimeTempDir: "/tmp", + // Runtime isn't used by this test extension; provide a minimal stub. + runtime: { + exec: async () => { + throw new Error("not used"); + }, + } as never, + } + ); + + const testTool = wrapped.test; + if (!testTool?.execute) { + throw new Error("wrapped tool missing execute"); + } + + const result = await (testTool.execute as (args: unknown, options: unknown) => Promise)( + { x: 1 }, + { toolCallId: "call-123" } + ); + + expect(result).toEqual({ seen: "call-123", original: { x: 1 } }); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } + }); + + it("returns original tool result when hook times out", async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), "mux-ext-mgr-timeout-")); + try { + const extDir = path.join(tempRoot, "ext"); + await mkdir(extDir, { recursive: true }); + + await writeFile( + path.join(extDir, "hang.js"), + [ + "module.exports = {", + " onPostToolUse: async () => new Promise(() => {}),", + "};", + "", + ].join("\n"), + "utf-8" + ); + + const manager = new ExtensionManager({ extDir, hookTimeoutMs: 10 }); + + const base = tool({ + description: "test", + inputSchema: z.object({}), + execute: async () => "ok", + }); + + const wrapped = manager.wrapToolsWithPostToolUse( + { test: base }, + { + workspaceId: "w1", + projectPath: "/tmp/project", + workspacePath: "/tmp/project", + runtimeConfig: { type: "local" }, + runtimeTempDir: "/tmp", + runtime: {} as never, + } + ); + + const result = await (wrapped.test!.execute as (args: unknown, options: unknown) => Promise)( + {}, + { toolCallId: "call" } + ); + + expect(result).toBe("ok"); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } + }); +}); diff --git a/src/node/services/extensions/extensionManager.ts b/src/node/services/extensions/extensionManager.ts new file mode 100644 index 0000000000..662f3e328c --- /dev/null +++ b/src/node/services/extensions/extensionManager.ts @@ -0,0 +1,219 @@ +import type { Tool } from "ai"; +import { log } from "@/node/services/log"; +import { discoverExtensions, type ExtensionInfo } from "@/node/utils/extensions/discovery"; +import { getMuxExtDir } from "@/common/constants/paths"; +import type { Runtime } from "@/node/runtime/Runtime"; +import type { RuntimeConfig } from "@/common/types/runtime"; +import type { Extension, PostToolUseHookPayload, PostToolUseHookReturn } from "./types"; + +function withTimeout(promise: Promise, timeoutMs: number, label: string): Promise { + return new Promise((resolve, reject) => { + const timer = setTimeout(() => { + reject(new Error(`${label} timed out after ${timeoutMs}ms`)); + }, timeoutMs); + + promise + .then((value) => { + clearTimeout(timer); + resolve(value); + }) + .catch((error: unknown) => { + clearTimeout(timer); + reject(error); + }); + }); +} + +function normalizeExtensionExport(exported: unknown): Extension | null { + // Support both: + // - module.exports = { ... } + // - module.exports = { default: { ... } } + const candidate = (() => { + if (typeof exported === "object" && exported !== null) { + const rec = exported as Record; + if ("default" in rec) { + return rec.default; + } + } + return exported; + })(); + + if (typeof candidate !== "object" || candidate === null) { + return null; + } + + // NOTE: We don't validate every property; we just check that it's an object. + // Hook existence is checked at call time. + return candidate as Extension; +} + +function getToolCallIdFromOptions(options: unknown): string { + if (typeof options !== "object" || options === null) { + return "unknown"; + } + const rec = options as Record; + const value = rec.toolCallId; + return typeof value === "string" ? value : "unknown"; +} + +export interface ExtensionManagerOptions { + /** Override extension directory (useful for tests) */ + extDir?: string; + /** Per-extension hook timeout */ + hookTimeoutMs?: number; +} + +interface LoadedExtension { + id: string; + entryPath: string; + extension: Extension; +} + +export interface WrapToolsContext { + workspaceId: string; + projectPath: string; + workspacePath: string; + runtimeConfig: RuntimeConfig; + runtimeTempDir: string; + runtime: Runtime; +} + +export class ExtensionManager { + private readonly extDir: string; + private readonly hookTimeoutMs: number; + + private initPromise: Promise | null = null; + private loaded: LoadedExtension[] = []; + + constructor(options?: ExtensionManagerOptions) { + this.extDir = options?.extDir ?? getMuxExtDir(); + this.hookTimeoutMs = options?.hookTimeoutMs ?? 5000; + } + + private async initializeOnce(): Promise { + if (this.initPromise) { + return this.initPromise; + } + + this.initPromise = (async () => { + const discovered = await discoverExtensions(this.extDir); + if (discovered.length === 0) { + log.debug("No extensions discovered", { extDir: this.extDir }); + return; + } + + log.info(`Loading ${discovered.length} extension(s) from ${this.extDir}`); + + const loaded: LoadedExtension[] = []; + for (const ext of discovered) { + const result = this.loadExtension(ext); + if (result) { + loaded.push(result); + } + } + + this.loaded = loaded; + log.info(`Loaded ${loaded.length}/${discovered.length} extension(s)`); + })(); + + return this.initPromise; + } + + private loadExtension(ext: ExtensionInfo): LoadedExtension | null { + try { + const exported: unknown = require(ext.entryPath); + const normalized = normalizeExtensionExport(exported); + if (!normalized) { + log.warn("Extension did not export an object", { id: ext.id, entryPath: ext.entryPath }); + return null; + } + + return { id: ext.id, entryPath: ext.entryPath, extension: normalized }; + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + log.error("Failed to load extension", { id: ext.id, entryPath: ext.entryPath, error: message }); + return null; + } + } + + private async runPostToolUseHook(payload: PostToolUseHookPayload): Promise { + await this.initializeOnce(); + if (this.loaded.length === 0) { + return payload.result; + } + + let currentResult: unknown = payload.result; + + for (const loaded of this.loaded) { + const handler = loaded.extension.onPostToolUse; + if (!handler) { + continue; + } + + try { + const hookPayload: PostToolUseHookPayload = { ...payload, result: currentResult }; + const returned = await withTimeout( + Promise.resolve(handler(hookPayload)), + this.hookTimeoutMs, + `Extension ${loaded.id} onPostToolUse` + ); + + const cast = returned as PostToolUseHookReturn; + if (typeof cast === "object" && cast !== null && "result" in cast) { + currentResult = (cast as { result: unknown }).result; + } + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + log.error("Extension onPostToolUse failed", { + id: loaded.id, + entryPath: loaded.entryPath, + toolName: payload.toolName, + error: message, + }); + } + } + + return currentResult; + } + + /** + * Wrap tools so that after each tool executes, we call extensions' onPostToolUse hooks. + * + * If an extension returns { result }, that becomes the tool result returned to the model. + */ + wrapToolsWithPostToolUse(tools: Record, ctx: WrapToolsContext): Record { + const wrapped: Record = {}; + + for (const [toolName, tool] of Object.entries(tools)) { + if (!tool.execute) { + wrapped[toolName] = tool; + continue; + } + + const originalExecute = tool.execute; + wrapped[toolName] = { + ...tool, + execute: async (args: Parameters[0], options) => { + const result: unknown = await originalExecute(args, options); + + const toolCallId = getToolCallIdFromOptions(options); + return this.runPostToolUseHook({ + workspaceId: ctx.workspaceId, + projectPath: ctx.projectPath, + workspacePath: ctx.workspacePath, + runtimeConfig: ctx.runtimeConfig, + runtimeTempDir: ctx.runtimeTempDir, + toolName, + toolCallId, + args, + result, + timestamp: Date.now(), + runtime: ctx.runtime, + }); + }, + }; + } + + return wrapped; + } +} diff --git a/src/node/services/extensions/types.ts b/src/node/services/extensions/types.ts new file mode 100644 index 0000000000..49f9aeed32 --- /dev/null +++ b/src/node/services/extensions/types.ts @@ -0,0 +1,31 @@ +import type { Runtime } from "@/node/runtime/Runtime"; +import type { RuntimeConfig } from "@/common/types/runtime"; + +export interface PostToolUseHookPayload { + workspaceId: string; + projectPath: string; + workspacePath: string; + runtimeConfig: RuntimeConfig; + runtimeTempDir: string; + + toolName: string; + toolCallId: string; + args: unknown; + result: unknown; + timestamp: number; + + /** Full Runtime handle for this workspace (local/worktree/ssh). */ + runtime: Runtime; +} + +/** + * Optional return value from onPostToolUse. + * + * If an extension returns { result }, that value becomes the tool result returned + * to the model (and shown in the UI). + */ +export type PostToolUseHookReturn = void | { result: unknown }; + +export interface Extension { + onPostToolUse?: (payload: PostToolUseHookPayload) => Promise | PostToolUseHookReturn; +} diff --git a/src/node/services/serverService.test.ts b/src/node/services/serverService.test.ts index 25676aaac8..5eba36e2b4 100644 --- a/src/node/services/serverService.test.ts +++ b/src/node/services/serverService.test.ts @@ -77,15 +77,13 @@ describe("ServerService.startServer", () => { } test("cleans up server when lockfile acquisition fails", async () => { - // Skip on Windows where chmod doesn't work the same way - if (process.platform === "win32") { - return; - } - const service = new ServerService(); - // Make muxHome read-only so lockfile.acquire() will fail - await fs.chmod(tempDir, 0o444); + // Force lockfile.acquire() to fail in a way that works even when tests run as root. + // ServerLockfile writes to `${muxHome}/server.lock.${pid}.tmp` before renaming. + // By creating that temp path as a directory, fs.writeFile() will fail deterministically. + const tempLockPath = path.join(tempDir, `server.lock.${process.pid}.tmp`); + await fs.mkdir(tempLockPath); let thrownError: Error | null = null; @@ -103,7 +101,7 @@ describe("ServerService.startServer", () => { // Verify that an error was thrown expect(thrownError).not.toBeNull(); - expect(thrownError!.message).toMatch(/EACCES|permission denied/i); + expect(thrownError!.message).toMatch(/EISDIR|is a directory|EACCES|permission denied/i); // Verify the server is NOT left running expect(service.isServerRunning()).toBe(false); diff --git a/src/node/utils/extensions/discovery.test.ts b/src/node/utils/extensions/discovery.test.ts new file mode 100644 index 0000000000..06548f6894 --- /dev/null +++ b/src/node/utils/extensions/discovery.test.ts @@ -0,0 +1,61 @@ +import * as os from "os"; +import * as path from "path"; +import { mkdtemp, mkdir, rm, writeFile } from "fs/promises"; +import { discoverExtensions } from "./discovery"; + +async function writeJson(filePath: string, value: unknown): Promise { + await writeFile(filePath, JSON.stringify(value, null, 2), "utf-8"); +} + +describe("discoverExtensions", () => { + it("discovers .js file extensions and folder extensions with manifest.json", async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), "mux-ext-discovery-")); + try { + const extDir = path.join(tempRoot, "ext"); + await mkdir(extDir, { recursive: true }); + + // File extension + await writeFile( + path.join(extDir, "a-file.js"), + "module.exports = { onPostToolUse() {} };\n", + "utf-8" + ); + + // Folder extension + const folderExtDir = path.join(extDir, "b-folder"); + await mkdir(folderExtDir, { recursive: true }); + await writeJson(path.join(folderExtDir, "manifest.json"), { entrypoint: "index.js" }); + await writeFile( + path.join(folderExtDir, "index.js"), + "module.exports = { onPostToolUse() {} };\n", + "utf-8" + ); + + // Invalid: missing entrypoint + const missingEntrypointDir = path.join(extDir, "c-missing"); + await mkdir(missingEntrypointDir, { recursive: true }); + await writeJson(path.join(missingEntrypointDir, "manifest.json"), { entrypoint: "nope.js" }); + + // Invalid: non-js file + await writeFile(path.join(extDir, "not-js.txt"), "nope", "utf-8"); + + const result = await discoverExtensions(extDir); + expect(result.map((e) => e.id)).toEqual(["a-file", "b-folder"]); + expect(result[0]?.type).toBe("file"); + expect(result[1]?.type).toBe("folder"); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } + }); + + it("returns [] if extension directory does not exist", async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), "mux-ext-missing-")); + try { + const missingExtDir = path.join(tempRoot, "does-not-exist"); + const result = await discoverExtensions(missingExtDir); + expect(result).toEqual([]); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } + }); +}); diff --git a/src/node/utils/extensions/discovery.ts b/src/node/utils/extensions/discovery.ts new file mode 100644 index 0000000000..a9f932f549 --- /dev/null +++ b/src/node/utils/extensions/discovery.ts @@ -0,0 +1,113 @@ +import * as path from "path"; +import { readdir, readFile, stat } from "fs/promises"; +import { log } from "@/node/services/log"; + +export interface ExtensionManifest { + entrypoint: string; +} + +export type ExtensionType = "file" | "folder"; + +export interface ExtensionInfo { + /** + * Stable identifier for the extension. + * - file extensions: basename without extension + * - folder extensions: folder name + */ + id: string; + + /** Absolute path to the extension entrypoint JS file */ + entryPath: string; + + /** Absolute directory containing the extension (file's parent or folder itself) */ + rootDir: string; + + type: ExtensionType; +} + +async function fileExists(filePath: string): Promise { + try { + const s = await stat(filePath); + return s.isFile(); + } catch { + return false; + } +} + +async function readManifest(manifestPath: string): Promise { + try { + const raw = await readFile(manifestPath, "utf-8"); + const parsed = JSON.parse(raw) as unknown; + if ( + typeof parsed === "object" && + parsed !== null && + "entrypoint" in parsed && + typeof (parsed as { entrypoint?: unknown }).entrypoint === "string" + ) { + return { entrypoint: (parsed as { entrypoint: string }).entrypoint }; + } + return null; + } catch (error) { + log.warn("Failed to read extension manifest", { manifestPath, error }); + return null; + } +} + +/** + * Discover extensions from a single directory. + * + * Supported layouts: + * - ~/.mux/ext/my-ext.js + * - ~/.mux/ext/my-ext/manifest.json (with { "entrypoint": "index.js" }) + */ +export async function discoverExtensions(extDir: string): Promise { + let entries; + try { + entries = await readdir(extDir, { withFileTypes: true }); + } catch { + return []; + } + + const extensions: ExtensionInfo[] = []; + + for (const entry of entries) { + const abs = path.join(extDir, entry.name); + + if (entry.isFile()) { + if (!entry.name.endsWith(".js") && !entry.name.endsWith(".cjs")) { + continue; + } + + const id = entry.name.replace(/\.(cjs|js)$/u, ""); + extensions.push({ + id, + entryPath: abs, + rootDir: extDir, + type: "file", + }); + continue; + } + + if (entry.isDirectory()) { + const manifestPath = path.join(abs, "manifest.json"); + const manifest = await readManifest(manifestPath); + if (!manifest) continue; + + const entryPath = path.join(abs, manifest.entrypoint); + if (!(await fileExists(entryPath))) { + log.warn("Extension manifest entrypoint missing", { entryPath, manifestPath }); + continue; + } + + extensions.push({ + id: entry.name, + entryPath, + rootDir: abs, + type: "folder", + }); + } + } + + extensions.sort((a, b) => a.id.localeCompare(b.id)); + return extensions; +}