Skip to content

Commit

Permalink
Fix race conditions in the Authentication client (#2635)
Browse files Browse the repository at this point in the history
* Fix race conditions in the Authentication cliente

Signed-off-by: Robert Deaton <[email protected]>

* Update flytekit/clients/auth/auth_client.py

Co-authored-by: Thomas J. Fan <[email protected]>

---------

Signed-off-by: Robert Deaton <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
  • Loading branch information
rdeaton-freenome and thomasjpfan authored Aug 15, 2024
1 parent abb5219 commit 6ababc9
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
import os
import re
import threading
import time
import typing
import urllib.parse as _urlparse
import webbrowser
Expand Down Expand Up @@ -236,6 +238,9 @@ def __init__(
self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._session = session or requests.Session()
self._lock = threading.Lock()
self._cached_credentials = None
self._cached_credentials_ts = None

self._request_auth_code_params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
Expand Down Expand Up @@ -339,25 +344,38 @@ def _request_access_token(self, auth_code) -> Credentials:

def get_creds_from_remote(self) -> Credentials:
"""
This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to
retrieve credentials
This is the entrypoint method. It will kickoff the full authentication
flow and trigger a web-browser to retrieve credentials. Because this
needs to open a port on localhost and may be called from a
multithreaded context (e.g. pyflyte register), this call may block
multiple threads and return a cached result for up to 60 seconds.
"""
# In the absence of globally-set token values, initiate the token request flow
q = Queue()
with self._lock:
# Clear cache if it's been more than 60 seconds since the last check
cache_ttl_s = 60
if self._cached_credentials_ts is not None and self._cached_credentials_ts + cache_ttl_s < time.monotonic():
self._cached_credentials = None

# First prepare the callback server in the background
server = self._create_callback_server()
if self._cached_credentials is not None:
return self._cached_credentials
q = Queue()

self._request_authorization_code()
# First prepare the callback server in the background
server = self._create_callback_server()

server.handle_request(q)
server.server_close()
self._request_authorization_code()

# Send the call to request the authorization code in the background
server.handle_request(q)
server.server_close()

# Request the access token once the auth code has been received.
auth_code = q.get()
return self._request_access_token(auth_code)
# Send the call to request the authorization code in the background

# Request the access token once the auth code has been received.
auth_code = q.get()
self._cached_credentials = self._request_access_token(auth_code)
self._cached_credentials_ts = time.monotonic()
return self._cached_credentials

def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
Expand Down

0 comments on commit 6ababc9

Please sign in to comment.