-
Notifications
You must be signed in to change notification settings - Fork 2
/
pipeline_entities_bert_ontonotes.py
36 lines (27 loc) · 1.57 KB
/
pipeline_entities_bert_ontonotes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from arekit.common.bound import Bound
from arekit.common.entities.base import Entity
from arekit.common.news.objects_parser import SentenceObjectsParserPipelineItem
from arekit.common.text.partitioning.terms import TermsPartitioning
from arekit.processing.entities.obj_desc import NerObjectDescriptor
from arelight.text.ner_ontonotes import BertOntonotesNER
class BertOntonotesNERPipelineItem(SentenceObjectsParserPipelineItem):
def __init__(self, obj_filter=None):
assert(callable(obj_filter) or obj_filter is None)
# Initialize bert-based model instance.
self.__ontonotes_ner = BertOntonotesNER()
self.__obj_filter = obj_filter
super(BertOntonotesNERPipelineItem, self).__init__(TermsPartitioning())
def _get_parts_provider_func(self, input_data, pipeline_ctx):
return self.__iter_subs_values_with_bounds(input_data)
def __iter_subs_values_with_bounds(self, terms_list):
assert(isinstance(terms_list, list))
single_sequence = [terms_list]
processed_sequences = self.__ontonotes_ner.extract(sequences=single_sequence)
for p_sequence in processed_sequences:
for s_obj in p_sequence:
assert(isinstance(s_obj, NerObjectDescriptor))
if self.__obj_filter is not None and not self.__obj_filter(s_obj):
continue
value = " ".join(terms_list[s_obj.Position:s_obj.Position + s_obj.Length])
entity = Entity(value=value, e_type=s_obj.ObjectType)
yield entity, Bound(pos=s_obj.Position, length=s_obj.Length)