From c65397dc8108ae287d9e0a77d1a04d3d485bf3f4 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Mon, 15 May 2023 22:06:40 -0700 Subject: [PATCH] Add http_proxy to client & Fix deviceflow (#1611) * Add http_proxy to client & Fix deviceflow RB=3890720 Signed-off-by: byhsu * nit Signed-off-by: byhsu * lint! Signed-off-by: byhsu --------- Signed-off-by: byhsu Co-authored-by: byhsu --- flytekit/clients/auth/authenticator.py | 31 ++++++++++++++----- flytekit/clients/auth/token_client.py | 25 +++++++++------ flytekit/clients/auth_helper.py | 5 ++- flytekit/configuration/__init__.py | 3 ++ flytekit/configuration/internal.py | 1 + .../unit/clients/auth/test_authenticator.py | 14 ++++----- .../unit/clients/auth/test_token_client.py | 21 ++++++------- 7 files changed, 64 insertions(+), 36 deletions(-) diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 9582c901d8..bde25602e0 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -48,10 +48,17 @@ class Authenticator(object): Base authenticator for all authentication flows """ - def __init__(self, endpoint: str, header_key: str, credentials: Credentials = None): + def __init__( + self, + endpoint: str, + header_key: str, + credentials: Credentials = None, + http_proxy_url: typing.Optional[str] = None, + ): self._endpoint = endpoint self._creds = credentials self._header_key = header_key if header_key else "authorization" + self._http_proxy_url = http_proxy_url def get_credentials(self) -> Credentials: return self._creds @@ -162,6 +169,7 @@ def __init__( cfg_store: ClientConfigStore, header_key: typing.Optional[str] = None, scopes: typing.Optional[typing.List[str]] = None, + http_proxy_url: typing.Optional[str] = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") @@ -171,7 +179,7 @@ def __init__( self._scopes = scopes or cfg.scopes self._client_id = client_id self._client_secret = client_secret - super().__init__(endpoint, cfg.header_key or header_key) + super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url) def refresh_credentials(self): """ @@ -187,7 +195,9 @@ def refresh_credentials(self): # Note that unlike the Pkce flow, the client ID does not come from Admin. logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}") authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret) - token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header) + token, expires_in = token_client.get_token( + token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url + ) logging.info("Retrieved new token, expires in {}".format(expires_in)) self._creds = Credentials(token) @@ -207,6 +217,7 @@ def __init__( cfg_store: ClientConfigStore, header_key: typing.Optional[str] = None, audience: typing.Optional[str] = None, + http_proxy_url: typing.Optional[str] = None, ): self._audience = audience cfg = cfg_store.get_client_config() @@ -219,21 +230,27 @@ def __init__( "Device Authentication is not available on the Flyte backend / authentication server" ) super().__init__( - endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint) + endpoint=endpoint, + header_key=header_key or cfg.header_key, + credentials=KeyringStore.retrieve(endpoint), + http_proxy_url=http_proxy_url, ) def refresh_credentials(self): - resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope) + resp = token_client.get_device_code( + self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url + ) print( f""" To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code} -OR copy paste the following URL: {resp.verification_uri_complete} """ ) try: # Currently the refresh token is not retreived. We may want to add support for refreshTokens so that # access tokens can be refreshed for once authenticated machines - token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id) + token, expires_in = token_client.poll_token_endpoint( + resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url + ) self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint) KeyringStore.store(self._creds) except Exception: diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py index 7cbb42a13e..728a56b5f7 100644 --- a/flytekit/clients/auth/token_client.py +++ b/flytekit/clients/auth/token_client.py @@ -9,6 +9,7 @@ import requests +from flytekit import logger from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending utf_8 = "utf-8" @@ -31,7 +32,6 @@ class DeviceCodeResponse: {'device_code': 'code', 'user_code': 'BNDJJFXL', 'verification_uri': 'url', - 'verification_uri_complete': 'url', 'expires_in': 600, 'interval': 5} """ @@ -39,7 +39,6 @@ class DeviceCodeResponse: device_code: str user_code: str verification_uri: str - verification_uri_complete: str expires_in: int interval: int @@ -49,7 +48,6 @@ def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse": device_code=j["device_code"], user_code=j["user_code"], verification_uri=j["verification_uri"], - verification_uri_complete=j["verification_uri_complete"], expires_in=j["expires_in"], interval=j["interval"], ) @@ -77,6 +75,7 @@ def get_token( client_id: typing.Optional[str] = None, device_code: typing.Optional[str] = None, grant_type: GrantType = GrantType.CLIENT_CREDS, + http_proxy_url: typing.Optional[str] = None, ) -> typing.Tuple[str, int]: """ :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration @@ -99,7 +98,8 @@ def get_token( if scopes is not None: body["scope"] = ",".join(scopes) - response = requests.post(token_endpoint, data=body, headers=headers) + proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None + response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies) if not response.ok: j = response.json() if "error" in j: @@ -118,19 +118,24 @@ def get_device_code( client_id: str, audience: typing.Optional[str] = None, scope: typing.Optional[typing.List[str]] = None, + http_proxy_url: typing.Optional[str] = None, ) -> DeviceCodeResponse: """ Retrieves the device Authentication code that can be done to authenticate the request using a browser on a separate device """ - payload = {"client_id": client_id, "scope": scope, "audience": audience} - resp = requests.post(device_auth_endpoint, payload) + _scope = " ".join(scope) if scope is not None else "" + payload = {"client_id": client_id, "scope": _scope, "audience": audience} + proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None + resp = requests.post(device_auth_endpoint, payload, proxies=proxies) if not resp.ok: raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}") return DeviceCodeResponse.from_json_response(resp.json()) -def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]: +def poll_token_endpoint( + resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None +) -> typing.Tuple[str, int]: tick = datetime.now() interval = timedelta(seconds=resp.interval) end_time = tick + timedelta(seconds=resp.expires_in) @@ -141,13 +146,15 @@ def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id grant_type=GrantType.DEVICE_CODE, client_id=client_id, device_code=resp.device_code, + http_proxy_url=http_proxy_url, ) print("Authentication successful!") return access_token, expires_in except AuthenticationPending: ... - except Exception: - raise + except Exception as e: + logger.error("Authentication attempt failed: ", e) + raise e print("Authentication Pending...") time.sleep(interval.total_seconds()) tick = tick + interval diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 93bd883324..a068e487cf 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -72,6 +72,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth client_secret=cfg.client_credentials_secret, cfg_store=cfg_store, scopes=cfg.scopes, + http_proxy_url=cfg.http_proxy_url, ) elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND: client_cfg = None @@ -82,7 +83,9 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth header_key=client_cfg.header_key if client_cfg else None, ) elif cfg_auth == AuthType.DEVICEFLOW: - return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience) + return DeviceCodeAuthenticator( + endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url + ) else: raise ValueError( f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value" diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 4785bd228c..a7e2c69ebd 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -376,6 +376,7 @@ class PlatformConfig(object): :param scopes: List of scopes to request. This is only applicable to the client credentials flow :param auth_mode: The OAuth mode to use. Defaults to pkce flow :param ca_cert_file_path: [optional] str Root Cert to be loaded and used to verify admin + :param http_proxy_url: [optional] HTTP Proxy to be used for OAuth requests """ endpoint: str = "localhost:30080" @@ -390,6 +391,7 @@ class PlatformConfig(object): auth_mode: AuthType = AuthType.STANDARD audience: typing.Optional[str] = None rpc_retries: int = 3 + http_proxy_url: typing.Optional[str] = None @classmethod def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig: @@ -426,6 +428,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file)) kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file)) + kwargs = set_if_exists(kwargs, "http_proxy_url", _internal.Platform.HTTP_PROXY_URL.read(config_file)) return PlatformConfig(**kwargs) @classmethod diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 5c3729e63b..4f993b4e11 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -111,6 +111,7 @@ class Platform(object): CA_CERT_FILE_PATH = ConfigEntry( LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath") ) + HTTP_PROXY_URL = ConfigEntry(LegacyConfigEntry(SECTION, "http_proxy_url"), YamlConfigEntry("admin.httpProxyURL")) class LocalSDK(object): diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index fdbddb2ebe..495f4648ac 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -70,7 +70,11 @@ def test_command_authenticator(mock_subprocess: MagicMock): @patch("flytekit.clients.auth.token_client.requests") def test_client_creds_authenticator(mock_requests): authn = ClientCredentialsAuthenticator( - ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store + ENDPOINT, + client_id="client", + client_secret="secret", + cfg_store=static_cfg_store, + http_proxy_url="https://my-proxy:31111", ) response = MagicMock() @@ -103,13 +107,9 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, device_authorization_endpoint="dev", ) ) - authn = DeviceCodeAuthenticator( - ENDPOINT, - cfg_store, - audience="x", - ) + authn = DeviceCodeAuthenticator(ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000") - device_mock.return_value = DeviceCodeResponse("x", "y", "s", "m", 1000, 0) + device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0) poll_mock.return_value = ("access", 100) authn.refresh_credentials() assert authn._creds diff --git a/tests/flytekit/unit/clients/auth/test_token_client.py b/tests/flytekit/unit/clients/auth/test_token_client.py index c22284cd38..f0c10b16d1 100644 --- a/tests/flytekit/unit/clients/auth/test_token_client.py +++ b/tests/flytekit/unit/clients/auth/test_token_client.py @@ -28,7 +28,9 @@ def test_get_token(mock_requests): response.status_code = 200 response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") mock_requests.post.return_value = response - access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"]) + access, expiration = get_token( + "https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000" + ) assert access == "abc" assert expiration == 60 @@ -39,19 +41,18 @@ def test_get_device_code(mock_requests): response.ok = False mock_requests.post.return_value = response with pytest.raises(AuthenticationError): - get_device_code("test.com", "test") + get_device_code("test.com", "test", http_proxy_url="http://proxy:3000") response.ok = True response.json.return_value = { "device_code": "code", "user_code": "BNDJJFXL", "verification_uri": "url", - "verification_uri_complete": "url", "expires_in": 600, "interval": 5, } mock_requests.post.return_value = response - c = get_device_code("test.com", "test") + c = get_device_code("test.com", "test", http_proxy_url="http://proxy:3000") assert c assert c.device_code == "code" @@ -63,19 +64,15 @@ def test_poll_token_endpoint(mock_requests): response.json.return_value = {"error": error_auth_pending} mock_requests.post.return_value = response - r = DeviceCodeResponse( - device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=1 - ) + r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1) with pytest.raises(AuthenticationError): - poll_token_endpoint(r, "test.com", "test") + poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000") response = MagicMock() response.ok = True response.json.return_value = {"access_token": "abc", "expires_in": 60} mock_requests.post.return_value = response - r = DeviceCodeResponse( - device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=0 - ) - t, e = poll_token_endpoint(r, "test.com", "test") + r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0) + t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000") assert t assert e