Skip to content

Commit

Permalink
Flytekit Auth system overhaul and pretty printing upgrade (#1458)
Browse files Browse the repository at this point in the history
* [wip] New authentication system

 - Reuse local keyring better
 - use grpc based auth system

Signed-off-by: Ketan Umare <[email protected]>

* Better error handling and printing, better exception handling and
retrying

Signed-off-by: Ketan Umare <[email protected]>

* Delete legacy files

Signed-off-by: Ketan Umare <[email protected]>

* add missing None

Signed-off-by: Ketan Umare <[email protected]>

* keyring removed

Signed-off-by: Ketan Umare <[email protected]>

* added insecure_skip_verify

Signed-off-by: Ketan Umare <[email protected]>

* test fixed

Signed-off-by: Ketan Umare <[email protected]>

* Test fixed

Signed-off-by: Ketan Umare <[email protected]>

* Auth update

Signed-off-by: Ketan Umare <[email protected]>

* updated test

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* flush buffer instead of closing, was getting a weird stack trace. make the image smaller

Signed-off-by: Yee Hing Tong <[email protected]>

* updated ca-cert logic

Signed-off-by: Ketan Umare <[email protected]>

* Fixed unit tests

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* test fix

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* nest raise if exc

Signed-off-by: Yee Hing Tong <[email protected]>

* added keyring.alt for tests

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* Lint

Signed-off-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
6 people committed Feb 28, 2023
1 parent d75c800 commit 54e66cb
Show file tree
Hide file tree
Showing 29 changed files with 1,256 additions and 731 deletions.
4 changes: 4 additions & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ google-cloud-bigquery-storage
IPython
tensorflow
grpcio-status<1.49.0
keyrings.alt

# Only install tensorflow if not running on an arm Mac.
tensorflow==2.8.1; platform_machine!='arm64' or platform_system!='Darwin'
# Newer versions of torch bring in nvidia dependencies that are not present in windows, so
# we put this constraint while we do not have per-environment requirements files
torch<=1.12.1
Expand Down
File renamed without changes.
251 changes: 147 additions & 104 deletions flytekit/clis/auth/auth.py → flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
from __future__ import annotations

import base64 as _base64
import hashlib as _hashlib
import http.server as _BaseHTTPServer
import logging
import multiprocessing
import os as _os
import re as _re
import typing
import urllib.parse as _urlparse
import webbrowser as _webbrowser
from dataclasses import dataclass
from http import HTTPStatus as _StatusCodes
from multiprocessing import get_context as _mp_get_context
from multiprocessing import get_context
from urllib.parse import urlencode as _urlencode

import keyring as _keyring
import requests as _requests

from flytekit.loggers import auth_logger
from .default_html import get_default_success_html
from .exceptions import AccessTokenNotFoundError
from .keyring import Credentials

_code_verifier_length = 64
_random_seed_length = 40
_utf_8 = "utf-8"


# Identifies the service used for storing passwords in keyring
_keyring_service_name = "flyteauth"
# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs
# suggest, we are storing a user's oidc.
_keyring_access_token_storage_key = "access_token"
_keyring_refresh_token_storage_key = "refresh_token"


def _generate_code_verifier():
"""
Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1
Expand Down Expand Up @@ -77,6 +76,17 @@ def state(self):
return self._state


@dataclass
class EndpointMetadata(object):
"""
This class can be used to control the rendering of the page on login successful or failure
"""

endpoint: str
success_html: typing.Optional[bytes] = None
failure_html: typing.Optional[bytes] = None


class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an
Expand All @@ -87,12 +97,16 @@ def do_GET(self):
url = _urlparse.urlparse(self.path)
if url.path.strip("/") == self.server.redirect_path.strip("/"):
self.send_response(_StatusCodes.OK)
self.send_header("Content-type", "text/html")
self.end_headers()
self.handle_login(dict(_urlparse.parse_qsl(url.query)))
if self.server.remote_metadata.success_html is None:
self.wfile.write(bytes(get_default_success_html(self.server.remote_metadata.endpoint), "utf-8"))
self.wfile.flush()
else:
self.send_response(_StatusCodes.NOT_FOUND)

def handle_login(self, data):
def handle_login(self, data: dict):
self.server.handle_authorization_code(AuthorizationCode(data["code"], data["state"]))


Expand All @@ -104,49 +118,97 @@ class OAuthHTTPServer(_BaseHTTPServer.HTTPServer):

def __init__(
self,
server_address,
RequestHandlerClass,
bind_and_activate=True,
redirect_path=None,
queue=None,
server_address: typing.Tuple[str, int],
remote_metadata: EndpointMetadata,
request_handler_class: typing.Type[_BaseHTTPServer.BaseHTTPRequestHandler],
bind_and_activate: bool = True,
redirect_path: str = None,
queue: multiprocessing.Queue = None,
):
_BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate)
_BaseHTTPServer.HTTPServer.__init__(self, server_address, request_handler_class, bind_and_activate)
self._redirect_path = redirect_path
self._remote_metadata = remote_metadata
self._auth_code = None
self._queue = queue

@property
def redirect_path(self):
def redirect_path(self) -> str:
return self._redirect_path

def handle_authorization_code(self, auth_code):
@property
def remote_metadata(self) -> EndpointMetadata:
return self._remote_metadata

def handle_authorization_code(self, auth_code: str):
self._queue.put(auth_code)
self.server_close()

def handle_request(self, queue=None):
def handle_request(self, queue: multiprocessing.Queue = None) -> typing.Any:
self._queue = queue
return super().handle_request()


class Credentials(object):
def __init__(self, access_token=None):
self._access_token = access_token
class _SingletonPerEndpoint(type):
"""
A metaclass to create per endpoint singletons for AuthorizationClient objects
"""

_instances: typing.Dict[str, AuthorizationClient] = {}

@property
def access_token(self):
return self._access_token
def __call__(cls, *args, **kwargs):
endpoint = ""
if args:
endpoint = args[0]
elif "auth_endpoint" in kwargs:
endpoint = kwargs["auth_endpoint"]
else:
raise ValueError("parameter auth_endpoint is required")
if endpoint not in cls._instances:
cls._instances[endpoint] = super(_SingletonPerEndpoint, cls).__call__(*args, **kwargs)
return cls._instances[endpoint]


class AuthorizationClient(object):
class AuthorizationClient(metaclass=_SingletonPerEndpoint):
"""
Authorization client that stores the credentials in keyring and uses oauth2 standard flow to retrieve the
credentials. NOTE: This will open an web browser to retreive the credentials.
"""

def __init__(
self,
auth_endpoint=None,
token_endpoint=None,
scopes=None,
client_id=None,
redirect_uri=None,
endpoint: str,
auth_endpoint: str,
token_endpoint: str,
scopes: typing.Optional[typing.List[str]] = None,
client_id: typing.Optional[str] = None,
redirect_uri: typing.Optional[str] = None,
endpoint_metadata: typing.Optional[EndpointMetadata] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
):
"""
Create new AuthorizationClient
:param endpoint: str endpoint to connect to
:param auth_endpoint: str endpoint where auth metadata can be found
:param token_endpoint: str endpoint to retrieve token from
:param scopes: list[str] oauth2 scopes
:param client_id
:param verify: (optional) Either a boolean, in which case it controls whether we verify
the server's TLS certificate, or a string, in which case it must be a path
to a CA bundle to use. Defaults to ``True``. When set to
``False``, requests will accept any TLS certificate presented by
the server, and will ignore hostname mismatches and/or expired
certificates, which will make your application vulnerable to
man-in-the-middle (MitM) attacks. Setting verify to ``False``
may be useful during local development or testing.
"""
self._endpoint = endpoint
self._auth_endpoint = auth_endpoint
if endpoint_metadata is None:
remote_url = _urlparse.urlparse(self._auth_endpoint)
self._remote = EndpointMetadata(endpoint=remote_url.hostname)
else:
self._remote = endpoint_metadata
self._token_endpoint = token_endpoint
self._client_id = client_id
self._scopes = scopes or []
Expand All @@ -156,10 +218,8 @@ def __init__(
self._code_challenge = code_challenge
state = _generate_state_parameter()
self._state = state
self._credentials = None
self._refresh_token = None
self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._expired = False

self._params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
Expand All @@ -174,55 +234,27 @@ def __init__(
"code_challenge_method": "S256",
}

# Prefer to use already-fetched token values when they've been set globally.
self.set_tokens_from_store()

def __repr__(self):
return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})"

def set_tokens_from_store(self):
self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key)
access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key)
if access_token:
self._credentials = Credentials(access_token=access_token)

@property
def has_valid_credentials(self) -> bool:
return self._credentials is not None

def start_authorization_flow(self):
# In the absence of globally-set token values, initiate the token request flow
ctx = _mp_get_context("fork")
q = ctx.Queue()

# First prepare the callback server in the background
server = self._create_callback_server()
server_process = ctx.Process(target=server.handle_request, args=(q,))
server_process.daemon = True
server_process.start()

# Send the call to request the authorization code in the background
self._request_authorization_code()

# Request the access token once the auth code has been received.
auth_code = q.get()
server_process.terminate()
self.request_access_token(auth_code)

def _create_callback_server(self):
server_url = _urlparse.urlparse(self._redirect_uri)
server_address = (server_url.hostname, server_url.port)
return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path)
return OAuthHTTPServer(
server_address,
self._remote,
OAuthCallbackHandler,
redirect_path=server_url.path,
)

def _request_authorization_code(self):
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
query = _urlencode(self._params)
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
auth_logger.debug(f"Requesting authorization code through {endpoint}")
logging.debug(f"Requesting authorization code through {endpoint}")
_webbrowser.open_new_tab(endpoint)

def _initialize_credentials(self, auth_token_resp):

def _credentials_from_response(self, auth_token_resp) -> Credentials:
"""
The auth_token_resp body is of the form:
{
Expand All @@ -232,21 +264,16 @@ def _initialize_credentials(self, auth_token_resp):
}
"""
response_body = auth_token_resp.json()
refresh_token = None
if "access_token" not in response_body:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
self._refresh_token = response_body["refresh_token"]
_keyring.set_password(
_keyring_service_name, _keyring_refresh_token_storage_key, response_body["refresh_token"]
)

refresh_token = response_body["refresh_token"]
access_token = response_body["access_token"]
_keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token)

# Once keyring credentials have been updated, get the singleton AuthorizationClient to read them again.
self.set_tokens_from_store()
return Credentials(access_token, refresh_token, self._endpoint)

def request_access_token(self, auth_code):
def _request_access_token(self, auth_code) -> Credentials:
if self._state != auth_code.state:
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed")
self._params.update(
Expand All @@ -262,45 +289,61 @@ def request_access_token(self, auth_code):
data=self._params,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
)
if resp.status_code != _StatusCodes.OK:
# TODO: handle expected (?) error cases:
# https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses
raise Exception(
"Failed to request access token with response: [{}] {}".format(resp.status_code, resp.content)
)
self._initialize_credentials(resp)
return self._credentials_from_response(resp)

def get_creds_from_remote(self) -> Credentials:
"""
This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to
retrieve credentials
"""
# In the absence of globally-set token values, initiate the token request flow
ctx = get_context("fork")
q = ctx.Queue()

# First prepare the callback server in the background
server = self._create_callback_server()

server_process = ctx.Process(target=server.handle_request, args=(q,))
server_process.daemon = True

def refresh_access_token(self):
if self._refresh_token is None:
try:
server_process.start()

# Send the call to request the authorization code in the background
self._request_authorization_code()

# Request the access token once the auth code has been received.
auth_code = q.get()
return self._request_access_token(auth_code)
finally:
server_process.terminate()

def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
raise ValueError("no refresh token available with which to refresh authorization credentials")

resp = _requests.post(
url=self._token_endpoint,
data={"grant_type": "refresh_token", "client_id": self._client_id, "refresh_token": self._refresh_token},
data={
"grant_type": "refresh_token",
"client_id": self._client_id,
"refresh_token": credentials.refresh_token,
},
headers=self._headers,
allow_redirects=False,
verify=self._verify,
)
if resp.status_code != _StatusCodes.OK:
self._expired = True
# In the absence of a successful response, assume the refresh token is expired. This should indicate
# to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized.
raise AccessTokenNotFoundError(f"Non-200 returned from refresh token endpoint {resp.status_code}")

_keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key)
_keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key)
raise ValueError(f"Non-200 returned from refresh token endpoint {resp.status_code}")
self._initialize_credentials(resp)

@property
def credentials(self):
"""
:return flytekit.clis.auth.auth.Credentials:
"""
return self._credentials

@property
def expired(self):
"""
:return bool:
"""
return self._expired
return self._credentials_from_response(resp)
Loading

0 comments on commit 54e66cb

Please sign in to comment.