@@ -13,6 +13,7 @@ from proto.marshal.rules.dates import DurationRule, TimestampRule
1313
1414{% if 'rest' in opts .transport %}
1515from requests import Response
16+ from requests import Request
1617from requests.sessions import Session
1718{% endif %}
1819
@@ -104,7 +105,8 @@ def test_{{ service.client_name|snake_case }}_from_service_account_info(client_c
104105 {% if 'grpc' in opts .transport %}
105106 (transports.{{ service.grpc_transport_name }}, "grpc"),
106107 (transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"),
107- {% elif 'rest' in opts .transport %}
108+ {% endif %}
109+ {% if 'rest' in opts .transport %}
108110 (transports.{{ service.rest_transport_name }}, "rest"),
109111 {% endif %}
110112])
@@ -160,7 +162,8 @@ def test_{{ service.client_name|snake_case }}_get_transport_class():
160162 {% if 'grpc' in opts .transport %}
161163 ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"),
162164 ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"),
163- {% elif 'rest' in opts .transport %}
165+ {% endif %}
166+ {% if 'rest' in opts .transport %}
164167 ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest"),
165168 {% endif %}
166169])
@@ -186,7 +189,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
186189 options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
187190 with mock.patch.object(transport_class, '__init__') as patched:
188191 patched.return_value = None
189- client = client_class(client_options=options)
192+ client = client_class(transport=transport_name, client_options=options)
190193 patched.assert_called_once_with(
191194 credentials=None,
192195 credentials_file=None,
@@ -203,7 +206,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
203206 with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
204207 with mock.patch.object(transport_class, '__init__') as patched:
205208 patched.return_value = None
206- client = client_class()
209+ client = client_class(transport=transport_name )
207210 patched.assert_called_once_with(
208211 credentials=None,
209212 credentials_file=None,
@@ -220,7 +223,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
220223 with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
221224 with mock.patch.object(transport_class, '__init__') as patched:
222225 patched.return_value = None
223- client = client_class()
226+ client = client_class(transport=transport_name )
224227 patched.assert_called_once_with(
225228 credentials=None,
226229 credentials_file=None,
@@ -247,7 +250,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
247250 options = client_options.ClientOptions(quota_project_id="octopus")
248251 with mock.patch.object(transport_class, '__init__') as patched:
249252 patched.return_value = None
250- client = client_class(client_options=options)
253+ client = client_class(transport=transport_name, client_options=options)
251254 patched.assert_called_once_with(
252255 credentials=None,
253256 credentials_file=None,
@@ -265,7 +268,8 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
265268 ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio", "true"),
266269 ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc", "false"),
267270 ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio", "false"),
268- {% elif 'rest' in opts .transport %}
271+ {% endif %}
272+ {% if 'rest' in opts .transport %}
269273 ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest", "true"),
270274 ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest", "false"),
271275 {% endif %}
@@ -285,7 +289,7 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
285289 options = client_options.ClientOptions(client_cert_source=client_cert_source_callback)
286290 with mock.patch.object(transport_class, '__init__') as patched:
287291 patched.return_value = None
288- client = client_class(client_options=options)
292+ client = client_class(transport=transport_name, client_options=options)
289293
290294 if use_client_cert_env == "false":
291295 expected_client_cert_source = None
@@ -319,7 +323,7 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
319323 expected_client_cert_source = client_cert_source_callback
320324
321325 patched.return_value = None
322- client = client_class()
326+ client = client_class(transport=transport_name )
323327 patched.assert_called_once_with(
324328 credentials=None,
325329 credentials_file=None,
@@ -336,7 +340,7 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
336340 with mock.patch.object(transport_class, '__init__') as patched:
337341 with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False):
338342 patched.return_value = None
339- client = client_class()
343+ client = client_class(transport=transport_name )
340344 patched.assert_called_once_with(
341345 credentials=None,
342346 credentials_file=None,
@@ -353,7 +357,8 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
353357 {% if 'grpc' in opts .transport %}
354358 ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"),
355359 ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"),
356- {% elif 'rest' in opts .transport %}
360+ {% endif %}
361+ {% if 'rest' in opts .transport %}
357362 ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest"),
358363 {% endif %}
359364])
@@ -364,7 +369,7 @@ def test_{{ service.client_name|snake_case }}_client_options_scopes(client_class
364369 )
365370 with mock.patch.object(transport_class, '__init__') as patched:
366371 patched.return_value = None
367- client = client_class(client_options=options)
372+ client = client_class(transport=transport_name, client_options=options)
368373 patched.assert_called_once_with(
369374 credentials=None,
370375 credentials_file=None,
@@ -380,7 +385,8 @@ def test_{{ service.client_name|snake_case }}_client_options_scopes(client_class
380385 {% if 'grpc' in opts .transport %}
381386 ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"),
382387 ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"),
383- {% elif 'rest' in opts .transport %}
388+ {% endif %}
389+ {% if 'rest' in opts .transport %}
384390 ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest"),
385391 {% endif %}
386392])
@@ -391,7 +397,7 @@ def test_{{ service.client_name|snake_case }}_client_options_credentials_file(cl
391397 )
392398 with mock.patch.object(transport_class, '__init__') as patched:
393399 patched.return_value = None
394- client = client_class(client_options=options)
400+ client = client_class(transport=transport_name, client_options=options)
395401 patched.assert_called_once_with(
396402 credentials=None,
397403 credentials_file="credentials.json",
@@ -1182,14 +1188,48 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type
11821188 {% endif %}
11831189
11841190
1191+ def test_{{ method.name|snake_case }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
1192+ client = {{ service.client_name }}(
1193+ credentials=ga_credentials.AnonymousCredentials(),
1194+ transport=transport,
1195+ )
1196+
1197+ # send a request that will satisfy transcoding
1198+ request_init = {{ method.http_options[0] .sample_request}}
1199+ {% for field in method .body_fields .values () %}
1200+ {% if not field .oneof or field .proto 3_optional %}
1201+ {# ignore oneof fields that might conflict with sample_request #}
1202+ request_init["{{ field.name }}"] = {{ field.mock_value }}
1203+ {% endif %}
1204+ {% endfor %}
1205+ request = request_type(request_init)
1206+ {% if method .client_streaming %}
1207+ requests = [request]
1208+ {% endif %}
1209+
1210+ # Mock the http request call within the method and fake a BadRequest error.
1211+ with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest):
1212+ # Wrap the value into a proper Response obj
1213+ response_value = Response()
1214+ response_value.status_code = 400
1215+ response_value.request = Request()
1216+ req.return_value = response_value
1217+ {% if method .client_streaming %}
1218+ client.{{ method.name|snake_case }}(iter(requests))
1219+ {% else %}
1220+ client.{{ method.name|snake_case }}(request)
1221+ {% endif %}
1222+
1223+
11851224def test_{{ method.name|snake_case }}_rest_from_dict():
11861225 test_{{ method.name|snake_case }}_rest(request_type=dict)
11871226
11881227
11891228{% if method .flattened_fields %}
1190- def test_{{ method.name|snake_case }}_rest_flattened():
1229+ def test_{{ method.name|snake_case }}_rest_flattened(transport: str = 'rest' ):
11911230 client = {{ service.client_name }}(
11921231 credentials=ga_credentials.AnonymousCredentials(),
1232+ transport=transport,
11931233 )
11941234
11951235 # Mock the http request call within the method and fake a response.
@@ -1242,9 +1282,10 @@ def test_{{ method.name|snake_case }}_rest_flattened():
12421282 {# TODO(kbandes) - reverse-transcode request args to check all request fields #}
12431283
12441284
1245- def test_{{ method.name|snake_case }}_rest_flattened_error():
1285+ def test_{{ method.name|snake_case }}_rest_flattened_error(transport: str = 'rest' ):
12461286 client = {{ service.client_name }}(
12471287 credentials=ga_credentials.AnonymousCredentials(),
1288+ transport=transport,
12481289 )
12491290
12501291 # Attempting to call a method with both a request object and flattened
@@ -1460,7 +1501,8 @@ def test_transport_get_channel():
14601501 {% if 'grpc' in opts .transport %}
14611502 transports.{{ service.grpc_transport_name }},
14621503 transports.{{ service.grpc_asyncio_transport_name }},
1463- {% elif 'rest' in opts .transport %}
1504+ {% endif %}
1505+ {% if 'rest' in opts .transport %}
14641506 transports.{{ service.rest_transport_name }},
14651507 {% endif %}
14661508])
0 commit comments