|
6 | 6 | from traceback import format_exception |
7 | 7 | from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar |
8 | 8 |
|
| 9 | +from botocore.exceptions import ClientError |
| 10 | +from fastapi import Request # noqa: TC002 |
9 | 11 | from pydantic import AwareDatetime, BaseModel, JsonValue |
| 12 | +from sse_starlette import JSONServerSentEvent |
10 | 13 |
|
| 14 | +from stdapi.api_errors import ApiError |
| 15 | +from stdapi.api_providers import format_http_error |
| 16 | +from stdapi.aws_bedrock import AWS_ERROR_MAP |
11 | 17 | from stdapi.config import SETTINGS, LogLevel |
12 | 18 | from stdapi.metering import SERVER_FULL_VERSION |
13 | 19 | from stdapi.server import SERVER_NAME |
14 | | -from stdapi.utils import stdout_write, webuuid |
| 20 | +from stdapi.utils import hide_security_details, stdout_write, webuuid |
15 | 21 |
|
16 | 22 | if TYPE_CHECKING: |
17 | 23 | from collections.abc import AsyncGenerator, Generator |
18 | 24 |
|
19 | | - from fastapi import Request |
20 | 25 | from pydantic.main import IncEx |
21 | 26 | from starlette.datastructures import Headers |
22 | 27 | from types_aiobotocore_meteringmarketplace.type_defs import ( |
@@ -99,6 +104,9 @@ class EventLog(TypedDict): |
99 | 104 | #: Request HTTP headers |
100 | 105 | REQUEST_HEADERS: ContextVar[Headers] = ContextVar("request_headers") |
101 | 106 |
|
| 107 | +#: HTTP request object |
| 108 | +REQUEST: ContextVar[Request] = ContextVar("request") |
| 109 | + |
102 | 110 | #: Paths to ignore in logging |
103 | 111 | LOGGING_PATHS_IGNORE = { |
104 | 112 | "/", |
@@ -162,6 +170,7 @@ def log_request_event(request: Request) -> Generator[EventLog]: |
162 | 170 | REQUEST_ID.set(request_id) |
163 | 171 | request_time = SETTINGS.now() |
164 | 172 | REQUEST_TIME.set(request_time) |
| 173 | + REQUEST.set(request) |
165 | 174 | log = EventLog( |
166 | 175 | type="request", |
167 | 176 | level="info", |
@@ -408,3 +417,58 @@ async def log_request_stream_event[T](stream: AsyncGenerator[T]) -> AsyncGenerat |
408 | 417 | Items from the input asynchronous generator in their modified or original form. |
409 | 418 | """ |
410 | 419 | return _rebuild_and_log_stream(await stream.__anext__(), stream) |
| 420 | + |
| 421 | + |
| 422 | +async def log_request_sse_stream_event( |
| 423 | + stream: AsyncGenerator[JSONServerSentEvent], |
| 424 | +) -> AsyncGenerator[JSONServerSentEvent]: |
| 425 | + """Log, monitor, and error-guard an SSE stream for use with ``EventSourceResponse``. |
| 426 | +
|
| 427 | + Combines :func:`log_request_stream_event` and an SSE error boundary into a |
| 428 | + single step. After the HTTP response headers are sent, any exception that |
| 429 | + escapes the underlying generator cannot be turned into an HTTP error response |
| 430 | + (Starlette raises ``RuntimeError: Caught handled exception, but response |
| 431 | + already started``). This wrapper catches such exceptions, logs them via |
| 432 | + :func:`log_error_details`, and yields a terminal ``error`` SSE event |
| 433 | + formatted for the matched API provider so that ``EventSourceResponse`` can |
| 434 | + close the connection cleanly. |
| 435 | +
|
| 436 | + Args: |
| 437 | + stream: Raw SSE async generator (e.g. from an adapter's ``format_stream``). |
| 438 | +
|
| 439 | + Yields: |
| 440 | + Items from ``stream`` (after monitoring setup), followed by a provider- |
| 441 | + formatted ``error`` SSE event on failure. |
| 442 | + """ |
| 443 | + try: |
| 444 | + async for chunk in _rebuild_and_log_stream(await stream.__anext__(), stream): |
| 445 | + yield chunk |
| 446 | + except ApiError as exc: |
| 447 | + status = exc.status |
| 448 | + log_error_details(exc.args[0], status=status) |
| 449 | + yield JSONServerSentEvent( |
| 450 | + data=format_http_error( |
| 451 | + REQUEST.get(), |
| 452 | + status, |
| 453 | + hide_security_details(status, exc.args[0]), |
| 454 | + exc.param, |
| 455 | + exc.code, |
| 456 | + )[0], |
| 457 | + event="error", |
| 458 | + ) |
| 459 | + except ClientError as exc: |
| 460 | + error = exc.response["Error"] |
| 461 | + status = AWS_ERROR_MAP.get(error["Code"], (502, "server_error"))[0] |
| 462 | + log_error_details(error["Message"], status=status) |
| 463 | + yield JSONServerSentEvent( |
| 464 | + data=format_http_error( |
| 465 | + REQUEST.get(), status, hide_security_details(status, error["Message"]) |
| 466 | + )[0], |
| 467 | + event="error", |
| 468 | + ) |
| 469 | + except Exception as exc: # noqa: BLE001 |
| 470 | + log_error_details("\n".join(format_exception(exc)), level="critical") |
| 471 | + yield JSONServerSentEvent( |
| 472 | + data=format_http_error(REQUEST.get(), 500, "Internal Server Error")[0], |
| 473 | + event="error", |
| 474 | + ) |
0 commit comments