Skip to content

Commit

Permalink
Support HTTP authentication in non-streaming mode (#7082)
Browse files Browse the repository at this point in the history
* Refactor cached_path

* Fix for empty storage_options

* Allow passing HTTP storage options

* Test HTTP storage_options passed by cached_path to get_from_cache

Test cached_path passes HTTP storage options to get_from_cache only if passed in DownloadConfig

* Test HTTP fsspec is called only if passed HTTP storage_options
  • Loading branch information
albertvillanova committed Aug 13, 2024
1 parent b288977 commit 145e041
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 12 deletions.
13 changes: 11 additions & 2 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def cached_path(
url_or_filename, storage_options = _prepare_path_and_storage_options(
url_or_filename, download_config=download_config
)
# Pass HTTP storage_options to get_from_cache only if passed HTTP download_config.storage_options
if (
storage_options
and storage_options.keys() < {"http", "https"}
and not (download_config.storage_options and download_config.storage_options.keys() < {"http", "https"})
):
storage_options = {}
output_path = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
Expand Down Expand Up @@ -525,6 +532,8 @@ def get_from_cache(
ConnectionError: in case of unreachable url
and no cache on disk
"""
if storage_options is None:
storage_options = {}
if use_auth_token != "deprecated":
warnings.warn(
"'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
Expand Down Expand Up @@ -570,7 +579,7 @@ def get_from_cache(
scheme = urlparse(url).scheme
if scheme == "ftp":
connected = ftp_head(url)
elif scheme not in ("http", "https"):
elif scheme not in {"http", "https"} or storage_options.get(scheme):
response = fsspec_head(url, storage_options=storage_options)
# s3fs uses "ETag", gcsfs uses "etag"
etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None
Expand Down Expand Up @@ -676,7 +685,7 @@ def temp_file_manager(mode="w+b"):
# GET file object
if scheme == "ftp":
ftp_get(url, temp_file)
elif scheme not in ("http", "https"):
elif scheme not in {"http", "https"} or storage_options.get(scheme):
fsspec_get(
url, temp_file, storage_options=storage_options, desc=download_desc, disable_tqdm=disable_tqdm
)
Expand Down
82 changes: 72 additions & 10 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -78,24 +79,85 @@ def tmpfs_file(tmpfs):
return FILE_PATH


@pytest.mark.parametrize("protocol", ["hf", "s3"])
def test_cached_path_protocols(protocol, monkeypatch, tmp_path):
@pytest.mark.parametrize(
"protocol, download_config_storage_options, expected_fsspec_called",
[
("hf", {}, True),
("s3", {"s3": {"anon": True}}, True),
# HTTP calls fsspec only if passed HTTP download_config.storage_options:
("https", {"https": {"block_size": "omit"}}, True),
("https", {}, False),
],
)
def test_cached_path_calls_fsspec_for_protocols(
protocol, download_config_storage_options, expected_fsspec_called, monkeypatch, tmp_path
):
# GH-6598: Test no TypeError: __init__() got an unexpected keyword argument 'hf'
# fsspec_head/get:
mock_fsspec_head = MagicMock(return_value={})
mock_fsspec_get = MagicMock(return_value=None)
monkeypatch.setattr("datasets.utils.file_utils.fsspec_head", mock_fsspec_head)
monkeypatch.setattr("datasets.utils.file_utils.fsspec_get", mock_fsspec_get)

# http_head_get:
@dataclass
class Response:
status_code: int
headers: dict = field(default_factory=dict)
cookies: dict = field(default_factory=dict)

mock_http_head = MagicMock(return_value=Response(status_code=200))
mock_http_get = MagicMock(return_value=None)
monkeypatch.setattr("datasets.utils.file_utils.http_head", mock_http_head)
monkeypatch.setattr("datasets.utils.file_utils.http_get", mock_http_get)
# Test:
cache_dir = tmp_path / "cache"
storage_options = {} if protocol == "hf" else {"s3": {"anon": True}}
download_config = DownloadConfig(cache_dir=cache_dir, storage_options=storage_options)
urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"}
download_config = DownloadConfig(cache_dir=cache_dir, storage_options=download_config_storage_options)
urls = {
"hf": "hf://datasets/org-name/ds-name@main/filename.ext",
"https": "https://doamin.org/filename.ext",
"s3": "s3://bucket-name/filename.ext",
}
url = urls[protocol]
_ = cached_path(url, download_config=download_config)
for mock in [mock_fsspec_head, mock_fsspec_get]:
assert mock.called
assert mock.call_count == 1
assert mock.call_args.args[0] == url
assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol]
if expected_fsspec_called:
for mock in [mock_fsspec_head, mock_fsspec_get]:
assert mock.called
assert mock.call_count == 1
assert mock.call_args.args[0] == url
assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol]
for mock in [mock_http_head, mock_http_get]:
assert not mock.called
else:
for mock in [mock_fsspec_head, mock_fsspec_get]:
assert not mock.called
for mock in [mock_http_head, mock_http_get]:
assert mock.called
assert mock.call_count == 1
assert mock.call_args.args[0] == url


@pytest.mark.parametrize(
"download_config_storage_options, expected_storage_options_passed_to_get_from_catch",
[
({}, {}), # No DownloadConfig.storage_options
({"https": {"block_size": "omit"}}, {"https": {"client_kwargs": {"trust_env": True}, "block_size": "omit"}}),
],
)
def test_cached_path_passes_http_storage_options_to_get_from_cache_only_if_present_in_download_config(
download_config_storage_options, expected_storage_options_passed_to_get_from_catch, monkeypatch, tmp_path
):
# Test cached_path passes HTTP storage_options to get_from_cache only if passed HTTP download_config.storage_options
mock_get_from_catch = MagicMock(return_value=None)
monkeypatch.setattr("datasets.utils.file_utils.get_from_cache", mock_get_from_catch)
url = "https://domain.org/data.txt"
cache_dir = tmp_path / "cache"
download_config = DownloadConfig(cache_dir=cache_dir, storage_options=download_config_storage_options)
_ = cached_path(url, download_config=download_config)
assert mock_get_from_catch.called
assert mock_get_from_catch.call_count == 1
assert mock_get_from_catch.call_args.args[0] == url
assert mock_get_from_catch.call_args.kwargs["storage_options"] == expected_storage_options_passed_to_get_from_catch


@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
Expand Down

0 comments on commit 145e041

Please sign in to comment.