From 37c0b49d69d461401bba6edcce1836cce3988d1d Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Tue, 2 Aug 2022 12:24:08 +0300 Subject: [PATCH] large and temp API refactoring by switching to 0.22.1 AREkit version. --- README.md | 6 +- arelight/demo/infer_bert_rus.py | 15 +-- arelight/demo/infer_nn_rus.py | 16 +-- arelight/demo/utils.py | 4 +- arelight/exp/doc_ops.py | 6 - arelight/exp/exp_io.py | 16 +-- arelight/exp/opin_ops.py | 24 +--- arelight/ner_obj_desc.py | 21 +++ arelight/network/bert/ctx.py | 17 +-- arelight/network/nn/ctx.py | 34 +---- arelight/network/nn/embedding.py | 69 ---------- arelight/pipelines/inference_nn.py | 4 +- arelight/pipelines/serialize_bert.py | 24 ++-- arelight/pipelines/serialize_nn.py | 63 ++++----- arelight/pipelines/train_bert.py | 2 +- arelight/text/ner_base.py | 2 +- .../text/pipeline_entities_bert_ontonotes.py | 3 +- download.py | 5 +- examples/args/common.py | 7 +- examples/args/const.py | 1 - examples/demo/wui_nn.py | 4 +- examples/infer_texts_nn.py | 2 - examples/{rusentrel => labels}/__init__.py | 0 examples/labels/base.py | 10 ++ examples/labels/formatter.py | 26 ++++ examples/labels/scalers.py | 26 ++++ examples/rusentrel/common.py | 61 --------- examples/rusentrel/config_setups.py | 71 ---------- examples/rusentrel/configs/__init__.py | 0 examples/rusentrel/configs/common.py | 14 -- examples/rusentrel/configs/mi.py | 12 -- examples/rusentrel/configs/single.py | 76 ----------- examples/rusentrel/exp_io.py | 11 -- examples/serialize_rusentrel_for_bert.py | 123 ------------------ examples/serialize_rusentrel_for_nn.py | 120 ----------------- examples/serialize_texts_bert.py | 7 +- examples/serialize_texts_nn.py | 9 +- examples/train_nn_on_rusentrel.py | 12 +- examples/utils.py | 6 +- setup.py | 2 +- test/test_bert_ontonotes_ner.py | 2 +- test/test_bert_ontonotes_ner_pipeline_item.py | 3 +- test/test_bert_serialization.py | 30 ++--- test/test_demo.py | 3 +- update_arekit.sh | 4 +- 45 files changed, 188 insertions(+), 785 deletions(-) create mode 100644 arelight/ner_obj_desc.py delete mode 100644 arelight/network/nn/embedding.py rename examples/{rusentrel => labels}/__init__.py (100%) create mode 100644 examples/labels/base.py create mode 100644 examples/labels/formatter.py create mode 100644 examples/labels/scalers.py delete mode 100644 examples/rusentrel/common.py delete mode 100644 examples/rusentrel/config_setups.py delete mode 100644 examples/rusentrel/configs/__init__.py delete mode 100644 examples/rusentrel/configs/common.py delete mode 100644 examples/rusentrel/configs/mi.py delete mode 100644 examples/rusentrel/configs/single.py delete mode 100644 examples/rusentrel/exp_io.py delete mode 100644 examples/serialize_rusentrel_for_bert.py delete mode 100644 examples/serialize_rusentrel_for_nn.py diff --git a/README.md b/README.md index 284e21c..abe852c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# ARElight 0.22.0 +# ARElight 0.22.1 ### :point_right: [DEMO](#docker-verion-quick) :point_left: @@ -18,7 +18,7 @@ we adopt [DeepPavlov](https://github.com/deepmipt/DeepPavlov) (BertOntoNotes mo # Dependencies -* arekit == 0.22.0 +* arekit == 0.22.1 * gensim == 3.2.0 * deeppavlov == 0.11.0 * rusenttokenize @@ -151,8 +151,6 @@ python3.6 serialize_texts_nn.py --from-files data/texts-inosmi-rus/e1.txt \ # Other Examples -* Serialize RuSentRel collection for BERT [[code]](examples/serialize_rusentrel_for_bert.py) -* Serialize RuSentRel collection for Neural Networks [[code]](examples/serialize_rusentrel_for_nn.py) * Finetune BERT on samples [[code]](examples/train_bert.py) * Finetune Neural Networks on RuSentRel [[code]](examples/train_nn_on_rusentrel.py) diff --git a/arelight/demo/infer_bert_rus.py b/arelight/demo/infer_bert_rus.py index c21f902..b763735 100644 --- a/arelight/demo/infer_bert_rus.py +++ b/arelight/demo/infer_bert_rus.py @@ -1,18 +1,13 @@ -from arekit.common.experiment.annot.algo.pair_based import PairBasedAnnotationAlgorithm -from arekit.common.experiment.annot.default import DefaultAnnotator from arekit.common.experiment.data_type import DataType from arekit.common.experiment.name_provider import ExperimentNameProvider from arekit.common.folding.nofold import NoFolding from arekit.common.labels.base import NoLabel from arekit.common.labels.provider.constant import ConstantLabelProvider +from arekit.common.opinions.annot.algo.pair_based import PairBasedOpinionAnnotationAlgorithm +from arekit.common.opinions.annot.base import BaseOpinionAnnotator from arekit.common.pipeline.base import BasePipeline -from arekit.contrib.bert.samplers.types import SampleFormattersService -from arekit.contrib.experiment_rusentrel.entities.factory import create_entity_formatter -from arekit.contrib.experiment_rusentrel.entities.types import EntityFormatterTypes -from arekit.contrib.experiment_rusentrel.labels.scalers.three import ThreeLabelScaler -from arekit.contrib.experiment_rusentrel.labels.types import ExperimentPositiveLabel, ExperimentNegativeLabel from arekit.contrib.networks.core.predict.tsv_writer import TsvPredictWriter -from arekit.processing.lemmatization.mystem import MystemWrapper +from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arelight.demo.utils import read_synonyms_collection from arelight.pipelines.backend_brat_json import BratBackendContentsPipelineItem @@ -48,8 +43,8 @@ def demo_infer_texts_bert_pipeline(texts_count, name_provider=ExperimentNameProvider(name="example-bert", suffix="infer"), text_b_type=text_b_type, output_dir=output_dir, - opin_annot=DefaultAnnotator( - PairBasedAnnotationAlgorithm( + opin_annot=BaseOpinionAnnotator( + PairBasedOpinionAnnotationAlgorithm( dist_in_terms_bound=None, label_provider=ConstantLabelProvider(label_instance=NoLabel()))), data_folding=NoFolding(doc_ids_to_fold=list(range(texts_count)), diff --git a/arelight/demo/infer_nn_rus.py b/arelight/demo/infer_nn_rus.py index c525d3f..cae6ef0 100644 --- a/arelight/demo/infer_nn_rus.py +++ b/arelight/demo/infer_nn_rus.py @@ -1,20 +1,16 @@ -from arekit.common.experiment.annot.algo.pair_based import PairBasedAnnotationAlgorithm -from arekit.common.experiment.annot.default import DefaultAnnotator from arekit.common.experiment.data_type import DataType from arekit.common.experiment.name_provider import ExperimentNameProvider from arekit.common.folding.nofold import NoFolding from arekit.common.labels.base import NoLabel from arekit.common.labels.provider.constant import ConstantLabelProvider +from arekit.common.opinions.annot.algo.pair_based import PairBasedOpinionAnnotationAlgorithm +from arekit.common.opinions.annot.base import BaseOpinionAnnotator from arekit.common.pipeline.base import BasePipeline -from arekit.contrib.experiment_rusentrel.entities.factory import create_entity_formatter -from arekit.contrib.experiment_rusentrel.entities.types import EntityFormattersService -from arekit.contrib.experiment_rusentrel.labels.scalers.three import ThreeLabelScaler -from arekit.contrib.experiment_rusentrel.labels.types import ExperimentPositiveLabel, ExperimentNegativeLabel from arekit.contrib.networks.core.callback.stat import TrainingStatProviderCallback from arekit.contrib.networks.core.callback.train_limiter import TrainingLimiterCallback from arekit.contrib.networks.core.predict.tsv_writer import TsvPredictWriter from arekit.contrib.networks.enum_name_types import ModelNames -from arekit.processing.lemmatization.mystem import MystemWrapper +from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arelight.demo.utils import read_synonyms_collection from arelight.network.nn.common import create_network_model_io, create_bags_collection_type, create_full_model_name @@ -29,7 +25,6 @@ def demo_infer_texts_tensorflow_nn_pipeline(texts_count, frames_collection, output_dir, synonyms_filepath, - embedding_filepath, embedding_matrix_filepath=None, vocab_filepath=None, bags_per_minibatch=2, @@ -56,14 +51,13 @@ def demo_infer_texts_tensorflow_nn_pipeline(texts_count, frames_collection=frames_collection, synonyms=read_synonyms_collection(synonyms_filepath=synonyms_filepath, stemmer=stemmer), terms_per_context=terms_per_context, - embedding_path=embedding_filepath, entities_parser=BertOntonotesNERPipelineItem( lambda s_obj: s_obj.ObjectType in ["ORG", "PERSON", "LOC", "GPE"]), entity_fmt=create_entity_formatter(entity_fmt_type), stemmer=stemmer, name_provider=exp_name_provider, - opin_annot=DefaultAnnotator( - PairBasedAnnotationAlgorithm( + opin_annot=BaseOpinionAnnotator( + PairBasedOpinionAnnotationAlgorithm( dist_in_terms_bound=None, label_provider=ConstantLabelProvider(label_instance=NoLabel()))), output_dir=output_dir, diff --git a/arelight/demo/utils.py b/arelight/demo/utils.py index bf769ab..89a67e5 100644 --- a/arelight/demo/utils.py +++ b/arelight/demo/utils.py @@ -1,6 +1,6 @@ from arekit.common.text.stemmer import Stemmer -from arekit.contrib.experiment_rusentrel.synonyms.collection import StemmerBasedSynonymCollection -from arekit.contrib.source.rusentrel.utils import iter_synonym_groups +from arekit.contrib.source.synonyms.utils import iter_synonym_groups +from arekit.contrib.utils.synonyms.stemmer_based import StemmerBasedSynonymCollection def iter_groups(filepath): diff --git a/arelight/exp/doc_ops.py b/arelight/exp/doc_ops.py index 2f7e375..eb913f2 100644 --- a/arelight/exp/doc_ops.py +++ b/arelight/exp/doc_ops.py @@ -1,14 +1,8 @@ -from arekit.common.experiment.api.enums import BaseDocumentTag from arekit.common.experiment.api.ops_doc import DocumentOperations class CustomDocOperations(DocumentOperations): - def iter_tagget_doc_ids(self, tag): - assert(isinstance(tag, BaseDocumentTag)) - assert(tag == BaseDocumentTag.Annotate) - return self.__doc_ids - def __init__(self, exp_ctx, text_parser): super(CustomDocOperations, self).__init__(exp_ctx, text_parser) self.__docs = None diff --git a/arelight/exp/exp_io.py b/arelight/exp/exp_io.py index 962fe03..9388b54 100644 --- a/arelight/exp/exp_io.py +++ b/arelight/exp/exp_io.py @@ -1,23 +1,9 @@ -import os -from arekit.contrib.experiment_rusentrel.model_io.tf_networks import RuSentRelExperimentNetworkIOUtils - - -class InferIOUtils(RuSentRelExperimentNetworkIOUtils): +class InferIOUtils(object): def __init__(self, output_dir, exp_ctx): assert(isinstance(output_dir, str)) super(InferIOUtils, self).__init__(exp_ctx=exp_ctx) self.__output_dir = output_dir - def __create_annot_input_target(self, doc_id, data_type): - filename = "annot_input_d{doc_id}_{data_type}.txt".format(doc_id=doc_id, data_type=data_type.name) - return os.path.join(self._get_target_dir(), filename) - def _get_experiment_sources_dir(self): return self.__output_dir - - def create_opinion_collection_target(self, doc_id, data_type, check_existance=False): - return self.__create_annot_input_target(doc_id=doc_id, data_type=data_type) - - def create_result_opinion_collection_target(self, doc_id, data_type, epoch_index): - return self.__create_annot_input_target(doc_id=doc_id, data_type=data_type) diff --git a/arelight/exp/opin_ops.py b/arelight/exp/opin_ops.py index ba6eac8..1b9d1de 100644 --- a/arelight/exp/opin_ops.py +++ b/arelight/exp/opin_ops.py @@ -1,8 +1,7 @@ -from arekit.common.experiment.api.ops_opin import OpinionOperations from arekit.common.opinions.collection import OpinionCollection -class CustomOpinionOperations(OpinionOperations): +class CustomOpinionOperations(object): def __init__(self, labels_formatter, exp_io, synonyms, neutral_labels_fmt): super(CustomOpinionOperations, self).__init__() @@ -11,27 +10,6 @@ def __init__(self, labels_formatter, exp_io, synonyms, neutral_labels_fmt): self.__synonyms = synonyms self.__neutral_labels_fmt = neutral_labels_fmt - @property - def LabelsFormatter(self): - return self.__labels_formatter - - def iter_opinions_for_extraction(self, doc_id, data_type): - # Reading automatically annotated collection of neutral opinions. - # TODO. #250, #251 provide opinion annotation here for the particular document. - return self.__exp_io.read_opinion_collection( - target=self.__exp_io.create_result_opinion_collection_target( - doc_id=doc_id, - data_type=data_type, - epoch_index=0), - labels_formatter=self.__neutral_labels_fmt, - create_collection_func=self.create_opinion_collection) - - def get_etalon_opinion_collection(self, doc_id): - return self.create_opinion_collection(None) - - def get_result_opinion_collection(self, doc_id, data_type, epoch_index): - raise Exception("Not Supported") - def create_opinion_collection(self, opinions=None): return OpinionCollection(opinions=[] if opinions is None else opinions, synonyms=self.__synonyms, diff --git a/arelight/ner_obj_desc.py b/arelight/ner_obj_desc.py new file mode 100644 index 0000000..50796ef --- /dev/null +++ b/arelight/ner_obj_desc.py @@ -0,0 +1,21 @@ +class NerObjectDescriptor: + + def __init__(self, pos, length, obj_type): + self.__pos = pos + self.__len = length + self.__obj_type = obj_type + + @property + def Position(self): + return self.__pos + + @property + def Length(self): + return self.__len + + @property + def ObjectType(self): + return self.__obj_type + + def get_range(self): + return self.__pos, self.__pos + self.__len \ No newline at end of file diff --git a/arelight/network/bert/ctx.py b/arelight/network/bert/ctx.py index 5d04c38..0d54360 100644 --- a/arelight/network/bert/ctx.py +++ b/arelight/network/bert/ctx.py @@ -1,25 +1,12 @@ -from arekit.common.entities.str_fmt import StringEntitiesFormatter from arekit.common.experiment.api.ctx_serialization import ExperimentSerializationContext class BertSerializationContext(ExperimentSerializationContext): - def __init__(self, label_scaler, terms_per_context, str_entity_formatter, - annotator, name_provider, data_folding): - assert(isinstance(str_entity_formatter, StringEntitiesFormatter)) + def __init__(self, label_scaler, terms_per_context, name_provider): assert(isinstance(terms_per_context, int)) - - super(BertSerializationContext, self).__init__(annot=annotator, - name_provider=name_provider, - label_scaler=label_scaler, - data_folding=data_folding) - + super(BertSerializationContext, self).__init__(name_provider=name_provider, label_scaler=label_scaler) self.__terms_per_context = terms_per_context - self.__str_entity_formatter = str_entity_formatter - - @property - def StringEntityFormatter(self): - return self.__str_entity_formatter @property def TermsPerContext(self): diff --git a/arelight/network/nn/ctx.py b/arelight/network/nn/ctx.py index 2be9dfc..931f32f 100644 --- a/arelight/network/nn/ctx.py +++ b/arelight/network/nn/ctx.py @@ -1,26 +1,22 @@ from arekit.common.entities.str_fmt import StringEntitiesFormatter -from arekit.contrib.experiment_rusentrel.connotations.provider import RuSentiFramesConnotationProvider -from arekit.contrib.experiment_rusentrel.labels.scalers.three import ThreeLabelScaler from arekit.contrib.networks.core.input.ctx_serialization import NetworkSerializationContext -from arekit.contrib.networks.embeddings.base import Embedding +from arekit.contrib.networks.embedding import Embedding from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection -from arekit.processing.pos.base import POSTagger +from arekit.contrib.utils.connotations.rusentiframes_sentiment import RuSentiFramesConnotationProvider +from arekit.contrib.utils.processing.pos.base import POSTagger -class NetworkSerializationContext(NetworkSerializationContext): +class CustomNeuralNetworkSerializationContext(NetworkSerializationContext): - def __init__(self, labels_scaler, pos_tagger, embedding, - terms_per_context, str_entity_formatter, annotator, - frames_collection, frame_variant_collection, name_provider, data_folding): + def __init__(self, labels_scaler, pos_tagger, embedding, terms_per_context, str_entity_formatter, + frames_collection, frame_variant_collection, name_provider): assert(isinstance(embedding, Embedding)) assert(isinstance(pos_tagger, POSTagger)) assert(isinstance(frames_collection, RuSentiFramesCollection)) assert(isinstance(str_entity_formatter, StringEntitiesFormatter)) assert(isinstance(terms_per_context, int)) - super(NetworkSerializationContext, self).__init__( - labels_scaler=labels_scaler, annot=annotator, - name_provider=name_provider, data_folding=data_folding) + super(NetworkSerializationContext, self).__init__(label_scaler=labels_scaler, name_provider=name_provider) self.__pos_tagger = pos_tagger self.__terms_per_context = terms_per_context @@ -35,22 +31,6 @@ def __init__(self, labels_scaler, pos_tagger, embedding, def PosTagger(self): return self.__pos_tagger - @property - def StringEntityFormatter(self): - return self.__str_entity_formatter - - @property - def StringEntityEmbeddingFormatter(self): - return self.__str_entity_formatter - - @property - def FrameVariantCollection(self): - return self.__frame_variant_collection - - @property - def WordEmbedding(self): - return self.__word_embedding - @property def TermsPerContext(self): return self.__terms_per_context diff --git a/arelight/network/nn/embedding.py b/arelight/network/nn/embedding.py deleted file mode 100644 index 4ebbf69..0000000 --- a/arelight/network/nn/embedding.py +++ /dev/null @@ -1,69 +0,0 @@ -import numpy as np -from arekit.common.text.stemmer import Stemmer -from arekit.contrib.networks.embeddings.base import Embedding -from gensim.models import KeyedVectors - - -class RusvectoresEmbedding(Embedding): - - def __init__(self, matrix, words): - super(RusvectoresEmbedding, self).__init__(matrix=matrix, - words=words) - - self.__index_without_pos = self.__create_terms_without_pos() - self.__stemmer = None - self.__lemmatize_by_default = True - - @classmethod - def from_word2vec_format(cls, filepath, binary): - assert(isinstance(binary, bool)) - - w2v_model = KeyedVectors.load_word2vec_format(filepath, binary=binary) - words_count = len(w2v_model.wv.vocab) - - return cls(matrix=np.array([vector for vector in w2v_model.syn0]), - words=[w2v_model.wv.index2word[index] for index in range(words_count)]) - - def set_stemmer(self, stemmer): - assert(isinstance(stemmer, Stemmer)) - self.__stemmer = stemmer - - def try_find_index_by_plain_word(self, word): - assert(isinstance(word, str)) - - temp = self.__lemmatize_by_default - self.__lemmatize_by_default = False - index = super(RusvectoresEmbedding, self).try_find_index_by_plain_word(word) - self.__lemmatize_by_default = temp - - return index - - def _handler(self, word): - return self.__try_find_word_index_pair_lemmatized(word, self.__lemmatize_by_default) - - # region private methods - - def __try_find_word_index_pair_lemmatized(self, term, lemmatize): - assert(isinstance(term, str)) - assert(isinstance(lemmatize, bool)) - - if lemmatize: - term = self.__stemmer.lemmatize_to_str(term) - - index = self.__index_without_pos[term] \ - if term in self.__index_without_pos else None - - return term, index - - def __create_terms_without_pos(self): - d = {} - for word_with_pos, index in self.iter_vocabulary(): - assert(isinstance(word_with_pos, str)) - word = word_with_pos.split(u'_')[0] - if word in d: - continue - d[word] = index - - return d - - # endregion \ No newline at end of file diff --git a/arelight/pipelines/inference_nn.py b/arelight/pipelines/inference_nn.py index a3d70d3..b3425db 100644 --- a/arelight/pipelines/inference_nn.py +++ b/arelight/pipelines/inference_nn.py @@ -18,7 +18,7 @@ from arekit.contrib.networks.core.predict.base_writer import BasePredictWriter from arekit.contrib.networks.factory import create_network_and_network_config_funcs from arekit.contrib.networks.shapes import NetworkInputShapes -from arekit.processing.languages.ru.pos_service import PartOfSpeechTypesService +from arekit.contrib.utils.processing.languages.ru.pos_service import PartOfSpeechTypesService from arelight.exp.exp_io import InferIOUtils @@ -76,7 +76,7 @@ def apply_core(self, input_data, pipeline_ctx): # Update for further pipeline items. pipeline_ctx.update("predict_fp", tgt) - # Fetch other required in furter information from input_data. + # Fetch other required in further information from input_data. samples_filepath = input_data.create_samples_writer_target(self.__data_type) embedding = input_data.load_embedding() vocab = input_data.load_vocab() diff --git a/arelight/pipelines/serialize_bert.py b/arelight/pipelines/serialize_bert.py index 29144cc..1479b3f 100644 --- a/arelight/pipelines/serialize_bert.py +++ b/arelight/pipelines/serialize_bert.py @@ -1,6 +1,4 @@ from arekit.common.entities.str_fmt import StringEntitiesFormatter -from arekit.common.experiment.api.base import BaseExperiment -from arekit.common.experiment.engine import ExperimentEngine from arekit.common.experiment.name_provider import ExperimentNameProvider from arekit.common.folding.base import BaseDataFolding from arekit.common.labels.base import NoLabel @@ -8,11 +6,9 @@ from arekit.common.labels.str_fmt import StringLabelsFormatter from arekit.common.news.entities_grouping import EntitiesGroupingPipelineItem from arekit.common.pipeline.items.base import BasePipelineItem -from arekit.common.synonyms import SynonymsCollection +from arekit.common.synonyms.base import SynonymsCollection from arekit.common.text.parser import BaseTextParser -from arekit.contrib.bert.handlers.serializer import BertExperimentInputSerializerIterationHandler -from arekit.contrib.bert.samplers.types import BertSampleProviderTypes -from arekit.processing.text.pipeline_terms_splitter import TermsSplitterParser +from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser from arelight.exp.doc_ops import CustomDocOperations from arelight.exp.exp_io import InferIOUtils @@ -23,7 +19,7 @@ class BertTextsSerializationPipelineItem(BasePipelineItem): - def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot, name_provider, + def __init__(self, terms_per_context, entities_parser, synonyms, name_provider, entity_fmt, text_b_type, data_folding, output_dir): assert(isinstance(entities_parser, BasePipelineItem)) assert(isinstance(entity_fmt, StringEntitiesFormatter)) @@ -37,13 +33,9 @@ def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot, nam # Label provider setup. labels_fmt = StringLabelsFormatter(stol={"neu": NoLabel}) - self.__exp_ctx = BertSerializationContext( - label_scaler=SingleLabelScaler(NoLabel()), - annotator=opin_annot, - terms_per_context=terms_per_context, - str_entity_formatter=entity_fmt, - name_provider=name_provider, - data_folding=data_folding) + self.__exp_ctx = BertSerializationContext(label_scaler=SingleLabelScaler(NoLabel()), + terms_per_context=terms_per_context, + name_provider=name_provider) self.__exp_io = InferIOUtils(exp_ctx=self.__exp_ctx, output_dir=output_dir) @@ -67,6 +59,7 @@ def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot, nam doc_ops=self.__doc_ops, opin_ops=self.__opin_ops) + # TODO. Handler does not exist anymore. self.__handler = BertExperimentInputSerializerIterationHandler( exp_io=self.__exp_io, exp_ctx=self.__exp_ctx, @@ -75,7 +68,7 @@ def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot, nam sample_labels_fmt=labels_fmt, annot_labels_fmt=labels_fmt, sample_provider_type=text_b_type, - entity_formatter=self.__exp_ctx.StringEntityFormatter, + entity_formatter=entity_fmt, value_to_group_id_func=synonyms.get_synonym_group_index, balance_train_samples=True) @@ -94,6 +87,7 @@ def apply_core(self, input_data, pipeline_ctx): # Setup document. self.__doc_ops.set_docs(docs) + # TODO. Engine does not exist anymore. engine = ExperimentEngine(self.__exp_ctx.DataFolding) engine.run([self.__handler]) diff --git a/arelight/pipelines/serialize_nn.py b/arelight/pipelines/serialize_nn.py index 795b05b..19fd03c 100644 --- a/arelight/pipelines/serialize_nn.py +++ b/arelight/pipelines/serialize_nn.py @@ -7,46 +7,49 @@ from arekit.common.labels.str_fmt import StringLabelsFormatter from arekit.common.news.entities_grouping import EntitiesGroupingPipelineItem from arekit.common.pipeline.items.base import BasePipelineItem -from arekit.common.synonyms import SynonymsCollection +from arekit.common.synonyms.base import SynonymsCollection +from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders from arekit.common.text.parser import BaseTextParser from arekit.common.text.stemmer import Stemmer -from arekit.contrib.networks.core.input.helper import NetworkInputHelper from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection -from arekit.processing.lemmatization.mystem import MystemWrapper -from arekit.processing.pos.mystem_wrap import POSMystemWrapper -from arekit.processing.text.pipeline_frames import FrameVariantsParser -from arekit.processing.text.pipeline_frames_lemmatized import LemmasBasedFrameVariantsParser -from arekit.processing.text.pipeline_frames_negation import FrameVariantsSentimentNegation -from arekit.processing.text.pipeline_terms_splitter import TermsSplitterParser -from arekit.processing.text.pipeline_tokenizer import DefaultTextTokenizer +from arekit.contrib.utils.pipelines.items.text.frames import FrameVariantsParser +from arekit.contrib.utils.pipelines.items.text.frames_lemmatized import LemmasBasedFrameVariantsParser +from arekit.contrib.utils.pipelines.items.text.frames_negation import FrameVariantsSentimentNegation +from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser +from arekit.contrib.utils.pipelines.items.text.tokenizer import DefaultTextTokenizer +from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper +from arekit.contrib.utils.processing.pos.mystem_wrap import POSMystemWrapper +from arekit.contrib.utils.resources import load_embedding_news_mystem_skipgram_1000_20_2015 from arelight.exp.doc_ops import CustomDocOperations from arelight.exp.exp_io import InferIOUtils from arelight.exp.opin_ops import CustomOpinionOperations from arelight.network.nn.common import create_and_fill_variant_collection from arelight.network.nn.ctx import NetworkSerializationContext -from arelight.network.nn.embedding import RusvectoresEmbedding from arelight.pipelines.utils import input_to_docs +# TODO. This become a part of AREkit. +# TODO. This become a part of AREkit. +# TODO. This become a part of AREkit. +# TODO. This should be removed. class NetworkTextsSerializationPipelineItem(BasePipelineItem): - def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot, name_provider, - embedding_path, frames_collection, entity_fmt, stemmer, data_folding, output_dir): + def __init__(self, terms_per_context, entities_parser, synonyms, name_provider, + frames_collection, entity_fmt, stemmer, data_folding, output_dir, embedding=None): assert(isinstance(frames_collection, RuSentiFramesCollection)) assert(isinstance(entities_parser, BasePipelineItem)) assert(isinstance(entity_fmt, StringEntitiesFormatter)) assert(isinstance(synonyms, SynonymsCollection)) assert(isinstance(terms_per_context, int)) - assert(isinstance(embedding_path, str)) assert(isinstance(stemmer, Stemmer)) assert(isinstance(data_folding, BaseDataFolding)) assert(isinstance(name_provider, ExperimentNameProvider)) assert(isinstance(output_dir, str)) - # Initalize embedding. - embedding = RusvectoresEmbedding.from_word2vec_format(filepath=embedding_path, binary=True) - embedding.set_stemmer(stemmer) + # Initialize embedding. + if embedding is None: + self.__embedding = load_embedding_news_mystem_skipgram_1000_20_2015() # Initialize synonyms collection. self.__synonyms = synonyms @@ -60,7 +63,9 @@ def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot, nam self.__text_parser = BaseTextParser(pipeline=[ TermsSplitterParser(), entities_parser, - EntitiesGroupingPipelineItem(lambda value: self.get_synonym_group_index(self.__synonyms, value)), + EntitiesGroupingPipelineItem( + lambda value: SynonymsCollectionValuesGroupingProviders.provide_existed_or_register_missed_value( + self.__synonyms, value)), DefaultTextTokenizer(keep_tokens=True), FrameVariantsParser(frame_variants=frame_variants_collection), LemmasBasedFrameVariantsParser(save_lemmas=False, @@ -68,18 +73,9 @@ def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot, nam frame_variants=frame_variants_collection), FrameVariantsSentimentNegation()]) - # initialize expriment related data. - self.__exp_ctx = NetworkSerializationContext( - labels_scaler=SingleLabelScaler(NoLabel()), - embedding=embedding, - annotator=opin_annot, - terms_per_context=terms_per_context, - str_entity_formatter=entity_fmt, - pos_tagger=pos_tagger, - name_provider=name_provider, - frames_collection=frames_collection, - frame_variant_collection=frame_variants_collection, - data_folding=data_folding) + # initialize experiment related data. + self.__exp_ctx = NetworkSerializationContext(labels_scaler=SingleLabelScaler(NoLabel()), + name_provider=name_provider) self.__exp_io = InferIOUtils(exp_ctx=self.__exp_ctx, output_dir=output_dir) @@ -92,18 +88,12 @@ def __init__(self, terms_per_context, entities_parser, synonyms, opin_annot, nam synonyms=synonyms, neutral_labels_fmt=self.__labels_fmt) + # TODO. Remove this. self.__exp = BaseExperiment(exp_io=self.__exp_io, exp_ctx=self.__exp_ctx, doc_ops=self.__doc_ops, opin_ops=self.__opin_ops) - @staticmethod - def get_synonym_group_index(synonyms, value): - assert(isinstance(synonyms, SynonymsCollection)) - if not synonyms.contains_synonym_value(value): - synonyms.add_synonym_value(value) - return synonyms.get_synonym_group_index(value) - def apply_core(self, input_data, pipeline_ctx): assert(isinstance(input_data, list)) @@ -112,6 +102,7 @@ def apply_core(self, input_data, pipeline_ctx): # Setup document. self.__doc_ops.set_docs(docs) + # TODO. This is outdated. NetworkInputHelper.prepare(exp_ctx=self.__exp.ExperimentContext, exp_io=self.__exp.ExperimentIO, doc_ops=self.__exp.DocumentOperations, diff --git a/arelight/pipelines/train_bert.py b/arelight/pipelines/train_bert.py index 52be8f3..80f5545 100644 --- a/arelight/pipelines/train_bert.py +++ b/arelight/pipelines/train_bert.py @@ -2,7 +2,7 @@ from arekit.common.data.storages.base import BaseRowsStorage from arekit.common.pipeline.context import PipelineContext from arekit.common.pipeline.items.base import BasePipelineItem -from arekit.common.synonyms import SynonymsCollection +from arekit.common.synonyms.base import SynonymsCollection from deeppavlov.models.bert import bert_classifier from deeppavlov.models.preprocessors.bert_preprocessor import BertPreprocessor from tqdm import tqdm diff --git a/arelight/text/ner_base.py b/arelight/text/ner_base.py index d3e6946..02fb3f2 100644 --- a/arelight/text/ner_base.py +++ b/arelight/text/ner_base.py @@ -1,4 +1,4 @@ -from arekit.processing.entities.obj_desc import NerObjectDescriptor +from arelight.ner_obj_desc import NerObjectDescriptor class BaseNER(object): diff --git a/arelight/text/pipeline_entities_bert_ontonotes.py b/arelight/text/pipeline_entities_bert_ontonotes.py index 3f133c2..418bc64 100644 --- a/arelight/text/pipeline_entities_bert_ontonotes.py +++ b/arelight/text/pipeline_entities_bert_ontonotes.py @@ -2,7 +2,8 @@ from arekit.common.entities.base import Entity from arekit.common.news.objects_parser import SentenceObjectsParserPipelineItem from arekit.common.text.partitioning.terms import TermsPartitioning -from arekit.processing.entities.obj_desc import NerObjectDescriptor + +from arelight.ner_obj_desc import NerObjectDescriptor from arelight.text.ner_ontonotes import BertOntonotesNER diff --git a/download.py b/download.py index 194d928..087d4e8 100644 --- a/download.py +++ b/download.py @@ -1,6 +1,8 @@ import os import tarfile -from arekit.contrib.source import utils + +from arekit.common import utils + from examples.args import const @@ -8,7 +10,6 @@ def download_examples_data(): root_dir = utils.get_default_download_dir() data = { - const.EMBEDDING_FILEPATH: "http://rusvectores.org/static/models/rusvectores2/news_mystem_skipgram_1000_20_2015.bin.gz", const.SYNONYMS_FILEPATH: "https://raw.githubusercontent.com/nicolay-r/RuSentRel/v1.1/synonyms.txt", # PCNN: pretrained model dir. const.PCNN_DEFAULT_MODEL_TAR: "https://www.dropbox.com/s/ceqy69vj59te534/fx_ctx_pcnn.tar.gz?dl=1", diff --git a/examples/args/common.py b/examples/args/common.py index 24459ee..018521a 100644 --- a/examples/args/common.py +++ b/examples/args/common.py @@ -1,12 +1,9 @@ -from arekit.contrib.bert.samplers.types import SampleFormattersService -from arekit.contrib.experiment_rusentrel.entities.types import EntityFormattersService -from arekit.contrib.experiment_rusentrel.labels.formatters.rusentiframes import ExperimentRuSentiFramesLabelsFormatter from arekit.contrib.networks.enum_name_types import ModelNamesService from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection from arekit.contrib.source.rusentiframes.types import RuSentiFramesVersionsService, RuSentiFramesVersions -from arekit.processing.lemmatization.mystem import MystemWrapper -from arelight.text.pipeline_entities_bert_ontonotes import BertOntonotesNERPipelineItem +from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper +from arelight.text.pipeline_entities_bert_ontonotes import BertOntonotesNERPipelineItem from arelight.text.pipeline_entities_default import TextEntitiesParser from examples.args.base import BaseArg diff --git a/examples/args/const.py b/examples/args/const.py index c9bdcac..40a2ecf 100644 --- a/examples/args/const.py +++ b/examples/args/const.py @@ -10,7 +10,6 @@ DATA_DIR = join(current_dir, "../../data") DEFAULT_TEXT_FILEPATH = join(DATA_DIR, "texts-inosmi-rus/e1.txt") -EMBEDDING_FILEPATH = join(DATA_DIR, "news_mystem_skipgram_1000_20_2015.bin.gz") SYNONYMS_FILEPATH = join(DATA_DIR, "synonyms.txt") # Common model dir. diff --git a/examples/demo/wui_nn.py b/examples/demo/wui_nn.py index b247cd5..27df8f7 100644 --- a/examples/demo/wui_nn.py +++ b/examples/demo/wui_nn.py @@ -7,14 +7,13 @@ import sys from os.path import join, basename -from arekit.contrib.experiment_rusentrel.labels.formatters.rusentiframes import ExperimentRuSentiFramesLabelsFormatter from arekit.contrib.networks.enum_input_types import ModelInputType from arekit.contrib.networks.enum_name_types import ModelNames from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection from arekit.contrib.source.rusentiframes.types import RuSentiFramesVersions from arelight.demo.infer_nn_rus import demo_infer_texts_tensorflow_nn_pipeline - +from arelight.labels.formatter import ExperimentRuSentiFramesLabelsFormatter bratUrl = '/brat/' @@ -74,7 +73,6 @@ def prepare_template(data, text, bratUrl, model_name): model_input_type=ModelInputType.SingleInstance, synonyms_filepath=join(data_dir, "synonyms.txt"), model_load_dir=join(data_dir, "models"), - embedding_filepath=join(data_dir, "news_mystem_skipgram_1000_20_2015.bin.gz"), frames_collection=frames_collection) brat_json = ppl.run([text.strip()]) diff --git a/examples/infer_texts_nn.py b/examples/infer_texts_nn.py index 0c331c9..327fc55 100644 --- a/examples/infer_texts_nn.py +++ b/examples/infer_texts_nn.py @@ -19,7 +19,6 @@ common.InputTextArg.add_argument(parser, default=None) common.FromFilesArg.add_argument(parser, default=[const.DEFAULT_TEXT_FILEPATH]) common.SynonymsCollectionFilepathArg.add_argument(parser, default=join(const.DATA_DIR, "synonyms.txt")) - common.RusVectoresEmbeddingFilepathArg.add_argument(parser, default=const.EMBEDDING_FILEPATH) common.LabelsCountArg.add_argument(parser, default=3) common.ModelNameArg.add_argument(parser, default=ModelNames.PCNN.value) common.TermsPerContextArg.add_argument(parser, default=const.TERMS_PER_CONTEXT) @@ -60,7 +59,6 @@ frames_collection=common.FramesColectionArg.read_argument(args), vocab_filepath=common.VocabFilepathArg.read_argument(args), embedding_matrix_filepath=common.EmbeddingMatrixFilepathArg.read_argument(args), - embedding_filepath=common.RusVectoresEmbeddingFilepathArg.read_argument(args), model_load_dir=common.ModelLoadDirArg.read_argument(args), terms_per_context=common.TermsPerContextArg.read_argument(args), synonyms_filepath=common.SynonymsCollectionFilepathArg.read_argument(args), diff --git a/examples/rusentrel/__init__.py b/examples/labels/__init__.py similarity index 100% rename from examples/rusentrel/__init__.py rename to examples/labels/__init__.py diff --git a/examples/labels/base.py b/examples/labels/base.py new file mode 100644 index 0000000..63d0a2c --- /dev/null +++ b/examples/labels/base.py @@ -0,0 +1,10 @@ +from arekit.common.labels.base import Label + + +class PositiveLabel(Label): + pass + + +class NegativeLabel(Label): + pass + diff --git a/examples/labels/formatter.py b/examples/labels/formatter.py new file mode 100644 index 0000000..ac5ccf5 --- /dev/null +++ b/examples/labels/formatter.py @@ -0,0 +1,26 @@ +from arekit.contrib.source.rusentiframes.labels_fmt import RuSentiFramesLabelsFormatter +from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter + +from examples.labels.base import NegativeLabel, PositiveLabel + + +class RuSentRelExperimentLabelsFormatter(RuSentRelLabelsFormatter): + + @classmethod + def _negative_label_type(cls): + return NegativeLabel + + @classmethod + def _positive_label_type(cls): + return PositiveLabel + + +class ExperimentRuSentiFramesLabelsFormatter(RuSentiFramesLabelsFormatter): + + @classmethod + def _positive_label_type(cls): + return PositiveLabel + + @classmethod + def _negative_label_type(cls): + return NegativeLabel \ No newline at end of file diff --git a/examples/labels/scalers.py b/examples/labels/scalers.py new file mode 100644 index 0000000..057c03b --- /dev/null +++ b/examples/labels/scalers.py @@ -0,0 +1,26 @@ +from collections import OrderedDict + +from arekit.common.labels.base import NoLabel +from arekit.common.labels.scaler.sentiment import SentimentLabelScaler + +from examples.labels.base import NegativeLabel, PositiveLabel + + +class ThreeLabelScaler(SentimentLabelScaler): + + def __init__(self): + + uint_labels = [(NoLabel(), 0), + (PositiveLabel(), 1), + (NegativeLabel(), 2)] + + int_labels = [(NoLabel(), 0), + (PositiveLabel(), 1), + (NegativeLabel(), -1)] + + super(ThreeLabelScaler, self).__init__(uint_dict=OrderedDict(uint_labels), + int_dict=OrderedDict(int_labels)) + + def invert_label(self, label): + int_label = self.label_to_int(label) + return self.int_to_label(-int_label) \ No newline at end of file diff --git a/examples/rusentrel/common.py b/examples/rusentrel/common.py deleted file mode 100644 index 334f00f..0000000 --- a/examples/rusentrel/common.py +++ /dev/null @@ -1,61 +0,0 @@ -from itertools import chain - -from arekit.common.experiment.data_type import DataType -from arekit.common.folding.nofold import NoFolding -from arekit.contrib.experiment_rusentrel.exp_ds.utils import read_ruattitudes_in_memory -from arekit.contrib.source.rusentrel.io_utils import RuSentRelIOUtils - -from arelight.network.nn.embedding import RusvectoresEmbedding - - -class Common: - - @staticmethod - def ra_doc_id_func(doc_id): - return 10000 + doc_id - - @staticmethod - def create_folding(rusentrel_version, ruattitudes_version, doc_id_func): - rsr_indices = list(RuSentRelIOUtils.iter_collection_indices(rusentrel_version)) - - ra_indices_dict = dict() - if ruattitudes_version is not None: - ra_indices_dict = read_ruattitudes_in_memory(version=ruattitudes_version, - keep_doc_ids_only=True, - doc_id_func=doc_id_func) - - return NoFolding(doc_ids_to_fold=list(chain(rsr_indices, ra_indices_dict.keys())), - supported_data_types=[DataType.Train]) - - @staticmethod - def create_exp_name(rusentrel_version, ra_version, folding_type): - return "".join(["rsr-{v}".format(v=rusentrel_version.value), - "-ra-{v}".format(v=ra_version.value) if ra_version is not None else "", - "-{ft}".format(ft=folding_type.value)]) - - @staticmethod - def create_exp_name_suffix(use_balancing, terms_per_context, dist_in_terms_between_att_ends): - """ Provides an external parameters that assumes to be synchronized both - by serialization and training experiment stages. - """ - assert(isinstance(use_balancing, bool)) - assert(isinstance(terms_per_context, int)) - assert(isinstance(dist_in_terms_between_att_ends, int) or dist_in_terms_between_att_ends is None) - - # You may provide your own parameters out there - params = [ - u"balanced" if use_balancing else u"nobalance", - u"tpc{}".format(terms_per_context) - ] - - if dist_in_terms_between_att_ends is not None: - params.append(u"dbe{}".format(dist_in_terms_between_att_ends)) - - return u'-'.join(params) - - @staticmethod - def load_rusvectores_embedding(filepath, stemmer): - embedding = RusvectoresEmbedding.from_word2vec_format(filepath=filepath, binary=True) - embedding.set_stemmer(stemmer) - return embedding - diff --git a/examples/rusentrel/config_setups.py b/examples/rusentrel/config_setups.py deleted file mode 100644 index 8a24871..0000000 --- a/examples/rusentrel/config_setups.py +++ /dev/null @@ -1,71 +0,0 @@ -from arekit.contrib.experiment_rusentrel.types import ExperimentTypes -from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig -from arekit.contrib.networks.enum_input_types import ModelInputType -from arekit.contrib.networks.enum_name_types import ModelNames -from arekit.contrib.networks.multi.configurations.base import BaseMultiInstanceConfig - -from examples.rusentrel.configs.common import apply_classic_mi_settings -from examples.rusentrel.configs.mi import apply_ds_mi_settings -from examples.rusentrel.configs.single import ctx_self_att_bilstm_custom_config, ctx_att_bilstm_p_zhou_custom_config, \ - ctx_att_bilstm_z_yang_custom_config, ctx_bilstm_custom_config, ctx_cnn_custom_config, ctx_lstm_custom_config, \ - ctx_pcnn_custom_config, ctx_rcnn_custom_config, ctx_rcnn_z_yang_custom_config, ctx_rcnn_p_zhou_custom_config - - -def modify_config_for_model(model_name, model_input_type, config): - assert(isinstance(model_name, ModelNames)) - assert(isinstance(model_input_type, ModelInputType)) - assert (isinstance(config, DefaultNetworkConfig)) - - if model_input_type == ModelInputType.SingleInstance: - if model_name == ModelNames.SelfAttentionBiLSTM: - ctx_self_att_bilstm_custom_config(config) - if model_name == ModelNames.AttSelfPZhouBiLSTM: - ctx_att_bilstm_p_zhou_custom_config(config) - if model_name == ModelNames.AttSelfZYangBiLSTM: - ctx_att_bilstm_z_yang_custom_config(config) - if model_name == ModelNames.BiLSTM: - ctx_bilstm_custom_config(config) - if model_name == ModelNames.CNN: - ctx_cnn_custom_config(config) - if model_name == ModelNames.LSTM: - ctx_lstm_custom_config(config) - if model_name == ModelNames.PCNN: - ctx_pcnn_custom_config(config) - if model_name == ModelNames.RCNN: - ctx_rcnn_custom_config(config) - if model_name == ModelNames.RCNNAttZYang: - ctx_rcnn_z_yang_custom_config(config) - if model_name == ModelNames.RCNNAttPZhou: - ctx_rcnn_p_zhou_custom_config(config) - - return - - if model_input_type == ModelInputType.MultiInstanceMaxPooling or \ - model_input_type == ModelInputType.MultiInstanceWithSelfAttention: - assert(isinstance(config, BaseMultiInstanceConfig)) - - # We assign all the settings related to the case of - # single instance model, for the related ContextConfig. - modify_config_for_model(model_name=model_name, - model_input_type=ModelInputType.SingleInstance, - config=config.ContextConfig) - - # We apply modification of some parameters - config.fix_context_parameters() - return - - raise NotImplementedError(u"Model input type {input_type} is not supported".format( - input_type=model_input_type)) - - -def optionally_modify_config_for_experiment(config, exp_type, model_input_type): - assert(isinstance(exp_type, ExperimentTypes)) - assert(isinstance(model_input_type, ModelInputType)) - - if model_input_type == ModelInputType.MultiInstanceMaxPooling: - if exp_type == ExperimentTypes.RuSentRel: - apply_classic_mi_settings(config) - if exp_type == ExperimentTypes.RuAttitudes or exp_type == ExperimentTypes.RuSentRelWithRuAttitudes: - apply_ds_mi_settings(config) - - return diff --git a/examples/rusentrel/configs/__init__.py b/examples/rusentrel/configs/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/rusentrel/configs/common.py b/examples/rusentrel/configs/common.py deleted file mode 100644 index db77fc5..0000000 --- a/examples/rusentrel/configs/common.py +++ /dev/null @@ -1,14 +0,0 @@ -from arekit.contrib.networks.multi.configurations.base import BaseMultiInstanceConfig - -from examples.args.const import BAGS_PER_MINIBATCH - -MI_CONTEXTS_PER_OPINION = 3 - - -def apply_classic_mi_settings(config): - """ - Multi instance version - """ - assert(isinstance(config, BaseMultiInstanceConfig)) - config.set_contexts_per_opinion(MI_CONTEXTS_PER_OPINION) - config.modify_bags_per_minibatch(BAGS_PER_MINIBATCH) diff --git a/examples/rusentrel/configs/mi.py b/examples/rusentrel/configs/mi.py deleted file mode 100644 index 8996de3..0000000 --- a/examples/rusentrel/configs/mi.py +++ /dev/null @@ -1,12 +0,0 @@ -from arekit.contrib.networks.multi.configurations.base import BaseMultiInstanceConfig - -from examples.rusentrel.configs.common import MI_CONTEXTS_PER_OPINION - - -def apply_ds_mi_settings(config): - """ - This function describes a base config setup for all models. - """ - assert(isinstance(config, BaseMultiInstanceConfig)) - config.set_contexts_per_opinion(MI_CONTEXTS_PER_OPINION) - config.modify_bags_per_minibatch(2) \ No newline at end of file diff --git a/examples/rusentrel/configs/single.py b/examples/rusentrel/configs/single.py deleted file mode 100644 index 0b736d4..0000000 --- a/examples/rusentrel/configs/single.py +++ /dev/null @@ -1,76 +0,0 @@ -import tensorflow as tf - -from arekit.contrib.networks.context.configurations.att_self_p_zhou_bilstm import AttentionSelfPZhouBiLSTMConfig -from arekit.contrib.networks.context.configurations.att_self_z_yang_bilstm import AttentionSelfZYangBiLSTMConfig -from arekit.contrib.networks.context.configurations.bilstm import BiLSTMConfig -from arekit.contrib.networks.context.configurations.cnn import CNNConfig -from arekit.contrib.networks.context.configurations.rcnn import RCNNConfig -from arekit.contrib.networks.context.configurations.rnn import RNNConfig -from arekit.contrib.networks.context.configurations.self_att_bilstm import SelfAttentionBiLSTMConfig -from arekit.contrib.networks.tf_helpers.cell_types import CellTypes - -from examples.args.const import TERMS_PER_CONTEXT - - -def ctx_self_att_bilstm_custom_config(config): - assert(isinstance(config, SelfAttentionBiLSTMConfig)) - config.modify_penaltization_term_coef(0.5) - config.modify_cell_type(CellTypes.BasicLSTM) - config.modify_dropout_rnn_keep_prob(0.8) - config.modify_terms_per_context(TERMS_PER_CONTEXT) - - -def ctx_att_bilstm_p_zhou_custom_config(config): - assert(isinstance(config, AttentionSelfPZhouBiLSTMConfig)) - config.modify_hidden_size(128) - config.modify_cell_type(CellTypes.LSTM) - config.modify_dropout_rnn_keep_prob(0.9) - - -def ctx_att_bilstm_z_yang_custom_config(config): - assert(isinstance(config, AttentionSelfZYangBiLSTMConfig)) - config.modify_weight_initializer(tf.contrib.layers.xavier_initializer()) - - -def ctx_bilstm_custom_config(config): - assert(isinstance(config, BiLSTMConfig)) - config.modify_hidden_size(128) - config.modify_cell_type(CellTypes.BasicLSTM) - config.modify_dropout_rnn_keep_prob(0.8) - config.modify_terms_per_context(TERMS_PER_CONTEXT) - - -def ctx_cnn_custom_config(config): - assert(isinstance(config, CNNConfig)) - config.modify_weight_initializer(tf.contrib.layers.xavier_initializer()) - - -def ctx_lstm_custom_config(config): - assert(isinstance(config, RNNConfig)) - config.modify_cell_type(CellTypes.BasicLSTM) - config.modify_hidden_size(128) - config.modify_dropout_rnn_keep_prob(0.8) - config.modify_terms_per_context(TERMS_PER_CONTEXT) - - -def ctx_pcnn_custom_config(config): - assert(isinstance(config, CNNConfig)) - config.modify_weight_initializer(tf.contrib.layers.xavier_initializer()) - - -def ctx_rcnn_custom_config(config): - assert(isinstance(config, RCNNConfig)) - config.modify_cell_type(CellTypes.LSTM) - config.modify_dropout_rnn_keep_prob(0.9) - - -def ctx_rcnn_p_zhou_custom_config(config): - assert(isinstance(config, RCNNConfig)) - config.modify_cell_type(CellTypes.LSTM) - config.modify_dropout_rnn_keep_prob(0.9) - - -def ctx_rcnn_z_yang_custom_config(config): - assert(isinstance(config, RCNNConfig)) - config.modify_cell_type(CellTypes.LSTM) - config.modify_dropout_rnn_keep_prob(0.9) \ No newline at end of file diff --git a/examples/rusentrel/exp_io.py b/examples/rusentrel/exp_io.py deleted file mode 100644 index f22780d..0000000 --- a/examples/rusentrel/exp_io.py +++ /dev/null @@ -1,11 +0,0 @@ -from arekit.contrib.experiment_rusentrel.model_io.tf_networks import RuSentRelExperimentNetworkIOUtils -from examples.args.const import OUTPUT_DIR - - -class CustomRuSentRelNetworkExperimentIO(RuSentRelExperimentNetworkIOUtils): - - def try_prepare(self): - pass - - def _get_experiment_sources_dir(self): - return OUTPUT_DIR diff --git a/examples/serialize_rusentrel_for_bert.py b/examples/serialize_rusentrel_for_bert.py deleted file mode 100644 index 5d30c63..0000000 --- a/examples/serialize_rusentrel_for_bert.py +++ /dev/null @@ -1,123 +0,0 @@ -import argparse - -from arekit.common.experiment.annot.algo.pair_based import PairBasedAnnotationAlgorithm -from arekit.common.experiment.annot.default import DefaultAnnotator -from arekit.common.experiment.engine import ExperimentEngine -from arekit.common.experiment.name_provider import ExperimentNameProvider -from arekit.common.folding.types import FoldingType -from arekit.common.labels.provider.constant import ConstantLabelProvider -from arekit.common.labels.str_fmt import StringLabelsFormatter -from arekit.contrib.bert.handlers.serializer import BertExperimentInputSerializerIterationHandler -from arekit.contrib.experiment_rusentrel.entities.factory import create_entity_formatter -from arekit.contrib.experiment_rusentrel.factory import create_experiment -from arekit.contrib.experiment_rusentrel.labels.types import ExperimentNeutralLabel, ExperimentPositiveLabel, \ - ExperimentNegativeLabel -from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider -from arekit.contrib.experiment_rusentrel.types import ExperimentTypes -from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions - -from examples.args import const, common -from examples.args.const import DEFAULT_TEXT_FILEPATH -from examples.rusentrel.common import Common -from examples.rusentrel.exp_io import CustomRuSentRelNetworkExperimentIO - -from arelight.network.bert.ctx import BertSerializationContext -from examples.utils import create_labels_scaler - - -class ExperimentBERTTextBThreeScaleLabelsFormatter(StringLabelsFormatter): - - def __init__(self): - stol = {'neu': ExperimentNeutralLabel, - 'pos': ExperimentPositiveLabel, - 'neg': ExperimentNegativeLabel} - super(ExperimentBERTTextBThreeScaleLabelsFormatter, self).__init__(stol=stol) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description="Serialization script for obtaining sources, " - "required for inference and training.") - - # Provide arguments. - common.InputTextArg.add_argument(parser, default=None) - common.FromFilesArg.add_argument(parser, default=[DEFAULT_TEXT_FILEPATH]) - common.EntitiesParserArg.add_argument(parser, default="bert-ontonotes") - common.RusVectoresEmbeddingFilepathArg.add_argument(parser, default=const.EMBEDDING_FILEPATH) - common.TermsPerContextArg.add_argument(parser, default=const.TERMS_PER_CONTEXT) - common.UseBalancingArg.add_argument(parser, default=True) - common.DistanceInTermsBetweenAttitudeEndsArg.add_argument(parser, default=None) - common.EntityFormatterTypesArg.add_argument(parser, default="hidden-bert-styled") - common.BertTextBFormatTypeArg.add_argument(parser, default='nli_m') - common.StemmerArg.add_argument(parser, default="mystem") - - # Parsing arguments. - args = parser.parse_args() - - # Reading arguments. - text_from_arg = common.InputTextArg.read_argument(args) - texts_from_files = common.FromFilesArg.read_argument(args) - terms_per_context = common.TermsPerContextArg.read_argument(args) - use_balancing = common.UseBalancingArg.read_argument(args) - stemmer = common.StemmerArg.read_argument(args) - entity_fmt = common.EntityFormatterTypesArg.read_argument(args) - dist_in_terms_between_attitude_ends = common.DistanceInTermsBetweenAttitudeEndsArg.read_argument(args) - - # Predefined parameters. - labels_count = 3 - rusentrel_version = RuSentRelVersions.V11 - synonyms = RuSentRelSynonymsCollectionProvider.load_collection(stemmer=stemmer) - folding_type = FoldingType.Fixed - - annot_algo = PairBasedAnnotationAlgorithm( - dist_in_terms_bound=None, - label_provider=ConstantLabelProvider(label_instance=ExperimentNeutralLabel())) - - exp_name = Common.create_exp_name(rusentrel_version=rusentrel_version, - ra_version=None, - folding_type=folding_type) - - extra_name_suffix = Common.create_exp_name_suffix( - use_balancing=use_balancing, - terms_per_context=terms_per_context, - dist_in_terms_between_att_ends=dist_in_terms_between_attitude_ends) - - data_folding = Common.create_folding( - rusentrel_version=rusentrel_version, - ruattitudes_version=None, - doc_id_func=lambda doc_id: Common.ra_doc_id_func(doc_id=doc_id)) - - # Preparing necessary structures for further initializations. - exp_ctx = BertSerializationContext( - label_scaler=create_labels_scaler(labels_count), - terms_per_context=terms_per_context, - str_entity_formatter=create_entity_formatter(entity_fmt), - annotator=DefaultAnnotator(annot_algo=annot_algo), - name_provider=ExperimentNameProvider(name=exp_name, suffix=extra_name_suffix + "-bert"), - data_folding=data_folding) - - experiment = create_experiment( - exp_type=ExperimentTypes.RuSentRel, - exp_ctx=exp_ctx, - exp_io=CustomRuSentRelNetworkExperimentIO(exp_ctx), - folding_type=folding_type, - rusentrel_version=rusentrel_version, - ruattitudes_version=None, - load_ruattitude_docs=True, - ra_doc_id_func=lambda doc_id: Common.ra_doc_id_func(doc_id=doc_id)) - - handler = BertExperimentInputSerializerIterationHandler( - exp_io=experiment.ExperimentIO, - exp_ctx=experiment.ExperimentContext, - doc_ops=experiment.DocumentOperations, - opin_ops=experiment.OpinionOperations, - sample_labels_fmt=ExperimentBERTTextBThreeScaleLabelsFormatter(), - annot_labels_fmt=experiment.OpinionOperations.LabelsFormatter, - sample_provider_type=common.BertTextBFormatTypeArg.read_argument(args), - entity_formatter=experiment.ExperimentContext.StringEntityFormatter, - value_to_group_id_func=synonyms.get_synonym_group_index, - balance_train_samples=use_balancing) - - engine = ExperimentEngine(exp_ctx.DataFolding) - - engine.run(handlers=[handler]) diff --git a/examples/serialize_rusentrel_for_nn.py b/examples/serialize_rusentrel_for_nn.py deleted file mode 100644 index f0b3f6e..0000000 --- a/examples/serialize_rusentrel_for_nn.py +++ /dev/null @@ -1,120 +0,0 @@ -import argparse - -from arelight.network.nn.common import create_and_fill_variant_collection -from arelight.network.nn.ctx import NetworkSerializationContext - -from examples.args import const, common - -from arekit.common.experiment.annot.algo.pair_based import PairBasedAnnotationAlgorithm -from arekit.common.experiment.annot.default import DefaultAnnotator -from arekit.common.experiment.engine import ExperimentEngine -from arekit.common.experiment.name_provider import ExperimentNameProvider -from arekit.common.folding.types import FoldingType -from arekit.common.labels.provider.constant import ConstantLabelProvider -from arekit.contrib.experiment_rusentrel.entities.factory import create_entity_formatter -from arekit.contrib.experiment_rusentrel.factory import create_experiment -from arekit.contrib.experiment_rusentrel.labels.types import ExperimentNeutralLabel -from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider -from arekit.contrib.experiment_rusentrel.types import ExperimentTypes -from arekit.contrib.networks.handlers.serializer import NetworksInputSerializerExperimentIteration -from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions -from arekit.processing.lemmatization.mystem import MystemWrapper -from arekit.processing.pos.mystem_wrap import POSMystemWrapper -from arekit.processing.text.pipeline_frames_lemmatized import LemmasBasedFrameVariantsParser -from arekit.processing.text.pipeline_tokenizer import DefaultTextTokenizer - -from examples.rusentrel.common import Common -from examples.rusentrel.exp_io import CustomRuSentRelNetworkExperimentIO -from examples.utils import create_labels_scaler - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description="RuSentRel dataset serialization script") - - # Provide arguments. - common.LabelsCountArg.add_argument(parser, default=3) - common.RusVectoresEmbeddingFilepathArg.add_argument(parser, default=const.EMBEDDING_FILEPATH) - common.TermsPerContextArg.add_argument(parser, default=const.TERMS_PER_CONTEXT) - common.EntityFormatterTypesArg.add_argument(parser, default="hidden-simple-eng") - common.StemmerArg.add_argument(parser, default="mystem") - common.UseBalancingArg.add_argument(parser, default=True) - common.DistanceInTermsBetweenAttitudeEndsArg.add_argument(parser, default=None) - common.FramesColectionArg.add_argument(parser) - - # Parsing arguments. - args = parser.parse_args() - - # Reading arguments. - embedding_filepath = common.RusVectoresEmbeddingFilepathArg.read_argument(args) - labels_count = common.LabelsCountArg.read_argument(args) - terms_per_context = common.TermsPerContextArg.read_argument(args) - entity_fmt = common.EntityFormatterTypesArg.read_argument(args) - stemmer = common.StemmerArg.read_argument(args) - use_balancing = common.UseBalancingArg.read_argument(args) - dist_in_terms_between_attitude_ends = common.DistanceInTermsBetweenAttitudeEndsArg.read_argument(args) - frames_collection = common.FramesColectionArg.read_argument(args) - pos_tagger = POSMystemWrapper(MystemWrapper().MystemInstance) - - # Default parameters - rusentrel_version = RuSentRelVersions.V11 - folding_type = FoldingType.Fixed - - synonyms_collection = RuSentRelSynonymsCollectionProvider.load_collection(stemmer=stemmer) - - annot_algo = PairBasedAnnotationAlgorithm( - dist_in_terms_bound=None, - label_provider=ConstantLabelProvider(label_instance=ExperimentNeutralLabel())) - - exp_name = Common.create_exp_name(rusentrel_version=rusentrel_version, - ra_version=None, - folding_type=folding_type) - - extra_name_suffix = Common.create_exp_name_suffix( - use_balancing=use_balancing, - terms_per_context=terms_per_context, - dist_in_terms_between_att_ends=dist_in_terms_between_attitude_ends) - - data_folding = Common.create_folding( - rusentrel_version=rusentrel_version, - ruattitudes_version=None, - doc_id_func=lambda doc_id: Common.ra_doc_id_func(doc_id=doc_id)) - - # Preparing necessary structures for further initializations. - exp_ctx = NetworkSerializationContext( - labels_scaler=create_labels_scaler(labels_count), - embedding=Common.load_rusvectores_embedding(filepath=embedding_filepath, stemmer=stemmer), - terms_per_context=terms_per_context, - str_entity_formatter=create_entity_formatter(entity_fmt), - pos_tagger=pos_tagger, - annotator=DefaultAnnotator(annot_algo=annot_algo), - name_provider=ExperimentNameProvider(name=exp_name, suffix=extra_name_suffix), - frames_collection=frames_collection, - frame_variant_collection=create_and_fill_variant_collection(frames_collection), - data_folding=data_folding) - - experiment = create_experiment( - exp_type=ExperimentTypes.RuSentRel, - exp_ctx=exp_ctx, - exp_io=CustomRuSentRelNetworkExperimentIO(exp_ctx), - folding_type=folding_type, - rusentrel_version=rusentrel_version, - ruattitudes_version=None, - load_ruattitude_docs=True, - text_parser_items=[ - DefaultTextTokenizer(keep_tokens=True), - LemmasBasedFrameVariantsParser(frame_variants=exp_ctx.FrameVariantCollection, stemmer=stemmer) - ], - ra_doc_id_func=lambda doc_id: Common.ra_doc_id_func(doc_id=doc_id)) - - # Performing serialization process. - serialization_handler = NetworksInputSerializerExperimentIteration( - exp_ctx=experiment.ExperimentContext, - doc_ops=experiment.DocumentOperations, - opin_ops=experiment.OpinionOperations, - exp_io=experiment.ExperimentIO, - balance=use_balancing, - value_to_group_id_func=synonyms_collection.get_synonym_group_index) - - engine = ExperimentEngine(exp_ctx.DataFolding) - - engine.run(handlers=[serialization_handler]) diff --git a/examples/serialize_texts_bert.py b/examples/serialize_texts_bert.py index 6b011fc..fca7db8 100644 --- a/examples/serialize_texts_bert.py +++ b/examples/serialize_texts_bert.py @@ -1,15 +1,14 @@ import argparse from os.path import join -from arekit.common.experiment.annot.algo.pair_based import PairBasedAnnotationAlgorithm -from arekit.common.experiment.annot.default import DefaultAnnotator from arekit.common.experiment.data_type import DataType from arekit.common.experiment.name_provider import ExperimentNameProvider from arekit.common.folding.nofold import NoFolding from arekit.common.labels.base import NoLabel from arekit.common.labels.provider.constant import ConstantLabelProvider +from arekit.common.opinions.annot.algo.pair_based import PairBasedOpinionAnnotationAlgorithm +from arekit.common.opinions.annot.base import BaseOpinionAnnotator from arekit.common.pipeline.base import BasePipeline -from arekit.contrib.experiment_rusentrel.entities.factory import create_entity_formatter from arelight.pipelines.serialize_bert import BertTextsSerializationPipelineItem from examples.args import const, common @@ -48,7 +47,7 @@ name_provider=ExperimentNameProvider(name="example-bert", suffix="serialize"), entity_fmt=create_entity_formatter(common.EntityFormatterTypesArg.read_argument(args)), text_b_type=common.BertTextBFormatTypeArg.read_argument(args), - opin_annot=DefaultAnnotator(annot_algo=PairBasedAnnotationAlgorithm( + opin_annot=BaseOpinionAnnotator(annot_algo=PairBasedOpinionAnnotationAlgorithm( dist_in_terms_bound=None, label_provider=ConstantLabelProvider(label_instance=NoLabel()))), data_folding=NoFolding(doc_ids_to_fold=list(range(len(texts_from_files))), diff --git a/examples/serialize_texts_nn.py b/examples/serialize_texts_nn.py index dd4d225..3e9cae7 100644 --- a/examples/serialize_texts_nn.py +++ b/examples/serialize_texts_nn.py @@ -1,15 +1,14 @@ import argparse from os.path import join -from arekit.common.experiment.annot.algo.pair_based import PairBasedAnnotationAlgorithm -from arekit.common.experiment.annot.default import DefaultAnnotator from arekit.common.experiment.data_type import DataType from arekit.common.experiment.name_provider import ExperimentNameProvider from arekit.common.folding.nofold import NoFolding from arekit.common.labels.base import NoLabel from arekit.common.labels.provider.constant import ConstantLabelProvider +from arekit.common.opinions.annot.algo.pair_based import PairBasedOpinionAnnotationAlgorithm +from arekit.common.opinions.annot.base import BaseOpinionAnnotator from arekit.common.pipeline.base import BasePipeline -from arekit.contrib.experiment_rusentrel.entities.factory import create_entity_formatter from examples.args import const from examples.args import common @@ -28,7 +27,6 @@ common.FromFilesArg.add_argument(parser, default=[DEFAULT_TEXT_FILEPATH]) common.SynonymsCollectionFilepathArg.add_argument(parser, default=join(const.DATA_DIR, "synonyms.txt")) common.EntitiesParserArg.add_argument(parser, default="bert-ontonotes") - common.RusVectoresEmbeddingFilepathArg.add_argument(parser, default=const.EMBEDDING_FILEPATH) common.TermsPerContextArg.add_argument(parser, default=const.TERMS_PER_CONTEXT) common.EntityFormatterTypesArg.add_argument(parser, default="hidden-simple-eng") common.StemmerArg.add_argument(parser, default="mystem") @@ -46,10 +44,9 @@ synonyms=synonyms_collection, output_dir=const.OUTPUT_DIR, entities_parser=common.EntitiesParserArg.read_argument(args), - embedding_path=common.RusVectoresEmbeddingFilepathArg.read_argument(args), name_provider=ExperimentNameProvider(name="example", suffix="serialize"), entity_fmt=create_entity_formatter(common.EntityFormatterTypesArg.read_argument(args)), - opin_annot=DefaultAnnotator(annot_algo=PairBasedAnnotationAlgorithm( + opin_annot=BaseOpinionAnnotator(annot_algo=PairBasedOpinionAnnotationAlgorithm( dist_in_terms_bound=None, label_provider=ConstantLabelProvider(label_instance=NoLabel()))), stemmer=common.StemmerArg.read_argument(args), diff --git a/examples/train_nn_on_rusentrel.py b/examples/train_nn_on_rusentrel.py index 3752ddd..58fc6ec 100644 --- a/examples/train_nn_on_rusentrel.py +++ b/examples/train_nn_on_rusentrel.py @@ -1,5 +1,8 @@ import argparse +from arekit.contrib.utils.np_utils.writer import NpzDataWriter +from arekit.contrib.utils.processing.languages.ru.pos_service import PartOfSpeechTypesService + from arelight.network.nn.common import create_bags_collection_type, create_full_model_name, create_network_model_io from examples.args import const, train @@ -8,11 +11,8 @@ from examples.rusentrel.config_setups import optionally_modify_config_for_experiment, modify_config_for_model from examples.rusentrel.exp_io import CustomRuSentRelNetworkExperimentIO -from arekit.common.experiment.api.ctx_training import ExperimentTrainingContext -from arekit.common.experiment.engine import ExperimentEngine from arekit.common.experiment.name_provider import ExperimentNameProvider from arekit.common.folding.types import FoldingType -from arekit.contrib.experiment_rusentrel.types import ExperimentTypes from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig from arekit.contrib.networks.core.callback.hidden import HiddenStatesWriterCallback from arekit.contrib.networks.core.callback.hidden_input import InputHiddenStatesWriterCallback @@ -21,10 +21,7 @@ from arekit.contrib.networks.enum_input_types import ModelInputType from arekit.contrib.networks.enum_name_types import ModelNames from arekit.contrib.networks.factory import create_network_and_network_config_funcs -from arekit.contrib.networks.np_utils.writer import NpzDataWriter -from arekit.contrib.networks.handlers.training import NetworksTrainingIterationHandler from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions -from arekit.processing.languages.ru.pos_service import PartOfSpeechTypesService from examples.args import common from examples.utils import create_labels_scaler @@ -66,6 +63,7 @@ epochs_count = train.EpochsCountArg.read_argument(args) # Utilize predefined versions and folding format. + # TODO. This is outdated. exp_type = ExperimentTypes.RuSentRel rusentrel_version = RuSentRelVersions.V11 folding_type = FoldingType.Fixed @@ -153,6 +151,7 @@ InputHiddenStatesWriterCallback(log_dir=model_target_dir, writer=data_writer) ] + # TODO. Switch to pipeline. training_handler = NetworksTrainingIterationHandler( load_model=model_load_dir is not None, exp_ctx=exp_ctx, @@ -163,6 +162,7 @@ network_callbacks=nework_callbacks, training_epochs=epochs_count) + # TODO. Engine does not exists anymore. engine = ExperimentEngine(exp_ctx.DataFolding) engine.run(handlers=[training_handler]) diff --git a/examples/utils.py b/examples/utils.py index 87d705d..d355810 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,8 +1,4 @@ -from arekit.contrib.experiment_rusentrel.labels.scalers.three import ThreeLabelScaler -from arekit.contrib.experiment_rusentrel.labels.scalers.two import TwoLabelScaler -from arekit.contrib.experiment_rusentrel.synonyms.collection import StemmerBasedSynonymCollection -from arekit.contrib.source.rusentrel.utils import iter_synonym_groups -from arekit.processing.lemmatization.mystem import MystemWrapper +from arelight.labels.scalers import ThreeLabelScaler def create_labels_scaler(labels_count): diff --git a/setup.py b/setup.py index 463f5d3..6249264 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ def get_requirements(filenames): setup( name='arelight', - version='0.22.0', + version='0.22.1', description='About Mass-media text processing application for your ' 'Relation Extraction task, powered by AREkit.', url='https://github.com/nicolay-r/ARElight', diff --git a/test/test_bert_ontonotes_ner.py b/test/test_bert_ontonotes_ner.py index 8fde7fb..4ab36ae 100644 --- a/test/test_bert_ontonotes_ner.py +++ b/test/test_bert_ontonotes_ner.py @@ -1,6 +1,6 @@ import unittest -from arekit.processing.entities.obj_desc import NerObjectDescriptor +from arelight.ner_obj_desc import NerObjectDescriptor from arelight.text.ner_ontonotes import BertOntonotesNER diff --git a/test/test_bert_ontonotes_ner_pipeline_item.py b/test/test_bert_ontonotes_ner_pipeline_item.py index 620c8bd..ee96798 100644 --- a/test/test_bert_ontonotes_ner_pipeline_item.py +++ b/test/test_bert_ontonotes_ner_pipeline_item.py @@ -4,7 +4,8 @@ from arekit.common.news.parser import NewsParser from arekit.common.news.sentence import BaseNewsSentence from arekit.common.text.parser import BaseTextParser -from arekit.processing.text.pipeline_terms_splitter import TermsSplitterParser +from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser + from arelight.text.pipeline_entities_bert_ontonotes import BertOntonotesNERPipelineItem diff --git a/test/test_bert_serialization.py b/test/test_bert_serialization.py index 5d9a4ad..2efaa5c 100644 --- a/test/test_bert_serialization.py +++ b/test/test_bert_serialization.py @@ -1,11 +1,7 @@ import unittest from os.path import dirname, join, realpath -from arekit.common.experiment.annot.algo.pair_based import PairBasedAnnotationAlgorithm -from arekit.common.experiment.annot.default import DefaultAnnotator -from arekit.common.experiment.api.base import BaseExperiment from arekit.common.experiment.data_type import DataType -from arekit.common.experiment.engine import ExperimentEngine from arekit.common.experiment.name_provider import ExperimentNameProvider from arekit.common.folding.nofold import NoFolding from arekit.common.labels.base import NoLabel @@ -15,15 +11,12 @@ from arekit.common.news.base import News from arekit.common.news.entities_grouping import EntitiesGroupingPipelineItem from arekit.common.news.sentence import BaseNewsSentence +from arekit.common.opinions.annot.algo.pair_based import PairBasedOpinionAnnotationAlgorithm from arekit.common.text.parser import BaseTextParser -from arekit.contrib.bert.handlers.serializer import BertExperimentInputSerializerIterationHandler -from arekit.contrib.bert.samplers.types import BertSampleProviderTypes -from arekit.contrib.experiment_rusentrel.entities.str_simple_sharp_prefixed_fmt import \ - SharpPrefixedEntitiesSimpleFormatter -from arekit.contrib.experiment_rusentrel.synonyms.collection import StemmerBasedSynonymCollection -from arekit.contrib.source.rusentrel.utils import iter_synonym_groups -from arekit.processing.lemmatization.mystem import MystemWrapper -from arekit.processing.text.pipeline_terms_splitter import TermsSplitterParser +from arekit.contrib.source.synonyms.utils import iter_synonym_groups +from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser +from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper +from arekit.contrib.utils.synonyms.stemmer_based import StemmerBasedSynonymCollection from ru_sent_tokenize import ru_sent_tokenize from arelight.exp.doc_ops import CustomDocOperations @@ -95,7 +88,7 @@ def test(self): ]) # Declaring algo. - algo = PairBasedAnnotationAlgorithm( + algo = PairBasedOpinionAnnotationAlgorithm( label_provider=ConstantLabelProvider(label_instance=NoLabel()), dist_in_terms_bound=None) @@ -104,11 +97,9 @@ def test(self): supported_data_types=[DataType.Test]) exp_ctx = BertSerializationContext( label_scaler=SingleLabelScaler(NoLabel()), - annotator=DefaultAnnotator(algo), + # annotator=BaseOpinionAnnotator(algo), terms_per_context=50, - str_entity_formatter=SharpPrefixedEntitiesSimpleFormatter(), - name_provider=ExperimentNameProvider(name="example-bert", suffix="serialize"), - data_folding=no_folding) + name_provider=ExperimentNameProvider(name="example-bert", suffix="serialize")) # Composing labels formatter and experiment preparation. labels_fmt = StringLabelsFormatter(stol={"neu": NoLabel}) @@ -120,8 +111,10 @@ def test(self): synonyms=synonyms, neutral_labels_fmt=labels_fmt) + # TODO. No exp anymore. exp = BaseExperiment(exp_io=exp_io, exp_ctx=exp_ctx, doc_ops=doc_ops, opin_ops=opin_ops) + # TODO. To pipeline. handler = BertExperimentInputSerializerIterationHandler( exp_io=exp_io, exp_ctx=exp_ctx, @@ -136,11 +129,12 @@ def test(self): value_to_group_id_func=synonyms.get_synonym_group_index, balance_train_samples=True) - # Initilize documents. + # Initialize documents. docs = input_to_docs(texts) doc_ops.set_docs(docs) # Run. + # TODO. No experiment engine anymore. engine = ExperimentEngine(exp_ctx.DataFolding) # Present folding limitation. engine.run([handler]) diff --git a/test/test_demo.py b/test/test_demo.py index cb68de4..ce1591e 100644 --- a/test/test_demo.py +++ b/test/test_demo.py @@ -2,7 +2,6 @@ import unittest from os.path import dirname, realpath, join -from arekit.contrib.experiment_rusentrel.labels.formatters.rusentiframes import ExperimentRuSentiFramesLabelsFormatter from arekit.contrib.networks.enum_input_types import ModelInputType from arekit.contrib.networks.enum_name_types import ModelNames from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection @@ -10,6 +9,7 @@ from arelight.demo.infer_bert_rus import demo_infer_texts_bert_pipeline from arelight.demo.infer_nn_rus import demo_infer_texts_tensorflow_nn_pipeline +from arelight.labels.formatter import ExperimentRuSentiFramesLabelsFormatter from examples.args import const @@ -47,7 +47,6 @@ def test_demo_rus_nn(self): model_name=ModelNames.PCNN, synonyms_filepath=join(TestDemo.ORIGIN_DATA_DIR, "synonyms.txt"), model_load_dir=const.NEURAL_NETWORKS_TARGET_DIR, - embedding_filepath=const.EMBEDDING_FILEPATH, model_input_type=ModelInputType.SingleInstance, frames_collection=frames_collection) diff --git a/update_arekit.sh b/update_arekit.sh index 0740b83..8c2788d 100755 --- a/update_arekit.sh +++ b/update_arekit.sh @@ -1,2 +1,2 @@ -pip uninstall arekit -pip install git+https://github.com/nicolay-r/AREkit@0.22.0-rc +pip3 uninstall arekit +pip3 install git+https://github.com/nicolay-r/AREkit@master --no-deps