diff --git a/dev-requirements.in b/dev-requirements.in index 755231ed71..25e6cf51e7 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -12,6 +12,7 @@ codespell google-cloud-bigquery google-cloud-bigquery-storage IPython +keyrings.alt # Only install tensorflow if not running on an arm Mac. tensorflow==2.8.1; platform_machine!='arm64' or platform_system!='Darwin' diff --git a/flytekit/clis/auth/__init__.py b/flytekit/clients/auth/__init__.py similarity index 100% rename from flytekit/clis/auth/__init__.py rename to flytekit/clients/auth/__init__.py diff --git a/flytekit/clis/auth/auth.py b/flytekit/clients/auth/auth_client.py similarity index 56% rename from flytekit/clis/auth/auth.py rename to flytekit/clients/auth/auth_client.py index f54379485a..94afa13612 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clients/auth/auth_client.py @@ -1,32 +1,31 @@ +from __future__ import annotations + import base64 as _base64 import hashlib as _hashlib import http.server as _BaseHTTPServer +import logging +import multiprocessing import os as _os import re as _re +import typing import urllib.parse as _urlparse import webbrowser as _webbrowser +from dataclasses import dataclass from http import HTTPStatus as _StatusCodes -from multiprocessing import get_context as _mp_get_context +from multiprocessing import get_context from urllib.parse import urlencode as _urlencode -import keyring as _keyring import requests as _requests -from flytekit.loggers import auth_logger +from .default_html import get_default_success_html +from .exceptions import AccessTokenNotFoundError +from .keyring import Credentials _code_verifier_length = 64 _random_seed_length = 40 _utf_8 = "utf-8" -# Identifies the service used for storing passwords in keyring -_keyring_service_name = "flyteauth" -# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs -# suggest, we are storing a user's oidc. -_keyring_access_token_storage_key = "access_token" -_keyring_refresh_token_storage_key = "refresh_token" - - def _generate_code_verifier(): """ Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1 @@ -77,6 +76,17 @@ def state(self): return self._state +@dataclass +class EndpointMetadata(object): + """ + This class can be used to control the rendering of the page on login successful or failure + """ + + endpoint: str + success_html: typing.Optional[bytes] = None + failure_html: typing.Optional[bytes] = None + + class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler): """ A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an @@ -87,12 +97,16 @@ def do_GET(self): url = _urlparse.urlparse(self.path) if url.path.strip("/") == self.server.redirect_path.strip("/"): self.send_response(_StatusCodes.OK) + self.send_header("Content-type", "text/html") self.end_headers() self.handle_login(dict(_urlparse.parse_qsl(url.query))) + if self.server.remote_metadata.success_html is None: + self.wfile.write(bytes(get_default_success_html(self.server.remote_metadata.endpoint), "utf-8")) + self.wfile.flush() else: self.send_response(_StatusCodes.NOT_FOUND) - def handle_login(self, data): + def handle_login(self, data: dict): self.server.handle_authorization_code(AuthorizationCode(data["code"], data["state"])) @@ -104,49 +118,97 @@ class OAuthHTTPServer(_BaseHTTPServer.HTTPServer): def __init__( self, - server_address, - RequestHandlerClass, - bind_and_activate=True, - redirect_path=None, - queue=None, + server_address: typing.Tuple[str, int], + remote_metadata: EndpointMetadata, + request_handler_class: typing.Type[_BaseHTTPServer.BaseHTTPRequestHandler], + bind_and_activate: bool = True, + redirect_path: str = None, + queue: multiprocessing.Queue = None, ): - _BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) + _BaseHTTPServer.HTTPServer.__init__(self, server_address, request_handler_class, bind_and_activate) self._redirect_path = redirect_path + self._remote_metadata = remote_metadata self._auth_code = None self._queue = queue @property - def redirect_path(self): + def redirect_path(self) -> str: return self._redirect_path - def handle_authorization_code(self, auth_code): + @property + def remote_metadata(self) -> EndpointMetadata: + return self._remote_metadata + + def handle_authorization_code(self, auth_code: str): self._queue.put(auth_code) self.server_close() - def handle_request(self, queue=None): + def handle_request(self, queue: multiprocessing.Queue = None) -> typing.Any: self._queue = queue return super().handle_request() -class Credentials(object): - def __init__(self, access_token=None): - self._access_token = access_token +class _SingletonPerEndpoint(type): + """ + A metaclass to create per endpoint singletons for AuthorizationClient objects + """ + + _instances: typing.Dict[str, AuthorizationClient] = {} - @property - def access_token(self): - return self._access_token + def __call__(cls, *args, **kwargs): + endpoint = "" + if args: + endpoint = args[0] + elif "auth_endpoint" in kwargs: + endpoint = kwargs["auth_endpoint"] + else: + raise ValueError("parameter auth_endpoint is required") + if endpoint not in cls._instances: + cls._instances[endpoint] = super(_SingletonPerEndpoint, cls).__call__(*args, **kwargs) + return cls._instances[endpoint] -class AuthorizationClient(object): +class AuthorizationClient(metaclass=_SingletonPerEndpoint): + """ + Authorization client that stores the credentials in keyring and uses oauth2 standard flow to retrieve the + credentials. NOTE: This will open an web browser to retreive the credentials. + """ + def __init__( self, - auth_endpoint=None, - token_endpoint=None, - scopes=None, - client_id=None, - redirect_uri=None, + endpoint: str, + auth_endpoint: str, + token_endpoint: str, + scopes: typing.Optional[typing.List[str]] = None, + client_id: typing.Optional[str] = None, + redirect_uri: typing.Optional[str] = None, + endpoint_metadata: typing.Optional[EndpointMetadata] = None, + verify: typing.Optional[typing.Union[bool, str]] = None, ): + """ + Create new AuthorizationClient + + :param endpoint: str endpoint to connect to + :param auth_endpoint: str endpoint where auth metadata can be found + :param token_endpoint: str endpoint to retrieve token from + :param scopes: list[str] oauth2 scopes + :param client_id + :param verify: (optional) Either a boolean, in which case it controls whether we verify + the server's TLS certificate, or a string, in which case it must be a path + to a CA bundle to use. Defaults to ``True``. When set to + ``False``, requests will accept any TLS certificate presented by + the server, and will ignore hostname mismatches and/or expired + certificates, which will make your application vulnerable to + man-in-the-middle (MitM) attacks. Setting verify to ``False`` + may be useful during local development or testing. + """ + self._endpoint = endpoint self._auth_endpoint = auth_endpoint + if endpoint_metadata is None: + remote_url = _urlparse.urlparse(self._auth_endpoint) + self._remote = EndpointMetadata(endpoint=remote_url.hostname) + else: + self._remote = endpoint_metadata self._token_endpoint = token_endpoint self._client_id = client_id self._scopes = scopes or [] @@ -156,10 +218,8 @@ def __init__( self._code_challenge = code_challenge state = _generate_state_parameter() self._state = state - self._credentials = None - self._refresh_token = None + self._verify = verify self._headers = {"content-type": "application/x-www-form-urlencoded"} - self._expired = False self._params = { "client_id": client_id, # This must match the Client ID of the OAuth application. @@ -174,55 +234,27 @@ def __init__( "code_challenge_method": "S256", } - # Prefer to use already-fetched token values when they've been set globally. - self.set_tokens_from_store() - def __repr__(self): return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})" - def set_tokens_from_store(self): - self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key) - access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key) - if access_token: - self._credentials = Credentials(access_token=access_token) - - @property - def has_valid_credentials(self) -> bool: - return self._credentials is not None - - def start_authorization_flow(self): - # In the absence of globally-set token values, initiate the token request flow - ctx = _mp_get_context("fork") - q = ctx.Queue() - - # First prepare the callback server in the background - server = self._create_callback_server() - server_process = ctx.Process(target=server.handle_request, args=(q,)) - server_process.daemon = True - server_process.start() - - # Send the call to request the authorization code in the background - self._request_authorization_code() - - # Request the access token once the auth code has been received. - auth_code = q.get() - server_process.terminate() - self.request_access_token(auth_code) - def _create_callback_server(self): server_url = _urlparse.urlparse(self._redirect_uri) server_address = (server_url.hostname, server_url.port) - return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path) + return OAuthHTTPServer( + server_address, + self._remote, + OAuthCallbackHandler, + redirect_path=server_url.path, + ) def _request_authorization_code(self): scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint) query = _urlencode(self._params) endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) - auth_logger.debug(f"Requesting authorization code through {endpoint}") + logging.debug(f"Requesting authorization code through {endpoint}") _webbrowser.open_new_tab(endpoint) - def _initialize_credentials(self, auth_token_resp): - + def _credentials_from_response(self, auth_token_resp) -> Credentials: """ The auth_token_resp body is of the form: { @@ -232,21 +264,16 @@ def _initialize_credentials(self, auth_token_resp): } """ response_body = auth_token_resp.json() + refresh_token = None if "access_token" not in response_body: raise ValueError('Expected "access_token" in response from oauth server') if "refresh_token" in response_body: - self._refresh_token = response_body["refresh_token"] - _keyring.set_password( - _keyring_service_name, _keyring_refresh_token_storage_key, response_body["refresh_token"] - ) - + refresh_token = response_body["refresh_token"] access_token = response_body["access_token"] - _keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token) - # Once keyring credentials have been updated, get the singleton AuthorizationClient to read them again. - self.set_tokens_from_store() + return Credentials(access_token, refresh_token, self._endpoint) - def request_access_token(self, auth_code): + def _request_access_token(self, auth_code) -> Credentials: if self._state != auth_code.state: raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed") self._params.update( @@ -262,6 +289,7 @@ def request_access_token(self, auth_code): data=self._params, headers=self._headers, allow_redirects=False, + verify=self._verify, ) if resp.status_code != _StatusCodes.OK: # TODO: handle expected (?) error cases: @@ -269,38 +297,53 @@ def request_access_token(self, auth_code): raise Exception( "Failed to request access token with response: [{}] {}".format(resp.status_code, resp.content) ) - self._initialize_credentials(resp) + return self._credentials_from_response(resp) + + def get_creds_from_remote(self) -> Credentials: + """ + This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to + retrieve credentials + """ + # In the absence of globally-set token values, initiate the token request flow + ctx = get_context("fork") + q = ctx.Queue() + + # First prepare the callback server in the background + server = self._create_callback_server() + + server_process = ctx.Process(target=server.handle_request, args=(q,)) + server_process.daemon = True - def refresh_access_token(self): - if self._refresh_token is None: + try: + server_process.start() + + # Send the call to request the authorization code in the background + self._request_authorization_code() + + # Request the access token once the auth code has been received. + auth_code = q.get() + return self._request_access_token(auth_code) + finally: + server_process.terminate() + + def refresh_access_token(self, credentials: Credentials) -> Credentials: + if credentials.refresh_token is None: raise ValueError("no refresh token available with which to refresh authorization credentials") resp = _requests.post( url=self._token_endpoint, - data={"grant_type": "refresh_token", "client_id": self._client_id, "refresh_token": self._refresh_token}, + data={ + "grant_type": "refresh_token", + "client_id": self._client_id, + "refresh_token": credentials.refresh_token, + }, headers=self._headers, allow_redirects=False, + verify=self._verify, ) if resp.status_code != _StatusCodes.OK: - self._expired = True # In the absence of a successful response, assume the refresh token is expired. This should indicate # to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized. + raise AccessTokenNotFoundError(f"Non-200 returned from refresh token endpoint {resp.status_code}") - _keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key) - _keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key) - raise ValueError(f"Non-200 returned from refresh token endpoint {resp.status_code}") - self._initialize_credentials(resp) - - @property - def credentials(self): - """ - :return flytekit.clis.auth.auth.Credentials: - """ - return self._credentials - - @property - def expired(self): - """ - :return bool: - """ - return self._expired + return self._credentials_from_response(resp) diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py new file mode 100644 index 0000000000..183c1787cd --- /dev/null +++ b/flytekit/clients/auth/authenticator.py @@ -0,0 +1,235 @@ +import base64 +import logging +import subprocess +import typing +from abc import abstractmethod +from dataclasses import dataclass + +import requests + +from .auth_client import AuthorizationClient +from .exceptions import AccessTokenNotFoundError, AuthenticationError +from .keyring import Credentials, KeyringStore + + +@dataclass +class ClientConfig: + """ + Client Configuration that is needed by the authenticator + """ + + token_endpoint: str + authorization_endpoint: str + redirect_uri: str + client_id: str + scopes: typing.List[str] = None + header_key: str = "authorization" + + +class ClientConfigStore(object): + """ + Client Config store retrieve client config. this can be done in multiple ways + """ + + @abstractmethod + def get_client_config(self) -> ClientConfig: + ... + + +class StaticClientConfigStore(ClientConfigStore): + def __init__(self, cfg: ClientConfig): + self._cfg = cfg + + def get_client_config(self) -> ClientConfig: + return self._cfg + + +class Authenticator(object): + """ + Base authenticator for all authentication flows + """ + + def __init__(self, endpoint: str, header_key: str, credentials: Credentials = None): + self._endpoint = endpoint + self._creds = credentials + self._header_key = header_key if header_key else "authorization" + + def get_credentials(self) -> Credentials: + return self._creds + + def _set_credentials(self, creds): + self._creds = creds + + def _set_header_key(self, h: str): + self._header_key = h + + def fetch_grpc_call_auth_metadata(self) -> typing.Optional[typing.Tuple[str, str]]: + if self._creds: + return self._header_key, f"Bearer {self._creds.access_token}" + return None + + @abstractmethod + def refresh_credentials(self): + ... + + +class PKCEAuthenticator(Authenticator): + """ + This Authenticator encapsulates the entire PKCE flow and automatically opens a browser window for login + """ + + def __init__( + self, + endpoint: str, + cfg_store: ClientConfigStore, + header_key: typing.Optional[str] = None, + verify: typing.Optional[typing.Union[bool, str]] = None, + ): + """ + Initialize with default creds from KeyStore using the endpoint name + """ + super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint)) + self._cfg_store = cfg_store + self._auth_client = None + self._verify = verify + + def _initialize_auth_client(self): + if not self._auth_client: + cfg = self._cfg_store.get_client_config() + self._set_header_key(cfg.header_key) + self._auth_client = AuthorizationClient( + endpoint=self._endpoint, + redirect_uri=cfg.redirect_uri, + client_id=cfg.client_id, + scopes=cfg.scopes, + auth_endpoint=cfg.authorization_endpoint, + token_endpoint=cfg.token_endpoint, + verify=self._verify, + ) + + def refresh_credentials(self): + """ """ + self._initialize_auth_client() + if self._creds: + """We have an access token so lets try to refresh it""" + try: + self._creds = self._auth_client.refresh_access_token(self._creds) + if self._creds: + KeyringStore.store(self._creds) + return + except AccessTokenNotFoundError: + logging.warning("Failed to refresh token. Kicking off a full authorization flow.") + KeyringStore.delete(self._endpoint) + + self._creds = self._auth_client.get_creds_from_remote() + KeyringStore.store(self._creds) + + +class CommandAuthenticator(Authenticator): + """ + This Authenticator retreives access_token using the provided command + """ + + def __init__(self, command: typing.List[str], header_key: str = None): + self._cmd = command + if not self._cmd: + raise AuthenticationError("Command cannot be empty for command authenticator") + super().__init__(None, header_key) + + def refresh_credentials(self): + """ + This function is used when the configuration value for AUTH_MODE is set to 'external_process'. + It reads an id token generated by an external process started by running the 'command'. + """ + logging.debug("Starting external process to generate id token. Command {}".format(self._cmd)) + try: + output = subprocess.run(self._cmd, capture_output=True, text=True, check=True) + except subprocess.CalledProcessError as e: + logging.error("Failed to generate token from command {}".format(self._cmd)) + raise AuthenticationError("Problems refreshing token with command: " + str(e)) + self._creds = Credentials(output.stdout.strip()) + + +class ClientCredentialsAuthenticator(Authenticator): + """ + This Authenticator uses ClientId and ClientSecret to authenticate + """ + + _utf_8 = "utf-8" + + def __init__( + self, + endpoint: str, + client_id: str, + client_secret: str, + cfg_store: ClientConfigStore, + header_key: str = None, + ): + if not client_id or not client_secret: + raise ValueError("Client ID and Client SECRET both are required.") + cfg = cfg_store.get_client_config() + self._token_endpoint = cfg.token_endpoint + self._scopes = cfg.scopes + self._client_id = client_id + self._client_secret = client_secret + super().__init__(endpoint, cfg.header_key or header_key) + + @staticmethod + def get_token(token_endpoint: str, authorization_header: str, scopes: typing.List[str]) -> typing.Tuple[str, int]: + """ + :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration + in seconds + """ + headers = { + "Authorization": authorization_header, + "Cache-Control": "no-cache", + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + } + body = { + "grant_type": "client_credentials", + } + if scopes is not None: + body["scope"] = ",".join(scopes) + response = requests.post(token_endpoint, data=body, headers=headers) + if response.status_code != 200: + logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) + raise AuthenticationError("Non-200 received from IDP") + + response = response.json() + return response["access_token"], response["expires_in"] + + @staticmethod + def get_basic_authorization_header(client_id: str, client_secret: str) -> str: + """ + This function transforms the client id and the client secret into a header that conforms with http basic auth. + It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text + + :param client_id: str + :param client_secret: str + :rtype: str + """ + concated = "{}:{}".format(client_id, client_secret) + return "Basic {}".format( + base64.b64encode(concated.encode(ClientCredentialsAuthenticator._utf_8)).decode( + ClientCredentialsAuthenticator._utf_8 + ) + ) + + def refresh_credentials(self): + """ + This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler + is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin, + like when waiting for another workflow to complete from within a task). This function uses basic auth, which means + the credentials for basic auth must be present from wherever this code is running. + + """ + token_endpoint = self._token_endpoint + scopes = self._scopes + + # Note that unlike the Pkce flow, the client ID does not come from Admin. + logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}") + authorization_header = self.get_basic_authorization_header(self._client_id, self._client_secret) + token, expires_in = self.get_token(token_endpoint, authorization_header, scopes) + logging.info("Retrieved new token, expires in {}".format(expires_in)) + self._creds = Credentials(token) diff --git a/flytekit/clients/auth/default_html.py b/flytekit/clients/auth/default_html.py new file mode 100644 index 0000000000..cd7f3d8964 --- /dev/null +++ b/flytekit/clients/auth/default_html.py @@ -0,0 +1,12 @@ +def get_default_success_html(endpoint: str) -> str: + return f""" + + + OAuth2 Authentication Success + + +

