From 8dfa2e5c59898e289bec397c884b0786c0643042 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 29 Mar 2024 13:00:08 -0700 Subject: [PATCH] Gz encoding (#2306) * wip, make a sandbox test Signed-off-by: Yee Hing Tong * gzip encoding Signed-off-by: Yee Hing Tong * revert Signed-off-by: Yee Hing Tong * fix test Signed-off-by: Yee Hing Tong * lint Signed-off-by: Yee Hing Tong * test Signed-off-by: Yee Hing Tong --------- Signed-off-by: Yee Hing Tong Signed-off-by: Jan Fiedler --- flytekit/types/file/file.py | 11 ++++++-- tests/flytekit/unit/core/test_data.py | 28 +++++++++++++++++++++ tests/flytekit/unit/core/test_flyte_file.py | 5 ++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index de4e49cdaf..73a835a20f 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -446,15 +446,22 @@ def to_literal( # If we're uploading something, that means that the uri should always point to the upload destination. if should_upload: + headers = self.get_additional_headers(source_path) if remote_path is not None: - remote_path = ctx.file_access.put_data(source_path, remote_path, is_multipart=False) + remote_path = ctx.file_access.put_data(source_path, remote_path, is_multipart=False, **headers) else: - remote_path = ctx.file_access.put_raw_data(source_path) + remote_path = ctx.file_access.put_raw_data(source_path, **headers) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) # If not uploading, then we can only take the original source path as the uri. else: return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=source_path))) + @staticmethod + def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, str]: + if str(source_path).endswith(".gz"): + return {"ContentEncoding": "gzip"} + return {} + def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] ) -> FlyteFile: diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 99963621a7..aa308d7929 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -12,7 +12,9 @@ from flytekit.configuration import Config, DataConfig, S3Config from flytekit.core.context_manager import FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider, get_fsspec_storage_options, s3_setup_args +from flytekit.core.type_engine import TypeEngine from flytekit.types.directory.types import FlyteDirectory +from flytekit.types.file import FlyteFile local = fsspec.filesystem("file") root = os.path.abspath(os.sep) @@ -418,3 +420,29 @@ def test_walk_local_copy_to_s3(source_folder): new_crawl = fd.crawl() new_suffixes = [y for x, y in new_crawl] assert len(new_suffixes) == 2 # should have written two files + + +@pytest.mark.sandbox_test +def test_s3_metadata(): + dc = Config.for_sandbox().data_config + random_folder = UUID(int=random.getrandbits(64)).hex + raw_output = f"s3://my-s3-bucket/testing/metadata_test/{random_folder}" + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + _, local_zip = tempfile.mkstemp(suffix=".gz") + with open(local_zip, "w") as f: + f.write("hello world") + + # Test writing file + ff = FlyteFile(path=local_zip) + ff2 = FlyteFile(path=local_zip, remote_path=f"{raw_output}/test.gz") + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: + lt = TypeEngine.to_literal_type(FlyteFile) + TypeEngine.to_literal(ctx, ff, FlyteFile, lt) + TypeEngine.to_literal(ctx, ff2, FlyteFile, lt) + + fd = FlyteDirectory(path=raw_output) + res = fd.crawl() + res = [(x, y) for x, y in res] + files = [os.path.join(x, y) for x, y in res] + assert len(files) == 2 diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index a12c414f35..420279128b 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -644,3 +644,8 @@ def test_join(): fs = ctx.file_access.get_filesystem("s3") f = ctx.file_access.join("s3://a", "b", "c", fs=fs) assert f == fs.sep.join(["s3://a", "b", "c"]) + + +def test_headers(): + assert FlyteFilePathTransformer.get_additional_headers("xyz") == {} + assert len(FlyteFilePathTransformer.get_additional_headers(".gz")) == 1