diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 63914c13b2..f2467e77d8 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -329,13 +329,13 @@ def __getattr__(self, item: str) -> _GroupSecrets: """ return self._GroupSecrets(item, self) - def get(self, group: str, key: str) -> str: + def get(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: """ Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError """ - self.check_group_key(group, key) - env_var = self.get_secrets_env_var(group, key) - fpath = self.get_secrets_file(group, key) + self.check_group_key(group) + env_var = self.get_secrets_env_var(group, key, group_version) + fpath = self.get_secrets_file(group, key, group_version) v = os.environ.get(env_var) if v is not None: return v @@ -346,26 +346,27 @@ def get(self, group: str, key: str) -> str: f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}" ) - def get_secrets_env_var(self, group: str, key: str) -> str: + def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: """ Returns a string that matches the ENV Variable to look for the secrets """ - self.check_group_key(group, key) - return f"{self._env_prefix}{group.upper()}_{key.upper()}" + self.check_group_key(group) + l = [k.upper() for k in filter(None, (group, group_version, key))] + return f"{self._env_prefix}{'_'.join(l)}" - def get_secrets_file(self, group: str, key: str) -> str: + def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: """ Returns a path that matches the file to look for the secrets """ - self.check_group_key(group, key) - return os.path.join(self._base_dir, group.lower(), f"{self._file_prefix}{key.lower()}") + self.check_group_key(group) + l = [k.lower() for k in filter(None, (group, group_version, key))] + l[-1] = f"{self._file_prefix}{l[-1]}" + return os.path.join(self._base_dir, *l) @staticmethod - def check_group_key(group: str, key: str): + def check_group_key(group: str): if group is None or group == "": raise ValueError("secrets group is a mandatory field.") - if key is None or key == "": - raise ValueError("secrets key is a mandatory field.") @dataclass(frozen=True) diff --git a/flytekit/models/security.py b/flytekit/models/security.py index 7babb859e4..9af90a4b8a 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -36,15 +36,13 @@ class MountType(Enum): """ group: str - key: str + key: Optional[str] = None group_version: Optional[str] = None mount_requirement: MountType = MountType.ANY def __post_init__(self): if self.group is None: raise ValueError("Group is a required parameter") - if self.key is None: - raise ValueError("Key is also a required parameter") def to_flyte_idl(self) -> _sec.Secret: return _sec.Secret( @@ -59,7 +57,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Secret) -> "Secret": return cls( group=pb2_object.group, group_version=pb2_object.group_version if pb2_object.group_version else None, - key=pb2_object.key, + key=pb2_object.key if pb2_object.key else None, mount_requirement=Secret.MountType(pb2_object.mount_requirement), ) diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 6e68c9d4be..f01c6b203c 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -115,18 +115,17 @@ def test_secrets_manager_default(): def test_secrets_manager_get_envvar(): sec = SecretsManager() - with pytest.raises(ValueError): - sec.get_secrets_env_var("test", "") with pytest.raises(ValueError): sec.get_secrets_env_var("", "x") cfg = SecretsConfig.auto() assert sec.get_secrets_env_var("group", "test") == f"{cfg.env_prefix}GROUP_TEST" + assert sec.get_secrets_env_var("group", "test", "v1") == f"{cfg.env_prefix}GROUP_V1_TEST" + assert sec.get_secrets_env_var("group", group_version="v1") == f"{cfg.env_prefix}GROUP_V1" + assert sec.get_secrets_env_var("group") == f"{cfg.env_prefix}GROUP" def test_secrets_manager_get_file(): sec = SecretsManager() - with pytest.raises(ValueError): - sec.get_secrets_file("test", "") with pytest.raises(ValueError): sec.get_secrets_file("", "x") cfg = SecretsConfig.auto() @@ -135,6 +134,12 @@ def test_secrets_manager_get_file(): "group", f"{cfg.file_prefix}test", ) + assert sec.get_secrets_file("group", "test", "v1") == os.path.join( + cfg.default_dir, + "group", + "v1", + f"{cfg.file_prefix}test", + ) def test_secrets_manager_file(tmpdir: py.path.local): @@ -145,8 +150,6 @@ def test_secrets_manager_file(tmpdir: py.path.local): with open(f, "w+") as w: w.write("my-password") - with pytest.raises(ValueError): - sec.get("test", "") with pytest.raises(ValueError): sec.get("", "x") # Group dir not exists diff --git a/tests/flytekit/unit/models/core/test_security.py b/tests/flytekit/unit/models/core/test_security.py new file mode 100644 index 0000000000..c2933f9353 --- /dev/null +++ b/tests/flytekit/unit/models/core/test_security.py @@ -0,0 +1,13 @@ +from flytekit.models.security import Secret + + +def test_secret(): + obj = Secret("grp", "key") + obj2 = Secret.from_flyte_idl(obj.to_flyte_idl()) + assert obj2.key == "key" + assert obj2.group_version is None + + obj = Secret("grp", group_version="v1") + obj2 = Secret.from_flyte_idl(obj.to_flyte_idl()) + assert obj2.key is None + assert obj2.group_version == "v1"