From 845bbc07b01906241c168946e51e26f8855f160b Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 2 Apr 2024 16:32:32 -0700 Subject: [PATCH 1/5] first bit Signed-off-by: Yee Hing Tong --- flytekit/clis/sdk_in_container/run.py | 18 ++++++++++++++++-- flytekit/interaction/click_types.py | 5 +++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 9f4effe3eb..7e8659a863 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -26,6 +26,7 @@ from flytekit.configuration import DefaultImages, FastSerializationSettings, ImageConfig, SerializationSettings from flytekit.configuration.plugin import get_plugin from flytekit.core import context_manager +from flytekit.core.artifact import ArtifactQuery from flytekit.core.base_task import PythonTask from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine @@ -362,6 +363,12 @@ def to_click_option( This handles converting workflow input types to supported click parameters with callbacks to initialize the input values to their expected types. """ + print(f"input_name: {input_name}") + print(f"default_val: {default_val}") + print(f"required: {required}") + print(f"literal_var: {literal_var}") + print(f"python_type: {python_type}") + print("====================================") run_level_params: RunLevelParams = ctx.obj literal_converter = FlyteLiteralConverter( @@ -374,9 +381,10 @@ def to_click_option( if literal_converter.is_bool() and not default_val: default_val = False + description_extra = "" if literal_var.type.simple == SimpleType.STRUCT: - if default_val: + if default_val and not isinstance(default_val, ArtifactQuery): if type(default_val) == dict or type(default_val) == list: default_val = json.dumps(default_val) else: @@ -384,6 +392,9 @@ def to_click_option( if literal_var.type.metadata: description_extra = f": {json.dumps(literal_var.type.metadata)}" + # If a query has been specified, the input is never strictly required at this layer + required = False if default_val and isinstance(default_val, ArtifactQuery) else required + return click.Option( param_decls=[f"--{input_name}"], type=literal_converter.click_type, @@ -508,7 +519,10 @@ def _run(*args, **kwargs): try: inputs = {} for input_name, _ in entity.python_interface.inputs.items(): - inputs[input_name] = kwargs.get(input_name) + processed_click_value = kwargs.get(input_name) + if isinstance(processed_click_value, ArtifactQuery): + continue + inputs[input_name] = processed_click_value if not run_level_params.is_remote: with FlyteContextManager.with_context(_update_flyte_context(run_level_params)): diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 6ab9f88a25..83bff78735 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -14,6 +14,7 @@ from pytimeparse import parse from flytekit import BlobType, FlyteContext, FlyteContextManager, Literal, LiteralType, StructuredDataset +from flytekit.core.artifact import ArtifactQuery from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine from flytekit.models.types import SimpleType @@ -80,6 +81,8 @@ class StructuredDatasetParamType(click.ParamType): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: + if isinstance(value, ArtifactQuery): + return value if isinstance(value, str): return StructuredDataset(uri=value) elif isinstance(value, StructuredDataset): @@ -353,6 +356,8 @@ def convert( """ Convert the value to a Flyte Literal or a python native type. This is used by click to convert the input. """ + if isinstance(value, ArtifactQuery): + return value try: # If the expected Python type is datetime.date, adjust the value to date if self._python_type is datetime.date: From 3a02c710265c1952de59f54664f7bdce09ce94c9 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 3 Apr 2024 10:34:30 -0700 Subject: [PATCH 2/5] add formatting Signed-off-by: Yee Hing Tong --- flytekit/clis/sdk_in_container/run.py | 21 ++++++++++------ flytekit/core/artifact.py | 36 +++++++++++++++++++++++++++ flytekit/interaction/click_types.py | 18 +++++++++++++- 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 7e8659a863..256e319ae0 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -363,12 +363,6 @@ def to_click_option( This handles converting workflow input types to supported click parameters with callbacks to initialize the input values to their expected types. """ - print(f"input_name: {input_name}") - print(f"default_val: {default_val}") - print(f"required: {required}") - print(f"literal_var: {literal_var}") - print(f"python_type: {python_type}") - print("====================================") run_level_params: RunLevelParams = ctx.obj literal_converter = FlyteLiteralConverter( @@ -381,7 +375,6 @@ def to_click_option( if literal_converter.is_bool() and not default_val: default_val = False - description_extra = "" if literal_var.type.simple == SimpleType.STRUCT: if default_val and not isinstance(default_val, ArtifactQuery): @@ -521,7 +514,19 @@ def _run(*args, **kwargs): for input_name, _ in entity.python_interface.inputs.items(): processed_click_value = kwargs.get(input_name) if isinstance(processed_click_value, ArtifactQuery): - continue + if run_level_params.is_remote: + click.secho( + click.style( + f"Input '{input_name}' not passed, supported backends will query" + f" for {processed_click_value.get_str(**kwargs)}", + bold=True, + ) + ) + continue + else: + raise click.UsageError( + f"Default for '{input_name}' is a query, which must be specified when running locally." + ) inputs[input_name] = processed_click_value if not run_level_params.is_remote: diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index e9a7909809..241559ab73 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -186,6 +186,42 @@ def to_flyte_idl( ) -> art_id.ArtifactQuery: return Serializer.artifact_query_to_idl(self, **kwargs) + def get_time_partition_str(self, **kwargs) -> str: + tp_str = "" + if self.time_partition: + tp = self.time_partition.value + if tp.HasField("time_value"): + tp = tp.time_value.ToDatetime() + tp_str += f" Time partition: {tp}" + elif tp.HasField("input_binding"): + var = tp.input_binding.var + if var not in kwargs: + raise ValueError(f"Time partition input binding {var} not found in kwargs") + else: + tp_str += f" Time partition from input<{var}>," + return tp_str + + def get_partition_str(self, **kwargs) -> str: + p_str = "" + if self.partitions and self.partitions.partitions and len(self.partitions.partitions) > 0: + p_str = " Partitions: " + for k, v in self.partitions.partitions.items(): + if v.value and v.value.HasField("static_value"): + p_str += f"{k}={v.value.static_value}, " + elif v.value and v.value.HasField("input_binding"): + var = v.value.input_binding.var + if var not in kwargs: + raise ValueError(f"Partition input binding {var} not found in kwargs") + else: + p_str += f"{k} from input<{var}>, " + return p_str.rstrip("\n\r, ") + + def get_str(self, **kwargs): + tp_str = self.get_time_partition_str(**kwargs) + p_str = self.get_partition_str(**kwargs) + + return f"'{self.artifact.name}'...{tp_str}{p_str}" + class TimePartition(object): def __init__( diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 83bff78735..439dc3cb73 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -62,6 +62,8 @@ class DirParamType(click.ParamType): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: + if isinstance(value, ArtifactQuery): + return value p = pathlib.Path(value) # set remote_directory to false if running pyflyte run locally. This makes sure that the original # directory is used and not a random one. @@ -96,6 +98,8 @@ class FileParamType(click.ParamType): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: + if isinstance(value, ArtifactQuery): + return value # set remote_directory to false if running pyflyte run locally. This makes sure that the original # file is used and not a random one. remote_path = None if getattr(ctx.obj, "is_remote", False) else False @@ -112,6 +116,8 @@ class PickleParamType(click.ParamType): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: + if isinstance(value, ArtifactQuery): + return value # set remote_directory to false if running pyflyte run locally. This makes sure that the original # file is used and not a random one. remote_path = None if getattr(ctx.obj, "is_remote", None) else False @@ -134,6 +140,8 @@ def __init__(self): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: + if isinstance(value, ArtifactQuery): + return value if value in self._ADDITONAL_FORMATS: if value == self._NOW_FMT: return datetime.datetime.now() @@ -146,6 +154,8 @@ class DurationParamType(click.ParamType): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: + if isinstance(value, ArtifactQuery): + return value if value is None: raise click.BadParameter("None value cannot be converted to a Duration type.") return datetime.timedelta(seconds=parse(value)) @@ -159,6 +169,8 @@ def __init__(self, enum_type: typing.Type[enum.Enum]): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> enum.Enum: + if isinstance(value, ArtifactQuery): + return value if isinstance(value, self._enum_type): return value return self._enum_type(super().convert(value, param, ctx)) @@ -194,6 +206,8 @@ def convert( Important to implement NoneType / Optional. Also could we just determine the click types from the python types """ + if isinstance(value, ArtifactQuery): + return value for t in self._types: try: return t.convert(value, param, ctx) @@ -231,6 +245,8 @@ def _parse(self, value: typing.Any, param: typing.Optional[click.Parameter]): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: + if isinstance(value, ArtifactQuery): + return value if value is None: raise click.BadParameter("None value cannot be converted to a Json type.") @@ -277,7 +293,7 @@ def modify_literal_uris(lit: Literal): SimpleType.STRING: click.STRING, SimpleType.BOOLEAN: click.BOOL, SimpleType.DURATION: DurationParamType(), - SimpleType.DATETIME: click.DateTime(), + SimpleType.DATETIME: DateTimeType(), } From b2e156ea5ba737287a21abb52913b4d8dc48056d Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 3 Apr 2024 10:39:44 -0700 Subject: [PATCH 3/5] better logging for --help Signed-off-by: Yee Hing Tong --- flytekit/core/artifact.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index 241559ab73..77532860cd 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -217,11 +217,16 @@ def get_partition_str(self, **kwargs) -> str: return p_str.rstrip("\n\r, ") def get_str(self, **kwargs): + # Detailed string that explains query a bit more, used in running tp_str = self.get_time_partition_str(**kwargs) p_str = self.get_partition_str(**kwargs) return f"'{self.artifact.name}'...{tp_str}{p_str}" + def __str__(self): + # Default string used for printing --help + return f"Artifact Query: on {self.artifact.name}" + class TimePartition(object): def __init__( From 94bf7bcec42995d64131c47d1c7badbd31ce555f Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 3 Apr 2024 15:56:20 -0700 Subject: [PATCH 4/5] test for param types Signed-off-by: Yee Hing Tong --- .../unit/interaction/test_click_types.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index b5e94a1ff7..8e63128111 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -10,13 +10,19 @@ import yaml from flytekit import FlyteContextManager +from flytekit.core.artifact import Artifact from flytekit.core.type_engine import TypeEngine from flytekit.interaction.click_types import ( DateTimeType, + DirParamType, DurationParamType, + EnumParamType, FileParamType, FlyteLiteralConverter, JsonParamType, + PickleParamType, + StructuredDatasetParamType, + UnionParamType, key_value_callback, ) @@ -163,3 +169,23 @@ def test_key_value_callback(): key_value_callback(ctx, "a", ["a=b", "c=d", "e"]) with pytest.raises(click.BadParameter): key_value_callback(ctx, "a", ["a=b", "c=d", "e=f", "g"]) + + +@pytest.mark.parametrize( + "param_type", + [ + (DateTimeType()), + (DurationParamType()), + (JsonParamType(typing.Dict[str, str])), + (UnionParamType([click.FLOAT, click.INT])), + (EnumParamType(Color)), + (PickleParamType()), + (StructuredDatasetParamType()), + (DirParamType()), + ], +) +def test_query_passing(param_type: click.ParamType): + a = Artifact(name="test-artifact") + query = a.query() + + assert param_type.convert(value=query, param=None, ctx=None) is query From f8508838f226550beb2b3a3f0aa98a691b8f131a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 3 Apr 2024 15:59:25 -0700 Subject: [PATCH 5/5] literal converter test Signed-off-by: Yee Hing Tong --- .../unit/interaction/test_click_types.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index 8e63128111..83a191c449 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -68,6 +68,23 @@ def test_literal_converter(python_type, python_value): assert lc.convert(click_ctx, dummy_param, python_value) == TypeEngine.to_literal(ctx, python_value, python_type, lt) +def test_literal_converter_query(): + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(int) + + lc = FlyteLiteralConverter( + ctx, + literal_type=lt, + python_type=int, + is_remote=True, + ) + + a = Artifact(name="test-artifact") + query = a.query() + click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + assert lc.convert(click_ctx, dummy_param, query) == query + + @pytest.mark.parametrize( "python_type, python_str_value, python_value", [