11import type { Client } from "@modelcontextprotocol/sdk/client/index.js" ;
2+ import escapeHtml from "escape-html" ;
23import type { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js" ;
34import type {
45 CallToolRequest ,
@@ -44,6 +45,13 @@ export type MCPServerOptions = {
4445 } ;
4546} ;
4647
48+ /**
49+ * Result of an OAuth callback request
50+ */
51+ export type MCPOAuthCallbackResult =
52+ | { serverId : string ; authSuccess : true ; authError ?: undefined }
53+ | { serverId : string ; authSuccess : false ; authError : string } ;
54+
4755/**
4856 * Options for registering an MCP server
4957 */
@@ -187,6 +195,19 @@ export class MCPClientManager {
187195 ) ;
188196 }
189197
198+ private failConnection (
199+ serverId : string ,
200+ error : string
201+ ) : MCPOAuthCallbackResult {
202+ this . clearServerAuthUrl ( serverId ) ;
203+ if ( this . mcpConnections [ serverId ] ) {
204+ this . mcpConnections [ serverId ] . connectionState = MCPConnectionState . FAILED ;
205+ this . mcpConnections [ serverId ] . connectionError = error ;
206+ }
207+ this . _onServerStateChanged . fire ( ) ;
208+ return { serverId, authSuccess : false , authError : error } ;
209+ }
210+
190211 jsonSchema : typeof import ( "ai" ) . jsonSchema | undefined ;
191212
192213 /**
@@ -663,19 +684,19 @@ export class MCPClientManager {
663684 return servers . some ( ( server ) => server . id === serverId ) ;
664685 }
665686
666- async handleCallbackRequest ( req : Request ) {
687+ async handleCallbackRequest ( req : Request ) : Promise < MCPOAuthCallbackResult > {
667688 const url = new URL ( req . url ) ;
668689 const code = url . searchParams . get ( "code" ) ;
669690 const state = url . searchParams . get ( "state" ) ;
670691 const error = url . searchParams . get ( "error" ) ;
671692 const errorDescription = url . searchParams . get ( "error_description" ) ;
672693
694+ // Early validation - these throw because we can't identify the connection
673695 if ( ! state ) {
674696 throw new Error ( "Unauthorized: no state provided" ) ;
675697 }
676698
677699 const serverId = this . extractServerIdFromState ( state ) ;
678-
679700 if ( ! serverId ) {
680701 throw new Error (
681702 "No serverId found in state parameter. Expected format: {nonce}.{serverId}"
@@ -684,7 +705,6 @@ export class MCPClientManager {
684705
685706 const servers = this . getServersFromStorage ( ) ;
686707 const serverExists = servers . some ( ( server ) => server . id === serverId ) ;
687-
688708 if ( ! serverExists ) {
689709 throw new Error (
690710 `No server found with id "${ serverId } ". Was the request matched with \`isCallbackRequest()\`?`
@@ -695,89 +715,61 @@ export class MCPClientManager {
695715 throw new Error ( `Could not find serverId: ${ serverId } ` ) ;
696716 }
697717
718+ // We have a valid connection - all errors from here should fail the connection
698719 const conn = this . mcpConnections [ serverId ] ;
699- if ( ! conn . options . transport . authProvider ) {
700- throw new Error (
701- "Trying to finalize authentication for a server connection without an authProvider"
702- ) ;
703- }
704720
705- const authProvider = conn . options . transport . authProvider ;
706- authProvider . serverId = serverId ;
721+ try {
722+ if ( ! conn . options . transport . authProvider ) {
723+ throw new Error (
724+ "Trying to finalize authentication for a server connection without an authProvider"
725+ ) ;
726+ }
707727
708- // Two-phase state validation: check first (non-destructive), consume later
709- // This prevents DoS attacks where attacker consumes valid state before legitimate callback
710- const stateValidation = await authProvider . checkState ( state ) ;
711- if ( ! stateValidation . valid ) {
712- this . clearServerAuthUrl ( serverId ) ;
713- if ( this . mcpConnections [ serverId ] ) {
714- this . mcpConnections [ serverId ] . connectionState =
715- MCPConnectionState . FAILED ;
728+ const authProvider = conn . options . transport . authProvider ;
729+ authProvider . serverId = serverId ;
730+
731+ // Two-phase state validation: check first (non-destructive), consume later
732+ // This prevents DoS attacks where attacker consumes valid state before legitimate callback
733+ const stateValidation = await authProvider . checkState ( state ) ;
734+ if ( ! stateValidation . valid ) {
735+ throw new Error ( stateValidation . error || "Invalid state" ) ;
716736 }
717- this . _onServerStateChanged . fire ( ) ;
718- return {
719- serverId,
720- authSuccess : false ,
721- authError : stateValidation . error || "Invalid state"
722- } ;
723- }
724737
725- if ( error ) {
726- return {
727- serverId,
728- authSuccess : false ,
729- authError : errorDescription || error
730- } ;
731- }
738+ if ( error ) {
739+ // Escape external OAuth error params to prevent XSS
740+ throw new Error ( escapeHtml ( errorDescription || error ) ) ;
741+ }
732742
733- if ( ! code ) {
734- throw new Error ( "Unauthorized: no code provided" ) ;
735- }
743+ if ( ! code ) {
744+ throw new Error ( "Unauthorized: no code provided" ) ;
745+ }
736746
737- if (
738- this . mcpConnections [ serverId ] . connectionState ===
739- MCPConnectionState . READY ||
740- this . mcpConnections [ serverId ] . connectionState ===
741- MCPConnectionState . CONNECTED
742- ) {
743- this . clearServerAuthUrl ( serverId ) ;
744- return {
745- serverId,
746- authSuccess : true
747- } ;
748- }
747+ // Already authenticated - just return success
748+ if (
749+ conn . connectionState === MCPConnectionState . READY ||
750+ conn . connectionState === MCPConnectionState . CONNECTED
751+ ) {
752+ this . clearServerAuthUrl ( serverId ) ;
753+ return { serverId, authSuccess : true } ;
754+ }
749755
750- if (
751- this . mcpConnections [ serverId ] . connectionState !==
752- MCPConnectionState . AUTHENTICATING
753- ) {
754- throw new Error (
755- `Failed to authenticate: the client is in "${ this . mcpConnections [ serverId ] . connectionState } " state, expected "authenticating"`
756- ) ;
757- }
756+ if ( conn . connectionState !== MCPConnectionState . AUTHENTICATING ) {
757+ throw new Error (
758+ `Failed to authenticate: the client is in "${ conn . connectionState } " state, expected "authenticating"`
759+ ) ;
760+ }
758761
759- try {
760762 await authProvider . consumeState ( state ) ;
761763 await conn . completeAuthorization ( code ) ;
762764 await authProvider . deleteCodeVerifier ( ) ;
763765 this . clearServerAuthUrl ( serverId ) ;
766+ conn . connectionError = null ;
764767 this . _onServerStateChanged . fire ( ) ;
765768
766- return {
767- serverId,
768- authSuccess : true
769- } ;
770- } catch ( authError ) {
771- const errorMessage =
772- authError instanceof Error ? authError . message : String ( authError ) ;
773-
774- this . _onServerStateChanged . fire ( ) ;
775-
776- return {
777- serverId,
778- authSuccess : false ,
779- authError : errorMessage
780- } ;
769+ return { serverId, authSuccess : true } ;
770+ } catch ( err ) {
771+ const message = err instanceof Error ? err . message : String ( err ) ;
772+ return this . failConnection ( serverId , message ) ;
781773 }
782774 }
783775
0 commit comments