Skip to content

Commit

Permalink
Add http_proxy to client & Fix deviceflow (#1611)
Browse files Browse the repository at this point in the history
* Add http_proxy to client & Fix deviceflow

RB=3890720

Signed-off-by: byhsu <[email protected]>

* nit

Signed-off-by: byhsu <[email protected]>

* lint!

Signed-off-by: byhsu <[email protected]>

---------

Signed-off-by: byhsu <[email protected]>
Co-authored-by: byhsu <[email protected]>
  • Loading branch information
ByronHsu and ByronHsu authored May 16, 2023
1 parent a10a3bc commit c65397d
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 36 deletions.
31 changes: 24 additions & 7 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,17 @@ class Authenticator(object):
Base authenticator for all authentication flows
"""

def __init__(self, endpoint: str, header_key: str, credentials: Credentials = None):
def __init__(
self,
endpoint: str,
header_key: str,
credentials: Credentials = None,
http_proxy_url: typing.Optional[str] = None,
):
self._endpoint = endpoint
self._creds = credentials
self._header_key = header_key if header_key else "authorization"
self._http_proxy_url = http_proxy_url

def get_credentials(self) -> Credentials:
return self._creds
Expand Down Expand Up @@ -162,6 +169,7 @@ def __init__(
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
scopes: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -171,7 +179,7 @@ def __init__(
self._scopes = scopes or cfg.scopes
self._client_id = client_id
self._client_secret = client_secret
super().__init__(endpoint, cfg.header_key or header_key)
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url)

def refresh_credentials(self):
"""
Expand All @@ -187,7 +195,9 @@ def refresh_credentials(self):
# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)
token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header)
token, expires_in = token_client.get_token(
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url
)
logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)

Expand All @@ -207,6 +217,7 @@ def __init__(
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
http_proxy_url: typing.Optional[str] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
Expand All @@ -219,21 +230,27 @@ def __init__(
"Device Authentication is not available on the Flyte backend / authentication server"
)
super().__init__(
endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint)
endpoint=endpoint,
header_key=header_key or cfg.header_key,
credentials=KeyringStore.retrieve(endpoint),
http_proxy_url=http_proxy_url,
)

def refresh_credentials(self):
resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope)
resp = token_client.get_device_code(
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url
)
print(
f"""
To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code}
OR copy paste the following URL: {resp.verification_uri_complete}
"""
)
try:
# Currently the refresh token is not retreived. We may want to add support for refreshTokens so that
# access tokens can be refreshed for once authenticated machines
token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id)
token, expires_in = token_client.poll_token_endpoint(
resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url
)
self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint)
KeyringStore.store(self._creds)
except Exception:
Expand Down
25 changes: 16 additions & 9 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import requests

from flytekit import logger
from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending

utf_8 = "utf-8"
Expand All @@ -31,15 +32,13 @@ class DeviceCodeResponse:
{'device_code': 'code',
'user_code': 'BNDJJFXL',
'verification_uri': 'url',
'verification_uri_complete': 'url',
'expires_in': 600,
'interval': 5}
"""

device_code: str
user_code: str
verification_uri: str
verification_uri_complete: str
expires_in: int
interval: int

Expand All @@ -49,7 +48,6 @@ def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse":
device_code=j["device_code"],
user_code=j["user_code"],
verification_uri=j["verification_uri"],
verification_uri_complete=j["verification_uri_complete"],
expires_in=j["expires_in"],
interval=j["interval"],
)
Expand Down Expand Up @@ -77,6 +75,7 @@ def get_token(
client_id: typing.Optional[str] = None,
device_code: typing.Optional[str] = None,
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
Expand All @@ -99,7 +98,8 @@ def get_token(
if scopes is not None:
body["scope"] = ",".join(scopes)

response = requests.post(token_endpoint, data=body, headers=headers)
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies)
if not response.ok:
j = response.json()
if "error" in j:
Expand All @@ -118,19 +118,24 @@ def get_device_code(
client_id: str,
audience: typing.Optional[str] = None,
scope: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
separate device
"""
payload = {"client_id": client_id, "scope": scope, "audience": audience}
resp = requests.post(device_auth_endpoint, payload)
_scope = " ".join(scope) if scope is not None else ""
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
resp = requests.post(device_auth_endpoint, payload, proxies=proxies)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())