Successfully logged into {endpoint}

+ Flyte login + + +""" # noqa diff --git a/flytekit/clients/auth/exceptions.py b/flytekit/clients/auth/exceptions.py new file mode 100644 index 0000000000..6e790e47a4 --- /dev/null +++ b/flytekit/clients/auth/exceptions.py @@ -0,0 +1,14 @@ +class AccessTokenNotFoundError(RuntimeError): + """ + This error is raised with Access token is not found or if Refreshing the token fails + """ + + pass + + +class AuthenticationError(RuntimeError): + """ + This is raised for any AuthenticationError + """ + + pass diff --git a/flytekit/clients/auth/keyring.py b/flytekit/clients/auth/keyring.py new file mode 100644 index 0000000000..c2b19c46b6 --- /dev/null +++ b/flytekit/clients/auth/keyring.py @@ -0,0 +1,64 @@ +import logging +import typing +from dataclasses import dataclass + +import keyring as _keyring +from keyring.errors import NoKeyringError + + +@dataclass +class Credentials(object): + """ + Stores the credentials together + """ + + access_token: str + refresh_token: str = "na" + for_endpoint: str = "flyte-default" + + +class KeyringStore: + """ + Methods to access Keyring Store. + """ + + _access_token_key = "access_token" + _refresh_token_key = "refresh_token" + + @staticmethod + def store(credentials: Credentials) -> Credentials: + try: + _keyring.set_password( + credentials.for_endpoint, + KeyringStore._refresh_token_key, + credentials.refresh_token, + ) + _keyring.set_password( + credentials.for_endpoint, + KeyringStore._access_token_key, + credentials.access_token, + ) + except NoKeyringError as e: + logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}") + return credentials + + @staticmethod + def retrieve(for_endpoint: str) -> typing.Optional[Credentials]: + try: + refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key) + access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key) + except NoKeyringError as e: + logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}") + return None + + if not access_token: + return None + return Credentials(access_token, refresh_token, for_endpoint) + + @staticmethod + def delete(for_endpoint: str): + try: + _keyring.delete_password(for_endpoint, KeyringStore._access_token_key) + _keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key) + except NoKeyringError as e: + logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}") diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py new file mode 100644 index 0000000000..41fc5c025f --- /dev/null +++ b/flytekit/clients/auth_helper.py @@ -0,0 +1,193 @@ +import logging +import ssl + +import grpc +from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest +from flyteidl.service.auth_pb2_grpc import AuthMetadataServiceStub +from OpenSSL import crypto + +from flytekit.clients.auth.authenticator import ( + Authenticator, + ClientConfig, + ClientConfigStore, + ClientCredentialsAuthenticator, + CommandAuthenticator, + PKCEAuthenticator, +) +from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor +from flytekit.clients.grpc_utils.wrap_exception_interceptor import RetryExceptionWrapperInterceptor +from flytekit.configuration import AuthType, PlatformConfig + + +class RemoteClientConfigStore(ClientConfigStore): + """ + This class implements the ClientConfigStore that is served by the Flyte Server, that implements AuthMetadataService + """ + + def __init__(self, secure_channel: grpc.Channel): + self._secure_channel = secure_channel + + def get_client_config(self) -> ClientConfig: + """ + Retrieves the ClientConfig from the given grpc.Channel assuming AuthMetadataService is available + """ + metadata_service = AuthMetadataServiceStub(self._secure_channel) + public_client_config = metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest()) + oauth2_metadata = metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest()) + return ClientConfig( + token_endpoint=oauth2_metadata.token_endpoint, + authorization_endpoint=oauth2_metadata.authorization_endpoint, + redirect_uri=public_client_config.redirect_uri, + client_id=public_client_config.client_id, + scopes=public_client_config.scopes, + header_key=public_client_config.authorization_metadata_key or None, + ) + + +def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Authenticator: + """ + Returns a new authenticator based on the platform config. + """ + cfg_auth = cfg.auth_mode + if type(cfg_auth) is str: + try: + cfg_auth = AuthType[cfg_auth.upper()] + except KeyError: + logging.warning(f"Authentication type {cfg_auth} does not exist, defaulting to standard") + cfg_auth = AuthType.STANDARD + + if cfg_auth == AuthType.STANDARD or cfg_auth == AuthType.PKCE: + verify = None + if cfg.insecure_skip_verify: + verify = False + elif cfg.ca_cert_file_path: + verify = cfg.ca_cert_file_path + return PKCEAuthenticator(cfg.endpoint, cfg_store, verify=verify) + elif cfg_auth == AuthType.BASIC or cfg_auth == AuthType.CLIENT_CREDENTIALS or cfg_auth == AuthType.CLIENTSECRET: + return ClientCredentialsAuthenticator( + endpoint=cfg.endpoint, + client_id=cfg.client_id, + client_secret=cfg.client_credentials_secret, + cfg_store=cfg_store, + ) + elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND: + client_cfg = None + if cfg_store: + client_cfg = cfg_store.get_client_config() + return CommandAuthenticator( + command=cfg.command, + header_key=client_cfg.header_key if client_cfg else None, + ) + else: + raise ValueError( + f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value" + ) + + +def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Channel) -> grpc.Channel: + """ + Given a grpc.Channel, preferrably a secure channel, it returns a composed channel that uses Interceptor to + perform an Oauth2.0 Auth flow + :param cfg: PlatformConfig + :param in_channel: grpc.Channel Precreated channel + :return: grpc.Channel. New composite channel + """ + authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel)) + return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator)) + + +def get_authenticated_channel(cfg: PlatformConfig) -> grpc.Channel: + """ + Returns a new channel for the given config that is authenticated + """ + channel = ( + grpc.insecure_channel(cfg.endpoint) + if cfg.insecure + else grpc.secure_channel(cfg.endpoint, grpc.ssl_channel_credentials()) + ) # noqa + return upgrade_channel_to_authenticated(cfg, channel) + + +def load_cert(cert_file: str) -> crypto.X509: + """ + Given a cert-file loads the PEM certificate and returns + """ + st_cert = open(cert_file, "rt").read() + return crypto.load_certificate(crypto.FILETYPE_PEM, st_cert) + + +def bootstrap_creds_from_server(endpoint: str) -> grpc.ChannelCredentials: + """ + Retrieves the SSL cert from the remote and uses that. should be used only if insecure-skip-verify + """ + # Get port from endpoint or use 443 + endpoint_parts = endpoint.rsplit(":", 1) + if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit(): + server_address = (endpoint_parts[0], endpoint_parts[1]) + else: + server_address = (endpoint, "443") + cert = ssl.get_server_certificate(server_address) # noqa + return grpc.ssl_channel_credentials(str.encode(cert)) + + +def get_channel(cfg: PlatformConfig, **kwargs) -> grpc.Channel: + """ + Creates a new grpc.Channel given a platformConfig. + It is possible to pass additional options to the underlying channel. Examples for various options are as below + + .. code-block:: python + + get_channel(cfg=PlatformConfig(...)) + + .. code-block:: python + :caption: Additional options to insecure / secure channel. Example `options` and `compression` refer to grpc guide + + get_channel(cfg=PlatformConfig(...), options=..., compression=...) + + .. code-block:: python + :caption: Create secure channel with custom `grpc.ssl_channel_credentials` + + get_channel(cfg=PlatformConfig(insecure=False,...), credentials=...) + + + :param cfg: PlatformConfig + :param kwargs: Optional arguments to be passed to channel method. Refer to usage example above + :return: grpc.Channel (secure / insecure) + """ + if cfg.insecure: + return grpc.insecure_channel(cfg.endpoint, **kwargs) + + credentials = None + if "credentials" not in kwargs: + if cfg.insecure_skip_verify: + credentials = bootstrap_creds_from_server(cfg.endpoint) + elif cfg.ca_cert_file_path: + credentials = grpc.ssl_channel_credentials(load_cert(cfg.ca_cert_file_path)) + else: + credentials = grpc.ssl_channel_credentials( + root_certificates=kwargs.get("root_certificates", None), + private_key=kwargs.get("private_key", None), + certificate_chain=kwargs.get("certificate_chain", None), + ) + else: + credentials = kwargs["credentials"] + return grpc.secure_channel( + target=cfg.endpoint, + credentials=credentials, + options=kwargs.get("options", None), + compression=kwargs.get("compression", None), + ) + + +def wrap_exceptions_channel(cfg: PlatformConfig, in_channel: grpc.Channel) -> grpc.Channel: + """ + Wraps the input channel with RetryExceptionWrapperInterceptor. This wrapper will cover all + exceptions and raise Exception from the Family flytekit.exceptions + + .. note:: This channel should be usually the outermost channel. This channel will raise a FlyteException + + :param cfg: PlatformConfig + :param in_channel: grpc.Channel + :return: grpc.Channel + """ + return grpc.intercept_channel(in_channel, RetryExceptionWrapperInterceptor(max_retries=cfg.rpc_retries)) diff --git a/flytekit/clients/grpc_utils/__init__.py b/flytekit/clients/grpc_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/clients/grpc_utils/auth_interceptor.py b/flytekit/clients/grpc_utils/auth_interceptor.py new file mode 100644 index 0000000000..21bcc30136 --- /dev/null +++ b/flytekit/clients/grpc_utils/auth_interceptor.py @@ -0,0 +1,80 @@ +import typing +from collections import namedtuple + +import grpc + +from flytekit.clients.auth.authenticator import Authenticator + + +class _ClientCallDetails( + namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")), + grpc.ClientCallDetails, +): + """ + Wrapper class for initializing a new ClientCallDetails instance. + We cannot make this of type - NamedTuple because, NamedTuple has a metaclass of type NamedTupleMeta and both + the metaclasses conflict + """ + + pass + + +class AuthUnaryInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor): + """ + This Interceptor can be used to automatically add Auth Metadata for every call - lazily in case authentication + is needed. + """ + + def __init__(self, authenticator: Authenticator): + self._authenticator = authenticator + + def _call_details_with_auth_metadata(self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails: + """ + Returns new ClientCallDetails with metadata added. + """ + metadata = None + auth_metadata = self._authenticator.fetch_grpc_call_auth_metadata() + if auth_metadata: + metadata = [] + if client_call_details.metadata: + metadata.extend(list(client_call_details.metadata)) + metadata.append(auth_metadata) + + return _ClientCallDetails( + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + ) + + def intercept_unary_unary( + self, + continuation: typing.Callable, + client_call_details: grpc.ClientCallDetails, + request: typing.Any, + ): + """ + Intercepts unary calls and adds auth metadata if available. On Unauthenticated, resets the token and refreshes + and then retries with the new token + """ + updated_call_details = self._call_details_with_auth_metadata(client_call_details) + fut: grpc.Future = continuation(updated_call_details, request) + e = fut.exception() + if e: + if e.code() == grpc.StatusCode.UNAUTHENTICATED: + self._authenticator.refresh_credentials() + updated_call_details = self._call_details_with_auth_metadata(client_call_details) + return continuation(updated_call_details, request) + return fut + + def intercept_unary_stream(self, continuation, client_call_details, request): + """ + Handles a stream call and adds authentication metadata if needed + """ + updated_call_details = self._call_details_with_auth_metadata(client_call_details) + c: grpc.Call = continuation(updated_call_details, request) + if c.code() == grpc.StatusCode.UNAUTHENTICATED: + self._authenticator.refresh_credentials() + updated_call_details = self._call_details_with_auth_metadata(client_call_details) + return continuation(updated_call_details, request) + return c diff --git a/flytekit/clients/grpc_utils/wrap_exception_interceptor.py b/flytekit/clients/grpc_utils/wrap_exception_interceptor.py new file mode 100644 index 0000000000..ea796f464a --- /dev/null +++ b/flytekit/clients/grpc_utils/wrap_exception_interceptor.py @@ -0,0 +1,49 @@ +import typing +from typing import Union + +import grpc + +from flytekit.exceptions.base import FlyteException +from flytekit.exceptions.system import FlyteSystemException +from flytekit.exceptions.user import ( + FlyteAuthenticationException, + FlyteEntityAlreadyExistsException, + FlyteEntityNotExistException, + FlyteInvalidInputException, +) + + +class RetryExceptionWrapperInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor): + def __init__(self, max_retries: int = 3): + self._max_retries = 3 + + @staticmethod + def _raise_if_exc(request: typing.Any, e: Union[grpc.Call, grpc.Future]): + if isinstance(e, grpc.RpcError): + if e.code() == grpc.StatusCode.UNAUTHENTICATED: + raise FlyteAuthenticationException() from e + elif e.code() == grpc.StatusCode.ALREADY_EXISTS: + raise FlyteEntityAlreadyExistsException() from e + elif e.code() == grpc.StatusCode.NOT_FOUND: + raise FlyteEntityNotExistException() from e + elif e.code() == grpc.StatusCode.INVALID_ARGUMENT: + raise FlyteInvalidInputException(request) from e + raise FlyteSystemException() from e + + def intercept_unary_unary(self, continuation, client_call_details, request): + retries = 0 + while True: + fut: grpc.Future = continuation(client_call_details, request) + e = fut.exception() + try: + if e: + self._raise_if_exc(request, e) + return fut + except FlyteException as e: + if retries == self._max_retries: + raise e + retries = retries + 1 + + def intercept_unary_stream(self, continuation, client_call_details, request): + c: grpc.Call = continuation(client_call_details, request) + return c diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 6c8f54e9ce..e71485b17c 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -1,91 +1,20 @@ from __future__ import annotations -import base64 as _base64 -import ssl -import subprocess -import time import typing -from typing import Optional import grpc -import requests as _requests from flyteidl.admin.project_pb2 import ProjectListRequest from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse from flyteidl.service import admin_pb2_grpc as _admin_service -from flyteidl.service import auth_pb2 -from flyteidl.service import auth_pb2_grpc as auth_service from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2 from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service from flyteidl.service import signal_pb2_grpc as signal_service from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub -from google.protobuf.json_format import MessageToJson as _MessageToJson -from flytekit.clis.auth import credentials as _credentials_access -from flytekit.configuration import AuthType, PlatformConfig -from flytekit.exceptions import user as _user_exceptions -from flytekit.exceptions.user import FlyteAuthenticationException +from flytekit.clients.auth_helper import get_channel, upgrade_channel_to_authenticated, wrap_exceptions_channel +from flytekit.configuration import PlatformConfig from flytekit.loggers import cli_logger -_utf_8 = "utf-8" - - -def _handle_rpc_error(retry=False): - def decorator(fn): - def handler(*args, **kwargs): - """ - Wraps rpc errors as Flyte exceptions and handles authentication the client. - """ - max_retries = 3 - max_wait_time = 1000 - - for i in range(max_retries): - try: - return fn(*args, **kwargs) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.UNAUTHENTICATED: - # Always retry auth errors. - if i == (max_retries - 1): - # Exit the loop and wrap the authentication error. - raise _user_exceptions.FlyteAuthenticationException(str(e)) - cli_logger.debug(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n") - args[0].refresh_credentials() - elif e.code() == grpc.StatusCode.ALREADY_EXISTS: - # There are two cases that we should throw error immediately - # 1. Entity already exists when we register entity - # 2. Entity not found when we fetch entity - raise _user_exceptions.FlyteEntityAlreadyExistsException(e) - elif e.code() == grpc.StatusCode.NOT_FOUND: - raise _user_exceptions.FlyteEntityNotExistException(e) - else: - # No more retries if retry=False or max_retries reached. - if (retry is False) or i == (max_retries - 1): - raise - else: - # Retry: Start with 200ms wait-time and exponentially back-off up to 1 second. - wait_time = min(200 * (2**i), max_wait_time) - cli_logger.error(f"Non-auth RPC error {e}, sleeping {wait_time}ms and retrying") - time.sleep(wait_time / 1000) - - return handler - - return decorator - - -def _handle_invalid_create_request(fn): - def handler(self, create_request): - try: - fn(self, create_request) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.INVALID_ARGUMENT: - cli_logger.error("Error creating Flyte entity because of invalid arguments. Create request: ") - cli_logger.error(_MessageToJson(create_request)) - cli_logger.error("Details returned from the flyte admin: ") - cli_logger.error(e.details) - # Re-raise since we're not handling the error here and add the create_request details - raise e - - return handler - class RawSynchronousFlyteClient(object): """ @@ -112,54 +41,9 @@ def __init__(self, cfg: PlatformConfig, **kwargs): insecure: if insecure is desired """ self._cfg = cfg - if cfg.insecure: - self._channel = grpc.insecure_channel(cfg.endpoint, **kwargs) - elif cfg.insecure_skip_verify: - # Get port from endpoint or use 443 - endpoint_parts = cfg.endpoint.rsplit(":", 1) - if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit(): - server_address = tuple(endpoint_parts) - else: - server_address = (cfg.endpoint, "443") - cert = ssl.get_server_certificate(server_address) - credentials = grpc.ssl_channel_credentials(str.encode(cert)) - options = kwargs.get("options", []) - self._channel = grpc.secure_channel( - target=cfg.endpoint, - credentials=credentials, - options=options, - compression=kwargs.get("compression", None), - ) - else: - if "credentials" not in kwargs: - credentials = grpc.ssl_channel_credentials( - root_certificates=kwargs.get("root_certificates", None), - private_key=kwargs.get("private_key", None), - certificate_chain=kwargs.get("certificate_chain", None), - ) - else: - credentials = kwargs["credentials"] - self._channel = grpc.secure_channel( - target=cfg.endpoint, - credentials=credentials, - options=kwargs.get("options", None), - compression=kwargs.get("compression", None), - ) + self._channel = wrap_exceptions_channel(cfg, upgrade_channel_to_authenticated(cfg, get_channel(cfg))) self._stub = _admin_service.AdminServiceStub(self._channel) - self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel) self._signal = signal_service.SignalServiceStub(self._channel) - try: - resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest()) - self._public_client_config = resp - except grpc.RpcError: - cli_logger.debug("No public client auth config found, skipping.") - self._public_client_config = None - try: - resp = self._auth_stub.GetOAuth2Metadata(auth_pb2.OAuth2MetadataRequest()) - self._oauth2_metadata = resp - except grpc.RpcError: - cli_logger.debug("No OAuth2 Metadata found, skipping.") - self._oauth2_metadata = None self._dataproxy_stub = dataproxy_service.DataProxyServiceStub(self._channel) cli_logger.info( @@ -175,161 +59,16 @@ def with_root_certificate(cls, cfg: PlatformConfig, root_cert_file: str) -> RawS b = fp.read() return RawSynchronousFlyteClient(cfg, credentials=grpc.ssl_channel_credentials(root_certificates=b)) - @property - def public_client_config(self) -> Optional[auth_pb2.PublicClientAuthConfigResponse]: - return self._public_client_config - - @property - def oauth2_metadata(self) -> Optional[auth_pb2.OAuth2MetadataResponse]: - return self._oauth2_metadata - @property def url(self) -> str: return self._cfg.endpoint - def _refresh_credentials_standard(self): - """ - This function is used when the configuration value for AUTH_MODE is set to 'standard'. - This either fetches the existing access token or initiates the flow to request a valid access token and store it. - :param self: RawSynchronousFlyteClient - :return: - """ - authorization_header_key = self.public_client_config.authorization_metadata_key or None - if not self.oauth2_metadata or not self.public_client_config: - raise ValueError( - "Raw Flyte client attempting client credentials flow but no response from Admin detected. " - "Check your Admin server's .well-known endpoints to make sure they're working as expected." - ) - - client = _credentials_access.get_client( - redirect_endpoint=self.public_client_config.redirect_uri, - client_id=self.public_client_config.client_id, - scopes=self.public_client_config.scopes, - auth_endpoint=self.oauth2_metadata.authorization_endpoint, - token_endpoint=self.oauth2_metadata.token_endpoint, - ) - - if client.has_valid_credentials and not self.check_access_token(client.credentials.access_token): - # When Python starts up, if credentials have been stored in the keyring, then the AuthorizationClient - # will have read them into its _credentials field, but it won't be in the RawSynchronousFlyteClient's - # metadata field yet. Therefore, if there's a mismatch, copy it over. - self.set_access_token(client.credentials.access_token, authorization_header_key) - return - - try: - client.refresh_access_token() - except ValueError: - client.start_authorization_flow() - - self.set_access_token(client.credentials.access_token, authorization_header_key) - - def _refresh_credentials_basic(self): - """ - This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler - is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin, - like when waiting for another workflow to complete from within a task). This function uses basic auth, which means - the credentials for basic auth must be present from wherever this code is running. - - :param self: RawSynchronousFlyteClient - :return: - """ - if not self.oauth2_metadata or not self.public_client_config: - raise ValueError( - "Raw Flyte client attempting client credentials flow but no response from Admin detected. " - "Check your Admin server's .well-known endpoints to make sure they're working as expected." - ) - - token_endpoint = self.oauth2_metadata.token_endpoint - scopes = self._cfg.scopes or self.public_client_config.scopes - scopes = ",".join(scopes) - - # Note that unlike the Pkce flow, the client ID does not come from Admin. - client_secret = self._cfg.client_credentials_secret - if not client_secret: - raise FlyteAuthenticationException("No client credentials secret provided in the config") - cli_logger.debug(f"Basic authorization flow with client id {self._cfg.client_id} scope {scopes}") - authorization_header = get_basic_authorization_header(self._cfg.client_id, client_secret) - token, expires_in = get_token(token_endpoint, authorization_header, scopes) - cli_logger.info("Retrieved new token, expires in {}".format(expires_in)) - authorization_header_key = self.public_client_config.authorization_metadata_key or None - self.set_access_token(token, authorization_header_key) - - def _refresh_credentials_from_command(self): - """ - This function is used when the configuration value for AUTH_MODE is set to 'external_process'. - It reads an id token generated by an external process started by running the 'command'. - - :param self: RawSynchronousFlyteClient - :return: - """ - command = self._cfg.command - if not command: - raise FlyteAuthenticationException("No command specified in configuration for command authentication") - cli_logger.debug("Starting external process to generate id token. Command {}".format(command)) - try: - output = subprocess.run(command, capture_output=True, text=True, check=True) - except subprocess.CalledProcessError as e: - cli_logger.error("Failed to generate token from command {}".format(command)) - raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e)) - authorization_header_key = self.public_client_config.authorization_metadata_key or None - if not authorization_header_key: - self.set_access_token(output.stdout.strip()) - self.set_access_token(output.stdout.strip(), authorization_header_key) - - def _refresh_credentials_noop(self): - pass - - def refresh_credentials(self): - cfg_auth = self._cfg.auth_mode - if type(cfg_auth) is str: - try: - cfg_auth = AuthType[cfg_auth.upper()] - except KeyError: - cli_logger.warning(f"Authentication type {cfg_auth} does not exist, defaulting to standard") - cfg_auth = AuthType.STANDARD - - if cfg_auth == AuthType.STANDARD or cfg_auth == AuthType.PKCE: - return self._refresh_credentials_standard() - elif cfg_auth == AuthType.BASIC or cfg_auth == AuthType.CLIENT_CREDENTIALS or cfg_auth == AuthType.CLIENTSECRET: - return self._refresh_credentials_basic() - elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND: - return self._refresh_credentials_from_command() - else: - raise ValueError( - f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value" - ) - - def set_access_token(self, access_token: str, authorization_header_key: Optional[str] = "authorization"): - # Always set the header to lower-case regardless of what the config is. The grpc libraries that Admin uses - # to parse the metadata don't change the metadata, but they do automatically lower the key you're looking for. - cli_logger.debug(f"Adding authorization header. Header name: {authorization_header_key}.") - self._metadata = [ - ( - authorization_header_key, - f"Bearer {access_token}", - ) - ] - - def check_access_token(self, access_token: str) -> bool: - """ - This checks to see if the given access token is the same as the one already stored in the client. The reason - this is useful is so that we can prevent unnecessary refreshing of tokens. - - :param access_token: The access token to check - :return: If no access token is stored, or if the stored token doesn't match, return False. - """ - if self._metadata is None: - return False - return access_token == self._metadata[0][1].replace("Bearer ", "") - #################################################################################################################### # # Task Endpoints # #################################################################################################################### - @_handle_rpc_error() - @_handle_invalid_create_request def create_task(self, task_create_request): """ This will create a task definition in the Admin database. Once successful, the task object can be @@ -350,7 +89,6 @@ def create_task(self, task_create_request): """ return self._stub.CreateTask(task_create_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_task_ids_paginated(self, identifier_list_request): """ This returns a page of identifiers for the tasks for a given project and domain. Filters can also be @@ -376,7 +114,6 @@ def list_task_ids_paginated(self, identifier_list_request): """ return self._stub.ListTaskIds(identifier_list_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_tasks_paginated(self, resource_list_request): """ This returns a page of task metadata for tasks in a given project and domain. Optionally, @@ -398,7 +135,6 @@ def list_tasks_paginated(self, resource_list_request): """ return self._stub.ListTasks(resource_list_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def get_task(self, get_object_request): """ This returns a single task for a given identifier. @@ -409,14 +145,12 @@ def get_task(self, get_object_request): """ return self._stub.GetTask(get_object_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def set_signal(self, signal_set_request: SignalSetRequest) -> SignalSetResponse: """ This sets a signal """ return self._signal.SetSignal(signal_set_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_signals(self, signal_list_request: SignalListRequest) -> SignalList: """ This lists signals @@ -429,8 +163,6 @@ def list_signals(self, signal_list_request: SignalListRequest) -> SignalList: # #################################################################################################################### - @_handle_rpc_error() - @_handle_invalid_create_request def create_workflow(self, workflow_create_request): """ This will create a workflow definition in the Admin database. Once successful, the workflow object can be @@ -451,7 +183,6 @@ def create_workflow(self, workflow_create_request): """ return self._stub.CreateWorkflow(workflow_create_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_workflow_ids_paginated(self, identifier_list_request): """ This returns a page of identifiers for the workflows for a given project and domain. Filters can also be @@ -477,7 +208,6 @@ def list_workflow_ids_paginated(self, identifier_list_request): """ return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_workflows_paginated(self, resource_list_request): """ This returns a page of workflow meta-information for workflows in a given project and domain. Optionally, @@ -499,7 +229,6 @@ def list_workflows_paginated(self, resource_list_request): """ return self._stub.ListWorkflows(resource_list_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def get_workflow(self, get_object_request): """ This returns a single workflow for a given identifier. @@ -516,8 +245,6 @@ def get_workflow(self, get_object_request): # #################################################################################################################### - @_handle_rpc_error() - @_handle_invalid_create_request def create_launch_plan(self, launch_plan_create_request): """ This will create a launch plan definition in the Admin database. Once successful, the launch plan object can be @@ -541,7 +268,6 @@ def create_launch_plan(self, launch_plan_create_request): # TODO: List endpoints when they come in - @_handle_rpc_error(retry=True) def get_launch_plan(self, object_get_request): """ Retrieves a launch plan entity. @@ -551,7 +277,6 @@ def get_launch_plan(self, object_get_request): """ return self._stub.GetLaunchPlan(object_get_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def get_active_launch_plan(self, active_launch_plan_request): """ Retrieves a launch plan entity. @@ -561,7 +286,6 @@ def get_active_launch_plan(self, active_launch_plan_request): """ return self._stub.GetActiveLaunchPlan(active_launch_plan_request, metadata=self._metadata) - @_handle_rpc_error() def update_launch_plan(self, update_request): """ Allows updates to a launch plan at a given identifier. Currently, a launch plan may only have it's state @@ -572,7 +296,6 @@ def update_launch_plan(self, update_request): """ return self._stub.UpdateLaunchPlan(update_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_launch_plan_ids_paginated(self, identifier_list_request): """ Lists launch plan named identifiers for a given project and domain. @@ -582,7 +305,6 @@ def list_launch_plan_ids_paginated(self, identifier_list_request): """ return self._stub.ListLaunchPlanIds(identifier_list_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_launch_plans_paginated(self, resource_list_request): """ Lists Launch Plans for a given Identifier (project, domain, name) @@ -592,7 +314,6 @@ def list_launch_plans_paginated(self, resource_list_request): """ return self._stub.ListLaunchPlans(resource_list_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_active_launch_plans_paginated(self, active_launch_plan_list_request): """ Lists Active Launch Plans for a given (project, domain) @@ -608,7 +329,6 @@ def list_active_launch_plans_paginated(self, active_launch_plan_list_request): # #################################################################################################################### - @_handle_rpc_error() def update_named_entity(self, update_named_entity_request): """ :param flyteidl.admin.common_pb2.NamedEntityUpdateRequest update_named_entity_request: @@ -622,7 +342,6 @@ def update_named_entity(self, update_named_entity_request): # #################################################################################################################### - @_handle_rpc_error() def create_execution(self, create_execution_request): """ This will create an execution for the given execution spec. @@ -631,7 +350,6 @@ def create_execution(self, create_execution_request): """ return self._stub.CreateExecution(create_execution_request, metadata=self._metadata) - @_handle_rpc_error() def recover_execution(self, recover_execution_request): """ This will recreate an execution with the same spec as the one belonging to the given execution identifier. @@ -640,7 +358,6 @@ def recover_execution(self, recover_execution_request): """ return self._stub.RecoverExecution(recover_execution_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def get_execution(self, get_object_request): """ Returns an execution of a workflow entity. @@ -650,7 +367,6 @@ def get_execution(self, get_object_request): """ return self._stub.GetExecution(get_object_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def get_execution_data(self, get_execution_data_request): """ Returns signed URLs to LiteralMap blobs for an execution's inputs and outputs (when available). @@ -660,7 +376,6 @@ def get_execution_data(self, get_execution_data_request): """ return self._stub.GetExecutionData(get_execution_data_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_executions_paginated(self, resource_list_request): """ Lists the executions for a given identifier. @@ -670,7 +385,6 @@ def list_executions_paginated(self, resource_list_request): """ return self._stub.ListExecutions(resource_list_request, metadata=self._metadata) - @_handle_rpc_error() def terminate_execution(self, terminate_execution_request): """ :param flyteidl.admin.execution_pb2.TerminateExecutionRequest terminate_execution_request: @@ -678,7 +392,6 @@ def terminate_execution(self, terminate_execution_request): """ return self._stub.TerminateExecution(terminate_execution_request, metadata=self._metadata) - @_handle_rpc_error() def relaunch_execution(self, relaunch_execution_request): """ :param flyteidl.admin.execution_pb2.ExecutionRelaunchRequest relaunch_execution_request: @@ -692,7 +405,6 @@ def relaunch_execution(self, relaunch_execution_request): # #################################################################################################################### - @_handle_rpc_error(retry=True) def get_node_execution(self, node_execution_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionGetRequest node_execution_request: @@ -700,7 +412,6 @@ def get_node_execution(self, node_execution_request): """ return self._stub.GetNodeExecution(node_execution_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def get_node_execution_data(self, get_node_execution_data_request): """ Returns signed URLs to LiteralMap blobs for a node execution's inputs and outputs (when available). @@ -710,7 +421,6 @@ def get_node_execution_data(self, get_node_execution_data_request): """ return self._stub.GetNodeExecutionData(get_node_execution_data_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_node_executions_paginated(self, node_execution_list_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_list_request: @@ -718,7 +428,6 @@ def list_node_executions_paginated(self, node_execution_list_request): """ return self._stub.ListNodeExecutions(node_execution_list_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_node_executions_for_task_paginated(self, node_execution_for_task_list_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_for_task_list_request: @@ -732,7 +441,6 @@ def list_node_executions_for_task_paginated(self, node_execution_for_task_list_r # #################################################################################################################### - @_handle_rpc_error(retry=True) def get_task_execution(self, task_execution_request): """ :param flyteidl.admin.task_execution_pb2.TaskExecutionGetRequest task_execution_request: @@ -740,7 +448,6 @@ def get_task_execution(self, task_execution_request): """ return self._stub.GetTaskExecution(task_execution_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def get_task_execution_data(self, get_task_execution_data_request): """ Returns signed URLs to LiteralMap blobs for a task execution's inputs and outputs (when available). @@ -750,7 +457,6 @@ def get_task_execution_data(self, get_task_execution_data_request): """ return self._stub.GetTaskExecutionData(get_task_execution_data_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_task_executions_paginated(self, task_execution_list_request): """ :param flyteidl.admin.task_execution_pb2.TaskExecutionListRequest task_execution_list_request: @@ -764,7 +470,6 @@ def list_task_executions_paginated(self, task_execution_list_request): # #################################################################################################################### - @_handle_rpc_error(retry=True) def list_projects(self, project_list_request: typing.Optional[ProjectListRequest] = None): """ This will return a list of the projects registered with the Flyte Admin Service @@ -775,7 +480,6 @@ def list_projects(self, project_list_request: typing.Optional[ProjectListRequest project_list_request = ProjectListRequest() return self._stub.ListProjects(project_list_request, metadata=self._metadata) - @_handle_rpc_error() def register_project(self, project_register_request): """ Registers a project along with a set of domains. @@ -784,7 +488,6 @@ def register_project(self, project_register_request): """ return self._stub.RegisterProject(project_register_request, metadata=self._metadata) - @_handle_rpc_error() def update_project(self, project): """ Update an existing project specified by id. @@ -798,7 +501,6 @@ def update_project(self, project): # Matching Attributes Endpoints # #################################################################################################################### - @_handle_rpc_error() def update_project_domain_attributes(self, project_domain_attributes_update_request): """ This updates the attributes for a project and domain registered with the Flyte Admin Service @@ -809,7 +511,6 @@ def update_project_domain_attributes(self, project_domain_attributes_update_requ project_domain_attributes_update_request, metadata=self._metadata ) - @_handle_rpc_error() def update_workflow_attributes(self, workflow_attributes_update_request): """ This updates the attributes for a project, domain, and workflow registered with the Flyte Admin Service @@ -818,7 +519,6 @@ def update_workflow_attributes(self, workflow_attributes_update_request): """ return self._stub.UpdateWorkflowAttributes(workflow_attributes_update_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def get_project_domain_attributes(self, project_domain_attributes_get_request): """ This fetches the attributes for a project and domain registered with the Flyte Admin Service @@ -827,7 +527,6 @@ def get_project_domain_attributes(self, project_domain_attributes_get_request): """ return self._stub.GetProjectDomainAttributes(project_domain_attributes_get_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def get_workflow_attributes(self, workflow_attributes_get_request): """ This fetches the attributes for a project, domain, and workflow registered with the Flyte Admin Service @@ -836,7 +535,6 @@ def get_workflow_attributes(self, workflow_attributes_get_request): """ return self._stub.GetWorkflowAttributes(workflow_attributes_get_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def list_matchable_attributes(self, matchable_attributes_list_request): """ This fetches the attributes for a specific resource type registered with the Flyte Admin Service @@ -859,7 +557,6 @@ def list_matchable_attributes(self, matchable_attributes_list_request): # Data proxy endpoints # #################################################################################################################### - @_handle_rpc_error(retry=True) def create_upload_location( self, create_upload_location_request: _dataproxy_pb2.CreateUploadLocationRequest ) -> _dataproxy_pb2.CreateUploadLocationResponse: @@ -870,7 +567,6 @@ def create_upload_location( """ return self._dataproxy_stub.CreateUploadLocation(create_upload_location_request, metadata=self._metadata) - @_handle_rpc_error(retry=True) def create_download_location( self, create_download_location_request: _dataproxy_pb2.CreateDownloadLocationRequest ) -> _dataproxy_pb2.CreateDownloadLocationResponse: @@ -880,43 +576,3 @@ def create_download_location( :rtype: flyteidl.service.dataproxy_pb2.CreateDownloadLocationResponse """ return self._dataproxy_stub.CreateDownloadLocation(create_download_location_request, metadata=self._metadata) - - -def get_token(token_endpoint, authorization_header, scope): - """ - :param Text token_endpoint: - :param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123') - :param Text scope: - :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration - in seconds - """ - headers = { - "Authorization": authorization_header, - "Cache-Control": "no-cache", - "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded", - } - body = { - "grant_type": "client_credentials", - } - if scope is not None: - body["scope"] = scope - response = _requests.post(token_endpoint, data=body, headers=headers) - if response.status_code != 200: - cli_logger.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) - raise FlyteAuthenticationException("Non-200 received from IDP") - - response = response.json() - return response["access_token"], response["expires_in"] - - -def get_basic_authorization_header(client_id, client_secret): - """ - This function transforms the client id and the client secret into a header that conforms with http basic auth. - It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. - :param Text client_id: - :param Text client_secret: - :rtype: Text - """ - concated = "{}:{}".format(client_id, client_secret) - return "Basic {}".format(_base64.b64encode(concated.encode(_utf_8)).decode(_utf_8)) diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py deleted file mode 100644 index a8475c8dfc..0000000000 --- a/flytekit/clis/auth/credentials.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import List - -from flytekit.clis.auth.auth import AuthorizationClient -from flytekit.loggers import auth_logger - -# Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3. -discovery_endpoint_path = "./.well-known/oauth-authorization-server" - -# Lazy initialized authorization client singleton -_authorization_client = None - - -def get_client( - redirect_endpoint: str, client_id: str, scopes: List[str], auth_endpoint: str, token_endpoint: str -) -> AuthorizationClient: - global _authorization_client - if _authorization_client is not None and not _authorization_client.expired: - return _authorization_client - - _authorization_client = AuthorizationClient( - redirect_uri=redirect_endpoint, - client_id=client_id, - scopes=scopes, - auth_endpoint=auth_endpoint, - token_endpoint=token_endpoint, - ) - - auth_logger.debug(f"Created oauth client with redirect {_authorization_client}") - - if not _authorization_client.has_valid_credentials: - _authorization_client.start_authorization_flow() - - return _authorization_client diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index 46513553b9..d228babf43 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -9,6 +9,7 @@ CTX_CONFIG_FILE = "config_file" CTX_PROJECT_ROOT = "project_root" CTX_MODULE = "module" +CTX_VERBOSE = "verbose" project_option = _click.option( diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 1f843450ed..5e1136d14c 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -1,8 +1,12 @@ +import typing + import click +import grpc +from google.protobuf.json_format import MessageToJson from flytekit import configuration from flytekit.clis.sdk_in_container.backfill import backfill -from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_PACKAGES +from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_PACKAGES, CTX_VERBOSE from flytekit.clis.sdk_in_container.init import init from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.package import package @@ -10,6 +14,8 @@ from flytekit.clis.sdk_in_container.run import run from flytekit.clis.sdk_in_container.serialize import serialize from flytekit.configuration.internal import LocalSDK +from flytekit.exceptions.base import FlyteException +from flytekit.exceptions.user import FlyteInvalidInputException from flytekit.loggers import cli_logger @@ -28,7 +34,60 @@ def validate_package(ctx, param, values): return pkgs -@click.group("pyflyte", invoke_without_command=True) +def pretty_print_grpc_error(e: grpc.RpcError): + if isinstance(e, grpc._channel._InactiveRpcError): # noqa + click.secho(f"RPC Failed, with Status: {e.code()}", fg="red") + click.secho(f"\tdetails: {e.details()}", fg="magenta") + click.secho(f"\tDebug string {e.debug_error_string()}", dim=True) + return + + +def pretty_print_exception(e: Exception): + if isinstance(e, click.exceptions.Exit): + raise e + + if isinstance(e, click.ClickException): + click.secho(e.message, fg="red") + raise e + + if isinstance(e, FlyteException): + if isinstance(e, FlyteInvalidInputException): + click.secho("Request rejected by the API, due to Invalid input.", fg="red") + click.secho(f"\tReason: {str(e)}", dim=True) + click.secho(f"\tInput Request: {MessageToJson(e.request)}", dim=True) + return + click.secho(f"Failed with Exception: Reason: {e._ERROR_CODE}", fg="red") # noqa + cause = e.__cause__ + if cause: + if isinstance(cause, grpc.RpcError): + pretty_print_grpc_error(cause) + else: + click.secho(f"Underlying Exception: {cause}") + return + + if isinstance(e, grpc.RpcError): + pretty_print_grpc_error(e) + return + + click.secho(f"Failed with Unknown Exception {type(e)} Reason: {e}", fg="red") # noqa + + +class ErrorHandlingCommand(click.Group): + def invoke(self, ctx: click.Context) -> typing.Any: + try: + return super().invoke(ctx) + except Exception as e: + if CTX_VERBOSE in ctx.obj and ctx.obj[CTX_VERBOSE]: + print("Verbose mode on") + raise e + pretty_print_exception(e) + raise SystemExit(e) + + +@click.group("pyflyte", invoke_without_command=True, cls=ErrorHandlingCommand) +@click.option( + "--verbose", required=False, default=False, is_flag=True, help="Show verbose messages and exception traces" +) @click.option( "-k", "--pkgs", @@ -47,7 +106,7 @@ def validate_package(ctx, param, values): help="Path to config file for use within container", ) @click.pass_context -def main(ctx, pkgs=None, config=None): +def main(ctx, pkgs: typing.List[str], config: str, verbose: bool): """ Entrypoint for all the user commands. """ @@ -63,6 +122,7 @@ def main(ctx, pkgs=None, config=None): if pkgs is None: pkgs = [] ctx.obj[CTX_PACKAGES] = pkgs + ctx.obj[CTX_VERBOSE] = verbose main.add_command(serialize) @@ -72,6 +132,7 @@ def main(ctx, pkgs=None, config=None): main.add_command(run) main.add_command(register) main.add_command(backfill) +main.epilog if __name__ == "__main__": main() diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 220f9209ea..77273cc81c 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -298,28 +298,31 @@ class PlatformConfig(object): This object contains the settings to talk to a Flyte backend (the DNS location of your Admin server basically). :param endpoint: DNS for Flyte backend - :param insecure: Whether or not to use SSL - :param insecure_skip_verify: Wether to skip SSL certificate verification - :param console_endpoint: endpoint for console if different than Flyte backend - :param command: This command is executed to return a token using an external process. + :param insecure: Whether to use SSL + :param insecure_skip_verify: Whether to skip SSL certificate verification + :param console_endpoint: endpoint for console if different from Flyte backend + :param command: This command is executed to return a token using an external process :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. :param client_credentials_secret: Used for service auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the password directly from the environment variable. Note that this is - less secure! Please only use this if mounting the secret as a file is impossible. - :param scopes: List of scopes to request. This is only applicable to the client credentials flow. - :param auth_mode: The OAuth mode to use. Defaults to pkce flow. + less secure! Please only use this if mounting the secret as a file is impossible + :param scopes: List of scopes to request. This is only applicable to the client credentials flow + :param auth_mode: The OAuth mode to use. Defaults to pkce flow + :param ca_cert_file_path: [optional] str Root Cert to be loaded and used to verify admin """ endpoint: str = "localhost:30080" insecure: bool = False insecure_skip_verify: bool = False + ca_cert_file_path: typing.Optional[str] = None console_endpoint: typing.Optional[str] = None command: typing.Optional[typing.List[str]] = None client_id: typing.Optional[str] = None client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) auth_mode: AuthType = AuthType.STANDARD + rpc_retries: int = 3 @classmethod def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig: @@ -334,6 +337,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs = set_if_exists( kwargs, "insecure_skip_verify", _internal.Platform.INSECURE_SKIP_VERIFY.read(config_file) ) + kwargs = set_if_exists(kwargs, "ca_cert_file_path", _internal.Platform.CA_CERT_FILE_PATH.read(config_file)) kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file)) kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file)) kwargs = set_if_exists( diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 5c29045db5..5c3729e63b 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -108,6 +108,9 @@ class Platform(object): LegacyConfigEntry(SECTION, "insecure_skip_verify", bool), YamlConfigEntry("admin.insecureSkipVerify", bool) ) CONSOLE_ENDPOINT = ConfigEntry(LegacyConfigEntry(SECTION, "console_endpoint"), YamlConfigEntry("console.endpoint")) + CA_CERT_FILE_PATH = ConfigEntry( + LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath") + ) class LocalSDK(object): diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index 6575218666..59b58a0a5a 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -1,3 +1,5 @@ +import typing + from flytekit.exceptions.base import FlyteException as _FlyteException from flytekit.exceptions.base import FlyteRecoverableException as _Recoverable @@ -84,3 +86,11 @@ class FlyteRecoverableException(FlyteUserException, _Recoverable): class FlyteAuthenticationException(FlyteAssertion): _ERROR_CODE = "USER:AuthenticationError" + + +class FlyteInvalidInputException(FlyteUserException): + _ERROR_CODE = "USER:BadInputToAPI" + + def __init__(self, request: typing.Any): + self.request = request + super(self).__init__() diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 93badd5374..03cc9a66e9 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -596,6 +596,7 @@ def _serialize_and_register( settings: typing.Optional[SerializationSettings], version: str, options: typing.Optional[Options] = None, + create_default_launchplan: bool = True, ) -> Identifier: """ This method serializes and register the given Flyte entity @@ -630,7 +631,7 @@ def _serialize_and_register( cp_entity, settings=settings, version=version, - create_default_launchplan=True, + create_default_launchplan=create_default_launchplan, options=options, og_entity=entity, ) @@ -685,14 +686,7 @@ def register_workflow( b.domain = ident.domain b.version = ident.version serialization_settings = b.build() - ident = self._serialize_and_register(entity, serialization_settings, version, options) - if default_launch_plan: - default_lp = LaunchPlan.get_default_launch_plan(self.context, entity) - self.register_launch_plan( - default_lp, version=ident.version, project=ident.project, domain=ident.domain, options=options - ) - remote_logger.debug("Created default launch plan for Workflow") - + ident = self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan) fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version) fwf._python_interface = entity.python_interface return fwf diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index b211153f44..e963f3dfc6 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -152,6 +152,7 @@ def test_union_type_with_invalid_input(): runner.invoke( pyflyte.main, [ + "--verbose", "run", os.path.join(DIR_NAME, "workflow.py"), "test_union2", diff --git a/tests/flytekit/unit/clients/auth/__init__.py b/tests/flytekit/unit/clients/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/cli/auth/test_auth.py b/tests/flytekit/unit/clients/auth/test_auth_client.py similarity index 52% rename from tests/flytekit/unit/cli/auth/test_auth.py rename to tests/flytekit/unit/clients/auth/test_auth_client.py index 487ddfcd1d..5ab843da8b 100644 --- a/tests/flytekit/unit/cli/auth/test_auth.py +++ b/tests/flytekit/unit/clients/auth/test_auth_client.py @@ -2,29 +2,40 @@ import re from multiprocessing import Queue as _Queue -from flytekit.clis.auth import auth as _auth +from flytekit.clients.auth.auth_client import ( + EndpointMetadata, + OAuthHTTPServer, + _create_code_challenge, + _generate_code_verifier, + _generate_state_parameter, +) def test_generate_code_verifier(): - verifier = _auth._generate_code_verifier() + verifier = _generate_code_verifier() assert verifier is not None assert 43 < len(verifier) < 128 assert not re.search(r"[^a-zA-Z0-9_\-.~]+", verifier) def test_generate_state_parameter(): - param = _auth._generate_state_parameter() + param = _generate_state_parameter() assert not re.search(r"[^a-zA-Z0-9-_.,]+", param) def test_create_code_challenge(): test_code_verifier = "test_code_verifier" - assert _auth._create_code_challenge(test_code_verifier) == "Qq1fGD0HhxwbmeMrqaebgn1qhvKeguQPXqLdpmixaM4" + assert _create_code_challenge(test_code_verifier) == "Qq1fGD0HhxwbmeMrqaebgn1qhvKeguQPXqLdpmixaM4" def test_oauth_http_server(): queue = _Queue() - server = _auth.OAuthHTTPServer(("localhost", 9000), _BaseHTTPServer.BaseHTTPRequestHandler, queue=queue) + server = OAuthHTTPServer( + ("localhost", 9000), + remote_metadata=EndpointMetadata(endpoint="example.com"), + request_handler_class=_BaseHTTPServer.BaseHTTPRequestHandler, + queue=queue, + ) test_auth_code = "auth_code" server.handle_authorization_code(test_auth_code) auth_code = queue.get() diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py new file mode 100644 index 0000000000..4c968cf0bd --- /dev/null +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -0,0 +1,95 @@ +import json +import subprocess +from unittest.mock import MagicMock, patch + +import pytest + +from flytekit.clients.auth.authenticator import ( + ClientConfig, + ClientCredentialsAuthenticator, + CommandAuthenticator, + PKCEAuthenticator, + StaticClientConfigStore, +) +from flytekit.clients.auth.exceptions import AuthenticationError + +ENDPOINT = "example.com" + +client_config = ClientConfig( + token_endpoint="token_endpoint", + authorization_endpoint="auth_endpoint", + redirect_uri="redirect_uri", + client_id="client", +) + +static_cfg_store = StaticClientConfigStore(client_config) + + +@patch("flytekit.clients.auth.authenticator.KeyringStore") +@patch("flytekit.clients.auth.auth_client.AuthorizationClient.get_creds_from_remote") +@patch("flytekit.clients.auth.auth_client.AuthorizationClient.refresh_access_token") +def test_pkce_authenticator(mock_refresh: MagicMock, mock_get_creds: MagicMock, mock_keyring: MagicMock): + mock_keyring.retrieve.return_value = None + authn = PKCEAuthenticator(ENDPOINT, static_cfg_store) + assert authn._verify is None + + authn = PKCEAuthenticator(ENDPOINT, static_cfg_store, verify=False) + assert authn._verify is False + + assert authn._creds is None + assert authn._auth_client is None + authn.refresh_credentials() + assert authn._auth_client + mock_get_creds.assert_called() + mock_refresh.assert_not_called() + mock_keyring.store.assert_called() + + authn.refresh_credentials() + mock_refresh.assert_called() + + +@patch("subprocess.run") +def test_command_authenticator(mock_subprocess: MagicMock): + with pytest.raises(AuthenticationError): + authn = CommandAuthenticator(None) # noqa + + authn = CommandAuthenticator(["echo"]) + + authn.refresh_credentials() + assert authn._creds + mock_subprocess.assert_called() + + mock_subprocess.side_effect = subprocess.CalledProcessError(-1, ["x"]) + + with pytest.raises(AuthenticationError): + authn.refresh_credentials() + + +def test_get_basic_authorization_header(): + header = ClientCredentialsAuthenticator.get_basic_authorization_header("client_id", "abc") + assert header == "Basic Y2xpZW50X2lkOmFiYw==" + + +@patch("flytekit.clients.auth.authenticator.requests") +def test_get_token(mock_requests): + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + access, expiration = ClientCredentialsAuthenticator.get_token("https://corp.idp.net", "abc123", ["my_scope"]) + assert access == "abc" + assert expiration == 60 + + +@patch("flytekit.clients.auth.authenticator.requests") +def test_client_creds_authenticator(mock_requests): + authn = ClientCredentialsAuthenticator( + ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store + ) + + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + authn.refresh_credentials() + assert authn._creds diff --git a/tests/flytekit/unit/clients/auth/test_default_html.py b/tests/flytekit/unit/clients/auth/test_default_html.py new file mode 100644 index 0000000000..391fb6a542 --- /dev/null +++ b/tests/flytekit/unit/clients/auth/test_default_html.py @@ -0,0 +1,18 @@ +from flytekit.clients.auth.default_html import get_default_success_html + + +def test_default_html(): + assert ( + get_default_success_html("flyte.org") + == """ + + + OAuth2 Authentication Success + + +

