diff --git a/flytekit/clients/auth/auth_client.py b/flytekit/clients/auth/auth_client.py index 94afa13612..ec1fd4d3e1 100644 --- a/flytekit/clients/auth/auth_client.py +++ b/flytekit/clients/auth/auth_client.py @@ -269,9 +269,11 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials: raise ValueError('Expected "access_token" in response from oauth server') if "refresh_token" in response_body: refresh_token = response_body["refresh_token"] + if "expires_in" in response_body: + expires_in = response_body["expires_in"] access_token = response_body["access_token"] - return Credentials(access_token, refresh_token, self._endpoint) + return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in) def _request_access_token(self, auth_code) -> Credentials: if self._state != auth_code.state: diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 1fe0d9711c..9582c901d8 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -1,12 +1,10 @@ -import base64 import logging import subprocess import typing from abc import abstractmethod from dataclasses import dataclass -import requests - +from . import token_client from .auth_client import AuthorizationClient from .exceptions import AccessTokenNotFoundError, AuthenticationError from .keyring import Credentials, KeyringStore @@ -22,6 +20,7 @@ class ClientConfig: authorization_endpoint: str redirect_uri: str client_id: str + device_authorization_endpoint: typing.Optional[str] = None scopes: typing.List[str] = None header_key: str = "authorization" @@ -155,8 +154,6 @@ class ClientCredentialsAuthenticator(Authenticator): This Authenticator uses ClientId and ClientSecret to authenticate """ - _utf_8 = "utf-8" - def __init__( self, endpoint: str, @@ -176,48 +173,6 @@ def __init__( self._client_secret = client_secret super().__init__(endpoint, cfg.header_key or header_key) - @staticmethod - def get_token(token_endpoint: str, authorization_header: str, scopes: typing.List[str]) -> typing.Tuple[str, int]: - """ - :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration - in seconds - """ - headers = { - "Authorization": authorization_header, - "Cache-Control": "no-cache", - "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded", - } - body = { - "grant_type": "client_credentials", - } - if scopes is not None: - body["scope"] = ",".join(scopes) - response = requests.post(token_endpoint, data=body, headers=headers) - if response.status_code != 200: - logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) - raise AuthenticationError("Non-200 received from IDP") - - response = response.json() - return response["access_token"], response["expires_in"] - - @staticmethod - def get_basic_authorization_header(client_id: str, client_secret: str) -> str: - """ - This function transforms the client id and the client secret into a header that conforms with http basic auth. - It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text - - :param client_id: str - :param client_secret: str - :rtype: str - """ - concated = "{}:{}".format(client_id, client_secret) - return "Basic {}".format( - base64.b64encode(concated.encode(ClientCredentialsAuthenticator._utf_8)).decode( - ClientCredentialsAuthenticator._utf_8 - ) - ) - def refresh_credentials(self): """ This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler @@ -231,7 +186,56 @@ def refresh_credentials(self): # Note that unlike the Pkce flow, the client ID does not come from Admin. logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}") - authorization_header = self.get_basic_authorization_header(self._client_id, self._client_secret) - token, expires_in = self.get_token(token_endpoint, authorization_header, scopes) + authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret) + token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header) logging.info("Retrieved new token, expires in {}".format(expires_in)) self._creds = Credentials(token) + + +class DeviceCodeAuthenticator(Authenticator): + """ + This Authenticator implements the Device Code authorization flow useful for headless user authentication. + + Examples described + - https://developer.okta.com/docs/guides/device-authorization-grant/main/ + - https://auth0.com/docs/get-started/authentication-and-authorization-flow/device-authorization-flow#device-flow + """ + + def __init__( + self, + endpoint: str, + cfg_store: ClientConfigStore, + header_key: typing.Optional[str] = None, + audience: typing.Optional[str] = None, + ): + self._audience = audience + cfg = cfg_store.get_client_config() + self._client_id = cfg.client_id + self._device_auth_endpoint = cfg.device_authorization_endpoint + self._scope = cfg.scopes + self._token_endpoint = cfg.token_endpoint + if self._device_auth_endpoint is None: + raise AuthenticationError( + "Device Authentication is not available on the Flyte backend / authentication server" + ) + super().__init__( + endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint) + ) + + def refresh_credentials(self): + resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope) + print( + f""" +To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code} +OR copy paste the following URL: {resp.verification_uri_complete} + """ + ) + try: + # Currently the refresh token is not retreived. We may want to add support for refreshTokens so that + # access tokens can be refreshed for once authenticated machines + token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id) + self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint) + KeyringStore.store(self._creds) + except Exception: + KeyringStore.delete(self._endpoint) + raise diff --git a/flytekit/clients/auth/exceptions.py b/flytekit/clients/auth/exceptions.py index 6e790e47a4..5086c5b6e1 100644 --- a/flytekit/clients/auth/exceptions.py +++ b/flytekit/clients/auth/exceptions.py @@ -12,3 +12,11 @@ class AuthenticationError(RuntimeError): """ pass + + +class AuthenticationPending(RuntimeError): + """ + This is raised if the token endpoint returns authentication pending + """ + + pass diff --git a/flytekit/clients/auth/keyring.py b/flytekit/clients/auth/keyring.py index c2b19c46b6..831558f4c2 100644 --- a/flytekit/clients/auth/keyring.py +++ b/flytekit/clients/auth/keyring.py @@ -15,6 +15,7 @@ class Credentials(object): access_token: str refresh_token: str = "na" for_endpoint: str = "flyte-default" + expires_in: typing.Optional[int] = None class KeyringStore: diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py new file mode 100644 index 0000000000..e7e55f74a9 --- /dev/null +++ b/flytekit/clients/auth/token_client.py @@ -0,0 +1,151 @@ +import base64 +import enum +import logging +import time +import typing +from dataclasses import dataclass +from datetime import datetime, timedelta + +import requests + +from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending + +utf_8 = "utf-8" + +# Errors that Token endpoint will return +error_slow_down = "slow_down" +error_auth_pending = "authorization_pending" + + +# Grant Types +class GrantType(str, enum.Enum): + CLIENT_CREDS = "client_credentials" + DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code" + + +@dataclass +class DeviceCodeResponse: + """ + Response from device auth flow endpoint + {'device_code': 'code', + 'user_code': 'BNDJJFXL', + 'verification_uri': 'url', + 'verification_uri_complete': 'url', + 'expires_in': 600, + 'interval': 5} + """ + + device_code: str + user_code: str + verification_uri: str + verification_uri_complete: str + expires_in: int + interval: int + + @classmethod + def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse": + return cls( + device_code=j["device_code"], + user_code=j["user_code"], + verification_uri=j["verification_uri"], + verification_uri_complete=j["verification_uri_complete"], + expires_in=j["expires_in"], + interval=j["interval"], + ) + + +def get_basic_authorization_header(client_id: str, client_secret: str) -> str: + """ + This function transforms the client id and the client secret into a header that conforms with http basic auth. + It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text + + :param client_id: str + :param client_secret: str + :rtype: str + """ + concated = "{}:{}".format(client_id, client_secret) + return "Basic {}".format(base64.b64encode(concated.encode(utf_8)).decode(utf_8)) + + +def get_token( + token_endpoint: str, + scopes: typing.Optional[typing.List[str]] = None, + authorization_header: typing.Optional[str] = None, + client_id: typing.Optional[str] = None, + device_code: typing.Optional[str] = None, + grant_type: GrantType = GrantType.CLIENT_CREDS, +) -> typing.Tuple[str, int]: + """ + :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration + in seconds + """ + headers = { + "Cache-Control": "no-cache", + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + } + if authorization_header: + headers["Authorization"] = authorization_header + body = { + "grant_type": grant_type.value, + } + if client_id: + body["client_id"] = client_id + if device_code: + body["device_code"] = device_code + if scopes is not None: + body["scope"] = ",".join(scopes) + + response = requests.post(token_endpoint, data=body, headers=headers) + if not response.ok: + j = response.json() + if "error" in j: + err = j["error"] + if err == error_auth_pending or err == error_slow_down: + raise AuthenticationPending(f"Token not yet available, try again in some time {err}") + logging.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) + raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) + + j = response.json() + return j["access_token"], j["expires_in"] + + +def get_device_code( + device_auth_endpoint: str, + client_id: str, + audience: typing.Optional[str] = None, + scope: typing.Optional[typing.List[str]] = None, +) -> DeviceCodeResponse: + """ + Retrieves the device Authentication code that can be done to authenticate the request using a browser on a + separate device + """ + payload = {"client_id": client_id, "scope": scope, "audience": audience} + resp = requests.post(device_auth_endpoint, payload) + if not resp.ok: + raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}") + return DeviceCodeResponse.from_json_response(resp.json()) + + +def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]: + tick = datetime.now() + interval = timedelta(seconds=resp.interval) + end_time = tick + timedelta(seconds=resp.expires_in) + while tick < end_time: + try: + access_token, expires_in = get_token( + token_endpoint, + grant_type=GrantType.DEVICE_CODE, + client_id=client_id, + device_code=resp.device_code, + ) + print("Authentication successful!") + return access_token, expires_in + except AuthenticationPending: + ... + except Exception: + raise + print("Authentication Pending...") + time.sleep(interval.total_seconds()) + tick = tick + interval + raise AuthenticationError("Authentication failed!") diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 3a5464fd6e..93bd883324 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -12,6 +12,7 @@ ClientConfigStore, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, ) from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor @@ -41,6 +42,7 @@ def get_client_config(self) -> ClientConfig: client_id=public_client_config.client_id, scopes=public_client_config.scopes, header_key=public_client_config.authorization_metadata_key or None, + device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint, ) @@ -79,6 +81,8 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth command=cfg.command, header_key=client_cfg.header_key if client_cfg else None, ) + elif cfg_auth == AuthType.DEVICEFLOW: + return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience) else: raise ValueError( f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value" diff --git a/flytekit/clis/sdk_in_container/utils.py b/flytekit/clis/sdk_in_container/utils.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 6485f3a9d5..7d259b760e 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -344,6 +344,7 @@ class AuthType(enum.Enum): CLIENTSECRET = "ClientSecret" PKCE = "Pkce" EXTERNALCOMMAND = "ExternalCommand" + DEVICEFLOW = "DeviceFlow" @dataclass(init=True, repr=True, eq=True, frozen=True) @@ -376,6 +377,7 @@ class PlatformConfig(object): client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) auth_mode: AuthType = AuthType.STANDARD + audience: typing.Optional[str] = None rpc_retries: int = 3 @classmethod diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 5c1586970a..fdbddb2ebe 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -8,10 +8,12 @@ ClientConfig, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, StaticClientConfigStore, ) from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth.token_client import DeviceCodeResponse ENDPOINT = "example.com" @@ -65,24 +67,8 @@ def test_command_authenticator(mock_subprocess: MagicMock): authn.refresh_credentials() -def test_get_basic_authorization_header(): - header = ClientCredentialsAuthenticator.get_basic_authorization_header("client_id", "abc") - assert header == "Basic Y2xpZW50X2lkOmFiYw==" - - -@patch("flytekit.clients.auth.authenticator.requests") -def test_get_token(mock_requests): - response = MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response - access, expiration = ClientCredentialsAuthenticator.get_token("https://corp.idp.net", "abc123", ["my_scope"]) - assert access == "abc" - assert expiration == 60 - - -@patch("flytekit.clients.auth.authenticator.requests") -def test_client_creds_authenticator_without_custom_scopes(mock_requests): +@patch("flytekit.clients.auth.token_client.requests") +def test_client_creds_authenticator(mock_requests): authn = ClientCredentialsAuthenticator( ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store ) @@ -93,12 +79,43 @@ def test_client_creds_authenticator_without_custom_scopes(mock_requests): mock_requests.post.return_value = response authn.refresh_credentials() expected_scopes = static_cfg_store.get_client_config().scopes - assert authn._creds assert authn._scopes == expected_scopes -@patch("flytekit.clients.auth.authenticator.requests") +@patch("flytekit.clients.auth.authenticator.KeyringStore") +@patch("flytekit.clients.auth.token_client.get_device_code") +@patch("flytekit.clients.auth.token_client.poll_token_endpoint") +def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock): + with pytest.raises(AuthenticationError): + DeviceCodeAuthenticator( + ENDPOINT, + static_cfg_store, + audience="x", + ) + + cfg_store = StaticClientConfigStore( + ClientConfig( + token_endpoint="token_endpoint", + authorization_endpoint="auth_endpoint", + redirect_uri="redirect_uri", + client_id="client", + device_authorization_endpoint="dev", + ) + ) + authn = DeviceCodeAuthenticator( + ENDPOINT, + cfg_store, + audience="x", + ) + + device_mock.return_value = DeviceCodeResponse("x", "y", "s", "m", 1000, 0) + poll_mock.return_value = ("access", 100) + authn.refresh_credentials() + assert authn._creds + + +@patch("flytekit.clients.auth.token_client.requests") def test_client_creds_authenticator_with_custom_scopes(mock_requests): expected_scopes = ["foo", "baz"] authn = ClientCredentialsAuthenticator( diff --git a/tests/flytekit/unit/clients/auth/test_token_client.py b/tests/flytekit/unit/clients/auth/test_token_client.py new file mode 100644 index 0000000000..6e56c351bc --- /dev/null +++ b/tests/flytekit/unit/clients/auth/test_token_client.py @@ -0,0 +1,78 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth.token_client import ( + DeviceCodeResponse, + error_auth_pending, + get_basic_authorization_header, + get_device_code, + get_token, + poll_token_endpoint, +) + + +def test_get_basic_authorization_header(): + header = get_basic_authorization_header("client_id", "abc") + assert header == "Basic Y2xpZW50X2lkOmFiYw==" + + +@patch("flytekit.clients.auth.token_client.requests") +def test_get_token(mock_requests): + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"]) + assert access == "abc" + assert expiration == 60 + + +@patch("flytekit.clients.auth.token_client.requests") +def test_get_device_code(mock_requests): + response = MagicMock() + response.ok = False + mock_requests.post.return_value = response + with pytest.raises(AuthenticationError): + get_device_code("test.com", "test") + + response.ok = True + response.json.return_value = { + "device_code": "code", + "user_code": "BNDJJFXL", + "verification_uri": "url", + "verification_uri_complete": "url", + "expires_in": 600, + "interval": 5, + } + mock_requests.post.return_value = response + c = get_device_code("test.com", "test") + assert c + assert c.device_code == "code" + + +@patch("flytekit.clients.auth.token_client.requests") +def test_poll_token_endpoint(mock_requests): + response = MagicMock() + response.ok = False + response.json.return_value = {"error": error_auth_pending} + mock_requests.post.return_value = response + + r = DeviceCodeResponse( + device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=1 + ) + with pytest.raises(AuthenticationError): + poll_token_endpoint(r, "test.com", "test") + + response = MagicMock() + response.ok = True + response.json.return_value = {"access_token": "abc", "expires_in": 60} + mock_requests.post.return_value = response + r = DeviceCodeResponse( + device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=0 + ) + t, e = poll_token_endpoint(r, "test.com", "test") + assert t + assert e diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index 8f14de730e..3bd57918f4 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -9,6 +9,7 @@ ClientConfigStore, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, ) from flytekit.clients.auth.exceptions import AuthenticationError @@ -31,6 +32,8 @@ OAUTH_AUTHORIZE = "https://your.domain.io/oauth2/authorize" +DEVICE_AUTH_ENDPOINT = "https://your.domain.io/..." + def get_auth_service_mock() -> MagicMock: auth_stub_mock = MagicMock() @@ -66,13 +69,14 @@ def test_remote_client_config_store(mock_auth_service: MagicMock): assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE -def get_client_config() -> ClientConfigStore: +def get_client_config(**kwargs) -> ClientConfigStore: cfg_store = MagicMock() cfg_store.get_client_config.return_value = ClientConfig( token_endpoint=TOKEN_ENDPOINT, authorization_endpoint=OAUTH_AUTHORIZE, redirect_uri=REDIRECT_URI, client_id=CLIENT_ID, + **kwargs ) return cfg_store @@ -135,6 +139,15 @@ def test_get_authenticator_cmd(): assert authn._cmd == ["echo"] +def test_get_authenticator_deviceflow(): + cfg = PlatformConfig(auth_mode=AuthType.DEVICEFLOW) + with pytest.raises(AuthenticationError): + get_authenticator(cfg, get_client_config()) + + authn = get_authenticator(cfg, get_client_config(device_authorization_endpoint=DEVICE_AUTH_ENDPOINT)) + assert isinstance(authn, DeviceCodeAuthenticator) + + def test_wrap_exceptions_channel(): ch = MagicMock() out_ch = wrap_exceptions_channel(PlatformConfig(), ch)