Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add http_proxy to client & Fix deviceflow #1611

Merged
merged 3 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,17 @@ class Authenticator(object):
Base authenticator for all authentication flows
"""

def __init__(self, endpoint: str, header_key: str, credentials: Credentials = None):
def __init__(
self,
endpoint: str,
header_key: str,
credentials: Credentials = None,
http_proxy_url: typing.Optional[str] = None,
):
self._endpoint = endpoint
self._creds = credentials
self._header_key = header_key if header_key else "authorization"
self._http_proxy_url = http_proxy_url

def get_credentials(self) -> Credentials:
return self._creds
Expand Down Expand Up @@ -162,6 +169,7 @@ def __init__(
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
scopes: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -171,7 +179,7 @@ def __init__(
self._scopes = scopes or cfg.scopes
self._client_id = client_id
self._client_secret = client_secret
super().__init__(endpoint, cfg.header_key or header_key)
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url)

def refresh_credentials(self):
"""
Expand All @@ -187,7 +195,9 @@ def refresh_credentials(self):
# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)
token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header)
token, expires_in = token_client.get_token(
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url
)
logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)

Expand All @@ -207,6 +217,7 @@ def __init__(
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
http_proxy_url: typing.Optional[str] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
Expand All @@ -219,21 +230,27 @@ def __init__(
"Device Authentication is not available on the Flyte backend / authentication server"
)
super().__init__(
endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint)
endpoint=endpoint,
header_key=header_key or cfg.header_key,
credentials=KeyringStore.retrieve(endpoint),
http_proxy_url=http_proxy_url,
)

def refresh_credentials(self):
resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope)
resp = token_client.get_device_code(
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url
)
print(
f"""
To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code}
OR copy paste the following URL: {resp.verification_uri_complete}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check for this field's existence, and show it if it's there? Even if it's not part of the standard, it's nice to have.

Copy link
Collaborator Author

@ByronHsu ByronHsu May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think it should be there because uri_complete is not in standard. resp.verification_uri + resp.user_code is enough for login.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but okta supports it. and lots of people use okta. it's not in the standard but i feel it's okay to support a nice to have.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kumare3 wdyt ^^

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just add back and merge?

"""
)
try:
# Currently the refresh token is not retreived. We may want to add support for refreshTokens so that
# access tokens can be refreshed for once authenticated machines
token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id)
token, expires_in = token_client.poll_token_endpoint(
resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url
)
self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint)
KeyringStore.store(self._creds)
except Exception:
Expand Down
25 changes: 16 additions & 9 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import requests

from flytekit import logger
from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending

utf_8 = "utf-8"
Expand All @@ -31,15 +32,13 @@ class DeviceCodeResponse:
{'device_code': 'code',
'user_code': 'BNDJJFXL',
'verification_uri': 'url',
'verification_uri_complete': 'url',
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
'expires_in': 600,
'interval': 5}
"""

device_code: str
user_code: str
verification_uri: str
verification_uri_complete: str
expires_in: int
interval: int

Expand All @@ -49,7 +48,6 @@ def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse":
device_code=j["device_code"],
user_code=j["user_code"],
verification_uri=j["verification_uri"],
verification_uri_complete=j["verification_uri_complete"],
expires_in=j["expires_in"],
interval=j["interval"],
)
Expand Down Expand Up @@ -77,6 +75,7 @@ def get_token(
client_id: typing.Optional[str] = None,
device_code: typing.Optional[str] = None,
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
Expand All @@ -99,7 +98,8 @@ def get_token(
if scopes is not None:
body["scope"] = ",".join(scopes)

response = requests.post(token_endpoint, data=body, headers=headers)
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies)
if not response.ok:
j = response.json()
if "error" in j:
Expand All @@ -118,19 +118,24 @@ def get_device_code(
client_id: str,
audience: typing.Optional[str] = None,
scope: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
separate device
"""
payload = {"client_id": client_id, "scope": scope, "audience": audience}
resp = requests.post(device_auth_endpoint, payload)
_scope = " ".join(scope) if scope is not None else ""
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
resp = requests.post(device_auth_endpoint, payload, proxies=proxies)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())


