feat(guardrails): add Microsoft Purview DLP guardrail#24966
Conversation
|
The latest updates on your projects. Learn more about Vercel for GitHub.
|
Congrats! CodSpeed is installed 🎉
You will start to see performance impacts in the reports once the benchmarks are run from your default branch.
|
Greptile SummaryThis PR adds a Microsoft Purview DLP guardrail that evaluates prompts and responses via the Microsoft Graph
Confidence Score: 5/5The new guardrail is well-scoped, uses litellm's existing HTTP handler, and the previously identified bugs (missing raise_for_status, unbounded cache, missing correlation_id) are all addressed in the current implementation. The core implementation handles every failure mode correctly: Graph API errors surface via raise_for_status and are mapped to appropriate HTTP status codes, the protection-scope cache is bounded with LRU eviction, correlation IDs are derived from litellm_call_id, and blocking mode consistently fails closed. The refactoring of get_last_user_message is a clean deduplication with verified behavioral equivalence. The test suite is thorough and uses only mocks. No files require special attention. The most complex logic in base.py and purview_dlp.py is well-covered by the test suite.
|
| Filename | Overview |
|---|---|
| litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/base.py | New PurviewGuardrailBase: OAuth2 token caching, LRU-bounded scope cache (1000 entries, 1h TTL), ETag-based protection scope reuse, and Graph API helpers — all with raise_for_status() checks and correlation_id from litellm_call_id. |
| litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/purview_dlp.py | New MicrosoftPurviewDLPGuardrail: pre_call / post_call / logging_only modes with full streaming support, Responses API handling, fail-closed blocking, and deferred async audit via logging_hook/async_logging_hook. |
| litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/init.py | Guardrail initializer and registry entries for microsoft_purview; validates required credentials and wires up MicrosoftPurviewDLPGuardrail to the callback manager. |
| litellm/proxy/guardrails/guardrail_hooks/azure/base.py | Refactored get_last_user_message method to delegate to the shared common_utils utility, removing ~30 lines of duplicated logic. Behavior is identical. |
| litellm/proxy/guardrails/guardrail_hooks/openai/base.py | Same deduplication refactor as azure/base.py — delegates get_last_user_message to the shared utility. |
| litellm/litellm_core_utils/prompt_templates/common_utils.py | Minor docstring correction (get_user_prompt → get_last_user_message) and removal of a redundant circular self-import; function logic unchanged. |
| litellm/types/guardrails.py | Adds MICROSOFT_PURVIEW to SupportedGuardrailIntegrations enum; alphabetically reorders AktoConfigModel import. |
| tests/test_litellm/proxy/guardrails/guardrail_hooks/test_microsoft_purview.py | 2659-line unit-test suite using AsyncMock/Mock/patch — no real network calls; covers should_block, token caching, scope caching, pre/post/logging hooks, streaming, Responses API, and identity resolution. |
| .circleci/config.yml | Removes a single stray blank line between two CI job definitions; no functional change. |
Reviews (27): Last reviewed commit: "fix(purview): drop caller-influenceable ..." | Re-trigger Greptile
|
|
completion_prompt_to_str returns None for both token-id lists *and*
empty/whitespace-only strings (stripped). The previous check 'raw_prompt
is not None and prompt_text is None' conflated these cases, raising the
misleading 'Token-id completion prompts cannot be scanned' error for
harmless empty-string prompts like {"prompt": ""}.
Tighten the check to only reject true token-id prompts (non-empty list
of ints). Empty/whitespace string prompts now fall through to the
'no prompt text → skip scan' path.
Co-authored-by: Yassin Kortam <yassin@berri.ai>
The streaming iterator hook previously routed all assembled streams through stream_chunk_builder, which only knows chat/text-completion deltas. Responses API streams emit typed events (response.created, response.completed, ...) whose final event carries the full ResponsesAPIResponse, so stream_chunk_builder would raise APIError or pass the assembled response through unchanged. Detect Responses API streaming chunks before the chat/text fallthrough and extract the assembled ResponsesAPIResponse from the latest response.completed (or response.failed / response.incomplete) event, then scan its output_text via the same _completion_response_text_parts path used by non-streaming.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using high effort and found 1 potential issue.
There are 2 total unresolved issues (including 1 from previous review).
Autofix Details
Bugbot Autofix prepared a fix for the issue found in the latest run.
- ✅ Fixed: Responses API stream detection can silently misroute events
- Changed
_assemble_responses_api_from_chunksto return a(is_responses_api_stream, assembled)tuple, and updated the streaming hook to fail closed with an accurate "Incomplete Responses API stream" error when events are detected but no final response-bearing event arrives, instead of silently falling through tostream_chunk_builder.
- Changed
Preview (fd38a2ac0a)
diff --git a/.circleci/config.yml b/.circleci/config.yml
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -2541,7 +2541,6 @@
paths:
- litellm-docker-database.tar.zst
-
test_bad_database_url:
machine:
image: ubuntu-2204:2024.04.1
diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py
--- a/litellm/litellm_core_utils/prompt_templates/common_utils.py
+++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py
@@ -1204,12 +1204,8 @@
{"role": "assistant", "content": "I'm good, thank you!"},
{"role": "user", "content": "What is the weather in Tokyo?"},
]
- get_user_prompt(messages) -> "What is the weather in Tokyo?"
+ get_last_user_message(messages) -> "What is the weather in Tokyo?"
"""
- from litellm.litellm_core_utils.prompt_templates.common_utils import (
- convert_content_list_to_str,
- )
-
if not messages:
return None
diff --git a/litellm/proxy/guardrails/guardrail_hooks/azure/base.py b/litellm/proxy/guardrails/guardrail_hooks/azure/base.py
--- a/litellm/proxy/guardrails/guardrail_hooks/azure/base.py
+++ b/litellm/proxy/guardrails/guardrail_hooks/azure/base.py
@@ -2,6 +2,9 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
+from litellm.litellm_core_utils.prompt_templates.common_utils import (
+ get_last_user_message,
+)
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
@@ -134,32 +137,4 @@
]
get_user_prompt(messages) -> "What is the weather in Tokyo?"
"""
- from litellm.litellm_core_utils.prompt_templates.common_utils import (
- convert_content_list_to_str,
- )
-
- if not messages:
- return None
-
- # Iterate from the end to find the last consecutive block of user messages
- user_messages = []
- for message in reversed(messages):
- if message.get("role") == "user":
- user_messages.append(message)
- else:
- # Stop when we hit a non-user message
- break
-
- if not user_messages:
- return None
-
- # Reverse to get the messages in chronological order
- user_messages.reverse()
-
- user_prompt = ""
- for message in user_messages:
- text_content = convert_content_list_to_str(message)
- user_prompt += text_content + "\n"
-
- result = user_prompt.strip()
- return result if result else None
+ return get_last_user_message(messages)
diff --git a/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/__init__.py
new file mode 100644
--- /dev/null
+++ b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/__init__.py
@@ -1,0 +1,57 @@
+from typing import TYPE_CHECKING
+
+from litellm.types.guardrails import SupportedGuardrailIntegrations
+
+from .purview_dlp import MicrosoftPurviewDLPGuardrail
+
+if TYPE_CHECKING:
+ from litellm.types.guardrails import Guardrail, LitellmParams
+
+
+def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
+ import litellm
+
+ tenant_id = getattr(litellm_params, "tenant_id", None)
+ client_id = getattr(litellm_params, "client_id", None)
+
+ # client_secret can be passed via the standard api_key field or as
+ # a dedicated client_secret parameter.
+ client_secret = litellm_params.api_key or getattr(
+ litellm_params, "client_secret", None
+ )
+
+ if not tenant_id:
+ raise ValueError("Microsoft Purview: tenant_id is required")
+ if not client_id:
+ raise ValueError("Microsoft Purview: client_id is required")
+ if not client_secret:
+ raise ValueError("Microsoft Purview: client_secret (or api_key) is required")
+
+ guardrail_name = guardrail.get("guardrail_name")
+ if not guardrail_name:
+ raise ValueError("Microsoft Purview: guardrail_name is required")
+
+ purview_guardrail = MicrosoftPurviewDLPGuardrail(
+ guardrail_name=guardrail_name,
+ tenant_id=str(tenant_id),
+ client_id=str(client_id),
+ client_secret=str(client_secret),
+ purview_app_name=str(
+ getattr(litellm_params, "purview_app_name", None) or "LiteLLM"
+ ),
+ user_id_field=str(getattr(litellm_params, "user_id_field", None) or "user_id"),
+ event_hook=litellm_params.mode,
+ default_on=litellm_params.default_on,
+ )
+
+ litellm.logging_callback_manager.add_litellm_callback(purview_guardrail)
+ return purview_guardrail
+
+
+guardrail_initializer_registry = {
+ SupportedGuardrailIntegrations.MICROSOFT_PURVIEW.value: initialize_guardrail,
+}
+
+guardrail_class_registry = {
+ SupportedGuardrailIntegrations.MICROSOFT_PURVIEW.value: MicrosoftPurviewDLPGuardrail,
+}
diff --git a/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/base.py b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/base.py
new file mode 100644
--- /dev/null
+++ b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/base.py
@@ -1,0 +1,487 @@
+import threading
+import time
+import uuid
+from collections import OrderedDict
+from types import SimpleNamespace
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+from litellm._logging import verbose_proxy_logger
+from litellm.litellm_core_utils.url_utils import encode_url_path_segment
+from litellm.litellm_core_utils.prompt_templates.common_utils import (
+ convert_content_list_to_str,
+)
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+
+if TYPE_CHECKING:
+ from litellm.types.llms.openai import AllMessageValues
+
+GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
+TOKEN_ENDPOINT_TEMPLATE = (
+ "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"
+)
+GRAPH_SCOPE = "https://graph.microsoft.com/.default"
+
+# Protection scope cache TTL in seconds (1 hour, per Microsoft recommendation).
+SCOPE_CACHE_TTL_SECONDS = 3600.0
+
+
+class PurviewGuardrailBase:
+ """
+ Base class for Microsoft Purview guardrails.
+
+ Manages OAuth2 client-credentials token acquisition, protection scope
+ computation with ETag caching, and authenticated POST calls to the
+ Microsoft Graph API.
+ """
+
+ def __init__(
+ self,
+ tenant_id: str,
+ client_id: str,
+ client_secret: str,
+ purview_app_name: str = "LiteLLM",
+ user_id_field: str = "user_id",
+ **kwargs: Any,
+ ):
+ # Forward remaining kwargs to the next class in the MRO
+ # (typically CustomGuardrail).
+ super().__init__(**kwargs)
+
+ self.async_handler = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.GuardrailCallback
+ )
+ self.tenant_id = tenant_id
+ self.client_id = client_id
+ self.client_secret = client_secret
+ self.purview_app_name = purview_app_name
+ self.user_id_field = user_id_field
+
+ # Token cache: (access_token, expires_at_epoch)
+ self._token_cache: Optional[Tuple[str, float]] = None
+
+ # Protection scope cache: user_id -> (etag, scope_response, fetched_at)
+ # Capped at 1000 entries (LRU eviction) to avoid unbounded growth.
+ self._scope_cache: OrderedDict[str, Tuple[str, Dict[str, Any], float]] = (
+ OrderedDict()
+ )
+ self._scope_cache_maxsize = 1000
+ # Use a threading.Lock (not asyncio.Lock) because this lock is acquired
+ # from both the proxy's main asyncio event loop and from short-lived
+ # event loops created by the logging_hook thread fallback. In Python
+ # 3.10+ an asyncio.Lock is bound to the first event loop that acquires
+ # it and raises RuntimeError from any other loop, which would silently
+ # break audit logging via the thread fallback. All critical sections
+ # below are pure in-memory dict ops with no awaits, so a synchronous
+ # lock is both correct and sufficient.
+ self._cache_lock = threading.Lock()
+
+ @staticmethod
+ def _encode_graph_user_id(user_id: str) -> str:
+ """Percent-encode Entra user id for Graph ``/users/{id}/...`` path segments."""
+ return encode_url_path_segment(user_id, field_name="user_id")
+
+ # ------------------------------------------------------------------
+ # OAuth2 token management
+ # ------------------------------------------------------------------
+
+ async def _get_access_token(self) -> str:
+ """Acquire or return cached OAuth2 token via client_credentials grant."""
+ now = time.time()
+ with self._cache_lock:
+ if self._token_cache and self._token_cache[1] > now + 60:
+ return self._token_cache[0]
+
+ url = TOKEN_ENDPOINT_TEMPLATE.format(tenant_id=self.tenant_id)
+ data = {
+ "grant_type": "client_credentials",
+ "client_id": self.client_id,
+ "client_secret": self.client_secret,
+ "scope": GRAPH_SCOPE,
+ }
+ response = await self.async_handler.post(
+ url=url,
+ data=data,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ )
+ response.raise_for_status()
+ token_data = response.json()
+ access_token = token_data["access_token"]
+ expires_in = int(token_data.get("expires_in", 3599))
+ # Recompute ``now`` after the await so the expiry reflects when the
+ # token was actually received, not when the request started.
+ with self._cache_lock:
+ self._token_cache = (access_token, time.time() + expires_in)
+ verbose_proxy_logger.debug(
+ "Purview: acquired new OAuth2 token (expires_in=%ds)", expires_in
+ )
+ return access_token
+
+ # ------------------------------------------------------------------
+ # Graph API helpers
+ # ------------------------------------------------------------------
+
+ async def _graph_post(
+ self,
+ url: str,
+ json_body: Dict[str, Any],
+ extra_headers: Optional[Dict[str, str]] = None,
+ ) -> Tuple[Dict[str, Any], Dict[str, str]]:
+ """POST to Graph API with bearer auth.
+
+ Returns:
+ Tuple of (response_json, response_headers).
+ """
+ token = await self._get_access_token()
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ }
+ if extra_headers:
+ headers.update(extra_headers)
+
+ verbose_proxy_logger.debug("Purview Graph POST %s", url)
+ response = await self.async_handler.post(
+ url=url, headers=headers, json=json_body
+ )
+ response.raise_for_status()
+ response_json: Dict[str, Any] = response.json()
+ response_headers = dict(response.headers)
+ verbose_proxy_logger.debug("Purview Graph response: %s", response_json)
+ return response_json, response_headers
+
+ # ------------------------------------------------------------------
+ # Protection scopes
+ # ------------------------------------------------------------------
+
+ async def _compute_protection_scopes(
+ self, user_id: str
+ ) -> Tuple[str, Dict[str, Any]]:
+ """Call protectionScopes/compute and cache with ETag.
+
+ Returns:
+ Tuple of (etag, scope_response).
+ """
+ encoded_user_id = self._encode_graph_user_id(user_id)
+ now = time.time()
+
+ with self._cache_lock:
+ cached = self._scope_cache.get(user_id)
+ if cached and (now - cached[2]) < SCOPE_CACHE_TTL_SECONDS:
+ self._scope_cache.move_to_end(user_id)
+ return cached[0], cached[1]
+
+ url = (
+ f"{GRAPH_API_BASE}/users/{encoded_user_id}"
+ "/dataSecurityAndGovernance/protectionScopes/compute"
+ )
+ body: Dict[str, Any] = {
+ "activities": "uploadText,downloadText",
+ "locations": [
+ {
+ "@odata.type": "microsoft.graph.policyLocationApplication",
+ "value": self.client_id,
+ }
+ ],
+ }
+
+ response_json, response_headers = await self._graph_post(url, body)
+ etag = response_headers.get("etag", response_headers.get("ETag", ""))
+
+ # Recompute ``now`` after the await so the TTL reflects when the
+ # scope response was actually received, not when the request started.
+ fetched_at = time.time()
+ with self._cache_lock:
+ self._scope_cache[user_id] = (etag, response_json, fetched_at)
+ # Move refreshed entry to the end so it is treated as most-recently-used.
+ # OrderedDict.__setitem__ preserves existing insertion order for known
+ # keys, so an explicit move_to_end() call is required.
+ self._scope_cache.move_to_end(user_id)
+ # Evict least-recently-used entry when cache exceeds max size.
+ while len(self._scope_cache) > self._scope_cache_maxsize:
+ self._scope_cache.popitem(last=False)
+ return etag, response_json
+
+ # ------------------------------------------------------------------
+ # Process content
+ # ------------------------------------------------------------------
+
+ async def _process_content(
+ self,
+ user_id: str,
+ text: str,
+ activity: str,
+ etag: str,
+ correlation_id: Optional[str] = None,
+ ) -> Dict[str, Any]:
+ """Call processContent for DLP policy evaluation.
+
+ Args:
+ user_id: Entra object ID of the user.
+ text: The content to evaluate.
+ activity: ``"uploadText"`` for prompts, ``"downloadText"`` for responses.
+ etag: Cached ETag from protectionScopes/compute.
+ correlation_id: Optional conversation/thread ID.
+ """
+ encoded_user_id = self._encode_graph_user_id(user_id)
+ url = (
+ f"{GRAPH_API_BASE}/users/{encoded_user_id}"
+ "/dataSecurityAndGovernance/processContent"
+ )
+ body: Dict[str, Any] = {
+ "contentToProcess": {
+ "contentEntries": [
+ {
+ "@odata.type": "microsoft.graph.processConversationMetadata",
+ "identifier": str(uuid.uuid4()),
+ "content": {
+ "@odata.type": "microsoft.graph.textContent",
+ "data": text,
+ },
+ "name": f"{self.purview_app_name} message",
+ "correlationId": correlation_id or str(uuid.uuid4()),
+ "sequenceNumber": 0,
+ "isTruncated": False,
+ }
+ ],
+ "activityMetadata": {"activity": activity},
+ "deviceMetadata": {},
+ "protectedAppMetadata": {
+ "name": self.purview_app_name,
+ "version": "1.0",
+ "applicationLocation": {
+ "@odata.type": "microsoft.graph.policyLocationApplication",
+ "value": self.client_id,
+ },
+ },
+ "integratedAppMetadata": {
+ "name": self.purview_app_name,
+ "version": "1.0",
+ },
+ }
+ }
+
+ extra_headers: Dict[str, str] = {}
+ if etag:
+ extra_headers["If-None-Match"] = etag
+
+ response_json, _ = await self._graph_post(url, body, extra_headers)
+
+ # If policies changed, invalidate scope cache so next call re-fetches.
+ if response_json.get("protectionScopeState") == "modified":
+ with self._cache_lock:
+ self._scope_cache.pop(user_id, None)
+
+ return response_json
+
+ # ------------------------------------------------------------------
+ # User ID resolution
+ # ------------------------------------------------------------------
+
+ def _resolve_user_id(
+ self, data: Dict[str, Any], user_api_key_dict: Any
+ ) -> Optional[str]:
+ """Resolve the Entra user object ID from request data or auth context.
+
+ Trust order (strongest first). Blocking DLP uses only
+ ``_resolve_trusted_user_id`` (API-key-bound ``user_id``). This resolver
+ also supports audit/logging fallbacks:
+
+ 1. ``user_api_key_dict.user_id`` — LiteLLM key / JWT-bound user
+ 2. ``user_api_key_dict.end_user_id`` — request-derived; audit only
+ 3. ``metadata["user_api_key_user_id"]`` — proxy-injected from the key
+ 4. ``metadata[user_id_field]`` — caller-supplied; audit only
+ """
+ trusted = self._resolve_trusted_user_id(data, user_api_key_dict)
+ if trusted:
+ return trusted
+
+ if hasattr(user_api_key_dict, "end_user_id") and user_api_key_dict.end_user_id:
+ return str(user_api_key_dict.end_user_id)
+
+ metadata = data.get("metadata") or data.get("litellm_metadata") or {}
+ uid = metadata.get("user_api_key_user_id")
+ if uid:
+ return str(uid)
+
+ uid = metadata.get(self.user_id_field)
+ if uid:
+ return str(uid)
+
+ return None
+
+ @staticmethod
+ def _logging_kwargs_metadata(kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ """Metadata dict from ``model_call_details`` / logging kwargs."""
+ litellm_params = kwargs.get("litellm_params") or {}
+ if not isinstance(litellm_params, dict):
+ return {}
+ md = litellm_params.get("metadata")
+ return md if isinstance(md, dict) else {}
+
+ def _resolve_trusted_user_id(
+ self, data: Dict[str, Any], user_api_key_dict: Any
+ ) -> Optional[str]:
+ """Resolve user ID from API-key/JWT-bound identity for blocking DLP.
+
+ Uses only ``UserAPIKeyAuth.user_id`` (bound on the LiteLLM key or JWT).
+ Intentionally omits ``UserAPIKeyAuth.end_user_id`` because the proxy sets
+ it from caller-controlled request fields (``user``, ``metadata.user_id``,
+ ``safety_identifier``, custom headers, etc.) via
+ ``get_end_user_id_from_request_body``.
+
+ Also omits ``metadata[user_id_field]`` and
+ ``metadata["user_api_key_user_id"]`` for the same impersonation risk when
+ the key has no bound user.
+
+ Returns ``None`` when no authenticated identity is available. Blocking
+ hooks must fail closed rather than skip the DLP check.
+ """
+ if hasattr(user_api_key_dict, "user_id") and user_api_key_dict.user_id:
+ return str(user_api_key_dict.user_id)
+
+ return None
+
+ def _resolve_user_id_from_logging_kwargs(
+ self, kwargs: Dict[str, Any]
+ ) -> Optional[str]:
+ """Same trust order as ``_resolve_user_id`` for logging-only hooks (no ``UserAPIKeyAuth``)."""
+ md = self._logging_kwargs_metadata(kwargs)
+ shim = SimpleNamespace(
+ user_id=md.get("user_api_key_user_id")
+ or kwargs.get("user_api_key_user_id"),
+ end_user_id=md.get("user_api_key_end_user_id")
+ or kwargs.get("user_api_key_end_user_id"),
+ )
+ return self._resolve_user_id({"metadata": md}, shim)
+
+ # ------------------------------------------------------------------
+ # Policy action evaluation
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _should_block(response: Dict[str, Any]) -> bool:
+ """Return True if any policyAction requires blocking."""
+ for action in response.get("policyActions", []):
+ odata_type = action.get("@odata.type", "")
+ action_field = action.get("action", "")
+
+ if "restrictAccessAction" in odata_type or action_field == "restrictAccess":
+ restriction = action.get("restrictionAction", "")
+ if restriction == "block":
+ return True
+ return False
+
+ # ------------------------------------------------------------------
+ # Prompt text for DLP
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def completion_prompt_to_str(prompt: Any) -> Optional[str]:
+ """Normalize OpenAI ``/v1/completions`` ``prompt`` for text DLP.
+
+ Supports string prompts and list-of-string prompts. List-of-token-id prompts
+ are skipped (no plaintext for Purview to evaluate).
+ """
+ if prompt is None:
+ return None
+ if isinstance(prompt, str):
+ stripped = prompt.strip()
+ return stripped or None
+ if isinstance(prompt, list) and prompt:
+ if all(isinstance(x, str) for x in prompt):
+ joined = "\n".join(s.strip() for s in prompt if isinstance(s, str))
+ return joined.strip() or None
+ if all(isinstance(x, int) for x in prompt):
+ verbose_proxy_logger.debug(
+ "Purview DLP: completions prompt is token ids only; skipping text scan"
+ )
+ return None
+ str_parts = [x for x in prompt if isinstance(x, str)]
+ if str_parts:
+ joined = "\n".join(s.strip() for s in str_parts)
+ return joined.strip() or None
+ return None
+
+ @staticmethod
+ def _extract_tool_call_args_from_message(message: Any) -> List[str]:
+ """Return plaintext arguments strings from tool_calls and function_call fields.
+
+ Covers both the request path (assistant messages in chat histories that
+ carry tool_calls / function_call) and the response path (model-generated
+ tool calls returned in a ModelResponse). Both dict-style and object-style
+ representations are handled.
+ """
+ args: List[str] = []
+
+ # tool_calls: [{"function": {"arguments": "..."}}]
+ tool_calls = (
+ message.get("tool_calls")
+ if isinstance(message, dict)
+ else getattr(message, "tool_calls", None)
+ )
+ if tool_calls:
+ for tc in tool_calls:
+ fn = (
+ tc.get("function")
+ if isinstance(tc, dict)
+ else getattr(tc, "function", None)
+ )
+ if fn is None:
+ continue
+ arguments = (
+ fn.get("arguments")
+ if isinstance(fn, dict)
+ else getattr(fn, "arguments", None)
+ )
+ if isinstance(arguments, str) and arguments.strip():
+ args.append(arguments)
+
+ # Legacy function_call: {"arguments": "..."}
+ function_call = (
+ message.get("function_call")
+ if isinstance(message, dict)
+ else getattr(message, "function_call", None)
+ )
+ if function_call is not None:
+ arguments = (
+ function_call.get("arguments")
+ if isinstance(function_call, dict)
+ else getattr(function_call, "arguments", None)
+ )
+ if isinstance(arguments, str) and arguments.strip():
+ args.append(arguments)
+
+ return args
+
+ def get_prompt_text_for_dlp(
+ self, messages: List["AllMessageValues"]
+ ) -> Optional[str]:
+ """Concatenate text from every chat message (all roles) for pre-call DLP.
+
+ Evaluates the same payload the model receives, not only the trailing user
+ turn. Each message is separated by ``\\n\\n`` so that tokens at message
+ boundaries are not merged (e.g., ``"end of msg1\\n\\nstart of msg2"``
+ rather than ``"end of msg1start of msg2"``), which preserves DLP pattern
+ detection accuracy across message boundaries.
+
+ Tool-call arguments (``tool_calls[].function.arguments`` and
+ ``function_call.arguments``) are included alongside message content so
+ that sensitive data hidden in function arguments is not bypassed.
+ """
+ if not messages:
+ return None
+ parts: List[str] = []
+ for msg in messages:
+ segments: List[str] = []
+ content = convert_content_list_to_str(message=msg).strip()
+ if content:
+ segments.append(content)
+ segments.extend(self._extract_tool_call_args_from_message(msg))
+ combined = "\n".join(segments)
+ if combined.strip():
+ parts.append(combined.strip())
+ text = "\n\n".join(parts)
+ return text or None
diff --git a/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/purview_dlp.py b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/purview_dlp.py
new file mode 100644
--- /dev/null
+++ b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/purview_dlp.py
@@ -1,0 +1,658 @@
+"""
+Microsoft Purview DLP Guardrail for LiteLLM.
+
+Supports three modes:
+- pre_call: Block sensitive data in prompts before they reach the LLM.
+- post_call: Block sensitive data in LLM responses.
+- logging_only: Log interactions to Purview for audit/compliance without blocking.
+"""
+
+import asyncio
+import threading
+import uuid
+from datetime import datetime
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncGenerator,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
+
+from fastapi import HTTPException
+
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_guardrail import (
+ CustomGuardrail,
+ log_guardrail_information,
+)
+from litellm.types.guardrails import GuardrailEventHooks
+from litellm.types.utils import (
+ Choices,
+ GuardrailStatus,
+ ModelResponse,
+ ModelResponseStream,
+ ResponsesAPIResponse,
+ TextChoices,
+ TextCompletionResponse,
+)
+
+from .base import PurviewGuardrailBase
+
+if TYPE_CHECKING:
+ from litellm.proxy._types import UserAPIKeyAuth
+ from litellm.types.proxy.guardrails.guardrail_hooks.base import (
+ GuardrailConfigModel,
+ )
+ from litellm.types.utils import (
+ CallTypesLiteral,
+ EmbeddingResponse,
+ ImageResponse,
+ )
+
+
+class MicrosoftPurviewDLPGuardrail(PurviewGuardrailBase, CustomGuardrail):
+ """
+ Microsoft Purview DLP guardrail.
+
+ Evaluates prompts and responses against Microsoft Purview DLP policies
+ via the Microsoft Graph ``processContent`` API.
+ """
+
+ def __init__(
+ self,
+ guardrail_name: str,
+ tenant_id: str,
+ client_id: str,
+ client_secret: str,
+ purview_app_name: str = "LiteLLM",
+ user_id_field: str = "user_id",
+ **kwargs: Any,
+ ):
+ supported_event_hooks = [
+ GuardrailEventHooks.pre_call,
+ GuardrailEventHooks.post_call,
+ GuardrailEventHooks.logging_only,
+ ]
+
+ super().__init__(
+ tenant_id=tenant_id,
+ client_id=client_id,
+ client_secret=client_secret,
+ purview_app_name=purview_app_name,
+ user_id_field=user_id_field,
+ guardrail_name=guardrail_name,
+ supported_event_hooks=supported_event_hooks,
+ **kwargs,
+ )
+ self.guardrail_provider = "microsoft_purview"
+ verbose_proxy_logger.info(
+ "Initialized Microsoft Purview DLP Guardrail: %s",
+ guardrail_name,
+ )
+
+ @staticmethod
+ def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
+ return None # Config model can be added later for UI support
+
+ # ------------------------------------------------------------------
+ # Core DLP check
+ # ------------------------------------------------------------------
+
+ async def _check_content(
+ self,
+ user_id: str,
+ text: str,
+ activity: str,
+ request_data: Dict[str, Any],
+ block_on_violation: bool = True,
+ ) -> Dict[str, Any]:
+ """Evaluate content against Purview DLP policies.
+
+ Args:
+ user_id: Entra object ID.
+ text: Content to evaluate.
+ activity: ``"uploadText"`` or ``"downloadText"``.
+ request_data: Original request dict (used for logging metadata).
+ block_on_violation: If False, log only — do not raise.
+
+ Returns:
+ The processContent response dict.
+ """
+ start_time = datetime.now()
+ status: GuardrailStatus = "success"
+ response: Dict[str, Any] = {}
+
+ try:
+ etag, _ = await self._compute_protection_scopes(user_id)
+ correlation_id = request_data.get("litellm_call_id") or str(uuid.uuid4())
+ response = await self._process_content(
+ user_id=user_id,
+ text=text,
+ activity=activity,
+ etag=etag,
+ correlation_id=correlation_id,
+ )
+
+ if self._should_block(response):
+ status = "guardrail_intervened"
+ except Exception as exc:
+ status = "guardrail_failed_to_respond"
+ if block_on_violation:
+ raise
+ verbose_proxy_logger.warning(
+ "Purview DLP: API/network error in logging-only mode (not re-raised): %s",
+ exc,
+ )
+ finally:
+ end_time = datetime.now()
+ self.add_standard_logging_guardrail_information_to_request_data(
+ guardrail_provider=self.guardrail_provider,
+ guardrail_json_response=response,
+ request_data=request_data,
+ guardrail_status=status,
+ start_time=start_time.timestamp(),
+ end_time=end_time.timestamp(),
+ duration=(end_time - start_time).total_seconds(),
... diff truncated: showing 800 of 3755 linesYou can send follow-ups to the cloud agent here.
Previously, _assemble_responses_api_from_chunks returned None both when the stream was not a Responses API stream and when it was a Responses API stream but no final ResponsesAPIResponse-bearing event was received. The caller treated both cases identically and fell through to stream_chunk_builder, which does not understand Responses API events. Return a (is_responses_api_stream, assembled) tuple so the caller can fail closed with an accurate error when Responses API events were seen but no final response event arrived, instead of misrouting events to the chat chunk builder. Co-authored-by: Yassin Kortam <yassin@berri.ai>
…locking mode Previously a bare `raise` in `_check_content` re-propagated raw network / HTTP errors (e.g. httpx.HTTPStatusError, ConnectionError) to the client, which would surface as a 500. Now blocking-mode failures from the Graph `processContent` call (and OAuth token / protection-scopes calls) are converted to HTTPException(400) with a structured detail payload, while HTTPException instances raised by upstream layers continue to propagate unchanged. Logging-only mode is unaffected.
|
bugbot run |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using high effort and found 2 potential issues.
Bugbot Autofix resolved 1 of the 2 issues found in the latest run.
- ✅ Fixed: ResponsesAPIResponse tool call arguments bypass DLP scanning
- Added a helper that extracts
argumentsfromfunction_callitems inResponsesAPIResponse.outputand included them in_completion_response_text_partsso tool-call args are DLP-scanned alongsideoutput_text.
- Added a helper that extracts
Preview (734a0cb5cf)
diff --git a/.circleci/config.yml b/.circleci/config.yml
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -2541,7 +2541,6 @@
paths:
- litellm-docker-database.tar.zst
-
test_bad_database_url:
machine:
image: ubuntu-2204:2024.04.1
diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py
--- a/litellm/litellm_core_utils/prompt_templates/common_utils.py
+++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py
@@ -1204,12 +1204,8 @@
{"role": "assistant", "content": "I'm good, thank you!"},
{"role": "user", "content": "What is the weather in Tokyo?"},
]
- get_user_prompt(messages) -> "What is the weather in Tokyo?"
+ get_last_user_message(messages) -> "What is the weather in Tokyo?"
"""
- from litellm.litellm_core_utils.prompt_templates.common_utils import (
- convert_content_list_to_str,
- )
-
if not messages:
return None
diff --git a/litellm/proxy/guardrails/guardrail_hooks/azure/base.py b/litellm/proxy/guardrails/guardrail_hooks/azure/base.py
--- a/litellm/proxy/guardrails/guardrail_hooks/azure/base.py
+++ b/litellm/proxy/guardrails/guardrail_hooks/azure/base.py
@@ -2,6 +2,9 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
+from litellm.litellm_core_utils.prompt_templates.common_utils import (
+ get_last_user_message,
+)
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
@@ -134,32 +137,4 @@
]
get_user_prompt(messages) -> "What is the weather in Tokyo?"
"""
- from litellm.litellm_core_utils.prompt_templates.common_utils import (
- convert_content_list_to_str,
- )
-
- if not messages:
- return None
-
- # Iterate from the end to find the last consecutive block of user messages
- user_messages = []
- for message in reversed(messages):
- if message.get("role") == "user":
- user_messages.append(message)
- else:
- # Stop when we hit a non-user message
- break
-
- if not user_messages:
- return None
-
- # Reverse to get the messages in chronological order
- user_messages.reverse()
-
- user_prompt = ""
- for message in user_messages:
- text_content = convert_content_list_to_str(message)
- user_prompt += text_content + "\n"
-
- result = user_prompt.strip()
- return result if result else None
+ return get_last_user_message(messages)
diff --git a/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/__init__.py
new file mode 100644
--- /dev/null
+++ b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/__init__.py
@@ -1,0 +1,57 @@
+from typing import TYPE_CHECKING
+
+from litellm.types.guardrails import SupportedGuardrailIntegrations
+
+from .purview_dlp import MicrosoftPurviewDLPGuardrail
+
+if TYPE_CHECKING:
+ from litellm.types.guardrails import Guardrail, LitellmParams
+
+
+def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
+ import litellm
+
+ tenant_id = getattr(litellm_params, "tenant_id", None)
+ client_id = getattr(litellm_params, "client_id", None)
+
+ # client_secret can be passed via the standard api_key field or as
+ # a dedicated client_secret parameter.
+ client_secret = litellm_params.api_key or getattr(
+ litellm_params, "client_secret", None
+ )
+
+ if not tenant_id:
+ raise ValueError("Microsoft Purview: tenant_id is required")
+ if not client_id:
+ raise ValueError("Microsoft Purview: client_id is required")
+ if not client_secret:
+ raise ValueError("Microsoft Purview: client_secret (or api_key) is required")
+
+ guardrail_name = guardrail.get("guardrail_name")
+ if not guardrail_name:
+ raise ValueError("Microsoft Purview: guardrail_name is required")
+
+ purview_guardrail = MicrosoftPurviewDLPGuardrail(
+ guardrail_name=guardrail_name,
+ tenant_id=str(tenant_id),
+ client_id=str(client_id),
+ client_secret=str(client_secret),
+ purview_app_name=str(
+ getattr(litellm_params, "purview_app_name", None) or "LiteLLM"
+ ),
+ user_id_field=str(getattr(litellm_params, "user_id_field", None) or "user_id"),
+ event_hook=litellm_params.mode,
+ default_on=litellm_params.default_on,
+ )
+
+ litellm.logging_callback_manager.add_litellm_callback(purview_guardrail)
+ return purview_guardrail
+
+
+guardrail_initializer_registry = {
+ SupportedGuardrailIntegrations.MICROSOFT_PURVIEW.value: initialize_guardrail,
+}
+
+guardrail_class_registry = {
+ SupportedGuardrailIntegrations.MICROSOFT_PURVIEW.value: MicrosoftPurviewDLPGuardrail,
+}
diff --git a/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/base.py b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/base.py
new file mode 100644
--- /dev/null
+++ b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/base.py
@@ -1,0 +1,487 @@
+import threading
+import time
+import uuid
+from collections import OrderedDict
+from types import SimpleNamespace
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+from litellm._logging import verbose_proxy_logger
+from litellm.litellm_core_utils.url_utils import encode_url_path_segment
+from litellm.litellm_core_utils.prompt_templates.common_utils import (
+ convert_content_list_to_str,
+)
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+
+if TYPE_CHECKING:
+ from litellm.types.llms.openai import AllMessageValues
+
+GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
+TOKEN_ENDPOINT_TEMPLATE = (
+ "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"
+)
+GRAPH_SCOPE = "https://graph.microsoft.com/.default"
+
+# Protection scope cache TTL in seconds (1 hour, per Microsoft recommendation).
+SCOPE_CACHE_TTL_SECONDS = 3600.0
+
+
+class PurviewGuardrailBase:
+ """
+ Base class for Microsoft Purview guardrails.
+
+ Manages OAuth2 client-credentials token acquisition, protection scope
+ computation with ETag caching, and authenticated POST calls to the
+ Microsoft Graph API.
+ """
+
+ def __init__(
+ self,
+ tenant_id: str,
+ client_id: str,
+ client_secret: str,
+ purview_app_name: str = "LiteLLM",
+ user_id_field: str = "user_id",
+ **kwargs: Any,
+ ):
+ # Forward remaining kwargs to the next class in the MRO
+ # (typically CustomGuardrail).
+ super().__init__(**kwargs)
+
+ self.async_handler = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.GuardrailCallback
+ )
+ self.tenant_id = tenant_id
+ self.client_id = client_id
+ self.client_secret = client_secret
+ self.purview_app_name = purview_app_name
+ self.user_id_field = user_id_field
+
+ # Token cache: (access_token, expires_at_epoch)
+ self._token_cache: Optional[Tuple[str, float]] = None
+
+ # Protection scope cache: user_id -> (etag, scope_response, fetched_at)
+ # Capped at 1000 entries (LRU eviction) to avoid unbounded growth.
+ self._scope_cache: OrderedDict[str, Tuple[str, Dict[str, Any], float]] = (
+ OrderedDict()
+ )
+ self._scope_cache_maxsize = 1000
+ # Use a threading.Lock (not asyncio.Lock) because this lock is acquired
+ # from both the proxy's main asyncio event loop and from short-lived
+ # event loops created by the logging_hook thread fallback. In Python
+ # 3.10+ an asyncio.Lock is bound to the first event loop that acquires
+ # it and raises RuntimeError from any other loop, which would silently
+ # break audit logging via the thread fallback. All critical sections
+ # below are pure in-memory dict ops with no awaits, so a synchronous
+ # lock is both correct and sufficient.
+ self._cache_lock = threading.Lock()
+
+ @staticmethod
+ def _encode_graph_user_id(user_id: str) -> str:
+ """Percent-encode Entra user id for Graph ``/users/{id}/...`` path segments."""
+ return encode_url_path_segment(user_id, field_name="user_id")
+
+ # ------------------------------------------------------------------
+ # OAuth2 token management
+ # ------------------------------------------------------------------
+
+ async def _get_access_token(self) -> str:
+ """Acquire or return cached OAuth2 token via client_credentials grant."""
+ now = time.time()
+ with self._cache_lock:
+ if self._token_cache and self._token_cache[1] > now + 60:
+ return self._token_cache[0]
+
+ url = TOKEN_ENDPOINT_TEMPLATE.format(tenant_id=self.tenant_id)
+ data = {
+ "grant_type": "client_credentials",
+ "client_id": self.client_id,
+ "client_secret": self.client_secret,
+ "scope": GRAPH_SCOPE,
+ }
+ response = await self.async_handler.post(
+ url=url,
+ data=data,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ )
+ response.raise_for_status()
+ token_data = response.json()
+ access_token = token_data["access_token"]
+ expires_in = int(token_data.get("expires_in", 3599))
+ # Recompute ``now`` after the await so the expiry reflects when the
+ # token was actually received, not when the request started.
+ with self._cache_lock:
+ self._token_cache = (access_token, time.time() + expires_in)
+ verbose_proxy_logger.debug(
+ "Purview: acquired new OAuth2 token (expires_in=%ds)", expires_in
+ )
+ return access_token
+
+ # ------------------------------------------------------------------
+ # Graph API helpers
+ # ------------------------------------------------------------------
+
+ async def _graph_post(
+ self,
+ url: str,
+ json_body: Dict[str, Any],
+ extra_headers: Optional[Dict[str, str]] = None,
+ ) -> Tuple[Dict[str, Any], Dict[str, str]]:
+ """POST to Graph API with bearer auth.
+
+ Returns:
+ Tuple of (response_json, response_headers).
+ """
+ token = await self._get_access_token()
+ headers = {
+ "Authorization": f"Bearer {token}",
+ "Content-Type": "application/json",
+ }
+ if extra_headers:
+ headers.update(extra_headers)
+
+ verbose_proxy_logger.debug("Purview Graph POST %s", url)
+ response = await self.async_handler.post(
+ url=url, headers=headers, json=json_body
+ )
+ response.raise_for_status()
+ response_json: Dict[str, Any] = response.json()
+ response_headers = dict(response.headers)
+ verbose_proxy_logger.debug("Purview Graph response: %s", response_json)
+ return response_json, response_headers
+
+ # ------------------------------------------------------------------
+ # Protection scopes
+ # ------------------------------------------------------------------
+
+ async def _compute_protection_scopes(
+ self, user_id: str
+ ) -> Tuple[str, Dict[str, Any]]:
+ """Call protectionScopes/compute and cache with ETag.
+
+ Returns:
+ Tuple of (etag, scope_response).
+ """
+ encoded_user_id = self._encode_graph_user_id(user_id)
+ now = time.time()
+
+ with self._cache_lock:
+ cached = self._scope_cache.get(user_id)
+ if cached and (now - cached[2]) < SCOPE_CACHE_TTL_SECONDS:
+ self._scope_cache.move_to_end(user_id)
+ return cached[0], cached[1]
+
+ url = (
+ f"{GRAPH_API_BASE}/users/{encoded_user_id}"
+ "/dataSecurityAndGovernance/protectionScopes/compute"
+ )
+ body: Dict[str, Any] = {
+ "activities": "uploadText,downloadText",
+ "locations": [
+ {
+ "@odata.type": "microsoft.graph.policyLocationApplication",
+ "value": self.client_id,
+ }
+ ],
+ }
+
+ response_json, response_headers = await self._graph_post(url, body)
+ etag = response_headers.get("etag", response_headers.get("ETag", ""))
+
+ # Recompute ``now`` after the await so the TTL reflects when the
+ # scope response was actually received, not when the request started.
+ fetched_at = time.time()
+ with self._cache_lock:
+ self._scope_cache[user_id] = (etag, response_json, fetched_at)
+ # Move refreshed entry to the end so it is treated as most-recently-used.
+ # OrderedDict.__setitem__ preserves existing insertion order for known
+ # keys, so an explicit move_to_end() call is required.
+ self._scope_cache.move_to_end(user_id)
+ # Evict least-recently-used entry when cache exceeds max size.
+ while len(self._scope_cache) > self._scope_cache_maxsize:
+ self._scope_cache.popitem(last=False)
+ return etag, response_json
+
+ # ------------------------------------------------------------------
+ # Process content
+ # ------------------------------------------------------------------
+
+ async def _process_content(
+ self,
+ user_id: str,
+ text: str,
+ activity: str,
+ etag: str,
+ correlation_id: Optional[str] = None,
+ ) -> Dict[str, Any]:
+ """Call processContent for DLP policy evaluation.
+
+ Args:
+ user_id: Entra object ID of the user.
+ text: The content to evaluate.
+ activity: ``"uploadText"`` for prompts, ``"downloadText"`` for responses.
+ etag: Cached ETag from protectionScopes/compute.
+ correlation_id: Optional conversation/thread ID.
+ """
+ encoded_user_id = self._encode_graph_user_id(user_id)
+ url = (
+ f"{GRAPH_API_BASE}/users/{encoded_user_id}"
+ "/dataSecurityAndGovernance/processContent"
+ )
+ body: Dict[str, Any] = {
+ "contentToProcess": {
+ "contentEntries": [
+ {
+ "@odata.type": "microsoft.graph.processConversationMetadata",
+ "identifier": str(uuid.uuid4()),
+ "content": {
+ "@odata.type": "microsoft.graph.textContent",
+ "data": text,
+ },
+ "name": f"{self.purview_app_name} message",
+ "correlationId": correlation_id or str(uuid.uuid4()),
+ "sequenceNumber": 0,
+ "isTruncated": False,
+ }
+ ],
+ "activityMetadata": {"activity": activity},
+ "deviceMetadata": {},
+ "protectedAppMetadata": {
+ "name": self.purview_app_name,
+ "version": "1.0",
+ "applicationLocation": {
+ "@odata.type": "microsoft.graph.policyLocationApplication",
+ "value": self.client_id,
+ },
+ },
+ "integratedAppMetadata": {
+ "name": self.purview_app_name,
+ "version": "1.0",
+ },
+ }
+ }
+
+ extra_headers: Dict[str, str] = {}
+ if etag:
+ extra_headers["If-None-Match"] = etag
+
+ response_json, _ = await self._graph_post(url, body, extra_headers)
+
+ # If policies changed, invalidate scope cache so next call re-fetches.
+ if response_json.get("protectionScopeState") == "modified":
+ with self._cache_lock:
+ self._scope_cache.pop(user_id, None)
+
+ return response_json
+
+ # ------------------------------------------------------------------
+ # User ID resolution
+ # ------------------------------------------------------------------
+
+ def _resolve_user_id(
+ self, data: Dict[str, Any], user_api_key_dict: Any
+ ) -> Optional[str]:
+ """Resolve the Entra user object ID from request data or auth context.
+
+ Trust order (strongest first). Blocking DLP uses only
+ ``_resolve_trusted_user_id`` (API-key-bound ``user_id``). This resolver
+ also supports audit/logging fallbacks:
+
+ 1. ``user_api_key_dict.user_id`` — LiteLLM key / JWT-bound user
+ 2. ``user_api_key_dict.end_user_id`` — request-derived; audit only
+ 3. ``metadata["user_api_key_user_id"]`` — proxy-injected from the key
+ 4. ``metadata[user_id_field]`` — caller-supplied; audit only
+ """
+ trusted = self._resolve_trusted_user_id(data, user_api_key_dict)
+ if trusted:
+ return trusted
+
+ if hasattr(user_api_key_dict, "end_user_id") and user_api_key_dict.end_user_id:
+ return str(user_api_key_dict.end_user_id)
+
+ metadata = data.get("metadata") or data.get("litellm_metadata") or {}
+ uid = metadata.get("user_api_key_user_id")
+ if uid:
+ return str(uid)
+
+ uid = metadata.get(self.user_id_field)
+ if uid:
+ return str(uid)
+
+ return None
+
+ @staticmethod
+ def _logging_kwargs_metadata(kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ """Metadata dict from ``model_call_details`` / logging kwargs."""
+ litellm_params = kwargs.get("litellm_params") or {}
+ if not isinstance(litellm_params, dict):
+ return {}
+ md = litellm_params.get("metadata")
+ return md if isinstance(md, dict) else {}
+
+ def _resolve_trusted_user_id(
+ self, data: Dict[str, Any], user_api_key_dict: Any
+ ) -> Optional[str]:
+ """Resolve user ID from API-key/JWT-bound identity for blocking DLP.
+
+ Uses only ``UserAPIKeyAuth.user_id`` (bound on the LiteLLM key or JWT).
+ Intentionally omits ``UserAPIKeyAuth.end_user_id`` because the proxy sets
+ it from caller-controlled request fields (``user``, ``metadata.user_id``,
+ ``safety_identifier``, custom headers, etc.) via
+ ``get_end_user_id_from_request_body``.
+
+ Also omits ``metadata[user_id_field]`` and
+ ``metadata["user_api_key_user_id"]`` for the same impersonation risk when
+ the key has no bound user.
+
+ Returns ``None`` when no authenticated identity is available. Blocking
+ hooks must fail closed rather than skip the DLP check.
+ """
+ if hasattr(user_api_key_dict, "user_id") and user_api_key_dict.user_id:
+ return str(user_api_key_dict.user_id)
+
+ return None
+
+ def _resolve_user_id_from_logging_kwargs(
+ self, kwargs: Dict[str, Any]
+ ) -> Optional[str]:
+ """Same trust order as ``_resolve_user_id`` for logging-only hooks (no ``UserAPIKeyAuth``)."""
+ md = self._logging_kwargs_metadata(kwargs)
+ shim = SimpleNamespace(
+ user_id=md.get("user_api_key_user_id")
+ or kwargs.get("user_api_key_user_id"),
+ end_user_id=md.get("user_api_key_end_user_id")
+ or kwargs.get("user_api_key_end_user_id"),
+ )
+ return self._resolve_user_id({"metadata": md}, shim)
+
+ # ------------------------------------------------------------------
+ # Policy action evaluation
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _should_block(response: Dict[str, Any]) -> bool:
+ """Return True if any policyAction requires blocking."""
+ for action in response.get("policyActions", []):
+ odata_type = action.get("@odata.type", "")
+ action_field = action.get("action", "")
+
+ if "restrictAccessAction" in odata_type or action_field == "restrictAccess":
+ restriction = action.get("restrictionAction", "")
+ if restriction == "block":
+ return True
+ return False
+
+ # ------------------------------------------------------------------
+ # Prompt text for DLP
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def completion_prompt_to_str(prompt: Any) -> Optional[str]:
+ """Normalize OpenAI ``/v1/completions`` ``prompt`` for text DLP.
+
+ Supports string prompts and list-of-string prompts. List-of-token-id prompts
+ are skipped (no plaintext for Purview to evaluate).
+ """
+ if prompt is None:
+ return None
+ if isinstance(prompt, str):
+ stripped = prompt.strip()
+ return stripped or None
+ if isinstance(prompt, list) and prompt:
+ if all(isinstance(x, str) for x in prompt):
+ joined = "\n".join(s.strip() for s in prompt if isinstance(s, str))
+ return joined.strip() or None
+ if all(isinstance(x, int) for x in prompt):
+ verbose_proxy_logger.debug(
+ "Purview DLP: completions prompt is token ids only; skipping text scan"
+ )
+ return None
+ str_parts = [x for x in prompt if isinstance(x, str)]
+ if str_parts:
+ joined = "\n".join(s.strip() for s in str_parts)
+ return joined.strip() or None
+ return None
+
+ @staticmethod
+ def _extract_tool_call_args_from_message(message: Any) -> List[str]:
+ """Return plaintext arguments strings from tool_calls and function_call fields.
+
+ Covers both the request path (assistant messages in chat histories that
+ carry tool_calls / function_call) and the response path (model-generated
+ tool calls returned in a ModelResponse). Both dict-style and object-style
+ representations are handled.
+ """
+ args: List[str] = []
+
+ # tool_calls: [{"function": {"arguments": "..."}}]
+ tool_calls = (
+ message.get("tool_calls")
+ if isinstance(message, dict)
+ else getattr(message, "tool_calls", None)
+ )
+ if tool_calls:
+ for tc in tool_calls:
+ fn = (
+ tc.get("function")
+ if isinstance(tc, dict)
+ else getattr(tc, "function", None)
+ )
+ if fn is None:
+ continue
+ arguments = (
+ fn.get("arguments")
+ if isinstance(fn, dict)
+ else getattr(fn, "arguments", None)
+ )
+ if isinstance(arguments, str) and arguments.strip():
+ args.append(arguments)
+
+ # Legacy function_call: {"arguments": "..."}
+ function_call = (
+ message.get("function_call")
+ if isinstance(message, dict)
+ else getattr(message, "function_call", None)
+ )
+ if function_call is not None:
+ arguments = (
+ function_call.get("arguments")
+ if isinstance(function_call, dict)
+ else getattr(function_call, "arguments", None)
+ )
+ if isinstance(arguments, str) and arguments.strip():
+ args.append(arguments)
+
+ return args
+
+ def get_prompt_text_for_dlp(
+ self, messages: List["AllMessageValues"]
+ ) -> Optional[str]:
+ """Concatenate text from every chat message (all roles) for pre-call DLP.
+
+ Evaluates the same payload the model receives, not only the trailing user
+ turn. Each message is separated by ``\\n\\n`` so that tokens at message
+ boundaries are not merged (e.g., ``"end of msg1\\n\\nstart of msg2"``
+ rather than ``"end of msg1start of msg2"``), which preserves DLP pattern
+ detection accuracy across message boundaries.
+
+ Tool-call arguments (``tool_calls[].function.arguments`` and
+ ``function_call.arguments``) are included alongside message content so
+ that sensitive data hidden in function arguments is not bypassed.
+ """
+ if not messages:
+ return None
+ parts: List[str] = []
+ for msg in messages:
+ segments: List[str] = []
+ content = convert_content_list_to_str(message=msg).strip()
+ if content:
+ segments.append(content)
+ segments.extend(self._extract_tool_call_args_from_message(msg))
+ combined = "\n".join(segments)
+ if combined.strip():
+ parts.append(combined.strip())
+ text = "\n\n".join(parts)
+ return text or None
diff --git a/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/purview_dlp.py b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/purview_dlp.py
new file mode 100644
--- /dev/null
+++ b/litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/purview_dlp.py
@@ -1,0 +1,697 @@
+"""
+Microsoft Purview DLP Guardrail for LiteLLM.
+
+Supports three modes:
+- pre_call: Block sensitive data in prompts before they reach the LLM.
+- post_call: Block sensitive data in LLM responses.
+- logging_only: Log interactions to Purview for audit/compliance without blocking.
+"""
+
+import asyncio
+import threading
+import uuid
+from datetime import datetime
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncGenerator,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
+
+from fastapi import HTTPException
+
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_guardrail import (
+ CustomGuardrail,
+ log_guardrail_information,
+)
+from litellm.types.guardrails import GuardrailEventHooks
+from litellm.types.utils import (
+ Choices,
+ GuardrailStatus,
+ ModelResponse,
+ ModelResponseStream,
+ ResponsesAPIResponse,
+ TextChoices,
+ TextCompletionResponse,
+)
+
+from .base import PurviewGuardrailBase
+
+if TYPE_CHECKING:
+ from litellm.proxy._types import UserAPIKeyAuth
+ from litellm.types.proxy.guardrails.guardrail_hooks.base import (
+ GuardrailConfigModel,
+ )
+ from litellm.types.utils import (
+ CallTypesLiteral,
+ EmbeddingResponse,
+ ImageResponse,
+ )
+
+
+class MicrosoftPurviewDLPGuardrail(PurviewGuardrailBase, CustomGuardrail):
+ """
+ Microsoft Purview DLP guardrail.
+
+ Evaluates prompts and responses against Microsoft Purview DLP policies
+ via the Microsoft Graph ``processContent`` API.
+ """
+
+ def __init__(
+ self,
+ guardrail_name: str,
+ tenant_id: str,
+ client_id: str,
+ client_secret: str,
+ purview_app_name: str = "LiteLLM",
+ user_id_field: str = "user_id",
+ **kwargs: Any,
+ ):
+ supported_event_hooks = [
+ GuardrailEventHooks.pre_call,
+ GuardrailEventHooks.post_call,
+ GuardrailEventHooks.logging_only,
+ ]
+
+ super().__init__(
+ tenant_id=tenant_id,
+ client_id=client_id,
+ client_secret=client_secret,
+ purview_app_name=purview_app_name,
+ user_id_field=user_id_field,
+ guardrail_name=guardrail_name,
+ supported_event_hooks=supported_event_hooks,
+ **kwargs,
+ )
+ self.guardrail_provider = "microsoft_purview"
+ verbose_proxy_logger.info(
+ "Initialized Microsoft Purview DLP Guardrail: %s",
+ guardrail_name,
+ )
+
+ @staticmethod
+ def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
+ return None # Config model can be added later for UI support
+
+ # ------------------------------------------------------------------
+ # Core DLP check
+ # ------------------------------------------------------------------
+
+ async def _check_content(
+ self,
+ user_id: str,
+ text: str,
+ activity: str,
+ request_data: Dict[str, Any],
+ block_on_violation: bool = True,
+ ) -> Dict[str, Any]:
+ """Evaluate content against Purview DLP policies.
+
+ Args:
+ user_id: Entra object ID.
+ text: Content to evaluate.
+ activity: ``"uploadText"`` or ``"downloadText"``.
+ request_data: Original request dict (used for logging metadata).
+ block_on_violation: If False, log only — do not raise.
+
+ Returns:
+ The processContent response dict.
+ """
+ start_time = datetime.now()
+ status: GuardrailStatus = "success"
+ response: Dict[str, Any] = {}
+
+ try:
+ etag, _ = await self._compute_protection_scopes(user_id)
+ correlation_id = request_data.get("litellm_call_id") or str(uuid.uuid4())
+ response = await self._process_content(
+ user_id=user_id,
+ text=text,
+ activity=activity,
+ etag=etag,
+ correlation_id=correlation_id,
+ )
+
+ if self._should_block(response):
+ status = "guardrail_intervened"
+ except HTTPException:
+ status = "guardrail_failed_to_respond"
+ raise
+ except Exception as exc:
+ status = "guardrail_failed_to_respond"
+ if block_on_violation:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Microsoft Purview DLP: upstream policy evaluation failed",
+ "activity": activity,
+ "exception": str(exc),
+ },
+ ) from exc
+ verbose_proxy_logger.warning(
+ "Purview DLP: API/network error in logging-only mode (not re-raised): %s",
+ exc,
+ )
... diff truncated: showing 800 of 3961 linesYou can send follow-ups to the cloud agent here.
Reviewed by Cursor Bugbot for commit b269c24. Configure here.
ResponsesAPIResponse.output_text only aggregates output_text content blocks and ignores function_call items, so sensitive data in model-generated tool-call arguments would bypass the DLP scan. Mirror the ModelResponse path by extracting function_call arguments explicitly from the output list. Co-authored-by: Yassin Kortam <yassin@berri.ai>
…allback If a Responses API stream slips past _assemble_responses_api_from_chunks (no chunks with type starting with 'response.') and stream_chunk_builder somehow returns a ResponsesAPIResponse, route it through _completion_response_text_parts instead of the 'not a ModelResponse' pass-through that would leak content unscanned.
|
bugbot run |
…errors httpx.HTTPStatusError from Graph API (429, 503, etc.) was always wrapped as HTTPException(400), making rate-limits and infrastructure errors indistinguishable from a DLP policy block and stripping retry-after info. Now: - 429 and 5xx pass through with their original status code; the upstream Retry-After header is forwarded. - 401/403 (proxy-side credential/consent issue, not actionable by the client) map to 502 Bad Gateway. - A debug log makes the logging_hook -> async_logging_hook deferral observable so audit failures don't silently disappear if the framework stops dispatching async_logging_hook for some code path.
OpenAI /v1/completions accepts prompt: [[token, ids]] (multi-prompt token-id batches). The previous blocking-mode check only fired on a flat list[int], so nested or mixed token-id prompts skipped the Purview scan while the model still received the data. Extract the token-id detection into PurviewGuardrailBase.is_token_id_prompt and use it from the pre-call hook so every list shape Purview cannot decode fails closed.
…esolver Logging-only hook now resolves the Purview user from only the proxy-injected user_api_key_user_id (which mirrors UserAPIKeyAuth.user_id after the proxy strips caller-supplied user_api_key_* keys). Skipping the audit when no trusted identity is available prevents a caller from submitting metadata.user_id pointing at a victim's Entra object id and having their prompt/response sent to Purview under that user's identity.

