Skip to content

Commit

Permalink
Fix FlyteFS (#2208)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Mar 12, 2024
1 parent 05fc527 commit 7144ae9
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 69 deletions.
32 changes: 19 additions & 13 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ def get_upload_signed_url(
filename: typing.Optional[str] = None,
expires_in: typing.Optional[datetime.timedelta] = None,
filename_root: typing.Optional[str] = None,
add_content_md5_metadata: bool = True,
) -> _data_proxy_pb2.CreateUploadLocationResponse:
"""
Get a signed url to be used during fast registration
Expand All @@ -1000,22 +1001,27 @@ def get_upload_signed_url(
the generated url
:param filename_root: If provided will be used as the root of the filename. If not, Admin will use a hash
This option is useful when uploading a series of files that you want to be grouped together.
:param add_content_md5_metadata: If true, the content md5 will be added to the metadata in signed URL
:rtype: flyteidl.service.dataproxy_pb2.CreateUploadLocationResponse
"""
expires_in_pb = None
if expires_in:
expires_in_pb = Duration()
expires_in_pb.FromTimedelta(expires_in)
return super(SynchronousFlyteClient, self).create_upload_location(
_data_proxy_pb2.CreateUploadLocationRequest(
project=project,
domain=domain,
content_md5=content_md5,
filename=filename,
expires_in=expires_in_pb,
filename_root=filename_root,
try:
expires_in_pb = None
if expires_in:
expires_in_pb = Duration()
expires_in_pb.FromTimedelta(expires_in)
return super(SynchronousFlyteClient, self).create_upload_location(
_data_proxy_pb2.CreateUploadLocationRequest(
project=project,
domain=domain,
content_md5=content_md5,
filename=filename,
expires_in=expires_in_pb,
filename_root=filename_root,
add_content_md5_metadata=add_content_md5_metadata,
)
)
)
except Exception as e:
raise RuntimeError(f"Failed to get signed url for {filename}, reason: {e}")

def get_download_signed_url(
self, native_url: str, expires_in: datetime.timedelta = None
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]:
if get_protocol(f) == "file":
local_fs = fsspec.filesystem("file")
if local_fs.exists(f) and local_fs.isdir(f):
print("Adding trailing sep to")
logger.debug("Adding trailing sep to")
f = os.path.join(f, "")
else:
print("Not adding trailing sep")
logger.debug("Not adding trailing sep")
else:
f = os.path.join(f, "")
t = os.path.join(t, "")
Expand Down
4 changes: 4 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,13 +806,15 @@ def upload_file(
to_upload: pathlib.Path,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
filename_root: typing.Optional[str] = None,
) -> typing.Tuple[bytes, str]:
"""
Function will use remote's client to hash and then upload the file using Admin's data proxy service.
:param to_upload: Must be a single file
:param project: Project to upload under, if not supplied will use the remote's default
:param domain: Domain to upload under, if not specified will use the remote's default
:param filename_root: If provided will be used as the root of the filename. If not, Admin will use a hash
:return: The uploaded location.
"""
if not to_upload.is_file():
Expand All @@ -825,9 +827,11 @@ def upload_file(
domain=domain or self.default_domain,
content_md5=md5_bytes,
filename=to_upload.name,
filename_root=filename_root,
)

extra_headers = self.get_extra_headers_for_protocol(upload_location.native_url)
extra_headers.update(upload_location.headers)
encoded_md5 = b64encode(md5_bytes)
with open(str(to_upload), "+rb") as local_file:
content = local_file.read()
Expand Down
58 changes: 6 additions & 52 deletions flytekit/remote/remote_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
import random
import threading
import typing
from base64 import b64encode
from uuid import UUID

import fsspec
import requests
from flyteidl.service.dataproxy_pb2 import CreateUploadLocationResponse
from fsspec.callbacks import NoOpCallback
from fsspec.implementations.http import HTTPFileSystem
from fsspec.utils import get_protocol
Expand Down Expand Up @@ -73,13 +71,6 @@ def _upload_chunk(self, final=False):
self.buffer.seek(0)
data = self.buffer.read()

# h = hashlib.md5()
# h.update(data)
# md5 = h.digest()
# l = len(data)
#
# headers = {"Content-Length": str(l), "Content-MD5": md5}

try:
res = self._remote.client.get_upload_signed_url(
self._remote.default_project,
Expand Down Expand Up @@ -132,32 +123,6 @@ async def _get_file(self, rpath, lpath, **kwargs):
"""
raise NotImplementedError("FlyteFS currently doesn't support downloading files.")

def get_upload_link(
self,
local_file_path: str,
remote_file_part: str,
prefix: str,
hashes: HashStructure,
) -> typing.Tuple[CreateUploadLocationResponse, int, bytes]:
if not pathlib.Path(local_file_path).exists():
raise AssertionError(f"File {local_file_path} does not exist")

p = pathlib.Path(typing.cast(str, local_file_path))
k = str(p.absolute())
if k in hashes:
md5_bytes, content_length = hashes[k]
else:
raise AssertionError(f"File {local_file_path} not found in hashes")
upload_response = self._remote.client.get_upload_signed_url(
self._remote.default_project,
self._remote.default_domain,
md5_bytes,
remote_file_part,
filename_root=prefix,
)
logger.debug(f"Resolved signed url {local_file_path} to {upload_response.native_url}")
return upload_response, content_length, md5_bytes

async def _put_file(
self,
lpath,
Expand All @@ -171,20 +136,11 @@ async def _put_file(
fsspec will call this method to upload a file. If recursive, rpath will already be individual files.
Make the request and upload, but then how do we get the s3 paths back to the user?
"""
# remove from kwargs otherwise super() call will fail
p = kwargs.pop(_PREFIX_KEY)
hashes = kwargs.pop(_HASHES_KEY)
# Parse rpath, strip out everything that doesn't make sense.
rpath = rpath.replace(f"{REMOTE_PLACEHOLDER}/", "", 1)
resp, content_length, md5_bytes = self.get_upload_link(lpath, rpath, p, hashes)

headers = {"Content-Length": str(content_length), "Content-MD5": b64encode(md5_bytes).decode("utf-8")}
kwargs["headers"] = headers
rpath = resp.signed_url
FlytePathResolver.add_mapping(rpath, resp.native_url)
logger.debug(f"Writing {lpath} to {rpath}")
await super()._put_file(lpath, rpath, chunk_size, callback=callback, method=method, **kwargs)
return resp.native_url
prefix = kwargs.pop(_PREFIX_KEY)
_, native_url = self._remote.upload_file(
pathlib.Path(lpath), self._remote.default_project, self._remote.default_domain, prefix
)
return native_url

@staticmethod
def extract_common(native_urls: typing.List[str]) -> str:
Expand Down Expand Up @@ -266,9 +222,6 @@ async def _put(
cp file.txt flyte://data/...
rpath gets ignored, so it doesn't matter what it is.
"""
if rpath != REMOTE_PLACEHOLDER:
logger.debug(f"FlyteFS doesn't yet support specifying full remote path, ignoring {rpath}")

# Hash everything at the top level
file_info = self.get_hashes_and_lengths(pathlib.Path(lpath))
prefix = self.get_filename_root(file_info)
Expand All @@ -278,6 +231,7 @@ async def _put(
res = await super()._put(lpath, REMOTE_PLACEHOLDER, recursive, callback, batch_size, **kwargs)
if isinstance(res, list):
res = self.extract_common(res)
FlytePathResolver.add_mapping(rpath.strip(os.path.sep), res)
return res

async def _isdir(self, path):
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/clients/test_friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def test_list_projects_paginated(mock_raw_list_projects):
@_mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.create_upload_location")
def test_create_upload_location(mock_raw_create_upload_location):
client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True))
client.get_upload_signed_url("foo", "bar", bytes(), "baz.qux", timedelta(minutes=42))
client.get_upload_signed_url("foo", "bar", bytes(), "baz.qux", timedelta(minutes=42), add_content_md5_metadata=True)
duration_pb = Duration()
duration_pb.FromTimedelta(timedelta(minutes=42))
create_upload_location_request = _data_proxy_pb2.CreateUploadLocationRequest(
project="foo", domain="bar", filename="baz.qux", expires_in=duration_pb
project="foo", domain="bar", filename="baz.qux", expires_in=duration_pb, add_content_md5_metadata=True
)
mock_raw_create_upload_location.assert_called_with(create_upload_location_request)

0 comments on commit 7144ae9

Please sign in to comment.