Skip to content

Commit

Permalink
Add support nested FlyteFile in pyflyte run (#1587)
Browse files Browse the repository at this point in the history
* Add support nested FlyteFile in pyflyte run

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Apr 17, 2023
1 parent e865db5 commit 401c505
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 27 deletions.
62 changes: 43 additions & 19 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
from flytekit.models import literals
from flytekit.models.interface import Variable
from flytekit.models.literals import Blob, BlobMetadata, Primitive, Union
from flytekit.models.literals import Blob, BlobMetadata, LiteralCollection, LiteralMap, Primitive, Union
from flytekit.models.types import LiteralType, SimpleType
from flytekit.remote.executions import FlyteWorkflowExecution
from flytekit.tools import module_loader, script_mode
Expand Down Expand Up @@ -215,24 +215,26 @@ def is_bool(self) -> bool:
return self._literal_type.simple == SimpleType.BOOLEAN
return False

def get_uri_for_dir(self, value: Directory, remote_filename: typing.Optional[str] = None):
def get_uri_for_dir(
self, ctx: typing.Optional[click.Context], value: Directory, remote_filename: typing.Optional[str] = None
):
uri = value.dir_path

if self._remote and value.local:
md5, _ = script_mode.hash_file(value.local_file)
if not remote_filename:
remote_filename = value.local_file.name
df_remote_location = self._create_upload_fn(filename=remote_filename, content_md5=md5)
self._flyte_ctx.file_access.put_data(value.local_file, df_remote_location.signed_url)
uri = df_remote_location.native_url[: -len(remote_filename)]
remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY]
_, native_url = remote.upload_file(value.local_file)
uri = native_url[: -len(remote_filename)]

return uri

def convert_to_structured_dataset(
self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: Directory
) -> Literal:

uri = self.get_uri_for_dir(value, "00000.parquet")
uri = self.get_uri_for_dir(ctx, value, "00000.parquet")

