-
Notifications
You must be signed in to change notification settings - Fork 97
/
word_sequence.py
178 lines (139 loc) · 4.6 KB
/
word_sequence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
WordSequence类
维护一个字典,把一个list(或者字符串)编码化,或者反向恢复
"""
import numpy as np
class WordSequence(object):
"""一个可以把句子编码化(index)的类
"""
PAD_TAG = '<pad>'
UNK_TAG = '<unk>'
START_TAG = '<s>'
END_TAG = '</s>'
PAD = 0
UNK = 1
START = 2
END = 3
def __init__(self):
"""初始化基本的dict
"""
self.dict = {
WordSequence.PAD_TAG: WordSequence.PAD,
WordSequence.UNK_TAG: WordSequence.UNK,
WordSequence.START_TAG: WordSequence.START,
WordSequence.END_TAG: WordSequence.END,
}
self.fited = False
def to_index(self, word):
"""把一个单字转换为index
"""
assert self.fited, 'WordSequence 尚未 fit'
if word in self.dict:
return self.dict[word]
return WordSequence.UNK
def to_word(self, index):
"""把一个index转换为单字
"""
assert self.fited, 'WordSequence 尚未 fit'
for k, v in self.dict.items():
if v == index:
return k
return WordSequence.UNK_TAG
def size(self):
"""返回字典大小
"""
assert self.fited, 'WordSequence 尚未 fit'
return len(self.dict) + 1
def __len__(self):
"""返回字典大小
"""
return self.size()
def fit(self, sentences, min_count=5, max_count=None, max_features=None):
"""训练 WordSequence
Args:
min_count 最小出现次数
max_count 最大出现次数
max_features 最大特征数
ws = WordSequence()
ws.fit([['hello', 'world']])
"""
assert not self.fited, 'WordSequence 只能 fit 一次'
count = {}
for sentence in sentences:
arr = list(sentence)
for a in arr:
if a not in count:
count[a] = 0
count[a] += 1
if min_count is not None:
count = {k: v for k, v in count.items() if v >= min_count}
if max_count is not None:
count = {k: v for k, v in count.items() if v <= max_count}
self.dict = {
WordSequence.PAD_TAG: WordSequence.PAD,
WordSequence.UNK_TAG: WordSequence.UNK,
WordSequence.START_TAG: WordSequence.START,
WordSequence.END_TAG: WordSequence.END,
}
if isinstance(max_features, int):
count = sorted(list(count.items()), key=lambda x: x[1])
if max_features is not None and len(count) > max_features:
count = count[-int(max_features):]
for w, _ in count:
self.dict[w] = len(self.dict)
else:
for w in sorted(count.keys()):
self.dict[w] = len(self.dict)
self.fited = True
def transform(self,
sentence, max_len=None):
"""把句子转换为向量
例如输入 ['a', 'b', 'c']
输出 [1, 2, 3] 这个数字是字典里的编号,顺序没有意义
"""
assert self.fited, 'WordSequence 尚未 fit'
# if max_len is not None:
# r = [self.PAD] * max_len
# else:
# r = [self.PAD] * len(sentence)
if max_len is not None:
r = [self.PAD] * max_len
else:
r = [self.PAD] * len(sentence)
for index, a in enumerate(sentence):
if max_len is not None and index >= len(r):
break
r[index] = self.to_index(a)
return np.array(r)
def inverse_transform(self, indices,
ignore_pad=False, ignore_unk=False,
ignore_start=False, ignore_end=False):
"""把向量转换为句子,和上面的相反
"""
ret = []
for i in indices:
word = self.to_word(i)
if word == WordSequence.PAD_TAG and ignore_pad:
continue
if word == WordSequence.UNK_TAG and ignore_unk:
continue
if word == WordSequence.START_TAG and ignore_start:
continue
if word == WordSequence.END_TAG and ignore_end:
continue
ret.append(word)
return ret
def test():
"""测试
"""
ws = WordSequence()
ws.fit([
['第', '一', '句', '话'],
['第', '二', '句', '话']
])
indice = ws.transform(['第', '三'])
print(indice)
back = ws.inverse_transform(indice)
print(back)
if __name__ == '__main__':
test()