Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Flyte Types Upload Issues in Default Input #2907

Merged
merged 24 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \
pandas \
pillow \
plotly \
pyarrow \
pygments \
scikit-learn \
ydata-profiling \
Expand Down
10 changes: 7 additions & 3 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,13 @@ def to_click_option(
description_extra = ""
if literal_var.type.simple == SimpleType.STRUCT:
if default_val and not isinstance(default_val, ArtifactQuery):
if type(default_val) == dict or type(default_val) == list:
default_val = json.dumps(default_val)
else:
"""
1. Convert default_val to a JSON string for click.Option, which uses json.loads to parse it.
2. Set a new context with remote access to allow Flyte types (e.g., files) to be uploaded.
3. Use FlyteContextManager for Flyte Types with custom serialization.
If no custom logic exists, fall back to json.dumps.
"""
with FlyteContextManager.with_context(flyte_ctx.new_builder()):
encoder = JSONEncoder(python_type)
default_val = encoder.encode(default_val)
if literal_var.type.metadata:
Expand Down
3 changes: 2 additions & 1 deletion flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,8 @@ async def async_to_literal(
remote_directory = ctx.file_access.get_random_remote_directory()
if not pathlib.Path(source_path).is_dir():
raise FlyteAssertion("Expected a directory. {} is not a directory".format(source_path))
await ctx.file_access.async_put_data(
# remote_directory will convert the path from `flyte://` to `s3://` or `gs://`
remote_directory = await ctx.file_access.async_put_data(
source_path, remote_directory, is_multipart=True, batch_size=batch_size
)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory)))
Expand Down
1 change: 1 addition & 0 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ async def async_to_literal(
)
else:
remote_path = await ctx.file_access.async_put_raw_data(source_path, **headers)
# If the source path is a local file, the remote path will be a remote storage path.
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path)))))
# If not uploading, then we can only take the original source path as the uri.
else:
Expand Down
46 changes: 39 additions & 7 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import tempfile
import time
import typing

import re
import joblib
from urllib.parse import urlparse
import uuid
import pytest
from mock import mock, patch

from flytekit import LaunchPlan, kwtypes
from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase
from flytekit.configuration import Config, ImageConfig, SerializationSettings
from flytekit.core.launch_plan import reference_launch_plan
from flytekit.core.task import reference_task
Expand Down Expand Up @@ -62,7 +62,8 @@ def register():
assert out.returncode == 0


def run(file_name, wf_name, *args):
def run(file_name, wf_name, *args) -> str:
# Copy the environment and set the environment variable
out = subprocess.run(
[
"pyflyte",
Expand All @@ -82,9 +83,20 @@ def run(file_name, wf_name, *args):
MODULE_PATH / file_name,
wf_name,
*args,
]
],
capture_output=True, # Capture the output streams
text=True, # Return outputs as strings (not bytes)
)
assert out.returncode == 0
assert out.returncode == 0, (f"Command failed with return code {out.returncode}.\n"
f"Standard Output: {out.stdout}\n"
f"Standard Error: {out.stderr}\n")

match = re.search(r'executions/([a-zA-Z0-9]+)', out.stdout)
if match:
execution_id = match.group(1)
return execution_id

return "Unknown"


def test_remote_run():
Expand All @@ -93,7 +105,28 @@ def test_remote_run():

# run twice to make sure it will register a new version of the workflow.
run("default_lp.py", "my_wf")
run("default_lp.py", "my_wf")


def test_generic_idl_flytetypes():
os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "true"
# default inputs for flyte types in dataclass
execution_id = run("generic_idl_flytetypes.py", "wf")
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.fetch_execution(name=execution_id)
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
print("Execution Error:", execution.error)
assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}"
os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "false"


def test_msgpack_idl_flytetypes():
# default inputs for flyte types in dataclass
execution_id = run("msgpack_idl_flytetypes.py", "wf")
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.fetch_execution(name=execution_id)
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
print("Execution Error:", execution.error)
assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}"


def test_fetch_execute_launch_plan(register):
Expand Down Expand Up @@ -736,7 +769,6 @@ def test_execute_workflow_remote_fn_with_maptask():
)
assert out.outputs["o0"] == [4, 5, 6]


def test_register_wf_fast(register):
from workflows.basic.subworkflows import parent_wf

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import typing
import os
from dataclasses import dataclass, fields, field
from typing import Dict, List
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset
from flytekit.types.directory import FlyteDirectory
from flytekit import task, workflow, ImageSpec
import datetime
from enum import Enum
import pandas as pd

@dataclass
class DC:
ff: FlyteFile
sd: StructuredDataset
fd: FlyteDirectory


@task
def t1(dc: DC = DC(ff=FlyteFile(os.path.realpath(__file__)),
sd=StructuredDataset(
uri="tests/flytekit/integration/remote/workflows/basic/data/df.parquet",
file_format="parquet"),
fd=FlyteDirectory("tests/flytekit/integration/remote/workflows/basic/data/")
)):

with open(dc.ff, "r") as f:
print("File Content: ", f.read())

print("sd:", dc.sd.open(pd.DataFrame).all())

df_path = os.path.join(dc.fd.path, "df.parquet")
print("fd: ", os.path.isdir(df_path))

return dc

@workflow
def wf(dc: DC = DC(ff=FlyteFile(os.path.realpath(__file__)),
sd=StructuredDataset(
uri="tests/flytekit/integration/remote/workflows/basic/data/df.parquet",
file_format="parquet"),
fd=FlyteDirectory("tests/flytekit/integration/remote/workflows/basic/data/")
)):
t1(dc=dc)

if __name__ == "__main__":
wf()
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import typing
import os
from dataclasses import dataclass, fields, field
from typing import Dict, List
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset
from flytekit.types.directory import FlyteDirectory
from flytekit import task, workflow, ImageSpec
import datetime
from enum import Enum
import pandas as pd

@dataclass
class DC:
ff: FlyteFile
sd: StructuredDataset
fd: FlyteDirectory


@task
def t1(dc: DC = DC(ff=FlyteFile(os.path.realpath(__file__)),
sd=StructuredDataset(
uri="tests/flytekit/integration/remote/workflows/basic/data/df.parquet",
file_format="parquet"),
fd=FlyteDirectory("tests/flytekit/integration/remote/workflows/basic/data/")
)):

with open(dc.ff, "r") as f:
print("File Content: ", f.read())

print("sd:", dc.sd.open(pd.DataFrame).all())

df_path = os.path.join(dc.fd.path, "df.parquet")
print("fd: ", os.path.isdir(df_path))

return dc

@workflow
def wf(dc: DC = DC(ff=FlyteFile(os.path.realpath(__file__)),
sd=StructuredDataset(
uri="tests/flytekit/integration/remote/workflows/basic/data/df.parquet",
file_format="parquet"),
fd=FlyteDirectory("tests/flytekit/integration/remote/workflows/basic/data/")
)):
t1(dc=dc)

if __name__ == "__main__":
wf()
Loading