def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]:
def poll_token_endpoint(
resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None
) -> typing.Tuple[str, int]:
tick = datetime.now()
interval = timedelta(seconds=resp.interval)
end_time = tick + timedelta(seconds=resp.expires_in)
Expand All @@ -141,13 +146,15 @@ def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id
grant_type=GrantType.DEVICE_CODE,
client_id=client_id,
device_code=resp.device_code,
http_proxy_url=http_proxy_url,
)
print("Authentication successful!")
return access_token, expires_in
except AuthenticationPending:
...
except Exception:
raise
except Exception as e:
logger.error("Authentication attempt failed: ", e)
raise e
print("Authentication Pending...")
time.sleep(interval.total_seconds())
tick = tick + interval
Expand Down
5 changes: 4 additions & 1 deletion flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
client_secret=cfg.client_credentials_secret,
cfg_store=cfg_store,
scopes=cfg.scopes,
http_proxy_url=cfg.http_proxy_url,
)
elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND:
client_cfg = None
Expand All @@ -82,7 +83,9 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
header_key=client_cfg.header_key if client_cfg else None,
)
elif cfg_auth == AuthType.DEVICEFLOW:
return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience)
return DeviceCodeAuthenticator(
endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url
)
else:
raise ValueError(
f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value"
Expand Down
3 changes: 3 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ class PlatformConfig(object):
:param scopes: List of scopes to request. This is only applicable to the client credentials flow
:param auth_mode: The OAuth mode to use. Defaults to pkce flow
:param ca_cert_file_path: [optional] str Root Cert to be loaded and used to verify admin
:param http_proxy_url: [optional] HTTP Proxy to be used for OAuth requests
"""

endpoint: str = "localhost:30080"
Expand All @@ -390,6 +391,7 @@ class PlatformConfig(object):
auth_mode: AuthType = AuthType.STANDARD
audience: typing.Optional[str] = None
rpc_retries: int = 3
http_proxy_url: typing.Optional[str] = None

@classmethod
def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig:
Expand Down Expand Up @@ -426,6 +428,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file))
kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file))
kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file))
kwargs = set_if_exists(kwargs, "http_proxy_url", _internal.Platform.HTTP_PROXY_URL.read(config_file))
return PlatformConfig(**kwargs)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Platform(object):
CA_CERT_FILE_PATH = ConfigEntry(
LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath")
)
HTTP_PROXY_URL = ConfigEntry(LegacyConfigEntry(SECTION, "http_proxy_url"), YamlConfigEntry("admin.httpProxyURL"))


class LocalSDK(object):
Expand Down
14 changes: 7 additions & 7 deletions tests/flytekit/unit/clients/auth/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def test_command_authenticator(mock_subprocess: MagicMock):
@patch("flytekit.clients.auth.token_client.requests")
def test_client_creds_authenticator(mock_requests):
authn = ClientCredentialsAuthenticator(
ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store
ENDPOINT,
client_id="client",
client_secret="secret",
cfg_store=static_cfg_store,
http_proxy_url="https://my-proxy:31111",
)

response = MagicMock()
Expand Down Expand Up @@ -103,13 +107,9 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock,
device_authorization_endpoint="dev",
)
)
authn = DeviceCodeAuthenticator(
ENDPOINT,
cfg_store,
audience="x",
)
authn = DeviceCodeAuthenticator(ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000")

device_mock.return_value = DeviceCodeResponse("x", "y", "s", "m", 1000, 0)
device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0)
poll_mock.return_value = ("access", 100)
authn.refresh_credentials()
assert authn._creds
Expand Down
21 changes: 9 additions & 12 deletions tests/flytekit/unit/clients/auth/test_token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def test_get_token(mock_requests):
response.status_code = 200
response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
mock_requests.post.return_value = response
access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"])
access, expiration = get_token(
"https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000"
)
assert access == "abc"
assert expiration == 60

Expand All @@ -39,19 +41,18 @@ def test_get_device_code(mock_requests):
response.ok = False
mock_requests.post.return_value = response
with pytest.raises(AuthenticationError):
get_device_code("test.com", "test")
get_device_code("test.com", "test", http_proxy_url="http://proxy:3000")

response.ok = True
response.json.return_value = {
"device_code": "code",
"user_code": "BNDJJFXL",
"verification_uri": "url",
"verification_uri_complete": "url",
"expires_in": 600,
"interval": 5,
}
mock_requests.post.return_value = response
c = get_device_code("test.com", "test")
c = get_device_code("test.com", "test", http_proxy_url="http://proxy:3000")
assert c
assert c.device_code == "code"

Expand All @@ -63,19 +64,15 @@ def test_poll_token_endpoint(mock_requests):
response.json.return_value = {"error": error_auth_pending}
mock_requests.post.return_value = response

r = DeviceCodeResponse(
device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=1
)
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1)
with pytest.raises(AuthenticationError):
poll_token_endpoint(r, "test.com", "test")
poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")

response = MagicMock()
response.ok = True
response.json.return_value = {"access_token": "abc", "expires_in": 60}
mock_requests.post.return_value = response
r = DeviceCodeResponse(
device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=0
)
t, e = poll_token_endpoint(r, "test.com", "test")
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0)
t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")
assert t
assert e