Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.

Commit b9d0532

Browse files
committed
test: improve routing parameter assert
1 parent 2b3d709 commit b9d0532

2 files changed

Lines changed: 85 additions & 29 deletions

File tree

gapic/schema/wrappers.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,66 @@ def try_parse_routing_rule(cls, routing_rule: routing_pb2.RoutingRule) -> Option
10671067
params = [RoutingParameter(x.field, x.path_template) for x in params]
10681068
return cls(params)
10691069

1070+
@classmethod
1071+
def resolve(cls, routing_rule: routing_pb2.RoutingRule, request: Union[dict, str]) -> dict:
1072+
"""Resolves the routing header which should be sent along with the request.
1073+
This function performs dynamic header resolution, identical to what's in `client.py.j2`.
1074+
The routing header is determined based on the given routing rule and request.
1075+
See the following link for more information on explicit routing headers:
1076+
https://google.aip.dev/client-libraries/4222#explicit-routing-headers-googleapirouting
1077+
1078+
Args:
1079+
routing_rule(routing_pb2.RoutingRule): A collection of Routing Parameter specifications
1080+
defined by `routing_pb2.RoutingRule`.
1081+
See https://github.com/googleapis/googleapis/blob/cb39bdd75da491466f6c92bc73cd46b0fbd6ba9a/google/api/routing.proto#L391
1082+
request(Union[dict, str]): The request for which the routine rule should be resolved.
1083+
The format can be either a dictionary or json string representing the request.
1084+
1085+
Returns(dict):
1086+
A dictionary containing the resolved routing header to the sent along with the given request.
1087+
"""
1088+
1089+
def _get_field(request, field_path: str):
1090+
segments = field_path.split(".")
1091+
1092+
# Either json string or dictionary is supported
1093+
if isinstance(request, str):
1094+
current = json.loads(request)
1095+
else:
1096+
current = request
1097+
1098+
# This is to cater for the case where the `field_path` contains a
1099+
# dot-separated path of field names leading to a field in a sub-message.
1100+
for x in segments:
1101+
current = current.get(x, None)
1102+
# Break if the sub-message does not exist
1103+
if current is None:
1104+
break
1105+
return current
1106+
1107+
header_params = {}
1108+
for routing_param in routing_rule.routing_parameters:
1109+
request_field_value = _get_field(request, routing_param.field)
1110+
# Only resolve the header for routing parameter fields which are populated in the request
1111+
if request_field_value is not None:
1112+
# If there is a path_template for a given routing parameter field, the value of the field must match
1113+
# If multiple Routing Parameters describe the same key
1114+
# (via the `path_template` field or via the `field` field when
1115+
# `path_template` is not provided), "last one wins" rule
1116+
# determines which Parameter gets used.
1117+
if routing_param.path_template:
1118+
routing_param_regex = routing_param.to_regex()
1119+
regex_match = routing_param_regex.match(
1120+
request_field_value
1121+
)
1122+
if regex_match:
1123+
header_params[routing_param.key] = regex_match.group(
1124+
routing_param.key
1125+
)
1126+
else: # No need to match
1127+
header_params[routing_param.key] = request_field_value
1128+
return header_params
1129+
10701130

10711131
@dataclasses.dataclass(frozen=True)
10721132
class HttpRule:

tests/unit/schema/wrappers/test_routing.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from gapic.schema import wrappers
1616

17+
import json
1718
import proto
1819
import pytest
1920

@@ -23,31 +24,6 @@ class RoutingTestRequest(proto.Message):
2324
app_profile_id = proto.Field(proto.STRING, number=2)
2425

2526

26-
def resolve(rule, request):
27-
"""This function performs dynamic header resolution, identical to what's in client.py.j2."""
28-
29-
def _get_field(request, field_path: str):
30-
segments = field_path.split(".")
31-
cur = request
32-
for x in segments:
33-
cur = getattr(cur, x)
34-
return cur
35-
36-
header_params = {}
37-
for routing_param in rule.routing_parameters:
38-
# This may raise exception (which we show to clients).
39-
request_field_value = _get_field(request, routing_param.field)
40-
if routing_param.path_template:
41-
routing_param_regex = routing_param.to_regex()
42-
regex_match = routing_param_regex.match(request_field_value)
43-
if regex_match:
44-
header_params[routing_param.key] = regex_match.group(
45-
routing_param.key)
46-
else: # No need to match
47-
header_params[routing_param.key] = request_field_value
48-
return header_params
49-
50-
5127
@pytest.mark.parametrize(
5228
"req, expected",
5329
[
@@ -63,7 +39,10 @@ def _get_field(request, field_path: str):
6339
def test_routing_rule_resolve_simple_extraction(req, expected):
6440
rule = wrappers.RoutingRule(
6541
[wrappers.RoutingParameter("app_profile_id", "")])
66-
assert resolve(rule, req) == expected
42+
assert wrappers.RoutingRule.resolve(
43+
rule,
44+
RoutingTestRequest.to_dict(req)
45+
) == expected
6746

6847

6948
@pytest.mark.parametrize(
@@ -82,7 +61,10 @@ def test_routing_rule_resolve_rename_extraction(req, expected):
8261
rule = wrappers.RoutingRule(
8362
[wrappers.RoutingParameter("app_profile_id", "{routing_id=**}")]
8463
)
85-
assert resolve(rule, req) == expected
64+
assert wrappers.RoutingRule.resolve(
65+
rule,
66+
RoutingTestRequest.to_dict(req)
67+
) == expected
8668

8769

8870
@pytest.mark.parametrize(
@@ -111,7 +93,10 @@ def test_routing_rule_resolve_field_match(req, expected):
11193
),
11294
]
11395
)
114-
assert resolve(rule, req) == expected
96+
assert wrappers.RoutingRule.resolve(
97+
rule,
98+
RoutingTestRequest.to_dict(req)
99+
) == expected
115100

116101

117102
@pytest.mark.parametrize(
@@ -135,6 +120,9 @@ def test_routing_rule_resolve_field_match(req, expected):
135120
wrappers.RoutingParameter(
136121
"table_name", "projects/*/{instance_id=instances/*}/**"
137122
),
123+
wrappers.RoutingParameter(
124+
"doesnotexist", "projects/*/{instance_id=instances/*}/**"
125+
),
138126
],
139127
RoutingTestRequest(
140128
table_name="projects/100/instances/200/tables/300"),
@@ -144,7 +132,15 @@ def test_routing_rule_resolve_field_match(req, expected):
144132
)
145133
def test_routing_rule_resolve(routing_parameters, req, expected):
146134
rule = wrappers.RoutingRule(routing_parameters)
147-
got = resolve(rule, req)
135+
got = wrappers.RoutingRule.resolve(
136+
rule, RoutingTestRequest.to_dict(req)
137+
)
138+
assert got == expected
139+
140+
rule = wrappers.RoutingRule(routing_parameters)
141+
got = wrappers.RoutingRule.resolve(
142+
rule, json.dumps(RoutingTestRequest.to_dict(req))
143+
)
148144
assert got == expected
149145

150146

0 commit comments

Comments
 (0)