Relevant issues
Pre-Submission checklist
tests/test_litellm/directory, Adding at least 1 test is a hard requirement - see detailsmake test-unit@greptileaiand received a Confidence Score of at least 4/5 before requesting a maintainer reviewDelays in PR merge?
If you're seeing a delay in your PR being merged, ping the LiteLLM Team on Slack (#pr-review).
CI (LiteLLM team)
Branch creation CI run
Link:
CI run for the last commit
Link:
Merge / cherry-pick CI run
Links:
Type
🆕 New Feature
Changes
Adds a Microsoft Purview DLP guardrail that evaluates prompts and responses against Purview DLP policies via the Microsoft Graph
processContentAPI.Three modes supported:
pre_call— blocks sensitive prompts before they reach the LLMpost_call— blocks sensitive content in LLM responseslogging_only— sends both prompt and response to Purview for audit without blockingImplementation:
litellm/proxy/guardrails/guardrail_hooks/microsoft_purview/base.py—PurviewGuardrailBasehandles OAuth2 client-credentials token acquisition, protection scope caching (ETag-based, 1h TTL per Microsoft recommendation), and Graph API callslitellm/proxy/guardrails/guardrail_hooks/microsoft_purview/purview_dlp.py—MicrosoftPurviewDLPGuardrailimplements pre/post call hooks and audit logging hooklitellm/proxy/guardrails/guardrail_hooks/microsoft_purview/__init__.py— initializer + auto-discovery registrieslitellm/types/guardrails.py— addsMICROSOFT_PURVIEWtoSupportedGuardrailIntegrationsdocs/my-website/docs/proxy/guardrails/microsoft_purview.md— full usage docsNote
Medium Risk
Adds a new guardrail that can block prompts/responses based on external Microsoft Graph DLP policy evaluation, introducing new failure modes (auth/rate limits/network) and request-flow impact across pre/post-call and streaming paths.
Overview
Adds a Microsoft Purview DLP guardrail integration that evaluates prompts and responses via Microsoft Graph (
protectionScopes/compute+processContent), supporting blocking (pre_call,post_call, including buffered streaming) and audit-only (logging_only) modes with token + protection-scope caching and identity hardening.Refactors existing OpenAI/Azure guardrail bases to reuse the shared
get_last_user_message()helper instead of duplicating the “last user block” extraction logic, and registers the new integration underSupportedGuardrailIntegrations.MICROSOFT_PURVIEWwith extensive unit test coverage.Reviewed by Cursor Bugbot for commit 87fc3bc. Bugbot is set up for automated code reviews on this repo. Configure here.