@@ -1287,65 +1287,8 @@ def test_{{ method_name }}_raw_page_lro():
12871287{% endfor %} {# method in methods for grpc #}
12881288
12891289{% for method in service .methods .values () if 'rest' in opts .transport %}{% with method_name = method .name |snake_case + "_unary" if method .operation_service else method .name |snake_case %}{% if method .http_options %}
1290- {# TODO(kbandes): remove this if condition when streaming are supported. #}
1291- {% if not (method .server_streaming or method .client_streaming ) %}
1292- @pytest.mark.parametrize("request_type", [
1293- {{ method.input.ident }},
1294- dict,
1295- ])
1296- def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
1297- client = {{ service.client_name }}(
1298- credentials=ga_credentials.AnonymousCredentials(),
1299- transport="rest",
1300- )
1301- # Send a request that will satisfy transcoding
1302- request = {{ method.input.ident }}({{ method.http_options[0] .sample_request(method) }})
1303- {% if method .client_streaming %}
1304- requests = [request]
1305- {% endif %}
1306-
1307-
1308- with mock.patch.object(type(client.transport._session), 'request') as req:
1309- {% if method .void %}
1310- return_value = None
1311- {% elif method .lro %}
1312- return_value = operations_pb2.Operation(name='operations/spam')
1313- {% elif method .server_streaming %}
1314- return_value = iter([{{ method.output.ident }}()])
1315- {% else %}
1316- return_value = {{ method.output.ident }}(
1317- {% for field in method .output .fields .values () | rejectattr ('message' )%}
1318- {% if not field .oneof or field .proto 3_optional %}
1319- {{ field.name }}={{ field.mock_value }},
1320- {% endif %}{% endfor %}
1321- {# This is a hack to only pick one field #}
1322- {% for oneof_fields in method .output .oneof_fields ().values () %}
1323- {% with field = oneof_fields [0] %}
1324- {{ field.name }}={{ field.mock_value }},
1325- {% endwith %}
1326- {% endfor %}
1327- )
1328- {% endif %}
1329- req.return_value = Response()
1330- req.return_value.status_code = 500
1331- req.return_value.request = PreparedRequest()
1332- {% if method .void %}
1333- json_return_value = ''
1334- {% else %}
1335- json_return_value = {{ method.output.ident }}.to_json(return_value)
1336- {% endif %}
1337- req.return_value._content = json_return_value.encode("UTF-8")
1338- with pytest.raises(core_exceptions.GoogleAPIError):
1339- # We only care that the correct exception is raised when putting
1340- # the request over the wire, so an empty request is fine.
1341- {% if method .client_streaming %}
1342- client.{{ method_name }}(iter([requests]))
1343- {% else %}
1344- client.{{ method_name }}(request)
1345- {% endif %}
1346-
1347-
13481290{# TODO(kbandes): remove this if condition when lro and streaming are supported. #}
1291+ {% if not (method .server_streaming or method .client_streaming ) %}
13491292@pytest.mark.parametrize("request_type", [
13501293 {{ method.input.ident }},
13511294 dict,
@@ -1458,7 +1401,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
14581401 ))
14591402
14601403 # verify fields with default values are dropped
1461- {% for req_field in method .input .required_fields if req_field .is_primitive %}
1404+ {% for req_field in method .input .required_fields if req_field .is_primitive and req_field . name in method . query_params %}
14621405 {% set field_name = req_field .name | camel_case %}
14631406 assert "{{ field_name }}" not in jsonified_request
14641407 {% endfor %}
@@ -1467,7 +1410,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
14671410 jsonified_request.update(unset_fields)
14681411
14691412 # verify required fields with default values are now present
1470- {% for req_field in method .input .required_fields if req_field .is_primitive %}
1413+ {% for req_field in method .input .required_fields if req_field .is_primitive and req_field . name in method . query_params %}
14711414 {% set field_name = req_field .name | camel_case %}
14721415 assert "{{ field_name }}" in jsonified_request
14731416 assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
@@ -1480,6 +1423,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
14801423 {% endfor %}
14811424
14821425 unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request)
1426+ {% if method .query_params %}
1427+ # Check that path parameters and body parameters are not mixing in.
1428+ assert not set(unset_fields) - set(({% for param in method .query_params %} "{{param}}", {% endfor %} ))
1429+ {% endif %}
14831430 jsonified_request.update(unset_fields)
14841431
14851432 # verify required fields with non-default values are left alone
@@ -1544,7 +1491,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
15441491 {% endif %}
15451492
15461493 expected_params = [
1547- {% for req_field in method .input .required_fields if req_field .is_primitive %}
1494+ {% for req_field in method .input .required_fields if req_field .is_primitive and req_field . name in method . query_params %}
15481495 (
15491496 "{{ req_field.name | camel_case }}",
15501497 {% if req_field .field_pb .type == 9 %}
@@ -1559,6 +1506,12 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
15591506 assert expected_params == actual_params
15601507
15611508
1509+ def test_{{ method_name }}_rest_unset_required_fields():
1510+ transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials)
1511+
1512+ unset_fields = transport.{{ method.name|snake_case }}._get_unset_required_fields({})
1513+ assert set(unset_fields) == (set(({% for param in method .query_params %} "{{ param|camel_case }}", {% endfor %} )) & set(({% for param in method .input .required_fields %} "{{ param.name|camel_case }}", {% endfor %} )))
1514+
15621515 {% endif %} {# required_fields #}
15631516
15641517
0 commit comments