-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flytekit Auth system overhaul and pretty printing upgrade #1458
Merged
Merged
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
36bb637
[wip] New authentication system
kumare3 ec3e8d0
Better error handling and printing, better exception handling and
kumare3 f3df707
Delete legacy files
kumare3 3251e8b
Merge branch 'master' into auth-update
332a912
add missing None
7444543
keyring removed
703346b
added insecure_skip_verify
e10a3bf
test fixed
8bf6528
Test fixed
fca797a
Auth update
5c3cb8f
updated test
kumare3 9994129
updated
kumare3 20dd899
flush buffer instead of closing, was getting a weird stack trace. mak…
wild-endeavor 4a433ab
updated ca-cert logic
kumare3 6fc56a6
Fixed unit tests
kumare3 171be7c
Merge branch 'master' into auth-update
kumare3 a9ebd4b
updated
kumare3 7097db4
test fix
kumare3 8831fe4
updated
kumare3 e8cdf58
nest raise if exc
wild-endeavor 96c3f60
added keyring.alt for tests
kumare3 6127100
updated
kumare3 9c7dc63
Merge branch 'master' into auth-update
kumare3 2c25466
updated
kumare3 9d9feaa
Lint
eapolinario File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,30 @@ | ||
from __future__ import annotations | ||
|
||
import base64 as _base64 | ||
import hashlib as _hashlib | ||
import http.server as _BaseHTTPServer | ||
import logging | ||
import multiprocessing | ||
import os as _os | ||
import re as _re | ||
import typing | ||
import urllib.parse as _urlparse | ||
import webbrowser as _webbrowser | ||
from dataclasses import dataclass | ||
from http import HTTPStatus as _StatusCodes | ||
from multiprocessing import get_context as _mp_get_context | ||
from multiprocessing import get_context | ||
from urllib.parse import urlencode as _urlencode | ||
|
||
import keyring as _keyring | ||
import requests as _requests | ||
|
||
from flytekit.loggers import auth_logger | ||
from .exceptions import AccessTokenNotFoundError | ||
from .keyring import Credentials | ||
|
||
_code_verifier_length = 64 | ||
_random_seed_length = 40 | ||
_utf_8 = "utf-8" | ||
|
||
|
||
# Identifies the service used for storing passwords in keyring | ||
_keyring_service_name = "flyteauth" | ||
# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs | ||
# suggest, we are storing a user's oidc. | ||
_keyring_access_token_storage_key = "access_token" | ||
_keyring_refresh_token_storage_key = "refresh_token" | ||
|
||
|
||
def _generate_code_verifier(): | ||
""" | ||
Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1 | ||
|
@@ -77,6 +75,17 @@ def state(self): | |
return self._state | ||
|
||
|
||
@dataclass | ||
class EndpointMetadata(object): | ||
""" | ||
This class can be used to control the rendering of the page on login successful or failure | ||
""" | ||
|
||
endpoint: str | ||
success_html: typing.Optional[bytes] = None | ||
failure_html: typing.Optional[bytes] = None | ||
|
||
|
||
class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler): | ||
""" | ||
A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an | ||
|
@@ -87,12 +96,28 @@ def do_GET(self): | |
url = _urlparse.urlparse(self.path) | ||
if url.path.strip("/") == self.server.redirect_path.strip("/"): | ||
self.send_response(_StatusCodes.OK) | ||
self.send_header("Content-type", "text/html") | ||
self.end_headers() | ||
self.handle_login(dict(_urlparse.parse_qsl(url.query))) | ||
if self.server.remote_metadata.success_html is None: | ||
self.wfile.write( | ||
bytes( | ||
f""" | ||
<html> | ||
<head> | ||
<title>Oauth2 authentication Flow</title> | ||
</head> | ||
<body> | ||
<h1>Log in successful to {self.server.remote_metadata.endpoint}</h1> | ||
</body></html>""", | ||
"utf-8", | ||
) | ||
) | ||
self.wfile.close() | ||
else: | ||
self.send_response(_StatusCodes.NOT_FOUND) | ||
|
||
def handle_login(self, data): | ||
def handle_login(self, data: dict): | ||
self.server.handle_authorization_code(AuthorizationCode(data["code"], data["state"])) | ||
|
||
|
||
|
@@ -104,49 +129,79 @@ class OAuthHTTPServer(_BaseHTTPServer.HTTPServer): | |
|
||
def __init__( | ||
self, | ||
server_address, | ||
RequestHandlerClass, | ||
bind_and_activate=True, | ||
redirect_path=None, | ||
queue=None, | ||
server_address: typing.Tuple[str, int], | ||
remote_metadata: EndpointMetadata, | ||
request_handler_class: typing.Type[_BaseHTTPServer.BaseHTTPRequestHandler], | ||
bind_and_activate: bool = True, | ||
redirect_path: str = None, | ||
queue: multiprocessing.Queue = None, | ||
): | ||
_BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) | ||
_BaseHTTPServer.HTTPServer.__init__(self, server_address, request_handler_class, bind_and_activate) | ||
self._redirect_path = redirect_path | ||
self._remote_metadata = remote_metadata | ||
self._auth_code = None | ||
self._queue = queue | ||
|
||
@property | ||
def redirect_path(self): | ||
def redirect_path(self) -> str: | ||
return self._redirect_path | ||
|
||
def handle_authorization_code(self, auth_code): | ||
@property | ||
def remote_metadata(self) -> EndpointMetadata: | ||
return self._remote_metadata | ||
|
||
def handle_authorization_code(self, auth_code: str): | ||
self._queue.put(auth_code) | ||
self.server_close() | ||
|
||
def handle_request(self, queue=None): | ||
def handle_request(self, queue: multiprocessing.Queue = None) -> typing.Any: | ||
self._queue = queue | ||
return super().handle_request() | ||
|
||
|
||
class Credentials(object): | ||
def __init__(self, access_token=None): | ||
self._access_token = access_token | ||
class _SingletonPerEndpoint(type): | ||
""" | ||
A metaclass to create per endpoint singletons for AuthorizationClient objects | ||
""" | ||
|
||
@property | ||
def access_token(self): | ||
return self._access_token | ||
_instances: typing.Dict[str, AuthorizationClient] = {} | ||
|
||
def __call__(cls, *args, **kwargs): | ||
endpoint = "" | ||
if args: | ||
endpoint = args[0] | ||
elif "auth_endpoint" in kwargs: | ||
endpoint = kwargs["auth_endpoint"] | ||
else: | ||
raise ValueError("parameter auth_endpoint is required") | ||
if endpoint not in cls._instances: | ||
cls._instances[endpoint] = super(_SingletonPerEndpoint, cls).__call__(*args, **kwargs) | ||
return cls._instances[endpoint] | ||
|
||
|
||
class AuthorizationClient(object): | ||
class AuthorizationClient(metaclass=_SingletonPerEndpoint): | ||
""" | ||
Authorization client that stores the credentials in keyring and uses oauth2 standard flow to retrieve the | ||
credentials. NOTE: This will open an web browser to retreive the credentials. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
auth_endpoint=None, | ||
token_endpoint=None, | ||
scopes=None, | ||
client_id=None, | ||
redirect_uri=None, | ||
endpoint: str, | ||
auth_endpoint: str, | ||
token_endpoint: str, | ||
scopes: typing.List[str] = None, | ||
client_id: str = None, | ||
redirect_uri: str = None, | ||
endpoint_metadata: EndpointMetadata = None, | ||
): | ||
self._endpoint = endpoint | ||
self._auth_endpoint = auth_endpoint | ||
if endpoint_metadata is None: | ||
remote_url = _urlparse.urlparse(self._auth_endpoint) | ||
self._remote = EndpointMetadata(endpoint=remote_url.hostname) | ||
else: | ||
self._remote = endpoint_metadata | ||
self._token_endpoint = token_endpoint | ||
self._client_id = client_id | ||
self._scopes = scopes or [] | ||
|
@@ -156,10 +211,7 @@ def __init__( | |
self._code_challenge = code_challenge | ||
state = _generate_state_parameter() | ||
self._state = state | ||
self._credentials = None | ||
self._refresh_token = None | ||
self._headers = {"content-type": "application/x-www-form-urlencoded"} | ||
self._expired = False | ||
|
||
self._params = { | ||
"client_id": client_id, # This must match the Client ID of the OAuth application. | ||
|
@@ -174,55 +226,27 @@ def __init__( | |
"code_challenge_method": "S256", | ||
} | ||
|
||
# Prefer to use already-fetched token values when they've been set globally. | ||
self.set_tokens_from_store() | ||
|
||
def __repr__(self): | ||
return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})" | ||
|
||
def set_tokens_from_store(self): | ||
self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key) | ||
access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key) | ||
if access_token: | ||
self._credentials = Credentials(access_token=access_token) | ||
|
||
@property | ||
def has_valid_credentials(self) -> bool: | ||
return self._credentials is not None | ||
|
||
def start_authorization_flow(self): | ||
# In the absence of globally-set token values, initiate the token request flow | ||
ctx = _mp_get_context("fork") | ||
q = ctx.Queue() | ||
|
||
# First prepare the callback server in the background | ||
server = self._create_callback_server() | ||
server_process = ctx.Process(target=server.handle_request, args=(q,)) | ||
server_process.daemon = True | ||
server_process.start() | ||
|
||
# Send the call to request the authorization code in the background | ||
self._request_authorization_code() | ||
|
||
# Request the access token once the auth code has been received. | ||
auth_code = q.get() | ||
server_process.terminate() | ||
self.request_access_token(auth_code) | ||
|
||
def _create_callback_server(self): | ||
server_url = _urlparse.urlparse(self._redirect_uri) | ||
server_address = (server_url.hostname, server_url.port) | ||
return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path) | ||
return OAuthHTTPServer( | ||
server_address, | ||
self._remote, | ||
OAuthCallbackHandler, | ||
redirect_path=server_url.path, | ||
) | ||
|
||
def _request_authorization_code(self): | ||
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint) | ||
query = _urlencode(self._params) | ||
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) | ||
auth_logger.debug(f"Requesting authorization code through {endpoint}") | ||
logging.debug(f"Requesting authorization code through {endpoint}") | ||
_webbrowser.open_new_tab(endpoint) | ||
|
||
def _initialize_credentials(self, auth_token_resp): | ||
|
||
def _credentials_from_response(self, auth_token_resp) -> Credentials: | ||
""" | ||
The auth_token_resp body is of the form: | ||
{ | ||
|
@@ -232,21 +256,16 @@ def _initialize_credentials(self, auth_token_resp): | |
} | ||
""" | ||
response_body = auth_token_resp.json() | ||
refresh_token = None | ||
if "access_token" not in response_body: | ||
raise ValueError('Expected "access_token" in response from oauth server') | ||
if "refresh_token" in response_body: | ||
self._refresh_token = response_body["refresh_token"] | ||
_keyring.set_password( | ||
_keyring_service_name, _keyring_refresh_token_storage_key, response_body["refresh_token"] | ||
) | ||
|
||
refresh_token = response_body["refresh_token"] | ||
access_token = response_body["access_token"] | ||
_keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token) | ||
|
||
# Once keyring credentials have been updated, get the singleton AuthorizationClient to read them again. | ||
self.set_tokens_from_store() | ||
return Credentials(access_token, refresh_token, self._endpoint) | ||
|
||
def request_access_token(self, auth_code): | ||
def _request_access_token(self, auth_code) -> Credentials: | ||
if self._state != auth_code.state: | ||
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed") | ||
self._params.update( | ||
|
@@ -269,38 +288,52 @@ def request_access_token(self, auth_code): | |
raise Exception( | ||
"Failed to request access token with response: [{}] {}".format(resp.status_code, resp.content) | ||
) | ||
self._initialize_credentials(resp) | ||
return self._credentials_from_response(resp) | ||
|
||
def get_creds_from_remote(self) -> Credentials: | ||
""" | ||
This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to | ||
retrieve credentials | ||
""" | ||
# In the absence of globally-set token values, initiate the token request flow | ||
ctx = get_context("fork") | ||
q = ctx.Queue() | ||
|
||
# First prepare the callback server in the background | ||
server = self._create_callback_server() | ||
|
||
def refresh_access_token(self): | ||
if self._refresh_token is None: | ||
server_process = ctx.Process(target=server.handle_request, args=(q,)) | ||
server_process.daemon = True | ||
|
||
try: | ||
server_process.start() | ||
|
||
# Send the call to request the authorization code in the background | ||
self._request_authorization_code() | ||
|
||
# Request the access token once the auth code has been received. | ||
auth_code = q.get() | ||
return self._request_access_token(auth_code) | ||
finally: | ||
server_process.terminate() | ||
|
||
def refresh_access_token(self, credentials: Credentials) -> Credentials: | ||
if credentials.refresh_token is None: | ||
raise ValueError("no refresh token available with which to refresh authorization credentials") | ||
|
||
resp = _requests.post( | ||
url=self._token_endpoint, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should consider passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me incorporate your change |
||
data={"grant_type": "refresh_token", "client_id": self._client_id, "refresh_token": self._refresh_token}, | ||
data={ | ||
"grant_type": "refresh_token", | ||
"client_id": self._client_id, | ||
"refresh_token": credentials.refresh_token, | ||
}, | ||
headers=self._headers, | ||
allow_redirects=False, | ||
) | ||
if resp.status_code != _StatusCodes.OK: | ||
self._expired = True | ||
# In the absence of a successful response, assume the refresh token is expired. This should indicate | ||
# to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized. | ||
raise AccessTokenNotFoundError(f"Non-200 returned from refresh token endpoint {resp.status_code}") | ||
|
||
_keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key) | ||
_keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key) | ||
raise ValueError(f"Non-200 returned from refresh token endpoint {resp.status_code}") | ||
self._initialize_credentials(resp) | ||
|
||
@property | ||
def credentials(self): | ||
""" | ||
:return flytekit.clis.auth.auth.Credentials: | ||
""" | ||
return self._credentials | ||
|
||
@property | ||
def expired(self): | ||
""" | ||
:return bool: | ||
""" | ||
return self._expired | ||
return self._credentials_from_response(resp) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add
please close this window
or something like that..There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do, i will make it a configurable thing, also add Flyte icon