Skip to content

Commit 4d818b9

Browse files
feat: Support timeout as aiohttp.ClientTimeout and total_attempts (max retries) in AsyncAuthorizedSession (#1961)
b/485304839 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent aba9348 commit 4d818b9

File tree

5 files changed

+262
-24
lines changed

5 files changed

+262
-24
lines changed

packages/google-auth/google/auth/aio/transport/aiohttp.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Transport adapter for Asynchronous HTTP Requests based on aiohttp.
16-
"""
15+
"""Transport adapter for Asynchronous HTTP Requests based on aiohttp."""
1716

1817
import asyncio
1918
import logging
20-
from typing import AsyncGenerator, Mapping, Optional
19+
from typing import AsyncGenerator, Mapping, Optional, TYPE_CHECKING, Union
2120

2221
try:
2322
import aiohttp # type: ignore
@@ -31,6 +30,15 @@
3130
from google.auth.aio import _helpers as _helpers_async
3231
from google.auth.aio import transport
3332

33+
if TYPE_CHECKING: # pragma: NO COVER
34+
from aiohttp import ClientTimeout # type: ignore
35+
36+
else:
37+
try:
38+
from aiohttp import ClientTimeout
39+
except (ImportError, AttributeError):
40+
ClientTimeout = None
41+
3442
_LOGGER = logging.getLogger(__name__)
3543

3644

@@ -123,7 +131,7 @@ async def __call__(
123131
method: str = "GET",
124132
body: Optional[bytes] = None,
125133
headers: Optional[Mapping[str, str]] = None,
126-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
134+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
127135
**kwargs,
128136
) -> transport.Response:
129137
"""
@@ -158,7 +166,10 @@ async def __call__(
158166
if not self._session:
159167
self._session = aiohttp.ClientSession()
160168

161-
client_timeout = aiohttp.ClientTimeout(total=timeout)
169+
if isinstance(timeout, aiohttp.ClientTimeout):
170+
client_timeout = timeout
171+
else:
172+
client_timeout = aiohttp.ClientTimeout(total=timeout)
162173
_helpers.request_log(_LOGGER, method, url, body, headers)
163174
response = await self._session.request(
164175
method,
@@ -176,8 +187,12 @@ async def __call__(
176187
raise client_exc from caught_exc
177188

178189
except asyncio.TimeoutError as caught_exc:
190+
if isinstance(timeout, aiohttp.ClientTimeout):
191+
timeout_seconds = timeout.total
192+
else:
193+
timeout_seconds = timeout
179194
timeout_exc = exceptions.TimeoutError(
180-
f"Request timed out after {timeout} seconds."
195+
f"Request timed out after {timeout_seconds} seconds."
181196
)
182197
raise timeout_exc from caught_exc
183198

packages/google-auth/google/auth/aio/transport/sessions.py

Lines changed: 221 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,22 @@
1616
from contextlib import asynccontextmanager
1717
import functools
1818
import time
19-
from typing import Mapping, Optional
19+
from typing import Mapping, Optional, TYPE_CHECKING, Union
2020

2121
from google.auth import _exponential_backoff, exceptions
2222
from google.auth.aio import transport
2323
from google.auth.aio.credentials import Credentials
2424
from google.auth.exceptions import TimeoutError
2525

26+
if TYPE_CHECKING: # pragma: NO COVER
27+
from aiohttp import ClientTimeout # type: ignore
28+
29+
else:
30+
try:
31+
from aiohttp import ClientTimeout
32+
except (ImportError, AttributeError):
33+
ClientTimeout = None
34+
2635
try:
2736
from google.auth.aio.transport.aiohttp import Request as AiohttpRequest
2837

@@ -137,7 +146,8 @@ async def request(
137146
data: Optional[bytes] = None,
138147
headers: Optional[Mapping[str, str]] = None,
139148
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
140-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
149+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
150+
total_attempts: Optional[int] = transport.DEFAULT_MAX_RETRY_ATTEMPTS,
141151
**kwargs,
142152
) -> transport.Response:
143153
"""
@@ -146,14 +156,16 @@ async def request(
146156
url (str): The URI to be requested.
147157
data (Optional[bytes]): The payload or body in HTTP request.
148158
headers (Optional[Mapping[str, str]]): Request headers.
149-
timeout (float):
159+
timeout (float, aiohttp.ClientTimeout):
150160
The amount of time in seconds to wait for the server response
151161
with each individual request.
152162
max_allowed_time (float):
153163
If the method runs longer than this, a ``Timeout`` exception is
154164
automatically raised. Unlike the ``timeout`` parameter, this
155165
value applies to the total method execution time, even if
156166
multiple requests are made under the hood.
167+
total_attempts (int):
168+
The total number of retry attempts.
157169
158170
Mind that it is not guaranteed that the timeout error is raised
159171
at ``max_allowed_time``. It might take longer, for example, if
@@ -172,7 +184,7 @@ async def request(
172184
"""
173185

174186
retries = _exponential_backoff.AsyncExponentialBackoff(
175-
total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS
187+
total_attempts=total_attempts,
176188
)
177189
async with timeout_guard(max_allowed_time) as with_timeout:
178190
await with_timeout(
@@ -198,11 +210,50 @@ async def get(
198210
data: Optional[bytes] = None,
199211
headers: Optional[Mapping[str, str]] = None,
200212
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
201-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
213+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
214+
total_attempts: Optional[int] = transport.DEFAULT_MAX_RETRY_ATTEMPTS,
202215
**kwargs,
203216
) -> transport.Response:
217+
"""
218+
Args:
219+
url (str): The URI to be requested.
220+
data (Optional[bytes]): The payload or body in HTTP request.
221+
headers (Optional[Mapping[str, str]]): Request headers.
222+
max_allowed_time (float):
223+
If the method runs longer than this, a ``Timeout`` exception is
224+
automatically raised. Unlike the ``timeout`` parameter, this
225+
value applies to the total method execution time, even if
226+
multiple requests are made under the hood.
227+
timeout (float, aiohttp.ClientTimeout):
228+
The amount of time in seconds to wait for the server response
229+
with each individual request.
230+
total_attempts (int):
231+
The total number of retry attempts.
232+
233+
Mind that it is not guaranteed that the timeout error is raised
234+
at ``max_allowed_time``. It might take longer, for example, if
235+
an underlying request takes a lot of time, but the request
236+
itself does not timeout, e.g. if a large file is being
237+
transmitted. The timeout error will be raised after such
238+
request completes.
239+
240+
Returns:
241+
google.auth.aio.transport.Response: The HTTP response.
242+
243+
Raises:
244+
google.auth.exceptions.TimeoutError: If the method does not complete within
245+
the configured `max_allowed_time` or the request exceeds the configured
246+
`timeout`.
247+
"""
204248
return await self.request(
205-
"GET", url, data, headers, max_allowed_time, timeout, **kwargs
249+
"GET",
250+
url,
251+
data,
252+
headers,
253+
max_allowed_time,
254+
timeout,
255+
total_attempts,
256+
**kwargs,
206257
)
207258

208259
@functools.wraps(request)
@@ -212,11 +263,50 @@ async def post(
212263
data: Optional[bytes] = None,
213264
headers: Optional[Mapping[str, str]] = None,
214265
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
215-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
266+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
267+
total_attempts: Optional[int] = transport.DEFAULT_MAX_RETRY_ATTEMPTS,
216268
**kwargs,
217269
) -> transport.Response:
270+
"""
271+
Args:
272+
url (str): The URI to be requested.
273+
data (Optional[bytes]): The payload or body in HTTP request.
274+
headers (Optional[Mapping[str, str]]): Request headers.
275+
max_allowed_time (float):
276+
If the method runs longer than this, a ``Timeout`` exception is
277+
automatically raised. Unlike the ``timeout`` parameter, this
278+
value applies to the total method execution time, even if
279+
multiple requests are made under the hood.
280+
timeout (float, aiohttp.ClientTimeout):
281+
The amount of time in seconds to wait for the server response
282+
with each individual request.
283+
total_attempts (int):
284+
The total number of retry attempts.
285+
286+
Mind that it is not guaranteed that the timeout error is raised
287+
at ``max_allowed_time``. It might take longer, for example, if
288+
an underlying request takes a lot of time, but the request
289+
itself does not timeout, e.g. if a large file is being
290+
transmitted. The timeout error will be raised after such
291+
request completes.
292+
293+
Returns:
294+
google.auth.aio.transport.Response: The HTTP response.
295+
296+
Raises:
297+
google.auth.exceptions.TimeoutError: If the method does not complete within
298+
the configured `max_allowed_time` or the request exceeds the configured
299+
`timeout`.
300+
"""
218301
return await self.request(
219-
"POST", url, data, headers, max_allowed_time, timeout, **kwargs
302+
"POST",
303+
url,
304+
data,
305+
headers,
306+
max_allowed_time,
307+
timeout,
308+
total_attempts,
309+
**kwargs,
220310
)
221311

222312
@functools.wraps(request)
@@ -226,11 +316,50 @@ async def put(
226316
data: Optional[bytes] = None,
227317
headers: Optional[Mapping[str, str]] = None,
228318
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
229-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
319+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
320+
total_attempts: Optional[int] = transport.DEFAULT_MAX_RETRY_ATTEMPTS,
230321
**kwargs,
231322
) -> transport.Response:
323+
"""
324+
Args:
325+
url (str): The URI to be requested.
326+
data (Optional[bytes]): The payload or body in HTTP request.
327+
headers (Optional[Mapping[str, str]]): Request headers.
328+
max_allowed_time (float):
329+
If the method runs longer than this, a ``Timeout`` exception is
330+
automatically raised. Unlike the ``timeout`` parameter, this
331+
value applies to the total method execution time, even if
332+
multiple requests are made under the hood.
333+
timeout (float, aiohttp.ClientTimeout):
334+
The amount of time in seconds to wait for the server response
335+
with each individual request.
336+
total_attempts (int):
337+
The total number of retry attempts.
338+
339+
Mind that it is not guaranteed that the timeout error is raised
340+
at ``max_allowed_time``. It might take longer, for example, if
341+
an underlying request takes a lot of time, but the request
342+
itself does not timeout, e.g. if a large file is being
343+
transmitted. The timeout error will be raised after such
344+
request completes.
345+
346+
Returns:
347+
google.auth.aio.transport.Response: The HTTP response.
348+
349+
Raises:
350+
google.auth.exceptions.TimeoutError: If the method does not complete within
351+
the configured `max_allowed_time` or the request exceeds the configured
352+
`timeout`.
353+
"""
232354
return await self.request(
233-
"PUT", url, data, headers, max_allowed_time, timeout, **kwargs
355+
"PUT",
356+
url,
357+
data,
358+
headers,
359+
max_allowed_time,
360+
timeout,
361+
total_attempts,
362+
**kwargs,
234363
)
235364

236365
@functools.wraps(request)
@@ -240,11 +369,50 @@ async def patch(
240369
data: Optional[bytes] = None,
241370
headers: Optional[Mapping[str, str]] = None,
242371
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
243-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
372+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
373+
total_attempts: Optional[int] = transport.DEFAULT_MAX_RETRY_ATTEMPTS,
244374
**kwargs,
245375
) -> transport.Response:
376+
"""
377+
Args:
378+
url (str): The URI to be requested.
379+
data (Optional[bytes]): The payload or body in HTTP request.
380+
headers (Optional[Mapping[str, str]]): Request headers.
381+
max_allowed_time (float):
382+
If the method runs longer than this, a ``Timeout`` exception is
383+
automatically raised. Unlike the ``timeout`` parameter, this
384+
value applies to the total method execution time, even if
385+
multiple requests are made under the hood.
386+
timeout (float, aiohttp.ClientTimeout):
387+
The amount of time in seconds to wait for the server response
388+
with each individual request.
389+
total_attempts (int):
390+
The total number of retry attempts.
391+
392+
Mind that it is not guaranteed that the timeout error is raised
393+
at ``max_allowed_time``. It might take longer, for example, if
394+
an underlying request takes a lot of time, but the request
395+
itself does not timeout, e.g. if a large file is being
396+
transmitted. The timeout error will be raised after such
397+
request completes.
398+
399+
Returns:
400+
google.auth.aio.transport.Response: The HTTP response.
401+
402+
Raises:
403+
google.auth.exceptions.TimeoutError: If the method does not complete within
404+
the configured `max_allowed_time` or the request exceeds the configured
405+
`timeout`.
406+
"""
246407
return await self.request(
247-
"PATCH", url, data, headers, max_allowed_time, timeout, **kwargs
408+
"PATCH",
409+
url,
410+
data,
411+
headers,
412+
max_allowed_time,
413+
timeout,
414+
total_attempts,
415+
**kwargs,
248416
)
249417

250418
@functools.wraps(request)
@@ -254,11 +422,50 @@ async def delete(
254422
data: Optional[bytes] = None,
255423
headers: Optional[Mapping[str, str]] = None,
256424
max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
257-
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
425+
timeout: Union[float, ClientTimeout] = transport._DEFAULT_TIMEOUT_SECONDS,
426+
total_attempts: Optional[int] = transport.DEFAULT_MAX_RETRY_ATTEMPTS,
258427
**kwargs,
259428
) -> transport.Response:
429+
"""
430+
Args:
431+
url (str): The URI to be requested.
432+
data (Optional[bytes]): The payload or body in HTTP request.
433+
headers (Optional[Mapping[str, str]]): Request headers.
434+
max_allowed_time (float):
435+
If the method runs longer than this, a ``Timeout`` exception is
436+
automatically raised. Unlike the ``timeout`` parameter, this
437+
value applies to the total method execution time, even if
438+
multiple requests are made under the hood.
439+
timeout (float, aiohttp.ClientTimeout):
440+
The amount of time in seconds to wait for the server response
441+
with each individual request.
442+
total_attempts (int):
443+
The total number of retry attempts.
444+
445+
Mind that it is not guaranteed that the timeout error is raised
446+
at ``max_allowed_time``. It might take longer, for example, if
447+
an underlying request takes a lot of time, but the request
448+
itself does not timeout, e.g. if a large file is being
449+
transmitted. The timeout error will be raised after such
450+
request completes.
451+
452+
Returns:
453+
google.auth.aio.transport.Response: The HTTP response.
454+
455+
Raises:
456+
google.auth.exceptions.TimeoutError: If the method does not complete within
457+
the configured `max_allowed_time` or the request exceeds the configured
458+
`timeout`.
459+
"""
260460
return await self.request(
261-
"DELETE", url, data, headers, max_allowed_time, timeout, **kwargs
461+
"DELETE",
462+
url,
463+
data,
464+
headers,
465+
max_allowed_time,
466+
timeout,
467+
total_attempts,
468+
**kwargs,
262469
)
263470

264471
async def close(self) -> None:

0 commit comments

Comments
 (0)