Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.

Commit 14b1760

Browse files
committed
Provide AsyncIO support for generated code
1 parent 920e419 commit 14b1760

24 files changed

Lines changed: 1406 additions & 55 deletions

File tree

.circleci/config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ jobs:
236236
name: Install system dependencies.
237237
command: |
238238
apt-get update
239-
apt-get install -y curl pandoc unzip
239+
apt-get install -y curl pandoc unzip git
240240
- run:
241241
name: Install nox.
242242
command: pip install nox
@@ -302,7 +302,7 @@ jobs:
302302
name: Install system dependencies.
303303
command: |
304304
apt-get update
305-
apt-get install -y curl pandoc unzip
305+
apt-get install -y curl pandoc unzip git
306306
- run:
307307
name: Install nox.
308308
command: pip install nox

gapic/schema/wrappers.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,13 @@ def __getattr__(self, name):
566566

567567
@utils.cached_property
568568
def client_output(self):
569+
return self._client_output(enable_asyncio=False)
570+
571+
@utils.cached_property
572+
def client_output_async(self):
573+
return self._client_output(enable_asyncio=True)
574+
575+
def _client_output(self, enable_asyncio: bool):
569576
"""Return the output from the client layer.
570577
571578
This takes into account transformations made by the outer GAPIC
@@ -584,8 +591,8 @@ def client_output(self):
584591
if self.lro:
585592
return PythonType(meta=metadata.Metadata(
586593
address=metadata.Address(
587-
name='Operation',
588-
module='operation',
594+
name='AsyncOperation' if enable_asyncio else 'Operation',
595+
module='operation_async' if enable_asyncio else 'operation',
589596
package=('google', 'api_core'),
590597
collisions=self.lro.response_type.ident.collisions,
591598
),
@@ -603,7 +610,7 @@ def client_output(self):
603610
if self.paged_result_field:
604611
return PythonType(meta=metadata.Metadata(
605612
address=metadata.Address(
606-
name=f'{self.name}Pager',
613+
name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager',
607614
package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + (
608615
'services',
609616
utils.to_snake_case(self.ident.parent[-1]),
@@ -734,6 +741,8 @@ def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]:
734741
if not self.void:
735742
answer.append(self.client_output)
736743
answer.extend(self.client_output.field_types)
744+
answer.append(self.client_output_async)
745+
answer.extend(self.client_output_async.field_types)
737746

738747
# If this method has LRO, it is possible (albeit unlikely) that
739748
# the LRO messages reside in a different module.
@@ -791,6 +800,11 @@ def client_name(self) -> str:
791800
"""Returns the name of the generated client class"""
792801
return self.name + "Client"
793802

803+
@property
804+
def async_client_name(self) -> str:
805+
"""Returns the name of the generated AsyncIO client class"""
806+
return self.name + "AsyncClient"
807+
794808
@property
795809
def transport_name(self):
796810
return self.name + "Transport"
@@ -799,6 +813,10 @@ def transport_name(self):
799813
def grpc_transport_name(self):
800814
return self.name + "GrpcTransport"
801815

816+
@property
817+
def grpc_asyncio_transport_name(self):
818+
return self.name + "GrpcAsyncIOTransport"
819+
802820
@property
803821
def has_lro(self) -> bool:
804822
"""Return whether the service has a long-running method."""
@@ -846,7 +864,7 @@ def names(self) -> FrozenSet[str]:
846864
used for imports.
847865
"""
848866
# Put together a set of the service and method names.
849-
answer = {self.name, self.client_name}
867+
answer = {self.name, self.client_name, self.async_client_name}
850868
answer.update(
851869
utils.to_snake_case(i.name) for i in self.methods.values()
852870
)

