Skip to content

Commit

Permalink
added envvars as multi arg
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 committed Sep 8, 2023
1 parent 5ef611f commit 9698597
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 7 deletions.
14 changes: 8 additions & 6 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flytekit.core.base_task import PythonTask
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
from flytekit.interaction.click_types import FlyteLiteralConverter, JsonParamType
from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback
from flytekit.models.interface import Parameter, Variable
from flytekit.models.types import SimpleType
from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow
Expand Down Expand Up @@ -142,13 +142,15 @@ class RunLevelParams(PyFlyteParams):
help="Whether to overwrite the cache if it already exists",
)
)
envs: typing.Dict[str, str] = make_field(
envvars: typing.Dict[str, str] = make_field(
click.Option(
param_decls=["--envs", "envs"],
param_decls=["--envvars", "--env"],
required=False,
type=JsonParamType(),
multiple=True,
type=str,
show_default=True,
help="Environment variables to set in the container",
callback=key_value_callback,
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
)
)
tag: typing.List[str] = make_field(
Expand Down Expand Up @@ -362,7 +364,7 @@ def run_remote(
options=options,
type_hints=type_hints,
overwrite_cache=run_level_params.overwrite_cache,
envs=run_level_params.envs,
envs=run_level_params.envvars,
tags=run_level_params.tag,
)

Expand Down
15 changes: 15 additions & 0 deletions flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,18 @@ def convert(self, ctx, param, value) -> typing.Union[Literal, typing.Any]:
raise
except Exception as e:
raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e


def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]:
"""
Callback for click to parse key-value pairs.
"""
if not values:
return None
result = {}
for v in values:
if "=" not in v:
raise click.BadParameter(f"Expected key-value pair of the form key=value, got {v}")
k, v = v.split("=", 1)
result[k.strip()] = v.strip()
return result
2 changes: 1 addition & 1 deletion flytekit/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def from_flyte_idl(cls, p):
if p.security_context
else None,
overwrite_cache=p.overwrite_cache,
envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None,
envs=_common_models.Envs.from_flyte_idl(p.envvars) if p.HasField("envs") else None,
tags=p.tags,
)

Expand Down
16 changes: 16 additions & 0 deletions tests/flytekit/unit/interaction/test_click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
FileParamType,
FlyteLiteralConverter,
JsonParamType,
key_value_callback,
)
from flytekit.models.types import SimpleType
from flytekit.remote import FlyteRemote
Expand Down Expand Up @@ -141,3 +142,18 @@ def test_json_type():
yaml.dump({"a": "b"}, f)
f.flush()
assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"}


def test_key_value_callback():
"""Write a test that verifies that the callback works correctly."""
ctx = click.Context(click.Command("test_command"), obj={"remote": True})
assert key_value_callback(ctx, "a", None) is None
assert key_value_callback(ctx, "a", ["a=b"]) == {"a": "b"}
assert key_value_callback(ctx, "a", ["a=b", "c=d"]) == {"a": "b", "c": "d"}
assert key_value_callback(ctx, "a", ["a=b", "c=d", "e=f"]) == {"a": "b", "c": "d", "e": "f"}
with pytest.raises(click.BadParameter):
key_value_callback(ctx, "a", ["a=b", "c"])
with pytest.raises(click.BadParameter):
key_value_callback(ctx, "a", ["a=b", "c=d", "e"])
with pytest.raises(click.BadParameter):
key_value_callback(ctx, "a", ["a=b", "c=d", "e=f", "g"])

0 comments on commit 9698597

Please sign in to comment.