Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flytekit Auth system overhaul and pretty printing upgrade #1458

Merged
merged 25 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
241 changes: 137 additions & 104 deletions flytekit/clis/auth/auth.py → flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
from __future__ import annotations

import base64 as _base64
import hashlib as _hashlib
import http.server as _BaseHTTPServer
import logging
import multiprocessing
import os as _os
import re as _re
import typing
import urllib.parse as _urlparse
import webbrowser as _webbrowser
from dataclasses import dataclass
from http import HTTPStatus as _StatusCodes
from multiprocessing import get_context as _mp_get_context
from multiprocessing import get_context
from urllib.parse import urlencode as _urlencode

import keyring as _keyring
import requests as _requests

from flytekit.loggers import auth_logger
from .exceptions import AccessTokenNotFoundError
from .keyring import Credentials

_code_verifier_length = 64
_random_seed_length = 40
_utf_8 = "utf-8"


# Identifies the service used for storing passwords in keyring
_keyring_service_name = "flyteauth"
# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs
# suggest, we are storing a user's oidc.
_keyring_access_token_storage_key = "access_token"
_keyring_refresh_token_storage_key = "refresh_token"


def _generate_code_verifier():
"""
Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1
Expand Down Expand Up @@ -77,6 +75,17 @@ def state(self):
return self._state


@dataclass
class EndpointMetadata(object):
"""
This class can be used to control the rendering of the page on login successful or failure
"""

endpoint: str
success_html: typing.Optional[bytes] = None
failure_html: typing.Optional[bytes] = None


class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an
Expand All @@ -87,12 +96,28 @@ def do_GET(self):
url = _urlparse.urlparse(self.path)
if url.path.strip("/") == self.server.redirect_path.strip("/"):
self.send_response(_StatusCodes.OK)
self.send_header("Content-type", "text/html")
self.end_headers()
self.handle_login(dict(_urlparse.parse_qsl(url.query)))
if self.server.remote_metadata.success_html is None:
self.wfile.write(
bytes(
f"""
<html>
<head>
<title>Oauth2 authentication Flow</title>
</head>
<body>
<h1>Log in successful to {self.server.remote_metadata.endpoint}</h1>
</body></html>""",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add please close this window or something like that..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do, i will make it a configurable thing, also add Flyte icon

"utf-8",
)
)
self.wfile.close()
else:
self.send_response(_StatusCodes.NOT_FOUND)

def handle_login(self, data):
def handle_login(self, data: dict):
self.server.handle_authorization_code(AuthorizationCode(data["code"], data["state"]))


Expand All @@ -104,49 +129,79 @@ class OAuthHTTPServer(_BaseHTTPServer.HTTPServer):

def __init__(
self,
server_address,
RequestHandlerClass,
bind_and_activate=True,
redirect_path=None,
queue=None,
server_address: typing.Tuple[str, int],
remote_metadata: EndpointMetadata,
request_handler_class: typing.Type[_BaseHTTPServer.BaseHTTPRequestHandler],
bind_and_activate: bool = True,
redirect_path: str = None,
queue: multiprocessing.Queue = None,
):
_BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate)
_BaseHTTPServer.HTTPServer.__init__(self, server_address, request_handler_class, bind_and_activate)
self._redirect_path = redirect_path
self._remote_metadata = remote_metadata
self._auth_code = None
self._queue = queue

@property
def redirect_path(self):
def redirect_path(self) -> str:
return self._redirect_path

def handle_authorization_code(self, auth_code):
@property
def remote_metadata(self) -> EndpointMetadata:
return self._remote_metadata

def handle_authorization_code(self, auth_code: str):
self._queue.put(auth_code)
self.server_close()

def handle_request(self, queue=None):
def handle_request(self, queue: multiprocessing.Queue = None) -> typing.Any:
self._queue = queue
return super().handle_request()


class Credentials(object):
def __init__(self, access_token=None):
self._access_token = access_token
class _SingletonPerEndpoint(type):
"""
A metaclass to create per endpoint singletons for AuthorizationClient objects
"""

@property
def access_token(self):
return self._access_token
_instances: typing.Dict[str, AuthorizationClient] = {}

def __call__(cls, *args, **kwargs):
endpoint = ""
if args:
endpoint = args[0]
elif "auth_endpoint" in kwargs:
endpoint = kwargs["auth_endpoint"]
else:
raise ValueError("parameter auth_endpoint is required")
if endpoint not in cls._instances:
cls._instances[endpoint] = super(_SingletonPerEndpoint, cls).__call__(*args, **kwargs)
return cls._instances[endpoint]


class AuthorizationClient(object):
class AuthorizationClient(metaclass=_SingletonPerEndpoint):
"""
Authorization client that stores the credentials in keyring and uses oauth2 standard flow to retrieve the
credentials. NOTE: This will open an web browser to retreive the credentials.
"""

def __init__(
self,
auth_endpoint=None,
token_endpoint=None,
scopes=None,
client_id=None,
redirect_uri=None,
endpoint: str,
auth_endpoint: str,
token_endpoint: str,
scopes: typing.List[str] = None,
client_id: str = None,
redirect_uri: str = None,
endpoint_metadata: EndpointMetadata = None,
):
self._endpoint = endpoint
self._auth_endpoint = auth_endpoint
if endpoint_metadata is None:
remote_url = _urlparse.urlparse(self._auth_endpoint)
self._remote = EndpointMetadata(endpoint=remote_url.hostname)
else:
self._remote = endpoint_metadata
self._token_endpoint = token_endpoint
self._client_id = client_id
self._scopes = scopes or []
Expand All @@ -156,10 +211,7 @@ def __init__(
self._code_challenge = code_challenge
state = _generate_state_parameter()
self._state = state
self._credentials = None
self._refresh_token = None
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._expired = False

self._params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
Expand All @@ -174,55 +226,27 @@ def __init__(
"code_challenge_method": "S256",
}

# Prefer to use already-fetched token values when they've been set globally.
self.set_tokens_from_store()

def __repr__(self):
return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})"

def set_tokens_from_store(self):
self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key)
access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key)
if access_token:
self._credentials = Credentials(access_token=access_token)

@property
def has_valid_credentials(self) -> bool:
return self._credentials is not None

def start_authorization_flow(self):
# In the absence of globally-set token values, initiate the token request flow
ctx = _mp_get_context("fork")
q = ctx.Queue()

# First prepare the callback server in the background
server = self._create_callback_server()
server_process = ctx.Process(target=server.handle_request, args=(q,))
server_process.daemon = True
server_process.start()

# Send the call to request the authorization code in the background
self._request_authorization_code()

# Request the access token once the auth code has been received.
auth_code = q.get()
server_process.terminate()
self.request_access_token(auth_code)

def _create_callback_server(self):
server_url = _urlparse.urlparse(self._redirect_uri)
server_address = (server_url.hostname, server_url.port)
return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path)
return OAuthHTTPServer(
server_address,
self._remote,
OAuthCallbackHandler,
redirect_path=server_url.path,
)

def _request_authorization_code(self):
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
query = _urlencode(self._params)
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
auth_logger.debug(f"Requesting authorization code through {endpoint}")
logging.debug(f"Requesting authorization code through {endpoint}")
_webbrowser.open_new_tab(endpoint)

def _initialize_credentials(self, auth_token_resp):

def _credentials_from_response(self, auth_token_resp) -> Credentials:
"""
The auth_token_resp body is of the form:
{
Expand All @@ -232,21 +256,16 @@ def _initialize_credentials(self, auth_token_resp):
}
"""
response_body = auth_token_resp.json()
refresh_token = None
if "access_token" not in response_body:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
self._refresh_token = response_body["refresh_token"]
_keyring.set_password(
_keyring_service_name, _keyring_refresh_token_storage_key, response_body["refresh_token"]
)

refresh_token = response_body["refresh_token"]
access_token = response_body["access_token"]
_keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token)

# Once keyring credentials have been updated, get the singleton AuthorizationClient to read them again.
self.set_tokens_from_store()
return Credentials(access_token, refresh_token, self._endpoint)

def request_access_token(self, auth_code):
def _request_access_token(self, auth_code) -> Credentials:
if self._state != auth_code.state:
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed")
self._params.update(
Expand All @@ -269,38 +288,52 @@ def request_access_token(self, auth_code):
raise Exception(
"Failed to request access token with response: [{}] {}".format(resp.status_code, resp.content)
)
self._initialize_credentials(resp)
return self._credentials_from_response(resp)

def get_creds_from_remote(self) -> Credentials:
"""
This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to
retrieve credentials
"""
# In the absence of globally-set token values, initiate the token request flow
ctx = get_context("fork")
q = ctx.Queue()

# First prepare the callback server in the background
server = self._create_callback_server()

def refresh_access_token(self):
if self._refresh_token is None:
server_process = ctx.Process(target=server.handle_request, args=(q,))
server_process.daemon = True

try:
server_process.start()

# Send the call to request the authorization code in the background
self._request_authorization_code()

# Request the access token once the auth code has been received.
auth_code = q.get()
return self._request_access_token(auth_code)
finally:
server_process.terminate()

def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
raise ValueError("no refresh token available with which to refresh authorization credentials")

resp = _requests.post(
url=self._token_endpoint,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider passing verify switch to request calls as in: #1509

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me incorporate your change

data={"grant_type": "refresh_token", "client_id": self._client_id, "refresh_token": self._refresh_token},
data={
"grant_type": "refresh_token",
"client_id": self._client_id,
"refresh_token": credentials.refresh_token,
},
headers=self._headers,
allow_redirects=False,
)
if resp.status_code != _StatusCodes.OK:
self._expired = True
# In the absence of a successful response, assume the refresh token is expired. This should indicate
# to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized.
raise AccessTokenNotFoundError(f"Non-200 returned from refresh token endpoint {resp.status_code}")

_keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key)
_keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key)
raise ValueError(f"Non-200 returned from refresh token endpoint {resp.status_code}")
self._initialize_credentials(resp)

@property
def credentials(self):
"""
:return flytekit.clis.auth.auth.Credentials:
"""
return self._credentials

@property
def expired(self):
"""
:return bool:
"""
return self._expired
return self._credentials_from_response(resp)
Loading