-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathtest_intent.py
88 lines (69 loc) · 2.23 KB
/
test_intent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
import pickle
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict
import torch
from dataset import SeqClsDataset
from model import SeqClassifier
from utils import Vocab
def main(args):
with open(args.cache_dir / "vocab.pkl", "rb") as f:
vocab: Vocab = pickle.load(f)
intent_idx_path = args.cache_dir / "intent2idx.json"
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
data = json.loads(args.test_file.read_text())
dataset = SeqClsDataset(data, vocab, intent2idx, args.max_len)
# TODO: crecate DataLoader for test dataset
embeddings = torch.load(args.cache_dir / "embeddings.pt")
model = SeqClassifier(
embeddings,
args.hidden_size,
args.num_layers,
args.dropout,
args.bidirectional,
dataset.num_classes,
)
model.eval()
ckpt = torch.load(args.ckpt_path)
# load weights into model
# TODO: predict dataset
# TODO: write prediction to file (args.pred_file)
def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument(
"--test_file",
type=Path,
help="Path to the test file.",
required=True
)
parser.add_argument(
"--cache_dir",
type=Path,
help="Directory to the preprocessed caches.",
default="./cache/intent/",
)
parser.add_argument(
"--ckpt_path",
type=Path,
help="Path to model checkpoint.",
required=True
)
parser.add_argument("--pred_file", type=Path, default="pred.intent.csv")
# data
parser.add_argument("--max_len", type=int, default=128)
# model
parser.add_argument("--hidden_size", type=int, default=512)
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--bidirectional", type=bool, default=True)
# data loader
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument(
"--device", type=torch.device, help="cpu, cuda, cuda:0, cuda:1", default="cpu"
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)