From 145e0414c24d1ad102f4d5af605dd92f56e0badf Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 8 Aug 2024 10:24:06 +0200 Subject: [PATCH] Support HTTP authentication in non-streaming mode (#7082) * Refactor cached_path * Fix for empty storage_options * Allow passing HTTP storage options * Test HTTP storage_options passed by cached_path to get_from_cache Test cached_path passes HTTP storage options to get_from_cache only if passed in DownloadConfig * Test HTTP fsspec is called only if passed HTTP storage_options --- src/datasets/utils/file_utils.py | 13 ++++- tests/test_file_utils.py | 82 ++++++++++++++++++++++++++++---- 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 23d91fe91dd3..10eb274e842f 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -201,6 +201,13 @@ def cached_path( url_or_filename, storage_options = _prepare_path_and_storage_options( url_or_filename, download_config=download_config ) + # Pass HTTP storage_options to get_from_cache only if passed HTTP download_config.storage_options + if ( + storage_options + and storage_options.keys() < {"http", "https"} + and not (download_config.storage_options and download_config.storage_options.keys() < {"http", "https"}) + ): + storage_options = {} output_path = get_from_cache( url_or_filename, cache_dir=cache_dir, @@ -525,6 +532,8 @@ def get_from_cache( ConnectionError: in case of unreachable url and no cache on disk """ + if storage_options is None: + storage_options = {} if use_auth_token != "deprecated": warnings.warn( "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n" @@ -570,7 +579,7 @@ def get_from_cache( scheme = urlparse(url).scheme if scheme == "ftp": connected = ftp_head(url) - elif scheme not in ("http", "https"): + elif scheme not in {"http", "https"} or storage_options.get(scheme): response = fsspec_head(url, storage_options=storage_options) # s3fs uses "ETag", gcsfs uses "etag" etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None @@ -676,7 +685,7 @@ def temp_file_manager(mode="w+b"): # GET file object if scheme == "ftp": ftp_get(url, temp_file) - elif scheme not in ("http", "https"): + elif scheme not in {"http", "https"} or storage_options.get(scheme): fsspec_get( url, temp_file, storage_options=storage_options, desc=download_desc, disable_tqdm=disable_tqdm ) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index bee49be7d179..c8913280e0de 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -1,5 +1,6 @@ import os import re +from dataclasses import dataclass, field from pathlib import Path from unittest.mock import MagicMock, patch @@ -78,24 +79,85 @@ def tmpfs_file(tmpfs): return FILE_PATH -@pytest.mark.parametrize("protocol", ["hf", "s3"]) -def test_cached_path_protocols(protocol, monkeypatch, tmp_path): +@pytest.mark.parametrize( + "protocol, download_config_storage_options, expected_fsspec_called", + [ + ("hf", {}, True), + ("s3", {"s3": {"anon": True}}, True), + # HTTP calls fsspec only if passed HTTP download_config.storage_options: + ("https", {"https": {"block_size": "omit"}}, True), + ("https", {}, False), + ], +) +def test_cached_path_calls_fsspec_for_protocols( + protocol, download_config_storage_options, expected_fsspec_called, monkeypatch, tmp_path +): # GH-6598: Test no TypeError: __init__() got an unexpected keyword argument 'hf' + # fsspec_head/get: mock_fsspec_head = MagicMock(return_value={}) mock_fsspec_get = MagicMock(return_value=None) monkeypatch.setattr("datasets.utils.file_utils.fsspec_head", mock_fsspec_head) monkeypatch.setattr("datasets.utils.file_utils.fsspec_get", mock_fsspec_get) + + # http_head_get: + @dataclass + class Response: + status_code: int + headers: dict = field(default_factory=dict) + cookies: dict = field(default_factory=dict) + + mock_http_head = MagicMock(return_value=Response(status_code=200)) + mock_http_get = MagicMock(return_value=None) + monkeypatch.setattr("datasets.utils.file_utils.http_head", mock_http_head) + monkeypatch.setattr("datasets.utils.file_utils.http_get", mock_http_get) + # Test: cache_dir = tmp_path / "cache" - storage_options = {} if protocol == "hf" else {"s3": {"anon": True}} - download_config = DownloadConfig(cache_dir=cache_dir, storage_options=storage_options) - urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"} + download_config = DownloadConfig(cache_dir=cache_dir, storage_options=download_config_storage_options) + urls = { + "hf": "hf://datasets/org-name/ds-name@main/filename.ext", + "https": "https://doamin.org/filename.ext", + "s3": "s3://bucket-name/filename.ext", + } url = urls[protocol] _ = cached_path(url, download_config=download_config) - for mock in [mock_fsspec_head, mock_fsspec_get]: - assert mock.called - assert mock.call_count == 1 - assert mock.call_args.args[0] == url - assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol] + if expected_fsspec_called: + for mock in [mock_fsspec_head, mock_fsspec_get]: + assert mock.called + assert mock.call_count == 1 + assert mock.call_args.args[0] == url + assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol] + for mock in [mock_http_head, mock_http_get]: + assert not mock.called + else: + for mock in [mock_fsspec_head, mock_fsspec_get]: + assert not mock.called + for mock in [mock_http_head, mock_http_get]: + assert mock.called + assert mock.call_count == 1 + assert mock.call_args.args[0] == url + + +@pytest.mark.parametrize( + "download_config_storage_options, expected_storage_options_passed_to_get_from_catch", + [ + ({}, {}), # No DownloadConfig.storage_options + ({"https": {"block_size": "omit"}}, {"https": {"client_kwargs": {"trust_env": True}, "block_size": "omit"}}), + ], +) +def test_cached_path_passes_http_storage_options_to_get_from_cache_only_if_present_in_download_config( + download_config_storage_options, expected_storage_options_passed_to_get_from_catch, monkeypatch, tmp_path +): + # Test cached_path passes HTTP storage_options to get_from_cache only if passed HTTP download_config.storage_options + mock_get_from_catch = MagicMock(return_value=None) + monkeypatch.setattr("datasets.utils.file_utils.get_from_cache", mock_get_from_catch) + url = "https://domain.org/data.txt" + cache_dir = tmp_path / "cache" + download_config = DownloadConfig(cache_dir=cache_dir, storage_options=download_config_storage_options) + _ = cached_path(url, download_config=download_config) + assert mock_get_from_catch.called + assert mock_get_from_catch.call_count == 1 + assert mock_get_from_catch.call_args.args[0] == url + assert mock_get_from_catch.call_args.kwargs["storage_options"] == expected_storage_options_passed_to_get_from_catch @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])