Skip to content

Commit

Permalink
Fix concurrent script loading with force_redownload (#6718)
Browse files Browse the repository at this point in the history
* fix concurrent script loading with force_redownload

* support various dynamic modules paths

* fix test

* fix tests

* disable on windows

* Update tests/test_load.py

Co-authored-by: Mario Šaško <[email protected]>

---------

Co-authored-by: Mario Šaško <[email protected]>
  • Loading branch information
lhoestq and mariosasko authored Mar 7, 2024
1 parent e52f4d0 commit f45bc6c
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 32 deletions.
57 changes: 31 additions & 26 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import time
import warnings
from collections import Counter
from contextlib import nullcontext
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
Expand Down Expand Up @@ -70,7 +71,6 @@
)
from .splits import Split
from .utils import _datasets_server
from .utils._filelock import FileLock
from .utils.deprecation_utils import deprecated
from .utils.file_utils import (
OfflineModeIsEnabled,
Expand All @@ -87,7 +87,7 @@
from .utils.info_utils import VerificationMode, is_small_dataset
from .utils.logging import get_logger
from .utils.metadata import MetadataConfigs
from .utils.py_utils import get_imports
from .utils.py_utils import get_imports, lock_importable_file
from .utils.version import Version


Expand Down Expand Up @@ -244,7 +244,10 @@ def __reduce__(self): # to make dynamically created class pickable, see _Initia
def get_dataset_builder_class(
dataset_module: "DatasetModule", dataset_name: Optional[str] = None
) -> Type[DatasetBuilder]:
builder_cls = import_main_class(dataset_module.module_path)
with lock_importable_file(
dataset_module.importable_file_path
) if dataset_module.importable_file_path else nullcontext():
builder_cls = import_main_class(dataset_module.module_path)
if dataset_module.builder_configs_parameters.builder_configs:
dataset_name = dataset_name or dataset_module.builder_kwargs.get("dataset_name")
if dataset_name is None:
Expand Down Expand Up @@ -375,17 +378,15 @@ def _copy_script_and_other_resources_in_importable_dir(
download_mode (Optional[Union[DownloadMode, str]]): download mode
Return:
importable_local_file: path to an importable module with importlib.import_module
importable_file: path to an importable module with importlib.import_module
"""

# Define a directory with a unique name in our dataset or metric folder
# path is: ./datasets|metrics/dataset|metric_name/hash_from_code/script.py
# we use a hash as subdirectory_name to be able to have multiple versions of a dataset/metric processing file together
importable_subdirectory = os.path.join(importable_directory_path, subdirectory_name)
importable_local_file = os.path.join(importable_subdirectory, name + ".py")
importable_file = os.path.join(importable_subdirectory, name + ".py")
# Prevent parallel disk operations
lock_path = importable_directory_path + ".lock"
with FileLock(lock_path):
with lock_importable_file(importable_file):
# Create main dataset/metrics folder if needed
if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(importable_directory_path):
shutil.rmtree(importable_directory_path)
Expand All @@ -406,13 +407,13 @@ def _copy_script_and_other_resources_in_importable_dir(
pass

# Copy dataset.py file in hash folder if needed
if not os.path.exists(importable_local_file):
shutil.copyfile(original_local_path, importable_local_file)
if not os.path.exists(importable_file):
shutil.copyfile(original_local_path, importable_file)
# Record metadata associating original dataset path with local unique folder
# Use os.path.splitext to split extension from importable_local_file
meta_path = os.path.splitext(importable_local_file)[0] + ".json"
meta_path = os.path.splitext(importable_file)[0] + ".json"
if not os.path.exists(meta_path):
meta = {"original file path": original_local_path, "local file path": importable_local_file}
meta = {"original file path": original_local_path, "local file path": importable_file}
# the filename is *.py in our case, so better rename to filename.json instead of filename.py.json
with open(meta_path, "w", encoding="utf-8") as meta_file:
json.dump(meta, meta_file)
Expand All @@ -437,7 +438,7 @@ def _copy_script_and_other_resources_in_importable_dir(
original_path, destination_additional_path
):
shutil.copyfile(original_path, destination_additional_path)
return importable_local_file
return importable_file


def _get_importable_file_path(
Expand All @@ -447,7 +448,7 @@ def _get_importable_file_path(
name: str,
) -> str:
importable_directory_path = os.path.join(dynamic_modules_path, module_namespace, name.replace("/", "--"))
return os.path.join(importable_directory_path, subdirectory_name, name + ".py")
return os.path.join(importable_directory_path, subdirectory_name, name.split("/")[-1] + ".py")


def _create_importable_file(
Expand Down Expand Up @@ -692,6 +693,7 @@ class DatasetModule:
builder_kwargs: dict
builder_configs_parameters: BuilderConfigsParameters = field(default_factory=BuilderConfigsParameters)
dataset_infos: Optional[DatasetInfosDict] = None
importable_file_path: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -983,7 +985,7 @@ def get_module(self) -> DatasetModule:
# make the new module to be noticed by the import system
importlib.invalidate_caches()
builder_kwargs = {"base_path": str(Path(self.path).parent)}
return DatasetModule(module_path, hash, builder_kwargs)
return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path)


class LocalDatasetModuleFactoryWithoutScript(_DatasetModuleFactory):
Expand Down Expand Up @@ -1536,7 +1538,7 @@ def get_module(self) -> DatasetModule:
"base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"),
"repo_id": self.name,
}
return DatasetModule(module_path, hash, builder_kwargs)
return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path)


class CachedDatasetModuleFactory(_DatasetModuleFactory):
Expand Down Expand Up @@ -1582,21 +1584,24 @@ def _get_modification_time(module_hash):
if not config.HF_DATASETS_OFFLINE:
warning_msg += ", or remotely on the Hugging Face Hub."
logger.warning(warning_msg)
# make the new module to be noticed by the import system
module_path = ".".join(
[
os.path.basename(dynamic_modules_path),
"datasets",
self.name.replace("/", "--"),
hash,
self.name.split("/")[-1],
]
importable_file_path = _get_importable_file_path(
dynamic_modules_path=dynamic_modules_path,
module_namespace="datasets",
subdirectory_name=hash,
name=self.name,
)
module_path, hash = _load_importable_file(
dynamic_modules_path=dynamic_modules_path,
module_namespace="datasets",
subdirectory_name=hash,
name=self.name,
)
# make the new module to be noticed by the import system
importlib.invalidate_caches()
builder_kwargs = {
"repo_id": self.name,
}
return DatasetModule(module_path, hash, builder_kwargs)
return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path)
cache_dir = os.path.expanduser(str(self.cache_dir or config.HF_DATASETS_CACHE))
cached_datasets_directory_path_root = os.path.join(cache_dir, self.name.replace("/", "___"))
cached_directory_paths = [
Expand Down
14 changes: 8 additions & 6 deletions src/datasets/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from .utils.logging import get_logger
from .utils.patching import patch_submodule
from .utils.py_utils import get_imports
from .utils.py_utils import get_imports, lock_importable_file


logger = get_logger(__name__)
Expand Down Expand Up @@ -120,11 +120,13 @@ def extend_dataset_builder_for_streaming(builder: "DatasetBuilder"):
extend_module_for_streaming(builder.__module__, download_config=download_config)
# if needed, we also have to extend additional internal imports (like wmt14 -> wmt_utils)
if not builder.__module__.startswith("datasets."): # check that it's not a packaged builder like csv
for imports in get_imports(inspect.getfile(builder.__class__)):
if imports[0] == "internal":
internal_import_name = imports[1]
internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name])
extend_module_for_streaming(internal_module_name, download_config=download_config)
importable_file = inspect.getfile(builder.__class__)
with lock_importable_file(importable_file):
for imports in get_imports(importable_file):
if imports[0] == "internal":
internal_import_name = imports[1]
internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name])
extend_module_for_streaming(internal_module_name, download_config=download_config)

# builders can inherit from other builders that might use streaming functionality
# (for example, ImageFolder and AudioFolder inherit from FolderBuilder which implements examples generation)
Expand Down
11 changes: 11 additions & 0 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from contextlib import contextmanager
from dataclasses import fields, is_dataclass
from multiprocessing import Manager
from pathlib import Path
from queue import Empty
from shutil import disk_usage
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
Expand All @@ -47,6 +48,7 @@
dumps,
pklregister,
)
from ._filelock import FileLock


try: # pragma: no branch
Expand Down Expand Up @@ -537,6 +539,15 @@ def _convert_github_url(url_path: str) -> Tuple[str, Optional[str]]:
return url_path, sub_directory


def lock_importable_file(importable_local_file: str) -> FileLock:
# Check the directory with a unique name in our dataset folder
# path is: ./datasets/dataset_name/hash_from_code/script.py
# we use a hash as subdirectory_name to be able to have multiple versions of a dataset/metric processing file together
importable_directory_path = str(Path(importable_local_file).resolve().parent.parent)
lock_path = importable_directory_path + ".lock"
return FileLock(lock_path)


def get_imports(file_path: str) -> Tuple[str, str, str, str]:
"""Find whether we should import or clone additional files for a given processing script.
And list the import.
Expand Down
23 changes: 23 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
assert_arrow_memory_doesnt_increase,
assert_arrow_memory_increases,
offline,
require_not_windows,
require_pil,
require_sndfile,
set_current_working_directory_to_temp_dir,
Expand Down Expand Up @@ -1674,6 +1675,28 @@ def test_load_dataset_distributed(tmp_path, csv_path):
assert all(dataset.cache_files == datasets[0].cache_files for dataset in datasets)


def distributed_load_dataset_with_script(args):
data_name, tmp_dir, download_mode = args
dataset = load_dataset(data_name, cache_dir=tmp_dir, download_mode=download_mode)
return dataset


@require_not_windows # windows doesn't support overwriting Arrow files from other processes
@pytest.mark.parametrize("download_mode", [None, "force_redownload"])
def test_load_dataset_distributed_with_script(tmp_path, download_mode):
# we need to check in the "force_redownload" case
# since in `_copy_script_and_other_resources_in_importable_dir()` we might delete the directory
# containing the .py file while the other processes use it
num_workers = 5
args = (SAMPLE_DATASET_IDENTIFIER, str(tmp_path), download_mode)
with Pool(processes=num_workers) as pool: # start num_workers processes
datasets = pool.map(distributed_load_dataset_with_script, [args] * num_workers)
assert len(datasets) == num_workers
assert all(len(dataset) == len(datasets[0]) > 0 for dataset in datasets)
assert len(datasets[0].cache_files) > 0
assert all(dataset.cache_files == datasets[0].cache_files for dataset in datasets)


def test_load_dataset_with_storage_options(mockfs):
with mockfs.open("data.txt", "w") as f:
f.write("Hello there\n")
Expand Down

0 comments on commit f45bc6c

Please sign in to comment.