Skip to content

Commit 7e23c7e

Browse files
authored
feat: add support for context manager in client (#987)
* feat: add support for context manager in client. * chore: remove extra whitespace. * chore: adds autogenerated unit tests. * chore: adds stronger warning. * chore: fixes tests. * chore: adds auto-generated tests for ads. * chore: updates golden files. * chore: refactor. * chore: refactor. * feat: adds close() to transport and ctx to async client. * feat: adds close method and removes ctx from transport in ads. * chore: adds warning infobox to docstring. * chore: updates integration tests. * chore: fixes typo.
1 parent 9961f43 commit 7e23c7e

48 files changed

Lines changed: 643 additions & 12 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

packages/gapic-generator/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,18 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
148148
"""
149149
return self._transport
150150

151+
def __enter__(self):
152+
return self
153+
154+
def __exit__(self, type, value, traceback):
155+
"""Releases underlying transport's resources.
156+
157+
.. warning::
158+
ONLY use as a context manager if the transport is NOT shared
159+
with other clients! Exiting the with block will CLOSE the transport
160+
and may cause errors in other clients!
161+
"""
162+
self.transport.close()
151163

152164
{% for message in service.resource_messages|sort(attribute="resource_type") %}
153165
@staticmethod

packages/gapic-generator/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/base.py.j2

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):
103103
{% endfor %} {# precomputed wrappers loop #}
104104
}
105105

106+
def close(self):
107+
"""Closes resources associated with the transport.
108+
109+
.. warning::
110+
Only call this method if the transport is NOT shared
111+
with other clients - this may cause errors in other clients!
112+
"""
113+
raise NotImplementedError()
106114

107115
{% if service.has_lro %}
108116

packages/gapic-generator/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
188188
**kwargs
189189
)
190190

191+
def close(self):
192+
self.grpc_channel.close()
193+
191194
@property
192195
def grpc_channel(self) -> grpc.Channel:
193196
"""Return the channel designed to connect to this service.

packages/gapic-generator/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,9 @@ def test_{{ service.name|snake_case }}_base_transport():
654654
with pytest.raises(NotImplementedError):
655655
getattr(transport, method)(request=object())
656656

657+
with pytest.raises(NotImplementedError):
658+
transport.close()
659+
657660
{% if service.has_lro %}
658661
# Additionally, the LRO client (a property) should
659662
# also raise NotImplementedError
@@ -903,5 +906,26 @@ def test_client_withDEFAULT_CLIENT_INFO():
903906
)
904907
prep.assert_called_once_with(client_info)
905908

909+
def test_grpc_transport_close():
910+
client = {{ service.client_name }}(
911+
credentials=ga_credentials.AnonymousCredentials(),
912+
transport='grpc',
913+
)
914+
with mock.patch.object(type(client.transport._grpc_channel), 'close') as chan_close:
915+
with client as _:
916+
chan_close.assert_not_called()
917+
chan_close.assert_called_once()
918+
919+
def test_grpc_client_ctx():
920+
client = {{ service.client_name }}(
921+
credentials=ga_credentials.AnonymousCredentials(),
922+
transport='grpc',
923+
)
924+
# Test client calls underlying transport.
925+
with mock.patch.object(type(client.transport), "close") as close:
926+
close.assert_not_called()
927+
with client as _:
928+
pass
929+
close.assert_called()
906930

907931
{% endblock %}

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,12 @@ class {{ service.async_client_name }}:
599599
return response
600600
{% endif %}
601601

602+
async def __aenter__(self):
603+
return self
604+
605+
async def __aexit__(self, exc_type, exc, tb):
606+
await self.transport.close()
607+
602608
try:
603609
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
604610
gapic_version=pkg_resources.get_distribution(

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,19 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
477477
{{ "\n" }}
478478
{% endfor %}
479479

480+
def __enter__(self):
481+
return self
482+
483+
def __exit__(self, type, value, traceback):
484+
"""Releases underlying transport's resources.
485+
486+
.. warning::
487+
ONLY use as a context manager if the transport is NOT shared
488+
with other clients! Exiting the with block will CLOSE the transport
489+
and may cause errors in other clients!
490+
"""
491+
self.transport.close()
492+
480493
{% if opts.add_iam_methods %}
481494
def set_iam_policy(
482495
self,

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,14 @@ class {{ service.name }}Transport(abc.ABC):
174174
{% endfor %} {# precomputed wrappers loop #}
175175
}
176176

177+
def close(self):
178+
"""Closes resources associated with the transport.
179+
180+
.. warning::
181+
Only call this method if the transport is NOT shared
182+
with other clients - this may cause errors in other clients!
183+
"""
184+
raise NotImplementedError()
177185

178186
{% if service.has_lro %}
179187

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
224224
"""Return the channel designed to connect to this service.
225225
"""
226226
return self._grpc_channel
227+
227228
{% if service.has_lro %}
228229

229230
@property
@@ -355,6 +356,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
355356
return self._stubs["test_iam_permissions"]
356357
{% endif %}
357358

359+
def close(self):
360+
self.grpc_channel.close()
361+
358362
__all__ = (
359363
'{{ service.name }}GrpcTransport',
360364
)

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,10 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
359359
return self._stubs["test_iam_permissions"]
360360
{% endif %}
361361

362+
def close(self):
363+
return self.grpc_channel.close()
364+
365+
362366
__all__ = (
363367
'{{ service.name }}GrpcAsyncIOTransport',
364368
)

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ class {{service.name}}RestTransport({{service.name}}Transport):
283283
return self._{{method.name | snake_case}}
284284
{%- endfor %}
285285

286+
def close(self):
287+
self._session.close()
288+
286289

287290
__all__=(
288291
'{{ service.name }}RestTransport',

0 commit comments

Comments
 (0)