Skip to content

Commit

Permalink
Gz encoding (#2306)
Browse files Browse the repository at this point in the history
* wip, make a sandbox test

Signed-off-by: Yee Hing Tong <[email protected]>

* gzip encoding

Signed-off-by: Yee Hing Tong <[email protected]>

* revert

Signed-off-by: Yee Hing Tong <[email protected]>

* fix test

Signed-off-by: Yee Hing Tong <[email protected]>

* lint

Signed-off-by: Yee Hing Tong <[email protected]>

* test

Signed-off-by: Yee Hing Tong <[email protected]>

---------

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
wild-endeavor authored and fiedlerNr9 committed Jul 25, 2024
1 parent a4672e6 commit 8dfa2e5
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
11 changes: 9 additions & 2 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,15 +446,22 @@ def to_literal(

# If we're uploading something, that means that the uri should always point to the upload destination.
if should_upload:
headers = self.get_additional_headers(source_path)
if remote_path is not None:
remote_path = ctx.file_access.put_data(source_path, remote_path, is_multipart=False)
remote_path = ctx.file_access.put_data(source_path, remote_path, is_multipart=False, **headers)
else:
remote_path = ctx.file_access.put_raw_data(source_path)
remote_path = ctx.file_access.put_raw_data(source_path, **headers)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
# If not uploading, then we can only take the original source path as the uri.
else:
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=source_path)))

@staticmethod
def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, str]:
if str(source_path).endswith(".gz"):
return {"ContentEncoding": "gzip"}
return {}

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike]
) -> FlyteFile:
Expand Down
28 changes: 28 additions & 0 deletions tests/flytekit/unit/core/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from flytekit.configuration import Config, DataConfig, S3Config
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider, get_fsspec_storage_options, s3_setup_args
from flytekit.core.type_engine import TypeEngine
from flytekit.types.directory.types import FlyteDirectory
from flytekit.types.file import FlyteFile

local = fsspec.filesystem("file")
root = os.path.abspath(os.sep)
Expand Down Expand Up @@ -418,3 +420,29 @@ def test_walk_local_copy_to_s3(source_folder):
new_crawl = fd.crawl()
new_suffixes = [y for x, y in new_crawl]
assert len(new_suffixes) == 2 # should have written two files


@pytest.mark.sandbox_test
def test_s3_metadata():
dc = Config.for_sandbox().data_config
random_folder = UUID(int=random.getrandbits(64)).hex
raw_output = f"s3://my-s3-bucket/testing/metadata_test/{random_folder}"
provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc)
_, local_zip = tempfile.mkstemp(suffix=".gz")
with open(local_zip, "w") as f:
f.write("hello world")

# Test writing file
ff = FlyteFile(path=local_zip)
ff2 = FlyteFile(path=local_zip, remote_path=f"{raw_output}/test.gz")
ctx = FlyteContextManager.current_context()
with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx:
lt = TypeEngine.to_literal_type(FlyteFile)
TypeEngine.to_literal(ctx, ff, FlyteFile, lt)
TypeEngine.to_literal(ctx, ff2, FlyteFile, lt)

fd = FlyteDirectory(path=raw_output)
res = fd.crawl()
res = [(x, y) for x, y in res]
files = [os.path.join(x, y) for x, y in res]
assert len(files) == 2
5 changes: 5 additions & 0 deletions tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,8 @@ def test_join():
fs = ctx.file_access.get_filesystem("s3")
f = ctx.file_access.join("s3://a", "b", "c", fs=fs)
assert f == fs.sep.join(["s3://a", "b", "c"])


def test_headers():
assert FlyteFilePathTransformer.get_additional_headers("xyz") == {}
assert len(FlyteFilePathTransformer.get_additional_headers(".gz")) == 1

0 comments on commit 8dfa2e5

Please sign in to comment.