Skip to content

Commit

Permalink
Use shutil instead of setuptools/distutils to copy dirs (#1349)
Browse files Browse the repository at this point in the history
# TL;DR
The bump up recently to setuptools version from 65.5.1 to 65.6.0 caused one of our click unit tests to fail - some interaction between the directory copy utility provided by setuptools.distutils and the python logging library.

Getting around this by using shutil instead.

Also realized that
* `FlyteSchema` transformers doesn't currently handle well you
  ```python
  @task
  def t3(df: FlyteSchema) -> FlyteSchema:
      return df
  ```
  Because it's never downloaded, there's nothing to upload back.  Added a special case to prevent this.
* [Nested dataclasses](https://github.com/flyteorg/flytekit/blob/430795d9ee4aa48957554a6eb1446fa80993681a/tests/flytekit/unit/core/test_type_engine.py#L1394) will get transformed multiple times (at least the same `FlyteSchema` instance will) by the type engine.  Something to fix for the future.

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Nov 28, 2022
1 parent 7cf5b68 commit a47e383
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 53 deletions.
29 changes: 25 additions & 4 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,24 @@
import os
import pathlib
import re
import shutil
import sys
import tempfile
import typing
from abc import abstractmethod
from distutils import dir_util
from shutil import copyfile
from typing import Dict, Union
from uuid import UUID

from flytekit.configuration import DataConfig
from flytekit.core.utils import PerformanceTimer
from flytekit.exceptions.user import FlyteAssertion
from flytekit.exceptions.user import FlyteAssertion, FlyteValueException
from flytekit.interfaces.random import random
from flytekit.loggers import logger

CURRENT_PYTHON = sys.version_info[:2]
THREE_SEVEN = (3, 7)


class UnsupportedPersistenceOp(Exception):
"""
Expand Down Expand Up @@ -221,17 +225,34 @@ def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, N
def exists(self, path: str):
return os.path.exists(self.strip_file_header(path))

def copy_tree(self, from_path: str, to_path: str):
# TODO: Remove this code after support for 3.7 is dropped and inline this function back
# 3.7 doesn't have dirs_exist_ok
if CURRENT_PYTHON == THREE_SEVEN:
tp = pathlib.Path(self.strip_file_header(to_path))
if tp.exists():
if not tp.is_dir():
raise FlyteValueException(tp, f"Target {tp} exists but is not a dir")
files = os.listdir(tp)
if len(files) != 0:
logger.debug(f"Deleting existing target dir {tp} with files {files}")
shutil.rmtree(tp)
shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path))
else:
# copytree will overwrite existing files in the to_path
shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True)

def get(self, from_path: str, to_path: str, recursive: bool = False):
if from_path != to_path:
if recursive:
dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path))
self.copy_tree(from_path, to_path)
else:
copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path))

def put(self, from_path: str, to_path: str, recursive: bool = False):
if from_path != to_path:
if recursive:
dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path))
self.copy_tree(from_path, to_path)
else:
# Emulate s3's flat storage by automatically creating directory path
self._make_local_path(os.path.dirname(self.strip_file_header(to_path)))
Expand Down
7 changes: 6 additions & 1 deletion flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.loggers import logger
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType

Expand Down Expand Up @@ -361,7 +362,11 @@ def to_literal(
remote_path = python_val.remote_path
if remote_path is None or remote_path == "":
remote_path = ctx.file_access.get_random_remote_path()
ctx.file_access.put_data(python_val.local_path, remote_path, is_multipart=True)
if python_val.supported_mode == SchemaOpenMode.READ and not python_val._downloaded:
# This means the local path is empty. Don't try to overwrite the remote data
logger.debug(f"Skipping upload for {python_val} because it was never downloaded.")
else:
ctx.file_access.put_data(python_val.local_path, remote_path, is_multipart=True)
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type))))

schema = python_type(
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-dbt/flytekitplugins/dbt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,4 @@ class DBTFreshnessOutput(BaseDBTOutput):
Raw value of DBT's ``sources.json``.
"""

raw_sources: str
raw_sources: str
7 changes: 1 addition & 6 deletions plugins/flytekit-dbt/flytekitplugins/dbt/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,4 @@ def my_workflow() -> DBTFreshnessOutput:
with open(sources_path) as file:
sources = file.read()


return DBTFreshnessOutput(
command=full_command,
exit_code=exit_code,
raw_sources=sources
)
return DBTFreshnessOutput(command=full_command, exit_code=exit_code, raw_sources=sources)
2 changes: 1 addition & 1 deletion plugins/flytekit-dbt/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,4 @@ def test_task_output(self):
with open(f"{DBT_PROJECT_DIR}/target/sources.json", "r") as fp:
exp_sources = fp.read()

assert output.raw_sources == exp_sources
assert output.raw_sources == exp_sources
80 changes: 40 additions & 40 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,46 +50,46 @@ def test_imperative_wf():
assert result.exit_code == 0


# def test_pyflyte_run_cli():
# runner = CliRunner()
# result = runner.invoke(
# pyflyte.main,
# [
# "run",
# WORKFLOW_FILE,
# "my_wf",
# "--a",
# "1",
# "--b",
# "Hello",
# "--c",
# "1.1",
# "--d",
# '{"i":1,"a":["h","e"]}',
# "--e",
# "[1,2,3]",
# "--f",
# '{"x":1.0, "y":2.0}',
# "--g",
# os.path.join(DIR_NAME, "testdata/df.parquet"),
# "--i",
# "2020-05-01",
# "--j",
# "20H",
# "--k",
# "RED",
# "--l",
# '{"hello": "world"}',
# "--remote",
# os.path.join(DIR_NAME, "testdata"),
# "--image",
# os.path.join(DIR_NAME, "testdata"),
# "--h",
# ],
# catch_exceptions=False,
# )
# print(result.stdout)
# assert result.exit_code == 0
def test_pyflyte_run_cli():
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
[
"run",
WORKFLOW_FILE,
"my_wf",
"--a",
"1",
"--b",
"Hello",
"--c",
"1.1",
"--d",
'{"i":1,"a":["h","e"]}',
"--e",
"[1,2,3]",
"--f",
'{"x":1.0, "y":2.0}',
"--g",
os.path.join(DIR_NAME, "testdata/df.parquet"),
"--i",
"2020-05-01",
"--j",
"20H",
"--k",
"RED",
"--l",
'{"hello": "world"}',
"--remote",
os.path.join(DIR_NAME, "testdata"),
"--image",
os.path.join(DIR_NAME, "testdata"),
"--h",
],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0


@pytest.mark.parametrize(
Expand Down

0 comments on commit a47e383

Please sign in to comment.