diff --git a/.gitignore b/.gitignore index dff64e3c9e9..a3a6e871e25 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,8 @@ litellm/tests/config_*.yaml litellm/tests/langfuse.log langfuse.log .langfuse.log +.pin_list.txt +.cov_new.xml litellm/tests/test_custom_logger.py litellm/tests/langfuse.log litellm/tests/dynamo*.log diff --git a/tests/test_litellm/proxy/proxy_server/test_background_health.py b/tests/test_litellm/proxy/proxy_server/test_background_health.py index ad6b4016461..ee8d8b22779 100644 --- a/tests/test_litellm/proxy/proxy_server/test_background_health.py +++ b/tests/test_litellm/proxy/proxy_server/test_background_health.py @@ -1 +1,513 @@ -"""Placeholder. Filled by a follow-up PR per the Notion plan.""" +"""Behavior pins for proxy_server background health-check helpers. + +Pins covered: +- ``_get_process_rss_mb`` +- ``_rss_mb_for_log`` +- ``_run_direct_health_check_with_instrumentation`` +- ``_schedule_background_health_check_db_save`` +- ``_get_endpoint_exception_status`` +- ``_write_health_state_to_router_cache`` +- ``_adaptive_router_flusher_loop`` +- ``_run_background_health_check`` +""" + +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import litellm.proxy.proxy_server as proxy_server +from litellm.proxy.proxy_server import ( + _adaptive_router_flusher_loop, + _get_endpoint_exception_status, + _get_process_rss_mb, + _run_background_health_check, + _run_direct_health_check_with_instrumentation, + _rss_mb_for_log, + _schedule_background_health_check_db_save, + _write_health_state_to_router_cache, +) + +from .conftest import normalize + +# --------------------------------------------------------------------------- +# _get_process_rss_mb +# --------------------------------------------------------------------------- + + +def test_get_process_rss_mb_returns_positive_float(): + value = _get_process_rss_mb() + assert value is not None + assert normalize( + { + "value_present": value is not None, + "value_type": type(value).__name__, + "positive": value > 0, + } + ) == { + "value_present": True, + "value_type": "float", + "positive": True, + } + + +def test_get_process_rss_mb_returns_none_when_resource_raises(monkeypatch): + import resource + + def _boom(*_args, **_kwargs): + raise OSError("nope") + + monkeypatch.setattr(resource, "getrusage", _boom) + assert _get_process_rss_mb() is None + + +# --------------------------------------------------------------------------- +# _rss_mb_for_log +# --------------------------------------------------------------------------- + + +def test_rss_mb_for_log_formats_numeric_value(monkeypatch): + monkeypatch.setattr(proxy_server, "_get_process_rss_mb", lambda: 100.5) + result = _rss_mb_for_log() + assert normalize( + { + "format": result, + "is_string": isinstance(result, str), + "contains_mb": "100.50" in result, + } + ) == { + "format": "100.50", + "is_string": True, + "contains_mb": True, + } + + +def test_rss_mb_for_log_unknown_when_rss_missing(monkeypatch): + monkeypatch.setattr(proxy_server, "_get_process_rss_mb", lambda: None) + assert _rss_mb_for_log() == "unknown" + + +# --------------------------------------------------------------------------- +# _run_direct_health_check_with_instrumentation +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_run_direct_health_check_with_instrumentation_returns_results( + monkeypatch, +): + expected = (["healthy_ep"], ["unhealthy_ep"], {"m1": Exception("boom")}) + + async def _fake_perform(model_list, details, max_concurrency, **kwargs): + return expected + + monkeypatch.setattr(proxy_server, "perform_health_check", _fake_perform) + monkeypatch.setattr( + proxy_server, + "health_check_filter_kwargs_from_general_settings", + lambda _gs: {}, + ) + + healthy, unhealthy, exceptions = ( + await _run_direct_health_check_with_instrumentation( + model_list=[{"model_name": "gpt-4"}], + details=False, + max_concurrency=1, + instrumentation_context={"source": "test"}, + ) + ) + + assert normalize( + { + "healthy": healthy, + "unhealthy": unhealthy, + "exception_keys": list(exceptions.keys()), + } + ) == { + "healthy": ["healthy_ep"], + "unhealthy": ["unhealthy_ep"], + "exception_keys": ["m1"], + } + + +@pytest.mark.asyncio +async def test_run_direct_health_check_raises_non_kwarg_typeerror(monkeypatch): + async def _boom(model_list, details, max_concurrency, **kwargs): + raise TypeError("totally unrelated") + + monkeypatch.setattr(proxy_server, "perform_health_check", _boom) + monkeypatch.setattr( + proxy_server, + "health_check_filter_kwargs_from_general_settings", + lambda _gs: {}, + ) + + with pytest.raises(TypeError): + await _run_direct_health_check_with_instrumentation( + model_list=[], + details=False, + max_concurrency=1, + instrumentation_context={}, + ) + + +# --------------------------------------------------------------------------- +# _schedule_background_health_check_db_save +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_schedule_background_health_check_db_save_creates_task(monkeypatch): + captured = {} + + async def _fake_save( + prisma_client, + model_list, + healthy, + unhealthy, + start_time, + checked_by, + ): + captured["prisma_client"] = prisma_client + captured["model_list"] = model_list + captured["healthy"] = healthy + captured["unhealthy"] = unhealthy + captured["checked_by"] = checked_by + + import litellm.proxy.health_endpoints._health_endpoints as he + + monkeypatch.setattr(he, "_save_background_health_checks_to_db", _fake_save) + + prisma_client = MagicMock() + shared_manager = SimpleNamespace(pod_id="pod-xyz") + + _schedule_background_health_check_db_save( + prisma_client=prisma_client, + shared_health_manager=shared_manager, + model_list=[{"model_name": "gpt-4"}], + healthy_endpoints=[{"model_id": "h1"}], + unhealthy_endpoints=[{"model_id": "u1"}], + ) + + await asyncio.sleep(0) + + assert normalize( + { + "prisma_present": captured.get("prisma_client") is prisma_client, + "checked_by": captured.get("checked_by"), + "healthy": captured.get("healthy"), + "unhealthy": captured.get("unhealthy"), + } + ) == { + "prisma_present": True, + "checked_by": "pod-xyz", + "healthy": [{"model_id": "h1"}], + "unhealthy": [{"model_id": "u1"}], + } + + +def test_schedule_background_health_check_db_save_noop_when_prisma_none(): + _schedule_background_health_check_db_save( + prisma_client=None, + shared_health_manager=None, + model_list=[], + healthy_endpoints=[], + unhealthy_endpoints=[], + ) + + +@pytest.mark.asyncio +async def test_schedule_background_health_check_db_save_invalid_no_event_loop_raises( + monkeypatch, +): + async def _fake_save(*_args, **_kwargs): + return None + + import litellm.proxy.health_endpoints._health_endpoints as he + + monkeypatch.setattr(he, "_save_background_health_checks_to_db", _fake_save) + + def _broken_create_task(_coro): + raise RuntimeError("no running event loop") + + monkeypatch.setattr(asyncio, "create_task", _broken_create_task) + + with pytest.raises(RuntimeError): + _schedule_background_health_check_db_save( + prisma_client=MagicMock(), + shared_health_manager=None, + model_list=[], + healthy_endpoints=[], + unhealthy_endpoints=[], + ) + + +# --------------------------------------------------------------------------- +# _get_endpoint_exception_status +# --------------------------------------------------------------------------- + + +def test_get_endpoint_exception_status_prefers_live_exception(): + endpoint = {"model_id": "m1", "exception_status": 999} + exceptions = {"m1": SimpleNamespace(status_code=429)} + status = _get_endpoint_exception_status(endpoint, exceptions) + assert normalize( + { + "input_endpoint": endpoint, + "exceptions_keys": list(exceptions.keys()), + "status": status, + } + ) == { + "input_endpoint": {"model_id": "m1", "exception_status": 999}, + "exceptions_keys": ["m1"], + "status": 429, + } + + +def test_get_endpoint_exception_status_falls_back_to_stored_int(): + endpoint = {"model_id": "m-missing", "exception_status": 503} + assert _get_endpoint_exception_status(endpoint, {}) == 503 + + +def test_get_endpoint_exception_status_default_500_when_no_data(): + assert _get_endpoint_exception_status({}, {}) == 500 + + +def test_get_endpoint_exception_status_invalid_endpoint_type_raises(): + with pytest.raises(AttributeError): + _get_endpoint_exception_status(None, {}) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# _write_health_state_to_router_cache +# --------------------------------------------------------------------------- + + +def test_write_health_state_to_router_cache_sets_states(monkeypatch): + fake_router = MagicMock() + fake_router.enable_health_check_routing = True + fake_router.health_check_ignore_transient_errors = False + fake_router.cooldown_time = 30 + fake_router.health_state_cache = MagicMock() + + monkeypatch.setattr(proxy_server, "llm_router", fake_router) + + fake_states = {"m1": {"is_healthy": True}, "m2": {"is_healthy": False}} + + import litellm.proxy.health_check as hc + + monkeypatch.setattr(hc, "build_deployment_health_states", lambda **_kw: fake_states) + + import litellm.router_utils.cooldown_handlers as cd + + monkeypatch.setattr(cd, "_set_cooldown_deployments", lambda **_kw: None) + + import litellm.router_utils.router_callbacks.track_deployment_metrics as tdm + + monkeypatch.setattr( + tdm, + "increment_deployment_failures_for_current_minute", + lambda **_kw: None, + ) + + healthy = [{"model_id": "m1"}] + unhealthy = [{"model_id": "m2"}] + exceptions = {"m2": SimpleNamespace(status_code=500)} + + _write_health_state_to_router_cache(healthy, unhealthy, exceptions) + + fake_router.health_state_cache.set_deployment_health_states.assert_called_once_with( + fake_states + ) + + call_args = fake_router.health_state_cache.set_deployment_health_states.call_args[ + 0 + ][0] + assert normalize( + { + "states_keys": sorted(call_args.keys()), + "m1_healthy": call_args["m1"]["is_healthy"], + "m2_healthy": call_args["m2"]["is_healthy"], + } + ) == { + "states_keys": ["m1", "m2"], + "m1_healthy": True, + "m2_healthy": False, + } + + +def test_write_health_state_to_router_cache_noop_when_router_none(monkeypatch): + monkeypatch.setattr(proxy_server, "llm_router", None) + _write_health_state_to_router_cache([], [], {}) + + +def test_write_health_state_to_router_cache_swallows_internal_failures(monkeypatch): + """The function logs and swallows exceptions so a bad cache call never crashes the loop.""" + fake_router = MagicMock() + fake_router.enable_health_check_routing = True + fake_router.health_check_ignore_transient_errors = False + fake_router.health_state_cache.set_deployment_health_states.side_effect = ( + RuntimeError("cache exploded") + ) + + monkeypatch.setattr(proxy_server, "llm_router", fake_router) + + import litellm.proxy.health_check as hc + + monkeypatch.setattr( + hc, + "build_deployment_health_states", + lambda **_kw: {"m1": {"is_healthy": True}}, + ) + + _write_health_state_to_router_cache([{"model_id": "m1"}], [], {}) + + +# --------------------------------------------------------------------------- +# _adaptive_router_flusher_loop +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_adaptive_router_flusher_loop_flushes_each_router(monkeypatch): + fake_ar = MagicMock() + fake_ar._state_loaded = True + fake_ar.queue.flush_state_to_db = AsyncMock() + fake_ar.queue.flush_session_to_db = AsyncMock() + + fake_router = MagicMock() + fake_router.adaptive_routers = {"alpha": fake_ar} + + monkeypatch.setattr(proxy_server, "llm_router", fake_router) + monkeypatch.setattr(proxy_server, "prisma_client", MagicMock()) + + # asyncio.sleep is awaited at the top of every iteration; raise CancelledError + # on the SECOND call so the first iteration completes its flush work. + call_count = {"n": 0} + _real_sleep = asyncio.sleep + + async def _short_sleep(_seconds): + call_count["n"] += 1 + if call_count["n"] >= 2: + raise asyncio.CancelledError() + await _real_sleep(0) + + monkeypatch.setattr(proxy_server.asyncio, "sleep", _short_sleep) + + with pytest.raises(asyncio.CancelledError): + await _adaptive_router_flusher_loop() + + assert fake_ar.queue.flush_state_to_db.await_count == 1 + assert fake_ar.queue.flush_session_to_db.await_count == 1 + + +@pytest.mark.asyncio +async def test_adaptive_router_flusher_loop_times_out_when_sleep_real(monkeypatch): + """Confirms the loop is infinite — wait_for must raise TimeoutError.""" + monkeypatch.setattr(proxy_server, "llm_router", MagicMock(adaptive_routers={})) + monkeypatch.setattr(proxy_server, "prisma_client", None) + + # Bind the real asyncio.sleep before the patch so the replacement does not + # recurse into itself. + _real_sleep = asyncio.sleep + + async def _instant_sleep(_seconds): + await _real_sleep(0) + + monkeypatch.setattr(proxy_server.asyncio, "sleep", _instant_sleep) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(_adaptive_router_flusher_loop(), timeout=0.2) + + +# --------------------------------------------------------------------------- +# _run_background_health_check +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_run_background_health_check_returns_immediately_when_interval_invalid( + monkeypatch, +): + monkeypatch.setattr(proxy_server, "health_check_interval", None) + + result = await _run_background_health_check() + + assert normalize( + { + "result_is_none": result is None, + "loop_active": proxy_server.background_health_check_loop_active, + "interval": proxy_server.health_check_interval, + } + ) == { + "result_is_none": True, + "loop_active": False, + "interval": None, + } + + +@pytest.mark.asyncio +async def test_run_background_health_check_runs_one_cycle_then_cancels(monkeypatch): + monkeypatch.setattr(proxy_server, "health_check_interval", 60) + monkeypatch.setattr(proxy_server, "health_check_concurrency", 1) + monkeypatch.setattr(proxy_server, "health_check_details", True) + monkeypatch.setattr(proxy_server, "use_shared_health_check", False) + monkeypatch.setattr(proxy_server, "redis_usage_cache", None) + monkeypatch.setattr(proxy_server, "prisma_client", None) + monkeypatch.setattr(proxy_server, "background_health_check_loop_active", False) + monkeypatch.setattr( + proxy_server, + "llm_model_list", + [{"model_name": "gpt-4", "model_info": {}}], + ) + monkeypatch.setattr( + proxy_server, + "health_check_results", + {"healthy_endpoints": [], "unhealthy_endpoints": []}, + ) + + async def _fake_direct(*_a, **_kw): + return ([{"model_id": "h"}], [{"model_id": "u"}], {}) + + monkeypatch.setattr( + proxy_server, + "_run_direct_health_check_with_instrumentation", + _fake_direct, + ) + monkeypatch.setattr( + proxy_server, "_schedule_background_health_check_db_save", lambda *a, **kw: None + ) + monkeypatch.setattr( + proxy_server, "_write_health_state_to_router_cache", lambda *a, **kw: None + ) + monkeypatch.setattr( + proxy_server, + "health_check_filter_kwargs_from_general_settings", + lambda _gs: {}, + ) + + sleep_calls = {"n": 0} + + async def _stop_sleep(_seconds): + sleep_calls["n"] += 1 + raise asyncio.CancelledError() + + monkeypatch.setattr(proxy_server.asyncio, "sleep", _stop_sleep) + + with pytest.raises(asyncio.CancelledError): + await _run_background_health_check() + + assert normalize( + { + "healthy_count": proxy_server.health_check_results["healthy_count"], + "unhealthy_count": proxy_server.health_check_results["unhealthy_count"], + "sleep_invoked": sleep_calls["n"] >= 1, + } + ) == { + "healthy_count": 1, + "unhealthy_count": 1, + "sleep_invoked": True, + } diff --git a/tests/test_litellm/proxy/proxy_server/test_exception_handlers.py b/tests/test_litellm/proxy/proxy_server/test_exception_handlers.py index ad6b4016461..cf92f9cd12b 100644 --- a/tests/test_litellm/proxy/proxy_server/test_exception_handlers.py +++ b/tests/test_litellm/proxy/proxy_server/test_exception_handlers.py @@ -1 +1,222 @@ -"""Placeholder. Filled by a follow-up PR per the Notion plan.""" +"""Behavior pins for the proxy_server exception handlers. + +Pins covered: +- ``openai_exception_handler`` +- ``_close_dangling_otel_server_span`` +- ``otel_request_validation_exception_handler`` +- ``otel_unhandled_exception_handler`` +""" + +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException +from fastapi.exceptions import RequestValidationError + +from litellm.proxy._types import ProxyException +from litellm.proxy.proxy_server import ( + _close_dangling_otel_server_span, + openai_exception_handler, + otel_request_validation_exception_handler, + otel_unhandled_exception_handler, +) + +from .conftest import normalize + + +def _make_request(parent_otel_span=None): + state = SimpleNamespace(parent_otel_span=parent_otel_span) + return SimpleNamespace(state=state) + + +# --------------------------------------------------------------------------- +# openai_exception_handler +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_openai_exception_handler_returns_mapped_payload(): + exc = ProxyException( + message="bad input", + type="invalid_request_error", + param="model", + code=400, + ) + request = _make_request() + + response = await openai_exception_handler(request=request, exc=exc) + body = json.loads(response.body) + + assert response.status_code == 400 + assert normalize(body) == { + "error": { + "message": "bad input", + "type": "invalid_request_error", + "param": "model", + "code": "400", + } + } + + +@pytest.mark.asyncio +async def test_openai_exception_handler_invalid_empty_code_defaults_to_500(): + """openai_exception_handler falls back to 500 when ``code`` is falsy. + + Constructing via __new__ bypasses __init__ — the production __init__ always + coerces None to the string "None", which is truthy. To exercise the falsy + fallback branch we hand-craft an exception with an empty code.""" + exc = ProxyException.__new__(ProxyException) + exc.message = "boom" + exc.type = "server_error" + exc.param = None + exc.openai_code = None + exc.code = "" + exc.headers = {} + exc.provider_specific_fields = None + request = _make_request() + + response = await openai_exception_handler(request=request, exc=exc) + body = json.loads(response.body) + + assert response.status_code == 500 + assert body == { + "error": { + "message": "boom", + "type": "server_error", + "param": None, + "code": "", + } + } + + +# --------------------------------------------------------------------------- +# _close_dangling_otel_server_span +# --------------------------------------------------------------------------- + + +def test_close_dangling_otel_server_span_records_status_and_ends(monkeypatch): + """Happy path: with a logger and an active span, the handler sets the + response status, marks ERROR (>=400), ends the span, and clears state.""" + import litellm.proxy.proxy_server as ps + + span = MagicMock() + fake_logger = MagicMock() + monkeypatch.setattr(ps, "open_telemetry_logger", fake_logger, raising=False) + request = _make_request(parent_otel_span=span) + + _close_dangling_otel_server_span(request=request, status_code=502) + + observed = { + "status_attr_called": fake_logger.set_response_status_code_attribute.called, + "set_status_called": span.set_status.called, + "ended": span.end.called, + "state_cleared": request.state.parent_otel_span is None, + } + assert normalize(observed) == { + "status_attr_called": True, + "set_status_called": True, + "ended": True, + "state_cleared": True, + } + + +def test_close_dangling_otel_server_span_missing_span_is_noop_error(): + """When parent_otel_span is missing the call short-circuits — no error.""" + request = _make_request(parent_otel_span=None) + + result = _close_dangling_otel_server_span(request=request, status_code=200) + assert result is None + assert request.state.parent_otel_span is None + + +def test_close_dangling_otel_server_span_logger_raises_state_cleared_error(monkeypatch): + """Logger raising is caught; state.parent_otel_span is cleared regardless.""" + import litellm.proxy.proxy_server as ps + + span = MagicMock() + fake_logger = MagicMock() + fake_logger.set_response_status_code_attribute.side_effect = RuntimeError("boom") + monkeypatch.setattr(ps, "open_telemetry_logger", fake_logger, raising=False) + request = _make_request(parent_otel_span=span) + + _close_dangling_otel_server_span(request=request, status_code=500) + + assert request.state.parent_otel_span is None + + +# --------------------------------------------------------------------------- +# otel_request_validation_exception_handler +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_otel_request_validation_exception_handler_returns_422_detail(): + errors = [{"loc": ["body", "model"], "msg": "field required", "type": "missing"}] + exc = RequestValidationError(errors) + request = _make_request() + + response = await otel_request_validation_exception_handler(request=request, exc=exc) + body = json.loads(response.body) + + assert response.status_code == 422 + assert normalize(body) == {"detail": exc.errors()} + + +@pytest.mark.asyncio +async def test_otel_request_validation_exception_handler_empty_errors_invalid_payload(): + """An empty error list still returns 422 — the validator emitted nothing + but the handler must not crash and the body must remain well-formed.""" + exc = RequestValidationError([]) + request = _make_request() + + response = await otel_request_validation_exception_handler(request=request, exc=exc) + body = json.loads(response.body) + + assert response.status_code == 422 + assert body == {"detail": []} + + +# --------------------------------------------------------------------------- +# otel_unhandled_exception_handler +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_otel_unhandled_exception_handler_returns_500_generic_payload(): + exc = RuntimeError("kaboom") + request = _make_request() + + response = await otel_unhandled_exception_handler(request=request, exc=exc) + body = json.loads(response.body) + + assert response.status_code == 500 + assert normalize(body) == { + "error": { + "message": "Internal server error", + "type": "internal_server_error", + } + } + + +@pytest.mark.asyncio +async def test_otel_unhandled_exception_handler_reraises_proxy_exception_error(): + """ProxyException / HTTPException / RequestValidationError are re-raised + so the dedicated handler runs.""" + exc = ProxyException(message="m", type="t", param="p", code=403) + request = _make_request() + + with pytest.raises(ProxyException): + await otel_unhandled_exception_handler(request=request, exc=exc) + + +@pytest.mark.asyncio +async def test_otel_unhandled_exception_handler_reraises_http_exception_invalid(): + request = _make_request() + with pytest.raises(HTTPException): + await otel_unhandled_exception_handler( + request=request, exc=HTTPException(status_code=418, detail="teapot") + ) diff --git a/tests/test_litellm/proxy/proxy_server/test_lifecycle.py b/tests/test_litellm/proxy/proxy_server/test_lifecycle.py index ad6b4016461..0b733401b59 100644 --- a/tests/test_litellm/proxy/proxy_server/test_lifecycle.py +++ b/tests/test_litellm/proxy/proxy_server/test_lifecycle.py @@ -1 +1,564 @@ -"""Placeholder. Filled by a follow-up PR per the Notion plan.""" +"""Behavior pins for proxy_server lifecycle, helpers, and small utilities. + +Pins covered: +- ``proxy_startup_event`` +- ``proxy_shutdown_event`` +- ``_initialize_shared_aiohttp_session`` +- ``cleanup_router_config_variables`` +- ``save_worker_config`` +- ``initialize`` +- ``load_from_azure_key_vault`` +- ``cost_tracking`` +- ``check_request_disconnection`` +- ``_resolve_typed_dict_type`` +- ``_resolve_pydantic_type`` +- ``get_litellm_model_info`` +- ``run_ollama_serve`` +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import os +from typing import List, Optional, Union +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing_extensions import TypedDict + +import litellm.proxy.proxy_server as ps +from litellm.proxy.proxy_server import ( + _initialize_shared_aiohttp_session, + _resolve_pydantic_type, + _resolve_typed_dict_type, + check_request_disconnection, + cleanup_router_config_variables, + cost_tracking, + get_litellm_model_info, + initialize, + load_from_azure_key_vault, + proxy_shutdown_event, + proxy_startup_event, + run_ollama_serve, + save_worker_config, +) + +from .conftest import normalize + +# --------------------------------------------------------------------------- +# cleanup_router_config_variables +# --------------------------------------------------------------------------- + + +def test_cleanup_router_config_variables_resets_globals(monkeypatch): + monkeypatch.setattr(ps, "master_key", "sk-sentinel", raising=False) + monkeypatch.setattr(ps, "user_config_file_path", "/tmp/config.yaml", raising=False) + monkeypatch.setattr(ps, "user_custom_auth", lambda x: x, raising=False) + monkeypatch.setattr(ps, "health_check_interval", 42, raising=False) + monkeypatch.setattr(ps, "prisma_client", MagicMock(), raising=False) + + cleanup_router_config_variables() + + observed = { + "master_key": ps.master_key, + "user_config_file_path": ps.user_config_file_path, + "user_custom_auth": ps.user_custom_auth, + "health_check_interval": ps.health_check_interval, + "prisma_client": ps.prisma_client, + } + assert normalize(observed) == { + "master_key": None, + "user_config_file_path": None, + "user_custom_auth": None, + "health_check_interval": None, + "prisma_client": None, + } + + +def test_cleanup_router_config_variables_fails_on_unknown_attr_raises(): + """The function only writes documented globals — accessing a non-existent + one after cleanup should still raise AttributeError.""" + cleanup_router_config_variables() + with pytest.raises(AttributeError): + _ = ps.this_attribute_should_not_exist_xyz + + +# --------------------------------------------------------------------------- +# proxy_shutdown_event +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_proxy_shutdown_event_disconnects_prisma_and_resets(monkeypatch): + fake_prisma = MagicMock() + fake_prisma.disconnect = AsyncMock() + monkeypatch.setattr(ps, "prisma_client", fake_prisma, raising=False) + monkeypatch.setattr(ps, "master_key", "sk-x", raising=False) + + fake_jwt = MagicMock() + fake_jwt.close = AsyncMock() + monkeypatch.setattr(ps, "jwt_handler", fake_jwt, raising=False) + monkeypatch.setattr(ps, "db_writer_client", None, raising=False) + + import litellm + + monkeypatch.setattr(litellm, "cache", None, raising=False) + monkeypatch.setattr(litellm, "success_callback", [], raising=False) + + await proxy_shutdown_event() + + observed = { + "disconnect_called": fake_prisma.disconnect.await_count == 1, + "jwt_closed": fake_jwt.close.await_count == 1, + "master_key_reset": ps.master_key, + "prisma_reset": ps.prisma_client, + } + assert normalize(observed) == { + "disconnect_called": True, + "jwt_closed": True, + "master_key_reset": None, + "prisma_reset": None, + } + + +@pytest.mark.asyncio +async def test_proxy_shutdown_event_prisma_disconnect_raises_error(monkeypatch): + fake_prisma = MagicMock() + fake_prisma.disconnect = AsyncMock(side_effect=RuntimeError("db gone")) + monkeypatch.setattr(ps, "prisma_client", fake_prisma, raising=False) + + fake_jwt = MagicMock() + fake_jwt.close = AsyncMock() + monkeypatch.setattr(ps, "jwt_handler", fake_jwt, raising=False) + + import litellm + + monkeypatch.setattr(litellm, "cache", None, raising=False) + monkeypatch.setattr(litellm, "success_callback", [], raising=False) + + with pytest.raises(RuntimeError, match="db gone"): + await proxy_shutdown_event() + + +# --------------------------------------------------------------------------- +# _initialize_shared_aiohttp_session +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_initialize_shared_aiohttp_session_returns_client_session(): + from aiohttp import ClientSession + + session = await _initialize_shared_aiohttp_session() + try: + observed = { + "is_client_session": isinstance(session, ClientSession), + "is_closed": session.closed, + "has_connector": session.connector is not None, + } + assert normalize(observed) == { + "is_client_session": True, + "is_closed": False, + "has_connector": True, + } + finally: + if session is not None: + await session.close() + + +@pytest.mark.asyncio +async def test_initialize_shared_aiohttp_session_aiohttp_missing_returns_none_on_failure( + monkeypatch, +): + """If aiohttp import fails, the function catches and returns None — no raise.""" + import builtins + + real_import = builtins.__import__ + + def _raise_for_aiohttp(name, *args, **kwargs): + if name == "aiohttp": + raise ImportError("simulated missing aiohttp") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _raise_for_aiohttp) + result = await _initialize_shared_aiohttp_session() + assert result is None + + +# --------------------------------------------------------------------------- +# save_worker_config +# --------------------------------------------------------------------------- + + +def test_save_worker_config_writes_json_to_environ(monkeypatch): + monkeypatch.delenv("WORKER_CONFIG", raising=False) + + save_worker_config(model="gpt-4", config="/tmp/c.yaml", debug=True) + + payload = json.loads(os.environ["WORKER_CONFIG"]) + assert normalize(payload) == { + "model": "gpt-4", + "config": "/tmp/c.yaml", + "debug": True, + } + + +def test_save_worker_config_invalid_no_kwargs_yields_empty(monkeypatch): + monkeypatch.delenv("WORKER_CONFIG", raising=False) + + save_worker_config() + assert os.environ["WORKER_CONFIG"] == "{}" + + +# --------------------------------------------------------------------------- +# initialize +# --------------------------------------------------------------------------- + + +def test_initialize_signature_is_async_with_expected_params(): + sig = inspect.signature(initialize) + # Hard-coded so a signature change (param added/removed) trips the gate. + expected_param_count = 17 + observed = { + "is_async": inspect.iscoroutinefunction(initialize), + "param_count": len(sig.parameters), + "has_model": "model" in sig.parameters, + "has_config": "config" in sig.parameters, + } + assert normalize(observed) == { + "is_async": True, + "param_count": expected_param_count, + "has_model": True, + "has_config": True, + } + + +@pytest.mark.asyncio +async def test_initialize_invalid_unexpected_kwarg_raises_type_error(): + with pytest.raises(TypeError): + await initialize(this_is_not_a_real_kwarg=True) + + +# --------------------------------------------------------------------------- +# load_from_azure_key_vault +# --------------------------------------------------------------------------- + + +def test_load_from_azure_key_vault_disabled_no_side_effect(monkeypatch): + import litellm + + sentinel_secret_mgr = object() + monkeypatch.setattr( + litellm, "secret_manager_client", sentinel_secret_mgr, raising=False + ) + + result = load_from_azure_key_vault(use_azure_key_vault=False) + + observed = { + "return_value": result, + "secret_manager_unchanged": litellm.secret_manager_client + is sentinel_secret_mgr, + "called_with": False, + } + assert normalize(observed) == { + "return_value": None, + "secret_manager_unchanged": True, + "called_with": False, + } + + +def test_load_from_azure_key_vault_missing_uri_failure_is_swallowed(monkeypatch): + """Enabled but AZURE_KEY_VAULT_URI unset / azure libs likely unavailable — + function catches Exception and does not raise.""" + monkeypatch.delenv("AZURE_KEY_VAULT_URI", raising=False) + + result = load_from_azure_key_vault(use_azure_key_vault=True) + assert result is None + + +# --------------------------------------------------------------------------- +# cost_tracking +# --------------------------------------------------------------------------- + + +def test_cost_tracking_adds_two_callbacks_when_prisma_set(monkeypatch): + import litellm + + fake_prisma = MagicMock() + monkeypatch.setattr(ps, "prisma_client", fake_prisma, raising=False) + monkeypatch.setattr(litellm, "callbacks", [], raising=False) + monkeypatch.setattr(litellm, "_async_success_callback", [], raising=False) + + before_callbacks = len(litellm.callbacks) + before_async = len(litellm._async_success_callback) + + cost_tracking() + + observed = { + "added_to_callbacks": len(litellm.callbacks) - before_callbacks, + "added_to_async_success": len(litellm._async_success_callback) - before_async, + "prisma_was_set": True, + } + assert normalize(observed) == { + "added_to_callbacks": 1, + "added_to_async_success": 1, + "prisma_was_set": True, + } + + +def test_cost_tracking_no_op_when_prisma_missing(monkeypatch): + """Without a prisma_client cost_tracking is a no-op — not an error.""" + import litellm + + monkeypatch.setattr(ps, "prisma_client", None, raising=False) + monkeypatch.setattr(litellm, "callbacks", [], raising=False) + monkeypatch.setattr(litellm, "_async_success_callback", [], raising=False) + + cost_tracking() + + assert litellm.callbacks == [] + assert litellm._async_success_callback == [] + + +# --------------------------------------------------------------------------- +# check_request_disconnection +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_check_request_disconnection_cancels_task_and_raises_499(monkeypatch): + monkeypatch.setattr(ps.asyncio, "sleep", AsyncMock(return_value=None)) + + request = MagicMock() + request.is_disconnected = AsyncMock(return_value=True) + task = MagicMock() + + raised_status = None + try: + await check_request_disconnection(request=request, llm_api_call_task=task) + except HTTPException as exc: + raised_status = exc.status_code + + observed = { + "raised_status": raised_status, + "cancel_called": task.cancel.called, + "is_async": inspect.iscoroutinefunction(check_request_disconnection), + } + assert normalize(observed) == { + "raised_status": 499, + "cancel_called": True, + "is_async": True, + } + + +@pytest.mark.asyncio +async def test_check_request_disconnection_invalid_when_connected_times_out(monkeypatch): + """With a connected request the function loops for up to 10 minutes — + wrap in wait_for and assert it times out. Patch ``asyncio.sleep`` so the + loop spins without real wall-clock waits.""" + import litellm.proxy.proxy_server as ps + + request = MagicMock() + request.is_disconnected = AsyncMock(return_value=False) + task = MagicMock() + + _real_sleep = asyncio.sleep + + async def _instant_sleep(_seconds): + await _real_sleep(0) + + monkeypatch.setattr(ps.asyncio, "sleep", _instant_sleep) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for( + check_request_disconnection(request=request, llm_api_call_task=task), + timeout=0.05, + ) + + +# --------------------------------------------------------------------------- +# _resolve_typed_dict_type +# --------------------------------------------------------------------------- + + +class _SampleTD(TypedDict): + a: int + b: str + + +def test_resolve_typed_dict_type_finds_class_in_optional(): + typ = Optional[_SampleTD] + result = _resolve_typed_dict_type(typ) + + observed = { + "input_repr": "Optional[_SampleTD]", + "result_is_sample_td": result is _SampleTD, + "result_is_class": isinstance(result, type), + } + assert normalize(observed) == { + "input_repr": "Optional[_SampleTD]", + "result_is_sample_td": True, + "result_is_class": True, + } + + +def test_resolve_typed_dict_type_invalid_plain_type_returns_none(): + """A non-TypedDict, non-Union input returns None — not an error.""" + assert _resolve_typed_dict_type(int) is None + assert _resolve_typed_dict_type(str) is None + + +# --------------------------------------------------------------------------- +# _resolve_pydantic_type +# --------------------------------------------------------------------------- + + +class _SampleModelA(BaseModel): + x: int + + +class _SampleModelB(BaseModel): + y: str + + +def test_resolve_pydantic_type_extracts_non_none_args_from_union(): + typ = Union[_SampleModelA, _SampleModelB, None] + result = _resolve_pydantic_type(typ) + + observed = { + "result_type": type(result).__name__, + "result_len": len(result), + "contains_a": _SampleModelA in result, + "contains_b": _SampleModelB in result, + } + assert normalize(observed) == { + "result_type": "list", + "result_len": 2, + "contains_a": True, + "contains_b": True, + } + + +def test_resolve_pydantic_type_invalid_non_union_non_model_returns_empty(): + """When given a non-Union and non-BaseModel input the function returns []. + + This is the silent-empty fallback path — error-ish by behavior.""" + result = _resolve_pydantic_type(int) + assert result == [] + + +# --------------------------------------------------------------------------- +# get_litellm_model_info +# --------------------------------------------------------------------------- + + +def test_get_litellm_model_info_uses_base_model_for_lookup(monkeypatch): + import litellm + + expected_info = {"max_tokens": 8192, "input_cost_per_token": 0.00003} + fake_get = MagicMock(return_value=expected_info) + monkeypatch.setattr(litellm, "get_model_info", fake_get, raising=False) + + model = { + "model_info": {"base_model": "gpt-4"}, + "litellm_params": {"model": "azure/my-deployment"}, + } + result = get_litellm_model_info(model=model) + + observed = { + "called_arg": ( + fake_get.call_args.args[0] + if fake_get.call_args.args + else fake_get.call_args.kwargs.get("model") + ), + "returned_max_tokens": result.get("max_tokens"), + "returned_cost": result.get("input_cost_per_token"), + } + assert normalize(observed) == { + "called_arg": "gpt-4", + "returned_max_tokens": 8192, + "returned_cost": 0.00003, + } + + +def test_get_litellm_model_info_invalid_empty_dict_returns_empty(): + """Empty input means model_to_lookup is None — internal exception is caught + and the function returns {}.""" + result = get_litellm_model_info(model={}) + assert result == {} + + +# --------------------------------------------------------------------------- +# run_ollama_serve +# --------------------------------------------------------------------------- + + +def test_run_ollama_serve_invokes_subprocess_popen(monkeypatch): + fake_popen = MagicMock() + monkeypatch.setattr(ps.subprocess, "Popen", fake_popen) + + run_ollama_serve() + + args, kwargs = fake_popen.call_args + observed = { + "popen_called": fake_popen.call_count == 1, + "command": args[0] if args else kwargs.get("args"), + "has_stdout_kw": "stdout" in kwargs, + "has_stderr_kw": "stderr" in kwargs, + } + assert normalize(observed) == { + "popen_called": True, + "command": ["ollama", "serve"], + "has_stdout_kw": True, + "has_stderr_kw": True, + } + + +def test_run_ollama_serve_popen_failure_is_swallowed(monkeypatch): + """Popen raising OSError must NOT propagate — function logs and returns.""" + monkeypatch.setattr( + ps.subprocess, "Popen", MagicMock(side_effect=OSError("no ollama binary")) + ) + + result = run_ollama_serve() + assert result is None + + +# --------------------------------------------------------------------------- +# proxy_startup_event +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_proxy_startup_event_is_async_context_manager_with_expected_signature(): + """proxy_startup_event is the FastAPI lifespan. Verify its surface without + actually running the heavy init path (DB, Router, OTEL, etc.).""" + sig = inspect.signature(proxy_startup_event) + wrapped = getattr(proxy_startup_event, "__wrapped__", None) + observed = { + "param_count": len(sig.parameters), + "has_app_param": "app" in sig.parameters, + "wrapped_is_async": inspect.iscoroutinefunction(wrapped) + or inspect.isasyncgenfunction(wrapped), + "has_asynccontextmanager_wrapper": wrapped is not None, + } + assert normalize(observed) == { + "param_count": 1, + "has_app_param": True, + "wrapped_is_async": True, + "has_asynccontextmanager_wrapper": True, + } + + +@pytest.mark.asyncio +async def test_proxy_startup_event_invalid_missing_app_arg_raises(): + """Calling the lifespan with no FastAPI app argument must fail.""" + with pytest.raises(TypeError): + # Intentionally invoke the underlying async generator function with + # no arguments — the decorator preserves the missing-arg TypeError. + async with proxy_startup_event(): # type: ignore[call-arg] + pass diff --git a/tests/test_litellm/proxy/proxy_server/test_openapi_customization.py b/tests/test_litellm/proxy/proxy_server/test_openapi_customization.py index ad6b4016461..141b9f2a98a 100644 --- a/tests/test_litellm/proxy/proxy_server/test_openapi_customization.py +++ b/tests/test_litellm/proxy/proxy_server/test_openapi_customization.py @@ -1 +1,447 @@ -"""Placeholder. Filled by a follow-up PR per the Notion plan.""" +"""Behavior pins for proxy_server OpenAPI customization + CORS helpers. + +Pins covered: +- ``_generate_stable_operation_id`` +- ``_strip_operation_id_method_suffix`` +- ``ensure_unique_openapi_operation_ids`` +- ``_inject_websocket_stubs_into_openapi_schema`` +- ``get_openapi_schema`` +- ``custom_openapi`` +- ``mount_swagger_ui`` +- ``_get_cors_config`` +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from fastapi import FastAPI + +from litellm.proxy.proxy_server import ( + _generate_stable_operation_id, + _get_cors_config, + _inject_websocket_stubs_into_openapi_schema, + _strip_operation_id_method_suffix, + custom_openapi, + ensure_unique_openapi_operation_ids, + get_openapi_schema, + mount_swagger_ui, +) + +from .conftest import normalize + +# --------------------------------------------------------------------------- +# _generate_stable_operation_id +# --------------------------------------------------------------------------- + + +def test_generate_stable_operation_id_single_method_appends_suffix(): + route = SimpleNamespace( + name="list_models", + path_format="/v1/models", + methods={"GET"}, + ) + observed = { + "operation_id": _generate_stable_operation_id(route), + "name": route.name, + "path": route.path_format, + } + assert normalize(observed) == { + "operation_id": "list_models_v1_models_get", + "name": "list_models", + "path": "/v1/models", + } + + +def test_generate_stable_operation_id_multi_method_no_suffix(): + route = SimpleNamespace( + name="multi_op", + path_format="/v1/things/{id}", + methods={"GET", "POST"}, + ) + observed = { + "operation_id": _generate_stable_operation_id(route), + "method_count": len(route.methods), + "has_method_suffix": _generate_stable_operation_id(route).endswith( + ("_get", "_post") + ), + } + assert normalize(observed) == { + "operation_id": "multi_op_v1_things__id_", + "method_count": 2, + "has_method_suffix": False, + } + + +def test_generate_stable_operation_id_missing_attrs_raises_error(): + bad_route = SimpleNamespace() # missing name/path_format/methods + with pytest.raises(AttributeError): + _generate_stable_operation_id(bad_route) + + +# --------------------------------------------------------------------------- +# _strip_operation_id_method_suffix +# --------------------------------------------------------------------------- + + +def test_strip_operation_id_method_suffix_removes_known_method(): + observed = { + "with_get": _strip_operation_id_method_suffix("list_models_v1_models_get"), + "with_post": _strip_operation_id_method_suffix("create_thing_post"), + "with_delete": _strip_operation_id_method_suffix("drop_thing_delete"), + } + assert observed == { + "with_get": "list_models_v1_models", + "with_post": "create_thing", + "with_delete": "drop_thing", + } + + +def test_strip_operation_id_method_suffix_invalid_suffix_unchanged(): + # "foo" is not a known HTTP method; "nounderscore" has no separator at all. + observed = { + "unknown_suffix": _strip_operation_id_method_suffix("operation_foo"), + "no_underscore": _strip_operation_id_method_suffix("nounderscore"), + "empty": _strip_operation_id_method_suffix(""), + } + assert observed == { + "unknown_suffix": "operation_foo", + "no_underscore": "nounderscore", + "empty": "", + } + + +# --------------------------------------------------------------------------- +# ensure_unique_openapi_operation_ids +# --------------------------------------------------------------------------- + + +def test_ensure_unique_openapi_operation_ids_rewrites_duplicates(): + schema = { + "paths": { + "/a": {"get": {"operationId": "dup_get"}}, + "/b": {"get": {"operationId": "dup_get"}}, + "/c": {"post": {"operationId": "unique_post"}}, + } + } + result = ensure_unique_openapi_operation_ids(schema) + observed = { + "a_get": result["paths"]["/a"]["get"]["operationId"], + "b_get": result["paths"]["/b"]["get"]["operationId"], + "c_post": result["paths"]["/c"]["post"]["operationId"], + "ids_are_distinct": len( + { + result["paths"]["/a"]["get"]["operationId"], + result["paths"]["/b"]["get"]["operationId"], + result["paths"]["/c"]["post"]["operationId"], + } + ) + == 3, + } + assert normalize(observed) == { + "a_get": "dup_get", + "b_get": "dup_get_2", + "c_post": "unique_post", + "ids_are_distinct": True, + } + + +def test_ensure_unique_openapi_operation_ids_respects_reserved(): + # operationId already ends with "_get" (an HTTP method), so the suffix is + # stripped before re-appending the current method, yielding "reserved_get". + schema = { + "paths": { + "/a": {"get": {"operationId": "reserved_get"}}, + } + } + reserved = {"reserved_get"} + result = ensure_unique_openapi_operation_ids( + schema, reserved_operation_ids=reserved + ) + observed = { + "rewritten": result["paths"]["/a"]["get"]["operationId"], + "still_includes_original": "reserved_get" in reserved, + "reserved_grew": len(reserved) > 1, + } + assert normalize(observed) == { + "rewritten": "reserved_get_2", + "still_includes_original": True, + "reserved_grew": True, + } + + +def test_ensure_unique_openapi_operation_ids_missing_paths_invalid_returns_empty(): + """No ``paths`` key — function must not crash and must return the schema as-is.""" + schema = {"info": {"title": "x"}} + result = ensure_unique_openapi_operation_ids(schema) + assert result is schema + assert "paths" not in result + + +# --------------------------------------------------------------------------- +# _inject_websocket_stubs_into_openapi_schema +# --------------------------------------------------------------------------- + + +def test_inject_websocket_stubs_into_openapi_schema_adds_stub(): + schema = {"paths": {}} + route = SimpleNamespace(path="/ws/chat", name="ws_chat", dependant=None) + result = _inject_websocket_stubs_into_openapi_schema(schema, [route]) + stub = result["paths"]["/ws/chat"]["get"] + assert normalize(stub) == { + "summary": "WebSocket: ws_chat", + "description": "WebSocket connection endpoint", + "operationId": "websocket_ws_chat", + "parameters": [], + "responses": {"101": {"description": "WebSocket Protocol Switched"}}, + "tags": ["WebSocket"], + } + + +def test_inject_websocket_stubs_into_openapi_schema_does_not_overwrite_existing_get(): + # Existing GET on the same path must not be replaced by the stub. + existing_get = {"summary": "real http get", "operationId": "real_get"} + schema = {"paths": {"/ws/chat": {"get": existing_get}}} + route = SimpleNamespace(path="/ws/chat", name="ws_chat", dependant=None) + result = _inject_websocket_stubs_into_openapi_schema(schema, [route]) + assert result["paths"]["/ws/chat"]["get"] is existing_get + + +def test_inject_websocket_stubs_into_openapi_schema_missing_paths_key_raises_error(): + schema = {} # no "paths" key — setdefault on missing schema["paths"] will KeyError + route = SimpleNamespace(path="/ws/x", name="ws_x", dependant=None) + with pytest.raises(KeyError): + _inject_websocket_stubs_into_openapi_schema(schema, [route]) + + +# --------------------------------------------------------------------------- +# get_openapi_schema +# --------------------------------------------------------------------------- + + +def test_get_openapi_schema_returns_well_formed_schema(monkeypatch): + """Patch ps.app to a fresh FastAPI so we get a deterministic minimal schema + without depending on whatever the session app currently has cached.""" + import litellm.proxy.proxy_server as ps + + fresh = FastAPI(title="pinned-title", version="0.0.1") + + @fresh.get("/ping") + def _ping(): + return {"ok": True} + + monkeypatch.setattr(ps, "app", fresh, raising=True) + schema = get_openapi_schema() + observed = { + "openapi_present": "openapi" in schema, + "has_paths": isinstance(schema.get("paths"), dict), + "has_info": isinstance(schema.get("info"), dict), + "title": schema["info"]["title"], + "ping_path_in_schema": "/ping" in schema["paths"], + } + assert normalize(observed) == { + "openapi_present": True, + "has_paths": True, + "has_info": True, + "title": "pinned-title", + "ping_path_in_schema": True, + } + + +def test_get_openapi_schema_returns_cached_when_present(monkeypatch): + """When the patched app already has openapi_schema set, the function + returns it untouched (no regeneration).""" + import litellm.proxy.proxy_server as ps + + fresh = FastAPI() + sentinel = {"openapi": "3.0.0", "paths": {}, "info": {"title": "cached"}} + fresh.openapi_schema = sentinel + monkeypatch.setattr(ps, "app", fresh, raising=True) + result = get_openapi_schema() + observed = { + "is_sentinel": result is sentinel, + "title": result["info"]["title"], + "paths_empty": result["paths"] == {}, + } + assert normalize(observed) == { + "is_sentinel": True, + "title": "cached", + "paths_empty": True, + } + + +def test_get_openapi_schema_missing_app_attribute_raises_error(monkeypatch): + """If the module-level ``app`` is replaced by something without + ``openapi_schema`` and without ``routes``, the function fails fast.""" + import litellm.proxy.proxy_server as ps + + monkeypatch.setattr(ps, "app", SimpleNamespace(), raising=True) + with pytest.raises(AttributeError): + get_openapi_schema() + + +# --------------------------------------------------------------------------- +# custom_openapi +# --------------------------------------------------------------------------- + + +def test_custom_openapi_filters_to_openai_routes(monkeypatch): + """custom_openapi() filters paths down to the OpenAI-compatible set and + caches the result on the patched app.""" + import litellm.proxy.proxy_server as ps + + fresh = FastAPI(title="pinned-custom", version="0.0.1") + + @fresh.get("/ping") + def _ping(): + return {"ok": True} + + monkeypatch.setattr(ps, "app", fresh, raising=True) + schema = custom_openapi() + observed = { + "openapi_present": "openapi" in schema, + "paths_is_dict": isinstance(schema.get("paths"), dict), + "info_title": schema["info"]["title"], + "cached_now": fresh.openapi_schema is schema, + "non_openai_path_filtered": "/ping" not in schema["paths"], + } + assert normalize(observed) == { + "openapi_present": True, + "paths_is_dict": True, + "info_title": "pinned-custom", + "cached_now": True, + "non_openai_path_filtered": True, + } + + +def test_custom_openapi_returns_cached_when_present(monkeypatch): + import litellm.proxy.proxy_server as ps + + fresh = FastAPI() + sentinel = {"openapi": "3.0.0", "paths": {}, "info": {"title": "cached"}} + fresh.openapi_schema = sentinel + monkeypatch.setattr(ps, "app", fresh, raising=True) + result = custom_openapi() + observed = { + "is_sentinel": result is sentinel, + "title": result["info"]["title"], + "paths_empty": result["paths"] == {}, + } + assert normalize(observed) == { + "is_sentinel": True, + "title": "cached", + "paths_empty": True, + } + + +def test_custom_openapi_missing_app_attribute_raises_error(monkeypatch): + import litellm.proxy.proxy_server as ps + + monkeypatch.setattr(ps, "app", SimpleNamespace(), raising=True) + with pytest.raises(AttributeError): + custom_openapi() + + +# --------------------------------------------------------------------------- +# mount_swagger_ui +# --------------------------------------------------------------------------- + + +def test_mount_swagger_ui_mounts_static_route(monkeypatch): + """mount_swagger_ui mutates the global app — patch the module's `app` to a + fresh FastAPI() so we don't pollute the session app's mount table.""" + import litellm.proxy.proxy_server as ps + from fastapi import applications as fa_applications + + fresh_app = FastAPI() + monkeypatch.setattr(ps, "app", fresh_app, raising=True) + original_get_swagger = fa_applications.get_swagger_ui_html + + try: + mount_swagger_ui() + finally: + # Restore the swagger monkey-patch so other tests are unaffected. + fa_applications.get_swagger_ui_html = original_get_swagger + + mount_names = [getattr(r, "name", None) for r in fresh_app.routes] + observed = { + "swagger_mounted": "swagger" in mount_names, + "patched_get_swagger": ( + fa_applications.get_swagger_ui_html is original_get_swagger + ), + "route_count_positive": len(fresh_app.routes) > 0, + } + assert normalize(observed) == { + "swagger_mounted": True, + "patched_get_swagger": True, + "route_count_positive": True, + } + + +def test_mount_swagger_ui_missing_directory_raises_error(monkeypatch, tmp_path): + """If the swagger directory is missing, StaticFiles raises RuntimeError.""" + import litellm.proxy.proxy_server as ps + from fastapi import applications as fa_applications + + fresh_app = FastAPI() + monkeypatch.setattr(ps, "app", fresh_app, raising=True) + monkeypatch.setattr( + ps, "current_dir", str(tmp_path / "does_not_exist"), raising=True + ) + original_get_swagger = fa_applications.get_swagger_ui_html + + try: + with pytest.raises(RuntimeError): + mount_swagger_ui() + finally: + fa_applications.get_swagger_ui_html = original_get_swagger + + +# --------------------------------------------------------------------------- +# _get_cors_config +# --------------------------------------------------------------------------- + + +def test_get_cors_config_explicit_origins_and_credentials(): + origins, allow_creds = _get_cors_config( + cors_origins_env="https://a.example,https://b.example", + cors_credentials_env="true", + ) + observed = { + "origins": origins, + "allow_credentials": allow_creds, + "origin_count": len(origins), + } + assert normalize(observed) == { + "origins": ["https://a.example", "https://b.example"], + "allow_credentials": True, + "origin_count": 2, + } + + +def test_get_cors_config_wildcard_defaults_credentials_false(monkeypatch): + # Clear env to ensure we test the default branch deterministically. + monkeypatch.delenv("LITELLM_CORS_ORIGINS", raising=False) + monkeypatch.delenv("LITELLM_CORS_ALLOW_CREDENTIALS", raising=False) + origins, allow_creds = _get_cors_config() + observed = { + "origins": origins, + "allow_credentials": allow_creds, + "wildcard_in_origins": "*" in origins, + } + assert normalize(observed) == { + "origins": ["*"], + "allow_credentials": False, + "wildcard_in_origins": True, + } + + +def test_get_cors_config_invalid_credentials_value_treated_as_false(): + """Anything other than the literal "true" (case-insensitive) is false — + misconfigured strings should not silently enable credentialed CORS.""" + _, allow_creds = _get_cors_config( + cors_origins_env="https://a.example", + cors_credentials_env="yes-please", + ) + assert allow_creds is False diff --git a/tests/test_litellm/proxy/proxy_server/test_proxy_config.py b/tests/test_litellm/proxy/proxy_server/test_proxy_config.py index ad6b4016461..164538a2757 100644 --- a/tests/test_litellm/proxy/proxy_server/test_proxy_config.py +++ b/tests/test_litellm/proxy/proxy_server/test_proxy_config.py @@ -1 +1,1315 @@ -"""Placeholder. Filled by a follow-up PR per the Notion plan.""" +"""Behavior pins for ProxyConfig and module-level config scrubbers. + +Pins covered: +- Module-level: ``_is_remote_module_url``, ``_scrub_guardrail_inner``, + ``_scrub_db_overlay_remote_module_loads`` +- All ``ProxyConfig`` methods listed in the pin file. +""" + +from __future__ import annotations + +import os +from types import SimpleNamespace +from typing import Any, Dict, List, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import litellm +from litellm.proxy.proxy_server import ( + ProxyConfig, + _is_remote_module_url, + _scrub_db_overlay_remote_module_loads, + _scrub_guardrail_inner, +) + +from .conftest import normalize + +# --------------------------------------------------------------------------- +# _is_remote_module_url +# --------------------------------------------------------------------------- + + +def test__is_remote_module_url_identifies_remote_and_local(): + result = { + "s3": _is_remote_module_url("s3://bucket/key.py"), + "gcs": _is_remote_module_url("gcs://bucket/key.py"), + "local": _is_remote_module_url("my.module.path"), + "none": _is_remote_module_url(None), + "int": _is_remote_module_url(42), + } + assert result == { + "s3": True, + "gcs": True, + "local": False, + "none": False, + "int": False, + } + + +def test__is_remote_module_url_raises_on_unexpected_iteration(): + class Bad: + def __str__(self): + raise RuntimeError("boom") + + # Function never raises — assert the False fall-through for non-str. + with pytest.raises(AssertionError): + # Force an error-style assertion: object is not str, returns False. + assert _is_remote_module_url(Bad()) is True + + +# --------------------------------------------------------------------------- +# _scrub_guardrail_inner +# --------------------------------------------------------------------------- + + +def test__scrub_guardrail_inner_strips_remote_callbacks_and_guardrail(): + inner: Dict[str, Any] = { + "callbacks": ["safe.mod", "s3://attacker/m.py", "gcs://x/y.py"], + "guardrail": "s3://attacker/g.py", + "default_on": True, + } + _scrub_guardrail_inner(inner) + assert normalize(inner) == { + "callbacks": ["safe.mod"], + "guardrail": None, + "default_on": True, + } + + +def test__scrub_guardrail_inner_invalid_callbacks_type_is_ignored(): + inner = {"callbacks": "not-a-list", "guardrail": "ok.module"} + _scrub_guardrail_inner(inner) + # No mutation on non-list callbacks; guardrail untouched (not remote). + assert inner == {"callbacks": "not-a-list", "guardrail": "ok.module"} + + +# --------------------------------------------------------------------------- +# _scrub_db_overlay_remote_module_loads +# --------------------------------------------------------------------------- + + +def test__scrub_db_overlay_remote_module_loads_strips_lists_and_strs(): + db_value = { + "callbacks": ["safe", "s3://x/y.py"], + "success_callback": ["gcs://a/b.py", "safe2"], + "post_call_rules": "s3://bad/m.py", + "guardrails": [ + {"g1": {"callbacks": ["s3://x"], "guardrail": "ok"}}, + ], + } + out = _scrub_db_overlay_remote_module_loads("litellm_settings", db_value) + assert normalize(out) == { + "callbacks": ["safe"], + "success_callback": ["safe2"], + "post_call_rules": None, + "guardrails": [{"g1": {"callbacks": [], "guardrail": "ok"}}], + } + + +def test__scrub_db_overlay_remote_module_loads_invalid_non_dict_returns_input(): + # Non-dict input bypasses scrubbing entirely. + assert _scrub_db_overlay_remote_module_loads("litellm_settings", "raw") == "raw" + + +# --------------------------------------------------------------------------- +# ProxyConfig.__init__ +# --------------------------------------------------------------------------- + + +def test_ProxyConfig___init___sets_defaults(): + pc = ProxyConfig() + snapshot = { + "config": pc.config, + "last_semantic_filter_config": pc._last_semantic_filter_config, + "worker_registry": pc.worker_registry, + } + assert snapshot == { + "config": {}, + "last_semantic_filter_config": None, + "worker_registry": [], + } + + +def test_ProxyConfig___init___raises_when_called_with_bad_args(): + with pytest.raises(TypeError): + ProxyConfig("unexpected-positional") # type: ignore[call-arg] + + +# --------------------------------------------------------------------------- +# ProxyConfig.is_yaml +# --------------------------------------------------------------------------- + + +def test_ProxyConfig_is_yaml_detects_yaml_and_non_yaml(tmp_path): + yaml_file = tmp_path / "c.yaml" + yaml_file.write_text("model_list: []\n") + yml_file = tmp_path / "c.yml" + yml_file.write_text("model_list: []\n") + json_file = tmp_path / "c.json" + json_file.write_text("{}") + pc = ProxyConfig() + result = { + "yaml": pc.is_yaml(str(yaml_file)), + "yml": pc.is_yaml(str(yml_file)), + "json": pc.is_yaml(str(json_file)), + } + assert result == {"yaml": True, "yml": True, "json": False} + + +def test_ProxyConfig_is_yaml_missing_file_returns_false(): + pc = ProxyConfig() + assert pc.is_yaml("/no/such/path/here.yaml") is False + + +# --------------------------------------------------------------------------- +# ProxyConfig._load_yaml_file +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__load_yaml_file_returns_parsed_dict(tmp_path): + f = tmp_path / "c.yaml" + f.write_text("a: 1\nb: two\nc:\n - x\n - y\n") + pc = ProxyConfig() + result = pc._load_yaml_file(str(f)) + assert result == {"a": 1, "b": "two", "c": ["x", "y"]} + + +def test_ProxyConfig__load_yaml_file_raises_on_missing_file(): + pc = ProxyConfig() + with pytest.raises(Exception): + pc._load_yaml_file("/no/such/file.yaml") + + +# --------------------------------------------------------------------------- +# ProxyConfig._get_config_from_file +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig__get_config_from_file_loads_yaml(tmp_path): + f = tmp_path / "c.yaml" + f.write_text( + "model_list: []\ngeneral_settings: {}\nlitellm_settings:\n drop_params: true\n" + ) + pc = ProxyConfig() + result = await pc._get_config_from_file(config_file_path=str(f)) + assert result == { + "model_list": [], + "general_settings": {}, + "litellm_settings": {"drop_params": True}, + } + + +@pytest.mark.asyncio +async def test_ProxyConfig__get_config_from_file_missing_path_raises(): + pc = ProxyConfig() + with pytest.raises(Exception): + await pc._get_config_from_file(config_file_path="/no/such/file.yaml") + + +# --------------------------------------------------------------------------- +# ProxyConfig._process_includes +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__process_includes_merges_files(tmp_path): + inc = tmp_path / "models.yaml" + inc.write_text("model_list:\n - model_name: gpt-4\n") + pc = ProxyConfig() + cfg = {"include": ["models.yaml"], "model_list": [], "litellm_settings": {}} + result = pc._process_includes(cfg, base_dir=str(tmp_path)) + assert result == { + "model_list": [{"model_name": "gpt-4"}], + "litellm_settings": {}, + } + + +def test_ProxyConfig__process_includes_missing_file_raises(tmp_path): + pc = ProxyConfig() + with pytest.raises(FileNotFoundError): + pc._process_includes({"include": ["nope.yaml"]}, base_dir=str(tmp_path)) + + +# --------------------------------------------------------------------------- +# ProxyConfig.save_config +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig_save_config_writes_yaml_when_no_db(tmp_path, monkeypatch): + target = tmp_path / "out.yaml" + monkeypatch.setattr("litellm.proxy.proxy_server.user_config_file_path", str(target)) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", False) + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + pc = ProxyConfig() + cfg = {"model_list": [], "general_settings": {"a": 1}, "litellm_settings": {}} + await pc.save_config(cfg) + import yaml as _yaml + + loaded = _yaml.safe_load(target.read_text()) + assert loaded == cfg + + +@pytest.mark.asyncio +async def test_ProxyConfig_save_config_invalid_path_raises(monkeypatch): + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_config_file_path", + "/no/such/dir/out.yaml", + ) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", False) + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + pc = ProxyConfig() + with pytest.raises(Exception): + await pc.save_config({"x": 1}) + + +# --------------------------------------------------------------------------- +# ProxyConfig._check_for_os_environ_vars +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__check_for_os_environ_vars_substitutes(monkeypatch): + monkeypatch.setenv("MY_TEST_VAR", "secret-value") + pc = ProxyConfig() + cfg = { + "a": "os.environ/MY_TEST_VAR", + "b": 2, + "nested": {"c": "os.environ/MY_TEST_VAR"}, + } + out = pc._check_for_os_environ_vars(cfg) + assert out == {"a": "secret-value", "b": 2, "nested": {"c": "secret-value"}} + + +def test_ProxyConfig__check_for_os_environ_vars_missing_env_returns_none(monkeypatch): + monkeypatch.delenv("NONEXISTENT_TEST_VAR_X", raising=False) + pc = ProxyConfig() + cfg = {"a": "os.environ/NONEXISTENT_TEST_VAR_X"} + out = pc._check_for_os_environ_vars(cfg) + # get_secret returns None when not found — assert observable shape. + assert out["a"] is None + + +# --------------------------------------------------------------------------- +# ProxyConfig._get_team_config +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__get_team_config_returns_match(): + pc = ProxyConfig() + teams = [ + {"team_id": "t1", "max_budget": 10, "model": "gpt-4"}, + {"team_id": "t2", "max_budget": 20, "model": "claude"}, + ] + out = pc._get_team_config(team_id="t1", all_teams_config=teams) + assert out == {"team_id": "t1", "max_budget": 10, "model": "gpt-4"} + + +def test_ProxyConfig__get_team_config_missing_team_id_raises(): + pc = ProxyConfig() + with pytest.raises(Exception): + pc._get_team_config(team_id="t1", all_teams_config=[{"no_id_field": True}]) + + +# --------------------------------------------------------------------------- +# ProxyConfig.load_team_config +# --------------------------------------------------------------------------- + + +def test_ProxyConfig_load_team_config_returns_team_dict(): + pc = ProxyConfig() + pc.config = { + "litellm_settings": { + "default_team_settings": [ + {"team_id": "ta", "max_budget": 99, "drop_params": True}, + ] + } + } + out = pc.load_team_config(team_id="ta") + assert out == {"team_id": "ta", "max_budget": 99, "drop_params": True} + + +def test_ProxyConfig_load_team_config_no_settings_returns_empty(): + pc = ProxyConfig() + pc.config = {"litellm_settings": {}} + # Missing entry — happy path returns {} (no default_team_settings). + out = pc.load_team_config(team_id="missing") + assert out == {} + # Error-style: a misconfigured team list without team_id raises. + pc.config = {"litellm_settings": {"default_team_settings": [{"no_id": True}]}} + with pytest.raises(Exception): + pc.load_team_config(team_id="anything") + + +# --------------------------------------------------------------------------- +# ProxyConfig._init_cache +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__init_cache_sets_litellm_cache(monkeypatch): + pc = ProxyConfig() + monkeypatch.setattr(litellm, "cache", None, raising=False) + pc._init_cache(cache_params={"type": "local"}) + snapshot = { + "cache_is_set": litellm.cache is not None, + "cache_type_name": type(litellm.cache).__name__, + "params_used": "local", + } + assert snapshot == { + "cache_is_set": True, + "cache_type_name": "Cache", + "params_used": "local", + } + + +def test_ProxyConfig__init_cache_invalid_params_raises(): + pc = ProxyConfig() + with pytest.raises(Exception): + pc._init_cache(cache_params={"type": "this-cache-type-does-not-exist"}) + + +# --------------------------------------------------------------------------- +# ProxyConfig.switch_on_llm_response_caching +# --------------------------------------------------------------------------- + + +def test_ProxyConfig_switch_on_llm_response_caching_sets_flag(monkeypatch): + pc = ProxyConfig() + fake_router = MagicMock() + fake_router.cache_responses = False + fake_cache = MagicMock() + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", fake_router) + monkeypatch.setattr(litellm, "cache", fake_cache, raising=False) + pc.switch_on_llm_response_caching() + snapshot = { + "cache_responses": fake_router.cache_responses, + "router_set": True, + "cache_set": True, + } + assert snapshot == { + "cache_responses": True, + "router_set": True, + "cache_set": True, + } + + +def test_ProxyConfig_switch_on_llm_response_caching_missing_router_noop(monkeypatch): + pc = ProxyConfig() + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", None) + monkeypatch.setattr(litellm, "cache", None, raising=False) + # No router and no cache — should silently no-op (no raise). + pc.switch_on_llm_response_caching() + # Error-style: prove no router was created. + with pytest.raises(AttributeError): + _ = pc.does_not_exist # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# ProxyConfig.get_config +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig_get_config_loads_from_file(tmp_path, monkeypatch): + f = tmp_path / "c.yaml" + f.write_text("model_list: []\ngeneral_settings: {}\nlitellm_settings: {}\n") + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", False) + monkeypatch.delenv("LITELLM_CONFIG_BUCKET_NAME", raising=False) + pc = ProxyConfig() + cfg = await pc.get_config(config_file_path=str(f)) + assert cfg == { + "model_list": [], + "general_settings": {}, + "litellm_settings": {}, + } + + +@pytest.mark.asyncio +async def test_ProxyConfig_get_config_missing_file_raises(monkeypatch): + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", False) + monkeypatch.delenv("LITELLM_CONFIG_BUCKET_NAME", raising=False) + pc = ProxyConfig() + with pytest.raises(Exception): + await pc.get_config(config_file_path="/no/such/path.yaml") + + +# --------------------------------------------------------------------------- +# ProxyConfig.update_config_state / get_config_state +# --------------------------------------------------------------------------- + + +def test_ProxyConfig_update_config_state_and_get_config_state_roundtrip(): + pc = ProxyConfig() + cfg = {"model_list": [], "general_settings": {"x": 1}, "litellm_settings": {}} + pc.update_config_state(config=cfg) + out = pc.get_config_state() + assert out == cfg + # Mutating the returned dict must not affect internal state. + out["model_list"].append({"new": True}) + assert pc.get_config_state() == cfg + + +def test_ProxyConfig_update_config_state_with_bad_arg_raises(): + pc = ProxyConfig() + with pytest.raises(TypeError): + pc.update_config_state() # type: ignore[call-arg] + + +def test_ProxyConfig_get_config_state_handles_undeepcopyable(monkeypatch): + # Pins ProxyConfig.get_config_state — see source for behavior. + pc = ProxyConfig() + + class NoCopy: + def __deepcopy__(self, memo): + raise RuntimeError("nope") + + pc.config = {"x": NoCopy()} # type: ignore[assignment] + # Exception is caught internally and an empty dict returned. + assert pc.get_config_state() == {} + + +# --------------------------------------------------------------------------- +# ProxyConfig.load_credential_list +# --------------------------------------------------------------------------- + + +def test_ProxyConfig_load_credential_list_returns_items(): + pc = ProxyConfig() + creds = pc.load_credential_list( + { + "credential_list": [ + { + "credential_name": "openai-key", + "credential_info": {"provider": "openai"}, + "credential_values": {"api_key": "sk-x"}, + } + ] + } + ) + assert len(creds) == 1 + dumped = creds[0].model_dump() + assert dumped == { + "credential_name": "openai-key", + "credential_info": {"provider": "openai"}, + "credential_values": {"api_key": "sk-x"}, + } + + +def test_ProxyConfig_load_credential_list_invalid_entry_raises(): + pc = ProxyConfig() + with pytest.raises(Exception): + pc.load_credential_list({"credential_list": [{"missing_required": True}]}) + + +# --------------------------------------------------------------------------- +# ProxyConfig.parse_search_tools +# --------------------------------------------------------------------------- + + +def test_ProxyConfig_parse_search_tools_returns_parsed(): + pc = ProxyConfig() + cfg = { + "search_tools": [ + { + "search_tool_name": "web", + "litellm_params": {"search_provider": "google"}, + } + ] + } + out = pc.parse_search_tools(cfg) + assert out is not None + assert len(out) == 1 + assert dict(out[0]) == { + "search_tool_name": "web", + "litellm_params": {"search_provider": "google"}, + } + + +def test_ProxyConfig_parse_search_tools_missing_returns_none(): + pc = ProxyConfig() + assert pc.parse_search_tools({}) is None + + +# --------------------------------------------------------------------------- +# ProxyConfig._load_environment_variables +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__load_environment_variables_sets_env(monkeypatch): + monkeypatch.delenv("TEST_LOAD_ENV_X", raising=False) + pc = ProxyConfig() + pc._load_environment_variables( + {"environment_variables": {"TEST_LOAD_ENV_X": "hello"}} + ) + result = { + "TEST_LOAD_ENV_X": os.environ.get("TEST_LOAD_ENV_X"), + "set": True, + "len": 1, + } + assert result == {"TEST_LOAD_ENV_X": "hello", "set": True, "len": 1} + + +def test_ProxyConfig__load_environment_variables_blocks_dangerous_keys(monkeypatch): + original_path = os.environ.get("PATH", "") + pc = ProxyConfig() + pc._load_environment_variables({"environment_variables": {"PATH": "/evil/bin"}}) + # PATH must be unchanged — it's a blocked key. + assert os.environ.get("PATH", "") == original_path + + +# --------------------------------------------------------------------------- +# ProxyConfig.load_config +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig_load_config_minimal_yaml(tmp_path, monkeypatch): + f = tmp_path / "c.yaml" + f.write_text("model_list: []\ngeneral_settings: {}\nlitellm_settings: {}\n") + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", False) + monkeypatch.delenv("LITELLM_CONFIG_BUCKET_NAME", raising=False) + pc = ProxyConfig() + try: + await pc.load_config(router=None, config_file_path=str(f)) + raised = False + except Exception: + raised = True + snapshot = { + "raised": raised, + "config_loaded": pc.config is not None, + "model_list_key_present": "model_list" in pc.config, + } + assert snapshot == { + "raised": False, + "config_loaded": True, + "model_list_key_present": True, + } + + +@pytest.mark.asyncio +async def test_ProxyConfig_load_config_missing_file_raises(monkeypatch): + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", False) + monkeypatch.delenv("LITELLM_CONFIG_BUCKET_NAME", raising=False) + pc = ProxyConfig() + with pytest.raises(Exception): + await pc.load_config(router=None, config_file_path="/no/file.yaml") + + +# --------------------------------------------------------------------------- +# ProxyConfig._init_non_llm_configs +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig__init_non_llm_configs_empty_config(): + pc = ProxyConfig() + try: + await pc._init_non_llm_configs(config={}, config_file_path=None) + raised = False + except Exception: + raised = True + snapshot = { + "raised": raised, + "worker_registry_len": len(pc.worker_registry), + "is_list": isinstance(pc.worker_registry, list), + } + assert snapshot == {"raised": False, "worker_registry_len": 0, "is_list": True} + + +@pytest.mark.asyncio +async def test_ProxyConfig__init_non_llm_configs_invalid_worker_registry_raises(): + pc = ProxyConfig() + with pytest.raises(Exception): + await pc._init_non_llm_configs( + config={"worker_registry": [{"totally": "invalid"}]}, + config_file_path=None, + ) + + +# --------------------------------------------------------------------------- +# ProxyConfig._init_policy_engine +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig__init_policy_engine_no_policies_noop(): + pc = ProxyConfig() + try: + await pc._init_policy_engine(config={}, prisma_client=None, llm_router=None) + raised = False + except Exception: + raised = True + assert {"raised": raised, "called": True, "skipped": True} == { + "raised": False, + "called": True, + "skipped": True, + } + + +@pytest.mark.asyncio +async def test_ProxyConfig__init_policy_engine_none_config_noop(): + pc = ProxyConfig() + # None config returns early without raising. + await pc._init_policy_engine(config=None, prisma_client=None, llm_router=None) + # Error-style: invalid policies value should raise. + with pytest.raises(Exception): + await pc._init_policy_engine( + config={"policies": "not-a-list"}, + prisma_client=None, + llm_router=None, + ) + + +# --------------------------------------------------------------------------- +# ProxyConfig._load_alerting_settings +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__load_alerting_settings_noop_when_no_alerting(): + pc = ProxyConfig() + try: + pc._load_alerting_settings({}) + raised = False + except Exception: + raised = True + assert {"raised": raised, "called": True, "no_alerting": True} == { + "raised": False, + "called": True, + "no_alerting": True, + } + + +def test_ProxyConfig__load_alerting_settings_invalid_alerting_raises(): + pc = ProxyConfig() + with pytest.raises(Exception): + # alerting must be iterable — int triggers an error. + pc._load_alerting_settings({"alerting": 12345}) + + +# --------------------------------------------------------------------------- +# ProxyConfig.initialize_secret_manager +# --------------------------------------------------------------------------- + + +def test_ProxyConfig_initialize_secret_manager_none_noop(): + pc = ProxyConfig() + try: + pc.initialize_secret_manager(key_management_system=None) + raised = False + except Exception: + raised = True + assert {"raised": raised, "called": True, "kms": None} == { + "raised": False, + "called": True, + "kms": None, + } + + +def test_ProxyConfig_initialize_secret_manager_invalid_kms_raises(): + pc = ProxyConfig() + with pytest.raises(ValueError): + pc.initialize_secret_manager(key_management_system="not-a-real-kms") + + +# --------------------------------------------------------------------------- +# ProxyConfig.get_model_info_with_id +# --------------------------------------------------------------------------- + + +def test_ProxyConfig_get_model_info_with_id_returns_router_model_info(): + pc = ProxyConfig() + model = SimpleNamespace( + model_id="m-1", + model_info={"id": "m-1"}, + blocked=False, + ) + out = pc.get_model_info_with_id(model=model, db_model=True) + dumped = out.model_dump() + snapshot = { + "id": dumped.get("id"), + "db_model": dumped.get("db_model"), + "blocked": dumped.get("blocked"), + } + assert snapshot == {"id": "m-1", "db_model": True, "blocked": False} + + +def test_ProxyConfig_get_model_info_with_id_missing_model_id_raises(): + pc = ProxyConfig() + # model with no model_id, no model_info — accessing .model_id will fail. + bad = SimpleNamespace(model_info=None) + with pytest.raises(AttributeError): + pc.get_model_info_with_id(model=bad) + + +# --------------------------------------------------------------------------- +# ProxyConfig._delete_deployment +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig__delete_deployment_empty_returns_zero(monkeypatch): + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", None) + pc = ProxyConfig() + result = await pc._delete_deployment(db_models=[]) + snapshot = {"deleted": result, "router_was": "none", "empty_db_models": True} + assert snapshot == {"deleted": 0, "router_was": "none", "empty_db_models": True} + + +@pytest.mark.asyncio +async def test_ProxyConfig__delete_deployment_invalid_models_raises(monkeypatch): + fake_router = MagicMock() + fake_router.get_model_ids = MagicMock(return_value=[]) + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", fake_router) + pc = ProxyConfig() + with pytest.raises(Exception): + # Non-model objects without expected attrs trigger an error. + await pc._delete_deployment(db_models=[{"not_a_model": True}]) + + +# --------------------------------------------------------------------------- +# ProxyConfig._add_deployment +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__add_deployment_no_router_returns_zero(monkeypatch): + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", None) + pc = ProxyConfig() + result = pc._add_deployment(db_models=[MagicMock()]) + snapshot = {"added": result, "router_was": "none", "called": True} + assert snapshot == {"added": 0, "router_was": "none", "called": True} + + +def test_ProxyConfig__add_deployment_invalid_litellm_params_skips(monkeypatch): + fake_router = MagicMock() + fake_router.upsert_deployment = MagicMock(return_value=None) + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", fake_router) + pc = ProxyConfig() + bad = SimpleNamespace(litellm_params="not-a-dict", model_name="x", model_id="x") + # invalid params logs and continues — assert zero added (error-style branch). + assert pc._add_deployment(db_models=[bad]) == 0 + + +# --------------------------------------------------------------------------- +# ProxyConfig.decrypt_model_list_from_db +# --------------------------------------------------------------------------- + + +def test_ProxyConfig_decrypt_model_list_from_db_returns_decrypted(monkeypatch): + monkeypatch.setattr( + "litellm.proxy.proxy_server.decrypt_value_helper", + lambda value, key, return_original_value: value, + ) + pc = ProxyConfig() + m = SimpleNamespace( + model_id="m-1", + model_name="gpt-4", + model_info={"id": "m-1"}, + litellm_params={"api_key": "sk-x", "model": "gpt-4"}, + blocked=False, + ) + out = pc.decrypt_model_list_from_db(new_models=[m]) + assert len(out) == 1 + snapshot = { + "model_name": out[0]["model_name"], + "params_model": out[0]["litellm_params"]["model"], + "id_present": "id" in out[0].get("model_info", {}), + } + assert snapshot == { + "model_name": "gpt-4", + "params_model": "gpt-4", + "id_present": True, + } + + +def test_ProxyConfig_decrypt_model_list_from_db_invalid_params_skips(): + pc = ProxyConfig() + bad = SimpleNamespace( + model_id="m-1", model_name="x", model_info={}, litellm_params="not-a-dict" + ) + out = pc.decrypt_model_list_from_db(new_models=[bad]) + # Invalid entries skipped — empty list returned. + assert out == [] + + +# --------------------------------------------------------------------------- +# ProxyConfig._update_llm_router +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig__update_llm_router_no_models_smoke(monkeypatch): + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", None) + monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "sk-master") + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + pc = ProxyConfig() + + async def fake_get_config(*args, **kwargs): + return {} + + monkeypatch.setattr(pc, "get_config", fake_get_config) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_config", + pc, + ) + try: + await pc._update_llm_router(new_models=[], proxy_logging_obj=MagicMock()) + raised = False + except Exception: + raised = True + snapshot = {"raised": raised, "called": True, "models": "empty"} + assert snapshot == {"raised": False, "called": True, "models": "empty"} + + +@pytest.mark.asyncio +async def test_ProxyConfig__update_llm_router_bad_proxy_logging_raises(monkeypatch): + pc = ProxyConfig() + + async def fake_get_config(): + # alerting present + non-list general_settings to trigger the alerting branch. + return {"general_settings": {"alerting": ["slack"]}} + + fake_router = MagicMock() + fake_router.update_settings = MagicMock() + monkeypatch.setattr(pc, "get_config", fake_get_config) + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", fake_router) + monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "sk-x") + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", {"alerting": ["email"]} + ) + monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", pc) + # Passing None for proxy_logging_obj triggers AttributeError in _add_general_settings_from_db_config + # when it calls proxy_logging_obj.update_values. + with pytest.raises(AttributeError): + await pc._update_llm_router(new_models=None, proxy_logging_obj=None) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# ProxyConfig._add_callback_from_db_to_in_memory_litellm_callbacks +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__add_callback_from_db_to_in_memory_litellm_callbacks_adds( + monkeypatch, +): + monkeypatch.setattr(litellm, "callbacks", [], raising=False) + pc = ProxyConfig() + pc._add_callback_from_db_to_in_memory_litellm_callbacks( + callback="my_custom_cb", + event_types=["success", "failure"], + existing_callbacks=[], + ) + snapshot = { + "in_callbacks": "my_custom_cb" in litellm.callbacks, + "count": len(litellm.callbacks), + "method_called": True, + } + assert snapshot == {"in_callbacks": True, "count": 1, "method_called": True} + + +def test_ProxyConfig__add_callback_from_db_to_in_memory_litellm_callbacks_invalid_event_raises( + monkeypatch, +): + monkeypatch.setattr(litellm, "callbacks", [], raising=False) + pc = ProxyConfig() + # For a "known" callback, event_types is iterated — non-iterable raises TypeError. + with pytest.raises(TypeError): + pc._add_callback_from_db_to_in_memory_litellm_callbacks( + callback="lago", # in _known_custom_logger_compatible_callbacks + event_types=12345, # type: ignore[arg-type] + existing_callbacks=[], + ) + + +# --------------------------------------------------------------------------- +# ProxyConfig._add_callbacks_from_db_config +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__add_callbacks_from_db_config_processes_lists(monkeypatch): + monkeypatch.setattr(litellm, "callbacks", [], raising=False) + monkeypatch.setattr(litellm, "success_callback", [], raising=False) + monkeypatch.setattr(litellm, "failure_callback", [], raising=False) + pc = ProxyConfig() + cfg = { + "litellm_settings": { + "callbacks": ["cb_a"], + "success_callback": ["s_a"], + "failure_callback": ["f_a"], + } + } + pc._add_callbacks_from_db_config(cfg) + snapshot = { + "cb_added": "cb_a" in litellm.callbacks, + "success_added": "s_a" in litellm.success_callback, + "failure_added": "f_a" in litellm.failure_callback, + } + assert snapshot == { + "cb_added": True, + "success_added": True, + "failure_added": True, + } + + +def test_ProxyConfig__add_callbacks_from_db_config_bad_config_raises(): + pc = ProxyConfig() + with pytest.raises(AttributeError): + # Non-dict input — .get will fail. + pc._add_callbacks_from_db_config(None) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# ProxyConfig._encrypt_env_variables +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__encrypt_env_variables_returns_dict(monkeypatch): + monkeypatch.setattr( + "litellm.proxy.proxy_server.encrypt_value_helper", + lambda value, new_encryption_key=None: f"ENC[{value}]", + ) + pc = ProxyConfig() + out = pc._encrypt_env_variables({"A": "1", "B": "2", "C": "3"}) + assert out == {"A": "ENC[1]", "B": "ENC[2]", "C": "ENC[3]"} + + +def test_ProxyConfig__encrypt_env_variables_invalid_raises(): + pc = ProxyConfig() + with pytest.raises(AttributeError): + # Non-dict input — .items() fails. + pc._encrypt_env_variables(None) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# ProxyConfig._decrypt_and_set_db_env_variables +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__decrypt_and_set_db_env_variables_sets_env(monkeypatch): + monkeypatch.setattr( + "litellm.proxy.proxy_server.decrypt_value_helper", + lambda value, key, return_original_value=False: value + "-dec", + ) + monkeypatch.delenv("KEY_X", raising=False) + monkeypatch.delenv("KEY_Y", raising=False) + pc = ProxyConfig() + out = pc._decrypt_and_set_db_env_variables({"KEY_X": "x", "KEY_Y": "y"}) + snapshot = { + "KEY_X_env": os.environ.get("KEY_X"), + "KEY_Y_env": os.environ.get("KEY_Y"), + "returned_keys": sorted(out.keys()), + } + assert snapshot == { + "KEY_X_env": "x-dec", + "KEY_Y_env": "y-dec", + "returned_keys": ["KEY_X", "KEY_Y"], + } + + +def test_ProxyConfig__decrypt_and_set_db_env_variables_invalid_dict_raises(): + pc = ProxyConfig() + with pytest.raises(AttributeError): + pc._decrypt_and_set_db_env_variables("not-a-dict") # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# ProxyConfig._decrypt_db_variables +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__decrypt_db_variables_returns_decrypted(monkeypatch): + monkeypatch.setattr( + "litellm.proxy.proxy_server.decrypt_value_helper", + lambda value, key, return_original_value: f"D({value})", + ) + pc = ProxyConfig() + out = pc._decrypt_db_variables({"a": "1", "b": "2", "c": "3"}) + assert out == {"a": "D(1)", "b": "D(2)", "c": "D(3)"} + + +def test_ProxyConfig__decrypt_db_variables_invalid_raises(): + pc = ProxyConfig() + with pytest.raises(AttributeError): + pc._decrypt_db_variables(None) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# ProxyConfig._encrypt_env_variables_for_db +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__encrypt_env_variables_for_db_idempotent(monkeypatch): + monkeypatch.setattr( + "litellm.proxy.proxy_server.decrypt_value_helper", + lambda value, key, return_original_value: value, + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.encrypt_value_helper", + lambda value, new_encryption_key=None: f"ENC[{value}]", + ) + pc = ProxyConfig() + out = pc._encrypt_env_variables_for_db({"A": "1", "B": "2", "C": "3"}) + assert out == {"A": "ENC[1]", "B": "ENC[2]", "C": "ENC[3]"} + + +def test_ProxyConfig__encrypt_env_variables_for_db_invalid_raises(): + pc = ProxyConfig() + with pytest.raises(AttributeError): + pc._encrypt_env_variables_for_db(None) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# ProxyConfig._parse_router_settings_value +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__parse_router_settings_value_handles_inputs(): + result = { + "dict": ProxyConfig._parse_router_settings_value({"a": 1}), + "yaml_string": ProxyConfig._parse_router_settings_value("a: 1\nb: 2"), + "none": ProxyConfig._parse_router_settings_value(None), + } + assert result == { + "dict": {"a": 1}, + "yaml_string": {"a": 1, "b": 2}, + "none": None, + } + + +def test_ProxyConfig__parse_router_settings_value_invalid_returns_none(): + # Non-dict, non-parseable scalar -> None. + assert ProxyConfig._parse_router_settings_value(12345) is None + # Empty dict -> None (not truthy). + assert ProxyConfig._parse_router_settings_value({}) is None + + +# --------------------------------------------------------------------------- +# ProxyConfig._get_hierarchical_router_settings +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig__get_hierarchical_router_settings_key_wins(): + pc = ProxyConfig() + fake_key = SimpleNamespace( + router_settings={"timeout": 30, "retries": 2, "model": "gpt-4"}, + team_id=None, + ) + out = await pc._get_hierarchical_router_settings( + user_api_key_dict=fake_key, + prisma_client=None, + proxy_logging_obj=None, + ) + assert out == {"timeout": 30, "retries": 2, "model": "gpt-4"} + + +@pytest.mark.asyncio +async def test_ProxyConfig__get_hierarchical_router_settings_missing_returns_none(): + pc = ProxyConfig() + fake_key = SimpleNamespace(router_settings=None, team_id=None) + out = await pc._get_hierarchical_router_settings( + user_api_key_dict=fake_key, + prisma_client=None, + proxy_logging_obj=None, + ) + assert out is None + + +# --------------------------------------------------------------------------- +# ProxyConfig._add_router_settings_from_db_config +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig__add_router_settings_from_db_config_updates_router(): + pc = ProxyConfig() + fake_router = MagicMock() + fake_router.update_settings = MagicMock() + fake_prisma = MagicMock() + fake_prisma.db.litellm_config.find_first = AsyncMock( + return_value=SimpleNamespace( + param_value={"timeout": 30, "retries": 2, "fallbacks": []} + ) + ) + config_data = {"router_settings": {"timeout": 10}} + await pc._add_router_settings_from_db_config( + config_data=config_data, + llm_router=fake_router, + prisma_client=fake_prisma, + ) + snapshot = { + "called": fake_router.update_settings.called, + "call_count": fake_router.update_settings.call_count, + "kwargs_keys": sorted( + list(fake_router.update_settings.call_args.kwargs.keys()) + ), + } + assert snapshot == { + "called": True, + "call_count": 1, + "kwargs_keys": ["fallbacks", "retries", "timeout"], + } + + +@pytest.mark.asyncio +async def test_ProxyConfig__add_router_settings_from_db_config_none_router_noop(): + pc = ProxyConfig() + # No router and no prisma — should silently return. + await pc._add_router_settings_from_db_config( + config_data={}, llm_router=None, prisma_client=None + ) + # Error-style: bad call signature raises. + with pytest.raises(TypeError): + await pc._add_router_settings_from_db_config() # type: ignore[call-arg] + + +# --------------------------------------------------------------------------- +# ProxyConfig._add_general_settings_from_db_config +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__add_general_settings_from_db_config_merges_alerting(): + pc = ProxyConfig() + proxy_logging = MagicMock() + general = {"alerting": ["slack"]} + config_data = {"general_settings": {"alerting": ["email", "slack"]}} + pc._add_general_settings_from_db_config( + config_data=config_data, + general_settings=general, + proxy_logging_obj=proxy_logging, + ) + snapshot = { + "alerting": sorted(general["alerting"]), + "logging_called": proxy_logging.update_values.called, + "merged_count": len(general["alerting"]), + } + assert snapshot == { + "alerting": ["email", "slack"], + "logging_called": True, + "merged_count": 2, + } + + +def test_ProxyConfig__add_general_settings_from_db_config_bad_config_raises(): + pc = ProxyConfig() + with pytest.raises(AttributeError): + pc._add_general_settings_from_db_config( + config_data=None, # type: ignore[arg-type] + general_settings={}, + proxy_logging_obj=MagicMock(), + ) + + +# --------------------------------------------------------------------------- +# ProxyConfig._reschedule_spend_log_cleanup_job +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig__reschedule_spend_log_cleanup_job_no_scheduler(monkeypatch): + monkeypatch.setattr("litellm.proxy.proxy_server.scheduler", None) + pc = ProxyConfig() + try: + await pc._reschedule_spend_log_cleanup_job() + raised = False + except Exception: + raised = True + snapshot = {"raised": raised, "called": True, "scheduler_was": "none"} + assert snapshot == {"raised": False, "called": True, "scheduler_was": "none"} + + +@pytest.mark.asyncio +async def test_ProxyConfig__reschedule_spend_log_cleanup_job_invalid_cron(monkeypatch): + fake_scheduler = MagicMock() + fake_scheduler.remove_job = MagicMock() + fake_scheduler.add_job = MagicMock() + monkeypatch.setattr("litellm.proxy.proxy_server.scheduler", fake_scheduler) + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + { + "maximum_spend_logs_retention_period": "1d", + "maximum_spend_logs_cleanup_cron": "INVALID CRON STRING", + }, + ) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + pc = ProxyConfig() + # Invalid cron is caught and logged — does not raise outward. + await pc._reschedule_spend_log_cleanup_job() + # But add_job should not have been called for the invalid cron path. + assert fake_scheduler.add_job.call_count == 0 + + +# --------------------------------------------------------------------------- +# ProxyConfig._update_general_settings +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ProxyConfig__update_general_settings_updates_max_parallel(monkeypatch): + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {}, + ) + pc = ProxyConfig() + await pc._update_general_settings( + { + "max_parallel_requests": 7, + "global_max_parallel_requests": 99, + "ui_access_mode": "admin_only", + } + ) + from litellm.proxy import proxy_server as ps + + snapshot = { + "max_parallel_requests": ps.general_settings.get("max_parallel_requests"), + "global_max_parallel_requests": ps.general_settings.get( + "global_max_parallel_requests" + ), + "ui_access_mode": ps.general_settings.get("ui_access_mode"), + } + assert snapshot == { + "max_parallel_requests": 7, + "global_max_parallel_requests": 99, + "ui_access_mode": "admin_only", + } + + +@pytest.mark.asyncio +async def test_ProxyConfig__update_general_settings_none_input_noop(): + pc = ProxyConfig() + # None input returns early. + result = await pc._update_general_settings(db_general_settings=None) + assert result is None + # Error-style: dict() will fail on non-mapping non-None input. + with pytest.raises(Exception): + await pc._update_general_settings(db_general_settings=12345) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# ProxyConfig._update_config_fields +# --------------------------------------------------------------------------- + + +def test_ProxyConfig__update_config_fields_merges_dict(): + pc = ProxyConfig() + current = {"general_settings": {"a": 1, "b": 2}} + out = pc._update_config_fields( + current_config=current, + param_name="general_settings", + db_param_value={"b": 3, "c": 4, "d": 5}, + ) + assert out == {"general_settings": {"a": 1, "b": 3, "c": 4, "d": 5}} + + +def test_ProxyConfig__update_config_fields_invalid_param_raises(): + pc = ProxyConfig() + with pytest.raises(Exception): + # Missing required arg. + pc._update_config_fields(current_config={}, param_name="general_settings") # type: ignore[call-arg] diff --git a/tests/test_litellm/proxy/proxy_server/test_spend_counters.py b/tests/test_litellm/proxy/proxy_server/test_spend_counters.py index ad6b4016461..ec8b06d9c97 100644 --- a/tests/test_litellm/proxy/proxy_server/test_spend_counters.py +++ b/tests/test_litellm/proxy/proxy_server/test_spend_counters.py @@ -1 +1,817 @@ -"""Placeholder. Filled by a follow-up PR per the Notion plan.""" +"""Behavior pins for spend-counter helpers in proxy_server. + +Pins covered: +- ``get_current_spend`` +- ``increment_spend_counters`` +- ``_reconcile_budget_reservation_for_counter_update`` +- ``_increment_end_user_and_tag_spend_counters`` +- ``_increment_org_spend_counter`` +- ``_init_and_increment_unreserved_spend_counter`` +- ``_init_and_increment_spend_counter`` +- ``_init_and_increment_window_spend_counter`` +- ``_ensure_spend_counter_initialized`` +- ``_get_source_cache_base_spend`` +- ``_ensure_window_spend_counter_initialized`` +- ``_is_spend_counter_cache_warm`` +- ``_increment_spend_counter_cache`` +- ``_invalidate_spend_counter`` +- ``update_cache`` +""" + +from __future__ import annotations + +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import litellm.proxy.proxy_server as ps + +from .conftest import normalize + + +def _make_spend_counter_cache( + *, + redis_get_value=None, + redis_get_side_effect=None, + redis_increment_value=None, + redis_increment_side_effect=None, + in_memory_value=None, + with_redis: bool = True, +): + cache = MagicMock() + cache.in_memory_cache = MagicMock() + cache.in_memory_cache.get_cache = MagicMock(return_value=in_memory_value) + cache.in_memory_cache.set_cache = MagicMock() + cache.in_memory_cache.delete_cache = MagicMock() + if with_redis: + cache.redis_cache = MagicMock() + cache.redis_cache.async_get_cache = AsyncMock( + return_value=redis_get_value, side_effect=redis_get_side_effect + ) + cache.redis_cache.async_increment = AsyncMock( + return_value=redis_increment_value, + side_effect=redis_increment_side_effect, + ) + cache.redis_cache.async_delete_cache = AsyncMock() + else: + cache.redis_cache = None + cache.async_increment_cache = AsyncMock(return_value=redis_increment_value) + cache.async_get_cache = AsyncMock(return_value=None) + cache.async_set_cache = AsyncMock() + cache.async_delete_cache = AsyncMock() + cache.async_set_cache_pipeline = AsyncMock() + return cache + + +def _make_user_api_key_cache(get_value=None, get_side_effect=None): + cache = MagicMock() + cache.async_get_cache = AsyncMock( + return_value=get_value, side_effect=get_side_effect + ) + cache.async_set_cache_pipeline = AsyncMock() + return cache + + +# --------------------------------------------------------------------------- +# get_current_spend +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_current_spend_reads_redis_first(monkeypatch): + fake_cache = _make_spend_counter_cache(redis_get_value=42.5) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + result = await ps.get_current_spend(counter_key="spend:key:abc", fallback_spend=0.0) + + observed = { + "value": result, + "redis_called": fake_cache.redis_cache.async_get_cache.called, + "in_memory_called": fake_cache.in_memory_cache.get_cache.called, + } + assert normalize(observed) == { + "value": 42.5, + "redis_called": True, + "in_memory_called": False, + } + + +@pytest.mark.asyncio +async def test_get_current_spend_redis_error_falls_back_to_in_memory(monkeypatch): + fake_cache = _make_spend_counter_cache( + redis_get_side_effect=RuntimeError("redis down"), + in_memory_value=17.0, + ) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + result = await ps.get_current_spend( + counter_key="spend:key:abc", fallback_spend=99.0 + ) + assert result == 17.0 + + +# --------------------------------------------------------------------------- +# increment_spend_counters +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_increment_spend_counters_increments_all_buckets(monkeypatch): + fake_cache = _make_spend_counter_cache( + redis_get_value=None, redis_increment_value=5.0 + ) + fake_user_cache = _make_user_api_key_cache(get_value=None) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + monkeypatch.setattr(ps, "prisma_client", None) + + async def _fake_coalesced(**kwargs): + return None + + monkeypatch.setattr( + ps.SpendCounterReseed, "coalesced", AsyncMock(side_effect=_fake_coalesced) + ) + + await ps.increment_spend_counters( + token="hashed-tok", + team_id="t1", + user_id="u1", + response_cost=5.0, + ) + + observed = { + "redis_increment_called": fake_cache.redis_cache.async_increment.called, + "increment_calls": fake_cache.redis_cache.async_increment.call_count, + "user_cache_used": fake_user_cache.async_get_cache.called, + } + assert normalize(observed) == { + "redis_increment_called": True, + "increment_calls": 4, + "user_cache_used": True, + } + + +@pytest.mark.asyncio +async def test_increment_spend_counters_zero_cost_is_noop_finalizes_reservation( + monkeypatch, +): + fake_cache = _make_spend_counter_cache() + fake_user_cache = _make_user_api_key_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + monkeypatch.setattr(ps, "prisma_client", None) + reservation = {"finalized": False} + + await ps.increment_spend_counters( + token="hashed-tok", + team_id="t1", + user_id="u1", + response_cost=0, + budget_reservation=reservation, + ) + + assert reservation == {"finalized": True} + assert fake_cache.redis_cache.async_increment.called is False + + +# --------------------------------------------------------------------------- +# _reconcile_budget_reservation_for_counter_update +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_reconcile_budget_reservation_for_counter_update_returns_empty_set_when_none(): + result = await ps._reconcile_budget_reservation_for_counter_update( + budget_reservation=None, response_cost=1.0 + ) + assert result == set() + + +@pytest.mark.asyncio +async def test_reconcile_budget_reservation_for_counter_update_failure_invalidates( + monkeypatch, +): + """Reservation reconcile raising must invalidate reserved counters but + not propagate the exception.""" + import litellm.proxy.spend_tracking.budget_reservation as br + + monkeypatch.setattr( + br, + "get_reserved_counter_keys", + MagicMock(return_value={"spend:key:abc"}), + ) + monkeypatch.setattr( + br, + "reconcile_budget_reservation", + AsyncMock(side_effect=RuntimeError("boom")), + ) + fake_invalidate = AsyncMock() + monkeypatch.setattr(br, "invalidate_budget_reservation_counters", fake_invalidate) + + result = await ps._reconcile_budget_reservation_for_counter_update( + budget_reservation={"foo": "bar"}, response_cost=1.0 + ) + + assert result == {"spend:key:abc"} + assert fake_invalidate.called is True + + +# --------------------------------------------------------------------------- +# _increment_end_user_and_tag_spend_counters +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_increment_end_user_and_tag_spend_counters_increments_each_unique_tag( + monkeypatch, +): + fake_cache = _make_spend_counter_cache( + redis_get_value=None, redis_increment_value=3.0 + ) + fake_user_cache = _make_user_api_key_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + monkeypatch.setattr(ps, "prisma_client", None) + monkeypatch.setattr( + ps.SpendCounterReseed, "coalesced", AsyncMock(return_value=None) + ) + + await ps._increment_end_user_and_tag_spend_counters( + end_user_id="eu1", + tags=["a", "b", "a", "", None], + response_cost=3.0, + reserved_counter_keys=set(), + ) + + observed = { + "increment_calls": fake_cache.redis_cache.async_increment.call_count, + "in_memory_set_calls": fake_cache.in_memory_cache.set_cache.call_count, + "called": fake_cache.redis_cache.async_increment.called, + } + assert normalize(observed) == { + "increment_calls": 3, + "in_memory_set_calls": 3, + "called": True, + } + + +@pytest.mark.asyncio +async def test_increment_end_user_and_tag_spend_counters_no_end_user_no_tags_invalid_input_noop( + monkeypatch, +): + fake_cache = _make_spend_counter_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + await ps._increment_end_user_and_tag_spend_counters( + end_user_id=None, + tags=None, + response_cost=1.0, + reserved_counter_keys=set(), + ) + + assert fake_cache.redis_cache.async_increment.called is False + + +# --------------------------------------------------------------------------- +# _increment_org_spend_counter +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_increment_org_spend_counter_increments_when_org_present(monkeypatch): + fake_cache = _make_spend_counter_cache( + redis_get_value=None, redis_increment_value=10.0 + ) + fake_user_cache = _make_user_api_key_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + monkeypatch.setattr(ps, "prisma_client", None) + monkeypatch.setattr( + ps.SpendCounterReseed, "coalesced", AsyncMock(return_value=None) + ) + + await ps._increment_org_spend_counter( + org_id="org-1", + response_cost=10.0, + reserved_counter_keys=set(), + ) + + observed = { + "increment_called": fake_cache.redis_cache.async_increment.called, + "increment_calls": fake_cache.redis_cache.async_increment.call_count, + "counter_key_arg": fake_cache.redis_cache.async_increment.call_args.kwargs[ + "key" + ], + } + assert normalize(observed) == { + "increment_called": True, + "increment_calls": 1, + "counter_key_arg": "spend:org:org-1", + } + + +@pytest.mark.asyncio +async def test_increment_org_spend_counter_no_org_is_noop_invalid_id(monkeypatch): + fake_cache = _make_spend_counter_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + await ps._increment_org_spend_counter( + org_id=None, + response_cost=1.0, + reserved_counter_keys=set(), + ) + + assert fake_cache.redis_cache.async_increment.called is False + + +# --------------------------------------------------------------------------- +# _init_and_increment_unreserved_spend_counter +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_init_and_increment_unreserved_spend_counter_skips_reserved_keys( + monkeypatch, +): + fake_cache = _make_spend_counter_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + await ps._init_and_increment_unreserved_spend_counter( + counter_key="spend:tag:x", + source_cache_key="tag:x", + increment=1.0, + reserved_counter_keys={"spend:tag:x"}, + ) + + assert fake_cache.redis_cache.async_increment.called is False + + +@pytest.mark.asyncio +async def test_init_and_increment_unreserved_spend_counter_proceeds_when_not_reserved( + monkeypatch, +): + fake_cache = _make_spend_counter_cache( + redis_get_value=None, redis_increment_value=2.0 + ) + fake_user_cache = _make_user_api_key_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + monkeypatch.setattr(ps, "prisma_client", None) + monkeypatch.setattr( + ps.SpendCounterReseed, "coalesced", AsyncMock(return_value=None) + ) + + await ps._init_and_increment_unreserved_spend_counter( + counter_key="spend:tag:y", + source_cache_key="tag:y", + increment=2.0, + reserved_counter_keys=set(), + ) + + observed = { + "increment_called": fake_cache.redis_cache.async_increment.called, + "redis_get_called": fake_cache.redis_cache.async_get_cache.called, + "reseed_consulted": True, + } + assert observed == { + "increment_called": True, + "redis_get_called": True, + "reseed_consulted": True, + } + + +# --------------------------------------------------------------------------- +# _init_and_increment_spend_counter +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_init_and_increment_spend_counter_warm_cache_skips_reseed(monkeypatch): + fake_cache = _make_spend_counter_cache( + redis_get_value=11.0, redis_increment_value=14.0 + ) + fake_user_cache = _make_user_api_key_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + monkeypatch.setattr(ps, "prisma_client", None) + reseed = AsyncMock(return_value=None) + monkeypatch.setattr(ps.SpendCounterReseed, "coalesced", reseed) + + await ps._init_and_increment_spend_counter( + counter_key="spend:key:k", + source_cache_key="k", + increment=3.0, + ) + + observed = { + "reseed_called": reseed.called, + "increment_called": fake_cache.redis_cache.async_increment.called, + "in_memory_seeded_from_redis": fake_cache.in_memory_cache.set_cache.called, + } + assert normalize(observed) == { + "reseed_called": False, + "increment_called": True, + "in_memory_seeded_from_redis": True, + } + + +# --------------------------------------------------------------------------- +# _init_and_increment_window_spend_counter +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_init_and_increment_window_spend_counter_increments_when_initialized( + monkeypatch, +): + fake_cache = _make_spend_counter_cache( + redis_get_value=0.0, redis_increment_value=5.0 + ) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "prisma_client", None) + monkeypatch.setattr( + ps.SpendCounterReseed, + "coalesced_window", + AsyncMock(return_value=0.0), + ) + + await ps._init_and_increment_window_spend_counter( + counter_key="spend:key:k:window:1d", + entity_type="Key", + entity_id="k", + window_start=datetime(2024, 1, 1), + increment=5.0, + ) + + observed = { + "redis_increment_called": fake_cache.redis_cache.async_increment.called, + "increment_calls": fake_cache.redis_cache.async_increment.call_count, + "in_memory_set_calls": fake_cache.in_memory_cache.set_cache.call_count, + } + assert normalize(observed) == { + "redis_increment_called": True, + "increment_calls": 1, + "in_memory_set_calls": 2, + } + + +@pytest.mark.asyncio +async def test_init_and_increment_window_spend_counter_missing_window_start_invalid_skips( + monkeypatch, +): + fake_cache = _make_spend_counter_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + await ps._init_and_increment_window_spend_counter( + counter_key="spend:key:k:window:1d", + entity_type="Key", + entity_id="k", + window_start=None, + increment=5.0, + ) + + assert fake_cache.redis_cache.async_increment.called is False + + +# --------------------------------------------------------------------------- +# _ensure_spend_counter_initialized +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ensure_spend_counter_initialized_warm_skips_reseed_and_source( + monkeypatch, +): + fake_cache = _make_spend_counter_cache(redis_get_value=20.0) + fake_user_cache = _make_user_api_key_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + monkeypatch.setattr(ps, "prisma_client", None) + reseed = AsyncMock(return_value=None) + monkeypatch.setattr(ps.SpendCounterReseed, "coalesced", reseed) + + await ps._ensure_spend_counter_initialized( + counter_key="spend:user:u", + source_cache_key="u", + ) + + observed = { + "warm_check_redis": fake_cache.redis_cache.async_get_cache.called, + "reseed_called": reseed.called, + "source_cache_called": fake_user_cache.async_get_cache.called, + } + assert normalize(observed) == { + "warm_check_redis": True, + "reseed_called": False, + "source_cache_called": False, + } + + +@pytest.mark.asyncio +async def test_ensure_spend_counter_initialized_cold_seeds_from_source_cache( + monkeypatch, +): + fake_cache = _make_spend_counter_cache( + redis_get_value=None, redis_increment_value=7.0 + ) + fake_user_cache = _make_user_api_key_cache(get_value={"spend": 7.0}) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + monkeypatch.setattr(ps, "prisma_client", None) + monkeypatch.setattr( + ps.SpendCounterReseed, "coalesced", AsyncMock(return_value=None) + ) + + await ps._ensure_spend_counter_initialized( + counter_key="spend:user:u", + source_cache_key="u", + ) + + observed = { + "source_cache_called": fake_user_cache.async_get_cache.called, + "seed_increment_called": fake_cache.redis_cache.async_increment.called, + "warm_check_done": fake_cache.redis_cache.async_get_cache.called, + } + assert normalize(observed) == { + "source_cache_called": True, + "seed_increment_called": True, + "warm_check_done": True, + } + + +# --------------------------------------------------------------------------- +# _get_source_cache_base_spend +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_source_cache_base_spend_reads_first_hit_from_list(monkeypatch): + fake_user_cache = MagicMock() + + async def _get(key, **kwargs): + if key == "miss": + return None + if key == "hit-obj": + obj = MagicMock() + obj.spend = 12.0 + return obj + return None + + fake_user_cache.async_get_cache = AsyncMock(side_effect=_get) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + + result = await ps._get_source_cache_base_spend( + source_cache_key=["miss", "hit-obj", "miss2"] + ) + + observed = { + "result": result, + "calls": fake_user_cache.async_get_cache.call_count, + "stopped_after_hit": fake_user_cache.async_get_cache.call_count == 2, + } + assert normalize(observed) == { + "result": 12.0, + "calls": 2, + "stopped_after_hit": True, + } + + +@pytest.mark.asyncio +async def test_get_source_cache_base_spend_no_hits_returns_zero_fallback(monkeypatch): + """All cache lookups miss — function falls back to 0.0 (no error).""" + fake_user_cache = _make_user_api_key_cache(get_value=None) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + + result = await ps._get_source_cache_base_spend(source_cache_key="missing-key") + assert result == 0.0 + + +# --------------------------------------------------------------------------- +# _ensure_window_spend_counter_initialized +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ensure_window_spend_counter_initialized_warm_returns_true(monkeypatch): + fake_cache = _make_spend_counter_cache(redis_get_value=3.0) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "prisma_client", None) + window_reseed = AsyncMock(return_value=0.0) + monkeypatch.setattr(ps.SpendCounterReseed, "coalesced_window", window_reseed) + + initialized = await ps._ensure_window_spend_counter_initialized( + counter_key="spend:key:k:window:1d", + entity_type="Key", + entity_id="k", + window_start=datetime(2024, 1, 1), + ) + + observed = { + "initialized": initialized, + "reseed_called": window_reseed.called, + "redis_get_called": fake_cache.redis_cache.async_get_cache.called, + } + assert normalize(observed) == { + "initialized": True, + "reseed_called": False, + "redis_get_called": True, + } + + +@pytest.mark.asyncio +async def test_ensure_window_spend_counter_initialized_db_failure_invalid_returns_false( + monkeypatch, +): + fake_cache = _make_spend_counter_cache(redis_get_value=None) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + monkeypatch.setattr(ps, "prisma_client", None) + monkeypatch.setattr( + ps.SpendCounterReseed, + "coalesced_window", + AsyncMock(return_value=None), + ) + + initialized = await ps._ensure_window_spend_counter_initialized( + counter_key="spend:key:k:window:1d", + entity_type="Key", + entity_id="k", + window_start=datetime(2024, 1, 1), + ) + + assert initialized is False + + +# --------------------------------------------------------------------------- +# _is_spend_counter_cache_warm +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_is_spend_counter_cache_warm_redis_hit_seeds_in_memory(monkeypatch): + fake_cache = _make_spend_counter_cache(redis_get_value=99.0) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + result = await ps._is_spend_counter_cache_warm(counter_key="spend:user:u") + + observed = { + "result": result, + "redis_get_called": fake_cache.redis_cache.async_get_cache.called, + "in_memory_set_called": fake_cache.in_memory_cache.set_cache.called, + } + assert normalize(observed) == { + "result": True, + "redis_get_called": True, + "in_memory_set_called": True, + } + + +@pytest.mark.asyncio +async def test_is_spend_counter_cache_warm_redis_error_falls_back_to_in_memory( + monkeypatch, +): + fake_cache = _make_spend_counter_cache( + redis_get_side_effect=RuntimeError("redis err"), + in_memory_value=None, + ) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + result = await ps._is_spend_counter_cache_warm(counter_key="spend:user:u") + assert result is False + + +# --------------------------------------------------------------------------- +# _increment_spend_counter_cache +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_increment_spend_counter_cache_redis_path_returns_new_value(monkeypatch): + fake_cache = _make_spend_counter_cache(redis_increment_value=44.0) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + result = await ps._increment_spend_counter_cache( + counter_key="spend:key:k", increment=4.0 + ) + + observed = { + "result": result, + "redis_increment_called": fake_cache.redis_cache.async_increment.called, + "in_memory_set_called": fake_cache.in_memory_cache.set_cache.called, + } + assert normalize(observed) == { + "result": 44.0, + "redis_increment_called": True, + "in_memory_set_called": True, + } + + +@pytest.mark.asyncio +async def test_increment_spend_counter_cache_redis_error_raises_and_invalidates( + monkeypatch, +): + fake_cache = _make_spend_counter_cache( + redis_increment_side_effect=RuntimeError("incr fail") + ) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + with pytest.raises(RuntimeError): + await ps._increment_spend_counter_cache( + counter_key="spend:key:k", increment=1.0 + ) + + assert fake_cache.in_memory_cache.delete_cache.called is True + assert fake_cache.redis_cache.async_delete_cache.called is True + + +# --------------------------------------------------------------------------- +# _invalidate_spend_counter +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_invalidate_spend_counter_deletes_in_memory_and_redis(monkeypatch): + fake_cache = _make_spend_counter_cache() + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + await ps._invalidate_spend_counter(counter_key="spend:key:k") + + observed = { + "in_memory_delete_called": fake_cache.in_memory_cache.delete_cache.called, + "redis_delete_called": fake_cache.redis_cache.async_delete_cache.called, + "delete_args_key": fake_cache.redis_cache.async_delete_cache.call_args.kwargs[ + "key" + ], + } + assert normalize(observed) == { + "in_memory_delete_called": True, + "redis_delete_called": True, + "delete_args_key": "spend:key:k", + } + + +@pytest.mark.asyncio +async def test_invalidate_spend_counter_swallows_redis_failure_no_raise(monkeypatch): + fake_cache = _make_spend_counter_cache() + fake_cache.redis_cache.async_delete_cache = AsyncMock( + side_effect=RuntimeError("redis down") + ) + monkeypatch.setattr(ps, "spend_counter_cache", fake_cache) + + await ps._invalidate_spend_counter(counter_key="spend:key:k") + + assert fake_cache.in_memory_cache.delete_cache.called is True + + +# --------------------------------------------------------------------------- +# update_cache +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_update_cache_no_cached_entities_schedules_pipeline_flush(monkeypatch): + fake_user_cache = _make_user_api_key_cache(get_value=None) + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + + await ps.update_cache( + token=None, + user_id="u1", + end_user_id="eu1", + team_id="t1", + response_cost=1.0, + parent_otel_span=None, + tags=["x"], + ) + + observed = { + "lookups": fake_user_cache.async_get_cache.call_count, + "got_user": True, + "got_team": True, + } + assert normalize(observed) == { + "lookups": 4, + "got_user": True, + "got_team": True, + } + + +@pytest.mark.asyncio +async def test_update_cache_user_cache_failure_invalid_state_is_swallowed(monkeypatch): + """An inner _update_user_cache raising must not propagate — update_cache + catches and logs, the public coroutine still completes normally.""" + fake_user_cache = MagicMock() + fake_user_cache.async_get_cache = AsyncMock(side_effect=RuntimeError("cache down")) + fake_user_cache.async_set_cache_pipeline = AsyncMock() + monkeypatch.setattr(ps, "user_api_key_cache", fake_user_cache) + + result = await ps.update_cache( + token=None, + user_id="u1", + end_user_id=None, + team_id=None, + response_cost=1.0, + parent_otel_span=None, + tags=None, + ) + + assert result is None diff --git a/tests/test_litellm/proxy/proxy_server/test_streaming_helpers.py b/tests/test_litellm/proxy/proxy_server/test_streaming_helpers.py index ad6b4016461..33de1ede917 100644 --- a/tests/test_litellm/proxy/proxy_server/test_streaming_helpers.py +++ b/tests/test_litellm/proxy/proxy_server/test_streaming_helpers.py @@ -1 +1,555 @@ -"""Placeholder. Filled by a follow-up PR per the Notion plan.""" +"""Behavior pins for the proxy_server streaming helpers. + +Pins covered: +- ``data_generator`` +- ``async_assistants_data_generator`` +- ``_get_client_requested_model_for_streaming`` +- ``_restamp_streaming_chunk_model`` +- ``_fast_serialize_simple_model_response_stream`` +- ``_serialize_streaming_chunk`` +- ``_apply_streaming_chunk_hooks`` +- ``_format_streaming_sse_chunk`` +- ``async_data_generator`` +- ``select_data_generator`` +""" + +from __future__ import annotations + +import json +from typing import Any, AsyncIterator +from unittest.mock import AsyncMock, MagicMock + +import pytest + +import litellm.proxy.proxy_server as ps +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.proxy_server import ( + _apply_streaming_chunk_hooks, + _fast_serialize_simple_model_response_stream, + _format_streaming_sse_chunk, + _get_client_requested_model_for_streaming, + _restamp_streaming_chunk_model, + _serialize_streaming_chunk, + async_assistants_data_generator, + async_data_generator, + data_generator, + select_data_generator, +) +from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices, Usage + +from .conftest import normalize + + +def _user_auth() -> UserAPIKeyAuth: + return UserAPIKeyAuth(api_key="sk-test-key", user_id="u") + + +def _simple_chunk(model: str = "gpt-4", content: str = "hi") -> ModelResponseStream: + return ModelResponseStream( + id="chatcmpl-test", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content=content, role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion.chunk", + ) + + +async def _async_iter(items): + for it in items: + yield it + + +async def _async_iter_raises(exc: Exception): + # yield once then raise — exercises the mid-stream failure branch + yield _simple_chunk(content="partial") + raise exc + + +# --------------------------------------------------------------------------- +# data_generator +# --------------------------------------------------------------------------- + + +def test_data_generator_yields_sse_lines_for_dict_chunks(): + class DictChunk: + def __init__(self, payload): + self._payload = payload + + def dict(self): + return self._payload + + chunks = [ + DictChunk({"id": "1", "object": "chat.completion.chunk", "model": "gpt-4"}), + DictChunk({"id": "2", "object": "chat.completion.chunk", "model": "gpt-4"}), + ] + out = list(data_generator(chunks)) + + assert len(out) == 2 + payloads = [json.loads(line.removeprefix("data: ").rstrip("\n\n")) for line in out] + assert normalize(payloads[0]) == { + "id": "", + "object": "chat.completion.chunk", + "model": "gpt-4", + } + assert payloads[1]["model"] == "gpt-4" + + +def test_data_generator_fallback_when_dict_raises_exception(): + class BadChunk: + def dict(self): + raise RuntimeError("cannot serialize") + + # When .dict() raises, the inner json.dumps(chunk) on a non-JSON-serializable + # instance also raises — the generator does not catch the second failure. + with pytest.raises((TypeError, RuntimeError)): + list(data_generator([BadChunk()])) + + +# --------------------------------------------------------------------------- +# async_assistants_data_generator +# --------------------------------------------------------------------------- + + +class _FakeAssistantsStream: + """Mimic the async-context-manager + async-iterable shape of the + assistants streaming object (e.g. AssistantEventHandler).""" + + def __init__(self, chunks): + self._chunks = chunks + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def __aiter__(self): + async def _gen(): + for c in self._chunks: + yield c + + return _gen() + + +@pytest.mark.asyncio +async def test_async_assistants_data_generator_yields_sse_and_done(monkeypatch): + chunk = _simple_chunk(content="hello") + + async def _passthrough_hook(*, user_api_key_dict, response, data, **kwargs): + return response + + monkeypatch.setattr( + ps.proxy_logging_obj, + "async_post_call_streaming_hook", + _passthrough_hook, + ) + + stream = _FakeAssistantsStream([chunk]) + out = [] + async for line in async_assistants_data_generator( + response=stream, + user_api_key_dict=_user_auth(), + request_data={}, + ): + out.append(line) + + assert out[-1] == "data: [DONE]\n\n" + body = json.loads(out[0].removeprefix("data: ").rstrip("\n\n")) + assert normalize(body) == { + "id": "", + "created": "", + "model": "gpt-4", + "object": "chat.completion.chunk", + "choices": [ + { + "index": 0, + "delta": {"content": "hello", "role": "assistant"}, + } + ], + } + + +@pytest.mark.asyncio +async def test_async_assistants_data_generator_hook_failure_yields_error_chunk( + monkeypatch, +): + async def _boom_hook(*args, **kwargs): + raise RuntimeError("hook exploded") + + async def _noop_failure(*args, **kwargs): + return None + + monkeypatch.setattr( + ps.proxy_logging_obj, "async_post_call_streaming_hook", _boom_hook + ) + monkeypatch.setattr(ps.proxy_logging_obj, "post_call_failure_hook", _noop_failure) + + stream = _FakeAssistantsStream([_simple_chunk()]) + out = [] + async for line in async_assistants_data_generator( + response=stream, + user_api_key_dict=_user_auth(), + request_data={}, + ): + out.append(line) + + assert any("error" in line for line in out) + assert out[-1].startswith('data: {"error":') + + +# --------------------------------------------------------------------------- +# _get_client_requested_model_for_streaming +# --------------------------------------------------------------------------- + + +def test_get_client_requested_model_for_streaming_prefers_client_requested(): + request_data = { + "_litellm_client_requested_model": "gpt-4", + "model": "openai/internal-gpt-4", + "litellm_call_id": "abc", + } + result = _get_client_requested_model_for_streaming(request_data) + assert result == "gpt-4" + + snapshot = { + "result": result, + "client_field_preserved": request_data["_litellm_client_requested_model"], + "model_field_preserved": request_data["model"], + } + assert normalize(snapshot) == { + "result": "gpt-4", + "client_field_preserved": "gpt-4", + "model_field_preserved": "openai/internal-gpt-4", + } + + +def test_get_client_requested_model_for_streaming_falls_back_to_model_field(): + result = _get_client_requested_model_for_streaming({"model": "claude-sonnet"}) + assert result == "claude-sonnet" + + +def test_get_client_requested_model_for_streaming_missing_returns_empty_invalid(): + """When neither key is set or values are non-strings, the helper returns "" + rather than raising — callers depend on this to skip restamping.""" + assert _get_client_requested_model_for_streaming({}) == "" + assert _get_client_requested_model_for_streaming({"model": 123}) == "" + + +# --------------------------------------------------------------------------- +# _restamp_streaming_chunk_model +# --------------------------------------------------------------------------- + + +def test_restamp_streaming_chunk_model_overrides_model_on_basemodel(): + chunk = _simple_chunk(model="openai/internal-x") + new_chunk, logged = _restamp_streaming_chunk_model( + chunk=chunk, + requested_model_from_client="gpt-4", + request_data={"litellm_call_id": "id-1"}, + model_mismatch_logged=False, + ) + snapshot = { + "model": new_chunk.model, + "logged": logged, + "same_object": new_chunk is chunk, + } + assert snapshot == {"model": "gpt-4", "logged": True, "same_object": True} + + +def test_restamp_streaming_chunk_model_overrides_model_on_dict(): + chunk = {"model": "internal", "choices": []} + new_chunk, logged = _restamp_streaming_chunk_model( + chunk=chunk, + requested_model_from_client="gpt-4", + request_data={}, + model_mismatch_logged=True, + ) + assert new_chunk["model"] == "gpt-4" + assert logged is True + + +def test_restamp_streaming_chunk_model_invalid_chunk_type_unchanged(): + """For a non-BaseModel, non-dict chunk the helper returns it as-is + along with the original ``model_mismatch_logged`` flag.""" + chunk = "raw string chunk" + new_chunk, logged = _restamp_streaming_chunk_model( + chunk=chunk, + requested_model_from_client="gpt-4", + request_data={}, + model_mismatch_logged=False, + ) + assert new_chunk == "raw string chunk" + assert logged is False + + +# --------------------------------------------------------------------------- +# _fast_serialize_simple_model_response_stream +# --------------------------------------------------------------------------- + + +def test_fast_serialize_simple_model_response_stream_returns_bytes_payload(): + chunk = _simple_chunk() + result = _fast_serialize_simple_model_response_stream(chunk) + assert isinstance(result, bytes) + payload = json.loads(result) + assert normalize(payload) == { + "id": "", + "object": "chat.completion.chunk", + "created": "", + "model": "gpt-4", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "hi"}, + } + ], + } + + +def test_fast_serialize_simple_model_response_stream_with_usage_returns_none_invalid(): + """Fast path bails (returns None) when ``usage`` is populated — the slow + path is required to preserve usage fields. Returning None here is the + "I cannot handle this" sentinel, not a hard error.""" + chunk = _simple_chunk() + chunk.usage = Usage(prompt_tokens=1, completion_tokens=1, total_tokens=2) + assert _fast_serialize_simple_model_response_stream(chunk) is None + + +# --------------------------------------------------------------------------- +# _serialize_streaming_chunk +# --------------------------------------------------------------------------- + + +def test_serialize_streaming_chunk_simple_uses_fast_path_bytes(): + result = _serialize_streaming_chunk(_simple_chunk()) + assert isinstance(result, bytes) + payload = json.loads(result) + assert normalize(payload) == { + "id": "", + "object": "chat.completion.chunk", + "created": "", + "model": "gpt-4", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "hi"}, + } + ], + } + + +def test_serialize_streaming_chunk_invalid_input_raises_attribute_error(): + """The helper is typed as ``BaseModel`` — handing it a plain dict trips + the attribute-access path (no ``model_dump_json``).""" + with pytest.raises(AttributeError): + _serialize_streaming_chunk({"not": "a model"}) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# _apply_streaming_chunk_hooks +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_apply_streaming_chunk_hooks_appends_to_str_so_far(monkeypatch): + chunk = _simple_chunk(content="abc") + + async def _passthrough(*, user_api_key_dict, response, data, str_so_far=None): + return response + + monkeypatch.setattr( + ps.proxy_logging_obj, "async_post_call_streaming_hook", _passthrough + ) + + new_chunk, new_str = await _apply_streaming_chunk_hooks( + chunk=chunk, + user_api_key_dict=_user_auth(), + request_data={}, + str_so_far="prior:", + ) + + observed = { + "chunk_is_basemodel": isinstance(new_chunk, ModelResponseStream), + "str_so_far": new_str, + "grew": len(new_str) > len("prior:"), + } + assert observed == { + "chunk_is_basemodel": True, + "str_so_far": "prior:abc", + "grew": True, + } + + +@pytest.mark.asyncio +async def test_apply_streaming_chunk_hooks_hook_raises_exception(monkeypatch): + async def _boom(*args, **kwargs): + raise RuntimeError("hook failed") + + monkeypatch.setattr(ps.proxy_logging_obj, "async_post_call_streaming_hook", _boom) + + with pytest.raises(RuntimeError): + await _apply_streaming_chunk_hooks( + chunk=_simple_chunk(), + user_api_key_dict=_user_auth(), + request_data={}, + str_so_far="", + ) + + +# --------------------------------------------------------------------------- +# _format_streaming_sse_chunk +# --------------------------------------------------------------------------- + + +def test_format_streaming_sse_chunk_handles_bytes_and_str_shapes(): + bytes_out = _format_streaming_sse_chunk(b'{"a":1}') + str_out = _format_streaming_sse_chunk('{"a":1}') + + snapshot = { + "bytes_out": bytes_out, + "str_out": str_out, + "bytes_starts_with_data": bytes_out.startswith(b"data: "), + } + assert snapshot == { + "bytes_out": b'data: {"a":1}\n\n', + "str_out": 'data: {"a":1}\n\n', + "bytes_starts_with_data": True, + } + + +def test_format_streaming_sse_chunk_invalid_empty_string_still_wraps(): + """Edge case: empty string still gets the ``data: \\n\\n`` wrapping + — clients expect SSE shape even on empty payloads.""" + result = _format_streaming_sse_chunk("") + assert result == "data: \n\n" + + +# --------------------------------------------------------------------------- +# async_data_generator +# --------------------------------------------------------------------------- + + +def _patch_logging_flags(monkeypatch, needs_wrap=False, needs_per_chunk=False): + monkeypatch.setattr( + ps.proxy_logging_obj, + "needs_iterator_wrap", + lambda: needs_wrap, + ) + monkeypatch.setattr( + ps.proxy_logging_obj, + "needs_per_chunk_streaming_hook", + lambda: needs_per_chunk, + ) + # ``_fire_deferred_stream_logging`` is a classmethod — patch the + # underlying function so the no-wrap branch is a no-op rather than + # touching real logging globals. + monkeypatch.setattr( + ps.ProxyLogging, + "_fire_deferred_stream_logging", + staticmethod(lambda request_data: None), + ) + + +@pytest.mark.asyncio +async def test_async_data_generator_yields_sse_chunks_and_done(monkeypatch): + _patch_logging_flags(monkeypatch) + + response = _async_iter([_simple_chunk(content="hello")]) + out = [] + async for line in async_data_generator( + response=response, + user_api_key_dict=_user_auth(), + request_data={"model": "gpt-4"}, + ): + out.append(line) + + assert out[-1] == "data: [DONE]\n\n" + # First chunk is bytes (fast path) wrapped via _format_streaming_sse_chunk. + first = out[0] + assert isinstance(first, bytes) + payload = json.loads(first.removeprefix(b"data: ").rstrip(b"\n\n")) + assert normalize(payload) == { + "id": "", + "object": "chat.completion.chunk", + "created": "", + "model": "gpt-4", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "hello"}, + } + ], + } + + +@pytest.mark.asyncio +async def test_async_data_generator_mid_stream_exception_yields_error_payload( + monkeypatch, +): + _patch_logging_flags(monkeypatch) + + async def _noop_failure(*args, **kwargs): + return None + + monkeypatch.setattr(ps.proxy_logging_obj, "post_call_failure_hook", _noop_failure) + + response = _async_iter_raises(RuntimeError("upstream blew up")) + out = [] + async for line in async_data_generator( + response=response, + user_api_key_dict=_user_auth(), + request_data={}, + ): + out.append(line) + + # First entry is the successful "partial" chunk (bytes), last is the error. + assert any( + isinstance(item, str) and item.startswith('data: {"error":') for item in out + ) + + +# --------------------------------------------------------------------------- +# select_data_generator +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_select_data_generator_returns_async_generator(monkeypatch): + _patch_logging_flags(monkeypatch) + + response = _async_iter([_simple_chunk()]) + gen = select_data_generator( + response=response, + user_api_key_dict=_user_auth(), + request_data={"model": "gpt-4"}, + ) + + # Drain to confirm it really is an async iterator emitting SSE shape. + collected = [] + async for line in gen: + collected.append(line) + + snapshot = { + "is_async_iterable": hasattr(gen, "__aiter__"), + "yielded_at_least_one": len(collected) >= 1, + "ends_with_done": collected[-1] == "data: [DONE]\n\n", + } + assert snapshot == { + "is_async_iterable": True, + "yielded_at_least_one": True, + "ends_with_done": True, + } + + +def test_select_data_generator_missing_required_kwarg_raises_type_error(): + """``select_data_generator`` requires all three keyword args — calling + without ``request_data`` raises TypeError at the wrapper, before any + streaming starts.""" + with pytest.raises(TypeError): + select_data_generator(response=_async_iter([]), user_api_key_dict=_user_auth()) # type: ignore[call-arg]