Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pathlib support to datasets/utils.py #8200

Merged
merged 15 commits into from
Jan 16, 2024
Merged
35 changes: 29 additions & 6 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ def test_get_redirect_url_max_hops_exceeded(self, mocker):
assert mock.call_count == 1
assert mock.call_args[0][0].full_url == url

def test_check_md5(self):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_check_md5(self, use_pathlib):
fpath = TEST_FILE
if use_pathlib:
fpath = pathlib.Path(fpath)
correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = ""
assert utils.check_md5(fpath, correct_md5)
Expand Down Expand Up @@ -116,7 +119,8 @@ def test_detect_file_type_incompatible(self, file):
utils._detect_file_type(file)

@pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
def test_decompress(self, extension, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress(self, extension, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}{extension}"
Expand All @@ -128,6 +132,8 @@ def create_compressed(root, content="this is the content"):
return compressed, file, content

compressed, file, content = create_compressed(tmpdir)
if use_pathlib:
compressed = pathlib.Path(compressed)

utils._decompress(compressed)

Expand All @@ -140,7 +146,8 @@ def test_decompress_no_compression(self):
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress_remove_finished(self, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
Expand All @@ -151,10 +158,20 @@ def create_compressed(root, content="this is the content"):
return compressed, file, content

compressed, file, content = create_compressed(tmpdir)
print(f"{type(compressed)=}")
if use_pathlib:
compressed = pathlib.Path(compressed)
tmpdir = pathlib.Path(tmpdir)

utils.extract_archive(compressed, tmpdir, remove_finished=True)
extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True)

assert not os.path.exists(compressed)
if use_pathlib:
assert isinstance(extracted_dir, pathlib.Path)
assert isinstance(compressed, pathlib.Path)
else:
assert isinstance(extracted_dir, str)
assert isinstance(compressed, str)

@pytest.mark.parametrize("extension", [".gz", ".xz"])
@pytest.mark.parametrize("remove_finished", [True, False])
Expand All @@ -167,7 +184,8 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m

mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)

def test_extract_zip(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_zip(self, tmpdir, use_pathlib):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")
Expand All @@ -177,6 +195,8 @@ def create_archive(root, content="this is the content"):

return archive, file, content

if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir)

utils.extract_archive(archive, tmpdir)
Expand All @@ -189,7 +209,8 @@ def create_archive(root, content="this is the content"):
@pytest.mark.parametrize(
"extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
)
def test_extract_tar(self, extension, mode, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_tar(self, extension, mode, tmpdir, use_pathlib):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
Expand All @@ -203,6 +224,8 @@ def create_archive(root, extension, mode, content="this is the content"):

return archive, dst, content

if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir, extension, mode)

utils.extract_archive(archive, tmpdir)
Expand Down
69 changes: 46 additions & 23 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def _save_response_content(
content: Iterator[bytes],
destination: str,
destination: Union[str, pathlib.Path],
length: Optional[int] = None,
) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
Expand All @@ -43,12 +43,12 @@
pbar.update(len(chunk))


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)


def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:

Check warning on line 51 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function calculate_md5: fpath changed from str to Union[str, pathlib.Path]

Check warning on line 51 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function calculate_md5: fpath changed from str to Union[str, pathlib.Path]
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
Expand All @@ -62,11 +62,11 @@
return md5.hexdigest()


def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:

Check warning on line 65 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function check_md5: fpath changed from str to Union[str, pathlib.Path]

Check warning on line 65 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function check_md5: fpath changed from str to Union[str, pathlib.Path]
return md5 == calculate_md5(fpath, **kwargs)


def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:

Check warning on line 69 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function check_integrity: fpath changed from str to Union[str, pathlib.Path]

Check warning on line 69 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function check_integrity: fpath changed from str to Union[str, pathlib.Path]
if not os.path.isfile(fpath):
return False
if md5 is None:
Expand Down Expand Up @@ -106,7 +106,7 @@
def download_url(
url: str,
root: Union[str, pathlib.Path],
filename: Optional[str] = None,
filename: Optional[Union[str, pathlib.Path]] = None,

Check warning on line 109 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function download_url: filename changed from Optional[str] to Optional[Union[str, pathlib.Path]]

Check warning on line 109 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function download_url: filename changed from Optional[str] to Optional[Union[str, pathlib.Path]]
md5: Optional[str] = None,
max_redirect_hops: int = 3,
) -> None:
Expand Down Expand Up @@ -159,7 +159,7 @@
raise RuntimeError("File not found or corrupted.")


