From 157b61344b44a93923de0b309c2c11c23d6cb1aa Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 8 Feb 2024 11:08:09 +0100 Subject: [PATCH] add gdown as optional requirement for dataset GDrive download (#8237) --- .github/workflows/tests-schedule.yml | 2 +- mypy.ini | 4 ++ setup.py | 1 - torchvision/datasets/caltech.py | 4 ++ torchvision/datasets/celeba.py | 4 ++ torchvision/datasets/pcam.py | 4 ++ torchvision/datasets/utils.py | 74 ++++------------------------ torchvision/datasets/widerface.py | 4 ++ 8 files changed, 30 insertions(+), 67 deletions(-) diff --git a/.github/workflows/tests-schedule.yml b/.github/workflows/tests-schedule.yml index 5426fdc997a..cc13e25d20e 100644 --- a/.github/workflows/tests-schedule.yml +++ b/.github/workflows/tests-schedule.yml @@ -36,7 +36,7 @@ jobs: run: pip install --no-build-isolation --editable . - name: Install all optional dataset requirements - run: pip install scipy pycocotools lmdb requests + run: pip install scipy pycocotools lmdb gdown - name: Install tests requirements run: pip install pytest diff --git a/mypy.ini b/mypy.ini index 0c4e51dabd9..f95901e9c6f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -110,3 +110,7 @@ ignore_missing_imports = True [mypy-h5py.*] ignore_missing_imports = True + +[mypy-gdown.*] + +ignore_missing_imports = True diff --git a/setup.py b/setup.py index f7d3ac2e3ba..ce1cd90ca05 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,6 @@ def write_version_file(): requirements = [ "numpy", - "requests", pytorch_dep, ] diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index 3a9635dfe09..7532b4f2061 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -30,6 +30,10 @@ class Caltech101(VisionDataset): download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + + .. warning:: + + To download the dataset `gdown `_ is required. """ def __init__( diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index d055f92f194..ea86d18e260 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -38,6 +38,10 @@ class CelebA(VisionDataset): download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + + .. warning:: + + To download the dataset `gdown `_ is required. """ base_folder = "celeba" diff --git a/torchvision/datasets/pcam.py b/torchvision/datasets/pcam.py index c21f66186ce..c8fa1c4a92c 100644 --- a/torchvision/datasets/pcam.py +++ b/torchvision/datasets/pcam.py @@ -25,6 +25,10 @@ class PCAM(VisionDataset): target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If dataset is already downloaded, it is not downloaded again. + + .. warning:: + + To download the dataset `gdown `_ is required. """ _FILES = { diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index b79b4ef4e61..dc17b8f92e6 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -1,8 +1,6 @@ import bz2 -import contextlib import gzip import hashlib -import itertools import lzma import os import os.path @@ -13,13 +11,11 @@ import urllib import urllib.error import urllib.request -import warnings import zipfile from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar from urllib.parse import urlparse import numpy as np -import requests import torch from torch.utils.model_zoo import tqdm @@ -187,22 +183,6 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: return files -def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]: - content = response.iter_content(chunk_size) - first_chunk = None - # filter out keep-alive new chunks - while not first_chunk: - first_chunk = next(content) - content = itertools.chain([first_chunk], content) - - try: - match = re.search("Google Drive - (?P<api_response>.+?)", first_chunk.decode()) - api_response = match["api_response"] if match is not None else None - except UnicodeDecodeError: - api_response = None - return api_response, content - - def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): """Download a Google Drive file from and place it in root. @@ -212,7 +192,12 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ filename (str, optional): Name to save the file under. If None, use the id of the file. md5 (str, optional): MD5 checksum of the download. If None, do not check """ - # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url + try: + import gdown + except ModuleNotFoundError: + raise RuntimeError( + "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'." + ) root = os.path.expanduser(root) if not filename: @@ -225,51 +210,10 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") return - url = "https://drive.google.com/uc" - params = dict(id=file_id, export="download") - with requests.Session() as session: - response = session.get(url, params=params, stream=True) + gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT) - for key, value in response.cookies.items(): - if key.startswith("download_warning"): - token = value - break - else: - api_response, content = _extract_gdrive_api_response(response) - token = "t" if api_response == "Virus scan warning" else None - - if token is not None: - response = session.get(url, params=dict(params, confirm=token), stream=True) - api_response, content = _extract_gdrive_api_response(response) - - if api_response == "Quota exceeded": - raise RuntimeError( - f"The daily quota of the file {filename} is exceeded and it " - f"can't be downloaded. This is a limitation of Google Drive " - f"and can only be overcome by trying again later." - ) - - _save_response_content(content, fpath) - - # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text - if os.stat(fpath).st_size < 10 * 1024: - with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh: - text = fh.read() - # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 - if re.search(r"]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text): - warnings.warn( - f"We detected some HTML elements in the downloaded file. " - f"This most likely means that the download triggered an unhandled API response by GDrive. " - f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " - f"the response:\n\n{text}" - ) - - if md5 and not check_md5(fpath, md5): - raise RuntimeError( - f"The MD5 checksum of the download file {fpath} does not match the one on record." - f"Please delete the file and try again. " - f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." - ) + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: diff --git a/torchvision/datasets/widerface.py b/torchvision/datasets/widerface.py index aa520455ef1..9003272ab60 100644 --- a/torchvision/datasets/widerface.py +++ b/torchvision/datasets/widerface.py @@ -34,6 +34,10 @@ class WIDERFace(VisionDataset): puts it in root directory. If dataset is already downloaded, it is not downloaded again. + .. warning:: + + To download the dataset `gdown `_ is required. + """ BASE_FOLDER = "widerface"