-
Notifications
You must be signed in to change notification settings - Fork 812
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for MRPC dataset with unit tests (#1712)
* Add support for MRPC dataset * Add unit tests * Remove lambda functions * Add dataset documentation * Add shuffle and sharding
- Loading branch information
Showing
4 changed files
with
151 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,11 @@ IMDb | |
|
||
.. autofunction:: IMDB | ||
|
||
MRPC | ||
~~~~ | ||
|
||
.. autofunction:: MRPC | ||
|
||
SogouNews | ||
~~~~~~~~~ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
from collections import defaultdict | ||
from unittest.mock import patch | ||
|
||
from parameterized import parameterized | ||
from torchtext.datasets.mrpc import MRPC | ||
|
||
from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode | ||
from ..common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
def _get_mock_dataset(root_dir): | ||
""" | ||
root_dir: directory to the mocked dataset | ||
""" | ||
base_dir = os.path.join(root_dir, "MRPC") | ||
os.makedirs(base_dir, exist_ok=True) | ||
|
||
seed = 1 | ||
mocked_data = defaultdict(list) | ||
for file_name, file_type in [("msr_paraphrase_train.txt", "train"), ("msr_paraphrase_test.txt", "test")]: | ||
txt_file = os.path.join(base_dir, file_name) | ||
with open(txt_file, "w", encoding="utf-8") as f: | ||
f.write("Quality\t#1 ID\t#2 ID\t#1 String\t#2 String\n") | ||
for i in range(5): | ||
label = seed % 2 | ||
rand_string_1 = get_random_unicode(seed) | ||
rand_string_2 = get_random_unicode(seed + 1) | ||
dataset_line = (label, rand_string_1, rand_string_2) | ||
f.write(f"{label}\t{i}\t{i}\t{rand_string_1}\t{rand_string_2}\n") | ||
|
||
# append line to correct dataset split | ||
mocked_data[file_type].append(dataset_line) | ||
seed += 1 | ||
|
||
return mocked_data | ||
|
||
|
||
class TestMRPC(TempDirMixin, TorchtextTestCase): | ||
root_dir = None | ||
samples = [] | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.root_dir = cls.get_base_temp_dir() | ||
cls.samples = _get_mock_dataset(cls.root_dir) | ||
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) | ||
cls.patcher.start() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.patcher.stop() | ||
super().tearDownClass() | ||
|
||
@parameterized.expand(["train", "test"]) | ||
def test_mrpc(self, split): | ||
dataset = MRPC(root=self.root_dir, split=split) | ||
|
||
samples = list(dataset) | ||
expected_samples = self.samples[split] | ||
for sample, expected_sample in zip_equal(samples, expected_samples): | ||
self.assertEqual(sample, expected_sample) | ||
|
||
@parameterized.expand(["train", "test"]) | ||
def test_sst2_split_argument(self, split): | ||
dataset1 = MRPC(root=self.root_dir, split=split) | ||
(dataset2,) = MRPC(root=self.root_dir, split=(split,)) | ||
|
||
for d1, d2 in zip_equal(dataset1, dataset2): | ||
self.assertEqual(d1, d2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import csv | ||
import os | ||
from typing import Union, Tuple | ||
|
||
from torchtext._internal.module_utils import is_module_available | ||
from torchtext.data.datasets_utils import ( | ||
_wrap_split_argument, | ||
_create_dataset_directory, | ||
) | ||
|
||
if is_module_available("torchdata"): | ||
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper | ||
|
||
|
||
URL = { | ||
"train": "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt", | ||
"test": "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt", | ||
} | ||
|
||
MD5 = { | ||
"train": "793daf7b6224281e75fe61c1f80afe35", | ||
"test": "e437fdddb92535b820fe8852e2df8a49", | ||
} | ||
|
||
NUM_LINES = { | ||
"train": 4076, | ||
"test": 1725, | ||
} | ||
|
||
|
||
DATASET_NAME = "MRPC" | ||
|
||
|
||
@_create_dataset_directory(dataset_name=DATASET_NAME) | ||
@_wrap_split_argument(("train", "test")) | ||
def MRPC(root: str, split: Union[Tuple[str], str]): | ||
"""MRPC Dataset | ||
For additional details refer to https://www.microsoft.com/en-us/download/details.aspx?id=52398 | ||
Number of lines per split: | ||
- train: 4076 | ||
- test: 1725 | ||
Args: | ||
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') | ||
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `test`) | ||
:returns: DataPipe that yields data points from MRPC dataset which consist of label, sentence1, sentence2 | ||
:rtype: (int, str, str) | ||
""" | ||
if not is_module_available("torchdata"): | ||
raise ModuleNotFoundError( | ||
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" | ||
) | ||
|
||
def _filepath_fn(x): | ||
return os.path.join(root, os.path.basename(x)) | ||
|
||
def _modify_res(x): | ||
return (int(x[0]), x[3], x[4]) | ||
|
||
url_dp = IterableWrapper([URL[split]]) | ||
# cache data on-disk with sanity check | ||
cache_dp = url_dp.on_disk_cache( | ||
filepath_fn=_filepath_fn, | ||
hash_dict={_filepath_fn(URL[split]): MD5[split]}, | ||
hash_type="md5", | ||
) | ||
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) | ||
cache_dp = FileOpener(cache_dp, encoding="utf-8") | ||
cache_dp = cache_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res) | ||
return cache_dp.shuffle().set_shuffle(False).sharding_filter() |