diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index d9c32f201c..af782bca03 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -57,6 +57,11 @@ IMDb .. autofunction:: IMDB +MRPC +~~~~ + +.. autofunction:: MRPC + SogouNews ~~~~~~~~~ diff --git a/test/datasets/test_mrpc.py b/test/datasets/test_mrpc.py new file mode 100644 index 0000000000..a5ed910a4e --- /dev/null +++ b/test/datasets/test_mrpc.py @@ -0,0 +1,71 @@ +import os +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.mrpc import MRPC + +from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "MRPC") + os.makedirs(base_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + for file_name, file_type in [("msr_paraphrase_train.txt", "train"), ("msr_paraphrase_test.txt", "test")]: + txt_file = os.path.join(base_dir, file_name) + with open(txt_file, "w", encoding="utf-8") as f: + f.write("Quality\t#1 ID\t#2 ID\t#1 String\t#2 String\n") + for i in range(5): + label = seed % 2 + rand_string_1 = get_random_unicode(seed) + rand_string_2 = get_random_unicode(seed + 1) + dataset_line = (label, rand_string_1, rand_string_2) + f.write(f"{label}\t{i}\t{i}\t{rand_string_1}\t{rand_string_2}\n") + + # append line to correct dataset split + mocked_data[file_type].append(dataset_line) + seed += 1 + + return mocked_data + + +class TestMRPC(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train", "test"]) + def test_mrpc(self, split): + dataset = MRPC(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "test"]) + def test_sst2_split_argument(self, split): + dataset1 = MRPC(root=self.root_dir, split=split) + (dataset2,) = MRPC(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index 29dcc5f165..b228e773b4 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -11,6 +11,7 @@ from .imdb import IMDB from .iwslt2016 import IWSLT2016 from .iwslt2017 import IWSLT2017 +from .mrpc import MRPC from .multi30k import Multi30k from .penntreebank import PennTreebank from .sogounews import SogouNews @@ -36,6 +37,7 @@ "IMDB": IMDB, "IWSLT2016": IWSLT2016, "IWSLT2017": IWSLT2017, + "MRPC": MRPC, "Multi30k": Multi30k, "PennTreebank": PennTreebank, "SQuAD1": SQuAD1, diff --git a/torchtext/datasets/mrpc.py b/torchtext/datasets/mrpc.py new file mode 100644 index 0000000000..c5077a195f --- /dev/null +++ b/torchtext/datasets/mrpc.py @@ -0,0 +1,73 @@ +import csv +import os +from typing import Union, Tuple + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _wrap_split_argument, + _create_dataset_directory, +) + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + + +URL = { + "train": "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt", + "test": "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt", +} + +MD5 = { + "train": "793daf7b6224281e75fe61c1f80afe35", + "test": "e437fdddb92535b820fe8852e2df8a49", +} + +NUM_LINES = { + "train": 4076, + "test": 1725, +} + + +DATASET_NAME = "MRPC" + + +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "test")) +def MRPC(root: str, split: Union[Tuple[str], str]): + """MRPC Dataset + + For additional details refer to https://www.microsoft.com/en-us/download/details.aspx?id=52398 + + Number of lines per split: + - train: 4076 + - test: 1725 + + Args: + root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') + split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `test`) + + :returns: DataPipe that yields data points from MRPC dataset which consist of label, sentence1, sentence2 + :rtype: (int, str, str) + """ + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + def _filepath_fn(x): + return os.path.join(root, os.path.basename(x)) + + def _modify_res(x): + return (int(x[0]), x[3], x[4]) + + url_dp = IterableWrapper([URL[split]]) + # cache data on-disk with sanity check + cache_dp = url_dp.on_disk_cache( + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(URL[split]): MD5[split]}, + hash_type="md5", + ) + cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, encoding="utf-8") + cache_dp = cache_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res) + return cache_dp.shuffle().set_shuffle(False).sharding_filter()