diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 38c4bbf8f448..23d91fe91dd3 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -1160,7 +1160,7 @@ def _prepare_single_hop_path_and_storage_options( urlpath = "hf://" + urlpath[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1) protocol = urlpath.split("://")[0] if "://" in urlpath else "file" if download_config is not None and protocol in download_config.storage_options: - storage_options = download_config.storage_options[protocol] + storage_options = download_config.storage_options[protocol].copy() elif download_config is not None and protocol not in download_config.storage_options: storage_options = { option_name: option_value @@ -1169,40 +1169,34 @@ def _prepare_single_hop_path_and_storage_options( } else: storage_options = {} - if storage_options: - storage_options = {protocol: storage_options} - if protocol in ["http", "https"]: - storage_options[protocol] = { - "headers": { - **get_authentication_headers_for_url(urlpath, token=token), - "user-agent": get_datasets_user_agent(), - }, - "client_kwargs": {"trust_env": True}, # Enable reading proxy env variables. - **(storage_options.get(protocol, {})), - } + if protocol in {"http", "https"}: + client_kwargs = storage_options.pop("client_kwargs", {}) + storage_options["client_kwargs"] = {"trust_env": True, **client_kwargs} # Enable reading proxy env variables if "drive.google.com" in urlpath: response = http_head(urlpath) - cookies = None for k, v in response.cookies.items(): if k.startswith("download_warning"): urlpath += "&confirm=" + v cookies = response.cookies - storage_options[protocol] = {"cookies": cookies, **storage_options.get(protocol, {})} - # Fix Google Drive URL to avoid Virus scan warning - if "drive.google.com" in urlpath and "confirm=" not in urlpath: - urlpath += "&confirm=t" + storage_options = {"cookies": cookies, **storage_options} + # Fix Google Drive URL to avoid Virus scan warning + if "confirm=" not in urlpath: + urlpath += "&confirm=t" if urlpath.startswith("https://raw.githubusercontent.com/"): # Workaround for served data with gzip content-encoding: https://github.com/fsspec/filesystem_spec/issues/389 - storage_options[protocol]["headers"]["Accept-Encoding"] = "identity" + headers = storage_options.pop("headers", {}) + storage_options["headers"] = {"Accept-Encoding": "identity", **headers} elif protocol == "hf": - storage_options[protocol] = { + storage_options = { "token": token, "endpoint": config.HF_ENDPOINT, - **storage_options.get(protocol, {}), + **storage_options, } # streaming with block_size=0 is only implemented in 0.21 (see https://github.com/huggingface/huggingface_hub/pull/1967) if config.HF_HUB_VERSION < version.parse("0.21.0"): - storage_options[protocol]["block_size"] = "default" + storage_options["block_size"] = "default" + if storage_options: + storage_options = {protocol: storage_options} return urlpath, storage_options diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 84e418eba3ab..bee49be7d179 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -12,6 +12,7 @@ from datasets.utils.file_utils import ( OfflineModeIsEnabled, _get_extraction_protocol, + _prepare_single_hop_path_and_storage_options, cached_path, fsspec_get, fsspec_head, @@ -47,7 +48,7 @@ FILE_PATH = "file" -TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/raw/main/some_text.txt" +TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt" TEST_URL_CONTENT = "foo\nbar\nfoobar" TEST_GG_DRIVE_FILENAME = "train.tsv" @@ -90,7 +91,6 @@ def test_cached_path_protocols(protocol, monkeypatch, tmp_path): urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"} url = urls[protocol] _ = cached_path(url, download_config=download_config) - assert True for mock in [mock_fsspec_head, mock_fsspec_get]: assert mock.called assert mock.call_count == 1 @@ -197,6 +197,75 @@ def test_fsspec_offline(tmp_path_factory): fsspec_head("s3://huggingface.co") +@pytest.mark.parametrize( + "urlpath, download_config, expected_urlpath, expected_storage_options", + [ + ( + "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt", + DownloadConfig(), + "hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt", + {"hf": {"endpoint": "https://huggingface.co", "token": None}}, + ), + ( + "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt", + DownloadConfig(token="MY-TOKEN"), + "hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt", + {"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN"}}, + ), + ( + "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt", + DownloadConfig(token="MY-TOKEN", storage_options={"hf": {"on_error": "omit"}}), + "hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt", + {"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN", "on_error": "omit"}}, + ), + ( + "https://domain.org/data.txt", + DownloadConfig(), + "https://domain.org/data.txt", + {"https": {"client_kwargs": {"trust_env": True}}}, + ), + ( + "https://domain.org/data.txt", + DownloadConfig(storage_options={"https": {"block_size": "omit"}}), + "https://domain.org/data.txt", + {"https": {"client_kwargs": {"trust_env": True}, "block_size": "omit"}}, + ), + ( + "https://domain.org/data.txt", + DownloadConfig(storage_options={"https": {"client_kwargs": {"raise_for_status": True}}}), + "https://domain.org/data.txt", + {"https": {"client_kwargs": {"trust_env": True, "raise_for_status": True}}}, + ), + ( + "https://domain.org/data.txt", + DownloadConfig(storage_options={"https": {"client_kwargs": {"trust_env": False}}}), + "https://domain.org/data.txt", + {"https": {"client_kwargs": {"trust_env": False}}}, + ), + ( + "https://raw.githubusercontent.com/data.txt", + DownloadConfig(storage_options={"https": {"headers": {"x-test": "true"}}}), + "https://raw.githubusercontent.com/data.txt", + { + "https": { + "client_kwargs": {"trust_env": True}, + "headers": {"x-test": "true", "Accept-Encoding": "identity"}, + } + }, + ), + ], +) +def test_prepare_single_hop_path_and_storage_options( + urlpath, download_config, expected_urlpath, expected_storage_options +): + original_download_config_storage_options = str(download_config.storage_options) + prepared_urlpath, storage_options = _prepare_single_hop_path_and_storage_options(urlpath, download_config) + assert prepared_urlpath == expected_urlpath + assert storage_options == expected_storage_options + # Check that DownloadConfig.storage_options are not modified: + assert str(download_config.storage_options) == original_download_config_storage_options + + class DummyTestFS(AbstractFileSystem): protocol = "mock" _file_class = AbstractBufferedFile