From 401c505ecb511aad6dba70adcae50762b4288521 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 18 Apr 2023 06:47:04 +0800 Subject: [PATCH] Add support nested FlyteFile in pyflyte run (#1587) * Add support nested FlyteFile in pyflyte run Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su --- flytekit/clis/sdk_in_container/run.py | 62 +++++++++++++------ flytekit/core/data_persistence.py | 1 + flytekit/remote/remote.py | 6 +- .../unit/cli/pyflyte/test_register.py | 2 +- tests/flytekit/unit/cli/pyflyte/test_run.py | 8 ++- tests/flytekit/unit/cli/pyflyte/workflow.py | 8 ++- tests/flytekit/unit/remote/test_remote.py | 2 +- 7 files changed, 62 insertions(+), 27 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index c45ec3f150..adfb541e58 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -37,7 +37,7 @@ from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.models import literals from flytekit.models.interface import Variable -from flytekit.models.literals import Blob, BlobMetadata, Primitive, Union +from flytekit.models.literals import Blob, BlobMetadata, LiteralCollection, LiteralMap, Primitive, Union from flytekit.models.types import LiteralType, SimpleType from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader, script_mode @@ -215,16 +215,18 @@ def is_bool(self) -> bool: return self._literal_type.simple == SimpleType.BOOLEAN return False - def get_uri_for_dir(self, value: Directory, remote_filename: typing.Optional[str] = None): + def get_uri_for_dir( + self, ctx: typing.Optional[click.Context], value: Directory, remote_filename: typing.Optional[str] = None + ): uri = value.dir_path if self._remote and value.local: md5, _ = script_mode.hash_file(value.local_file) if not remote_filename: remote_filename = value.local_file.name - df_remote_location = self._create_upload_fn(filename=remote_filename, content_md5=md5) - self._flyte_ctx.file_access.put_data(value.local_file, df_remote_location.signed_url) - uri = df_remote_location.native_url[: -len(remote_filename)] + remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] + _, native_url = remote.upload_file(value.local_file) + uri = native_url[: -len(remote_filename)] return uri @@ -232,7 +234,7 @@ def convert_to_structured_dataset( self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: Directory ) -> Literal: - uri = self.get_uri_for_dir(value, "00000.parquet") + uri = self.get_uri_for_dir(ctx, value, "00000.parquet") lit = Literal( scalar=Scalar( @@ -254,15 +256,13 @@ def convert_to_blob( value: typing.Union[Directory, FileParam], ) -> Literal: if isinstance(value, Directory): - uri = self.get_uri_for_dir(value) + uri = self.get_uri_for_dir(ctx, value) else: uri = value.filepath if self._remote and value.local: fp = pathlib.Path(value.filepath) - md5, _ = script_mode.hash_file(value.filepath) - df_remote_location = self._create_upload_fn(filename=fp.name, content_md5=md5) - self._flyte_ctx.file_access.put_data(fp, df_remote_location.signed_url) - uri = df_remote_location.native_url + remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] + _, uri = remote.upload_file(fp) lit = Literal( scalar=Scalar( @@ -308,14 +308,38 @@ def convert_to_literal( if self._literal_type.blob: return self.convert_to_blob(ctx, param, value) - if self._literal_type.collection_type or self._literal_type.map_value_type: - # TODO Does not support nested flytefile, flyteschema types - v = json.loads(value) if isinstance(value, str) else value - if self._literal_type.collection_type and not isinstance(v, list): - raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(v)}") - if self._literal_type.map_value_type and not isinstance(v, dict): - raise click.BadParameter("Expected json map '{}', parsed value is {%s}" % type(v)) - return TypeEngine.to_literal(self._flyte_ctx, v, self._python_type, self._literal_type) + if self._literal_type.collection_type: + python_value = json.loads(value) if isinstance(value, str) else value + if not isinstance(python_value, list): + raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(python_value)}") + converter = FlyteLiteralConverter( + ctx, + self._flyte_ctx, + self._literal_type.collection_type, + type(python_value[0]), + self._create_upload_fn, + ) + lt = Literal(collection=LiteralCollection([])) + for v in python_value: + click_val = converter._click_type.convert(v, param, ctx) + lt.collection.literals.append(converter.convert_to_literal(ctx, param, click_val)) + return lt + if self._literal_type.map_value_type: + python_value = json.loads(value) if isinstance(value, str) else value + if not isinstance(python_value, dict): + raise click.BadParameter("Expected json map '{}', parsed value is {%s}" % type(python_value)) + converter = FlyteLiteralConverter( + ctx, + self._flyte_ctx, + self._literal_type.map_value_type, + type(python_value[next(iter(python_value))]), + self._create_upload_fn, + ) + lt = Literal(map=LiteralMap({})) + for k, v in python_value.items(): + click_val = converter._click_type.convert(v, param, ctx) + lt.map.literals[k] = converter.convert_to_literal(ctx, param, click_val) + return lt if self._literal_type.union_type: return self.convert_to_union(ctx, param, value) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index ea36689874..068771afa6 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -315,6 +315,7 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul :param is_multipart: """ try: + local_path = str(local_path) with PerformanceTimer(f"Writing ({local_path} -> {remote_path})"): self.put(cast(str, local_path), remote_path, recursive=is_multipart) except Exception as ex: diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 5ed1561da1..8716504dc1 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -714,9 +714,9 @@ def fast_package(self, root: os.PathLike, deref_symlinks: bool = True, output: s md5_bytes, _ = hash_file(pathlib.Path(zip_file)) # Upload zip file to Admin using FlyteRemote. - return self._upload_file(pathlib.Path(zip_file)) + return self.upload_file(pathlib.Path(zip_file)) - def _upload_file( + def upload_file( self, to_upload: pathlib.Path, project: typing.Optional[str] = None, domain: typing.Optional[str] = None ) -> typing.Tuple[bytes, str]: """ @@ -824,7 +824,7 @@ def register_script( with tempfile.TemporaryDirectory() as tmp_dir: archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) compress_scripts(source_path, str(archive_fname), module_name) - md5_bytes, upload_native_url = self._upload_file( + md5_bytes, upload_native_url = self.upload_file( archive_fname, project or self.default_project, domain or self.default_domain ) diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index a6c0bb91d8..0a371b76d1 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -92,7 +92,7 @@ def test_non_fast_register(mock_client, mock_remote): def test_non_fast_register_require_version(mock_client, mock_remote): mock_remote._client = mock_client mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" - mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" + mock_remote.return_value.upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" runner = CliRunner() context_manager.FlyteEntities.entities.clear() with runner.isolated_filesystem(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 74302f0127..de09e21829 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,4 +1,5 @@ import functools +import json import os import pathlib import typing @@ -65,6 +66,7 @@ def test_imperative_wf(): def test_pyflyte_run_cli(): runner = CliRunner() + parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet") result = runner.invoke( pyflyte.main, [ @@ -84,7 +86,7 @@ def test_pyflyte_run_cli(): "--f", '{"x":1.0, "y":2.0}', "--g", - os.path.join(DIR_NAME, "testdata/df.parquet"), + parquet_file, "--i", "2020-05-01", "--j", @@ -98,6 +100,10 @@ def test_pyflyte_run_cli(): "--image", os.path.join(DIR_NAME, "testdata"), "--h", + "--n", + json.dumps([{"x": parquet_file}]), + "--o", + json.dumps({"x": [parquet_file]}), ], catch_exceptions=False, ) diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 85438eb00d..01621a6a01 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -56,8 +56,10 @@ def print_all( k: Color, l: dict, m: dict, + n: typing.List[typing.Dict[str, FlyteFile]], + o: typing.Dict[str, typing.List[FlyteFile]], ): - print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}") + print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}") @task @@ -84,6 +86,8 @@ def my_wf( j: datetime.timedelta, k: Color, l: dict, + n: typing.List[typing.Dict[str, FlyteFile]], + o: typing.Dict[str, typing.List[FlyteFile]], remote: pd.DataFrame, image: StructuredDataset, m: dict = {"hello": "world"}, @@ -91,5 +95,5 @@ def my_wf( x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks show_sd(in_sd=x) show_sd(in_sd=image) - print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m) + print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o) return x diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 5e20eaeee3..5bfd7e4bf6 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -175,7 +175,7 @@ def test_more_stuff(mock_client): # Can't upload a folder with pytest.raises(ValueError): with tempfile.TemporaryDirectory() as tmp_dir: - r._upload_file(pathlib.Path(tmp_dir)) + r.upload_file(pathlib.Path(tmp_dir)) serialization_settings = flytekit.configuration.SerializationSettings( project="project",