Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run ignore query #2322

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flytekit.configuration import DefaultImages, FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.configuration.plugin import get_plugin
from flytekit.core import context_manager
from flytekit.core.artifact import ArtifactQuery
from flytekit.core.base_task import PythonTask
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import TypeEngine
Expand Down Expand Up @@ -376,14 +377,17 @@ def to_click_option(

description_extra = ""
if literal_var.type.simple == SimpleType.STRUCT:
if default_val:
if default_val and not isinstance(default_val, ArtifactQuery):
if type(default_val) == dict or type(default_val) == list:
default_val = json.dumps(default_val)
else:
default_val = cast(DataClassJsonMixin, default_val).to_json()
if literal_var.type.metadata:
description_extra = f": {json.dumps(literal_var.type.metadata)}"

# If a query has been specified, the input is never strictly required at this layer
required = False if default_val and isinstance(default_val, ArtifactQuery) else required

return click.Option(
param_decls=[f"--{input_name}"],
type=literal_converter.click_type,
Expand Down Expand Up @@ -508,7 +512,22 @@ def _run(*args, **kwargs):
try:
inputs = {}
for input_name, _ in entity.python_interface.inputs.items():
inputs[input_name] = kwargs.get(input_name)
processed_click_value = kwargs.get(input_name)
if isinstance(processed_click_value, ArtifactQuery):
if run_level_params.is_remote:
click.secho(
click.style(
f"Input '{input_name}' not passed, supported backends will query"
f" for {processed_click_value.get_str(**kwargs)}",
bold=True,
)
)
continue
else:
raise click.UsageError(
f"Default for '{input_name}' is a query, which must be specified when running locally."
)
inputs[input_name] = processed_click_value

if not run_level_params.is_remote:
with FlyteContextManager.with_context(_update_flyte_context(run_level_params)):
Expand Down
41 changes: 41 additions & 0 deletions flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,47 @@
) -> art_id.ArtifactQuery:
return Serializer.artifact_query_to_idl(self, **kwargs)

def get_time_partition_str(self, **kwargs) -> str:
tp_str = ""

Check warning on line 190 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L190

Added line #L190 was not covered by tests
if self.time_partition:
tp = self.time_partition.value

Check warning on line 192 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L192

Added line #L192 was not covered by tests
if tp.HasField("time_value"):
tp = tp.time_value.ToDatetime()
tp_str += f" Time partition: {tp}"

Check warning on line 195 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L194-L195

Added lines #L194 - L195 were not covered by tests
elif tp.HasField("input_binding"):
var = tp.input_binding.var

Check warning on line 197 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L197

Added line #L197 was not covered by tests
if var not in kwargs:
raise ValueError(f"Time partition input binding {var} not found in kwargs")

Check warning on line 199 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L199

Added line #L199 was not covered by tests
else:
tp_str += f" Time partition from input<{var}>,"
return tp_str

Check warning on line 202 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L201-L202

Added lines #L201 - L202 were not covered by tests

def get_partition_str(self, **kwargs) -> str:
p_str = ""

Check warning on line 205 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L205

Added line #L205 was not covered by tests
if self.partitions and self.partitions.partitions and len(self.partitions.partitions) > 0:
p_str = " Partitions: "

Check warning on line 207 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L207

Added line #L207 was not covered by tests
for k, v in self.partitions.partitions.items():
if v.value and v.value.HasField("static_value"):
p_str += f"{k}={v.value.static_value}, "

Check warning on line 210 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L210

Added line #L210 was not covered by tests
elif v.value and v.value.HasField("input_binding"):
var = v.value.input_binding.var

Check warning on line 212 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L212

Added line #L212 was not covered by tests
if var not in kwargs:
raise ValueError(f"Partition input binding {var} not found in kwargs")

Check warning on line 214 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L214

Added line #L214 was not covered by tests
else:
p_str += f"{k} from input<{var}>, "
return p_str.rstrip("\n\r, ")

Check warning on line 217 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L216-L217

Added lines #L216 - L217 were not covered by tests

def get_str(self, **kwargs):
# Detailed string that explains query a bit more, used in running
tp_str = self.get_time_partition_str(**kwargs)
p_str = self.get_partition_str(**kwargs)

Check warning on line 222 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L221-L222

Added lines #L221 - L222 were not covered by tests

return f"'{self.artifact.name}'...{tp_str}{p_str}"

Check warning on line 224 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L224

Added line #L224 was not covered by tests

def __str__(self):
# Default string used for printing --help
return f"Artifact Query: on {self.artifact.name}"


class TimePartition(object):
def __init__(
Expand Down
23 changes: 22 additions & 1 deletion flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pytimeparse import parse

from flytekit import BlobType, FlyteContext, FlyteContextManager, Literal, LiteralType, StructuredDataset
from flytekit.core.artifact import ArtifactQuery
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import TypeEngine
from flytekit.models.types import SimpleType
Expand Down Expand Up @@ -61,6 +62,8 @@
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:
if isinstance(value, ArtifactQuery):
return value
p = pathlib.Path(value)
# set remote_directory to false if running pyflyte run locally. This makes sure that the original
# directory is used and not a random one.
Expand All @@ -80,6 +83,8 @@
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:
if isinstance(value, ArtifactQuery):
return value
if isinstance(value, str):
return StructuredDataset(uri=value)
elif isinstance(value, StructuredDataset):
Expand All @@ -93,6 +98,8 @@
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:
if isinstance(value, ArtifactQuery):
return value

Check warning on line 102 in flytekit/interaction/click_types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/interaction/click_types.py#L102

Added line #L102 was not covered by tests
# set remote_directory to false if running pyflyte run locally. This makes sure that the original
# file is used and not a random one.
remote_path = None if getattr(ctx.obj, "is_remote", False) else False
Expand All @@ -109,6 +116,8 @@
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:
if isinstance(value, ArtifactQuery):
return value
# set remote_directory to false if running pyflyte run locally. This makes sure that the original
# file is used and not a random one.
remote_path = None if getattr(ctx.obj, "is_remote", None) else False
Expand All @@ -131,6 +140,8 @@
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:
if isinstance(value, ArtifactQuery):
return value
if value in self._ADDITONAL_FORMATS:
if value == self._NOW_FMT:
return datetime.datetime.now()
Expand All @@ -143,6 +154,8 @@
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:
if isinstance(value, ArtifactQuery):
return value
if value is None:
raise click.BadParameter("None value cannot be converted to a Duration type.")
return datetime.timedelta(seconds=parse(value))
Expand All @@ -156,6 +169,8 @@
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> enum.Enum:
if isinstance(value, ArtifactQuery):
return value
if isinstance(value, self._enum_type):
return value
return self._enum_type(super().convert(value, param, ctx))
Expand Down Expand Up @@ -191,6 +206,8 @@
Important to implement NoneType / Optional.
Also could we just determine the click types from the python types
"""
if isinstance(value, ArtifactQuery):
return value
for t in self._types:
try:
return t.convert(value, param, ctx)
Expand Down Expand Up @@ -228,6 +245,8 @@
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:
if isinstance(value, ArtifactQuery):
return value
if value is None:
raise click.BadParameter("None value cannot be converted to a Json type.")

Expand Down Expand Up @@ -274,7 +293,7 @@
SimpleType.STRING: click.STRING,
SimpleType.BOOLEAN: click.BOOL,
SimpleType.DURATION: DurationParamType(),
SimpleType.DATETIME: click.DateTime(),
SimpleType.DATETIME: DateTimeType(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch.

}


Expand Down Expand Up @@ -353,6 +372,8 @@
"""
Convert the value to a Flyte Literal or a python native type. This is used by click to convert the input.
"""
if isinstance(value, ArtifactQuery):
return value
try:
# If the expected Python type is datetime.date, adjust the value to date
if self._python_type is datetime.date:
Expand Down
43 changes: 43 additions & 0 deletions tests/flytekit/unit/interaction/test_click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@
import yaml

from flytekit import FlyteContextManager
from flytekit.core.artifact import Artifact
from flytekit.core.type_engine import TypeEngine
from flytekit.interaction.click_types import (
DateTimeType,
DirParamType,
DurationParamType,
EnumParamType,
FileParamType,
FlyteLiteralConverter,
JsonParamType,
PickleParamType,
StructuredDatasetParamType,
UnionParamType,
key_value_callback,
)

Expand Down Expand Up @@ -62,6 +68,23 @@ def test_literal_converter(python_type, python_value):
assert lc.convert(click_ctx, dummy_param, python_value) == TypeEngine.to_literal(ctx, python_value, python_type, lt)


def test_literal_converter_query():
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(int)

lc = FlyteLiteralConverter(
ctx,
literal_type=lt,
python_type=int,
is_remote=True,
)

a = Artifact(name="test-artifact")
query = a.query()
click_ctx = click.Context(click.Command("test_command"), obj={"remote": True})
assert lc.convert(click_ctx, dummy_param, query) == query


@pytest.mark.parametrize(
"python_type, python_str_value, python_value",
[
Expand Down Expand Up @@ -163,3 +186,23 @@ def test_key_value_callback():
key_value_callback(ctx, "a", ["a=b", "c=d", "e"])
with pytest.raises(click.BadParameter):
key_value_callback(ctx, "a", ["a=b", "c=d", "e=f", "g"])


@pytest.mark.parametrize(
"param_type",
[
(DateTimeType()),
(DurationParamType()),
(JsonParamType(typing.Dict[str, str])),
(UnionParamType([click.FLOAT, click.INT])),
(EnumParamType(Color)),
(PickleParamType()),
(StructuredDatasetParamType()),
(DirParamType()),
],
)
def test_query_passing(param_type: click.ParamType):
a = Artifact(name="test-artifact")
query = a.query()

assert param_type.convert(value=query, param=None, ctx=None) is query
Loading