Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows Secret groups to be optional and configurable #2062

Merged
merged 21 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def get_remote(
def configure_pyflyte_cli(main: Command) -> Command:
"""Configure pyflyte's CLI."""

@staticmethod
def secret_requires_group() -> bool:
"""Return True if secrets require group entry."""


class FlytekitPlugin:
@staticmethod
Expand All @@ -62,6 +66,11 @@ def configure_pyflyte_cli(main: Command) -> Command:
"""Configure pyflyte's CLI."""
return main

@staticmethod
def secret_requires_group() -> bool:
"""Return True if secrets require group entry."""
return True


def _get_plugin_from_entrypoint():
"""Get plugin from entrypoint."""
Expand Down
20 changes: 15 additions & 5 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,11 @@ def __getattr__(self, item: str) -> _GroupSecrets:
return self._GroupSecrets(item, self)

def get(
self, group: str, key: Optional[str] = None, group_version: Optional[str] = None, encode_mode: str = "r"
self,
group: Optional[str] = None,
key: Optional[str] = None,
group_version: Optional[str] = None,
encode_mode: str = "r",
) -> str:
"""
Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError
Expand All @@ -370,15 +374,19 @@ def get(
f"in Env Var:{env_var} and FilePath: {fpath}"
)

def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
def get_secrets_env_var(
self, group: Optional[str] = None, key: Optional[str] = None, group_version: Optional[str] = None
) -> str:
"""
Returns a string that matches the ENV Variable to look for the secrets
"""
self.check_group_key(group)
l = [k.upper() for k in filter(None, (group, group_version, key))]
return f"{self._env_prefix}{'_'.join(l)}"

def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
def get_secrets_file(
self, group: Optional[str] = None, key: Optional[str] = None, group_version: Optional[str] = None
) -> str:
"""
Returns a path that matches the file to look for the secrets
"""
Expand All @@ -388,8 +396,10 @@ def get_secrets_file(self, group: str, key: Optional[str] = None, group_version:
return os.path.join(self._base_dir, *l)

@staticmethod
def check_group_key(group: str):
if group is None or group == "":
def check_group_key(group: Optional[str]):
from flytekit.configuration.plugin import get_plugin

if get_plugin().secret_requires_group() and (group is None or group == ""):
raise ValueError("secrets group is a mandatory field.")


Expand Down
6 changes: 4 additions & 2 deletions flytekit/models/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ class MountType(Enum):
Caution: May not be supported in all environments
"""

group: str
group: Optional[str] = None
key: Optional[str] = None
group_version: Optional[str] = None
mount_requirement: MountType = MountType.ANY

def __post_init__(self):
if self.group is None:
from flytekit.configuration.plugin import get_plugin

if get_plugin().secret_requires_group() and self.group is None:
raise ValueError("Group is a required parameter")

def to_flyte_idl(self) -> _sec.Secret:
Expand Down
21 changes: 21 additions & 0 deletions tests/flytekit/unit/core/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import base64
import os
from datetime import datetime
from pathlib import Path
from unittest.mock import Mock

import mock
import py
import pytest

import flytekit.configuration.plugin
from flytekit.configuration import (
SERIALIZED_CONTEXT_ENV_VAR,
FastSerializationSettings,
Expand Down Expand Up @@ -128,6 +131,24 @@ def test_secrets_manager_get_envvar():
assert sec.get_secrets_env_var("group") == f"{cfg.env_prefix}GROUP"


def test_secret_manager_no_group(monkeypatch):
plugin_mock = Mock()
plugin_mock.secret_requires_group.return_value = False
mock_global_plugin = {"plugin": plugin_mock}
monkeypatch.setattr(flytekit.configuration.plugin, "_GLOBAL_CONFIG", mock_global_plugin)

sec = SecretsManager()
cfg = SecretsConfig.auto()
sec.check_group_key(None)
sec.check_group_key("")

assert sec.get_secrets_env_var(key="ABC") == f"{cfg.env_prefix}ABC"

default_path = Path(cfg.default_dir)
expected_path = default_path / f"{cfg.file_prefix}abc"
assert sec.get_secrets_file(key="ABC") == str(expected_path)


def test_secrets_manager_get_file():
sec = SecretsManager()
with pytest.raises(ValueError):
Expand Down
13 changes: 13 additions & 0 deletions tests/flytekit/unit/models/core/test_security.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from unittest.mock import Mock

import flytekit.configuration.plugin
from flytekit.models.security import Secret


Expand All @@ -11,3 +14,13 @@ def test_secret():
obj2 = Secret.from_flyte_idl(obj.to_flyte_idl())
assert obj2.key is None
assert obj2.group_version == "v1"


def test_secret_no_group(monkeypatch):
plugin_mock = Mock()
plugin_mock.secret_requires_group.return_value = False
mock_global_plugin = {"plugin": plugin_mock}
monkeypatch.setattr(flytekit.configuration.plugin, "_GLOBAL_CONFIG", mock_global_plugin)

s = Secret(key="key")
assert s.group is None