-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathpredict.py
25 lines (17 loc) · 1016 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from tlxzoo.module.bert.transform import BertTransform
import tensorlayerx as tlx
from tlxzoo.text.text_token_classidication import TextTokenClassification
if __name__ == '__main__':
transform = BertTransform(vocab_file="./demo/text/text_classification/bert/vocab.txt", task="token", max_length=128)
model = TextTokenClassification("bert", n_class=9)
model.load_weights("./demo/text/token_classification/bert/model.npz")
tokens = ["CRICKET", "-", "LEICESTERSHIRE", "TAKE", "OVER", "AT", "TOP", "AFTER", "INNINGS", "VICTORY", "."]
labels = ["O", "O", "B-ORG", "O", "O", "O", "O", "O", "O", "O", "O"]
x, y = transform(tokens, labels)
inputs = tlx.convert_to_tensor([x["inputs"]])
token_type_ids = tlx.convert_to_tensor([x["token_type_ids"]])
attention_mask = tlx.convert_to_tensor([x["attention_mask"]])
_logits = model(inputs=inputs, token_type_ids=token_type_ids, attention_mask=attention_mask)
labels = tlx.argmax(_logits, axis=-1)
print(y)
print(labels)