-
Notifications
You must be signed in to change notification settings - Fork 812
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for MNLI dataset with unit tests (#1715)
* Add support for MNLI + add tests * Adjust dataset size docstring * Remove lambda functions * Add dataset documentation * Add shuffle and sharding
- Loading branch information
Showing
4 changed files
with
188 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,11 @@ IMDb | |
|
||
.. autofunction:: IMDB | ||
|
||
MNLI | ||
~~~~ | ||
|
||
.. autofunction:: MNLI | ||
|
||
MRPC | ||
~~~~ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import os | ||
import zipfile | ||
from collections import defaultdict | ||
from unittest.mock import patch | ||
|
||
from parameterized import parameterized | ||
from torchtext.datasets.mnli import MNLI | ||
|
||
from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode | ||
from ..common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
def _get_mock_dataset(root_dir): | ||
""" | ||
root_dir: directory to the mocked dataset | ||
""" | ||
base_dir = os.path.join(root_dir, "MNLI") | ||
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") | ||
os.makedirs(temp_dataset_dir, exist_ok=True) | ||
|
||
seed = 1 | ||
mocked_data = defaultdict(list) | ||
for file_name in ["multinli_1.0_train.txt", "multinli_1.0_dev_matched.txt", "multinli_1.0_dev_mismatched.txt"]: | ||
txt_file = os.path.join(temp_dataset_dir, file_name) | ||
with open(txt_file, "w", encoding="utf-8") as f: | ||
f.write( | ||
"gold_label\tsentence1_binary_parse\tsentence2_binary_parse\tsentence1_parse\tsentence2_parse\tsentence1\tsentence2\tpromptID\tpairID\tgenre\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5" | ||
) | ||
for i in range(5): | ||
label = seed % 3 | ||
rand_string = get_random_unicode(seed) | ||
dataset_line = (label, rand_string, rand_string) | ||
f.write( | ||
f"{label}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{rand_string}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\t{i}\n" | ||
) | ||
|
||
# append line to correct dataset split | ||
mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) | ||
seed += 1 | ||
|
||
compressed_dataset_path = os.path.join(base_dir, "multinli_1.0.zip") | ||
# create zip file from dataset folder | ||
with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: | ||
for file_name in ("multinli_1.0_train.txt", "multinli_1.0_dev_matched.txt", "multinli_1.0_dev_mismatched.txt"): | ||
txt_file = os.path.join(temp_dataset_dir, file_name) | ||
zip_file.write(txt_file, arcname=os.path.join("multinli_1.0", file_name)) | ||
|
||
return mocked_data | ||
|
||
|
||
class TestMNLI(TempDirMixin, TorchtextTestCase): | ||
root_dir = None | ||
samples = [] | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.root_dir = cls.get_base_temp_dir() | ||
cls.samples = _get_mock_dataset(cls.root_dir) | ||
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) | ||
cls.patcher.start() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.patcher.stop() | ||
super().tearDownClass() | ||
|
||
@parameterized.expand(["train", "dev_matched", "dev_mismatched"]) | ||
def test_mnli(self, split): | ||
dataset = MNLI(root=self.root_dir, split=split) | ||
|
||
samples = list(dataset) | ||
expected_samples = self.samples[split] | ||
for sample, expected_sample in zip_equal(samples, expected_samples): | ||
self.assertEqual(sample, expected_sample) | ||
|
||
@parameterized.expand(["train", "dev_matched", "dev_mismatched"]) | ||
def test_sst2_split_argument(self, split): | ||
dataset1 = MNLI(root=self.root_dir, split=split) | ||
(dataset2,) = MNLI(root=self.root_dir, split=(split,)) | ||
|
||
for d1, d2 in zip_equal(dataset1, dataset2): | ||
self.assertEqual(d1, d2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
import csv | ||
import os | ||
|
||
from torchtext._internal.module_utils import is_module_available | ||
from torchtext.data.datasets_utils import ( | ||
_create_dataset_directory, | ||
_wrap_split_argument, | ||
) | ||
|
||
if is_module_available("torchdata"): | ||
from torchdata.datapipes.iter import FileOpener, IterableWrapper | ||
|
||
# we import HttpReader from _download_hooks so we can swap out public URLs | ||
# with interal URLs when the dataset is used within Facebook | ||
from torchtext._download_hooks import HttpReader | ||
|
||
|
||
URL = "https://cims.nyu.edu/~sbowman/multinli/multinli_1.0.zip" | ||
|
||
MD5 = "0f70aaf66293b3c088a864891db51353" | ||
|
||
NUM_LINES = { | ||
"train": 392702, | ||
"dev_matched": 9815, | ||
"dev_mismatched": 9832, | ||
} | ||
|
||
_PATH = "multinli_1.0.zip" | ||
|
||
DATASET_NAME = "MNLI" | ||
|
||
_EXTRACTED_FILES = { | ||
"train": "multinli_1.0_train.txt", | ||
"dev_matched": "multinli_1.0_dev_matched.txt", | ||
"dev_mismatched": "multinli_1.0_dev_mismatched.txt", | ||
} | ||
|
||
LABEL_TO_INT = {"entailment": 0, "neutral": 1, "contradiction": 2} | ||
|
||
|
||
@_create_dataset_directory(dataset_name=DATASET_NAME) | ||
@_wrap_split_argument(("train", "dev_matched", "dev_mismatched")) | ||
def MNLI(root, split): | ||
"""MNLI Dataset | ||
For additional details refer to https://cims.nyu.edu/~sbowman/multinli/ | ||
Number of lines per split: | ||
- train: 392702 | ||
- dev_matched: 9815 | ||
- dev_mismatched: 9832 | ||
Args: | ||
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') | ||
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev_matched`, `dev_mismatched`) | ||
:returns: DataPipe that yields tuple of text and label (0 to 2). | ||
:rtype: Tuple[int, str, str] | ||
""" | ||
# TODO Remove this after removing conditional dependency | ||
if not is_module_available("torchdata"): | ||
raise ModuleNotFoundError( | ||
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" | ||
) | ||
|
||
def _filepath_fn(x=None): | ||
return os.path.join(root, os.path.basename(x)) | ||
|
||
def _extracted_filepath_fn(_=None): | ||
return os.path.join(root, _EXTRACTED_FILES[split]) | ||
|
||
def _filter_fn(x): | ||
return _EXTRACTED_FILES[split] in x[0] | ||
|
||
def _filter_res(x): | ||
return x[0] in LABEL_TO_INT | ||
|
||
def _modify_res(x): | ||
return (LABEL_TO_INT[x[0]], x[5], x[6]) | ||
|
||
url_dp = IterableWrapper([URL]) | ||
cache_compressed_dp = url_dp.on_disk_cache( | ||
filepath_fn=_filepath_fn, | ||
hash_dict={_filepath_fn(URL): MD5}, | ||
hash_type="md5", | ||
) | ||
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) | ||
|
||
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) | ||
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn) | ||
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) | ||
|
||
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") | ||
parsed_data = ( | ||
data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) | ||
) | ||
return parsed_data.shuffle().set_shuffle(False).sharding_filter() |