Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
293a6ee
feat(auth): add tenant_id to session and request context
andylim-duo Mar 11, 2026
f597175
fix(test): use async context manager for ServerSession test
andylim-duo Mar 11, 2026
f5351e0
Merge branch 'main' into feature/multi-tenant-session-context
andylim-duo Mar 11, 2026
a7d9a9f
test(auth): add tenant isolation tests for concurrent requests
andylim-duo Mar 11, 2026
2f6fa78
Merge branch 'main' into feature/multi-tenant-session-context
andylim-duo Mar 12, 2026
9f4b679
docs(context): add tenant_id field description to RequestContext docs…
andylim-duo Mar 12, 2026
9e3ded2
feat(auth): populate session.tenant_id from auth context on first req…
andylim-duo Mar 12, 2026
161f123
fix(test): resolve pyright reportUnnecessaryComparison error
andylim-duo Mar 12, 2026
f3ed099
test(auth): add E2E tests for tenant_id binding in request/notificati…
andylim-duo Mar 12, 2026
2dcebd0
style: apply ruff formatting to test file
andylim-duo Mar 12, 2026
ec564e7
fix: remove stale pragma no cover from send_roots_list_changed
andylim-duo Mar 12, 2026
3dfb2d7
feat(auth): enforce set-once semantics on ServerSession.tenant_id
andylim-duo Mar 12, 2026
117f0da
refactor(auth): decouple core server from auth module for tenant extr…
andylim-duo Mar 13, 2026
02725c8
fix(test): remove dead code in test_get_tenant_id_with_tenant
andylim-duo Mar 16, 2026
04c7535
fix(test): replace anyio.sleep(0.01) with anyio.lowlevel.checkpoint()
andylim-duo Mar 16, 2026
92dff29
fix(test): use explicit import for anyio.lowlevel.checkpoint
andylim-duo Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None

return result

async def send_roots_list_changed(self) -> None: # pragma: no cover
async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
await self.send_notification(types.RootsListChangedNotification())

Expand Down
19 changes: 17 additions & 2 deletions src/mcp/server/auth/middleware/auth_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
from mcp.shared._context import tenant_id_var

# Create a contextvar to store the authenticated user
# The default is None, indicating no authenticated user is present
Expand All @@ -20,6 +21,16 @@ def get_access_token() -> AccessToken | None:
return auth_user.access_token if auth_user else None


def get_tenant_id() -> str | None:
"""Get the tenant_id from the current authentication context.

Returns:
The tenant_id if an authenticated user with a tenant is available, None otherwise.
"""
access_token = get_access_token()
return access_token.tenant_id if access_token else None


class AuthContextMiddleware:
"""Middleware that extracts the authenticated user from the request
and sets it in a contextvar for easy access throughout the request lifecycle.
Expand All @@ -36,11 +47,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
user = scope.get("user")
if isinstance(user, AuthenticatedUser):
# Set the authenticated user in the contextvar
token = auth_context_var.set(user)
auth_token = auth_context_var.set(user)
# Propagate tenant_id to the transport-agnostic contextvar
tenant_id = user.access_token.tenant_id if user.access_token else None
tenant_token = tenant_id_var.set(tenant_id)
try:
await self.app(scope, receive, send)
finally:
auth_context_var.reset(token)
tenant_id_var.reset(tenant_token)
auth_context_var.reset(auth_token)
else:
# No authenticated user, just process the request
await self.app(scope, receive, send)
9 changes: 9 additions & 0 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async def main():
from mcp.server.streamable_http import EventStore
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._context import tenant_id_var
from mcp.shared.exceptions import MCPError
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
Expand Down Expand Up @@ -451,11 +452,15 @@ async def _handle_request(
task_metadata = None
if hasattr(req, "params") and req.params is not None:
task_metadata = getattr(req.params, "task", None)
tenant_id = tenant_id_var.get()
if tenant_id is not None and session.tenant_id is None:
session.tenant_id = tenant_id
ctx = ServerRequestContext(
request_id=message.request_id,
meta=message.request_meta,
session=session,
lifespan_context=lifespan_context,
tenant_id=tenant_id,
experimental=Experimental(
task_metadata=task_metadata,
_client_capabilities=client_capabilities,
Expand Down Expand Up @@ -495,9 +500,13 @@ async def _handle_notification(
try:
client_capabilities = session.client_params.capabilities if session.client_params else None
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
tenant_id = tenant_id_var.get()
if tenant_id is not None and session.tenant_id is None:
session.tenant_id = tenant_id
ctx = ServerRequestContext(
session=session,
lifespan_context=lifespan_context,
tenant_id=tenant_id,
experimental=Experimental(
task_metadata=None,
_client_capabilities=client_capabilities,
Expand Down
22 changes: 22 additions & 0 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class ServerSession(
_initialized: InitializationState = InitializationState.NotInitialized
_client_params: types.InitializeRequestParams | None = None
_experimental_features: ExperimentalServerSessionFeatures | None = None
_tenant_id: str | None = None

def __init__(
self,
Expand Down Expand Up @@ -108,6 +109,27 @@ def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification]
def client_params(self) -> types.InitializeRequestParams | None:
return self._client_params

@property
def tenant_id(self) -> str | None:
"""Get the tenant_id for this session."""
return self._tenant_id

@tenant_id.setter
def tenant_id(self, value: str | None) -> None:
"""Set the tenant_id for this session (set-once).

Once a session is bound to a tenant, the tenant_id cannot be changed.
This prevents accidental tenant reassignment which could be a security issue.