lit = Literal(
scalar=Scalar(
Expand All @@ -254,15 +256,13 @@ def convert_to_blob(
value: typing.Union[Directory, FileParam],
) -> Literal:
if isinstance(value, Directory):
uri = self.get_uri_for_dir(value)
uri = self.get_uri_for_dir(ctx, value)
else:
uri = value.filepath
if self._remote and value.local:
fp = pathlib.Path(value.filepath)
md5, _ = script_mode.hash_file(value.filepath)
df_remote_location = self._create_upload_fn(filename=fp.name, content_md5=md5)
self._flyte_ctx.file_access.put_data(fp, df_remote_location.signed_url)
uri = df_remote_location.native_url
remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY]
_, uri = remote.upload_file(fp)

lit = Literal(
scalar=Scalar(
Expand Down Expand Up @@ -308,14 +308,38 @@ def convert_to_literal(
if self._literal_type.blob:
return self.convert_to_blob(ctx, param, value)

if self._literal_type.collection_type or self._literal_type.map_value_type:
# TODO Does not support nested flytefile, flyteschema types
v = json.loads(value) if isinstance(value, str) else value
if self._literal_type.collection_type and not isinstance(v, list):
raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(v)}")
if self._literal_type.map_value_type and not isinstance(v, dict):
raise click.BadParameter("Expected json map '{}', parsed value is {%s}" % type(v))
return TypeEngine.to_literal(self._flyte_ctx, v, self._python_type, self._literal_type)
if self._literal_type.collection_type:
python_value = json.loads(value) if isinstance(value, str) else value
if not isinstance(python_value, list):
raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(python_value)}")
converter = FlyteLiteralConverter(
ctx,
self._flyte_ctx,
self._literal_type.collection_type,
type(python_value[0]),
self._create_upload_fn,
)
lt = Literal(collection=LiteralCollection([]))
for v in python_value:
click_val = converter._click_type.convert(v, param, ctx)
lt.collection.literals.append(converter.convert_to_literal(ctx, param, click_val))
return lt
if self._literal_type.map_value_type:
python_value = json.loads(value) if isinstance(value, str) else value
if not isinstance(python_value, dict):
raise click.BadParameter("Expected json map '{}', parsed value is {%s}" % type(python_value))
converter = FlyteLiteralConverter(
ctx,
self._flyte_ctx,
self._literal_type.map_value_type,
type(python_value[next(iter(python_value))]),
self._create_upload_fn,
)
lt = Literal(map=LiteralMap({}))
for k, v in python_value.items():
click_val = converter._click_type.convert(v, param, ctx)
lt.map.literals[k] = converter.convert_to_literal(ctx, param, click_val)
return lt

if self._literal_type.union_type:
return self.convert_to_union(ctx, param, value)
Expand Down
1 change: 1 addition & 0 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul
:param is_multipart:
"""
try:
local_path = str(local_path)
with PerformanceTimer(f"Writing ({local_path} -> {remote_path})"):
self.put(cast(str, local_path), remote_path, recursive=is_multipart)
except Exception as ex:
Expand Down
6 changes: 3 additions & 3 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,9 +714,9 @@ def fast_package(self, root: os.PathLike, deref_symlinks: bool = True, output: s
md5_bytes, _ = hash_file(pathlib.Path(zip_file))

# Upload zip file to Admin using FlyteRemote.
return self._upload_file(pathlib.Path(zip_file))
return self.upload_file(pathlib.Path(zip_file))

def _upload_file(
def upload_file(
self, to_upload: pathlib.Path, project: typing.Optional[str] = None, domain: typing.Optional[str] = None
) -> typing.Tuple[bytes, str]:
"""
Expand Down Expand Up @@ -824,7 +824,7 @@ def register_script(
with tempfile.TemporaryDirectory() as tmp_dir:
archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz"))
compress_scripts(source_path, str(archive_fname), module_name)
md5_bytes, upload_native_url = self._upload_file(
md5_bytes, upload_native_url = self.upload_file(
archive_fname, project or self.default_project, domain or self.default_domain
)

Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/cli/pyflyte/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_non_fast_register(mock_client, mock_remote):
def test_non_fast_register_require_version(mock_client, mock_remote):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url"
mock_remote.return_value.upload_file.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
context_manager.FlyteEntities.entities.clear()
with runner.isolated_filesystem():
Expand Down
8 changes: 7 additions & 1 deletion tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import json
import os
import pathlib
import typing
Expand Down Expand Up @@ -65,6 +66,7 @@ def test_imperative_wf():

def test_pyflyte_run_cli():
runner = CliRunner()
parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet")
result = runner.invoke(
pyflyte.main,
[
Expand All @@ -84,7 +86,7 @@ def test_pyflyte_run_cli():
"--f",
'{"x":1.0, "y":2.0}',
"--g",
os.path.join(DIR_NAME, "testdata/df.parquet"),
parquet_file,
"--i",
"2020-05-01",
"--j",
Expand All @@ -98,6 +100,10 @@ def test_pyflyte_run_cli():
"--image",
os.path.join(DIR_NAME, "testdata"),
"--h",
"--n",
json.dumps([{"x": parquet_file}]),
"--o",
json.dumps({"x": [parquet_file]}),
],
catch_exceptions=False,
)
Expand Down
8 changes: 6 additions & 2 deletions tests/flytekit/unit/cli/pyflyte/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def print_all(
k: Color,
l: dict,
m: dict,
n: typing.List[typing.Dict[str, FlyteFile]],
o: typing.Dict[str, typing.List[FlyteFile]],
):
print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}")
print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}")


@task
Expand All @@ -84,12 +86,14 @@ def my_wf(
j: datetime.timedelta,
k: Color,
l: dict,
n: typing.List[typing.Dict[str, FlyteFile]],
o: typing.Dict[str, typing.List[FlyteFile]],
remote: pd.DataFrame,
image: StructuredDataset,
m: dict = {"hello": "world"},
) -> Annotated[StructuredDataset, subset_cols]:
x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks
show_sd(in_sd=x)
show_sd(in_sd=image)
print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m)
print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o)
return x
2 changes: 1 addition & 1 deletion tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_more_stuff(mock_client):
# Can't upload a folder
with pytest.raises(ValueError):
with tempfile.TemporaryDirectory() as tmp_dir:
r._upload_file(pathlib.Path(tmp_dir))
r.upload_file(pathlib.Path(tmp_dir))

serialization_settings = flytekit.configuration.SerializationSettings(
project="project",
Expand Down

0 comments on commit 401c505

Please sign in to comment.