Skip to content

Commit

Permalink
Add support for MRPC dataset with unit tests (#1712)
Browse files Browse the repository at this point in the history
* Add support for MRPC dataset
* Add unit tests
* Remove lambda functions
* Add dataset documentation
* Add shuffle and sharding
  • Loading branch information
vcm2114 authored May 18, 2022
1 parent ec20f88 commit bb41e4f
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ IMDb

.. autofunction:: IMDB

MRPC
~~~~

.. autofunction:: MRPC

SogouNews
~~~~~~~~~

Expand Down
71 changes: 71 additions & 0 deletions test/datasets/test_mrpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
from collections import defaultdict
from unittest.mock import patch

from parameterized import parameterized
from torchtext.datasets.mrpc import MRPC

from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode
from ..common.torchtext_test_case import TorchtextTestCase


def _get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""
base_dir = os.path.join(root_dir, "MRPC")
os.makedirs(base_dir, exist_ok=True)

seed = 1
mocked_data = defaultdict(list)
for file_name, file_type in [("msr_paraphrase_train.txt", "train"), ("msr_paraphrase_test.txt", "test")]:
txt_file = os.path.join(base_dir, file_name)
with open(txt_file, "w", encoding="utf-8") as f:
f.write("Quality\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for i in range(5):
label = seed % 2
rand_string_1 = get_random_unicode(seed)
rand_string_2 = get_random_unicode(seed + 1)
dataset_line = (label, rand_string_1, rand_string_2)
f.write(f"{label}\t{i}\t{i}\t{rand_string_1}\t{rand_string_2}\n")

# append line to correct dataset split
mocked_data[file_type].append(dataset_line)
seed += 1

return mocked_data


class TestMRPC(TempDirMixin, TorchtextTestCase):
root_dir = None
samples = []

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.root_dir = cls.get_base_temp_dir()
cls.samples = _get_mock_dataset(cls.root_dir)
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True)
cls.patcher.start()

@classmethod
def tearDownClass(cls):
cls.patcher.stop()
super().tearDownClass()

@parameterized.expand(["train", "test"])
def test_mrpc(self, split):
dataset = MRPC(root=self.root_dir, split=split)

samples = list(dataset)
expected_samples = self.samples[split]
for sample, expected_sample in zip_equal(samples, expected_samples):
self.assertEqual(sample, expected_sample)

@parameterized.expand(["train", "test"])
def test_sst2_split_argument(self, split):
dataset1 = MRPC(root=self.root_dir, split=split)
(dataset2,) = MRPC(root=self.root_dir, split=(split,))

for d1, d2 in zip_equal(dataset1, dataset2):
self.assertEqual(d1, d2)
2 changes: 2 additions & 0 deletions torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .imdb import IMDB
from .iwslt2016 import IWSLT2016
from .iwslt2017 import IWSLT2017
from .mrpc import MRPC
from .multi30k import Multi30k
from .penntreebank import PennTreebank
from .sogounews import SogouNews
Expand All @@ -36,6 +37,7 @@
"IMDB": IMDB,
"IWSLT2016": IWSLT2016,
"IWSLT2017": IWSLT2017,
"MRPC": MRPC,
"Multi30k": Multi30k,
"PennTreebank": PennTreebank,
"SQuAD1": SQuAD1,
Expand Down
73 changes: 73 additions & 0 deletions torchtext/datasets/mrpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import csv
import os
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_wrap_split_argument,
_create_dataset_directory,
)

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper


URL = {
"train": "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt",
"test": "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt",
}

MD5 = {
"train": "793daf7b6224281e75fe61c1f80afe35",
"test": "e437fdddb92535b820fe8852e2df8a49",
}

NUM_LINES = {
"train": 4076,
"test": 1725,
}


DATASET_NAME = "MRPC"


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def MRPC(root: str, split: Union[Tuple[str], str]):
"""MRPC Dataset
For additional details refer to https://www.microsoft.com/en-us/download/details.aspx?id=52398
Number of lines per split:
- train: 4076
- test: 1725
Args:
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache')
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `test`)
:returns: DataPipe that yields data points from MRPC dataset which consist of label, sentence1, sentence2
:rtype: (int, str, str)
"""
if not is_module_available("torchdata"):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(x):
return os.path.join(root, os.path.basename(x))

def _modify_res(x):
return (int(x[0]), x[3], x[4])

url_dp = IterableWrapper([URL[split]])
# cache data on-disk with sanity check
cache_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(URL[split]): MD5[split]},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
cache_dp = FileOpener(cache_dp, encoding="utf-8")
cache_dp = cache_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res)
return cache_dp.shuffle().set_shuffle(False).sharding_filter()

0 comments on commit bb41e4f

Please sign in to comment.