Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use shutil instead of setuptools/distutils to copy dirs #1349

Merged
merged 10 commits into from
Nov 28, 2022
27 changes: 24 additions & 3 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import os
import pathlib
import re
import shutil
import sys
import tempfile
import typing
from abc import abstractmethod
from distutils import dir_util
from shutil import copyfile
from typing import Dict, Union
from uuid import UUID
Expand All @@ -39,6 +40,9 @@
from flytekit.interfaces.random import random
from flytekit.loggers import logger

CURRENT_PYTHON = sys.version_info[:2]
THREE_SEVEN = (3, 7)


class UnsupportedPersistenceOp(Exception):
"""
Expand Down Expand Up @@ -221,17 +225,34 @@ def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, N
def exists(self, path: str):
return os.path.exists(self.strip_file_header(path))

def copy_tree(self, from_path: str, to_path: str):
# TODO: Remove this code after support for 3.7 is dropped and inline this function back
# 3.7 doesn't have dirs_exist_ok
if CURRENT_PYTHON == THREE_SEVEN:
tp = pathlib.Path(self.strip_file_header(to_path))
if tp.exists():
if not tp.is_dir():
raise ValueError("not a dir")
files = os.listdir(tp)
if len(files) != 0:
logger.debug(f"Deleting existing target dir {tp} with files {files}")
shutil.rmtree(tp)
shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path))
else:
# copytree will overwrite existing files in the to_path
shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True)

def get(self, from_path: str, to_path: str, recursive: bool = False):
if from_path != to_path:
if recursive:
dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path))
self.copy_tree(from_path, to_path)
else:
copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path))

def put(self, from_path: str, to_path: str, recursive: bool = False):
if from_path != to_path:
if recursive:
dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path))
self.copy_tree(from_path, to_path)
else:
# Emulate s3's flat storage by automatically creating directory path
self._make_local_path(os.path.dirname(self.strip_file_header(to_path)))
Expand Down
7 changes: 6 additions & 1 deletion flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.loggers import logger
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType

Expand Down Expand Up @@ -361,7 +362,11 @@ def to_literal(
remote_path = python_val.remote_path
if remote_path is None or remote_path == "":
remote_path = ctx.file_access.get_random_remote_path()
ctx.file_access.put_data(python_val.local_path, remote_path, is_multipart=True)
if python_val.supported_mode == SchemaOpenMode.READ and not python_val._downloaded:
# This means the local path is empty. Don't try to overwrite the remote data
logger.debug(f"Skipping upload for {python_val} because it was never downloaded.")
else:
ctx.file_access.put_data(python_val.local_path, remote_path, is_multipart=True)
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type))))

schema = python_type(
Expand Down
80 changes: 40 additions & 40 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,46 +50,46 @@ def test_imperative_wf():
assert result.exit_code == 0


# def test_pyflyte_run_cli():
# runner = CliRunner()
# result = runner.invoke(
# pyflyte.main,
# [
# "run",
# WORKFLOW_FILE,
# "my_wf",
# "--a",
# "1",
# "--b",
# "Hello",
# "--c",
# "1.1",
# "--d",
# '{"i":1,"a":["h","e"]}',
# "--e",
# "[1,2,3]",
# "--f",
# '{"x":1.0, "y":2.0}',
# "--g",
# os.path.join(DIR_NAME, "testdata/df.parquet"),
# "--i",
# "2020-05-01",
# "--j",
# "20H",
# "--k",
# "RED",
# "--l",
# '{"hello": "world"}',
# "--remote",
# os.path.join(DIR_NAME, "testdata"),
# "--image",
# os.path.join(DIR_NAME, "testdata"),
# "--h",
# ],
# catch_exceptions=False,
# )
# print(result.stdout)
# assert result.exit_code == 0
def test_pyflyte_run_cli():
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
[
"run",
WORKFLOW_FILE,
"my_wf",
"--a",
"1",
"--b",
"Hello",
"--c",
"1.1",
"--d",
'{"i":1,"a":["h","e"]}',
"--e",
"[1,2,3]",
"--f",
'{"x":1.0, "y":2.0}',
"--g",
os.path.join(DIR_NAME, "testdata/df.parquet"),
"--i",
"2020-05-01",
"--j",
"20H",
"--k",
"RED",
"--l",
'{"hello": "world"}',
"--remote",
os.path.join(DIR_NAME, "testdata"),
"--image",
os.path.join(DIR_NAME, "testdata"),
"--h",
],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0


@pytest.mark.parametrize(
Expand Down
1 change: 1 addition & 0 deletions tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def schema_to_sd_wf() -> (pd.DataFrame, pd.DataFrame):


def test_structured_dataset_wf():
# res = sd_to_schema_wf()
assert_frame_equal(sd_wf(), subset_df)
assert_frame_equal(sd_to_schema_wf(), superset_df)
assert_frame_equal(schema_to_sd_wf()[0], subset_df)
Expand Down