Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion litellm/litellm_core_utils/core_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,28 @@ def _get_parent_otel_span_from_kwargs(
return None


def process_response_headers(response_headers: Union[httpx.Headers, dict]) -> dict:
def process_response_headers(
response_headers: Union[httpx.Headers, dict],
preserve_litellm_internal_headers: bool = False,
) -> dict:
"""
`preserve_litellm_internal_headers` must only be True when the input is a
LiteLLM-owned dict (e.g. `_hidden_params["additional_headers"]` that has
already been through one round of processing). For raw upstream provider
headers — whether passed as `httpx.Headers` or a plain dict — it must
remain False, otherwise a malicious provider returning `x-litellm-*` could
spoof LiteLLM-internal markers (e.g. `x-litellm-attempted-fallbacks`).

When the input is an `httpx.Headers` object the flag is always treated as
False regardless of what the caller requested, because `httpx.Headers` is
always a raw provider response and can never be LiteLLM-owned.
"""
from litellm.types.utils import OPENAI_RESPONSE_HEADERS

# Raw httpx.Headers objects come directly from provider HTTP responses and
# must never be treated as LiteLLM-owned, regardless of caller intent.
_preserve = preserve_litellm_internal_headers and isinstance(response_headers, dict)

openai_headers = {}
processed_headers = {}
additional_headers = {}
Expand All @@ -256,6 +275,12 @@ def process_response_headers(response_headers: Union[httpx.Headers, dict]) -> di
"llm_provider-"
): # return raw provider headers (incl. openai-compatible ones)
processed_headers[k] = v
elif _preserve and k.startswith("x-litellm-"):
# LiteLLM's own internal headers (e.g. x-litellm-attempted-fallbacks,
# x-litellm-model-group) are not LLM provider headers and must not be
# prefixed. Downstream consumers (proxy override, callers checking
# whether a fallback happened) look up the bare key.
processed_headers[k] = v
else:
additional_headers["{}-{}".format("llm_provider", k)] = v

Expand Down
10 changes: 8 additions & 2 deletions litellm/litellm_core_utils/fallback_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
safe_deep_copy,
filter_internal_params,
)
from litellm.router_utils.add_retry_fallback_headers import (
add_fallback_headers_to_response,
)

from .asyncify import run_async_function

Expand Down Expand Up @@ -42,7 +45,7 @@ async def async_completion_with_fallbacks(**kwargs):

# Try each fallback model
most_recent_exception_str: Optional[str] = None
for fallback in fallbacks:
for attempted_fallbacks, fallback in enumerate(fallbacks):
try:
completion_kwargs = safe_deep_copy(base_kwargs)
# Handle dictionary fallback configurations
Expand All @@ -63,7 +66,10 @@ async def async_completion_with_fallbacks(**kwargs):
)

if response is not None:
return response
return add_fallback_headers_to_response(
response=response,
attempted_fallbacks=attempted_fallbacks,
)

except Exception as e:
verbose_logger.exception(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def set_hidden_params(
result=self.result, litellm_model_name=model, router_model_id=model_id
),
"additional_headers": process_response_headers(
self._get_value_from_hidden_params("additional_headers") or {}
self._get_value_from_hidden_params("additional_headers") or {},
preserve_litellm_internal_headers=True,
Comment thread
veria-ai[bot] marked this conversation as resolved.
),
"litellm_model_name": model,
}
Expand Down
128 changes: 127 additions & 1 deletion tests/test_litellm/litellm_core_utils/test_fallback_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
"""Tests for litellm.litellm_core_utils.fallback_utils."""

import pytest
import httpx

import litellm
from litellm.litellm_core_utils.fallback_utils import async_completion_with_fallbacks
from litellm.litellm_core_utils.core_helpers import process_response_headers
from litellm.litellm_core_utils.fallback_utils import (
async_completion_with_fallbacks,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -41,3 +47,123 @@ async def _fake_acompletion(*, model: str, **kwargs):
"primary-model",
"fallback-model",
]


