diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 972b8655d5..9bdf85a178 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -273,7 +273,10 @@ def _refresh_credentials_from_command(self): except subprocess.CalledProcessError as e: cli_logger.error("Failed to generate token from command {}".format(command)) raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e)) - self.set_access_token(output.stdout.strip()) + authorization_header_key = self.public_client_config.authorization_metadata_key or None + if not authorization_header_key: + self.set_access_token(output.stdout.strip()) + self.set_access_token(output.stdout.strip(), authorization_header_key) def _refresh_credentials_noop(self): pass diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py index b329a16112..1f79016f63 100644 --- a/flytekit/configuration/file.py +++ b/flytekit/configuration/file.py @@ -32,13 +32,16 @@ class LegacyConfigEntry(object): option: str type_: typing.Type = str + def get_env_name(self): + return f"FLYTE_{self.section.upper()}_{self.option.upper()}" + def read_from_env(self, transform: typing.Optional[typing.Callable] = None) -> typing.Optional[typing.Any]: """ Reads the config entry from environment variable, the structure of the env var is current ``FLYTE_{SECTION}_{OPTION}`` all upper cased. We will change this in the future. :return: """ - env = f"FLYTE_{self.section.upper()}_{self.option.upper()}" + env = self.get_env_name() v = os.environ.get(env, None) if v is None: return None diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index 4899051169..6636c1e0c8 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -9,6 +9,7 @@ from flytekit.clients.raw import RawSynchronousFlyteClient, get_basic_authorization_header, get_token from flytekit.configuration import AuthType, PlatformConfig +from flytekit.configuration.internal import Credentials def get_admin_stub_mock() -> mock.MagicMock: @@ -49,17 +50,21 @@ def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_ad assert client.check_access_token("abc") +@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.set_access_token") +@mock.patch("flytekit.clients.raw.auth_service") @mock.patch("subprocess.run") -def test_refresh_credentials_from_command(mock_call_to_external_process): - command = ["command", "generating", "token"] +def test_refresh_credentials_from_command(mock_call_to_external_process, mock_admin_auth, mock_set_access_token): token = "token" + command = ["command", "generating", "token"] - mock_call_to_external_process.return_value = CompletedProcess(command, 0, stdout=token) + mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock() + client = RawSynchronousFlyteClient(PlatformConfig(command=command)) - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS, command=command)) - cc._refresh_credentials_from_command() + mock_call_to_external_process.return_value = CompletedProcess(command, 0, stdout=token) + client._refresh_credentials_from_command() mock_call_to_external_process.assert_called_with(command, capture_output=True, text=True, check=True) + mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key) @mock.patch("flytekit.clients.raw.dataproxy_service") @@ -195,6 +200,14 @@ def test_basic_strings(mocked_method): @patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command") def test_refresh_command(mocked_method): - cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS)) + cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNALCOMMAND)) + cc.refresh_credentials() + assert mocked_method.called + + +@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command") +def test_refresh_from_environment_variable(mocked_method, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv(Credentials.AUTH_MODE.legacy.get_env_name(), AuthType.EXTERNAL_PROCESS.name, prepend=False) + cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=None).auto(None)) cc.refresh_credentials() assert mocked_method.called