forked from gauthierdmn/question_answering
-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathdata_loader.py
23 lines (18 loc) · 862 Bytes
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# external libraries
import torch.utils.data as data
class SquadDataset(data.Dataset):
"""Custom Dataset for SQuAD data compatible with torch.utils.data.DataLoader."""
def __init__(self, w_context, c_context, w_question, c_question, labels):
"""Set the path for context, question and labels."""
self.w_context = w_context
self.c_context = c_context
self.w_question = w_question
self.c_question = c_question
self.labels = labels
def __getitem__(self, index):
"""Returns one data tuple of the form ( word context, character context, word question,
character question, answer)."""
return self.w_context[index], self.c_context[index], self.w_question[index], self.c_question[index],\
self.labels[index]
def __len__(self):
return len(self.w_context)