Skip to content

Commit

Permalink
pyflyte run supports pickle (#1646)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored and eapolinario committed Jul 10, 2023
1 parent 7bbd44b commit dac0603
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
22 changes: 20 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass
from typing import cast

import cloudpickle
import rich_click as click
import yaml
from dataclasses_json import DataClassJsonMixin
Expand All @@ -34,7 +35,7 @@
from flytekit.configuration.default_images import DefaultImages
from flytekit.core import context_manager
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContext
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
Expand All @@ -46,6 +47,7 @@
from flytekit.tools import module_loader, script_mode
from flytekit.tools.script_mode import _find_project_root
from flytekit.tools.translator import Options
from flytekit.types.pickle.pickle import FlytePickleTransformer

REMOTE_FLAG_KEY = "remote"
RUN_LEVEL_PARAMS_KEY = "run_level_params"
Expand Down Expand Up @@ -104,6 +106,19 @@ def convert(
raise click.BadParameter(f"parameter should be a valid file path, {value}")


class PickleParamType(click.ParamType):
name = "pickle"

def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:

uri = FlyteContextManager.current_context().file_access.get_random_local_path()
with open(uri, "w+b") as outfile:
cloudpickle.dump(value, outfile)
return FileParam(filepath=str(pathlib.Path(uri).resolve()))


class DateTimeType(click.DateTime):

_NOW_FMT = "now"
Expand Down Expand Up @@ -228,7 +243,10 @@ def __init__(

if self._literal_type.blob:
if self._literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
self._click_type = FileParamType()
if self._literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT:
self._click_type = PickleParamType()
else:
self._click_type = FileParamType()
else:
self._click_type = DirParamType()

Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def test_pyflyte_run_cli():
json.dumps([{"x": parquet_file}]),
"--o",
json.dumps({"x": [parquet_file]}),
"--p",
"Any",
],
catch_exceptions=False,
)
Expand Down
6 changes: 4 additions & 2 deletions tests/flytekit/unit/cli/pyflyte/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def print_all(
m: dict,
n: typing.List[typing.Dict[str, FlyteFile]],
o: typing.Dict[str, typing.List[FlyteFile]],
p: typing.Any,
):
print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}")
print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o} , {p}")


@task
Expand Down Expand Up @@ -88,12 +89,13 @@ def my_wf(
l: dict,
n: typing.List[typing.Dict[str, FlyteFile]],
o: typing.Dict[str, typing.List[FlyteFile]],
p: typing.Any,
remote: pd.DataFrame,
image: StructuredDataset,
m: dict = {"hello": "world"},
) -> Annotated[StructuredDataset, subset_cols]:
x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks
show_sd(in_sd=x)
show_sd(in_sd=image)
print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o)
print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p)
return x

0 comments on commit dac0603

Please sign in to comment.