Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions litellm/llms/azure/containers/transformation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Optional
from urllib.parse import parse_qs, urlparse, urlunparse

from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.containers.transformation import OpenAIContainerConfig
from litellm.types.router import GenericLiteLLMParams

# Endpoint-specific path suffixes that may appear in a deployment's api_base
# (e.g. the responses endpoint URL is stored as api_base for Azure models).
# Strip these before building the containers URL so we always start from the
# resource root (https://resource.cognitiveservices.azure.com).
_AZURE_ENDPOINT_PATHS = ("/openai/responses",)


class AzureContainerConfig(OpenAIContainerConfig):
"""
Expand All @@ -27,6 +34,27 @@ def validate_environment(
litellm_params=GenericLiteLLMParams(api_key=api_key),
)

@staticmethod
def _normalize_api_base(api_base: Optional[str]) -> Optional[str]:
"""Strip endpoint-specific path suffixes from api_base to get the resource root."""
if not api_base:
return api_base
parsed = urlparse(api_base)
path = parsed.path.rstrip("/")
for ep in _AZURE_ENDPOINT_PATHS:
if path.endswith(ep):
return urlunparse(
(parsed.scheme, parsed.netloc, path[: -len(ep)], "", "", "")
)
Comment thread
cursor[bot] marked this conversation as resolved.
return api_base

@staticmethod
def _extract_api_version(api_base: Optional[str]) -> Optional[str]:
"""Return the api-version query param from api_base if present."""
if not api_base:
return None
return parse_qs(urlparse(api_base).query).get("api-version", [None])[0]

def get_complete_url(
self,
api_base: Optional[str],
Expand All @@ -39,10 +67,19 @@ def get_complete_url(
{endpoint}/openai/v1/containers
when api_version is 'v1', 'latest', or 'preview'; otherwise:
{endpoint}/openai/containers
The deployment's api_base may be the responses endpoint URL
(e.g. .../openai/responses?api-version=2025-04-01-preview). We
prefer the api-version embedded there over the deployment's
api_version field, which may point to an older chat API version.
"""
effective_params = dict(litellm_params)
api_version_from_base = self._extract_api_version(api_base)
if api_version_from_base:
effective_params["api_version"] = api_version_from_base
return BaseAzureLLM._get_base_azure_url(
api_base=api_base,
litellm_params=litellm_params,
api_base=self._normalize_api_base(api_base),
litellm_params=effective_params,
route="/openai/containers",
default_api_version="v1",
)
26 changes: 18 additions & 8 deletions litellm/llms/custom_httpx/container_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,26 +257,31 @@ def _sync_handle(
returns_binary = endpoint_config.get("returns_binary", False)
is_multipart = endpoint_config.get("is_multipart", False)

# An empty dict passed as `params` to httpx strips any existing query
# string from the URL (e.g. ?api-version=...). Use None instead so
# httpx leaves the URL's own query string intact.
effective_params = query_params or None

try:
if method == "GET":
response = http_client.get(
url=url, headers=headers, params=query_params
url=url, headers=headers, params=effective_params
)
elif method == "DELETE":
response = http_client.delete(
url=url, headers=headers, params=query_params
url=url, headers=headers, params=effective_params
)
elif method == "POST":
if is_multipart and "file" in kwargs:
files, headers = _prepare_multipart_file_upload(
kwargs["file"], headers
)
response = http_client.post(
url=url, headers=headers, params=query_params, files=files
url=url, headers=headers, params=effective_params, files=files
)
else:
response = http_client.post(
url=url, headers=headers, params=query_params
url=url, headers=headers, params=effective_params
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
Expand Down Expand Up @@ -376,26 +381,31 @@ async def _async_handle(
returns_binary = endpoint_config.get("returns_binary", False)
is_multipart = endpoint_config.get("is_multipart", False)

# An empty dict passed as `params` to httpx strips any existing query
Comment thread
cursor[bot] marked this conversation as resolved.
# string from the URL (e.g. ?api-version=...). Use None instead so
# httpx leaves the URL's own query string intact.
effective_params = query_params or None

try:
if method == "GET":
response = await http_client.get(
url=url, headers=headers, params=query_params
url=url, headers=headers, params=effective_params
)
elif method == "DELETE":
response = await http_client.delete(
url=url, headers=headers, params=query_params
url=url, headers=headers, params=effective_params
)
elif method == "POST":
if is_multipart and "file" in kwargs:
files, headers = _prepare_multipart_file_upload(
kwargs["file"], headers
)
response = await http_client.post(
url=url, headers=headers, params=query_params, files=files
url=url, headers=headers, params=effective_params, files=files
)
else:
response = await http_client.post(
url=url, headers=headers, params=query_params
url=url, headers=headers, params=effective_params
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
Expand Down
20 changes: 10 additions & 10 deletions litellm/llms/custom_httpx/llm_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7834,7 +7834,7 @@ def container_list_handler(
response = sync_httpx_client.get(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_list_response(
Expand Down Expand Up @@ -7911,7 +7911,7 @@ async def async_container_list_handler(
response = await async_httpx_client.get(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_list_response(
Expand Down Expand Up @@ -8001,7 +8001,7 @@ def container_retrieve_handler(
response = sync_httpx_client.get(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_retrieve_response(
Expand Down Expand Up @@ -8078,7 +8078,7 @@ async def async_container_retrieve_handler(
response = await async_httpx_client.get(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_retrieve_response(
Expand Down Expand Up @@ -8168,7 +8168,7 @@ def container_delete_handler(
response = sync_httpx_client.delete(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_delete_response(
Expand Down Expand Up @@ -8245,7 +8245,7 @@ async def async_container_delete_handler(
response = await async_httpx_client.delete(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_delete_response(
Expand Down Expand Up @@ -8341,7 +8341,7 @@ def container_file_list_handler(
response = sync_httpx_client.get(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_file_list_response(
Expand Down Expand Up @@ -8420,7 +8420,7 @@ async def async_container_file_list_handler(
response = await async_httpx_client.get(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_file_list_response(
Expand Down Expand Up @@ -8508,7 +8508,7 @@ def container_file_content_handler(
response = sync_httpx_client.get(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_file_content_response(
Expand Down Expand Up @@ -8584,7 +8584,7 @@ async def async_container_file_content_handler(
response = await async_httpx_client.get(
url=url,
headers=headers,
params=params,
params=params or None,
)

return container_provider_config.transform_container_file_content_response(
Expand Down

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions litellm/proxy/container_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def retrieve_container(
custom_llm_provider=custom_llm_provider,
)
data.update(
get_container_forwarding_params(
await get_container_forwarding_params(
container_id,
original_container_id,
custom_llm_provider,
Expand Down Expand Up @@ -433,7 +433,7 @@ async def delete_container(
custom_llm_provider=custom_llm_provider,
)
data.update(
get_container_forwarding_params(
await get_container_forwarding_params(
container_id,
original_container_id,
custom_llm_provider,
Expand Down
14 changes: 8 additions & 6 deletions litellm/proxy/container_endpoints/handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,12 @@ async def _process_binary_request(
)
data: Dict[str, Any] = {
"file_id": file_id,
**get_container_forwarding_params(
container_id=container_id,
original_container_id=original_container_id,
custom_llm_provider=resolved_provider,
**(
await get_container_forwarding_params(
container_id=container_id,
original_container_id=original_container_id,
custom_llm_provider=resolved_provider,
)
),
}
processor = ProxyBaseLLMRequestProcessing(data=data)
Expand Down Expand Up @@ -316,7 +318,7 @@ async def _process_multipart_upload_request(
)

data.update(
get_container_forwarding_params(
await get_container_forwarding_params(
container_id=container_id,
original_container_id=original_container_id,
custom_llm_provider=resolved_provider,
Expand Down Expand Up @@ -396,7 +398,7 @@ async def _process_request(
)
)
data.update(
get_container_forwarding_params(
await get_container_forwarding_params(
container_id=path_params["container_id"],
original_container_id=original_container_id,
custom_llm_provider=resolved_provider,
Expand Down
75 changes: 74 additions & 1 deletion litellm/proxy/container_endpoints/ownership.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
_NEGATIVE_OWNER_SENTINEL = "__litellm_container_no_owner__"
_CONTAINER_OWNER_CACHE = InMemoryCache(max_size_in_memory=10000, default_ttl=60)

# Caches the stored ``unified_object_id`` (the encoded container ID
# captured at create time) so ``get_container_forwarding_params`` can
# recover the deployment ``model_id`` for native upstream IDs without
# re-hitting Prisma on every retrieve/delete.
_NEGATIVE_STORED_ID_SENTINEL = "__litellm_container_no_stored_id__"
_CONTAINER_STORED_ID_CACHE = InMemoryCache(max_size_in_memory=10000, default_ttl=60)

# Per-caller-scope cache for ``GET /v1/containers`` list filtering. Without
# this, every list call issues a fresh ``find_many`` against
# ``litellm_managedobjecttable``. The cache key is the sorted owner-scope
Expand Down Expand Up @@ -56,7 +63,7 @@ def decode_container_id_for_ownership(
return original_container_id, custom_llm_provider


def get_container_forwarding_params(
async def get_container_forwarding_params(
container_id: str, original_container_id: str, custom_llm_provider: str
) -> Dict[str, str]:
params = {
Expand All @@ -65,6 +72,20 @@ def get_container_forwarding_params(
}
decoded = ResponsesAPIRequestUtils._decode_container_id(container_id)
model_id = decoded.get("model_id")
if not (isinstance(model_id, str) and model_id):
# Native upstream IDs (e.g. Azure ``cntr_<hex>``) carry no LiteLLM
# routing payload, so decoding the user-supplied id yields no
# ``model_id``. Recover it from the encoded ``unified_object_id``
# captured on the ownership row at create time — when the router
# selected a specific deployment that ID embeds the model_id.
stored_id = await _get_stored_container_id(
original_container_id, custom_llm_provider
)
if stored_id and stored_id != container_id:
stored_decoded = ResponsesAPIRequestUtils._decode_container_id(stored_id)
stored_model_id = stored_decoded.get("model_id")
if isinstance(stored_model_id, str) and stored_model_id:
model_id = stored_model_id
if isinstance(model_id, str) and model_id:
params["model_id"] = model_id
return params
Expand Down Expand Up @@ -168,6 +189,7 @@ async def record_container_owner(
)

_CONTAINER_OWNER_CACHE.set_cache(model_object_id, owner)
_CONTAINER_STORED_ID_CACHE.set_cache(model_object_id, container_id)
# Drop the caller's own list-cache entry so the just-created container
# shows up on their next ``GET /v1/containers``. Other callers with
# disjoint scope tuples have their own entries; intersecting-scope
Expand Down Expand Up @@ -207,9 +229,60 @@ async def _get_container_owner(
_CONTAINER_OWNER_CACHE.set_cache(
model_object_id, owner if owner is not None else _NEGATIVE_OWNER_SENTINEL
)
stored_id = getattr(row, "unified_object_id", None) if row is not None else None
_CONTAINER_STORED_ID_CACHE.set_cache(
model_object_id,
(
stored_id
if isinstance(stored_id, str) and stored_id
else _NEGATIVE_STORED_ID_SENTINEL
),
)
return owner


async def _get_stored_container_id(
original_container_id: str, custom_llm_provider: str
) -> Optional[str]:
"""Return the ``unified_object_id`` stored at create time, if any.

Used by :func:`get_container_forwarding_params` to recover the
deployment ``model_id`` for native upstream container IDs: the stored
value is the encoded form produced by ``encode_container_id_in_response``
when the router selected a specific deployment.
"""
model_object_id = _container_model_object_id(
original_container_id, custom_llm_provider
)

cached = _CONTAINER_STORED_ID_CACHE.get_cache(model_object_id)
if cached == _NEGATIVE_STORED_ID_SENTINEL:
return None
if isinstance(cached, str) and cached:
return cached

prisma_client = await _get_prisma_client()
if prisma_client is None:
return None

row = await prisma_client.db.litellm_managedobjecttable.find_first(
where={
"model_object_id": model_object_id,
"file_purpose": CONTAINER_OBJECT_PURPOSE,
}
)
stored_id = getattr(row, "unified_object_id", None) if row is not None else None
_CONTAINER_STORED_ID_CACHE.set_cache(
model_object_id,
(
stored_id
if isinstance(stored_id, str) and stored_id
else _NEGATIVE_STORED_ID_SENTINEL
),
)
return stored_id if isinstance(stored_id, str) and stored_id else None


async def assert_user_can_access_container(
container_id: str,
user_api_key_dict: UserAPIKeyAuth,
Expand Down
Loading
Loading