gapic/templates/%namespace/%name/__init__.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ _lazy_name_to_package_map = {
3535
'types': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.types',
3636
{%- for service in api.services.values()|sort(attribute='name')|unique(attribute='name') if service.meta.address.subpackage == api.subpackage_view %}
3737
'{{ service.client_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client',
38+
'{{ service.async_client_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client',
3839
'{{ service.transport_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.transports.base',
3940
'{{ service.grpc_transport_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.transports.grpc',
4041
{%- endfor %} {# Need to do types and enums #}
@@ -105,6 +106,8 @@ from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.'
105106
if service.meta.address.subpackage == api.subpackage_view -%}
106107
from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%}
107108
{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client import {{ service.client_name }}
109+
from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%}
110+
{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client import {{ service.async_client_name }}
108111
{% endfor -%}
109112

110113
{# Import messages and enums from each proto.
@@ -141,6 +144,7 @@ __all__ = (
141144
{% for service in api.services.values()|sort(attribute='name')
142145
if service.meta.address.subpackage == api.subpackage_view -%}
143146
'{{ service.client_name }}',
147+
'{{ service.async_client_name }}',
144148
{% endfor -%}
145149
{% for proto in api.protos.values()|sort(attribute='module_name')
146150
if proto.meta.address.subpackage == api.subpackage_view -%}

gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
{% block content %}
44
from .client import {{ service.client_name }}
5+
from .async_client import {{ service.async_client_name }}
56

67
__all__ = (
78
'{{ service.client_name }}',
9+
'{{ service.async_client_name }}',
810
)
911
{% endblock %}
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
{% extends '_base.py.j2' %}
2+
3+
{% block content %}
4+
from collections import OrderedDict
5+
import functools
6+
import re
7+
from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
8+
import pkg_resources
9+
10+
import google.api_core.client_options as ClientOptions # type: ignore
11+
from google.api_core import exceptions # type: ignore
12+
from google.api_core import gapic_v1 # type: ignore
13+
from google.api_core import retry as retries # type: ignore
14+
from google.auth import credentials # type: ignore
15+
from google.oauth2 import service_account # type: ignore
16+
17+
{% filter sort_lines -%}
18+
{% for method in service.methods.values() -%}
19+
{% for ref_type in method.flat_ref_types -%}
20+
{{ ref_type.ident.python_import }}
21+
{% endfor -%}
22+
{% endfor -%}
23+
{% endfilter %}
24+
from .transports.base import {{ service.name }}Transport
25+
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}
26+
from .client import {{ service.client_name }}
27+
28+
29+
class {{ service.async_client_name }}:
30+
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""
31+
32+
_client: {{ service.client_name }}
33+
34+
DEFAULT_ENDPOINT = {{ service.client_name }}.DEFAULT_ENDPOINT
35+
DEFAULT_MTLS_ENDPOINT = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT
36+
37+
{% for message in service.resource_messages -%}
38+
{{ message.resource_type|snake_case }}_path = staticmethod({{ service.client_name }}.{{ message.resource_type|snake_case }}_path)
39+
40+
{% endfor %}
41+
42+
from_service_account_file = {{ service.client_name }}.from_service_account_file
43+
from_service_account_json = from_service_account_file
44+
45+
get_transport_class = functools.partial(type({{ service.client_name }}).get_transport_class, type({{ service.client_name }}))
46+
47+
def __init__(self, *,
48+
credentials: credentials.Credentials = None,
49+
transport: Union[str, {{ service.name }}Transport] = "grpc_asyncio",
50+
client_options: ClientOptions = None,
51+
) -> None:
52+
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.
53+
54+
Args:
55+
credentials (Optional[google.auth.credentials.Credentials]): The
56+
authorization credentials to attach to requests. These
57+
credentials identify the application to the service; if none
58+
are specified, the client will attempt to ascertain the
59+
credentials from the environment.
60+
transport (Union[str, ~.{{ service.name }}Transport]): The
61+
transport to use. If set to None, a transport is chosen
62+
automatically.
63+
client_options (ClientOptions): Custom options for the client.
64+
(1) The ``api_endpoint`` property can be used to override the
65+
default endpoint provided by the client.
66+
(2) If ``transport`` argument is None, ``client_options`` can be
67+
used to create a mutual TLS transport. If ``client_cert_source``
68+
is provided, mutual TLS transport will be created with the given
69+
``api_endpoint`` or the default mTLS endpoint, and the client
70+
SSL credentials obtained from ``client_cert_source``.
71+
72+
Raises:
73+
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
74+
creation failed for any reason.
75+
"""
76+
{# NOTE(lidiz) Not using kwargs since we want the docstring and types. #}
77+
self._client = {{ service.client_name }}(
78+
credentials=credentials,
79+
transport=transport,
80+
client_options=client_options,
81+
)
82+
83+
{% for method in service.methods.values() -%}
84+
{% if not method.server_streaming %}async {% endif -%}def {{ method.name|snake_case }}(self,
85+
{%- if not method.client_streaming %}
86+
request: {{ method.input.ident }} = None,
87+
*,
88+
{% for field in method.flattened_fields.values() -%}
89+
{{ field.name }}: {{ field.ident }} = None,
90+
{% endfor -%}
91+
{%- else %}
92+
requests: AsyncIterator[{{ method.input.ident }}] = None,
93+
*,
94+
{% endif -%}
95+
retry: retries.Retry = gapic_v1.method.DEFAULT,
96+
timeout: float = None,
97+
metadata: Sequence[Tuple[str, str]] = (),
98+
{%- if not method.server_streaming %}
99+
) -> {{ method.client_output_async.ident }}:
100+
{%- else %}
101+
) -> AsyncIterable[{{ method.client_output_async.ident }}]:
102+
{%- endif %}
103+
r"""{{ method.meta.doc|rst(width=72, indent=8) }}
104+
105+
Args:
106+
{%- if not method.client_streaming %}
107+
request (:class:`{{ method.input.ident.sphinx }}`):
108+
The request object.{{ ' ' -}}
109+
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
110+
{% for key, field in method.flattened_fields.items() -%}
111+
{{ field.name }} (:class:`{{ field.ident.sphinx }}`):
112+
{{ field.meta.doc|rst(width=72, indent=16, nl=False) }}
113+
This corresponds to the ``{{ key }}`` field
114+
on the ``request`` instance; if ``request`` is provided, this
115+
should not be set.
116+
{% endfor -%}
117+
{%- else %}
118+
requests (AsyncIterator[`{{ method.input.ident.sphinx }}`]):
119+
The request object AsyncIterator.{{ ' ' -}}
120+
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
121+
{%- endif %}
122+
retry (google.api_core.retry.Retry): Designation of what errors, if any,
123+
should be retried.
124+
timeout (float): The timeout for this request.
125+
metadata (Sequence[Tuple[str, str]]): Strings which should be
126+
sent along with the request as metadata.
127+
{%- if not method.void %}
128+
129+
Returns:
130+
{%- if not method.server_streaming %}
131+
{{ method.client_output_async.ident.sphinx }}:
132+
{%- else %}
133+
AsyncIterable[{{ method.client_output_async.ident.sphinx }}]:
134+
{%- endif %}
135+
{{ method.client_output_async.meta.doc|rst(width=72, indent=16) }}
136+
{%- endif %}
137+
"""
138+
{%- if not method.client_streaming %}
139+
# Create or coerce a protobuf request object.
140+
{% if method.flattened_fields -%}
141+
# Sanity check: If we got a request object, we should *not* have
142+
# gotten any keyword arguments that map to the request.
143+
if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]):
144+
raise ValueError('If the `request` argument is set, then none of '
145+
'the individual field arguments should be set.')
146+
147+
{% endif -%}
148+
{% if method.input.ident.package != method.ident.package -%} {# request lives in a different package, so there is no proto wrapper #}
149+
# The request isn't a proto-plus wrapped type,
150+
# so it must be constructed via keyword expansion.
151+
if isinstance(request, dict):
152+
request = {{ method.input.ident }}(**request)
153+
{% if method.flattened_fields -%}{# Cross-package req and flattened fields #}
154+
elif not request:
155+
request = {{ method.input.ident }}()
156+
{% endif -%}{# Cross-package req and flattened fields #}
157+
{%- else %}
158+
request = {{ method.input.ident }}(request)
159+
{% endif %} {# different request package #}
160+
161+
{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
162+
{% if method.flattened_fields -%}
163+
# If we have keyword arguments corresponding to fields on the
164+
# request, apply these.
165+
{% endif -%}
166+
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
167+
if {{ field.name }} is not None:
168+
request.{{ key }} = {{ field.name }}
169+
{%- endfor %}
170+
{# They can be _extended_, however -#}
171+
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
172+
if {{ field.name }}:
173+
request.{{ key }}.extend({{ field.name }})
174+
{%- endfor %}
175+
{%- endif %}
176+
177+
# Wrap the RPC method; this adds retry and timeout information,
178+
# and friendly error handling.
179+
rpc = gapic_v1.method_async.wrap_method(
180+
self._client._transport.{{ method.name|snake_case }},
181+
{%- if method.retry %}
182+
default_retry=retries.Retry(
183+
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
184+
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
185+
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
186+
predicate=retries.if_exception_type(
187+
{%- filter sort_lines %}
188+
{%- for ex in method.retry.retryable_exceptions %}
189+
exceptions.{{ ex.__name__ }},
190+
{%- endfor %}
191+
{%- endfilter %}
192+
),
193+
),
194+
{%- endif %}
195+
default_timeout={{ method.timeout }},
196+
client_info=_client_info,
197+
)
198+
{%- if method.field_headers %}
199+
200+
# Certain fields should be provided within the metadata header;
201+
# add these here.
202+
metadata = tuple(metadata) + (
203+
gapic_v1.routing_header.to_grpc_metadata((
204+
{%- for field_header in method.field_headers %}
205+
('{{ field_header }}', request.{{ field_header }}),
206+
{%- endfor %}
207+
)),
208+
)
209+
{%- endif %}
210+
211+
# Send the request.
212+
{% if not method.void %}response = {% endif %}
213+
{%- if not method.server_streaming %}await {% endif %}rpc(
214+
{%- if not method.client_streaming %}
215+
request,
216+
{%- else %}
217+
requests,
218+
{%- endif %}
219+
retry=retry,
220+
timeout=timeout,
221+
metadata=metadata,
222+
)
223+
{%- if method.lro %}
224+
225+
# Wrap the response in an operation future.
226+
response = operation_async.from_gapic(
227+
response,
228+
self._client._transport.operations_client,
229+
{{ method.lro.response_type.ident }},
230+
metadata_type={{ method.lro.metadata_type.ident }},
231+
)
232+
{%- elif method.paged_result_field %}
233+
234+
# This method is paged; wrap the response in a pager, which provides
235+
# an `__aiter__` convenience method.
236+
response = {{ method.client_output_async.ident }}(
237+
method=rpc,
238+
request=request,
239+
response=response,
240+
)
241+
{%- endif %}
242+
{%- if not method.void %}
243+
244+
# Done; return the response.
245+
return response
246+
{%- endif %}
247+
{{ '\n' }}
248+
{% endfor %}
249+
250+
251+
try:
252+
_client_info = gapic_v1.client_info.ClientInfo(
253+
gapic_version=pkg_resources.get_distribution(
254+
'{{ api.naming.warehouse_package_name }}',
255+
).version,
256+
)
257+
except pkg_resources.DistributionNotFound:
258+
_client_info = gapic_v1.client_info.ClientInfo()
259+
260+
261+
__all__ = (
262+
'{{ service.async_client_name }}',
263+
)
264+
{% endblock %}

0 commit comments

Comments
 (0)