Skip to content
24 changes: 13 additions & 11 deletions packages/opencode/src/mcp/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import { NamedError } from "@opencode-ai/core/util/error"
import { InstallationVersion } from "@opencode-ai/core/installation/version"
import { withTimeout } from "@/util/timeout"
import { FSUtil } from "@opencode-ai/core/fs-util"
import { McpOAuthProvider, OAUTH_CALLBACK_PATH } from "./oauth-provider"
import { McpOAuthPendingProvider, McpOAuthProvider, OAUTH_CALLBACK_PATH } from "./oauth-provider"
import { McpOAuthCallback } from "./oauth-callback"
import { McpAuth } from "./auth"
import { EventV2Bridge } from "@/event-v2-bridge"
Expand Down Expand Up @@ -109,7 +109,7 @@ export type Status = Schema.Schema.Type<typeof Status>

// Store transports for OAuth servers to allow finishing auth
type TransportWithAuth = StreamableHTTPClientTransport | SSEClientTransport
const pendingOAuthTransports = new Map<string, TransportWithAuth>()
const pendingOAuthTransports = new Map<string, { transport: TransportWithAuth; provider?: McpOAuthPendingProvider }>()

// Prompt cache types
type PromptInfo = Awaited<ReturnType<MCPClient["listPrompts"]>>["prompts"][number]
Expand Down Expand Up @@ -301,7 +301,7 @@ export const layer = Layer.effect(
})
.pipe(Effect.ignore, Effect.as(undefined))
} else {
pendingOAuthTransports.set(key, transport)
pendingOAuthTransports.set(key, { transport })
lastStatus = { status: "needs_auth" as const }
return events
.publish(TuiEvent.ToastShow, {
Expand Down Expand Up @@ -819,7 +819,7 @@ export const layer = Layer.effect(
.join("")
yield* auth.updateOAuthState(mcpName, oauthState)
let capturedUrl: URL | undefined
const authProvider = new McpOAuthProvider(
const authProvider = new McpOAuthPendingProvider(
mcpName,
mcpConfig.url,
{
Expand All @@ -845,15 +845,16 @@ export const layer = Layer.effect(
return yield* Effect.tryPromise({
try: () => {
const client = createClient(directory)
return client
.connect(transport)
.then(() => ({ authorizationUrl: "", oauthState, client }) satisfies AuthResult)
return client.connect(transport).then(async () => {
await authProvider.commit()
return { authorizationUrl: "", oauthState, client } satisfies AuthResult
})
},
catch: (error) => error,
}).pipe(
Effect.catch((error) => {
if (error instanceof UnauthorizedError && capturedUrl) {
pendingOAuthTransports.set(mcpName, transport)
pendingOAuthTransports.set(mcpName, { transport, provider: authProvider })
return Effect.succeed({ authorizationUrl: capturedUrl.toString(), oauthState } satisfies AuthResult)
}
return Effect.die(error)
Expand Down Expand Up @@ -924,11 +925,11 @@ export const layer = Layer.effect(

const finishAuth = Effect.fn("MCP.finishAuth")(function* (mcpName: string, authorizationCode: string) {
yield* requireMcpConfig(mcpName)
const transport = pendingOAuthTransports.get(mcpName)
if (!transport) throw new Error(`No pending OAuth flow for MCP server: ${mcpName}`)
const pending = pendingOAuthTransports.get(mcpName)
if (!pending) throw new Error(`No pending OAuth flow for MCP server: ${mcpName}`)

const result = yield* Effect.tryPromise({
try: () => transport.finishAuth(authorizationCode).then(() => true as const),
try: () => pending.transport.finishAuth(authorizationCode).then(() => true as const),
catch: (error) => {
return error
},
Expand All @@ -938,6 +939,7 @@ export const layer = Layer.effect(
return { status: "failed", error: "OAuth completion failed" } satisfies Status
}

yield* Effect.promise(() => pending.provider?.commit() ?? Promise.resolve())
yield* auth.clearCodeVerifier(mcpName)
pendingOAuthTransports.delete(mcpName)

Expand Down
71 changes: 62 additions & 9 deletions packages/opencode/src/mcp/oauth-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ export interface McpOAuthCallbacks {

export class McpOAuthProvider implements OAuthClientProvider {
constructor(
private mcpName: string,
private serverUrl: string,
private config: McpOAuthConfig,
protected mcpName: string,
protected serverUrl: string,
protected config: McpOAuthConfig,
private callbacks: McpOAuthCallbacks,
private auth: McpAuth.Interface,
protected auth: McpAuth.Interface,
) {}

get redirectUrl(): string {
Expand All @@ -53,7 +53,6 @@ export class McpOAuthProvider implements OAuthClientProvider {
}

async clientInformation(): Promise<OAuthClientInformation | undefined> {
// Check config first (pre-registered client)
if (this.config.clientId) {
return {
client_id: this.config.clientId,
Expand Down Expand Up @@ -164,10 +163,7 @@ export class McpOAuthProvider implements OAuthClientProvider {

async invalidateCredentials(type: "all" | "client" | "tokens"): Promise<void> {
const entry = await Effect.runPromise(this.auth.get(this.mcpName))
if (!entry) {
return
}

if (!entry) return
switch (type) {
case "all":
await Effect.runPromise(this.auth.remove(this.mcpName))
Expand All @@ -184,6 +180,63 @@ export class McpOAuthProvider implements OAuthClientProvider {
}
}

export class McpOAuthPendingProvider extends McpOAuthProvider {
private pendingClientInfo?: OAuthClientInformationFull
private pendingTokens?: OAuthTokens

override async clientInformation(): Promise<OAuthClientInformation | undefined> {
if (!this.config.clientId) return this.pendingClientInfo
return {
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
}
}

override async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
this.pendingClientInfo = info
}

override async tokens(): Promise<OAuthTokens | undefined> {
return this.pendingTokens
}

override async saveTokens(tokens: OAuthTokens): Promise<void> {
this.pendingTokens = tokens
}

override async invalidateCredentials(type: "all" | "client" | "tokens"): Promise<void> {
if (type === "all" || type === "client") this.pendingClientInfo = undefined
if (type === "all" || type === "tokens") this.pendingTokens = undefined
}

async commit(): Promise<void> {
if (!this.pendingTokens) return
await Effect.runPromise(
this.auth.set(
this.mcpName,
{
tokens: {
accessToken: this.pendingTokens.access_token,
refreshToken: this.pendingTokens.refresh_token,
expiresAt: this.pendingTokens.expires_in ? Date.now() / 1000 + this.pendingTokens.expires_in : undefined,
scope: this.pendingTokens.scope,
},
clientInfo:
this.pendingClientInfo && !this.config.clientId
? {
clientId: this.pendingClientInfo.client_id,
clientSecret: this.pendingClientInfo.client_secret,
clientIdIssuedAt: this.pendingClientInfo.client_id_issued_at,
clientSecretExpiresAt: this.pendingClientInfo.client_secret_expires_at,
}
: undefined,
},
this.serverUrl,
),
)
}
}

export { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH }

/**
Expand Down
73 changes: 72 additions & 1 deletion packages/opencode/test/mcp/oauth-auto-connect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ let simulateAuthFlow = true
let connectSucceedsImmediately = false
let serverCapabilities: { tools?: object; resources?: object } = { tools: {} }
let listToolsCalls = 0
let finishAuthFails = false
let finishAuthStoresCredentials = false

// Mock the transport constructors to simulate OAuth auto-auth on 401
void mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
Expand All @@ -32,6 +34,10 @@ void mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
state?: () => Promise<string>
redirectToAuthorization?: (url: URL) => Promise<void>
saveCodeVerifier?: (v: string) => Promise<void>
tokens?: () => Promise<{ access_token: string } | undefined>
clientInformation?: () => Promise<{ client_id: string } | undefined>
saveClientInformation?: (info: { client_id: string; client_secret?: string }) => Promise<void>
saveTokens?: (tokens: { access_token: string; token_type: string }) => Promise<void>
}
| undefined
constructor(url: URL, options?: { authProvider?: unknown }) {
Expand All @@ -49,6 +55,8 @@ void mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
// It calls auth() which eventually calls provider.state(), then
// provider.redirectToAuthorization(), then throws UnauthorizedError.
if (simulateAuthFlow && this.authProvider) {
if (await this.authProvider.tokens?.()) throw new MockUnauthorizedError()
if (await this.authProvider.clientInformation?.()) throw new MockUnauthorizedError()
// The SDK calls provider.state() to get the OAuth state parameter
if (this.authProvider.state) {
await this.authProvider.state()
Expand All @@ -65,7 +73,14 @@ void mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
}
throw new MockUnauthorizedError()
}
async finishAuth(_code: string) {}
async finishAuth(_code: string) {
if (finishAuthFails) throw new Error("Token exchange failed")
if (finishAuthStoresCredentials) {
await this.authProvider?.saveClientInformation?.({ client_id: "replacement-client" })
await this.authProvider?.saveTokens?.({ access_token: "replacement-token", token_type: "Bearer" })
}
}
async close() {}
},
}))

Expand Down Expand Up @@ -125,6 +140,8 @@ beforeEach(() => {
connectSucceedsImmediately = false
serverCapabilities = { tools: {} }
listToolsCalls = 0
finishAuthFails = false
finishAuthStoresCredentials = false
})

// Import modules after mocking
Expand All @@ -133,6 +150,7 @@ const { EventV2Bridge } = await import("../../src/event-v2-bridge")
const { Config } = await import("../../src/config/config")
const { McpAuth } = await import("../../src/mcp/auth")
const { McpOAuthProvider } = await import("../../src/mcp/oauth-provider")
const { McpOAuthCallback } = await import("../../src/mcp/oauth-callback")
const { FSUtil } = await import("@opencode-ai/core/fs-util")
const { CrossSpawnSpawner } = await import("@opencode-ai/core/cross-spawn-spawner")

Expand Down Expand Up @@ -227,6 +245,59 @@ mcpTest.instance("state() returns existing state when one is saved", () =>
}),
)

mcpTest.instance(
"failed reauthentication preserves existing credentials",
() =>
Effect.gen(function* () {
yield* Effect.addFinalizer(() => Effect.promise(() => McpOAuthCallback.stop()).pipe(Effect.ignore))
const mcp = yield* MCP.Service
const auth = yield* McpAuth.Service
const name = "test-reauth-failure"
const url = "https://example.com/mcp"
const clientInfo = { clientId: "dynamic-client", clientSecret: "dynamic-secret" }

yield* auth.updateClientInfo(name, clientInfo, url)
yield* auth.updateTokens(name, { accessToken: "working-token" }, url)
expect((yield* mcp.startAuth(name)).authorizationUrl).toContain("https://auth.example.com/authorize")
finishAuthFails = true

expect(yield* mcp.finishAuth(name, "invalid-code")).toEqual({
status: "failed",
error: "OAuth completion failed",
})
const entry = yield* auth.get(name)
expect(entry?.tokens?.accessToken).toBe("working-token")
expect(entry?.clientInfo).toEqual(clientInfo)
}),
{ config: config("test-reauth-failure") },
)

mcpTest.instance(
"successful reauthentication commits replacement credentials",
() =>
Effect.gen(function* () {
yield* Effect.addFinalizer(() => Effect.promise(() => McpOAuthCallback.stop()).pipe(Effect.ignore))
const mcp = yield* MCP.Service
const auth = yield* McpAuth.Service
const name = "test-reauth-success"
const url = "https://example.com/mcp"

yield* auth.updateClientInfo(name, { clientId: "old-client" }, url)
yield* auth.updateTokens(name, { accessToken: "old-token" }, url)
expect((yield* mcp.startAuth(name)).authorizationUrl).toContain("https://auth.example.com/authorize")
expect((yield* auth.get(name))?.tokens?.accessToken).toBe("old-token")
finishAuthStoresCredentials = true
connectSucceedsImmediately = true

expect((yield* mcp.finishAuth(name, "valid-code")).status).toBe("connected")
const entry = yield* auth.get(name)
expect(entry?.tokens?.accessToken).toBe("replacement-token")
expect(entry?.clientInfo?.clientId).toBe("replacement-client")
expect(entry?.serverUrl).toBe(url)
}),
{ config: config("test-reauth-success") },
)

mcpTest.instance(
"auth status only reports credentials stored for the configured server URL",
() =>
Expand Down
Loading