diff --git a/packages/opencode/src/mcp/index.ts b/packages/opencode/src/mcp/index.ts index 023eab520ee1..7f6141ca3da2 100644 --- a/packages/opencode/src/mcp/index.ts +++ b/packages/opencode/src/mcp/index.ts @@ -22,7 +22,7 @@ import { NamedError } from "@opencode-ai/core/util/error" import { InstallationVersion } from "@opencode-ai/core/installation/version" import { withTimeout } from "@/util/timeout" import { FSUtil } from "@opencode-ai/core/fs-util" -import { McpOAuthProvider, OAUTH_CALLBACK_PATH } from "./oauth-provider" +import { McpOAuthPendingProvider, McpOAuthProvider, OAUTH_CALLBACK_PATH } from "./oauth-provider" import { McpOAuthCallback } from "./oauth-callback" import { McpAuth } from "./auth" import { EventV2Bridge } from "@/event-v2-bridge" @@ -109,7 +109,7 @@ export type Status = Schema.Schema.Type // Store transports for OAuth servers to allow finishing auth type TransportWithAuth = StreamableHTTPClientTransport | SSEClientTransport -const pendingOAuthTransports = new Map() +const pendingOAuthTransports = new Map() // Prompt cache types type PromptInfo = Awaited>["prompts"][number] @@ -301,7 +301,7 @@ export const layer = Layer.effect( }) .pipe(Effect.ignore, Effect.as(undefined)) } else { - pendingOAuthTransports.set(key, transport) + pendingOAuthTransports.set(key, { transport }) lastStatus = { status: "needs_auth" as const } return events .publish(TuiEvent.ToastShow, { @@ -819,7 +819,7 @@ export const layer = Layer.effect( .join("") yield* auth.updateOAuthState(mcpName, oauthState) let capturedUrl: URL | undefined - const authProvider = new McpOAuthProvider( + const authProvider = new McpOAuthPendingProvider( mcpName, mcpConfig.url, { @@ -845,15 +845,16 @@ export const layer = Layer.effect( return yield* Effect.tryPromise({ try: () => { const client = createClient(directory) - return client - .connect(transport) - .then(() => ({ authorizationUrl: "", oauthState, client }) satisfies AuthResult) + return client.connect(transport).then(async () => { + await authProvider.commit() + return { authorizationUrl: "", oauthState, client } satisfies AuthResult + }) }, catch: (error) => error, }).pipe( Effect.catch((error) => { if (error instanceof UnauthorizedError && capturedUrl) { - pendingOAuthTransports.set(mcpName, transport) + pendingOAuthTransports.set(mcpName, { transport, provider: authProvider }) return Effect.succeed({ authorizationUrl: capturedUrl.toString(), oauthState } satisfies AuthResult) } return Effect.die(error) @@ -924,11 +925,11 @@ export const layer = Layer.effect( const finishAuth = Effect.fn("MCP.finishAuth")(function* (mcpName: string, authorizationCode: string) { yield* requireMcpConfig(mcpName) - const transport = pendingOAuthTransports.get(mcpName) - if (!transport) throw new Error(`No pending OAuth flow for MCP server: ${mcpName}`) + const pending = pendingOAuthTransports.get(mcpName) + if (!pending) throw new Error(`No pending OAuth flow for MCP server: ${mcpName}`) const result = yield* Effect.tryPromise({ - try: () => transport.finishAuth(authorizationCode).then(() => true as const), + try: () => pending.transport.finishAuth(authorizationCode).then(() => true as const), catch: (error) => { return error }, @@ -938,6 +939,7 @@ export const layer = Layer.effect( return { status: "failed", error: "OAuth completion failed" } satisfies Status } + yield* Effect.promise(() => pending.provider?.commit() ?? Promise.resolve()) yield* auth.clearCodeVerifier(mcpName) pendingOAuthTransports.delete(mcpName) diff --git a/packages/opencode/src/mcp/oauth-provider.ts b/packages/opencode/src/mcp/oauth-provider.ts index aa29777f5447..596bfe1d551f 100644 --- a/packages/opencode/src/mcp/oauth-provider.ts +++ b/packages/opencode/src/mcp/oauth-provider.ts @@ -25,11 +25,11 @@ export interface McpOAuthCallbacks { export class McpOAuthProvider implements OAuthClientProvider { constructor( - private mcpName: string, - private serverUrl: string, - private config: McpOAuthConfig, + protected mcpName: string, + protected serverUrl: string, + protected config: McpOAuthConfig, private callbacks: McpOAuthCallbacks, - private auth: McpAuth.Interface, + protected auth: McpAuth.Interface, ) {} get redirectUrl(): string { @@ -53,7 +53,6 @@ export class McpOAuthProvider implements OAuthClientProvider { } async clientInformation(): Promise { - // Check config first (pre-registered client) if (this.config.clientId) { return { client_id: this.config.clientId, @@ -164,10 +163,7 @@ export class McpOAuthProvider implements OAuthClientProvider { async invalidateCredentials(type: "all" | "client" | "tokens"): Promise { const entry = await Effect.runPromise(this.auth.get(this.mcpName)) - if (!entry) { - return - } - + if (!entry) return switch (type) { case "all": await Effect.runPromise(this.auth.remove(this.mcpName)) @@ -184,6 +180,63 @@ export class McpOAuthProvider implements OAuthClientProvider { } } +export class McpOAuthPendingProvider extends McpOAuthProvider { + private pendingClientInfo?: OAuthClientInformationFull + private pendingTokens?: OAuthTokens + + override async clientInformation(): Promise { + if (!this.config.clientId) return this.pendingClientInfo + return { + client_id: this.config.clientId, + client_secret: this.config.clientSecret, + } + } + + override async saveClientInformation(info: OAuthClientInformationFull): Promise { + this.pendingClientInfo = info + } + + override async tokens(): Promise { + return this.pendingTokens + } + + override async saveTokens(tokens: OAuthTokens): Promise { + this.pendingTokens = tokens + } + + override async invalidateCredentials(type: "all" | "client" | "tokens"): Promise { + if (type === "all" || type === "client") this.pendingClientInfo = undefined + if (type === "all" || type === "tokens") this.pendingTokens = undefined + } + + async commit(): Promise { + if (!this.pendingTokens) return + await Effect.runPromise( + this.auth.set( + this.mcpName, + { + tokens: { + accessToken: this.pendingTokens.access_token, + refreshToken: this.pendingTokens.refresh_token, + expiresAt: this.pendingTokens.expires_in ? Date.now() / 1000 + this.pendingTokens.expires_in : undefined, + scope: this.pendingTokens.scope, + }, + clientInfo: + this.pendingClientInfo && !this.config.clientId + ? { + clientId: this.pendingClientInfo.client_id, + clientSecret: this.pendingClientInfo.client_secret, + clientIdIssuedAt: this.pendingClientInfo.client_id_issued_at, + clientSecretExpiresAt: this.pendingClientInfo.client_secret_expires_at, + } + : undefined, + }, + this.serverUrl, + ), + ) + } +} + export { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } /** diff --git a/packages/opencode/test/mcp/oauth-auto-connect.test.ts b/packages/opencode/test/mcp/oauth-auto-connect.test.ts index 7a87b9a25800..43b10929e4bb 100644 --- a/packages/opencode/test/mcp/oauth-auto-connect.test.ts +++ b/packages/opencode/test/mcp/oauth-auto-connect.test.ts @@ -23,6 +23,8 @@ let simulateAuthFlow = true let connectSucceedsImmediately = false let serverCapabilities: { tools?: object; resources?: object } = { tools: {} } let listToolsCalls = 0 +let finishAuthFails = false +let finishAuthStoresCredentials = false // Mock the transport constructors to simulate OAuth auto-auth on 401 void mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ @@ -32,6 +34,10 @@ void mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ state?: () => Promise redirectToAuthorization?: (url: URL) => Promise saveCodeVerifier?: (v: string) => Promise + tokens?: () => Promise<{ access_token: string } | undefined> + clientInformation?: () => Promise<{ client_id: string } | undefined> + saveClientInformation?: (info: { client_id: string; client_secret?: string }) => Promise + saveTokens?: (tokens: { access_token: string; token_type: string }) => Promise } | undefined constructor(url: URL, options?: { authProvider?: unknown }) { @@ -49,6 +55,8 @@ void mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ // It calls auth() which eventually calls provider.state(), then // provider.redirectToAuthorization(), then throws UnauthorizedError. if (simulateAuthFlow && this.authProvider) { + if (await this.authProvider.tokens?.()) throw new MockUnauthorizedError() + if (await this.authProvider.clientInformation?.()) throw new MockUnauthorizedError() // The SDK calls provider.state() to get the OAuth state parameter if (this.authProvider.state) { await this.authProvider.state() @@ -65,7 +73,14 @@ void mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ } throw new MockUnauthorizedError() } - async finishAuth(_code: string) {} + async finishAuth(_code: string) { + if (finishAuthFails) throw new Error("Token exchange failed") + if (finishAuthStoresCredentials) { + await this.authProvider?.saveClientInformation?.({ client_id: "replacement-client" }) + await this.authProvider?.saveTokens?.({ access_token: "replacement-token", token_type: "Bearer" }) + } + } + async close() {} }, })) @@ -125,6 +140,8 @@ beforeEach(() => { connectSucceedsImmediately = false serverCapabilities = { tools: {} } listToolsCalls = 0 + finishAuthFails = false + finishAuthStoresCredentials = false }) // Import modules after mocking @@ -133,6 +150,7 @@ const { EventV2Bridge } = await import("../../src/event-v2-bridge") const { Config } = await import("../../src/config/config") const { McpAuth } = await import("../../src/mcp/auth") const { McpOAuthProvider } = await import("../../src/mcp/oauth-provider") +const { McpOAuthCallback } = await import("../../src/mcp/oauth-callback") const { FSUtil } = await import("@opencode-ai/core/fs-util") const { CrossSpawnSpawner } = await import("@opencode-ai/core/cross-spawn-spawner") @@ -227,6 +245,59 @@ mcpTest.instance("state() returns existing state when one is saved", () => }), ) +mcpTest.instance( + "failed reauthentication preserves existing credentials", + () => + Effect.gen(function* () { + yield* Effect.addFinalizer(() => Effect.promise(() => McpOAuthCallback.stop()).pipe(Effect.ignore)) + const mcp = yield* MCP.Service + const auth = yield* McpAuth.Service + const name = "test-reauth-failure" + const url = "https://example.com/mcp" + const clientInfo = { clientId: "dynamic-client", clientSecret: "dynamic-secret" } + + yield* auth.updateClientInfo(name, clientInfo, url) + yield* auth.updateTokens(name, { accessToken: "working-token" }, url) + expect((yield* mcp.startAuth(name)).authorizationUrl).toContain("https://auth.example.com/authorize") + finishAuthFails = true + + expect(yield* mcp.finishAuth(name, "invalid-code")).toEqual({ + status: "failed", + error: "OAuth completion failed", + }) + const entry = yield* auth.get(name) + expect(entry?.tokens?.accessToken).toBe("working-token") + expect(entry?.clientInfo).toEqual(clientInfo) + }), + { config: config("test-reauth-failure") }, +) + +mcpTest.instance( + "successful reauthentication commits replacement credentials", + () => + Effect.gen(function* () { + yield* Effect.addFinalizer(() => Effect.promise(() => McpOAuthCallback.stop()).pipe(Effect.ignore)) + const mcp = yield* MCP.Service + const auth = yield* McpAuth.Service + const name = "test-reauth-success" + const url = "https://example.com/mcp" + + yield* auth.updateClientInfo(name, { clientId: "old-client" }, url) + yield* auth.updateTokens(name, { accessToken: "old-token" }, url) + expect((yield* mcp.startAuth(name)).authorizationUrl).toContain("https://auth.example.com/authorize") + expect((yield* auth.get(name))?.tokens?.accessToken).toBe("old-token") + finishAuthStoresCredentials = true + connectSucceedsImmediately = true + + expect((yield* mcp.finishAuth(name, "valid-code")).status).toBe("connected") + const entry = yield* auth.get(name) + expect(entry?.tokens?.accessToken).toBe("replacement-token") + expect(entry?.clientInfo?.clientId).toBe("replacement-client") + expect(entry?.serverUrl).toBe(url) + }), + { config: config("test-reauth-success") }, +) + mcpTest.instance( "auth status only reports credentials stored for the configured server URL", () =>