From 82cde04917da5c2db31241e4bb334ed5e0ba66fe Mon Sep 17 00:00:00 2001 From: Ken Payson Date: Fri, 22 Apr 2016 10:10:28 -0700 Subject: [PATCH] Added JWTAccessCredentials. Newer Google APIs can accept JWTs signed using ServiceAccountCredentials for authentication. (See https://jwt.io/). The new behavior for GoogleCredentials.get_application_default() will attempt to use a signed JWT if ServiceAccountCredentials are available and no scope is specified. Upon specifying a scope, OAuth2 authentication will be used. --- oauth2client/client.py | 42 ++++-- oauth2client/service_account.py | 198 +++++++++++++++++++++++++-- tests/test_client.py | 14 ++ tests/test_service_account.py | 228 ++++++++++++++++++++++++++++++++ 4 files changed, 457 insertions(+), 25 deletions(-) diff --git a/oauth2client/client.py b/oauth2client/client.py index 3c958d282..97cc7a361 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -495,6 +495,26 @@ def _update_query_params(uri, params): return urllib.parse.urlunparse(new_parts) +def _initialize_headers(headers): + """Creates a copy of the headers.""" + if headers is None: + headers = {} + else: + headers = dict(headers) + return headers + + +def _apply_user_agent(headers, user_agent): + """Adds a user-agent to the headers.""" + if user_agent is not None: + if 'user-agent' in headers: + headers['user-agent'] = (user_agent + ' ' + headers['user-agent']) + else: + headers['user-agent'] = user_agent + + return headers + + class OAuth2Credentials(Credentials): """Credentials object for OAuth 2.0. @@ -598,18 +618,9 @@ def new_request(uri, method='GET', body=None, headers=None, # Clone and modify the request headers to add the appropriate # Authorization header. - if headers is None: - headers = {} - else: - headers = dict(headers) + headers = _initialize_headers(headers) self.apply(headers) - - if self.user_agent is not None: - if 'user-agent' in headers: - headers['user-agent'] = (self.user_agent + ' ' + - headers['user-agent']) - else: - headers['user-agent'] = self.user_agent + _apply_user_agent(headers, self.user_agent) body_stream_position = None if all(getattr(body, stream_prop, None) for stream_prop in @@ -1237,6 +1248,7 @@ def from_json(cls, json_data): # TODO(issue 388): eliminate the circularity that is the reason for # this non-top-level import. from oauth2client.service_account import ServiceAccountCredentials + from oauth2client.service_account import _JWTAccessCredentials data = json.loads(_from_bytes(json_data)) # We handle service_account.ServiceAccountCredentials since it is a @@ -1244,6 +1256,10 @@ def from_json(cls, json_data): if (data['_module'] == 'oauth2client.service_account' and data['_class'] == 'ServiceAccountCredentials'): return ServiceAccountCredentials.from_json(data) + elif (data['_module'] == 'oauth2client.service_account' and + data['_class'] == '_JWTAccessCredentials'): + return _JWTAccessCredentials.from_json(data) + token_expiry = _parse_expiry(data.get('token_expiry')) google_credentials = cls( @@ -1523,8 +1539,8 @@ def _get_application_default_credential_from_file(filename): token_uri=GOOGLE_TOKEN_URI, user_agent='Python client library') else: # client_credentials['type'] == SERVICE_ACCOUNT - from oauth2client.service_account import ServiceAccountCredentials - return ServiceAccountCredentials.from_json_keyfile_dict( + from oauth2client.service_account import _JWTAccessCredentials + return _JWTAccessCredentials.from_json_keyfile_dict( client_credentials) diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py index f009b0c56..8fca83b06 100644 --- a/oauth2client/service_account.py +++ b/oauth2client/service_account.py @@ -17,6 +17,7 @@ import base64 import copy import datetime +import httplib2 import json import time @@ -26,9 +27,16 @@ from oauth2client._helpers import _from_bytes from oauth2client._helpers import _urlsafe_b64encode from oauth2client import util +from oauth2client.client import _apply_user_agent +from oauth2client.client import _initialize_headers +from oauth2client.client import AccessTokenInfo from oauth2client.client import AssertionCredentials +from oauth2client.client import clean_headers from oauth2client.client import EXPIRY_FORMAT +from oauth2client.client import GoogleCredentials from oauth2client.client import SERVICE_ACCOUNT +from oauth2client.client import TokenRevokeError +from oauth2client.client import _UTCNOW from oauth2client import crypt @@ -426,6 +434,32 @@ def create_scoped(self, scopes): result._private_key_pkcs12 = self._private_key_pkcs12 result._private_key_password = self._private_key_password return result + + def create_with_claims(self, claims): + """Create credentials that specify additional claims. + + Args: + claims: dict, key-value pairs for claims. + + Returns: + ServiceAccountCredentials, a copy of the current service account + credentials with updated claims to use when obtaining access tokens. + """ + new_kwargs = dict(self._kwargs) + new_kwargs.update(claims) + result = self.__class__(self._service_account_email, + self._signer, + scopes=self._scopes, + private_key_id=self._private_key_id, + client_id=self.client_id, + user_agent=self._user_agent, + **new_kwargs) + result.token_uri = self.token_uri + result.revoke_uri = self.revoke_uri + result._private_key_pkcs8_pem = self._private_key_pkcs8_pem + result._private_key_pkcs12 = self._private_key_pkcs12 + result._private_key_password = self._private_key_password + return result def create_delegated(self, sub): """Create credentials that act as domain-wide delegation of authority. @@ -446,18 +480,158 @@ def create_delegated(self, sub): ServiceAccountCredentials, a copy of the current service account updated to act on behalf of ``sub``. """ - new_kwargs = dict(self._kwargs) - new_kwargs['sub'] = sub - result = self.__class__(self._service_account_email, - self._signer, - scopes=self._scopes, - private_key_id=self._private_key_id, - client_id=self.client_id, - user_agent=self._user_agent, - **new_kwargs) + return self.create_with_claims({'sub': sub}) + + +def _datetime_to_secs(utc_time): + # TODO(issue 298): use time_delta.total_seconds() + # time_delta.total_seconds() not supported in Python 2.6 + epoch = datetime.datetime(1970, 1, 1) + time_delta = utc_time - epoch + return time_delta.days * 86400 + time_delta.seconds + + +class _JWTAccessCredentials(ServiceAccountCredentials): + """Self signed JWT credentials. + + Makes an assertion to server using a self signed JWT from service account + credentials. These credentials do NOT use OAuth 2.0 and instead + authenticate directly. + """ + _MAX_TOKEN_LIFETIME_SECS = 3600 + """Max lifetime of the token (one hour, in seconds).""" + + def __init__(self, + service_account_email, + signer, + scopes=None, + private_key_id=None, + client_id=None, + user_agent=None, + additional_claims=None): + if additional_claims is None: + additional_claims = {} + super(_JWTAccessCredentials, self).__init__( + service_account_email, + signer, + private_key_id=private_key_id, + client_id=client_id, + user_agent=user_agent, + **additional_claims) + + def authorize(self, http): + """Authorize an httplib2.Http instance with a JWT assertion. + + Unless specified, the 'aud' of the assertion will be the base + uri of the request. + + Args: + http: An instance of ``httplib2.Http`` or something that acts + like it. + Returns: + A modified instance of http that was passed in. + Example:: + h = httplib2.Http() + h = credentials.authorize(h) + """ + request_orig = http.request + request_auth = super(_JWTAccessCredentials, self).authorize(http).request + + # The closure that will replace 'httplib2.Http.request'. + def new_request(uri, method='GET', body=None, headers=None, + redirections=httplib2.DEFAULT_MAX_REDIRECTS, + connection_type=None): + if 'aud' in self._kwargs: + # Preemptively refresh token, this is not done for OAuth2 + if self.access_token is None or self.access_token_expired: + self.refresh(None) + return request_auth(uri, method, body, + headers, redirections, + connection_type) + else: + # If we don't have an 'aud' (audience) claim, + # create a 1-time token with the uri root as the audience + headers = _initialize_headers(headers) + _apply_user_agent(headers, self.user_agent) + uri_root = uri.split('?', 1)[0] + token, unused_expiry = self._create_token({'aud': uri_root}) + + headers['Authorization'] = 'Bearer ' + token + return request_orig(uri, method, body, + clean_headers(headers), + redirections, connection_type) + + # Replace the request method with our own closure. + http.request = new_request + + return http + + def get_access_token(self, http=None, additional_claims=None): + """Create a signed jwt. + + Args: + http: unused + additional_claims: dict, additional claims to add to + the payload of the JWT. + Returns: + An AccessTokenInfo with the signed jwt + """ + if additional_claims is None: + if self.access_token is None or self.access_token_expired: + self.refresh(None) + return AccessTokenInfo(access_token=self.access_token, + expires_in=self._expires_in()) + else: + # Create a 1 time token + token, unused_expiry = self._create_token(additional_claims) + return AccessTokenInfo(access_token=token, + expires_in=self._MAX_TOKEN_LIFETIME_SECS) + + def revoke(self, http): + """Cannot revoke JWTAccessCredentials tokens.""" + pass + + def create_scoped_required(self): + # JWTAccessCredentials are unscoped by definition + return True + + def create_scoped(self, scopes): + # Returns an OAuth2 credentials with the given scope + result = ServiceAccountCredentials(self._service_account_email, + self._signer, + scopes=scopes, + private_key_id=self._private_key_id, + client_id=self.client_id, + user_agent=self._user_agent, + **self._kwargs) result.token_uri = self.token_uri result.revoke_uri = self.revoke_uri - result._private_key_pkcs8_pem = self._private_key_pkcs8_pem - result._private_key_pkcs12 = self._private_key_pkcs12 - result._private_key_password = self._private_key_password + if self._private_key_pkcs8_pem is not None: + result._private_key_pkcs8_pem = self._private_key_pkcs8_pem + if self._private_key_pkcs12 is not None: + result._private_key_pkcs12 = self._private_key_pkcs12 + if self._private_key_password is not None: + result._private_key_password = self._private_key_password return result + + def refresh(self, http): + self._refresh(None) + + def _refresh(self, http_request): + self.access_token, self.token_expiry = self._create_token() + + def _create_token(self, additional_claims=None): + now = _UTCNOW() + expiry = now + datetime.timedelta(seconds=self._MAX_TOKEN_LIFETIME_SECS) + payload = { + 'iat': _datetime_to_secs(now), + 'exp': _datetime_to_secs(expiry), + 'iss': self._service_account_email, + 'sub': self._service_account_email + } + payload.update(self._kwargs) + if additional_claims is not None: + payload.update(additional_claims) + jwt = crypt.make_signed_jwt(self._signer, payload, + key_id=self._private_key_id) + return jwt.decode('ascii'), expiry diff --git a/tests/test_client.py b/tests/test_client.py index 6e0d997ed..7e398777b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -840,6 +840,20 @@ def test_to_from_json_service_account(self): creds2_vals.pop('_signer') self.assertEqual(creds1_vals, creds2_vals) + def test_to_from_json_service_account_scoped(self): + credentials_file = datafile( + os.path.join('gcloud', _WELL_KNOWN_CREDENTIALS_FILE)) + creds1 = GoogleCredentials.from_stream(credentials_file) + creds1 = creds1.create_scoped(['dummy_scope']) + # Convert to and then back from json. + creds2 = GoogleCredentials.from_json(creds1.to_json()) + + creds1_vals = creds1.__dict__ + creds1_vals.pop('_signer') + creds2_vals = creds2.__dict__ + creds2_vals.pop('_signer') + self.assertEqual(creds1_vals, creds2_vals) + def test_parse_expiry(self): dt = datetime.datetime(2016, 1, 1) parsed_expiry = client._parse_expiry(dt) diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 3d9e7db5c..9697bc567 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -23,11 +23,13 @@ import rsa import tempfile +import httplib2 import mock import unittest2 from .http_mock import HttpMockSequence from oauth2client import crypt +from oauth2client.service_account import _JWTAccessCredentials from oauth2client.service_account import ServiceAccountCredentials from oauth2client.service_account import SERVICE_ACCOUNT @@ -354,6 +356,232 @@ def test_access_token(self, utcnow): self.assertEqual(credentials.access_token, token2) +TOKEN_LIFE = _JWTAccessCredentials._MAX_TOKEN_LIFETIME_SECS +T1 = 42 +T1_DATE = datetime.datetime(1970, 1, 1, second=T1) +T1_EXPIRY = T1 + TOKEN_LIFE +T1_EXPIRY_DATE = T1_DATE + datetime.timedelta(seconds=TOKEN_LIFE) + +T2 = T1 + 100 +T2_DATE = T1_DATE + datetime.timedelta(seconds=100) +T2_EXPIRY = T2 + TOKEN_LIFE +T2_EXPIRY_DATE = T2_DATE + datetime.timedelta(seconds=TOKEN_LIFE) + +T3 = T1 + TOKEN_LIFE + 1 +T3_DATE = T1_DATE + datetime.timedelta(seconds=TOKEN_LIFE + 1) +T3_EXPIRY = T3 + TOKEN_LIFE +T3_EXPIRY_DATE = T3_DATE + datetime.timedelta(seconds=TOKEN_LIFE) + + +class JWTAccessCredentialsTests(unittest2.TestCase): + + def setUp(self): + self.client_id = '123' + self.service_account_email = 'dummy@google.com' + self.private_key_id = 'ABCDEF' + self.private_key = datafile('pem_from_pkcs12.pem') + self.signer = crypt.Signer.from_string(self.private_key) + self.url = 'https://test.url.com' + self.jwt = _JWTAccessCredentials(self.service_account_email, + self.signer, + private_key_id=self.private_key_id, + client_id=self.client_id, + additional_claims={'aud': self.url}) + + @mock.patch('oauth2client.service_account._UTCNOW') + @mock.patch('oauth2client.client._UTCNOW') + @mock.patch('time.time') + def test_get_access_token_no_claims(self, time, client_utcnow, utcnow): + utcnow.return_value = T1_DATE + client_utcnow.return_value = T1_DATE + time.return_value = T1 + + token_info = self.jwt.get_access_token() + payload = crypt.verify_signed_jwt_with_certs( + token_info.access_token, + {'key': datafile('public_cert.pem')}, audience=self.url) + self.assertEqual(payload['iss'], self.service_account_email) + self.assertEqual(payload['sub'], self.service_account_email) + self.assertEqual(payload['iat'], T1) + self.assertEqual(payload['exp'], T1_EXPIRY) + self.assertEqual(token_info.expires_in, T1_EXPIRY - T1) + + # Verify that we vend the same token after 100 seconds + utcnow.return_value = T2_DATE + client_utcnow.return_value = T2_DATE + token_info = self.jwt.get_access_token() + payload = crypt.verify_signed_jwt_with_certs( + token_info.access_token, + {'key': datafile('public_cert.pem')}, audience=self.url) + self.assertEqual(payload['iat'], T1) + self.assertEqual(payload['exp'], T1_EXPIRY) + self.assertEqual(token_info.expires_in, T1_EXPIRY - T2) + + # Verify that we vend a new token after _MAX_TOKEN_LIFETIME_SECS + utcnow.return_value = T3_DATE + client_utcnow.return_value = T3_DATE + time.return_value = T3 + token_info = self.jwt.get_access_token() + payload = crypt.verify_signed_jwt_with_certs( + token_info.access_token, + {'key': datafile('public_cert.pem')}, audience=self.url) + expires_in = token_info.expires_in + self.assertEqual(payload['iat'], T3) + self.assertEqual(payload['exp'], T3_EXPIRY) + self.assertEqual(expires_in, T3_EXPIRY - T3) + + @mock.patch('oauth2client.service_account._UTCNOW') + @mock.patch('time.time') + def test_get_access_token_additional_claims(self, time, utcnow): + utcnow.return_value = T1_DATE + time.return_value = T1 + + token_info = self.jwt.get_access_token(additional_claims= + {'aud': 'https://test2.url.com', + 'sub': 'dummy2@google.com' + }) + payload = crypt.verify_signed_jwt_with_certs( + token_info.access_token, + {'key' : datafile('public_cert.pem')}, + audience='https://test2.url.com') + expires_in = token_info.expires_in + self.assertEqual(payload['iss'], self.service_account_email) + self.assertEqual(payload['sub'], 'dummy2@google.com') + self.assertEqual(payload['iat'], T1) + self.assertEqual(payload['exp'], T1_EXPIRY) + self.assertEqual(expires_in, T1_EXPIRY - T1) + + def test_revoke(self): + self.jwt.revoke(None) + + def test_create_scoped_required(self): + self.assertTrue(self.jwt.create_scoped_required()) + + def test_create_scoped(self): + self.jwt._private_key_pkcs12 = '' + self.jwt._private_key_password = '' + + new_credentials = self.jwt.create_scoped('dummy_scope') + self.assertNotEqual(self.jwt, new_credentials) + self.assertIsInstance(new_credentials, ServiceAccountCredentials) + self.assertEqual('dummy_scope', new_credentials._scopes) + + @mock.patch('oauth2client.service_account._UTCNOW') + @mock.patch('oauth2client.client._UTCNOW') + @mock.patch('time.time') + def test_authorize_success(self, time, client_utcnow, utcnow): + utcnow.return_value = T1_DATE + client_utcnow.return_value = T1_DATE + time.return_value = T1 + + def mock_request(uri, method='GET', body=None, headers=None, + redirections=0, connection_type=None): + self.assertEqual(uri, self.url) + bearer, token = headers[b'Authorization'].split() + payload = crypt.verify_signed_jwt_with_certs( + token, + {'key': datafile('public_cert.pem')}, + audience=self.url) + self.assertEqual(payload['iss'], self.service_account_email) + self.assertEqual(payload['sub'], self.service_account_email) + self.assertEqual(payload['iat'], T1) + self.assertEqual(payload['exp'], T1_EXPIRY) + self.assertEqual(uri, self.url) + self.assertEqual(bearer, b'Bearer') + return (httplib2.Response({'status': '200'}), b'') + + h = httplib2.Http() + h.request = mock_request + self.jwt.authorize(h) + h.request(self.url) + + # Ensure we use the cached token + utcnow.return_value = T2_DATE + client_utcnow.return_value = T2_DATE + h.request(self.url) + + @mock.patch('oauth2client.service_account._UTCNOW') + @mock.patch('oauth2client.client._UTCNOW') + @mock.patch('time.time') + def test_authorize_no_aud(self, time, client_utcnow, utcnow): + utcnow.return_value = T1_DATE + client_utcnow.return_value = T1_DATE + time.return_value = T1 + + jwt = _JWTAccessCredentials(self.service_account_email, + self.signer, + private_key_id=self.private_key_id, + client_id=self.client_id) + + def mock_request(uri, method='GET', body=None, headers=None, + redirections=0, connection_type=None): + self.assertEqual(uri, self.url) + bearer, token = headers[b'Authorization'].split() + payload = crypt.verify_signed_jwt_with_certs( + token, + {'key': datafile('public_cert.pem')}, + audience=self.url) + self.assertEqual(payload['iss'], self.service_account_email) + self.assertEqual(payload['sub'], self.service_account_email) + self.assertEqual(payload['iat'], T1) + self.assertEqual(payload['exp'], T1_EXPIRY) + self.assertEqual(uri, self.url) + self.assertEqual(bearer, b'Bearer') + return (httplib2.Response({'status': '200'}), b'') + + h = httplib2.Http() + h.request = mock_request + jwt.authorize(h) + h.request(self.url) + + # Ensure we do not cache the token + self.assertIsNone(jwt.access_token) + + @mock.patch('oauth2client.service_account._UTCNOW') + def test_authorize_stale_token(self, utcnow): + utcnow.return_value = T1_DATE + # Create an initial token + h = HttpMockSequence([({'status': '200'}, b''), + ({'status': '200'}, b'')]) + self.jwt.authorize(h) + h.request(self.url) + token_1 = self.jwt.access_token + + # Expire the token + utcnow.return_value = T3_DATE + h.request(self.url) + token_2 = self.jwt.access_token + self.assertEquals(self.jwt.token_expiry, T3_EXPIRY_DATE) + self.assertNotEqual(token_1, token_2) + + @mock.patch('oauth2client.service_account._UTCNOW') + def test_authorize_401(self, utcnow): + utcnow.return_value = T1_DATE + + h = HttpMockSequence([ + ({'status': '200'}, b''), + ({'status': '401'}, b''), + ({'status': '200'}, b'')]) + self.jwt.authorize(h) + h.request(self.url) + token_1 = self.jwt.access_token + + utcnow.return_value = T2_DATE + self.assertEquals(h.request(self.url)[0].status, 200) + token_2 = self.jwt.access_token + # Check the 401 forced a new token + self.assertNotEqual(token_1, token_2) + + @mock.patch('oauth2client.service_account._UTCNOW') + def test_refresh(self, utcnow): + utcnow.return_value = T1_DATE + token_1 = self.jwt.access_token + + utcnow.return_value = T2_DATE + self.jwt.refresh(None) + token_2 = self.jwt.access_token + self.assertEquals(self.jwt.token_expiry, T2_EXPIRY_DATE) + self.assertNotEqual(token_1, token_2) if __name__ == '__main__': # pragma: NO COVER unittest2.main()