diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 7c4439d83d..806a0af3e3 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -246,7 +246,7 @@ def _refresh_credentials_basic(self): raise FlyteAuthenticationException("No client credentials secret provided in the config") cli_logger.debug(f"Basic authorization flow with client id {self._cfg.client_id} scope {scopes}") authorization_header = get_basic_authorization_header(self._cfg.client_id, client_secret) - token, expires_in = get_token(token_endpoint, authorization_header, scopes) + token, expires_in = get_token(token_endpoint, authorization_header, scopes, audience=self._cfg.audience) cli_logger.info("Retrieved new token, expires in {}".format(expires_in)) authorization_header_key = self.public_client_config.authorization_metadata_key or None self.set_access_token(token, authorization_header_key) @@ -865,7 +865,7 @@ def create_download_location( return self._dataproxy_stub.CreateDownloadLocation(create_download_location_request, metadata=self._metadata) -def get_token(token_endpoint, authorization_header, scope): +def get_token(token_endpoint, authorization_header, scope, audience=None): """ :param Text token_endpoint: :param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123') @@ -884,6 +884,8 @@ def get_token(token_endpoint, authorization_header, scope): } if scope is not None: body["scope"] = scope + if audience is not None: + body["audience"] = audience response = _requests.post(token_endpoint, data=body, headers=headers) if response.status_code != 200: cli_logger.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index cb228b788a..d3a1f63088 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -308,6 +308,7 @@ class PlatformConfig(object): allow the Flyte engine to read the password directly from the environment variable. Note that this is less secure! Please only use this if mounting the secret as a file is impossible. :param scopes: List of scopes to request. This is only applicable to the client credentials flow. + :param audience: The audience to request. This is only applicable to the client credentials flow. :param auth_mode: The OAuth mode to use. Defaults to pkce flow. """ @@ -319,6 +320,7 @@ class PlatformConfig(object): client_id: typing.Optional[str] = None client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) + audience: typing.Optional[str] = None auth_mode: AuthType = AuthType.STANDARD @classmethod @@ -352,6 +354,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None client_credentials_secret, ) kwargs = set_if_exists(kwargs, "scopes", _internal.Credentials.SCOPES.read(config_file)) + kwargs = set_if_exists(kwargs, "audience", _internal.Credentials.AUDIENCE.read(config_file)) kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file)) kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file)) diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 5c29045db5..67b776342e 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -86,6 +86,8 @@ class Credentials(object): """ SCOPES = ConfigEntry(LegacyConfigEntry(SECTION, "scopes", list), YamlConfigEntry("admin.scopes", list)) + + AUDIENCE = ConfigEntry(LegacyConfigEntry(SECTION, "audience"), YamlConfigEntry("admin.audience")) AUTH_MODE = ConfigEntry(LegacyConfigEntry(SECTION, "auth_mode"), YamlConfigEntry("admin.authType")) """