Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/gemini/oauth.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import { beforeEach, describe, expect, it, mock } from "bun:test";

import { exchangeGeminiWithVerifier } from "./oauth";

describe("exchangeGeminiWithVerifier", () => {
beforeEach(() => {
mock.restore();
});

it("returns a failure when code is not a string", async () => {
const result = await exchangeGeminiWithVerifier(
{ code: "not-a-string" } as unknown as string,
"verifier",
);

expect(result.type).toBe("failed");
if (result.type === "failed") {
expect(result.error).toContain("Missing authorization code");
}
});

it("returns a failure when verifier is not a string", async () => {
const result = await exchangeGeminiWithVerifier(
"auth-code",
{ verifier: "not-a-string" } as unknown as string,
);

expect(result.type).toBe("failed");
if (result.type === "failed") {
expect(result.error).toContain("Missing PKCE verifier");
}
});

it("allows retry after a failed token exchange", async () => {
const fetchMock = mock(async () => {
return new Response(
JSON.stringify({ error: "internal_error" }),
{ status: 500, statusText: "Internal Server Error" },
);
});
(globalThis as { fetch: typeof fetch }).fetch = fetchMock as unknown as typeof fetch;

const first = await exchangeGeminiWithVerifier("retry-code-1", "retry-verifier-1");
const second = await exchangeGeminiWithVerifier("retry-code-1", "retry-verifier-1");

expect(first.type).toBe("failed");
expect(second.type).toBe("failed");
expect(fetchMock.mock.calls.length).toBe(2);
});

it("marks code consumed after successful exchange", async () => {
let callCount = 0;
const fetchMock = mock(async () => {
callCount += 1;
if (callCount === 1) {
return new Response(
JSON.stringify({
access_token: "access-token",
expires_in: 3600,
refresh_token: "refresh-token",
}),
{ status: 200 },
);
}

return new Response(JSON.stringify({ email: "user@example.com" }), { status: 200 });
});
(globalThis as { fetch: typeof fetch }).fetch = fetchMock as unknown as typeof fetch;

const first = await exchangeGeminiWithVerifier("success-code-1", "success-verifier-1");
const second = await exchangeGeminiWithVerifier("success-code-1", "success-verifier-1");

expect(first.type).toBe("success");
expect(second.type).toBe("failed");
if (second.type === "failed") {
expect(second.error).toContain("already submitted");
}
expect(callCount).toBe(2);
});
});
78 changes: 71 additions & 7 deletions src/gemini/oauth.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { generatePKCE } from "@openauthjs/openauth/pkce";
import { randomBytes } from "node:crypto";
import { createHash, randomBytes } from "node:crypto";

