-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathpytorch_utils.py
141 lines (120 loc) · 4.46 KB
/
pytorch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pathlib import Path
class Vocab(object):
def __init__(self, name="vocab",
offset_items=tuple([]),
UNK=None, lower=True):
self.name = name
self.item2idx = {}
self.idx2item = []
self.size = 0
self.UNK = UNK
self.lower=lower
self.batch_add(offset_items, lower=False)
if UNK is not None:
self.add(UNK, lower=False)
self.UNK_ID = self.item2idx[self.UNK]
self.offset = self.size
def add(self, item, lower=True):
if self.lower and lower:
item = item.lower()
if item not in self.item2idx:
self.item2idx[item] = self.size
self.size += 1
self.idx2item.append(item)
def batch_add(self, items, lower=True):
for item in items:
self.add(item, lower=lower)
def in_vocab(self, item, lower=True):
if self.lower and lower:
item = item.lower()
return item in self.item2idx
def getidx(self, item, lower=True):
if self.lower and lower:
item = item.lower()
if item not in self.item2idx:
if self.UNK is None:
raise RuntimeError("UNK is not defined. %s not in vocab." % item)
return self.UNK_ID
return self.item2idx[item]
def __repr__(self):
return "Vocab(name={}, size={:d}, UNK={}, offset={:d}, lower={})".format(
self.name, self.size,
self.UNK, self.offset,
self.lower
)
def load_word_vectors(vector_file, ndims, vocab, cache_file, override_cache=False):
W = np.zeros((vocab.size, ndims), dtype="float32")
# Check for cached file and return vectors
cache_file = Path(cache_file)
if cache_file.is_file() and not override_cache:
W = np.load(cache_file)
return W
# Else load vectors from the vector file
total, found = 0, 0
with open(vector_file) as fp:
for i, line in enumerate(fp):
line = line.rstrip().split()
if line:
total += 1
try:
assert len(line) == ndims+1,(
"Line[{}] {} vector dims {} doesn't match ndims={}".format(i, line[0], len(line)-1, ndims)
)
except AssertionError as e:
print(e)
continue
word = line[0]
idx = vocab.getidx(word)
if idx >= vocab.offset:
found += 1
vecs = np.array(list(map(float, line[1:])))
W[idx, :] += vecs
# Write to cache file
print("Found {} [{:.2f}%] vectors from {} vectors in {} with ndims={}".format(
found, found * 100/vocab.size, total, vector_file, ndims))
norm_W = np.sqrt((W*W).sum(axis=1, keepdims=True))
valid_idx = norm_W.squeeze() != 0
W[valid_idx, :] /= norm_W[valid_idx]
print("Caching embedding with shape {} to {}".format(W.shape, cache_file.as_posix()))
np.save(cache_file, W)
return W
class Seq2Vec(object):
def __init__(self, vocab):
self.vocab = vocab
def encode(self, seq):
vec = []
for item in seq:
vec.append(self.vocab.getidx(item))
return vec
def batch_encode(self, seq_batch):
vecs = [self.encode(seq) for seq in seq_batch]
return vecs
class Seq2OneHot(object):
def __init__(self, size):
self.size = size
def encode(self, x, as_variable=False):
one_hot = torch.zeros(self.size)
for i in x:
one_hot[i] += 1
one_hot = one_hot.view(1, -1)
if as_variable:
return Variable(one_hot)
return one_hot
def print_log_probs(log_probs, label_vocab, label_true=None):
for i, label_probs in enumerate(log_probs.data.tolist()):
prob_string = ", ".join([
"{}: {:.3f}".format(label_vocab.idx2item[j], val)
for j, val in enumerate(label_probs)
])
true_string = "?"
if label_true is not None:
true_string = label_vocab.idx2item[label_true[i]]
print(prob_string, "True label: ", true_string)