Skip to content

Commit aa7f4d5

Browse files
authored
feat: forward compatible diregapic LRO support (googleapis#1085)
Detect whether a method fulfills the criteria for DIREGAPIC LRO. If so, fudge the name of the generated method by adding the suffix '_primitive'. This change is made for both the synchronous and async client variants. Any generated unit tests are changed to use and reference the fudged name. The names of the corresponding transport method is NOT changed.
1 parent a03bc22 commit aa7f4d5

13 files changed

Lines changed: 1159 additions & 56 deletions

File tree

BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ toolchain(
5151

5252
py_binary(
5353
name = "gapic_plugin",
54-
srcs = glob(["gapic/**/*.py"]),
54+
srcs = glob(["gapic/**/*.py", "google/**/*.py"]),
5555
data = [":pandoc_binary"] + glob([
5656
"gapic/**/*.j2",
5757
"gapic/**/.*.j2",

gapic/schema/wrappers.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from google.api import resource_pb2
4242
from google.api_core import exceptions
4343
from google.api_core import path_template
44+
from google.cloud import extended_operations_pb2 as ex_ops_pb2
4445
from google.protobuf import descriptor_pb2 # type: ignore
4546
from google.protobuf.json_format import MessageToDict # type: ignore
4647

@@ -344,6 +345,39 @@ def oneof_fields(self, include_optional=False):
344345

345346
return oneof_fields
346347

348+
@utils.cached_property
349+
def is_diregapic_operation(self) -> bool:
350+
if not self.name == "Operation":
351+
return False
352+
353+
name, status, error_code, error_message = False, False, False, False
354+
duplicate_msg = f"Message '{self.name}' has multiple fields with the same operation response mapping: {{}}"
355+
for f in self.field:
356+
maybe_op_mapping = f.options.Extensions[ex_ops_pb2.operation_field]
357+
OperationResponseMapping = ex_ops_pb2.OperationResponseMapping
358+
359+
if maybe_op_mapping == OperationResponseMapping.NAME:
360+
if name:
361+
raise TypeError(duplicate_msg.format("name"))
362+
name = True
363+
364+
if maybe_op_mapping == OperationResponseMapping.STATUS:
365+
if status:
366+
raise TypeError(duplicate_msg.format("status"))
367+
status = True
368+
369+
if maybe_op_mapping == OperationResponseMapping.ERROR_CODE:
370+
if error_code:
371+
raise TypeError(duplicate_msg.format("error_code"))
372+
error_code = True
373+
374+
if maybe_op_mapping == OperationResponseMapping.ERROR_MESSAGE:
375+
if error_message:
376+
raise TypeError(duplicate_msg.format("error_message"))
377+
error_message = True
378+
379+
return name and status and error_code and error_message
380+
347381
@utils.cached_property
348382
def required_fields(self) -> Sequence['Field']:
349383
required_fields = [
@@ -765,6 +799,10 @@ class Method:
765799
def __getattr__(self, name):
766800
return getattr(self.method_pb, name)
767801

802+
@property
803+
def is_operation_polling_method(self):
804+
return self.output.is_diregapic_operation and self.options.Extensions[ex_ops_pb2.operation_polling_method]
805+
768806
@utils.cached_property
769807
def client_output(self):
770808
return self._client_output(enable_asyncio=False)
@@ -838,6 +876,10 @@ def _client_output(self, enable_asyncio: bool):
838876
# Return the usual output.
839877
return self.output
840878

879+
@property
880+
def operation_service(self) -> Optional[str]:
881+
return self.options.Extensions[ex_ops_pb2.operation_service]
882+
841883
@property
842884
def is_deprecated(self) -> bool:
843885
"""Returns true if the method is deprecated, false otherwise."""
@@ -1172,6 +1214,10 @@ class Service:
11721214
def __getattr__(self, name):
11731215
return getattr(self.service_pb, name)
11741216

1217+
@property
1218+
def custom_polling_method(self) -> Optional[Method]:
1219+
return next((m for m in self.methods.values() if m.is_operation_polling_method), None)
1220+
11751221
@property
11761222
def client_name(self) -> str:
11771223
"""Returns the name of the generated client class"""

gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ class {{ service.async_client_name }}:
150150
)
151151

152152
{% for method in service.methods.values() %}
153-
{%+ if not method.server_streaming %}async {% endif %}def {{ method.name|snake_case }}(self,
153+
{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}
154+
{%+ if not method.server_streaming %}async {% endif %}def {{ method_name }}(self,
155+
{% endwith %}
154156
{% if not method.client_streaming %}
155157
request: Union[{{ method.input.ident }}, dict] = None,
156158
*,

gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,11 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
315315

316316

317317
{% for method in service.methods.values() %}
318+
{% if method.operation_service %}{# DIREGAPIC LRO #}
319+
def {{ method.name|snake_case }}_unary(self,
320+
{% else %}
318321
def {{ method.name|snake_case }}(self,
322+
{% endif %}{# DIREGAPIC LRO #}
319323
{% if not method.client_streaming %}
320324
request: Union[{{ method.input.ident }}, dict] = None,
321325
*,

0 commit comments

Comments
 (0)