import {
GEMINI_CLIENT_ID,
Expand Down Expand Up @@ -54,6 +54,10 @@ interface GeminiUserInfo {
email?: string;
}

const AUTHORIZATION_CODE_REPLAY_TTL_MS = 10 * 60 * 1000;
const exchangeInFlight = new Map<string, Promise<GeminiTokenExchangeResult>>();
const consumedExchanges = new Map<string, number>();

/**
* Build the Gemini OAuth authorization URL including PKCE.
*/
Expand All @@ -71,8 +75,6 @@ export async function authorizeGemini(): Promise<GeminiAuthorization> {
url.searchParams.set("state", state);
url.searchParams.set("access_type", "offline");
url.searchParams.set("prompt", "consent");
// Add a fragment so any stray terminal glyphs are ignored by the auth server.
url.hash = "opencode";

return {
url: url.toString(),
Expand All @@ -88,14 +90,59 @@ export async function exchangeGeminiWithVerifier(
code: string,
verifier: string,
): Promise<GeminiTokenExchangeResult> {
try {
return await exchangeGeminiWithVerifierInternal(code, verifier);
} catch (error) {
const normalizedCode = typeof code === "string" ? code.trim() : "";
const normalizedVerifier = typeof verifier === "string" ? verifier.trim() : "";
if (isGeminiDebugEnabled() && (typeof code !== "string" || typeof verifier !== "string")) {
logGeminiDebugMessage(
`OAuth exchange received non-string inputs: code=${typeof code} verifier=${typeof verifier}`,
);
}
if (!normalizedCode) {
return {
type: "failed",
error: error instanceof Error ? error.message : "Unknown error",
error: "Missing authorization code in exchange request",
};
}
if (!normalizedVerifier) {
return {
type: "failed",
error: "Missing PKCE verifier for OAuth exchange",
};
}

pruneConsumedExchanges();
const exchangeKey = buildExchangeKey(normalizedCode, normalizedVerifier);
if (consumedExchanges.has(exchangeKey)) {
return {
type: "failed",
error: "Authorization code was already submitted. Start a new login flow.",
};
}

const pending = exchangeInFlight.get(exchangeKey);
if (pending) {
return pending;
}

const exchangePromise = exchangeGeminiWithVerifierInternal(normalizedCode, normalizedVerifier).catch(
(error): GeminiTokenExchangeResult => ({
type: "failed",
error: error instanceof Error ? error.message : "Unknown error",
}),
);
exchangeInFlight.set(exchangeKey, exchangePromise);

let exchangeResult: GeminiTokenExchangeResult | undefined;
try {
exchangeResult = await exchangePromise;
return exchangeResult;
} finally {
exchangeInFlight.delete(exchangeKey);
if (exchangeResult?.type === "success") {
consumedExchanges.set(exchangeKey, Date.now());
}
pruneConsumedExchanges();
}
}

async function exchangeGeminiWithVerifierInternal(
Expand All @@ -104,6 +151,7 @@ async function exchangeGeminiWithVerifierInternal(
): Promise<GeminiTokenExchangeResult> {
if (isGeminiDebugEnabled()) {
logGeminiDebugMessage("OAuth exchange: POST https://oauth2.googleapis.com/token");
logGeminiDebugMessage(`OAuth exchange code fingerprint: ${fingerprint(code)} len=${code.length}`);
}
const tokenResponse = await fetch("https://oauth2.googleapis.com/token", {
method: "POST",
Expand Down Expand Up @@ -175,3 +223,19 @@ async function exchangeGeminiWithVerifierInternal(
email: userInfo.email,
};
}

function buildExchangeKey(code: string, verifier: string): string {
return createHash("sha256").update(code).update("\u0000").update(verifier).digest("hex");
}

function pruneConsumedExchanges(now = Date.now()): void {
for (const [key, consumedAt] of consumedExchanges.entries()) {
if (now - consumedAt > AUTHORIZATION_CODE_REPLAY_TTL_MS) {
consumedExchanges.delete(key);
}
}
}

function fingerprint(value: string): string {
return createHash("sha256").update(value).digest("hex").slice(0, 12);
}
47 changes: 47 additions & 0 deletions src/plugin/oauth-authorize.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { describe, expect, it } from "bun:test";

import {
normalizeAuthorizationCode,
parseOAuthCallbackInput,
} from "./oauth-authorize";

describe("oauth authorize helpers", () => {
it("parses full callback URLs", () => {
const parsed = parseOAuthCallbackInput(
"http://localhost:8085/oauth2callback?code=4%2Fabc123&state=state-1",
);

expect(parsed.source).toBe("url");
expect(parsed.code).toBe("4/abc123");
expect(parsed.state).toBe("state-1");
});

it("parses query-style callback inputs", () => {
const parsed = parseOAuthCallbackInput("code=4%2Fabc123&state=state-2");

expect(parsed.source).toBe("query");
expect(parsed.code).toBe("4/abc123");
expect(parsed.state).toBe("state-2");
});

it("falls back to raw code when no query markers are present", () => {
const parsed = parseOAuthCallbackInput("4/0AbCDef");

expect(parsed.source).toBe("raw");
expect(parsed.code).toBe("4/0AbCDef");
});

it("normalizes encoded authorization codes", () => {
const singleEncoded = normalizeAuthorizationCode("4%2Fabc");
const doubleEncoded = normalizeAuthorizationCode("4%252Fabc");

expect(singleEncoded).toBe("4/abc");
expect(doubleEncoded).toBe("4/abc");
});

it("rejects malformed authorization codes", () => {
expect(normalizeAuthorizationCode(" ")).toBeUndefined();
expect(normalizeAuthorizationCode("4/abc 123")).toBeUndefined();
expect(normalizeAuthorizationCode("4/abc\n123")).toBeUndefined();
});
});
Loading