def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]:
def poll_token_endpoint(
resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None
) -> typing.Tuple[str, int]:
tick = datetime.now()
interval = timedelta(seconds=resp.interval)
end_time = tick + timedelta(seconds=resp.expires_in)
Expand All @@ -141,13 +146,15 @@ def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id
grant_type=GrantType.DEVICE_CODE,
client_id=client_id,
device_code=resp.device_code,
http_proxy_url=http_proxy_url,
)
print("Authentication successful!")
return access_token, expires_in
except AuthenticationPending:
...
except Exception:
raise
except Exception as e:
logger.error("Authentication attempt failed: ", e)
raise e
print("Authentication Pending...")
time.sleep(interval.total_seconds())
tick = tick + interval
Expand Down
5 changes: 4 additions & 1 deletion flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
client_secret=cfg.client_credentials_secret,
cfg_store=cfg_store,
scopes=cfg.scopes,
http_proxy_url=cfg.http_proxy_url,
)
elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND:
client_cfg = None
Expand All @@ -82,7 +83,9 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
header_key=client_cfg.header_key if client_cfg else None,
)
elif cfg_auth == AuthType.DEVICEFLOW:
return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience)
return DeviceCodeAuthenticator(
endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url
)
else:
raise ValueError(
f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value"
Expand Down
3 changes: 3 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ class PlatformConfig(object):
:param scopes: List of scopes to request. This is only applicable to the client credentials flow
:param auth_mode: The OAuth mode to use. Defaults to pkce flow
:param ca_cert_file_path: [optional] str Root Cert to be loaded and used to verify admin
:param http_proxy_url: [optional] HTTP Proxy to be used for OAuth requests
"""

endpoint: str = "localhost:30080"
Expand All @@ -390,6 +391,7 @@ class PlatformConfig(object):
auth_mode: AuthType = AuthType.STANDARD
audience: typing.Optional[str] = None
rpc_retries: int = 3
http_proxy_url: typing.Optional[str] = None

@classmethod
def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig:
Expand Down Expand Up @@ -426,6 +428,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file))
kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file))
kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file))
kwargs = set_if_exists(kwargs, "http_proxy_url", _internal.Platform.HTTP_PROXY_URL.read(config_file))
return PlatformConfig(**kwargs)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Platform(object):
CA_CERT_FILE_PATH = ConfigEntry(
LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath")
)
HTTP_PROXY_URL = ConfigEntry(LegacyConfigEntry(SECTION, "http_proxy_url"), YamlConfigEntry("admin.httpProxyURL"))


class LocalSDK(object):
Expand Down
14 changes: 7 additions & 7 deletions tests/flytekit/unit/clients/auth/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def test_command_authenticator(mock_subprocess: MagicMock):
@patch("flytekit.clients.auth.token_client.requests")
def test_client_creds_authenticator(mock_requests):
authn = ClientCredentialsAuthenticator(
ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store
ENDPOINT,
client_id="client",
client_secret="secret",
cfg_store=static_cfg_store,
http_proxy_url="https://my-proxy:31111",
)

response = MagicMock()
Expand Down Expand Up @@ -103,13 +107,9 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock,
device_authorization_endpoint="dev",
)
)
authn = DeviceCodeAuthenticator(
ENDPOINT,
cfg_store,
audience="x",
)
authn = DeviceCodeAuthenticator(ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000")

device_mock.return_value = DeviceCodeResponse("x", "y", "s", "m", 1000, 0)
device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0)
poll_mock.return_value = ("access", 100)
authn.refresh_credentials()
assert authn._creds
Expand Down
21 changes: 9 additions & 12 deletions tests/flytekit/unit/clients/auth/test_token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def test_get_token(mock_requests):
response.status_code = 200
response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
mock_requests.post.return_value = response
access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"])
access, expiration = get_token(
"https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000"
)
assert access == "abc"
assert expiration == 60

Expand All @@ -39,19 +41,18 @@ def test_get_device_code(mock_requests):
response.ok = False
mock_requests.post.return_value = response
with pytest.raises(AuthenticationError):
get_device_code("test.com", "test")
get_device_code("test.com", "test", http_proxy_url="http://proxy:3000")

response.ok = True
response.json.return_value = {
"device_code": "code",
"user_code": "BNDJJFXL",
"verification_uri": "url",
"verification_uri_complete": "url",
"expires_in": 600,
"interval": 5,
}
mock_requests.post.return_value = response
c = get_device_code("test.com", "test")
c = get_device_code("test.com", "test", http_proxy_url="http://proxy:3000")
assert c
assert c.device_code == "code"

Expand All @@ -63,19 +64,15 @@ def test_poll_token_endpoint(mock_requests):
response.json.return_value = {"error": error_auth_pending}
mock_requests.post.return_value = response

r = DeviceCodeResponse(
device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=1
)
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1)
with pytest.raises(AuthenticationError):
poll_token_endpoint(r, "test.com", "test")
poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")

response = MagicMock()
response.ok = True
response.json.return_value = {"access_token": "abc", "expires_in": 60}
mock_requests.post.return_value = response
r = DeviceCodeResponse(
device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=0
)
t, e = poll_token_endpoint(r, "test.com", "test")
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0)
t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")
assert t
assert e

0 comments on commit c65397d

Please sign in to comment.