def list_dir(root: str, prefix: bool = False) -> List[str]:
def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]:

Check warning on line 162 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function list_dir: root changed from str to Union[str, pathlib.Path]

Check warning on line 162 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function list_dir: root changed from str to Union[str, pathlib.Path]
"""List all directories at a given root

Args:
Expand All @@ -174,7 +174,7 @@
return directories


def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]:

Check warning on line 177 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function list_files: root changed from str to Union[str, pathlib.Path]

Check warning on line 177 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function list_files: root changed from str to Union[str, pathlib.Path]
"""List all files ending with a suffix at a given root

Args:
Expand Down Expand Up @@ -208,7 +208,10 @@


def download_file_from_google_drive(
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
file_id: str,
root: Union[str, pathlib.Path],
filename: Optional[Union[str, pathlib.Path]] = None,

Check warning on line 213 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function download_file_from_google_drive: filename changed from Optional[str] to Optional[Union[str, pathlib.Path]]

Check warning on line 213 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function download_file_from_google_drive: filename changed from Optional[str] to Optional[Union[str, pathlib.Path]]
md5: Optional[str] = None,
):
"""Download a Google Drive file from and place it in root.

Expand Down Expand Up @@ -278,7 +281,9 @@
)


def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_tar(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)

Expand All @@ -289,14 +294,16 @@
}


def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_zip(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)


_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
Expand All @@ -312,7 +319,7 @@
}


def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.

Args:
Expand Down Expand Up @@ -355,7 +362,11 @@
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def _decompress(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> pathlib.Path:
r"""Decompress a file.

The compression is automatically detected from the file name.
Expand All @@ -373,7 +384,7 @@
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")

if to_path is None:
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
Expand All @@ -384,10 +395,14 @@
if remove_finished:
os.remove(from_path)

return to_path
return pathlib.Path(to_path)


def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def extract_archive(
from_path: Union[str, pathlib.Path],

Check warning on line 402 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function extract_archive: from_path changed from str to Union[str, pathlib.Path]

Check warning on line 402 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function extract_archive: from_path changed from str to Union[str, pathlib.Path]
to_path: Optional[Union[str, pathlib.Path]] = None,

Check warning on line 403 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function extract_archive: to_path changed from Optional[str] to Optional[Union[str, pathlib.Path]]

Check warning on line 403 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function extract_archive: to_path changed from Optional[str] to Optional[Union[str, pathlib.Path]]
remove_finished: bool = False,
) -> Union[str, pathlib.Path]:
"""Extract an archive.

The archive type and a possible compression is automatically detected from the file name. If the file is compressed
Expand All @@ -402,16 +417,24 @@
Returns:
(str): Path to the directory the file was extracted to.
"""

def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
if isinstance(from_path, str):
return os.fspath(ret_path)
else:
return ret_path

if to_path is None:
to_path = os.path.dirname(from_path)

suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
ret_path = _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)
return path_or_str(ret_path)
ahmadsharif1 marked this conversation as resolved.
Show resolved Hide resolved

# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]
Expand All @@ -420,14 +443,14 @@
if remove_finished:
os.remove(from_path)

return to_path
return path_or_str(pathlib.Path(to_path))


def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
download_root: Union[str, pathlib.Path],

Check warning on line 451 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function download_and_extract_archive: download_root changed from str to Union[str, pathlib.Path]

Check warning on line 451 in torchvision/datasets/utils.py

View workflow job for this annotation

GitHub Actions / bc

Function download_and_extract_archive: download_root changed from str to Union[str, pathlib.Path]
extract_root: Optional[Union[str, pathlib.Path]] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
Expand Down Expand Up @@ -479,7 +502,7 @@
return value


def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.

Args:
Expand Down
Loading