Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ runs/
__pycache__

tmpfork
.cmux/*.tmp.*
.mux/*.tmp.*
.mux-agent-cli
storybook-static/
*.tgz
Expand Down
11 changes: 11 additions & 0 deletions src/common/constants/paths.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ export function getMuxSessionsDir(rootDir?: string): string {
return join(root, "sessions");
}

/**
* Get the directory where user-installed extensions live.
* Example: ~/.mux/ext
*
* @param rootDir - Optional root directory (defaults to getMuxHome())
*/
export function getMuxExtDir(rootDir?: string): string {
const root = rootDir ?? getMuxHome();
return join(root, "ext");
}

/**
* Get the directory where plan files are stored.
* Example: ~/.mux/plans/workspace-id.md
Expand Down
18 changes: 15 additions & 3 deletions src/node/services/aiService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
import type { MuxMessage, MuxTextPart } from "@/common/types/message";
import { createMuxMessage } from "@/common/types/message";
import type { Config, ProviderConfig } from "@/node/config";
import { ExtensionManager } from "./extensions/extensionManager";
import { StreamManager } from "./streamManager";
import type { InitStateManager } from "./initStateManager";
import type { SendMessageError } from "@/common/types/errors";
Expand Down Expand Up @@ -278,6 +279,7 @@ export class AIService extends EventEmitter {
private readonly initStateManager: InitStateManager;
private readonly mockModeEnabled: boolean;
private readonly mockScenarioPlayer?: MockScenarioPlayer;
private readonly extensionManager: ExtensionManager;
private readonly backgroundProcessManager?: BackgroundProcessManager;

constructor(
Expand All @@ -297,6 +299,7 @@ export class AIService extends EventEmitter {
this.partialService = partialService;
this.initStateManager = initStateManager;
this.backgroundProcessManager = backgroundProcessManager;
this.extensionManager = new ExtensionManager();
this.streamManager = new StreamManager(historyService, partialService, sessionUsageService);
void this.ensureSessionsDir();
this.setupStreamEventForwarding();
Expand Down Expand Up @@ -1172,10 +1175,19 @@ export class AIService extends EventEmitter {
// Apply tool policy to filter tools (if policy provided)
const tools = applyToolPolicy(allTools, toolPolicy);

const toolsWithExtensions = this.extensionManager.wrapToolsWithPostToolUse(tools, {
workspaceId,
projectPath: metadata.projectPath,
workspacePath,
runtimeConfig: metadata.runtimeConfig,
runtimeTempDir,
runtime,
});

log.info("AIService.streamMessage: tool configuration", {
workspaceId,
model: modelString,
toolNames: Object.keys(tools),
toolNames: Object.keys(toolsWithExtensions),
hasToolPolicy: Boolean(toolPolicy),
});

Expand Down Expand Up @@ -1336,7 +1348,7 @@ export class AIService extends EventEmitter {
systemMessage,
messages: finalMessages,
tools: Object.fromEntries(
Object.entries(tools).map(([name, tool]) => [
Object.entries(toolsWithExtensions).map(([name, tool]) => [
name,
{
description: tool.description,
Expand Down Expand Up @@ -1365,7 +1377,7 @@ export class AIService extends EventEmitter {
systemMessage,
runtime,
abortSignal,
tools,
toolsWithExtensions,
{
systemMessageTokens,
timestamp: Date.now(),
Expand Down
114 changes: 114 additions & 0 deletions src/node/services/extensions/extensionManager.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import * as os from "os";
import * as path from "path";
import { mkdtemp, mkdir, rm, writeFile } from "fs/promises";
import { z } from "zod";
import { tool } from "ai";
import { ExtensionManager } from "./extensionManager";

describe("ExtensionManager", () => {
it("wrapToolsWithPostToolUse lets extensions transform tool results", async () => {
const tempRoot = await mkdtemp(path.join(os.tmpdir(), "mux-ext-mgr-"));
try {
const extDir = path.join(tempRoot, "ext");
await mkdir(extDir, { recursive: true });

await writeFile(
path.join(extDir, "transform.js"),
[
"module.exports = {",
" onPostToolUse: (payload) => ({ result: { seen: payload.toolCallId, original: payload.result } }),",
"};",
"",
].join("\n"),
"utf-8"
);

const manager = new ExtensionManager({ extDir, hookTimeoutMs: 1000 });

const base = tool({
description: "test",
inputSchema: z.object({ x: z.number() }),
execute: async ({ x }: { x: number }) => ({ x }),
});

const wrapped = manager.wrapToolsWithPostToolUse(
{ test: base },
{
workspaceId: "w1",
projectPath: "/tmp/project",
workspacePath: "/tmp/project",
runtimeConfig: { type: "local" },
runtimeTempDir: "/tmp",
// Runtime isn't used by this test extension; provide a minimal stub.
runtime: {
exec: async () => {
throw new Error("not used");
},
} as never,
}
);

const testTool = wrapped.test;
if (!testTool?.execute) {
throw new Error("wrapped tool missing execute");
}

const result = await (testTool.execute as (args: unknown, options: unknown) => Promise<unknown>)(
{ x: 1 },
{ toolCallId: "call-123" }
);

expect(result).toEqual({ seen: "call-123", original: { x: 1 } });
} finally {
await rm(tempRoot, { recursive: true, force: true });
}
});

it("returns original tool result when hook times out", async () => {
const tempRoot = await mkdtemp(path.join(os.tmpdir(), "mux-ext-mgr-timeout-"));
try {
const extDir = path.join(tempRoot, "ext");
await mkdir(extDir, { recursive: true });

await writeFile(
path.join(extDir, "hang.js"),
[
"module.exports = {",
" onPostToolUse: async () => new Promise(() => {}),",
"};",
"",
].join("\n"),
"utf-8"
);

const manager = new ExtensionManager({ extDir, hookTimeoutMs: 10 });

const base = tool({
description: "test",
inputSchema: z.object({}),
execute: async () => "ok",
});

const wrapped = manager.wrapToolsWithPostToolUse(
{ test: base },
{
workspaceId: "w1",
projectPath: "/tmp/project",
workspacePath: "/tmp/project",
runtimeConfig: { type: "local" },
runtimeTempDir: "/tmp",
runtime: {} as never,
}
);

const result = await (wrapped.test!.execute as (args: unknown, options: unknown) => Promise<unknown>)(
{},
{ toolCallId: "call" }
);

expect(result).toBe("ok");
} finally {
await rm(tempRoot, { recursive: true, force: true });
}
});
});
Loading
Loading