Successfully logged into flyte.org

+ Flyte login + + +""" + ) # noqa diff --git a/tests/flytekit/unit/clients/auth/test_keyring_store.py b/tests/flytekit/unit/clients/auth/test_keyring_store.py new file mode 100644 index 0000000000..d068a1f451 --- /dev/null +++ b/tests/flytekit/unit/clients/auth/test_keyring_store.py @@ -0,0 +1,32 @@ +from unittest.mock import MagicMock, patch + +from keyring.errors import NoKeyringError + +from flytekit.clients.auth.keyring import Credentials, KeyringStore + + +@patch("keyring.get_password") +def test_keyring_store_get(kr_get_password: MagicMock): + kr_get_password.return_value = "t" + assert KeyringStore.retrieve("example1.com") is not None + + kr_get_password.side_effect = NoKeyringError() + assert KeyringStore.retrieve("example2.com") is None + + +@patch("keyring.delete_password") +def test_keyring_store_delete(kr_del_password: MagicMock): + kr_del_password.return_value = None + assert KeyringStore.delete("example1.com") is None + + kr_del_password.side_effect = NoKeyringError() + assert KeyringStore.delete("example2.com") is None + + +@patch("keyring.set_password") +def test_keyring_store_set(kr_set_password: MagicMock): + kr_set_password.return_value = None + assert KeyringStore.store(Credentials(access_token="a", refresh_token="r", for_endpoint="f")) + + kr_set_password.side_effect = NoKeyringError() + assert KeyringStore.retrieve("example2.com") is None diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py new file mode 100644 index 0000000000..8f14de730e --- /dev/null +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -0,0 +1,154 @@ +import os.path +from unittest.mock import MagicMock, patch + +import pytest +from flyteidl.service.auth_pb2 import OAuth2MetadataResponse, PublicClientAuthConfigResponse + +from flytekit.clients.auth.authenticator import ( + ClientConfig, + ClientConfigStore, + ClientCredentialsAuthenticator, + CommandAuthenticator, + PKCEAuthenticator, +) +from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth_helper import ( + RemoteClientConfigStore, + get_authenticator, + load_cert, + upgrade_channel_to_authenticated, + wrap_exceptions_channel, +) +from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor +from flytekit.clients.grpc_utils.wrap_exception_interceptor import RetryExceptionWrapperInterceptor +from flytekit.configuration import AuthType, PlatformConfig + +REDIRECT_URI = "http://localhost:53593/callback" + +TOKEN_ENDPOINT = "https://your.domain.io/oauth2/token" + +CLIENT_ID = "flytectl" + +OAUTH_AUTHORIZE = "https://your.domain.io/oauth2/authorize" + + +def get_auth_service_mock() -> MagicMock: + auth_stub_mock = MagicMock() + auth_stub_mock.GetPublicClientConfig.return_value = PublicClientAuthConfigResponse( + client_id=CLIENT_ID, + redirect_uri=REDIRECT_URI, + scopes=["offline", "all"], + authorization_metadata_key="flyte-authorization", + ) + auth_stub_mock.GetOAuth2Metadata.return_value = OAuth2MetadataResponse( + issuer="https://your.domain.io", + authorization_endpoint=OAUTH_AUTHORIZE, + token_endpoint=TOKEN_ENDPOINT, + response_types_supported=["code", "token", "code token"], + scopes_supported=["all"], + token_endpoint_auth_methods_supported=["client_secret_basic"], + jwks_uri="https://your.domain.io/oauth2/jwks", + code_challenge_methods_supported=["S256"], + grant_types_supported=["client_credentials", "refresh_token", "authorization_code"], + ) + return auth_stub_mock + + +@patch("flytekit.clients.auth_helper.AuthMetadataServiceStub") +def test_remote_client_config_store(mock_auth_service: MagicMock): + ch = MagicMock() + cs = RemoteClientConfigStore(ch) + mock_auth_service.return_value = get_auth_service_mock() + + ccfg = cs.get_client_config() + assert ccfg is not None + assert ccfg.client_id == CLIENT_ID + assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE + + +def get_client_config() -> ClientConfigStore: + cfg_store = MagicMock() + cfg_store.get_client_config.return_value = ClientConfig( + token_endpoint=TOKEN_ENDPOINT, + authorization_endpoint=OAUTH_AUTHORIZE, + redirect_uri=REDIRECT_URI, + client_id=CLIENT_ID, + ) + return cfg_store + + +def test_get_authenticator_basic(): + cfg = PlatformConfig(auth_mode=AuthType.BASIC) + + with pytest.raises(ValueError, match="Client ID and Client SECRET both are required"): + get_authenticator(cfg, None) + + cfg = PlatformConfig(auth_mode=AuthType.BASIC, client_credentials_secret="xyz", client_id="id") + authn = get_authenticator(cfg, get_client_config()) + assert authn + assert isinstance(authn, ClientCredentialsAuthenticator) + + cfg = PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS, client_credentials_secret="xyz", client_id="id") + authn = get_authenticator(cfg, get_client_config()) + assert authn + assert isinstance(authn, ClientCredentialsAuthenticator) + + cfg = PlatformConfig(auth_mode=AuthType.CLIENTSECRET, client_credentials_secret="xyz", client_id="id") + authn = get_authenticator(cfg, get_client_config()) + assert authn + assert isinstance(authn, ClientCredentialsAuthenticator) + + +def test_get_authenticator_pkce(): + cfg = PlatformConfig() + authn = get_authenticator(cfg, get_client_config()) + assert authn + assert isinstance(authn, PKCEAuthenticator) + + cfg = PlatformConfig(insecure_skip_verify=True) + authn = get_authenticator(cfg, get_client_config()) + assert authn + assert isinstance(authn, PKCEAuthenticator) + assert authn._verify is False + + cfg = PlatformConfig(ca_cert_file_path="/file") + authn = get_authenticator(cfg, get_client_config()) + assert authn + assert isinstance(authn, PKCEAuthenticator) + assert authn._verify == "/file" + + +def test_get_authenticator_cmd(): + cfg = PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS) + with pytest.raises(AuthenticationError): + get_authenticator(cfg, get_client_config()) + + cfg = PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS, command=["echo"]) + authn = get_authenticator(cfg, get_client_config()) + assert authn + assert isinstance(authn, CommandAuthenticator) + + cfg = PlatformConfig(auth_mode=AuthType.EXTERNALCOMMAND, command=["echo"]) + authn = get_authenticator(cfg, get_client_config()) + assert authn + assert isinstance(authn, CommandAuthenticator) + assert authn._cmd == ["echo"] + + +def test_wrap_exceptions_channel(): + ch = MagicMock() + out_ch = wrap_exceptions_channel(PlatformConfig(), ch) + assert isinstance(out_ch._interceptor, RetryExceptionWrapperInterceptor) # noqa + + +def test_upgrade_channel_to_auth(): + ch = MagicMock() + out_ch = upgrade_channel_to_authenticated(PlatformConfig(), ch) + assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) # noqa + + +def test_load_cert(): + cert_file = os.path.join(os.path.dirname(__file__), "testdata", "rootCACert.pem") + f = load_cert(cert_file) + assert f + print(f) diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index 10a7e09333..ee4e516354 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -1,148 +1,9 @@ -import json -from subprocess import CompletedProcess +from unittest import mock -import grpc -import mock -import pytest from flyteidl.admin import project_pb2 as _project_pb2 -from flyteidl.service import auth_pb2 -from mock import MagicMock, patch -from flytekit.clients.raw import ( - RawSynchronousFlyteClient, - _handle_invalid_create_request, - get_basic_authorization_header, - get_token, -) -from flytekit.configuration import AuthType, PlatformConfig -from flytekit.configuration.internal import Credentials - - -def get_admin_stub_mock() -> mock.MagicMock: - auth_stub_mock = mock.MagicMock() - auth_stub_mock.GetPublicClientConfig.return_value = auth_pb2.PublicClientAuthConfigResponse( - client_id="flytectl", - redirect_uri="http://localhost:53593/callback", - scopes=["offline", "all"], - authorization_metadata_key="flyte-authorization", - ) - auth_stub_mock.GetOAuth2Metadata.return_value = auth_pb2.OAuth2MetadataResponse( - issuer="https://your.domain.io", - authorization_endpoint="https://your.domain.io/oauth2/authorize", - token_endpoint="https://your.domain.io/oauth2/token", - response_types_supported=["code", "token", "code token"], - scopes_supported=["all"], - token_endpoint_auth_methods_supported=["client_secret_basic"], - jwks_uri="https://your.domain.io/oauth2/jwks", - code_challenge_methods_supported=["S256"], - grant_types_supported=["client_credentials", "refresh_token", "authorization_code"], - ) - return auth_stub_mock - - -@mock.patch("flytekit.clients.raw.signal_service") -@mock.patch("flytekit.clients.raw.dataproxy_service") -@mock.patch("flytekit.clients.raw.auth_service") -@mock.patch("flytekit.clients.raw._admin_service") -@mock.patch("flytekit.clients.raw.grpc.insecure_channel") -@mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): - mock_secure_channel.return_value = True - mock_channel.return_value = True - mock_admin.AdminServiceStub.return_value = True - mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock() - client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) - client.set_access_token("abc") - assert client._metadata[0][1] == "Bearer abc" - assert client.check_access_token("abc") - - -@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.set_access_token") -@mock.patch("flytekit.clients.raw.auth_service") -@mock.patch("subprocess.run") -def test_refresh_credentials_from_command(mock_call_to_external_process, mock_admin_auth, mock_set_access_token): - token = "token" - command = ["command", "generating", "token"] - - mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock() - client = RawSynchronousFlyteClient(PlatformConfig(command=command)) - - mock_call_to_external_process.return_value = CompletedProcess(command, 0, stdout=token) - client._refresh_credentials_from_command() - - mock_call_to_external_process.assert_called_with(command, capture_output=True, text=True, check=True) - mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key) - - -@mock.patch("flytekit.clients.raw.signal_service") -@mock.patch("flytekit.clients.raw.dataproxy_service") -@mock.patch("flytekit.clients.raw.get_basic_authorization_header") -@mock.patch("flytekit.clients.raw.get_token") -@mock.patch("flytekit.clients.raw.auth_service") -@mock.patch("flytekit.clients.raw._admin_service") -@mock.patch("flytekit.clients.raw.grpc.insecure_channel") -@mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_refresh_client_credentials_aka_basic( - mock_secure_channel, - mock_channel, - mock_admin, - mock_admin_auth, - mock_get_token, - mock_get_basic_header, - mock_dataproxy, - mock_signal, -): - mock_secure_channel.return_value = True - mock_channel.return_value = True - mock_admin.AdminServiceStub.return_value = True - mock_get_basic_header.return_value = "Basic 123" - mock_get_token.return_value = ("token1", 1234567) - - mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock() - client = RawSynchronousFlyteClient( - PlatformConfig( - endpoint="a.b.com", insecure=True, client_credentials_secret="sosecret", scopes=["a", "b", "c", "d"] - ) - ) - client._metadata = None - assert not client.check_access_token("fdsa") - client._refresh_credentials_basic() - - # Scopes from configuration take precendence. - mock_get_token.assert_called_once_with("https://your.domain.io/oauth2/token", "Basic 123", "a,b,c,d") - - client.set_access_token("token") - assert client._metadata[0][0] == "authorization" - - -@mock.patch("flytekit.clients.raw.signal_service") -@mock.patch("flytekit.clients.raw.dataproxy_service") -@mock.patch("flytekit.clients.raw.auth_service") -@mock.patch("flytekit.clients.raw._admin_service") -@mock.patch("flytekit.clients.raw.grpc.insecure_channel") -@mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): - mock_secure_channel.return_value = True - mock_channel.return_value = True - mock_admin.AdminServiceStub.return_value = True - - # If the public client config is missing then raise an error - mocked_auth = get_admin_stub_mock() - mocked_auth.GetPublicClientConfig.return_value = None - mock_admin_auth.AuthMetadataServiceStub.return_value = mocked_auth - client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) - assert client.public_client_config is None - with pytest.raises(ValueError): - client._refresh_credentials_basic() - - # If the oauth2 metadata is missing then raise an error - mocked_auth = get_admin_stub_mock() - mocked_auth.GetOAuth2Metadata.return_value = None - mock_admin_auth.AuthMetadataServiceStub.return_value = mocked_auth - client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) - assert client.oauth2_metadata is None - with pytest.raises(ValueError): - client._refresh_credentials_basic() +from flytekit.clients.raw import RawSynchronousFlyteClient +from flytekit.configuration import PlatformConfig @mock.patch("flytekit.clients.raw._admin_service") @@ -161,84 +22,3 @@ def test_list_projects_paginated(mock_channel, mock_admin): project_list_request = _project_pb2.ProjectListRequest(limit=100, token="", filters=None, sort_by=None) client.list_projects(project_list_request) mock_admin.AdminServiceStub().ListProjects.assert_called_with(project_list_request, metadata=None) - - -def test_get_basic_authorization_header(): - header = get_basic_authorization_header("client_id", "abc") - assert header == "Basic Y2xpZW50X2lkOmFiYw==" - - -@patch("flytekit.clients.raw._requests") -def test_get_token(mock_requests): - response = MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response - access, expiration = get_token("https://corp.idp.net", "abc123", "my_scope") - assert access == "abc" - assert expiration == 60 - - -@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_standard") -def test_refresh_standard(mocked_method): - cc = RawSynchronousFlyteClient(PlatformConfig()) - cc.refresh_credentials() - assert mocked_method.called - - -@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_basic") -def test_refresh_basic(mocked_method): - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.BASIC)) - cc.refresh_credentials() - assert mocked_method.called - - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS)) - cc.refresh_credentials() - assert mocked_method.call_count == 2 - - -@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_basic") -def test_basic_strings(mocked_method): - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode="basic")) - cc.refresh_credentials() - assert mocked_method.called - - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode="client_credentials")) - cc.refresh_credentials() - assert mocked_method.call_count == 2 - - -@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command") -def test_refresh_command(mocked_method): - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNALCOMMAND)) - cc.refresh_credentials() - assert mocked_method.called - - -@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command") -def test_refresh_from_environment_variable(mocked_method, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv(Credentials.AUTH_MODE.legacy.get_env_name(), AuthType.EXTERNAL_PROCESS.name, prepend=False) - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=None).auto(None)) - cc.refresh_credentials() - assert mocked_method.called - - -def test__handle_invalid_create_request_decorator_happy(): - client = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS)) - mocked_method = client._stub.CreateWorkflow = mock.Mock() - _handle_invalid_create_request(client.create_workflow("/flyteidl.service.AdminService/CreateWorkflow")) - mocked_method.assert_called_once() - - -@patch("flytekit.clients.raw.cli_logger") -@patch("flytekit.clients.raw._MessageToJson") -def test__handle_invalid_create_request_decorator_raises(mock_to_JSON, mock_logger): - mock_to_JSON(return_value="test") - err = grpc.RpcError() - err.details = "There is already a workflow with different structure." - err.code = lambda: grpc.StatusCode.INVALID_ARGUMENT - client = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS)) - client._stub.CreateWorkflow = mock.Mock(side_effect=err) - with pytest.raises(grpc.RpcError): - _handle_invalid_create_request(client.create_workflow("/flyteidl.service.AdminService/CreateWorkflow")) - mock_logger.error.assert_called_with("There is already a workflow with different structure.") diff --git a/tests/flytekit/unit/clients/testdata/rootCACert.pem b/tests/flytekit/unit/clients/testdata/rootCACert.pem new file mode 100644 index 0000000000..e9117bced2 --- /dev/null +++ b/tests/flytekit/unit/clients/testdata/rootCACert.pem @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICsDCCAZgCCQC8mxWHhdmIjDANBgkqhkiG9w0BAQsFADAaMQswCQYDVQQGEwJC +UjELMAkGA1UECAwCV0EwHhcNMjMwMjIyMDU1MTAwWhcNMzMwMjE5MDU1MTAwWjAa +MQswCQYDVQQGEwJCUjELMAkGA1UECAwCV0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQDru2F+xX/Z4Q2W6h5qzwZUpCUjkgjQgSZPHI11hfLqOemXoW9o +66LtU/QOeMXSLGnhq7WSILqVz9sCmRvQmxYTFuur8eQXPEpqggOVWfJQ6vj4ssvf +a5j5KWytM9ixgwUmw9xwAAQ4FC1LQsYyD44Lkb+OPJZEK55ZOyGAPblVJW9WSvnX +DsWXbsDENMU7XlpMtVI5/tToBLaSKIMhbZlssCJtjc7omBTL14yy9L7Bj321/a0m +8SRH/XtfmTux/HHq60qTvUWiHrL+CjehZQcehGlKpqaYq5sEQCVWt/cfyvHLo0gd +X6ayAfmno8WeNbXqsuoU+xckI+8WI4vQ2cfnAgMBAAEwDQYJKoZIhvcNAQELBQAD +ggEBAB++cNraohRh91Xa1GW6IQAInyDnxsXSY0wXrpJsAN541lETdk9+L1CbtHxk +5hwvrkUItzrUjIIKq11k0NceM3ClYyO826By/DMjWtPMYp0eSXLKJJnAT5euhONy +9eBLyNFO0yUj77fEiEj6k5PAUBCgs6ZzWTVCgBiNKPAT6WxaYeIwXdQvC0KoJ0t7 +0SD7/I4i9SSjw3lCRZfMKdd7MEPTpi5hXpZPphg9HYJX5o1KSjWvMTYDUzaQtOlf +GM9zNSXug/GyYgVgUyg2dqp/ohbMtqgFH1kTbvMlLmS6BtQxyi11G6QWv6gdb8z5 +es7Wv+5ZqVjroswzEGi/h72Xo0E= +-----END CERTIFICATE----- diff --git a/tests/flytekit/unit/clients/testdata/rootCAKey.pem b/tests/flytekit/unit/clients/testdata/rootCAKey.pem new file mode 100644 index 0000000000..0f71fce5f1 --- /dev/null +++ b/tests/flytekit/unit/clients/testdata/rootCAKey.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA67thfsV/2eENluoeas8GVKQlI5II0IEmTxyNdYXy6jnpl6Fv +aOui7VP0DnjF0ixp4au1kiC6lc/bApkb0JsWExbrq/HkFzxKaoIDlVnyUOr4+LLL +32uY+SlsrTPYsYMFJsPccAAEOBQtS0LGMg+OC5G/jjyWRCueWTshgD25VSVvVkr5 +1w7Fl27AxDTFO15aTLVSOf7U6AS2kiiDIW2ZbLAibY3O6JgUy9eMsvS+wY99tf2t +JvEkR/17X5k7sfxx6utKk71Foh6y/go3oWUHHoRpSqammKubBEAlVrf3H8rxy6NI +HV+msgH5p6PFnjW16rLqFPsXJCPvFiOL0NnH5wIDAQABAoIBAA+cVRSEF7diA/he +gK0qEI1CYYM9hH/qTZMnnOaPfEqukx2Lf0k/cYat7JeYv+DvOAPNzzRiHnkVTreZ +VBI4cvnIpsq4Nhaj03nCKmKVlkpthRdTH9Un1vWJHL1LlaoLtyeeCNcR6TWdgHJf +daiTByEVAc51jK3vBYl7NPi9HazZsT8p1IxN8pXrM4yXWZ2UDhVkNfiylmg+qsnt +s5C8ENmOoAEeUxfQjPUK72UrQAZM7+N+yGvXPHTkJAEs2yx25NbTUZugXa6+Ehlx +55RuXLbDB0cn/eUp0wmLAryYekTm0AjPy4wsvwoKAkxSnIMdfj6kO0EbEYZv5yFC +aPTeJyECgYEA+h1L3dqnhykNvu6xk1T6fmJ1up8xYH7syToz/QXTlNHJKYQnmZDt +u2Y6jmOC4Tv892/CAeDJZqw0KGJb94sDzD4oE0etf79SyfJ3Gh0BBU9P4W1r2uPx +VNak8V/g/qAfP2cwZz/h9jPWkQ1UcfZa1Mpali4EEHdODAwqxXouHTsCgYEA8Udx +jzyrQuwINJMKfzws52QYsuuy2IUbhfbz9r7lqauvbuYx0PCUH/fnwxUCqjr1BS/l +jdy1SjlcvJ1lHsuJ6YaHu30FjNcDiyr0gAEXQeuIpIVD74BdTNJpKoy2AWIdIYtj +wbjm95V3XOdo/tmFYx+6Ign702Lmi+gkDtDxRUUCgYEAwaLAw6eun5OHEtTVIc1e +iU5M+wiYP67EPx4Sdcd3APZRmRS5W8i6ZKVGnEoqX5oDxMT/HFkdU6HqV4Ge1c0I +Sa2tdQ+/IPHMdJCE6PCfg67dlxcRs0tZ4Wa0GDM0i60HxBxteuIYXHXRnkcFo50o +wSlQbIh/mQfkoqsgyfZHkVUCgYAxiCcp7pyB+o6crGsFP8dAIW5onLZ0eK7zy4S9 +7Oac9F/pdlxXtmvSPERZ6iBH7h6K2BBaFSsqd6gwGGe/8Kz5QeLvfHT9Os7BbSoQ +dSjfIYlFrQ4LRuDgenmYgJaEpi2wyzrJdDoGLar5aZBGcUVO2h6OClqmRLFrm1Z7 +rC07uQKBgBqrNfLs0aBolWXhD5KcWl7Eg+lr8Zm3htYTXSZaQwpuDLQ/SA7F02fu +UWbJVNN44xeJZdFtygxqPLOXrRoNYuYwJ6A5SIjERFiwg7kxGqlWKMoVs0HFFsEb +noBtXOzyx5GEyghotTaAr/wdS7eY8ccmTsKm15td1MLUxjK1nlS/ +-----END RSA PRIVATE KEY-----