diff --git a/dev-requirements.in b/dev-requirements.in
index 755231ed71..25e6cf51e7 100644
--- a/dev-requirements.in
+++ b/dev-requirements.in
@@ -12,6 +12,7 @@ codespell
google-cloud-bigquery
google-cloud-bigquery-storage
IPython
+keyrings.alt
# Only install tensorflow if not running on an arm Mac.
tensorflow==2.8.1; platform_machine!='arm64' or platform_system!='Darwin'
diff --git a/flytekit/clis/auth/__init__.py b/flytekit/clients/auth/__init__.py
similarity index 100%
rename from flytekit/clis/auth/__init__.py
rename to flytekit/clients/auth/__init__.py
diff --git a/flytekit/clis/auth/auth.py b/flytekit/clients/auth/auth_client.py
similarity index 56%
rename from flytekit/clis/auth/auth.py
rename to flytekit/clients/auth/auth_client.py
index f54379485a..94afa13612 100644
--- a/flytekit/clis/auth/auth.py
+++ b/flytekit/clients/auth/auth_client.py
@@ -1,32 +1,31 @@
+from __future__ import annotations
+
import base64 as _base64
import hashlib as _hashlib
import http.server as _BaseHTTPServer
+import logging
+import multiprocessing
import os as _os
import re as _re
+import typing
import urllib.parse as _urlparse
import webbrowser as _webbrowser
+from dataclasses import dataclass
from http import HTTPStatus as _StatusCodes
-from multiprocessing import get_context as _mp_get_context
+from multiprocessing import get_context
from urllib.parse import urlencode as _urlencode
-import keyring as _keyring
import requests as _requests
-from flytekit.loggers import auth_logger
+from .default_html import get_default_success_html
+from .exceptions import AccessTokenNotFoundError
+from .keyring import Credentials
_code_verifier_length = 64
_random_seed_length = 40
_utf_8 = "utf-8"
-# Identifies the service used for storing passwords in keyring
-_keyring_service_name = "flyteauth"
-# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs
-# suggest, we are storing a user's oidc.
-_keyring_access_token_storage_key = "access_token"
-_keyring_refresh_token_storage_key = "refresh_token"
-
-
def _generate_code_verifier():
"""
Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1
@@ -77,6 +76,17 @@ def state(self):
return self._state
+@dataclass
+class EndpointMetadata(object):
+ """
+ This class can be used to control the rendering of the page on login successful or failure
+ """
+
+ endpoint: str
+ success_html: typing.Optional[bytes] = None
+ failure_html: typing.Optional[bytes] = None
+
+
class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an
@@ -87,12 +97,16 @@ def do_GET(self):
url = _urlparse.urlparse(self.path)
if url.path.strip("/") == self.server.redirect_path.strip("/"):
self.send_response(_StatusCodes.OK)
+ self.send_header("Content-type", "text/html")
self.end_headers()
self.handle_login(dict(_urlparse.parse_qsl(url.query)))
+ if self.server.remote_metadata.success_html is None:
+ self.wfile.write(bytes(get_default_success_html(self.server.remote_metadata.endpoint), "utf-8"))
+ self.wfile.flush()
else:
self.send_response(_StatusCodes.NOT_FOUND)
- def handle_login(self, data):
+ def handle_login(self, data: dict):
self.server.handle_authorization_code(AuthorizationCode(data["code"], data["state"]))
@@ -104,49 +118,97 @@ class OAuthHTTPServer(_BaseHTTPServer.HTTPServer):
def __init__(
self,
- server_address,
- RequestHandlerClass,
- bind_and_activate=True,
- redirect_path=None,
- queue=None,
+ server_address: typing.Tuple[str, int],
+ remote_metadata: EndpointMetadata,
+ request_handler_class: typing.Type[_BaseHTTPServer.BaseHTTPRequestHandler],
+ bind_and_activate: bool = True,
+ redirect_path: str = None,
+ queue: multiprocessing.Queue = None,
):
- _BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate)
+ _BaseHTTPServer.HTTPServer.__init__(self, server_address, request_handler_class, bind_and_activate)
self._redirect_path = redirect_path
+ self._remote_metadata = remote_metadata
self._auth_code = None
self._queue = queue
@property
- def redirect_path(self):
+ def redirect_path(self) -> str:
return self._redirect_path
- def handle_authorization_code(self, auth_code):
+ @property
+ def remote_metadata(self) -> EndpointMetadata:
+ return self._remote_metadata
+
+ def handle_authorization_code(self, auth_code: str):
self._queue.put(auth_code)
self.server_close()
- def handle_request(self, queue=None):
+ def handle_request(self, queue: multiprocessing.Queue = None) -> typing.Any:
self._queue = queue
return super().handle_request()
-class Credentials(object):
- def __init__(self, access_token=None):
- self._access_token = access_token
+class _SingletonPerEndpoint(type):
+ """
+ A metaclass to create per endpoint singletons for AuthorizationClient objects
+ """
+
+ _instances: typing.Dict[str, AuthorizationClient] = {}
- @property
- def access_token(self):
- return self._access_token
+ def __call__(cls, *args, **kwargs):
+ endpoint = ""
+ if args:
+ endpoint = args[0]
+ elif "auth_endpoint" in kwargs:
+ endpoint = kwargs["auth_endpoint"]
+ else:
+ raise ValueError("parameter auth_endpoint is required")
+ if endpoint not in cls._instances:
+ cls._instances[endpoint] = super(_SingletonPerEndpoint, cls).__call__(*args, **kwargs)
+ return cls._instances[endpoint]
-class AuthorizationClient(object):
+class AuthorizationClient(metaclass=_SingletonPerEndpoint):
+ """
+ Authorization client that stores the credentials in keyring and uses oauth2 standard flow to retrieve the
+ credentials. NOTE: This will open an web browser to retreive the credentials.
+ """
+
def __init__(
self,
- auth_endpoint=None,
- token_endpoint=None,
- scopes=None,
- client_id=None,
- redirect_uri=None,
+ endpoint: str,
+ auth_endpoint: str,
+ token_endpoint: str,
+ scopes: typing.Optional[typing.List[str]] = None,
+ client_id: typing.Optional[str] = None,
+ redirect_uri: typing.Optional[str] = None,
+ endpoint_metadata: typing.Optional[EndpointMetadata] = None,
+ verify: typing.Optional[typing.Union[bool, str]] = None,
):
+ """
+ Create new AuthorizationClient
+
+ :param endpoint: str endpoint to connect to
+ :param auth_endpoint: str endpoint where auth metadata can be found
+ :param token_endpoint: str endpoint to retrieve token from
+ :param scopes: list[str] oauth2 scopes
+ :param client_id
+ :param verify: (optional) Either a boolean, in which case it controls whether we verify
+ the server's TLS certificate, or a string, in which case it must be a path
+ to a CA bundle to use. Defaults to ``True``. When set to
+ ``False``, requests will accept any TLS certificate presented by
+ the server, and will ignore hostname mismatches and/or expired
+ certificates, which will make your application vulnerable to
+ man-in-the-middle (MitM) attacks. Setting verify to ``False``
+ may be useful during local development or testing.
+ """
+ self._endpoint = endpoint
self._auth_endpoint = auth_endpoint
+ if endpoint_metadata is None:
+ remote_url = _urlparse.urlparse(self._auth_endpoint)
+ self._remote = EndpointMetadata(endpoint=remote_url.hostname)
+ else:
+ self._remote = endpoint_metadata
self._token_endpoint = token_endpoint
self._client_id = client_id
self._scopes = scopes or []
@@ -156,10 +218,8 @@ def __init__(
self._code_challenge = code_challenge
state = _generate_state_parameter()
self._state = state
- self._credentials = None
- self._refresh_token = None
+ self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
- self._expired = False
self._params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
@@ -174,55 +234,27 @@ def __init__(
"code_challenge_method": "S256",
}
- # Prefer to use already-fetched token values when they've been set globally.
- self.set_tokens_from_store()
-
def __repr__(self):
return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})"
- def set_tokens_from_store(self):
- self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key)
- access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key)
- if access_token:
- self._credentials = Credentials(access_token=access_token)
-
- @property
- def has_valid_credentials(self) -> bool:
- return self._credentials is not None
-
- def start_authorization_flow(self):
- # In the absence of globally-set token values, initiate the token request flow
- ctx = _mp_get_context("fork")
- q = ctx.Queue()
-
- # First prepare the callback server in the background
- server = self._create_callback_server()
- server_process = ctx.Process(target=server.handle_request, args=(q,))
- server_process.daemon = True
- server_process.start()
-
- # Send the call to request the authorization code in the background
- self._request_authorization_code()
-
- # Request the access token once the auth code has been received.
- auth_code = q.get()
- server_process.terminate()
- self.request_access_token(auth_code)
-
def _create_callback_server(self):
server_url = _urlparse.urlparse(self._redirect_uri)
server_address = (server_url.hostname, server_url.port)
- return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path)
+ return OAuthHTTPServer(
+ server_address,
+ self._remote,
+ OAuthCallbackHandler,
+ redirect_path=server_url.path,
+ )
def _request_authorization_code(self):
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
query = _urlencode(self._params)
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
- auth_logger.debug(f"Requesting authorization code through {endpoint}")
+ logging.debug(f"Requesting authorization code through {endpoint}")
_webbrowser.open_new_tab(endpoint)
- def _initialize_credentials(self, auth_token_resp):
-
+ def _credentials_from_response(self, auth_token_resp) -> Credentials:
"""
The auth_token_resp body is of the form:
{
@@ -232,21 +264,16 @@ def _initialize_credentials(self, auth_token_resp):
}
"""
response_body = auth_token_resp.json()
+ refresh_token = None
if "access_token" not in response_body:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
- self._refresh_token = response_body["refresh_token"]
- _keyring.set_password(
- _keyring_service_name, _keyring_refresh_token_storage_key, response_body["refresh_token"]
- )
-
+ refresh_token = response_body["refresh_token"]
access_token = response_body["access_token"]
- _keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token)
- # Once keyring credentials have been updated, get the singleton AuthorizationClient to read them again.
- self.set_tokens_from_store()
+ return Credentials(access_token, refresh_token, self._endpoint)
- def request_access_token(self, auth_code):
+ def _request_access_token(self, auth_code) -> Credentials:
if self._state != auth_code.state:
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed")
self._params.update(
@@ -262,6 +289,7 @@ def request_access_token(self, auth_code):
data=self._params,
headers=self._headers,
allow_redirects=False,
+ verify=self._verify,
)
if resp.status_code != _StatusCodes.OK:
# TODO: handle expected (?) error cases:
@@ -269,38 +297,53 @@ def request_access_token(self, auth_code):
raise Exception(
"Failed to request access token with response: [{}] {}".format(resp.status_code, resp.content)
)
- self._initialize_credentials(resp)
+ return self._credentials_from_response(resp)
+
+ def get_creds_from_remote(self) -> Credentials:
+ """
+ This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to
+ retrieve credentials
+ """
+ # In the absence of globally-set token values, initiate the token request flow
+ ctx = get_context("fork")
+ q = ctx.Queue()
+
+ # First prepare the callback server in the background
+ server = self._create_callback_server()
+
+ server_process = ctx.Process(target=server.handle_request, args=(q,))
+ server_process.daemon = True
- def refresh_access_token(self):
- if self._refresh_token is None:
+ try:
+ server_process.start()
+
+ # Send the call to request the authorization code in the background
+ self._request_authorization_code()
+
+ # Request the access token once the auth code has been received.
+ auth_code = q.get()
+ return self._request_access_token(auth_code)
+ finally:
+ server_process.terminate()
+
+ def refresh_access_token(self, credentials: Credentials) -> Credentials:
+ if credentials.refresh_token is None:
raise ValueError("no refresh token available with which to refresh authorization credentials")
resp = _requests.post(
url=self._token_endpoint,
- data={"grant_type": "refresh_token", "client_id": self._client_id, "refresh_token": self._refresh_token},
+ data={
+ "grant_type": "refresh_token",
+ "client_id": self._client_id,
+ "refresh_token": credentials.refresh_token,
+ },
headers=self._headers,
allow_redirects=False,
+ verify=self._verify,
)
if resp.status_code != _StatusCodes.OK:
- self._expired = True
# In the absence of a successful response, assume the refresh token is expired. This should indicate
# to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized.
+ raise AccessTokenNotFoundError(f"Non-200 returned from refresh token endpoint {resp.status_code}")
- _keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key)
- _keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key)
- raise ValueError(f"Non-200 returned from refresh token endpoint {resp.status_code}")
- self._initialize_credentials(resp)
-
- @property
- def credentials(self):
- """
- :return flytekit.clis.auth.auth.Credentials:
- """
- return self._credentials
-
- @property
- def expired(self):
- """
- :return bool:
- """
- return self._expired
+ return self._credentials_from_response(resp)
diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py
new file mode 100644
index 0000000000..183c1787cd
--- /dev/null
+++ b/flytekit/clients/auth/authenticator.py
@@ -0,0 +1,235 @@
+import base64
+import logging
+import subprocess
+import typing
+from abc import abstractmethod
+from dataclasses import dataclass
+
+import requests
+
+from .auth_client import AuthorizationClient
+from .exceptions import AccessTokenNotFoundError, AuthenticationError
+from .keyring import Credentials, KeyringStore
+
+
+@dataclass
+class ClientConfig:
+ """
+ Client Configuration that is needed by the authenticator
+ """
+
+ token_endpoint: str
+ authorization_endpoint: str
+ redirect_uri: str
+ client_id: str
+ scopes: typing.List[str] = None
+ header_key: str = "authorization"
+
+
+class ClientConfigStore(object):
+ """
+ Client Config store retrieve client config. this can be done in multiple ways
+ """
+
+ @abstractmethod
+ def get_client_config(self) -> ClientConfig:
+ ...
+
+
+class StaticClientConfigStore(ClientConfigStore):
+ def __init__(self, cfg: ClientConfig):
+ self._cfg = cfg
+
+ def get_client_config(self) -> ClientConfig:
+ return self._cfg
+
+
+class Authenticator(object):
+ """
+ Base authenticator for all authentication flows
+ """
+
+ def __init__(self, endpoint: str, header_key: str, credentials: Credentials = None):
+ self._endpoint = endpoint
+ self._creds = credentials
+ self._header_key = header_key if header_key else "authorization"
+
+ def get_credentials(self) -> Credentials:
+ return self._creds
+
+ def _set_credentials(self, creds):
+ self._creds = creds
+
+ def _set_header_key(self, h: str):
+ self._header_key = h
+
+ def fetch_grpc_call_auth_metadata(self) -> typing.Optional[typing.Tuple[str, str]]:
+ if self._creds:
+ return self._header_key, f"Bearer {self._creds.access_token}"
+ return None
+
+ @abstractmethod
+ def refresh_credentials(self):
+ ...
+
+
+class PKCEAuthenticator(Authenticator):
+ """
+ This Authenticator encapsulates the entire PKCE flow and automatically opens a browser window for login
+ """
+
+ def __init__(
+ self,
+ endpoint: str,
+ cfg_store: ClientConfigStore,
+ header_key: typing.Optional[str] = None,
+ verify: typing.Optional[typing.Union[bool, str]] = None,
+ ):
+ """
+ Initialize with default creds from KeyStore using the endpoint name
+ """
+ super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint))
+ self._cfg_store = cfg_store
+ self._auth_client = None
+ self._verify = verify
+
+ def _initialize_auth_client(self):
+ if not self._auth_client:
+ cfg = self._cfg_store.get_client_config()
+ self._set_header_key(cfg.header_key)
+ self._auth_client = AuthorizationClient(
+ endpoint=self._endpoint,
+ redirect_uri=cfg.redirect_uri,
+ client_id=cfg.client_id,
+ scopes=cfg.scopes,
+ auth_endpoint=cfg.authorization_endpoint,
+ token_endpoint=cfg.token_endpoint,
+ verify=self._verify,
+ )
+
+ def refresh_credentials(self):
+ """ """
+ self._initialize_auth_client()
+ if self._creds:
+ """We have an access token so lets try to refresh it"""
+ try:
+ self._creds = self._auth_client.refresh_access_token(self._creds)
+ if self._creds:
+ KeyringStore.store(self._creds)
+ return
+ except AccessTokenNotFoundError:
+ logging.warning("Failed to refresh token. Kicking off a full authorization flow.")
+ KeyringStore.delete(self._endpoint)
+
+ self._creds = self._auth_client.get_creds_from_remote()
+ KeyringStore.store(self._creds)
+
+
+class CommandAuthenticator(Authenticator):
+ """
+ This Authenticator retreives access_token using the provided command
+ """
+
+ def __init__(self, command: typing.List[str], header_key: str = None):
+ self._cmd = command
+ if not self._cmd:
+ raise AuthenticationError("Command cannot be empty for command authenticator")
+ super().__init__(None, header_key)
+
+ def refresh_credentials(self):
+ """
+ This function is used when the configuration value for AUTH_MODE is set to 'external_process'.
+ It reads an id token generated by an external process started by running the 'command'.
+ """
+ logging.debug("Starting external process to generate id token. Command {}".format(self._cmd))
+ try:
+ output = subprocess.run(self._cmd, capture_output=True, text=True, check=True)
+ except subprocess.CalledProcessError as e:
+ logging.error("Failed to generate token from command {}".format(self._cmd))
+ raise AuthenticationError("Problems refreshing token with command: " + str(e))
+ self._creds = Credentials(output.stdout.strip())
+
+
+class ClientCredentialsAuthenticator(Authenticator):
+ """
+ This Authenticator uses ClientId and ClientSecret to authenticate
+ """
+
+ _utf_8 = "utf-8"
+
+ def __init__(
+ self,
+ endpoint: str,
+ client_id: str,
+ client_secret: str,
+ cfg_store: ClientConfigStore,
+ header_key: str = None,
+ ):
+ if not client_id or not client_secret:
+ raise ValueError("Client ID and Client SECRET both are required.")
+ cfg = cfg_store.get_client_config()
+ self._token_endpoint = cfg.token_endpoint
+ self._scopes = cfg.scopes
+ self._client_id = client_id
+ self._client_secret = client_secret
+ super().__init__(endpoint, cfg.header_key or header_key)
+
+ @staticmethod
+ def get_token(token_endpoint: str, authorization_header: str, scopes: typing.List[str]) -> typing.Tuple[str, int]:
+ """
+ :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
+ in seconds
+ """
+ headers = {
+ "Authorization": authorization_header,
+ "Cache-Control": "no-cache",
+ "Accept": "application/json",
+ "Content-Type": "application/x-www-form-urlencoded",
+ }
+ body = {
+ "grant_type": "client_credentials",
+ }
+ if scopes is not None:
+ body["scope"] = ",".join(scopes)
+ response = requests.post(token_endpoint, data=body, headers=headers)
+ if response.status_code != 200:
+ logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text))
+ raise AuthenticationError("Non-200 received from IDP")
+
+ response = response.json()
+ return response["access_token"], response["expires_in"]
+
+ @staticmethod
+ def get_basic_authorization_header(client_id: str, client_secret: str) -> str:
+ """
+ This function transforms the client id and the client secret into a header that conforms with http basic auth.
+ It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text
+
+ :param client_id: str
+ :param client_secret: str
+ :rtype: str
+ """
+ concated = "{}:{}".format(client_id, client_secret)
+ return "Basic {}".format(
+ base64.b64encode(concated.encode(ClientCredentialsAuthenticator._utf_8)).decode(
+ ClientCredentialsAuthenticator._utf_8
+ )
+ )
+
+ def refresh_credentials(self):
+ """
+ This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler
+ is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin,
+ like when waiting for another workflow to complete from within a task). This function uses basic auth, which means
+ the credentials for basic auth must be present from wherever this code is running.
+
+ """
+ token_endpoint = self._token_endpoint
+ scopes = self._scopes
+
+ # Note that unlike the Pkce flow, the client ID does not come from Admin.
+ logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
+ authorization_header = self.get_basic_authorization_header(self._client_id, self._client_secret)
+ token, expires_in = self.get_token(token_endpoint, authorization_header, scopes)
+ logging.info("Retrieved new token, expires in {}".format(expires_in))
+ self._creds = Credentials(token)
diff --git a/flytekit/clients/auth/default_html.py b/flytekit/clients/auth/default_html.py
new file mode 100644
index 0000000000..cd7f3d8964
--- /dev/null
+++ b/flytekit/clients/auth/default_html.py
@@ -0,0 +1,12 @@
+def get_default_success_html(endpoint: str) -> str:
+ return f"""
+
+
+ OAuth2 Authentication Success
+
+
+ Successfully logged into {endpoint}
+
+
+
+""" # noqa
diff --git a/flytekit/clients/auth/exceptions.py b/flytekit/clients/auth/exceptions.py
new file mode 100644
index 0000000000..6e790e47a4
--- /dev/null
+++ b/flytekit/clients/auth/exceptions.py
@@ -0,0 +1,14 @@
+class AccessTokenNotFoundError(RuntimeError):
+ """
+ This error is raised with Access token is not found or if Refreshing the token fails
+ """
+
+ pass
+
+
+class AuthenticationError(RuntimeError):
+ """
+ This is raised for any AuthenticationError
+ """
+
+ pass
diff --git a/flytekit/clients/auth/keyring.py b/flytekit/clients/auth/keyring.py
new file mode 100644
index 0000000000..c2b19c46b6
--- /dev/null
+++ b/flytekit/clients/auth/keyring.py
@@ -0,0 +1,64 @@
+import logging
+import typing
+from dataclasses import dataclass
+
+import keyring as _keyring
+from keyring.errors import NoKeyringError
+
+
+@dataclass
+class Credentials(object):
+ """
+ Stores the credentials together
+ """
+
+ access_token: str
+ refresh_token: str = "na"
+ for_endpoint: str = "flyte-default"
+
+
+class KeyringStore:
+ """
+ Methods to access Keyring Store.
+ """
+
+ _access_token_key = "access_token"
+ _refresh_token_key = "refresh_token"
+
+ @staticmethod
+ def store(credentials: Credentials) -> Credentials:
+ try:
+ _keyring.set_password(
+ credentials.for_endpoint,
+ KeyringStore._refresh_token_key,
+ credentials.refresh_token,
+ )
+ _keyring.set_password(
+ credentials.for_endpoint,
+ KeyringStore._access_token_key,
+ credentials.access_token,
+ )
+ except NoKeyringError as e:
+ logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}")
+ return credentials
+
+ @staticmethod
+ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]:
+ try:
+ refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key)
+ access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key)
+ except NoKeyringError as e:
+ logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}")
+ return None
+
+ if not access_token:
+ return None
+ return Credentials(access_token, refresh_token, for_endpoint)
+
+ @staticmethod
+ def delete(for_endpoint: str):
+ try:
+ _keyring.delete_password(for_endpoint, KeyringStore._access_token_key)
+ _keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key)
+ except NoKeyringError as e:
+ logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}")
diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py
new file mode 100644
index 0000000000..41fc5c025f
--- /dev/null
+++ b/flytekit/clients/auth_helper.py
@@ -0,0 +1,193 @@
+import logging
+import ssl
+
+import grpc
+from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest
+from flyteidl.service.auth_pb2_grpc import AuthMetadataServiceStub
+from OpenSSL import crypto
+
+from flytekit.clients.auth.authenticator import (
+ Authenticator,
+ ClientConfig,
+ ClientConfigStore,
+ ClientCredentialsAuthenticator,
+ CommandAuthenticator,
+ PKCEAuthenticator,
+)
+from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor
+from flytekit.clients.grpc_utils.wrap_exception_interceptor import RetryExceptionWrapperInterceptor
+from flytekit.configuration import AuthType, PlatformConfig
+
+
+class RemoteClientConfigStore(ClientConfigStore):
+ """
+ This class implements the ClientConfigStore that is served by the Flyte Server, that implements AuthMetadataService
+ """
+
+ def __init__(self, secure_channel: grpc.Channel):
+ self._secure_channel = secure_channel
+
+ def get_client_config(self) -> ClientConfig:
+ """
+ Retrieves the ClientConfig from the given grpc.Channel assuming AuthMetadataService is available
+ """
+ metadata_service = AuthMetadataServiceStub(self._secure_channel)
+ public_client_config = metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
+ oauth2_metadata = metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
+ return ClientConfig(
+ token_endpoint=oauth2_metadata.token_endpoint,
+ authorization_endpoint=oauth2_metadata.authorization_endpoint,
+ redirect_uri=public_client_config.redirect_uri,
+ client_id=public_client_config.client_id,
+ scopes=public_client_config.scopes,
+ header_key=public_client_config.authorization_metadata_key or None,
+ )
+
+
+def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Authenticator:
+ """
+ Returns a new authenticator based on the platform config.
+ """
+ cfg_auth = cfg.auth_mode
+ if type(cfg_auth) is str:
+ try:
+ cfg_auth = AuthType[cfg_auth.upper()]
+ except KeyError:
+ logging.warning(f"Authentication type {cfg_auth} does not exist, defaulting to standard")
+ cfg_auth = AuthType.STANDARD
+
+ if cfg_auth == AuthType.STANDARD or cfg_auth == AuthType.PKCE:
+ verify = None
+ if cfg.insecure_skip_verify:
+ verify = False
+ elif cfg.ca_cert_file_path:
+ verify = cfg.ca_cert_file_path
+ return PKCEAuthenticator(cfg.endpoint, cfg_store, verify=verify)
+ elif cfg_auth == AuthType.BASIC or cfg_auth == AuthType.CLIENT_CREDENTIALS or cfg_auth == AuthType.CLIENTSECRET:
+ return ClientCredentialsAuthenticator(
+ endpoint=cfg.endpoint,
+ client_id=cfg.client_id,
+ client_secret=cfg.client_credentials_secret,
+ cfg_store=cfg_store,
+ )
+ elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND:
+ client_cfg = None
+ if cfg_store:
+ client_cfg = cfg_store.get_client_config()
+ return CommandAuthenticator(
+ command=cfg.command,
+ header_key=client_cfg.header_key if client_cfg else None,
+ )
+ else:
+ raise ValueError(
+ f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value"
+ )
+
+
+def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Channel) -> grpc.Channel:
+ """
+ Given a grpc.Channel, preferrably a secure channel, it returns a composed channel that uses Interceptor to
+ perform an Oauth2.0 Auth flow
+ :param cfg: PlatformConfig
+ :param in_channel: grpc.Channel Precreated channel
+ :return: grpc.Channel. New composite channel
+ """
+ authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel))
+ return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator))
+
+
+def get_authenticated_channel(cfg: PlatformConfig) -> grpc.Channel:
+ """
+ Returns a new channel for the given config that is authenticated
+ """
+ channel = (
+ grpc.insecure_channel(cfg.endpoint)
+ if cfg.insecure
+ else grpc.secure_channel(cfg.endpoint, grpc.ssl_channel_credentials())
+ ) # noqa
+ return upgrade_channel_to_authenticated(cfg, channel)
+
+
+def load_cert(cert_file: str) -> crypto.X509:
+ """
+ Given a cert-file loads the PEM certificate and returns
+ """
+ st_cert = open(cert_file, "rt").read()
+ return crypto.load_certificate(crypto.FILETYPE_PEM, st_cert)
+
+
+def bootstrap_creds_from_server(endpoint: str) -> grpc.ChannelCredentials:
+ """
+ Retrieves the SSL cert from the remote and uses that. should be used only if insecure-skip-verify
+ """
+ # Get port from endpoint or use 443
+ endpoint_parts = endpoint.rsplit(":", 1)
+ if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
+ server_address = (endpoint_parts[0], endpoint_parts[1])
+ else:
+ server_address = (endpoint, "443")
+ cert = ssl.get_server_certificate(server_address) # noqa
+ return grpc.ssl_channel_credentials(str.encode(cert))
+
+
+def get_channel(cfg: PlatformConfig, **kwargs) -> grpc.Channel:
+ """
+ Creates a new grpc.Channel given a platformConfig.
+ It is possible to pass additional options to the underlying channel. Examples for various options are as below
+
+ .. code-block:: python
+
+ get_channel(cfg=PlatformConfig(...))
+
+ .. code-block:: python
+ :caption: Additional options to insecure / secure channel. Example `options` and `compression` refer to grpc guide
+
+ get_channel(cfg=PlatformConfig(...), options=..., compression=...)
+
+ .. code-block:: python
+ :caption: Create secure channel with custom `grpc.ssl_channel_credentials`
+
+ get_channel(cfg=PlatformConfig(insecure=False,...), credentials=...)
+
+
+ :param cfg: PlatformConfig
+ :param kwargs: Optional arguments to be passed to channel method. Refer to usage example above
+ :return: grpc.Channel (secure / insecure)
+ """
+ if cfg.insecure:
+ return grpc.insecure_channel(cfg.endpoint, **kwargs)
+
+ credentials = None
+ if "credentials" not in kwargs:
+ if cfg.insecure_skip_verify:
+ credentials = bootstrap_creds_from_server(cfg.endpoint)
+ elif cfg.ca_cert_file_path:
+ credentials = grpc.ssl_channel_credentials(load_cert(cfg.ca_cert_file_path))
+ else:
+ credentials = grpc.ssl_channel_credentials(
+ root_certificates=kwargs.get("root_certificates", None),
+ private_key=kwargs.get("private_key", None),
+ certificate_chain=kwargs.get("certificate_chain", None),
+ )
+ else:
+ credentials = kwargs["credentials"]
+ return grpc.secure_channel(
+ target=cfg.endpoint,
+ credentials=credentials,
+ options=kwargs.get("options", None),
+ compression=kwargs.get("compression", None),
+ )
+
+
+def wrap_exceptions_channel(cfg: PlatformConfig, in_channel: grpc.Channel) -> grpc.Channel:
+ """
+ Wraps the input channel with RetryExceptionWrapperInterceptor. This wrapper will cover all
+ exceptions and raise Exception from the Family flytekit.exceptions
+
+ .. note:: This channel should be usually the outermost channel. This channel will raise a FlyteException
+
+ :param cfg: PlatformConfig
+ :param in_channel: grpc.Channel
+ :return: grpc.Channel
+ """
+ return grpc.intercept_channel(in_channel, RetryExceptionWrapperInterceptor(max_retries=cfg.rpc_retries))
diff --git a/flytekit/clients/grpc_utils/__init__.py b/flytekit/clients/grpc_utils/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/flytekit/clients/grpc_utils/auth_interceptor.py b/flytekit/clients/grpc_utils/auth_interceptor.py
new file mode 100644
index 0000000000..21bcc30136
--- /dev/null
+++ b/flytekit/clients/grpc_utils/auth_interceptor.py
@@ -0,0 +1,80 @@
+import typing
+from collections import namedtuple
+
+import grpc
+
+from flytekit.clients.auth.authenticator import Authenticator
+
+
+class _ClientCallDetails(
+ namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")),
+ grpc.ClientCallDetails,
+):
+ """
+ Wrapper class for initializing a new ClientCallDetails instance.
+ We cannot make this of type - NamedTuple because, NamedTuple has a metaclass of type NamedTupleMeta and both
+ the metaclasses conflict
+ """
+
+ pass
+
+
+class AuthUnaryInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor):
+ """
+ This Interceptor can be used to automatically add Auth Metadata for every call - lazily in case authentication
+ is needed.
+ """
+
+ def __init__(self, authenticator: Authenticator):
+ self._authenticator = authenticator
+
+ def _call_details_with_auth_metadata(self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
+ """
+ Returns new ClientCallDetails with metadata added.
+ """
+ metadata = None
+ auth_metadata = self._authenticator.fetch_grpc_call_auth_metadata()
+ if auth_metadata:
+ metadata = []
+ if client_call_details.metadata:
+ metadata.extend(list(client_call_details.metadata))
+ metadata.append(auth_metadata)
+
+ return _ClientCallDetails(
+ client_call_details.method,
+ client_call_details.timeout,
+ metadata,
+ client_call_details.credentials,
+ )
+
+ def intercept_unary_unary(
+ self,
+ continuation: typing.Callable,
+ client_call_details: grpc.ClientCallDetails,
+ request: typing.Any,
+ ):
+ """
+ Intercepts unary calls and adds auth metadata if available. On Unauthenticated, resets the token and refreshes
+ and then retries with the new token
+ """
+ updated_call_details = self._call_details_with_auth_metadata(client_call_details)
+ fut: grpc.Future = continuation(updated_call_details, request)
+ e = fut.exception()
+ if e:
+ if e.code() == grpc.StatusCode.UNAUTHENTICATED:
+ self._authenticator.refresh_credentials()
+ updated_call_details = self._call_details_with_auth_metadata(client_call_details)
+ return continuation(updated_call_details, request)
+ return fut
+
+ def intercept_unary_stream(self, continuation, client_call_details, request):
+ """
+ Handles a stream call and adds authentication metadata if needed
+ """
+ updated_call_details = self._call_details_with_auth_metadata(client_call_details)
+ c: grpc.Call = continuation(updated_call_details, request)
+ if c.code() == grpc.StatusCode.UNAUTHENTICATED:
+ self._authenticator.refresh_credentials()
+ updated_call_details = self._call_details_with_auth_metadata(client_call_details)
+ return continuation(updated_call_details, request)
+ return c
diff --git a/flytekit/clients/grpc_utils/wrap_exception_interceptor.py b/flytekit/clients/grpc_utils/wrap_exception_interceptor.py
new file mode 100644
index 0000000000..ea796f464a
--- /dev/null
+++ b/flytekit/clients/grpc_utils/wrap_exception_interceptor.py
@@ -0,0 +1,49 @@
+import typing
+from typing import Union
+
+import grpc
+
+from flytekit.exceptions.base import FlyteException
+from flytekit.exceptions.system import FlyteSystemException
+from flytekit.exceptions.user import (
+ FlyteAuthenticationException,
+ FlyteEntityAlreadyExistsException,
+ FlyteEntityNotExistException,
+ FlyteInvalidInputException,
+)
+
+
+class RetryExceptionWrapperInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor):
+ def __init__(self, max_retries: int = 3):
+ self._max_retries = 3
+
+ @staticmethod
+ def _raise_if_exc(request: typing.Any, e: Union[grpc.Call, grpc.Future]):
+ if isinstance(e, grpc.RpcError):
+ if e.code() == grpc.StatusCode.UNAUTHENTICATED:
+ raise FlyteAuthenticationException() from e
+ elif e.code() == grpc.StatusCode.ALREADY_EXISTS:
+ raise FlyteEntityAlreadyExistsException() from e
+ elif e.code() == grpc.StatusCode.NOT_FOUND:
+ raise FlyteEntityNotExistException() from e
+ elif e.code() == grpc.StatusCode.INVALID_ARGUMENT:
+ raise FlyteInvalidInputException(request) from e
+ raise FlyteSystemException() from e
+
+ def intercept_unary_unary(self, continuation, client_call_details, request):
+ retries = 0
+ while True:
+ fut: grpc.Future = continuation(client_call_details, request)
+ e = fut.exception()
+ try:
+ if e:
+ self._raise_if_exc(request, e)
+ return fut
+ except FlyteException as e:
+ if retries == self._max_retries:
+ raise e
+ retries = retries + 1
+
+ def intercept_unary_stream(self, continuation, client_call_details, request):
+ c: grpc.Call = continuation(client_call_details, request)
+ return c
diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py
index 6c8f54e9ce..e71485b17c 100644
--- a/flytekit/clients/raw.py
+++ b/flytekit/clients/raw.py
@@ -1,91 +1,20 @@
from __future__ import annotations
-import base64 as _base64
-import ssl
-import subprocess
-import time
import typing
-from typing import Optional
import grpc
-import requests as _requests
from flyteidl.admin.project_pb2 import ProjectListRequest
from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse
from flyteidl.service import admin_pb2_grpc as _admin_service
-from flyteidl.service import auth_pb2
-from flyteidl.service import auth_pb2_grpc as auth_service
from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2
from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service
from flyteidl.service import signal_pb2_grpc as signal_service
from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub
-from google.protobuf.json_format import MessageToJson as _MessageToJson
-from flytekit.clis.auth import credentials as _credentials_access
-from flytekit.configuration import AuthType, PlatformConfig
-from flytekit.exceptions import user as _user_exceptions
-from flytekit.exceptions.user import FlyteAuthenticationException
+from flytekit.clients.auth_helper import get_channel, upgrade_channel_to_authenticated, wrap_exceptions_channel
+from flytekit.configuration import PlatformConfig
from flytekit.loggers import cli_logger
-_utf_8 = "utf-8"
-
-
-def _handle_rpc_error(retry=False):
- def decorator(fn):
- def handler(*args, **kwargs):
- """
- Wraps rpc errors as Flyte exceptions and handles authentication the client.
- """
- max_retries = 3
- max_wait_time = 1000
-
- for i in range(max_retries):
- try:
- return fn(*args, **kwargs)
- except grpc.RpcError as e:
- if e.code() == grpc.StatusCode.UNAUTHENTICATED:
- # Always retry auth errors.
- if i == (max_retries - 1):
- # Exit the loop and wrap the authentication error.
- raise _user_exceptions.FlyteAuthenticationException(str(e))
- cli_logger.debug(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n")
- args[0].refresh_credentials()
- elif e.code() == grpc.StatusCode.ALREADY_EXISTS:
- # There are two cases that we should throw error immediately
- # 1. Entity already exists when we register entity
- # 2. Entity not found when we fetch entity
- raise _user_exceptions.FlyteEntityAlreadyExistsException(e)
- elif e.code() == grpc.StatusCode.NOT_FOUND:
- raise _user_exceptions.FlyteEntityNotExistException(e)
- else:
- # No more retries if retry=False or max_retries reached.
- if (retry is False) or i == (max_retries - 1):
- raise
- else:
- # Retry: Start with 200ms wait-time and exponentially back-off up to 1 second.
- wait_time = min(200 * (2**i), max_wait_time)
- cli_logger.error(f"Non-auth RPC error {e}, sleeping {wait_time}ms and retrying")
- time.sleep(wait_time / 1000)
-
- return handler
-
- return decorator
-
-
-def _handle_invalid_create_request(fn):
- def handler(self, create_request):
- try:
- fn(self, create_request)
- except grpc.RpcError as e:
- if e.code() == grpc.StatusCode.INVALID_ARGUMENT:
- cli_logger.error("Error creating Flyte entity because of invalid arguments. Create request: ")
- cli_logger.error(_MessageToJson(create_request))
- cli_logger.error("Details returned from the flyte admin: ")
- cli_logger.error(e.details)
- # Re-raise since we're not handling the error here and add the create_request details
- raise e
-
- return handler
-
class RawSynchronousFlyteClient(object):
"""
@@ -112,54 +41,9 @@ def __init__(self, cfg: PlatformConfig, **kwargs):
insecure: if insecure is desired
"""
self._cfg = cfg
- if cfg.insecure:
- self._channel = grpc.insecure_channel(cfg.endpoint, **kwargs)
- elif cfg.insecure_skip_verify:
- # Get port from endpoint or use 443
- endpoint_parts = cfg.endpoint.rsplit(":", 1)
- if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
- server_address = tuple(endpoint_parts)
- else:
- server_address = (cfg.endpoint, "443")
- cert = ssl.get_server_certificate(server_address)
- credentials = grpc.ssl_channel_credentials(str.encode(cert))
- options = kwargs.get("options", [])
- self._channel = grpc.secure_channel(
- target=cfg.endpoint,
- credentials=credentials,
- options=options,
- compression=kwargs.get("compression", None),
- )
- else:
- if "credentials" not in kwargs:
- credentials = grpc.ssl_channel_credentials(
- root_certificates=kwargs.get("root_certificates", None),
- private_key=kwargs.get("private_key", None),
- certificate_chain=kwargs.get("certificate_chain", None),
- )
- else:
- credentials = kwargs["credentials"]
- self._channel = grpc.secure_channel(
- target=cfg.endpoint,
- credentials=credentials,
- options=kwargs.get("options", None),
- compression=kwargs.get("compression", None),
- )
+ self._channel = wrap_exceptions_channel(cfg, upgrade_channel_to_authenticated(cfg, get_channel(cfg)))
self._stub = _admin_service.AdminServiceStub(self._channel)
- self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel)
self._signal = signal_service.SignalServiceStub(self._channel)
- try:
- resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest())
- self._public_client_config = resp
- except grpc.RpcError:
- cli_logger.debug("No public client auth config found, skipping.")
- self._public_client_config = None
- try:
- resp = self._auth_stub.GetOAuth2Metadata(auth_pb2.OAuth2MetadataRequest())
- self._oauth2_metadata = resp
- except grpc.RpcError:
- cli_logger.debug("No OAuth2 Metadata found, skipping.")
- self._oauth2_metadata = None
self._dataproxy_stub = dataproxy_service.DataProxyServiceStub(self._channel)
cli_logger.info(
@@ -175,161 +59,16 @@ def with_root_certificate(cls, cfg: PlatformConfig, root_cert_file: str) -> RawS
b = fp.read()
return RawSynchronousFlyteClient(cfg, credentials=grpc.ssl_channel_credentials(root_certificates=b))
- @property
- def public_client_config(self) -> Optional[auth_pb2.PublicClientAuthConfigResponse]:
- return self._public_client_config
-
- @property
- def oauth2_metadata(self) -> Optional[auth_pb2.OAuth2MetadataResponse]:
- return self._oauth2_metadata
-
@property
def url(self) -> str:
return self._cfg.endpoint
- def _refresh_credentials_standard(self):
- """
- This function is used when the configuration value for AUTH_MODE is set to 'standard'.
- This either fetches the existing access token or initiates the flow to request a valid access token and store it.
- :param self: RawSynchronousFlyteClient
- :return:
- """
- authorization_header_key = self.public_client_config.authorization_metadata_key or None
- if not self.oauth2_metadata or not self.public_client_config:
- raise ValueError(
- "Raw Flyte client attempting client credentials flow but no response from Admin detected. "
- "Check your Admin server's .well-known endpoints to make sure they're working as expected."
- )
-
- client = _credentials_access.get_client(
- redirect_endpoint=self.public_client_config.redirect_uri,
- client_id=self.public_client_config.client_id,
- scopes=self.public_client_config.scopes,
- auth_endpoint=self.oauth2_metadata.authorization_endpoint,
- token_endpoint=self.oauth2_metadata.token_endpoint,
- )
-
- if client.has_valid_credentials and not self.check_access_token(client.credentials.access_token):
- # When Python starts up, if credentials have been stored in the keyring, then the AuthorizationClient
- # will have read them into its _credentials field, but it won't be in the RawSynchronousFlyteClient's
- # metadata field yet. Therefore, if there's a mismatch, copy it over.
- self.set_access_token(client.credentials.access_token, authorization_header_key)
- return
-
- try:
- client.refresh_access_token()
- except ValueError:
- client.start_authorization_flow()
-
- self.set_access_token(client.credentials.access_token, authorization_header_key)
-
- def _refresh_credentials_basic(self):
- """
- This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler
- is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin,
- like when waiting for another workflow to complete from within a task). This function uses basic auth, which means
- the credentials for basic auth must be present from wherever this code is running.
-
- :param self: RawSynchronousFlyteClient
- :return:
- """
- if not self.oauth2_metadata or not self.public_client_config:
- raise ValueError(
- "Raw Flyte client attempting client credentials flow but no response from Admin detected. "
- "Check your Admin server's .well-known endpoints to make sure they're working as expected."
- )
-
- token_endpoint = self.oauth2_metadata.token_endpoint
- scopes = self._cfg.scopes or self.public_client_config.scopes
- scopes = ",".join(scopes)
-
- # Note that unlike the Pkce flow, the client ID does not come from Admin.
- client_secret = self._cfg.client_credentials_secret
- if not client_secret:
- raise FlyteAuthenticationException("No client credentials secret provided in the config")
- cli_logger.debug(f"Basic authorization flow with client id {self._cfg.client_id} scope {scopes}")
- authorization_header = get_basic_authorization_header(self._cfg.client_id, client_secret)
- token, expires_in = get_token(token_endpoint, authorization_header, scopes)
- cli_logger.info("Retrieved new token, expires in {}".format(expires_in))
- authorization_header_key = self.public_client_config.authorization_metadata_key or None
- self.set_access_token(token, authorization_header_key)
-
- def _refresh_credentials_from_command(self):
- """
- This function is used when the configuration value for AUTH_MODE is set to 'external_process'.
- It reads an id token generated by an external process started by running the 'command'.
-
- :param self: RawSynchronousFlyteClient
- :return:
- """
- command = self._cfg.command
- if not command:
- raise FlyteAuthenticationException("No command specified in configuration for command authentication")
- cli_logger.debug("Starting external process to generate id token. Command {}".format(command))
- try:
- output = subprocess.run(command, capture_output=True, text=True, check=True)
- except subprocess.CalledProcessError as e:
- cli_logger.error("Failed to generate token from command {}".format(command))
- raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e))
- authorization_header_key = self.public_client_config.authorization_metadata_key or None
- if not authorization_header_key:
- self.set_access_token(output.stdout.strip())
- self.set_access_token(output.stdout.strip(), authorization_header_key)
-
- def _refresh_credentials_noop(self):
- pass
-
- def refresh_credentials(self):
- cfg_auth = self._cfg.auth_mode
- if type(cfg_auth) is str:
- try:
- cfg_auth = AuthType[cfg_auth.upper()]
- except KeyError:
- cli_logger.warning(f"Authentication type {cfg_auth} does not exist, defaulting to standard")
- cfg_auth = AuthType.STANDARD
-
- if cfg_auth == AuthType.STANDARD or cfg_auth == AuthType.PKCE:
- return self._refresh_credentials_standard()
- elif cfg_auth == AuthType.BASIC or cfg_auth == AuthType.CLIENT_CREDENTIALS or cfg_auth == AuthType.CLIENTSECRET:
- return self._refresh_credentials_basic()
- elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND:
- return self._refresh_credentials_from_command()
- else:
- raise ValueError(
- f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value"
- )
-
- def set_access_token(self, access_token: str, authorization_header_key: Optional[str] = "authorization"):
- # Always set the header to lower-case regardless of what the config is. The grpc libraries that Admin uses
- # to parse the metadata don't change the metadata, but they do automatically lower the key you're looking for.
- cli_logger.debug(f"Adding authorization header. Header name: {authorization_header_key}.")
- self._metadata = [
- (
- authorization_header_key,
- f"Bearer {access_token}",
- )
- ]
-
- def check_access_token(self, access_token: str) -> bool:
- """
- This checks to see if the given access token is the same as the one already stored in the client. The reason
- this is useful is so that we can prevent unnecessary refreshing of tokens.
-
- :param access_token: The access token to check
- :return: If no access token is stored, or if the stored token doesn't match, return False.
- """
- if self._metadata is None:
- return False
- return access_token == self._metadata[0][1].replace("Bearer ", "")
-
####################################################################################################################
#
# Task Endpoints
#
####################################################################################################################
- @_handle_rpc_error()
- @_handle_invalid_create_request
def create_task(self, task_create_request):
"""
This will create a task definition in the Admin database. Once successful, the task object can be
@@ -350,7 +89,6 @@ def create_task(self, task_create_request):
"""
return self._stub.CreateTask(task_create_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_task_ids_paginated(self, identifier_list_request):
"""
This returns a page of identifiers for the tasks for a given project and domain. Filters can also be
@@ -376,7 +114,6 @@ def list_task_ids_paginated(self, identifier_list_request):
"""
return self._stub.ListTaskIds(identifier_list_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_tasks_paginated(self, resource_list_request):
"""
This returns a page of task metadata for tasks in a given project and domain. Optionally,
@@ -398,7 +135,6 @@ def list_tasks_paginated(self, resource_list_request):
"""
return self._stub.ListTasks(resource_list_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def get_task(self, get_object_request):
"""
This returns a single task for a given identifier.
@@ -409,14 +145,12 @@ def get_task(self, get_object_request):
"""
return self._stub.GetTask(get_object_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def set_signal(self, signal_set_request: SignalSetRequest) -> SignalSetResponse:
"""
This sets a signal
"""
return self._signal.SetSignal(signal_set_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_signals(self, signal_list_request: SignalListRequest) -> SignalList:
"""
This lists signals
@@ -429,8 +163,6 @@ def list_signals(self, signal_list_request: SignalListRequest) -> SignalList:
#
####################################################################################################################
- @_handle_rpc_error()
- @_handle_invalid_create_request
def create_workflow(self, workflow_create_request):
"""
This will create a workflow definition in the Admin database. Once successful, the workflow object can be
@@ -451,7 +183,6 @@ def create_workflow(self, workflow_create_request):
"""
return self._stub.CreateWorkflow(workflow_create_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_workflow_ids_paginated(self, identifier_list_request):
"""
This returns a page of identifiers for the workflows for a given project and domain. Filters can also be
@@ -477,7 +208,6 @@ def list_workflow_ids_paginated(self, identifier_list_request):
"""
return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_workflows_paginated(self, resource_list_request):
"""
This returns a page of workflow meta-information for workflows in a given project and domain. Optionally,
@@ -499,7 +229,6 @@ def list_workflows_paginated(self, resource_list_request):
"""
return self._stub.ListWorkflows(resource_list_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def get_workflow(self, get_object_request):
"""
This returns a single workflow for a given identifier.
@@ -516,8 +245,6 @@ def get_workflow(self, get_object_request):
#
####################################################################################################################
- @_handle_rpc_error()
- @_handle_invalid_create_request
def create_launch_plan(self, launch_plan_create_request):
"""
This will create a launch plan definition in the Admin database. Once successful, the launch plan object can be
@@ -541,7 +268,6 @@ def create_launch_plan(self, launch_plan_create_request):
# TODO: List endpoints when they come in
- @_handle_rpc_error(retry=True)
def get_launch_plan(self, object_get_request):
"""
Retrieves a launch plan entity.
@@ -551,7 +277,6 @@ def get_launch_plan(self, object_get_request):
"""
return self._stub.GetLaunchPlan(object_get_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def get_active_launch_plan(self, active_launch_plan_request):
"""
Retrieves a launch plan entity.
@@ -561,7 +286,6 @@ def get_active_launch_plan(self, active_launch_plan_request):
"""
return self._stub.GetActiveLaunchPlan(active_launch_plan_request, metadata=self._metadata)
- @_handle_rpc_error()
def update_launch_plan(self, update_request):
"""
Allows updates to a launch plan at a given identifier. Currently, a launch plan may only have it's state
@@ -572,7 +296,6 @@ def update_launch_plan(self, update_request):
"""
return self._stub.UpdateLaunchPlan(update_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_launch_plan_ids_paginated(self, identifier_list_request):
"""
Lists launch plan named identifiers for a given project and domain.
@@ -582,7 +305,6 @@ def list_launch_plan_ids_paginated(self, identifier_list_request):
"""
return self._stub.ListLaunchPlanIds(identifier_list_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_launch_plans_paginated(self, resource_list_request):
"""
Lists Launch Plans for a given Identifier (project, domain, name)
@@ -592,7 +314,6 @@ def list_launch_plans_paginated(self, resource_list_request):
"""
return self._stub.ListLaunchPlans(resource_list_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_active_launch_plans_paginated(self, active_launch_plan_list_request):
"""
Lists Active Launch Plans for a given (project, domain)
@@ -608,7 +329,6 @@ def list_active_launch_plans_paginated(self, active_launch_plan_list_request):
#
####################################################################################################################
- @_handle_rpc_error()
def update_named_entity(self, update_named_entity_request):
"""
:param flyteidl.admin.common_pb2.NamedEntityUpdateRequest update_named_entity_request:
@@ -622,7 +342,6 @@ def update_named_entity(self, update_named_entity_request):
#
####################################################################################################################
- @_handle_rpc_error()
def create_execution(self, create_execution_request):
"""
This will create an execution for the given execution spec.
@@ -631,7 +350,6 @@ def create_execution(self, create_execution_request):
"""
return self._stub.CreateExecution(create_execution_request, metadata=self._metadata)
- @_handle_rpc_error()
def recover_execution(self, recover_execution_request):
"""
This will recreate an execution with the same spec as the one belonging to the given execution identifier.
@@ -640,7 +358,6 @@ def recover_execution(self, recover_execution_request):
"""
return self._stub.RecoverExecution(recover_execution_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def get_execution(self, get_object_request):
"""
Returns an execution of a workflow entity.
@@ -650,7 +367,6 @@ def get_execution(self, get_object_request):
"""
return self._stub.GetExecution(get_object_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def get_execution_data(self, get_execution_data_request):
"""
Returns signed URLs to LiteralMap blobs for an execution's inputs and outputs (when available).
@@ -660,7 +376,6 @@ def get_execution_data(self, get_execution_data_request):
"""
return self._stub.GetExecutionData(get_execution_data_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_executions_paginated(self, resource_list_request):
"""
Lists the executions for a given identifier.
@@ -670,7 +385,6 @@ def list_executions_paginated(self, resource_list_request):
"""
return self._stub.ListExecutions(resource_list_request, metadata=self._metadata)
- @_handle_rpc_error()
def terminate_execution(self, terminate_execution_request):
"""
:param flyteidl.admin.execution_pb2.TerminateExecutionRequest terminate_execution_request:
@@ -678,7 +392,6 @@ def terminate_execution(self, terminate_execution_request):
"""
return self._stub.TerminateExecution(terminate_execution_request, metadata=self._metadata)
- @_handle_rpc_error()
def relaunch_execution(self, relaunch_execution_request):
"""
:param flyteidl.admin.execution_pb2.ExecutionRelaunchRequest relaunch_execution_request:
@@ -692,7 +405,6 @@ def relaunch_execution(self, relaunch_execution_request):
#
####################################################################################################################
- @_handle_rpc_error(retry=True)
def get_node_execution(self, node_execution_request):
"""
:param flyteidl.admin.node_execution_pb2.NodeExecutionGetRequest node_execution_request:
@@ -700,7 +412,6 @@ def get_node_execution(self, node_execution_request):
"""
return self._stub.GetNodeExecution(node_execution_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def get_node_execution_data(self, get_node_execution_data_request):
"""
Returns signed URLs to LiteralMap blobs for a node execution's inputs and outputs (when available).
@@ -710,7 +421,6 @@ def get_node_execution_data(self, get_node_execution_data_request):
"""
return self._stub.GetNodeExecutionData(get_node_execution_data_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_node_executions_paginated(self, node_execution_list_request):
"""
:param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_list_request:
@@ -718,7 +428,6 @@ def list_node_executions_paginated(self, node_execution_list_request):
"""
return self._stub.ListNodeExecutions(node_execution_list_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_node_executions_for_task_paginated(self, node_execution_for_task_list_request):
"""
:param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_for_task_list_request:
@@ -732,7 +441,6 @@ def list_node_executions_for_task_paginated(self, node_execution_for_task_list_r
#
####################################################################################################################
- @_handle_rpc_error(retry=True)
def get_task_execution(self, task_execution_request):
"""
:param flyteidl.admin.task_execution_pb2.TaskExecutionGetRequest task_execution_request:
@@ -740,7 +448,6 @@ def get_task_execution(self, task_execution_request):
"""
return self._stub.GetTaskExecution(task_execution_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def get_task_execution_data(self, get_task_execution_data_request):
"""
Returns signed URLs to LiteralMap blobs for a task execution's inputs and outputs (when available).
@@ -750,7 +457,6 @@ def get_task_execution_data(self, get_task_execution_data_request):
"""
return self._stub.GetTaskExecutionData(get_task_execution_data_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_task_executions_paginated(self, task_execution_list_request):
"""
:param flyteidl.admin.task_execution_pb2.TaskExecutionListRequest task_execution_list_request:
@@ -764,7 +470,6 @@ def list_task_executions_paginated(self, task_execution_list_request):
#
####################################################################################################################
- @_handle_rpc_error(retry=True)
def list_projects(self, project_list_request: typing.Optional[ProjectListRequest] = None):
"""
This will return a list of the projects registered with the Flyte Admin Service
@@ -775,7 +480,6 @@ def list_projects(self, project_list_request: typing.Optional[ProjectListRequest
project_list_request = ProjectListRequest()
return self._stub.ListProjects(project_list_request, metadata=self._metadata)
- @_handle_rpc_error()
def register_project(self, project_register_request):
"""
Registers a project along with a set of domains.
@@ -784,7 +488,6 @@ def register_project(self, project_register_request):
"""
return self._stub.RegisterProject(project_register_request, metadata=self._metadata)
- @_handle_rpc_error()
def update_project(self, project):
"""
Update an existing project specified by id.
@@ -798,7 +501,6 @@ def update_project(self, project):
# Matching Attributes Endpoints
#
####################################################################################################################
- @_handle_rpc_error()
def update_project_domain_attributes(self, project_domain_attributes_update_request):
"""
This updates the attributes for a project and domain registered with the Flyte Admin Service
@@ -809,7 +511,6 @@ def update_project_domain_attributes(self, project_domain_attributes_update_requ
project_domain_attributes_update_request, metadata=self._metadata
)
- @_handle_rpc_error()
def update_workflow_attributes(self, workflow_attributes_update_request):
"""
This updates the attributes for a project, domain, and workflow registered with the Flyte Admin Service
@@ -818,7 +519,6 @@ def update_workflow_attributes(self, workflow_attributes_update_request):
"""
return self._stub.UpdateWorkflowAttributes(workflow_attributes_update_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def get_project_domain_attributes(self, project_domain_attributes_get_request):
"""
This fetches the attributes for a project and domain registered with the Flyte Admin Service
@@ -827,7 +527,6 @@ def get_project_domain_attributes(self, project_domain_attributes_get_request):
"""
return self._stub.GetProjectDomainAttributes(project_domain_attributes_get_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def get_workflow_attributes(self, workflow_attributes_get_request):
"""
This fetches the attributes for a project, domain, and workflow registered with the Flyte Admin Service
@@ -836,7 +535,6 @@ def get_workflow_attributes(self, workflow_attributes_get_request):
"""
return self._stub.GetWorkflowAttributes(workflow_attributes_get_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def list_matchable_attributes(self, matchable_attributes_list_request):
"""
This fetches the attributes for a specific resource type registered with the Flyte Admin Service
@@ -859,7 +557,6 @@ def list_matchable_attributes(self, matchable_attributes_list_request):
# Data proxy endpoints
#
####################################################################################################################
- @_handle_rpc_error(retry=True)
def create_upload_location(
self, create_upload_location_request: _dataproxy_pb2.CreateUploadLocationRequest
) -> _dataproxy_pb2.CreateUploadLocationResponse:
@@ -870,7 +567,6 @@ def create_upload_location(
"""
return self._dataproxy_stub.CreateUploadLocation(create_upload_location_request, metadata=self._metadata)
- @_handle_rpc_error(retry=True)
def create_download_location(
self, create_download_location_request: _dataproxy_pb2.CreateDownloadLocationRequest
) -> _dataproxy_pb2.CreateDownloadLocationResponse:
@@ -880,43 +576,3 @@ def create_download_location(
:rtype: flyteidl.service.dataproxy_pb2.CreateDownloadLocationResponse
"""
return self._dataproxy_stub.CreateDownloadLocation(create_download_location_request, metadata=self._metadata)
-
-
-def get_token(token_endpoint, authorization_header, scope):
- """
- :param Text token_endpoint:
- :param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123')
- :param Text scope:
- :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
- in seconds
- """
- headers = {
- "Authorization": authorization_header,
- "Cache-Control": "no-cache",
- "Accept": "application/json",
- "Content-Type": "application/x-www-form-urlencoded",
- }
- body = {
- "grant_type": "client_credentials",
- }
- if scope is not None:
- body["scope"] = scope
- response = _requests.post(token_endpoint, data=body, headers=headers)
- if response.status_code != 200:
- cli_logger.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text))
- raise FlyteAuthenticationException("Non-200 received from IDP")
-
- response = response.json()
- return response["access_token"], response["expires_in"]
-
-
-def get_basic_authorization_header(client_id, client_secret):
- """
- This function transforms the client id and the client secret into a header that conforms with http basic auth.
- It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text.
- :param Text client_id:
- :param Text client_secret:
- :rtype: Text
- """
- concated = "{}:{}".format(client_id, client_secret)
- return "Basic {}".format(_base64.b64encode(concated.encode(_utf_8)).decode(_utf_8))
diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py
deleted file mode 100644
index a8475c8dfc..0000000000
--- a/flytekit/clis/auth/credentials.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from typing import List
-
-from flytekit.clis.auth.auth import AuthorizationClient
-from flytekit.loggers import auth_logger
-
-# Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3.
-discovery_endpoint_path = "./.well-known/oauth-authorization-server"
-
-# Lazy initialized authorization client singleton
-_authorization_client = None
-
-
-def get_client(
- redirect_endpoint: str, client_id: str, scopes: List[str], auth_endpoint: str, token_endpoint: str
-) -> AuthorizationClient:
- global _authorization_client
- if _authorization_client is not None and not _authorization_client.expired:
- return _authorization_client
-
- _authorization_client = AuthorizationClient(
- redirect_uri=redirect_endpoint,
- client_id=client_id,
- scopes=scopes,
- auth_endpoint=auth_endpoint,
- token_endpoint=token_endpoint,
- )
-
- auth_logger.debug(f"Created oauth client with redirect {_authorization_client}")
-
- if not _authorization_client.has_valid_credentials:
- _authorization_client.start_authorization_flow()
-
- return _authorization_client
diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py
index 46513553b9..d228babf43 100644
--- a/flytekit/clis/sdk_in_container/constants.py
+++ b/flytekit/clis/sdk_in_container/constants.py
@@ -9,6 +9,7 @@
CTX_CONFIG_FILE = "config_file"
CTX_PROJECT_ROOT = "project_root"
CTX_MODULE = "module"
+CTX_VERBOSE = "verbose"
project_option = _click.option(
diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py
index 1f843450ed..5e1136d14c 100644
--- a/flytekit/clis/sdk_in_container/pyflyte.py
+++ b/flytekit/clis/sdk_in_container/pyflyte.py
@@ -1,8 +1,12 @@
+import typing
+
import click
+import grpc
+from google.protobuf.json_format import MessageToJson
from flytekit import configuration
from flytekit.clis.sdk_in_container.backfill import backfill
-from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_PACKAGES
+from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_PACKAGES, CTX_VERBOSE
from flytekit.clis.sdk_in_container.init import init
from flytekit.clis.sdk_in_container.local_cache import local_cache
from flytekit.clis.sdk_in_container.package import package
@@ -10,6 +14,8 @@
from flytekit.clis.sdk_in_container.run import run
from flytekit.clis.sdk_in_container.serialize import serialize
from flytekit.configuration.internal import LocalSDK
+from flytekit.exceptions.base import FlyteException
+from flytekit.exceptions.user import FlyteInvalidInputException
from flytekit.loggers import cli_logger
@@ -28,7 +34,60 @@ def validate_package(ctx, param, values):
return pkgs
-@click.group("pyflyte", invoke_without_command=True)
+def pretty_print_grpc_error(e: grpc.RpcError):
+ if isinstance(e, grpc._channel._InactiveRpcError): # noqa
+ click.secho(f"RPC Failed, with Status: {e.code()}", fg="red")
+ click.secho(f"\tdetails: {e.details()}", fg="magenta")
+ click.secho(f"\tDebug string {e.debug_error_string()}", dim=True)
+ return
+
+
+def pretty_print_exception(e: Exception):
+ if isinstance(e, click.exceptions.Exit):
+ raise e
+
+ if isinstance(e, click.ClickException):
+ click.secho(e.message, fg="red")
+ raise e
+
+ if isinstance(e, FlyteException):
+ if isinstance(e, FlyteInvalidInputException):
+ click.secho("Request rejected by the API, due to Invalid input.", fg="red")
+ click.secho(f"\tReason: {str(e)}", dim=True)
+ click.secho(f"\tInput Request: {MessageToJson(e.request)}", dim=True)
+ return
+ click.secho(f"Failed with Exception: Reason: {e._ERROR_CODE}", fg="red") # noqa
+ cause = e.__cause__
+ if cause:
+ if isinstance(cause, grpc.RpcError):
+ pretty_print_grpc_error(cause)
+ else:
+ click.secho(f"Underlying Exception: {cause}")
+ return
+
+ if isinstance(e, grpc.RpcError):
+ pretty_print_grpc_error(e)
+ return
+
+ click.secho(f"Failed with Unknown Exception {type(e)} Reason: {e}", fg="red") # noqa
+
+
+class ErrorHandlingCommand(click.Group):
+ def invoke(self, ctx: click.Context) -> typing.Any:
+ try:
+ return super().invoke(ctx)
+ except Exception as e:
+ if CTX_VERBOSE in ctx.obj and ctx.obj[CTX_VERBOSE]:
+ print("Verbose mode on")
+ raise e
+ pretty_print_exception(e)
+ raise SystemExit(e)
+
+
+@click.group("pyflyte", invoke_without_command=True, cls=ErrorHandlingCommand)
+@click.option(
+ "--verbose", required=False, default=False, is_flag=True, help="Show verbose messages and exception traces"
+)
@click.option(
"-k",
"--pkgs",
@@ -47,7 +106,7 @@ def validate_package(ctx, param, values):
help="Path to config file for use within container",
)
@click.pass_context
-def main(ctx, pkgs=None, config=None):
+def main(ctx, pkgs: typing.List[str], config: str, verbose: bool):
"""
Entrypoint for all the user commands.
"""
@@ -63,6 +122,7 @@ def main(ctx, pkgs=None, config=None):
if pkgs is None:
pkgs = []
ctx.obj[CTX_PACKAGES] = pkgs
+ ctx.obj[CTX_VERBOSE] = verbose
main.add_command(serialize)
@@ -72,6 +132,7 @@ def main(ctx, pkgs=None, config=None):
main.add_command(run)
main.add_command(register)
main.add_command(backfill)
+main.epilog
if __name__ == "__main__":
main()
diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py
index 220f9209ea..77273cc81c 100644
--- a/flytekit/configuration/__init__.py
+++ b/flytekit/configuration/__init__.py
@@ -298,28 +298,31 @@ class PlatformConfig(object):
This object contains the settings to talk to a Flyte backend (the DNS location of your Admin server basically).
:param endpoint: DNS for Flyte backend
- :param insecure: Whether or not to use SSL
- :param insecure_skip_verify: Wether to skip SSL certificate verification
- :param console_endpoint: endpoint for console if different than Flyte backend
- :param command: This command is executed to return a token using an external process.
+ :param insecure: Whether to use SSL
+ :param insecure_skip_verify: Whether to skip SSL certificate verification
+ :param console_endpoint: endpoint for console if different from Flyte backend
+ :param command: This command is executed to return a token using an external process
:param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment.
More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/.
:param client_credentials_secret: Used for service auth, which is automatically called during pyflyte. This will
allow the Flyte engine to read the password directly from the environment variable. Note that this is
- less secure! Please only use this if mounting the secret as a file is impossible.
- :param scopes: List of scopes to request. This is only applicable to the client credentials flow.
- :param auth_mode: The OAuth mode to use. Defaults to pkce flow.
+ less secure! Please only use this if mounting the secret as a file is impossible
+ :param scopes: List of scopes to request. This is only applicable to the client credentials flow
+ :param auth_mode: The OAuth mode to use. Defaults to pkce flow
+ :param ca_cert_file_path: [optional] str Root Cert to be loaded and used to verify admin
"""
endpoint: str = "localhost:30080"
insecure: bool = False
insecure_skip_verify: bool = False
+ ca_cert_file_path: typing.Optional[str] = None
console_endpoint: typing.Optional[str] = None
command: typing.Optional[typing.List[str]] = None
client_id: typing.Optional[str] = None
client_credentials_secret: typing.Optional[str] = None
scopes: List[str] = field(default_factory=list)
auth_mode: AuthType = AuthType.STANDARD
+ rpc_retries: int = 3
@classmethod
def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig:
@@ -334,6 +337,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
kwargs = set_if_exists(
kwargs, "insecure_skip_verify", _internal.Platform.INSECURE_SKIP_VERIFY.read(config_file)
)
+ kwargs = set_if_exists(kwargs, "ca_cert_file_path", _internal.Platform.CA_CERT_FILE_PATH.read(config_file))
kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file))
kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file))
kwargs = set_if_exists(
diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py
index 5c29045db5..5c3729e63b 100644
--- a/flytekit/configuration/internal.py
+++ b/flytekit/configuration/internal.py
@@ -108,6 +108,9 @@ class Platform(object):
LegacyConfigEntry(SECTION, "insecure_skip_verify", bool), YamlConfigEntry("admin.insecureSkipVerify", bool)
)
CONSOLE_ENDPOINT = ConfigEntry(LegacyConfigEntry(SECTION, "console_endpoint"), YamlConfigEntry("console.endpoint"))
+ CA_CERT_FILE_PATH = ConfigEntry(
+ LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath")
+ )
class LocalSDK(object):
diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py
index 6575218666..59b58a0a5a 100644
--- a/flytekit/exceptions/user.py
+++ b/flytekit/exceptions/user.py
@@ -1,3 +1,5 @@
+import typing
+
from flytekit.exceptions.base import FlyteException as _FlyteException
from flytekit.exceptions.base import FlyteRecoverableException as _Recoverable
@@ -84,3 +86,11 @@ class FlyteRecoverableException(FlyteUserException, _Recoverable):
class FlyteAuthenticationException(FlyteAssertion):
_ERROR_CODE = "USER:AuthenticationError"
+
+
+class FlyteInvalidInputException(FlyteUserException):
+ _ERROR_CODE = "USER:BadInputToAPI"
+
+ def __init__(self, request: typing.Any):
+ self.request = request
+ super(self).__init__()
diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py
index 93badd5374..03cc9a66e9 100644
--- a/flytekit/remote/remote.py
+++ b/flytekit/remote/remote.py
@@ -596,6 +596,7 @@ def _serialize_and_register(
settings: typing.Optional[SerializationSettings],
version: str,
options: typing.Optional[Options] = None,
+ create_default_launchplan: bool = True,
) -> Identifier:
"""
This method serializes and register the given Flyte entity
@@ -630,7 +631,7 @@ def _serialize_and_register(
cp_entity,
settings=settings,
version=version,
- create_default_launchplan=True,
+ create_default_launchplan=create_default_launchplan,
options=options,
og_entity=entity,
)
@@ -685,14 +686,7 @@ def register_workflow(
b.domain = ident.domain
b.version = ident.version
serialization_settings = b.build()
- ident = self._serialize_and_register(entity, serialization_settings, version, options)
- if default_launch_plan:
- default_lp = LaunchPlan.get_default_launch_plan(self.context, entity)
- self.register_launch_plan(
- default_lp, version=ident.version, project=ident.project, domain=ident.domain, options=options
- )
- remote_logger.debug("Created default launch plan for Workflow")
-
+ ident = self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan)
fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version)
fwf._python_interface = entity.python_interface
return fwf
diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py
index b211153f44..e963f3dfc6 100644
--- a/tests/flytekit/unit/cli/pyflyte/test_run.py
+++ b/tests/flytekit/unit/cli/pyflyte/test_run.py
@@ -152,6 +152,7 @@ def test_union_type_with_invalid_input():
runner.invoke(
pyflyte.main,
[
+ "--verbose",
"run",
os.path.join(DIR_NAME, "workflow.py"),
"test_union2",
diff --git a/tests/flytekit/unit/clients/auth/__init__.py b/tests/flytekit/unit/clients/auth/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/flytekit/unit/cli/auth/test_auth.py b/tests/flytekit/unit/clients/auth/test_auth_client.py
similarity index 52%
rename from tests/flytekit/unit/cli/auth/test_auth.py
rename to tests/flytekit/unit/clients/auth/test_auth_client.py
index 487ddfcd1d..5ab843da8b 100644
--- a/tests/flytekit/unit/cli/auth/test_auth.py
+++ b/tests/flytekit/unit/clients/auth/test_auth_client.py
@@ -2,29 +2,40 @@
import re
from multiprocessing import Queue as _Queue
-from flytekit.clis.auth import auth as _auth
+from flytekit.clients.auth.auth_client import (
+ EndpointMetadata,
+ OAuthHTTPServer,
+ _create_code_challenge,
+ _generate_code_verifier,
+ _generate_state_parameter,
+)
def test_generate_code_verifier():
- verifier = _auth._generate_code_verifier()
+ verifier = _generate_code_verifier()
assert verifier is not None
assert 43 < len(verifier) < 128
assert not re.search(r"[^a-zA-Z0-9_\-.~]+", verifier)
def test_generate_state_parameter():
- param = _auth._generate_state_parameter()
+ param = _generate_state_parameter()
assert not re.search(r"[^a-zA-Z0-9-_.,]+", param)
def test_create_code_challenge():
test_code_verifier = "test_code_verifier"
- assert _auth._create_code_challenge(test_code_verifier) == "Qq1fGD0HhxwbmeMrqaebgn1qhvKeguQPXqLdpmixaM4"
+ assert _create_code_challenge(test_code_verifier) == "Qq1fGD0HhxwbmeMrqaebgn1qhvKeguQPXqLdpmixaM4"
def test_oauth_http_server():
queue = _Queue()
- server = _auth.OAuthHTTPServer(("localhost", 9000), _BaseHTTPServer.BaseHTTPRequestHandler, queue=queue)
+ server = OAuthHTTPServer(
+ ("localhost", 9000),
+ remote_metadata=EndpointMetadata(endpoint="example.com"),
+ request_handler_class=_BaseHTTPServer.BaseHTTPRequestHandler,
+ queue=queue,
+ )
test_auth_code = "auth_code"
server.handle_authorization_code(test_auth_code)
auth_code = queue.get()
diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py
new file mode 100644
index 0000000000..4c968cf0bd
--- /dev/null
+++ b/tests/flytekit/unit/clients/auth/test_authenticator.py
@@ -0,0 +1,95 @@
+import json
+import subprocess
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from flytekit.clients.auth.authenticator import (
+ ClientConfig,
+ ClientCredentialsAuthenticator,
+ CommandAuthenticator,
+ PKCEAuthenticator,
+ StaticClientConfigStore,
+)
+from flytekit.clients.auth.exceptions import AuthenticationError
+
+ENDPOINT = "example.com"
+
+client_config = ClientConfig(
+ token_endpoint="token_endpoint",
+ authorization_endpoint="auth_endpoint",
+ redirect_uri="redirect_uri",
+ client_id="client",
+)
+
+static_cfg_store = StaticClientConfigStore(client_config)
+
+
+@patch("flytekit.clients.auth.authenticator.KeyringStore")
+@patch("flytekit.clients.auth.auth_client.AuthorizationClient.get_creds_from_remote")
+@patch("flytekit.clients.auth.auth_client.AuthorizationClient.refresh_access_token")
+def test_pkce_authenticator(mock_refresh: MagicMock, mock_get_creds: MagicMock, mock_keyring: MagicMock):
+ mock_keyring.retrieve.return_value = None
+ authn = PKCEAuthenticator(ENDPOINT, static_cfg_store)
+ assert authn._verify is None
+
+ authn = PKCEAuthenticator(ENDPOINT, static_cfg_store, verify=False)
+ assert authn._verify is False
+
+ assert authn._creds is None
+ assert authn._auth_client is None
+ authn.refresh_credentials()
+ assert authn._auth_client
+ mock_get_creds.assert_called()
+ mock_refresh.assert_not_called()
+ mock_keyring.store.assert_called()
+
+ authn.refresh_credentials()
+ mock_refresh.assert_called()
+
+
+@patch("subprocess.run")
+def test_command_authenticator(mock_subprocess: MagicMock):
+ with pytest.raises(AuthenticationError):
+ authn = CommandAuthenticator(None) # noqa
+
+ authn = CommandAuthenticator(["echo"])
+
+ authn.refresh_credentials()
+ assert authn._creds
+ mock_subprocess.assert_called()
+
+ mock_subprocess.side_effect = subprocess.CalledProcessError(-1, ["x"])
+
+ with pytest.raises(AuthenticationError):
+ authn.refresh_credentials()
+
+
+def test_get_basic_authorization_header():
+ header = ClientCredentialsAuthenticator.get_basic_authorization_header("client_id", "abc")
+ assert header == "Basic Y2xpZW50X2lkOmFiYw=="
+
+
+@patch("flytekit.clients.auth.authenticator.requests")
+def test_get_token(mock_requests):
+ response = MagicMock()
+ response.status_code = 200
+ response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
+ mock_requests.post.return_value = response
+ access, expiration = ClientCredentialsAuthenticator.get_token("https://corp.idp.net", "abc123", ["my_scope"])
+ assert access == "abc"
+ assert expiration == 60
+
+
+@patch("flytekit.clients.auth.authenticator.requests")
+def test_client_creds_authenticator(mock_requests):
+ authn = ClientCredentialsAuthenticator(
+ ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store
+ )
+
+ response = MagicMock()
+ response.status_code = 200
+ response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
+ mock_requests.post.return_value = response
+ authn.refresh_credentials()
+ assert authn._creds
diff --git a/tests/flytekit/unit/clients/auth/test_default_html.py b/tests/flytekit/unit/clients/auth/test_default_html.py
new file mode 100644
index 0000000000..391fb6a542
--- /dev/null
+++ b/tests/flytekit/unit/clients/auth/test_default_html.py
@@ -0,0 +1,18 @@
+from flytekit.clients.auth.default_html import get_default_success_html
+
+
+def test_default_html():
+ assert (
+ get_default_success_html("flyte.org")
+ == """
+
+
+ OAuth2 Authentication Success
+
+
+ Successfully logged into flyte.org
+
+
+
+"""
+ ) # noqa
diff --git a/tests/flytekit/unit/clients/auth/test_keyring_store.py b/tests/flytekit/unit/clients/auth/test_keyring_store.py
new file mode 100644
index 0000000000..d068a1f451
--- /dev/null
+++ b/tests/flytekit/unit/clients/auth/test_keyring_store.py
@@ -0,0 +1,32 @@
+from unittest.mock import MagicMock, patch
+
+from keyring.errors import NoKeyringError
+
+from flytekit.clients.auth.keyring import Credentials, KeyringStore
+
+
+@patch("keyring.get_password")
+def test_keyring_store_get(kr_get_password: MagicMock):
+ kr_get_password.return_value = "t"
+ assert KeyringStore.retrieve("example1.com") is not None
+
+ kr_get_password.side_effect = NoKeyringError()
+ assert KeyringStore.retrieve("example2.com") is None
+
+
+@patch("keyring.delete_password")
+def test_keyring_store_delete(kr_del_password: MagicMock):
+ kr_del_password.return_value = None
+ assert KeyringStore.delete("example1.com") is None
+
+ kr_del_password.side_effect = NoKeyringError()
+ assert KeyringStore.delete("example2.com") is None
+
+
+@patch("keyring.set_password")
+def test_keyring_store_set(kr_set_password: MagicMock):
+ kr_set_password.return_value = None
+ assert KeyringStore.store(Credentials(access_token="a", refresh_token="r", for_endpoint="f"))
+
+ kr_set_password.side_effect = NoKeyringError()
+ assert KeyringStore.retrieve("example2.com") is None
diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py
new file mode 100644
index 0000000000..8f14de730e
--- /dev/null
+++ b/tests/flytekit/unit/clients/test_auth_helper.py
@@ -0,0 +1,154 @@
+import os.path
+from unittest.mock import MagicMock, patch
+
+import pytest
+from flyteidl.service.auth_pb2 import OAuth2MetadataResponse, PublicClientAuthConfigResponse
+
+from flytekit.clients.auth.authenticator import (
+ ClientConfig,
+ ClientConfigStore,
+ ClientCredentialsAuthenticator,
+ CommandAuthenticator,
+ PKCEAuthenticator,
+)
+from flytekit.clients.auth.exceptions import AuthenticationError
+from flytekit.clients.auth_helper import (
+ RemoteClientConfigStore,
+ get_authenticator,
+ load_cert,
+ upgrade_channel_to_authenticated,
+ wrap_exceptions_channel,
+)
+from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor
+from flytekit.clients.grpc_utils.wrap_exception_interceptor import RetryExceptionWrapperInterceptor
+from flytekit.configuration import AuthType, PlatformConfig
+
+REDIRECT_URI = "http://localhost:53593/callback"
+
+TOKEN_ENDPOINT = "https://your.domain.io/oauth2/token"
+
+CLIENT_ID = "flytectl"
+
+OAUTH_AUTHORIZE = "https://your.domain.io/oauth2/authorize"
+
+
+def get_auth_service_mock() -> MagicMock:
+ auth_stub_mock = MagicMock()
+ auth_stub_mock.GetPublicClientConfig.return_value = PublicClientAuthConfigResponse(
+ client_id=CLIENT_ID,
+ redirect_uri=REDIRECT_URI,
+ scopes=["offline", "all"],
+ authorization_metadata_key="flyte-authorization",
+ )
+ auth_stub_mock.GetOAuth2Metadata.return_value = OAuth2MetadataResponse(
+ issuer="https://your.domain.io",
+ authorization_endpoint=OAUTH_AUTHORIZE,
+ token_endpoint=TOKEN_ENDPOINT,
+ response_types_supported=["code", "token", "code token"],
+ scopes_supported=["all"],
+ token_endpoint_auth_methods_supported=["client_secret_basic"],
+ jwks_uri="https://your.domain.io/oauth2/jwks",
+ code_challenge_methods_supported=["S256"],
+ grant_types_supported=["client_credentials", "refresh_token", "authorization_code"],
+ )
+ return auth_stub_mock
+
+
+@patch("flytekit.clients.auth_helper.AuthMetadataServiceStub")
+def test_remote_client_config_store(mock_auth_service: MagicMock):
+ ch = MagicMock()
+ cs = RemoteClientConfigStore(ch)
+ mock_auth_service.return_value = get_auth_service_mock()
+
+ ccfg = cs.get_client_config()
+ assert ccfg is not None
+ assert ccfg.client_id == CLIENT_ID
+ assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE
+
+
+def get_client_config() -> ClientConfigStore:
+ cfg_store = MagicMock()
+ cfg_store.get_client_config.return_value = ClientConfig(
+ token_endpoint=TOKEN_ENDPOINT,
+ authorization_endpoint=OAUTH_AUTHORIZE,
+ redirect_uri=REDIRECT_URI,
+ client_id=CLIENT_ID,
+ )
+ return cfg_store
+
+
+def test_get_authenticator_basic():
+ cfg = PlatformConfig(auth_mode=AuthType.BASIC)
+
+ with pytest.raises(ValueError, match="Client ID and Client SECRET both are required"):
+ get_authenticator(cfg, None)
+
+ cfg = PlatformConfig(auth_mode=AuthType.BASIC, client_credentials_secret="xyz", client_id="id")
+ authn = get_authenticator(cfg, get_client_config())
+ assert authn
+ assert isinstance(authn, ClientCredentialsAuthenticator)
+
+ cfg = PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS, client_credentials_secret="xyz", client_id="id")
+ authn = get_authenticator(cfg, get_client_config())
+ assert authn
+ assert isinstance(authn, ClientCredentialsAuthenticator)
+
+ cfg = PlatformConfig(auth_mode=AuthType.CLIENTSECRET, client_credentials_secret="xyz", client_id="id")
+ authn = get_authenticator(cfg, get_client_config())
+ assert authn
+ assert isinstance(authn, ClientCredentialsAuthenticator)
+
+
+def test_get_authenticator_pkce():
+ cfg = PlatformConfig()
+ authn = get_authenticator(cfg, get_client_config())
+ assert authn
+ assert isinstance(authn, PKCEAuthenticator)
+
+ cfg = PlatformConfig(insecure_skip_verify=True)
+ authn = get_authenticator(cfg, get_client_config())
+ assert authn
+ assert isinstance(authn, PKCEAuthenticator)
+ assert authn._verify is False
+
+ cfg = PlatformConfig(ca_cert_file_path="/file")
+ authn = get_authenticator(cfg, get_client_config())
+ assert authn
+ assert isinstance(authn, PKCEAuthenticator)
+ assert authn._verify == "/file"
+
+
+def test_get_authenticator_cmd():
+ cfg = PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS)
+ with pytest.raises(AuthenticationError):
+ get_authenticator(cfg, get_client_config())
+
+ cfg = PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS, command=["echo"])
+ authn = get_authenticator(cfg, get_client_config())
+ assert authn
+ assert isinstance(authn, CommandAuthenticator)
+
+ cfg = PlatformConfig(auth_mode=AuthType.EXTERNALCOMMAND, command=["echo"])
+ authn = get_authenticator(cfg, get_client_config())
+ assert authn
+ assert isinstance(authn, CommandAuthenticator)
+ assert authn._cmd == ["echo"]
+
+
+def test_wrap_exceptions_channel():
+ ch = MagicMock()
+ out_ch = wrap_exceptions_channel(PlatformConfig(), ch)
+ assert isinstance(out_ch._interceptor, RetryExceptionWrapperInterceptor) # noqa
+
+
+def test_upgrade_channel_to_auth():
+ ch = MagicMock()
+ out_ch = upgrade_channel_to_authenticated(PlatformConfig(), ch)
+ assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) # noqa
+
+
+def test_load_cert():
+ cert_file = os.path.join(os.path.dirname(__file__), "testdata", "rootCACert.pem")
+ f = load_cert(cert_file)
+ assert f
+ print(f)
diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py
index 10a7e09333..ee4e516354 100644
--- a/tests/flytekit/unit/clients/test_raw.py
+++ b/tests/flytekit/unit/clients/test_raw.py
@@ -1,148 +1,9 @@
-import json
-from subprocess import CompletedProcess
+from unittest import mock
-import grpc
-import mock
-import pytest
from flyteidl.admin import project_pb2 as _project_pb2
-from flyteidl.service import auth_pb2
-from mock import MagicMock, patch
-from flytekit.clients.raw import (
- RawSynchronousFlyteClient,
- _handle_invalid_create_request,
- get_basic_authorization_header,
- get_token,
-)
-from flytekit.configuration import AuthType, PlatformConfig
-from flytekit.configuration.internal import Credentials
-
-
-def get_admin_stub_mock() -> mock.MagicMock:
- auth_stub_mock = mock.MagicMock()
- auth_stub_mock.GetPublicClientConfig.return_value = auth_pb2.PublicClientAuthConfigResponse(
- client_id="flytectl",
- redirect_uri="http://localhost:53593/callback",
- scopes=["offline", "all"],
- authorization_metadata_key="flyte-authorization",
- )
- auth_stub_mock.GetOAuth2Metadata.return_value = auth_pb2.OAuth2MetadataResponse(
- issuer="https://your.domain.io",
- authorization_endpoint="https://your.domain.io/oauth2/authorize",
- token_endpoint="https://your.domain.io/oauth2/token",
- response_types_supported=["code", "token", "code token"],
- scopes_supported=["all"],
- token_endpoint_auth_methods_supported=["client_secret_basic"],
- jwks_uri="https://your.domain.io/oauth2/jwks",
- code_challenge_methods_supported=["S256"],
- grant_types_supported=["client_credentials", "refresh_token", "authorization_code"],
- )
- return auth_stub_mock
-
-
-@mock.patch("flytekit.clients.raw.signal_service")
-@mock.patch("flytekit.clients.raw.dataproxy_service")
-@mock.patch("flytekit.clients.raw.auth_service")
-@mock.patch("flytekit.clients.raw._admin_service")
-@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
-@mock.patch("flytekit.clients.raw.grpc.secure_channel")
-def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal):
- mock_secure_channel.return_value = True
- mock_channel.return_value = True
- mock_admin.AdminServiceStub.return_value = True
- mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock()
- client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True))
- client.set_access_token("abc")
- assert client._metadata[0][1] == "Bearer abc"
- assert client.check_access_token("abc")
-
-
-@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.set_access_token")
-@mock.patch("flytekit.clients.raw.auth_service")
-@mock.patch("subprocess.run")
-def test_refresh_credentials_from_command(mock_call_to_external_process, mock_admin_auth, mock_set_access_token):
- token = "token"
- command = ["command", "generating", "token"]
-
- mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock()
- client = RawSynchronousFlyteClient(PlatformConfig(command=command))
-
- mock_call_to_external_process.return_value = CompletedProcess(command, 0, stdout=token)
- client._refresh_credentials_from_command()
-
- mock_call_to_external_process.assert_called_with(command, capture_output=True, text=True, check=True)
- mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key)
-
-
-@mock.patch("flytekit.clients.raw.signal_service")
-@mock.patch("flytekit.clients.raw.dataproxy_service")
-@mock.patch("flytekit.clients.raw.get_basic_authorization_header")
-@mock.patch("flytekit.clients.raw.get_token")
-@mock.patch("flytekit.clients.raw.auth_service")
-@mock.patch("flytekit.clients.raw._admin_service")
-@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
-@mock.patch("flytekit.clients.raw.grpc.secure_channel")
-def test_refresh_client_credentials_aka_basic(
- mock_secure_channel,
- mock_channel,
- mock_admin,
- mock_admin_auth,
- mock_get_token,
- mock_get_basic_header,
- mock_dataproxy,
- mock_signal,
-):
- mock_secure_channel.return_value = True
- mock_channel.return_value = True
- mock_admin.AdminServiceStub.return_value = True
- mock_get_basic_header.return_value = "Basic 123"
- mock_get_token.return_value = ("token1", 1234567)
-
- mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock()
- client = RawSynchronousFlyteClient(
- PlatformConfig(
- endpoint="a.b.com", insecure=True, client_credentials_secret="sosecret", scopes=["a", "b", "c", "d"]
- )
- )
- client._metadata = None
- assert not client.check_access_token("fdsa")
- client._refresh_credentials_basic()
-
- # Scopes from configuration take precendence.
- mock_get_token.assert_called_once_with("https://your.domain.io/oauth2/token", "Basic 123", "a,b,c,d")
-
- client.set_access_token("token")
- assert client._metadata[0][0] == "authorization"
-
-
-@mock.patch("flytekit.clients.raw.signal_service")
-@mock.patch("flytekit.clients.raw.dataproxy_service")
-@mock.patch("flytekit.clients.raw.auth_service")
-@mock.patch("flytekit.clients.raw._admin_service")
-@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
-@mock.patch("flytekit.clients.raw.grpc.secure_channel")
-def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal):
- mock_secure_channel.return_value = True
- mock_channel.return_value = True
- mock_admin.AdminServiceStub.return_value = True
-
- # If the public client config is missing then raise an error
- mocked_auth = get_admin_stub_mock()
- mocked_auth.GetPublicClientConfig.return_value = None
- mock_admin_auth.AuthMetadataServiceStub.return_value = mocked_auth
- client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True))
- assert client.public_client_config is None
- with pytest.raises(ValueError):
- client._refresh_credentials_basic()
-
- # If the oauth2 metadata is missing then raise an error
- mocked_auth = get_admin_stub_mock()
- mocked_auth.GetOAuth2Metadata.return_value = None
- mock_admin_auth.AuthMetadataServiceStub.return_value = mocked_auth
- client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True))
- assert client.oauth2_metadata is None
- with pytest.raises(ValueError):
- client._refresh_credentials_basic()
+from flytekit.clients.raw import RawSynchronousFlyteClient
+from flytekit.configuration import PlatformConfig
@mock.patch("flytekit.clients.raw._admin_service")
@@ -161,84 +22,3 @@ def test_list_projects_paginated(mock_channel, mock_admin):
project_list_request = _project_pb2.ProjectListRequest(limit=100, token="", filters=None, sort_by=None)
client.list_projects(project_list_request)
mock_admin.AdminServiceStub().ListProjects.assert_called_with(project_list_request, metadata=None)
-
-
-def test_get_basic_authorization_header():
- header = get_basic_authorization_header("client_id", "abc")
- assert header == "Basic Y2xpZW50X2lkOmFiYw=="
-
-
-@patch("flytekit.clients.raw._requests")
-def test_get_token(mock_requests):
- response = MagicMock()
- response.status_code = 200
- response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
- mock_requests.post.return_value = response
- access, expiration = get_token("https://corp.idp.net", "abc123", "my_scope")
- assert access == "abc"
- assert expiration == 60
-
-
-@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_standard")
-def test_refresh_standard(mocked_method):
- cc = RawSynchronousFlyteClient(PlatformConfig())
- cc.refresh_credentials()
- assert mocked_method.called
-
-
-@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_basic")
-def test_refresh_basic(mocked_method):
- cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.BASIC))
- cc.refresh_credentials()
- assert mocked_method.called
-
- cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS))
- cc.refresh_credentials()
- assert mocked_method.call_count == 2
-
-
-@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_basic")
-def test_basic_strings(mocked_method):
- cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode="basic"))
- cc.refresh_credentials()
- assert mocked_method.called
-
- cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode="client_credentials"))
- cc.refresh_credentials()
- assert mocked_method.call_count == 2
-
-
-@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command")
-def test_refresh_command(mocked_method):
- cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNALCOMMAND))
- cc.refresh_credentials()
- assert mocked_method.called
-
-
-@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command")
-def test_refresh_from_environment_variable(mocked_method, monkeypatch: pytest.MonkeyPatch):
- monkeypatch.setenv(Credentials.AUTH_MODE.legacy.get_env_name(), AuthType.EXTERNAL_PROCESS.name, prepend=False)
- cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=None).auto(None))
- cc.refresh_credentials()
- assert mocked_method.called
-
-
-def test__handle_invalid_create_request_decorator_happy():
- client = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS))
- mocked_method = client._stub.CreateWorkflow = mock.Mock()
- _handle_invalid_create_request(client.create_workflow("/flyteidl.service.AdminService/CreateWorkflow"))
- mocked_method.assert_called_once()
-
-
-@patch("flytekit.clients.raw.cli_logger")
-@patch("flytekit.clients.raw._MessageToJson")
-def test__handle_invalid_create_request_decorator_raises(mock_to_JSON, mock_logger):
- mock_to_JSON(return_value="test")
- err = grpc.RpcError()
- err.details = "There is already a workflow with different structure."
- err.code = lambda: grpc.StatusCode.INVALID_ARGUMENT
- client = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS))
- client._stub.CreateWorkflow = mock.Mock(side_effect=err)
- with pytest.raises(grpc.RpcError):
- _handle_invalid_create_request(client.create_workflow("/flyteidl.service.AdminService/CreateWorkflow"))
- mock_logger.error.assert_called_with("There is already a workflow with different structure.")
diff --git a/tests/flytekit/unit/clients/testdata/rootCACert.pem b/tests/flytekit/unit/clients/testdata/rootCACert.pem
new file mode 100644
index 0000000000..e9117bced2
--- /dev/null
+++ b/tests/flytekit/unit/clients/testdata/rootCACert.pem
@@ -0,0 +1,17 @@
+-----BEGIN CERTIFICATE-----
+MIICsDCCAZgCCQC8mxWHhdmIjDANBgkqhkiG9w0BAQsFADAaMQswCQYDVQQGEwJC
+UjELMAkGA1UECAwCV0EwHhcNMjMwMjIyMDU1MTAwWhcNMzMwMjE5MDU1MTAwWjAa
+MQswCQYDVQQGEwJCUjELMAkGA1UECAwCV0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB
+DwAwggEKAoIBAQDru2F+xX/Z4Q2W6h5qzwZUpCUjkgjQgSZPHI11hfLqOemXoW9o
+66LtU/QOeMXSLGnhq7WSILqVz9sCmRvQmxYTFuur8eQXPEpqggOVWfJQ6vj4ssvf
+a5j5KWytM9ixgwUmw9xwAAQ4FC1LQsYyD44Lkb+OPJZEK55ZOyGAPblVJW9WSvnX
+DsWXbsDENMU7XlpMtVI5/tToBLaSKIMhbZlssCJtjc7omBTL14yy9L7Bj321/a0m
+8SRH/XtfmTux/HHq60qTvUWiHrL+CjehZQcehGlKpqaYq5sEQCVWt/cfyvHLo0gd
+X6ayAfmno8WeNbXqsuoU+xckI+8WI4vQ2cfnAgMBAAEwDQYJKoZIhvcNAQELBQAD
+ggEBAB++cNraohRh91Xa1GW6IQAInyDnxsXSY0wXrpJsAN541lETdk9+L1CbtHxk
+5hwvrkUItzrUjIIKq11k0NceM3ClYyO826By/DMjWtPMYp0eSXLKJJnAT5euhONy
+9eBLyNFO0yUj77fEiEj6k5PAUBCgs6ZzWTVCgBiNKPAT6WxaYeIwXdQvC0KoJ0t7
+0SD7/I4i9SSjw3lCRZfMKdd7MEPTpi5hXpZPphg9HYJX5o1KSjWvMTYDUzaQtOlf
+GM9zNSXug/GyYgVgUyg2dqp/ohbMtqgFH1kTbvMlLmS6BtQxyi11G6QWv6gdb8z5
+es7Wv+5ZqVjroswzEGi/h72Xo0E=
+-----END CERTIFICATE-----
diff --git a/tests/flytekit/unit/clients/testdata/rootCAKey.pem b/tests/flytekit/unit/clients/testdata/rootCAKey.pem
new file mode 100644
index 0000000000..0f71fce5f1
--- /dev/null
+++ b/tests/flytekit/unit/clients/testdata/rootCAKey.pem
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEowIBAAKCAQEA67thfsV/2eENluoeas8GVKQlI5II0IEmTxyNdYXy6jnpl6Fv
+aOui7VP0DnjF0ixp4au1kiC6lc/bApkb0JsWExbrq/HkFzxKaoIDlVnyUOr4+LLL
+32uY+SlsrTPYsYMFJsPccAAEOBQtS0LGMg+OC5G/jjyWRCueWTshgD25VSVvVkr5
+1w7Fl27AxDTFO15aTLVSOf7U6AS2kiiDIW2ZbLAibY3O6JgUy9eMsvS+wY99tf2t
+JvEkR/17X5k7sfxx6utKk71Foh6y/go3oWUHHoRpSqammKubBEAlVrf3H8rxy6NI
+HV+msgH5p6PFnjW16rLqFPsXJCPvFiOL0NnH5wIDAQABAoIBAA+cVRSEF7diA/he
+gK0qEI1CYYM9hH/qTZMnnOaPfEqukx2Lf0k/cYat7JeYv+DvOAPNzzRiHnkVTreZ
+VBI4cvnIpsq4Nhaj03nCKmKVlkpthRdTH9Un1vWJHL1LlaoLtyeeCNcR6TWdgHJf
+daiTByEVAc51jK3vBYl7NPi9HazZsT8p1IxN8pXrM4yXWZ2UDhVkNfiylmg+qsnt
+s5C8ENmOoAEeUxfQjPUK72UrQAZM7+N+yGvXPHTkJAEs2yx25NbTUZugXa6+Ehlx
+55RuXLbDB0cn/eUp0wmLAryYekTm0AjPy4wsvwoKAkxSnIMdfj6kO0EbEYZv5yFC
+aPTeJyECgYEA+h1L3dqnhykNvu6xk1T6fmJ1up8xYH7syToz/QXTlNHJKYQnmZDt
+u2Y6jmOC4Tv892/CAeDJZqw0KGJb94sDzD4oE0etf79SyfJ3Gh0BBU9P4W1r2uPx
+VNak8V/g/qAfP2cwZz/h9jPWkQ1UcfZa1Mpali4EEHdODAwqxXouHTsCgYEA8Udx
+jzyrQuwINJMKfzws52QYsuuy2IUbhfbz9r7lqauvbuYx0PCUH/fnwxUCqjr1BS/l
+jdy1SjlcvJ1lHsuJ6YaHu30FjNcDiyr0gAEXQeuIpIVD74BdTNJpKoy2AWIdIYtj
+wbjm95V3XOdo/tmFYx+6Ign702Lmi+gkDtDxRUUCgYEAwaLAw6eun5OHEtTVIc1e
+iU5M+wiYP67EPx4Sdcd3APZRmRS5W8i6ZKVGnEoqX5oDxMT/HFkdU6HqV4Ge1c0I
+Sa2tdQ+/IPHMdJCE6PCfg67dlxcRs0tZ4Wa0GDM0i60HxBxteuIYXHXRnkcFo50o
+wSlQbIh/mQfkoqsgyfZHkVUCgYAxiCcp7pyB+o6crGsFP8dAIW5onLZ0eK7zy4S9
+7Oac9F/pdlxXtmvSPERZ6iBH7h6K2BBaFSsqd6gwGGe/8Kz5QeLvfHT9Os7BbSoQ
+dSjfIYlFrQ4LRuDgenmYgJaEpi2wyzrJdDoGLar5aZBGcUVO2h6OClqmRLFrm1Z7
+rC07uQKBgBqrNfLs0aBolWXhD5KcWl7Eg+lr8Zm3htYTXSZaQwpuDLQ/SA7F02fu
+UWbJVNN44xeJZdFtygxqPLOXrRoNYuYwJ6A5SIjERFiwg7kxGqlWKMoVs0HFFsEb
+noBtXOzyx5GEyghotTaAr/wdS7eY8ccmTsKm15td1MLUxjK1nlS/
+-----END RSA PRIVATE KEY-----