From 0d2f7607cdc97d23cc4efc9c61f5893716804293 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Wed, 14 Sep 2022 17:33:37 -0700 Subject: [PATCH] Strip newline from client secret (#1163) * Strip newline from client secret * Add logging and rework the secret file comparison to work on windows Signed-off-by: Eduardo Apolinario Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- flytekit/configuration/__init__.py | 9 ++++++++- .../configs/creds_secret_location.yaml | 2 +- .../unit/configuration/test_internal.py | 17 +++++++++++++++-- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index c0edf07ab2..047ce4b3d3 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -94,6 +94,7 @@ from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists +from flytekit.loggers import logger PROJECT_PLACEHOLDER = "{{ registration.project }}" DOMAIN_PLACEHOLDER = "{{ registration.domain }}" @@ -336,10 +337,16 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs, "client_credentials_secret", _internal.Credentials.CLIENT_CREDENTIALS_SECRET.read(config_file) ) + client_credentials_secret = read_file_if_exists( + _internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file) + ) + if client_credentials_secret and client_credentials_secret.endswith("\n"): + logger.info("Newline stripped from client secret") + client_credentials_secret = client_credentials_secret.strip() kwargs = set_if_exists( kwargs, "client_credentials_secret", - read_file_if_exists(_internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file)), + client_credentials_secret, ) kwargs = set_if_exists(kwargs, "scopes", _internal.Credentials.SCOPES.read(config_file)) kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) diff --git a/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml b/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml index 9c1ad83a3e..7da41b7c38 100644 --- a/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml +++ b/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml @@ -1,7 +1,7 @@ admin: # For GRPC endpoints you might want to use dns:///flyte.myexample.com endpoint: dns:///flyte.mycorp.io - clientSecretLocation: ../tests/flytekit/unit/configuration/configs/fake_secret + clientSecretLocation: configs/fake_secret authType: Pkce insecure: true clientId: propeller diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 6ba81f309c..7f6be53a55 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -2,7 +2,7 @@ import mock -from flytekit.configuration import get_config_file, read_file_if_exists +from flytekit.configuration import PlatformConfig, get_config_file, read_file_if_exists from flytekit.configuration.internal import AWS, Credentials, Images @@ -31,7 +31,20 @@ def test_client_secret_location(): os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/creds_secret_location.yaml") ) secret_location = Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(cfg) - assert secret_location == "../tests/flytekit/unit/configuration/configs/fake_secret" + assert secret_location == "configs/fake_secret" + + # Modify the path to the secret inline + cfg._yaml_config["admin"]["clientSecretLocation"] = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs/fake_secret" + ) + + # Assert secret contains a newline + with open(cfg._yaml_config["admin"]["clientSecretLocation"], "rb") as f: + assert f.read().decode().endswith("\n") is True + + # Assert that secret in platform config does not contain a newline + platform_cfg = PlatformConfig.auto(cfg) + assert platform_cfg.client_credentials_secret == "hello" def test_read_file_if_exists():