diff --git a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py index 70fc2c233e7..39fda7074cc 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py +++ b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py @@ -566,25 +566,32 @@ async def get_allowed_mcp_servers( ) ) + key_access_group_extras = ( + await MCPRequestHandler._get_key_access_group_mcp_server_extras( + user_api_key_auth + ) + ) + ######################################################### # Calculate key/team allowed servers using inheritance and intersection logic ######################################################### - allowed_mcp_servers: List[str] = [] - has_lower_level_mcp_restrictions = ( - len(allowed_mcp_servers_for_key) > 0 - or len(allowed_mcp_servers_for_team) > 0 - ) - if len(allowed_mcp_servers_for_team) > 0: - if len(allowed_mcp_servers_for_key) > 0: - # Key has its own MCP permissions - use intersection with team permissions - for _mcp_server in allowed_mcp_servers_for_key: - if _mcp_server in allowed_mcp_servers_for_team: - allowed_mcp_servers.append(_mcp_server) - else: - # Key has no MCP permissions - inherit from team - allowed_mcp_servers = allowed_mcp_servers_for_team + key_set = set(allowed_mcp_servers_for_key) + team_set = set(allowed_mcp_servers_for_team) + extras_set = set(key_access_group_extras) + + has_lower_level_mcp_restrictions = bool(key_set or team_set or extras_set) + + # 1. Team-gated base scope. + if not team_set: + base = key_set # no team restriction + elif not key_set: + base = team_set # key has no own perms → inherits team else: - allowed_mcp_servers = allowed_mcp_servers_for_key + base = key_set & team_set # both restrict → intersect + + # 2. Extend with access-group extras (LIT-3189 — bypasses team + # ceiling, gated by group's assigned_team_ids / assigned_key_ids). + allowed_mcp_servers: List[str] = list(base | extras_set) ######################################################### # Check end_user permissions if end_user_id is set @@ -877,6 +884,43 @@ def is_tool_allowed( return True return False + @staticmethod + async def _get_key_access_group_mcp_server_extras( + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + ) -> List[str]: + """ + Resolve the key's unified `access_group_ids` (LiteLLM_AccessGroupTable) to + MCP server IDs, gated by the access group's `assigned_team_ids` / + `assigned_key_ids`. These servers extend the team's MCP scope rather + than being capped by it. Tag-style `mcp_access_groups` (per-server tags) + are intentionally not handled here — they have no assignment fields and + remain subject to the team ceiling. + """ + if user_api_key_auth is None: + return [] + try: + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from litellm.proxy.auth.auth_checks import ( + get_authorized_resources_from_key_access_groups, + ) + + raw_server_ids = await get_authorized_resources_from_key_access_groups( + valid_token=user_api_key_auth, + team_object=None, + resource_field="access_mcp_server_ids", + ) + if not raw_server_ids: + return [] + # Permission entries may be server_ids OR names/aliases — expand to ids. + return global_mcp_server_manager.expand_permission_list(raw_server_ids) + except Exception as e: + verbose_logger.warning( + f"Failed to get key access group MCP server extras: {str(e)}" + ) + return [] + @staticmethod async def _get_allowed_mcp_servers_for_key( user_api_key_auth: Optional[UserAPIKeyAuth] = None, diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 703023d2a0f..e291cdbbfb3 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -3163,44 +3163,40 @@ async def can_team_access_model( raise -async def _key_access_group_grants_model( - model: Union[str, List[str]], +async def get_authorized_resources_from_key_access_groups( valid_token: Optional[UserAPIKeyAuth], team_object: Optional[LiteLLM_TeamTable], - llm_router: Optional[Router], -) -> bool: + resource_field: Literal[ + "access_model_names", "access_mcp_server_ids", "access_agent_ids" + ], +) -> List[str]: """ - Returns True if the key's `access_group_ids` expand to models that grant - access to `model`. Used to let a key's access group override a team's - model restriction in `common_checks`. - - A key's access group only counts if the access group itself authorizes the - caller as an owner — that is, the group's `assigned_team_ids` includes the - key's `team_id`, or the group's `assigned_key_ids` includes the key's - token. This preserves the team-as-owner boundary (a team member cannot - escalate by naming a group assigned to a different team) while still - letting a group reach the key without first being added to the team's - `access_group_ids` list. + For each access_group_id on the key, fetch the LiteLLM_AccessGroupTable row + and contribute its `resource_field` only if the group authorizes the caller + as an owner — that is, the group's `assigned_team_ids` includes the key's + `team_id`, or the group's `assigned_key_ids` includes the key's token. This + preserves the team-as-owner boundary while still letting a group reach the + key without first being added to the team's `access_group_ids` list. """ if valid_token is None: - return False + return [] key_access_group_ids = list(valid_token.access_group_ids or []) if not key_access_group_ids: - return False + return [] from litellm.proxy.proxy_server import prisma_client as _prisma_client from litellm.proxy.proxy_server import proxy_logging_obj as _proxy_logging_obj from litellm.proxy.proxy_server import user_api_key_cache as _user_api_key_cache if _prisma_client is None or _user_api_key_cache is None: - return False + return [] key_team_id = valid_token.team_id or ( team_object.team_id if team_object is not None else None ) key_token = valid_token.token - authorized_models: List[str] = [] + authorized_resources: List[str] = [] for ag_id in key_access_group_ids: try: ag = await get_access_object( @@ -3216,17 +3212,36 @@ async def _key_access_group_grants_model( ) key_authorized = bool(key_token and key_token in (ag.assigned_key_ids or [])) if team_authorized or key_authorized: - authorized_models.extend(ag.access_model_names or []) + authorized_resources.extend(getattr(ag, resource_field, []) or []) + return list(set(authorized_resources)) + + +async def _key_access_group_grants_model( + model: Union[str, List[str]], + valid_token: Optional[UserAPIKeyAuth], + team_object: Optional[LiteLLM_TeamTable], + llm_router: Optional[Router], +) -> bool: + """ + Returns True if the key's `access_group_ids` expand to models that grant + access to `model`. Used to let a key's access group override a team's + model restriction in `common_checks`. + """ + authorized_models = await get_authorized_resources_from_key_access_groups( + valid_token=valid_token, + team_object=team_object, + resource_field="access_model_names", + ) if not authorized_models: return False try: _can_object_call_model( model=model, llm_router=llm_router, - models=list(set(authorized_models)), - team_model_aliases=valid_token.team_model_aliases, - team_id=valid_token.team_id, + models=authorized_models, + team_model_aliases=valid_token.team_model_aliases if valid_token else None, + team_id=valid_token.team_id if valid_token else None, object_type="key", ) return True diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py index 88742c67a86..54a36eac2a1 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py @@ -880,7 +880,7 @@ async def test_legitimate_well_known_path_still_bypasses_auth(self): with patch( "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", ) as mock_auth: - (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + auth_result, *_rest = await MCPRequestHandler.process_mcp_request(scope) mock_auth.assert_not_called() assert isinstance(auth_result, UserAPIKeyAuth) @@ -997,7 +997,7 @@ async def mock_user_api_key_auth_fails(api_key, request): mock_mgr.get_mcp_server_by_name.return_value = ( TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.oauth2) ) - (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + auth_result, *_rest = await MCPRequestHandler.process_mcp_request(scope) assert isinstance(auth_result, UserAPIKeyAuth) async def test_fallback_blocked_when_any_target_in_header_is_not_oauth2(self): @@ -1157,7 +1157,7 @@ async def test_delegate_skips_litellm_auth_with_no_authorization(self): delegate_auth_to_upstream=True, ) ) - (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + auth_result, *_rest = await MCPRequestHandler.process_mcp_request(scope) assert isinstance(auth_result, UserAPIKeyAuth) mock_auth.assert_not_called() @@ -1400,7 +1400,7 @@ async def test_explicit_litellm_key_takes_precedence_over_delegate(self): delegate_auth_to_upstream=True, ) ) - (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + auth_result, *_rest = await MCPRequestHandler.process_mcp_request(scope) assert isinstance(auth_result, UserAPIKeyAuth) assert auth_result.user_id == "real-user" mock_auth.assert_called_once() @@ -1437,7 +1437,7 @@ async def test_litellm_key_via_authorization_header_not_bypassed(self): delegate_auth_to_upstream=True, ) ) - (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + auth_result, *_rest = await MCPRequestHandler.process_mcp_request(scope) assert isinstance(auth_result, UserAPIKeyAuth) assert auth_result.user_id == "real-user" mock_auth.assert_called_once() @@ -3160,3 +3160,299 @@ async def test_get_allowed_tools_for_server_org_no_restriction(self): user_api_key_auth=auth, ) assert sorted(result) == ["tool_a", "tool_b"] + + +# --------------------------------------------------------------------------- +# LIT-3189: key unified access_group_ids extend team MCP scope +# --------------------------------------------------------------------------- + + +def _patch_proxy_server_globals_for_mcp(): + """Non-None mocks so the helper's None-guard doesn't short-circuit.""" + return [ + patch("litellm.proxy.proxy_server.prisma_client", MagicMock()), + patch("litellm.proxy.proxy_server.user_api_key_cache", MagicMock()), + patch("litellm.proxy.proxy_server.proxy_logging_obj", MagicMock()), + ] + + +def _fake_mcp_access_group( + access_group_id, + access_mcp_server_ids=None, + assigned_team_ids=None, + assigned_key_ids=None, +): + from litellm.proxy._types import LiteLLM_AccessGroupTable + + return LiteLLM_AccessGroupTable( + access_group_id=access_group_id, + access_group_name=access_group_id, + access_mcp_server_ids=access_mcp_server_ids or [], + assigned_team_ids=assigned_team_ids or [], + assigned_key_ids=assigned_key_ids or [], + ) + + +def _start_patches(patches): + for p in patches: + p.start() + + +def _stop_patches(patches): + for p in patches: + p.stop() + + +@pytest.mark.asyncio +async def test_mcp_key_access_group_extras_when_team_authorized(): + """Group's assigned_team_ids includes key's team and grants an MCP server → server returned.""" + valid_token = UserAPIKeyAuth( + token="test-token", + access_group_ids=["mcp-premium"], + team_id="team-a", + ) + fake_ag = _fake_mcp_access_group( + access_group_id="mcp-premium", + access_mcp_server_ids=["srv-stripe"], + assigned_team_ids=["team-a"], + ) + + mock_mgr = MagicMock() + mock_mgr.expand_permission_list.side_effect = lambda x: list(x) + + patches = _patch_proxy_server_globals_for_mcp() + [ + patch( + "litellm.proxy.auth.auth_checks.get_access_object", + new_callable=AsyncMock, + return_value=fake_ag, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager", + mock_mgr, + ), + ] + _start_patches(patches) + try: + result = await MCPRequestHandler._get_key_access_group_mcp_server_extras( + valid_token + ) + assert result == ["srv-stripe"] + finally: + _stop_patches(patches) + + +@pytest.mark.asyncio +async def test_mcp_key_access_group_extras_when_key_directly_authorized(): + """Group's assigned_key_ids includes the key's token → server returned (per-key auth).""" + valid_token = UserAPIKeyAuth( + token="test-token-hashed", + access_group_ids=["mcp-per-key"], + team_id="team-a", + ) + fake_ag = _fake_mcp_access_group( + access_group_id="mcp-per-key", + access_mcp_server_ids=["srv-stripe"], + assigned_team_ids=[], + assigned_key_ids=["test-token-hashed"], + ) + + mock_mgr = MagicMock() + mock_mgr.expand_permission_list.side_effect = lambda x: list(x) + + patches = _patch_proxy_server_globals_for_mcp() + [ + patch( + "litellm.proxy.auth.auth_checks.get_access_object", + new_callable=AsyncMock, + return_value=fake_ag, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager", + mock_mgr, + ), + ] + _start_patches(patches) + try: + result = await MCPRequestHandler._get_key_access_group_mcp_server_extras( + valid_token + ) + assert result == ["srv-stripe"] + finally: + _stop_patches(patches) + + +@pytest.mark.asyncio +async def test_mcp_key_access_group_extras_when_key_has_no_groups(): + """Empty access_group_ids → no extras, no DB read.""" + valid_token = UserAPIKeyAuth( + token="test-token", + access_group_ids=[], + team_id="team-a", + ) + result = await MCPRequestHandler._get_key_access_group_mcp_server_extras( + valid_token + ) + assert result == [] + + +@pytest.mark.asyncio +async def test_mcp_key_access_group_extras_when_group_has_no_servers(): + """Group authorizes the team but its access_mcp_server_ids is empty → no extras.""" + valid_token = UserAPIKeyAuth( + token="test-token", + access_group_ids=["mcp-empty"], + team_id="team-a", + ) + fake_ag = _fake_mcp_access_group( + access_group_id="mcp-empty", + access_mcp_server_ids=[], + assigned_team_ids=["team-a"], + ) + + patches = _patch_proxy_server_globals_for_mcp() + [ + patch( + "litellm.proxy.auth.auth_checks.get_access_object", + new_callable=AsyncMock, + return_value=fake_ag, + ), + ] + _start_patches(patches) + try: + result = await MCPRequestHandler._get_key_access_group_mcp_server_extras( + valid_token + ) + assert result == [] + finally: + _stop_patches(patches) + + +@pytest.mark.asyncio +async def test_mcp_key_access_group_extras_when_group_authorizes_neither(): + """ + Escalation regression: team member attaches a foreign access group to their key. + Group grants servers BUT assigned_team_ids/assigned_key_ids exclude this caller. + No extras contributed. + """ + valid_token = UserAPIKeyAuth( + token="team-a-token", + access_group_ids=["team-b-mcp-group"], + team_id="team-a", + ) + fake_ag = _fake_mcp_access_group( + access_group_id="team-b-mcp-group", + access_mcp_server_ids=["srv-finance-only"], + assigned_team_ids=["team-b"], + assigned_key_ids=["team-b-token"], + ) + + patches = _patch_proxy_server_globals_for_mcp() + [ + patch( + "litellm.proxy.auth.auth_checks.get_access_object", + new_callable=AsyncMock, + return_value=fake_ag, + ), + ] + _start_patches(patches) + try: + result = await MCPRequestHandler._get_key_access_group_mcp_server_extras( + valid_token + ) + assert result == [] + finally: + _stop_patches(patches) + + +@pytest.mark.asyncio +async def test_mcp_key_access_group_extras_when_get_access_object_raises(): + """Group lookup failure is treated as no authorization (does not crash).""" + valid_token = UserAPIKeyAuth( + token="test-token", + access_group_ids=["missing-mcp-group"], + team_id="team-a", + ) + patches = _patch_proxy_server_globals_for_mcp() + [ + patch( + "litellm.proxy.auth.auth_checks.get_access_object", + new_callable=AsyncMock, + side_effect=Exception("not found"), + ), + ] + _start_patches(patches) + try: + result = await MCPRequestHandler._get_key_access_group_mcp_server_extras( + valid_token + ) + assert result == [] + finally: + _stop_patches(patches) + + +@pytest.mark.asyncio +async def test_get_allowed_mcp_servers_unions_key_access_group_extras(): + """End-to-end: team has [srv-team], key access group grants [srv-extra] → both in final list. + + Without this fix [srv-extra] would be intersected away because the team doesn't list it. + """ + auth = UserAPIKeyAuth( + token="test-token", + api_key="test-key", + team_id="team-a", + access_group_ids=["mcp-extra-group"], + ) + + with ( + patch.object( + MCPRequestHandler, + "_get_allowed_mcp_servers_for_key", + new_callable=AsyncMock, + return_value=[], + ), + patch.object( + MCPRequestHandler, + "_get_allowed_mcp_servers_for_team", + new_callable=AsyncMock, + return_value=["srv-team"], + ), + patch.object( + MCPRequestHandler, + "_get_key_access_group_mcp_server_extras", + new_callable=AsyncMock, + return_value=["srv-extra"], + ), + ): + result = await MCPRequestHandler.get_allowed_mcp_servers(auth) + assert sorted(result) == ["srv-extra", "srv-team"] + + +@pytest.mark.asyncio +async def test_get_allowed_mcp_servers_no_union_when_no_authorized_extras(): + """End-to-end: no authorized extras → behavior identical to today (team ceiling enforced).""" + auth = UserAPIKeyAuth( + token="test-token", + api_key="test-key", + team_id="team-a", + access_group_ids=["mcp-foreign-group"], + ) + + with ( + patch.object( + MCPRequestHandler, + "_get_allowed_mcp_servers_for_key", + new_callable=AsyncMock, + return_value=["srv-key-only"], + ), + patch.object( + MCPRequestHandler, + "_get_allowed_mcp_servers_for_team", + new_callable=AsyncMock, + return_value=["srv-team"], + ), + patch.object( + MCPRequestHandler, + "_get_key_access_group_mcp_server_extras", + new_callable=AsyncMock, + return_value=[], + ), + ): + # key ∩ team = {} (no overlap), extras = [] → final = [] + result = await MCPRequestHandler.get_allowed_mcp_servers(auth) + assert result == []