Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(auth): add audience for auth #1403

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _refresh_credentials_basic(self):
raise FlyteAuthenticationException("No client credentials secret provided in the config")
cli_logger.debug(f"Basic authorization flow with client id {self._cfg.client_id} scope {scopes}")
authorization_header = get_basic_authorization_header(self._cfg.client_id, client_secret)
token, expires_in = get_token(token_endpoint, authorization_header, scopes)
token, expires_in = get_token(token_endpoint, authorization_header, scopes, audience=self._cfg.audience)
cli_logger.info("Retrieved new token, expires in {}".format(expires_in))
authorization_header_key = self.public_client_config.authorization_metadata_key or None
self.set_access_token(token, authorization_header_key)
Expand Down Expand Up @@ -865,7 +865,7 @@ def create_download_location(
return self._dataproxy_stub.CreateDownloadLocation(create_download_location_request, metadata=self._metadata)


def get_token(token_endpoint, authorization_header, scope):
def get_token(token_endpoint, authorization_header, scope, audience=None):
"""
:param Text token_endpoint:
:param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123')
Expand All @@ -884,6 +884,8 @@ def get_token(token_endpoint, authorization_header, scope):
}
if scope is not None:
body["scope"] = scope
if audience is not None:
body["audience"] = audience
response = _requests.post(token_endpoint, data=body, headers=headers)
if response.status_code != 200:
cli_logger.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text))
Expand Down
3 changes: 3 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ class PlatformConfig(object):
allow the Flyte engine to read the password directly from the environment variable. Note that this is
less secure! Please only use this if mounting the secret as a file is impossible.
:param scopes: List of scopes to request. This is only applicable to the client credentials flow.
:param audience: The audience to request. This is only applicable to the client credentials flow.
:param auth_mode: The OAuth mode to use. Defaults to pkce flow.
"""

Expand All @@ -319,6 +320,7 @@ class PlatformConfig(object):
client_id: typing.Optional[str] = None
client_credentials_secret: typing.Optional[str] = None
scopes: List[str] = field(default_factory=list)
audience: typing.Optional[str] = None
auth_mode: AuthType = AuthType.STANDARD

@classmethod
Expand Down Expand Up @@ -352,6 +354,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
client_credentials_secret,
)
kwargs = set_if_exists(kwargs, "scopes", _internal.Credentials.SCOPES.read(config_file))
kwargs = set_if_exists(kwargs, "audience", _internal.Credentials.AUDIENCE.read(config_file))
kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file))
kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file))
kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file))
Expand Down
2 changes: 2 additions & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class Credentials(object):
"""

SCOPES = ConfigEntry(LegacyConfigEntry(SECTION, "scopes", list), YamlConfigEntry("admin.scopes", list))

AUDIENCE = ConfigEntry(LegacyConfigEntry(SECTION, "audience"), YamlConfigEntry("admin.audience"))

AUTH_MODE = ConfigEntry(LegacyConfigEntry(SECTION, "auth_mode"), YamlConfigEntry("admin.authType"))
"""
Expand Down