@pytest.mark.asyncio
async def test_async_completion_with_fallbacks_sets_attempted_fallbacks_header():
"""
When a fallback succeeds, the response must carry the
`x-litellm-attempted-fallbacks` header so the proxy and other callers can
detect that a fallback occurred. Without it,
`_override_openai_response_model` stamps the requested model back over the
fallback model used. See issue #28241.
"""
response = await async_completion_with_fallbacks(
model="openai/primary-llm",
messages=[{"role": "user", "content": "hi"}],
api_key="fake-key",
mock_response=Exception("forced failure"),
kwargs={
"fallbacks": [
{
"model": "openai/backup-llm",
"api_key": "fake-key",
"mock_response": "backup-resp",
}
]
},
)

hidden_params = getattr(response, "_hidden_params", None)
assert isinstance(hidden_params, dict)
headers = hidden_params.get("additional_headers") or {}
assert headers.get("x-litellm-attempted-fallbacks") == 1


@pytest.mark.asyncio
async def test_async_completion_with_fallbacks_header_is_zero_when_primary_succeeds():
"""
When the primary model succeeds on the first attempt, the header should be
`0` (no fallback was used). This mirrors the existing router-level
semantics in `async_function_with_fallbacks`.
"""
response = await async_completion_with_fallbacks(
model="openai/primary-llm",
messages=[{"role": "user", "content": "hi"}],
api_key="fake-key",
mock_response="primary-resp",
kwargs={
"fallbacks": [
{
"model": "openai/backup-llm",
"api_key": "fake-key",
"mock_response": "backup-resp",
}
]
},
)

hidden_params = getattr(response, "_hidden_params", None)
assert isinstance(hidden_params, dict)
headers = hidden_params.get("additional_headers") or {}
assert headers.get("x-litellm-attempted-fallbacks") == 0
assert response.choices[0].message.content == "primary-resp"


def test_process_response_headers_preserves_x_litellm_headers_when_internal():
"""
`process_response_headers` must not add the `llm_provider-` prefix to
LiteLLM's own internal headers (anything starting with `x-litellm-`) when
the caller has marked the input as LiteLLM-owned. These are markers set by
LiteLLM (e.g. fallback / retry headers); the proxy and other callers look
up the bare key.
"""
result = process_response_headers(
{
"x-litellm-attempted-fallbacks": 1,
"x-litellm-model-group": "gpt-4",
"x-stainless-arch": "arm64",
},
preserve_litellm_internal_headers=True,
)
assert result["x-litellm-attempted-fallbacks"] == 1
assert result["x-litellm-model-group"] == "gpt-4"
assert result["llm_provider-x-stainless-arch"] == "arm64"


def test_process_response_headers_prefixes_x_litellm_from_raw_provider():
"""
On raw upstream-provider headers (default `preserve_litellm_internal_headers=False`),
a header whose name starts with `x-litellm-` MUST still get the
`llm_provider-` prefix. Otherwise a malicious provider could return
`x-litellm-attempted-fallbacks` and spoof a LiteLLM-internal marker,
bypassing the proxy model-override guard.
"""
result = process_response_headers(
{
"x-litellm-attempted-fallbacks": 99,
"x-stainless-arch": "arm64",
}
)
assert "x-litellm-attempted-fallbacks" not in result
assert result["llm_provider-x-litellm-attempted-fallbacks"] == 99
assert result["llm_provider-x-stainless-arch"] == "arm64"


def test_process_response_headers_ignores_preserve_flag_for_httpx_headers():
"""
Some providers store raw httpx.Headers directly in _hidden_params["additional_headers"]
without a prior normalization pass. If preserve_litellm_internal_headers=True were
honored for httpx.Headers inputs, a provider returning x-litellm-attempted-fallbacks
could spoof it as a bare LiteLLM-internal marker and make the proxy skip
stamping the correct response model. The flag must be ignored for httpx.Headers.
"""
raw = httpx.Headers(
{
"x-litellm-attempted-fallbacks": "1",
"content-type": "application/json",
}
)
result = process_response_headers(raw, preserve_litellm_internal_headers=True)
assert "x-litellm-attempted-fallbacks" not in result
assert result["llm_provider-x-litellm-attempted-fallbacks"] == "1"
Loading