Skip to content
This repository has been archived by the owner on Jan 18, 2025. It is now read-only.

Commit

Permalink
Thread safety changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kpayson64 committed May 6, 2016
1 parent 21083a6 commit 3e6d957
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 56 deletions.
25 changes: 14 additions & 11 deletions oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,19 +598,9 @@ def new_request(uri, method='GET', body=None, headers=None,

# Clone and modify the request headers to add the appropriate
# Authorization header.
if headers is None:
headers = {}
else:
headers = dict(headers)
headers = self._initialize_headers(headers)
self.apply(headers)

if self.user_agent is not None:
if 'user-agent' in headers:
headers['user-agent'] = (self.user_agent + ' ' +
headers['user-agent'])
else:
headers['user-agent'] = self.user_agent

body_stream_position = None
if all(getattr(body, stream_prop, None) for stream_prop in
('read', 'seek', 'tell')):
Expand Down Expand Up @@ -844,6 +834,19 @@ def _generate_refresh_request_headers(self):

return headers

def _initialize_headers(self, headers):
"""Initialize the headers/apply user_agent as needed."""
if headers is None:
headers = {}
else:
headers = dict(headers)
if self.user_agent is not None:
if 'user-agent' in headers:
headers['user-agent'] = (self.user_agent + ' ' + headers['user-agent'])
else:
headers['user-agent'] = self.user_agent
return headers

def _refresh(self, http_request):
"""Refreshes the access_token.
Expand Down
36 changes: 19 additions & 17 deletions oauth2client/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from oauth2client import util
from oauth2client.client import AccessTokenInfo
from oauth2client.client import AssertionCredentials
from oauth2client.client import clean_headers
from oauth2client.client import EXPIRY_FORMAT
from oauth2client.client import GoogleCredentials
from oauth2client.client import SERVICE_ACCOUNT
Expand Down Expand Up @@ -515,7 +516,6 @@ def __init__(self,
client_id=client_id,
user_agent=user_agent,
**additional_claims)
self._aud = None

def authorize(self, http):
"""Authorize an httplib2.Http instance with a JWT assertion.
Expand All @@ -532,27 +532,31 @@ def authorize(self, http):
h = httplib2.Http()
h = credentials.authorize(h)
"""
super(_JWTAccessCredentials, self).authorize(http)
request_orig = http.request
request_auth = super(_JWTAccessCredentials, self).authorize(http).request

# The closure that will replace 'httplib2.Http.request'.
def new_request(uri, method='GET', body=None, headers=None,
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None):
# Preemptively refresh token, this is not done for OAuth2
if self.access_token is None or self.access_token_expired:
self.refresh(None)

# If we don't have an 'aud' (audience) claim,
# extract it from the uri.
if 'aud' not in self._kwargs:
if 'aud' in self._kwargs:
# Preemptively refresh token, this is not done for OAuth2
if self.access_token is None or self.access_token_expired:
self.refresh(None)
return request_auth(uri, method, body,
headers, redirections,
connection_type)
else:
# If we don't have an 'aud' (audience) claim,
# create a 1-time token with the uri root as the audience
uri_root = uri.split('?', 1)[0]
if uri_root != self._aud:
self.access_token = None
self.token_expiry = None
self._aud = uri_root
return request_orig(uri, method, body,
headers, redirections, connection_type)
headers = self._initialize_headers(headers)
token, unused_expiry = self._create_token({'aud': uri_root})

headers['Authorization'] = 'Bearer ' + token
return request_orig(uri, method, body,
clean_headers(headers),
redirections, connection_type)

# Replace the request method with our own closure.
http.request = new_request
Expand Down Expand Up @@ -622,8 +626,6 @@ def _create_token(self, additional_claims=None):
'iss': self._service_account_email,
'sub': self._service_account_email
}
if self._aud is not None:
payload['aud'] = self._aud
payload.update(self._kwargs)
if additional_claims is not None:
payload.update(additional_claims)
Expand Down
48 changes: 20 additions & 28 deletions tests/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,38 +512,30 @@ def test_authorize_no_aud(self, time, client_utcnow, utcnow):
self.signer,
private_key_id=self.private_key_id,
client_id=self.client_id)
h = HttpMockSequence([({'status': '200'}, b''),
({'status': '200'}, b''),
({'status': '200'}, b'')])
jwt.authorize(h)
h.request(self.url)

payload = crypt.verify_signed_jwt_with_certs(
jwt.access_token,
{'key': datafile('public_cert.pem')},
audience=self.url)
self.assertEqual(payload['iss'], self.service_account_email)
self.assertEqual(payload['sub'], self.service_account_email)
self.assertEqual(payload['iat'], T1)
self.assertEqual(payload['exp'], T1_EXPIRY)
def mock_request(uri, method='GET', body=None, headers=None,
redirections=0, connection_type=None):
self.assertEqual(uri, self.url)
bearer, token = headers[b'Authorization'].split()
payload = crypt.verify_signed_jwt_with_certs(
token,
{'key': datafile('public_cert.pem')},
audience=self.url)
self.assertEqual(payload['iss'], self.service_account_email)
self.assertEqual(payload['sub'], self.service_account_email)
self.assertEqual(payload['iat'], T1)
self.assertEqual(payload['exp'], T1_EXPIRY)
self.assertEqual(uri, self.url)
self.assertEqual(bearer, b'Bearer')
return (httplib2.Response({'status': '200'}), b'')

# Ensure we use the cached token
utcnow.return_value = T2_DATE
client_utcnow.return_value = T2_DATE
time.return_value = T2
h = httplib2.Http()
h.request = mock_request
jwt.authorize(h)
h.request(self.url)
payload = crypt.verify_signed_jwt_with_certs(
jwt.access_token,
{'key': datafile('public_cert.pem')},
audience=self.url)
self.assertEqual(payload['iat'], T1)

# Ensure we create a new token for new url
h.request('http://some.new.url/location/new')
payload = crypt.verify_signed_jwt_with_certs(
jwt.access_token,
{'key': datafile('public_cert.pem')},
audience='http://some.new.url/location/new')
# Ensure we do not cache the token
self.assertIsNone(jwt.access_token)

@mock.patch('oauth2client.service_account._UTCNOW')
def test_authorize_stale_token(self, utcnow):
Expand Down

0 comments on commit 3e6d957

Please sign in to comment.