Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def on_request(self, request):
"""
self._enforce_https(request)

if self._need_new_token:
if self._token is None or self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
Comment thread
lmazuel marked this conversation as resolved.
self._update_headers(request.http_request.headers, self._token.token)
self._update_headers(request.http_request.headers, self._token.token)


class AzureKeyCredentialPolicy(SansIOHTTPPolicy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
@pytest.mark.asyncio
async def test_bearer_policy_adds_header():
"""The bearer token policy should add a header containing a token from its credential"""
expected_token = AccessToken("expected_token", 0)
# 2524608000 == 01/01/2050 @ 12:00am (UTC)
expected_token = AccessToken("expected_token", 2524608000)

async def verify_authorization_header(request):
assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token)
Expand All @@ -37,6 +38,10 @@ async def get_token(_):
await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None)
assert get_token_calls == 1

await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None)
# Didn't need a new token
assert get_token_calls == 1


@pytest.mark.asyncio
async def test_bearer_policy_send():
Expand Down
10 changes: 8 additions & 2 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,24 @@

def test_bearer_policy_adds_header():
"""The bearer token policy should add a header containing a token from its credential"""
expected_token = AccessToken("expected_token", 0)
# 2524608000 == 01/01/2050 @ 12:00am (UTC)
expected_token = AccessToken("expected_token", 2524608000)

def verify_authorization_header(request):
assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token)

fake_credential = Mock(get_token=Mock(return_value=expected_token))
policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)]

Pipeline(transport=Mock(), policies=policies).run(HttpRequest("GET", "https://spam.eggs"))
pipeline = Pipeline(transport=Mock(), policies=policies)
pipeline.run(HttpRequest("GET", "https://spam.eggs"))

assert fake_credential.get_token.call_count == 1

pipeline.run(HttpRequest("GET", "https://spam.eggs"))

# Didn't need a new token
assert fake_credential.get_token.call_count == 1

def test_bearer_policy_send():
"""The bearer token policy should invoke the next policy's send method and return the result"""
Expand Down