Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix read authorization_metadata_key from public_client_config in _refresh_credentials_from_command #1065

Merged
merged 17 commits into from
Jun 16, 2022
5 changes: 4 additions & 1 deletion flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,10 @@ def _refresh_credentials_from_command(self):
except subprocess.CalledProcessError as e:
cli_logger.error("Failed to generate token from command {}".format(command))
raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e))
self.set_access_token(output.stdout.strip())
authorization_header_key = self.public_client_config.authorization_metadata_key or None
if not authorization_header_key:
self.set_access_token(output.stdout.strip())
self.set_access_token(output.stdout.strip(), authorization_header_key)
eapolinario marked this conversation as resolved.
Show resolved Hide resolved

def _refresh_credentials_noop(self):
pass
Expand Down
5 changes: 4 additions & 1 deletion flytekit/configuration/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ class LegacyConfigEntry(object):
option: str
type_: typing.Type = str

def get_env_name(self):
return f"FLYTE_{self.section.upper()}_{self.option.upper()}"

def read_from_env(self, transform: typing.Optional[typing.Callable] = None) -> typing.Optional[typing.Any]:
"""
Reads the config entry from environment variable, the structure of the env var is current
``FLYTE_{SECTION}_{OPTION}`` all upper cased. We will change this in the future.
:return:
"""
env = f"FLYTE_{self.section.upper()}_{self.option.upper()}"
env = self.get_env_name()
v = os.environ.get(env, None)
if v is None:
return None
Expand Down
25 changes: 19 additions & 6 deletions tests/flytekit/unit/clients/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from flytekit.clients.raw import RawSynchronousFlyteClient, get_basic_authorization_header, get_token
from flytekit.configuration import AuthType, PlatformConfig
from flytekit.configuration.internal import Credentials


def get_admin_stub_mock() -> mock.MagicMock:
Expand Down Expand Up @@ -49,17 +50,21 @@ def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_ad
assert client.check_access_token("abc")


@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.set_access_token")
@mock.patch("flytekit.clients.raw.auth_service")
@mock.patch("subprocess.run")
def test_refresh_credentials_from_command(mock_call_to_external_process):
command = ["command", "generating", "token"]
def test_refresh_credentials_from_command(mock_call_to_external_process, mock_admin_auth, mock_set_access_token):
token = "token"
command = ["command", "generating", "token"]

mock_call_to_external_process.return_value = CompletedProcess(command, 0, stdout=token)
mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock()
client = RawSynchronousFlyteClient(PlatformConfig(command=command))

cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS, command=command))
cc._refresh_credentials_from_command()
mock_call_to_external_process.return_value = CompletedProcess(command, 0, stdout=token)
client._refresh_credentials_from_command()

mock_call_to_external_process.assert_called_with(command, capture_output=True, text=True, check=True)
mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key)


@mock.patch("flytekit.clients.raw.dataproxy_service")
Expand Down Expand Up @@ -195,6 +200,14 @@ def test_basic_strings(mocked_method):

@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command")
def test_refresh_command(mocked_method):
cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNAL_PROCESS))
cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.EXTERNALCOMMAND))
cc.refresh_credentials()
assert mocked_method.called


@patch.object(RawSynchronousFlyteClient, "_refresh_credentials_from_command")
def test_refresh_from_environment_variable(mocked_method, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv(Credentials.AUTH_MODE.legacy.get_env_name(), AuthType.EXTERNAL_PROCESS.name, prepend=False)
cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=None).auto(None))
cc.refresh_credentials()
assert mocked_method.called