Raises:
ValueError: If tenant_id is already set to a different value.
"""
if self._tenant_id is not None and value != self._tenant_id:
raise ValueError(
f"Cannot change tenant_id from '{self._tenant_id}' to '{value}': session is already bound to a tenant"
)
self._tenant_id = value
Comment thread
andylim-duo marked this conversation as resolved.

Comment thread
andylim-duo marked this conversation as resolved.
@property
def experimental(self) -> ExperimentalServerSessionFeatures:
"""Experimental APIs for server→client task operations.
Expand Down
11 changes: 11 additions & 0 deletions src/mcp/shared/_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Request context for MCP handlers."""

import contextvars
from dataclasses import dataclass
from typing import Any, Generic

Expand All @@ -8,6 +9,11 @@
from mcp.shared.session import BaseSession
from mcp.types import RequestId, RequestParamsMeta

# Transport-agnostic contextvar for tenant identification.
# Set by the transport layer (e.g., AuthContextMiddleware for HTTP+OAuth).
# Read by the core server to populate RequestContext.tenant_id.
tenant_id_var = contextvars.ContextVar[str | None]("tenant_id", default=None)

SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])


Expand All @@ -17,8 +23,13 @@ class RequestContext(Generic[SessionT]):

For request handlers, request_id is always populated.
For notification handlers, request_id is None.

The tenant_id field is used in multi-tenant server deployments to identify
which tenant the request belongs to. It is populated from session context
and enables tenant-specific request handling and isolation.
"""

session: SessionT
request_id: RequestId | None = None
meta: RequestParamsMeta | None = None
tenant_id: str | None = None
128 changes: 128 additions & 0 deletions tests/server/auth/middleware/test_auth_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
AuthContextMiddleware,
auth_context_var,
get_access_token,
get_tenant_id,
)
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
from mcp.shared._context import tenant_id_var


class MockApp:
Expand Down Expand Up @@ -117,3 +119,129 @@ async def send(message: Message) -> None: # pragma: no cover
# Verify context is still empty after middleware
assert auth_context_var.get() is None
assert get_access_token() is None


@pytest.fixture
def access_token_with_tenant() -> AccessToken:
"""Create an access token with a tenant_id."""
return AccessToken(
token="tenant_token",
client_id="test_client",
scopes=["read", "write"],
expires_at=int(time.time()) + 3600,
tenant_id="tenant-abc",
)


def test_get_tenant_id_without_auth_context():
"""Test get_tenant_id returns None when no auth context exists."""
assert auth_context_var.get() is None
assert get_tenant_id() is None


@pytest.mark.anyio
async def test_get_tenant_id_with_tenant(access_token_with_tenant: AccessToken):
"""Test get_tenant_id returns tenant_id when auth context has a tenant."""
user = AuthenticatedUser(access_token_with_tenant)
scope: Scope = {"type": "http", "user": user}

tenant_id_during_call: str | None = None

class TenantCheckApp:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
nonlocal tenant_id_during_call
tenant_id_during_call = get_tenant_id()

middleware = AuthContextMiddleware(TenantCheckApp())

async def receive() -> Message: # pragma: no cover
return {"type": "http.request"}

async def send(message: Message) -> None: # pragma: no cover
pass

await middleware(scope, receive, send)

assert tenant_id_during_call == "tenant-abc"
# Verify context is reset after middleware
assert get_tenant_id() is None


@pytest.mark.anyio
async def test_middleware_sets_tenant_id_var(access_token_with_tenant: AccessToken):
"""Test AuthContextMiddleware populates the transport-agnostic tenant_id_var."""
user = AuthenticatedUser(access_token_with_tenant)
scope: Scope = {"type": "http", "user": user}

observed_tenant_id: str | None = None

class CheckApp:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
nonlocal observed_tenant_id
observed_tenant_id = tenant_id_var.get()

middleware = AuthContextMiddleware(CheckApp())

async def receive() -> Message: # pragma: no cover
return {"type": "http.request"}

async def send(message: Message) -> None: # pragma: no cover
pass

await middleware(scope, receive, send)

assert observed_tenant_id == "tenant-abc"
# Verify contextvar is reset after middleware
assert tenant_id_var.get() is None


@pytest.mark.anyio
async def test_middleware_sets_tenant_id_var_none_without_tenant(valid_access_token: AccessToken):
"""Test AuthContextMiddleware sets tenant_id_var to None when token has no tenant."""
user = AuthenticatedUser(valid_access_token)
scope: Scope = {"type": "http", "user": user}

observed_tenant_id: str | None = "sentinel"

class CheckApp:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
nonlocal observed_tenant_id
observed_tenant_id = tenant_id_var.get()

middleware = AuthContextMiddleware(CheckApp())

async def receive() -> Message: # pragma: no cover
return {"type": "http.request"}

async def send(message: Message) -> None: # pragma: no cover
pass

await middleware(scope, receive, send)

assert observed_tenant_id is None


@pytest.mark.anyio
async def test_get_tenant_id_without_tenant(valid_access_token: AccessToken):
"""Test get_tenant_id returns None when auth context has no tenant."""
tenant_id_during_call: str | None = "not-none"

class TenantCheckApp:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
nonlocal tenant_id_during_call
tenant_id_during_call = get_tenant_id()

middleware = AuthContextMiddleware(TenantCheckApp())

user = AuthenticatedUser(valid_access_token)
scope: Scope = {"type": "http", "user": user}

async def receive() -> Message: # pragma: no cover
return {"type": "http.request"}

async def send(message: Message) -> None: # pragma: no cover
pass

await middleware(scope, receive, send)

assert tenant_id_during_call is None
Loading
Loading