Skip to content

Commit

Permalink
large and temp API refactoring by switching to 0.22.1 AREkit version.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 2, 2022
1 parent 74d424b commit 37c0b49
Show file tree
Hide file tree
Showing 45 changed files with 188 additions and 785 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ARElight 0.22.0
# ARElight 0.22.1

### :point_right: [DEMO](#docker-verion-quick) :point_left:

Expand All @@ -18,7 +18,7 @@ we adopt [DeepPavlov](https://github.com/deepmipt/DeepPavlov) (BertOntoNotes mo

# Dependencies

* arekit == 0.22.0
* arekit == 0.22.1
* gensim == 3.2.0
* deeppavlov == 0.11.0
* rusenttokenize
Expand Down Expand Up @@ -151,8 +151,6 @@ python3.6 serialize_texts_nn.py --from-files data/texts-inosmi-rus/e1.txt \

# Other Examples

* Serialize RuSentRel collection for BERT [[code]](examples/serialize_rusentrel_for_bert.py)
* Serialize RuSentRel collection for Neural Networks [[code]](examples/serialize_rusentrel_for_nn.py)
* Finetune BERT on samples [[code]](examples/train_bert.py)
* Finetune Neural Networks on RuSentRel [[code]](examples/train_nn_on_rusentrel.py)

Expand Down
15 changes: 5 additions & 10 deletions arelight/demo/infer_bert_rus.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from arekit.common.experiment.annot.algo.pair_based import PairBasedAnnotationAlgorithm
from arekit.common.experiment.annot.default import DefaultAnnotator
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.name_provider import ExperimentNameProvider
from arekit.common.folding.nofold import NoFolding
from arekit.common.labels.base import NoLabel
from arekit.common.labels.provider.constant import ConstantLabelProvider
from arekit.common.opinions.annot.algo.pair_based import PairBasedOpinionAnnotationAlgorithm
from arekit.common.opinions.annot.base import BaseOpinionAnnotator
from arekit.common.pipeline.base import BasePipeline
from arekit.contrib.bert.samplers.types import SampleFormattersService
from arekit.contrib.experiment_rusentrel.entities.factory import create_entity_formatter
from arekit.contrib.experiment_rusentrel.entities.types import EntityFormatterTypes
from arekit.contrib.experiment_rusentrel.labels.scalers.three import ThreeLabelScaler
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentPositiveLabel, ExperimentNegativeLabel
from arekit.contrib.networks.core.predict.tsv_writer import TsvPredictWriter
from arekit.processing.lemmatization.mystem import MystemWrapper
from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper

from arelight.demo.utils import read_synonyms_collection
from arelight.pipelines.backend_brat_json import BratBackendContentsPipelineItem
Expand Down Expand Up @@ -48,8 +43,8 @@ def demo_infer_texts_bert_pipeline(texts_count,
name_provider=ExperimentNameProvider(name="example-bert", suffix="infer"),
text_b_type=text_b_type,
output_dir=output_dir,
opin_annot=DefaultAnnotator(
PairBasedAnnotationAlgorithm(
opin_annot=BaseOpinionAnnotator(
PairBasedOpinionAnnotationAlgorithm(
dist_in_terms_bound=None,
label_provider=ConstantLabelProvider(label_instance=NoLabel()))),
data_folding=NoFolding(doc_ids_to_fold=list(range(texts_count)),
Expand Down
16 changes: 5 additions & 11 deletions arelight/demo/infer_nn_rus.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from arekit.common.experiment.annot.algo.pair_based import PairBasedAnnotationAlgorithm
from arekit.common.experiment.annot.default import DefaultAnnotator
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.name_provider import ExperimentNameProvider
from arekit.common.folding.nofold import NoFolding
from arekit.common.labels.base import NoLabel
from arekit.common.labels.provider.constant import ConstantLabelProvider
from arekit.common.opinions.annot.algo.pair_based import PairBasedOpinionAnnotationAlgorithm
from arekit.common.opinions.annot.base import BaseOpinionAnnotator
from arekit.common.pipeline.base import BasePipeline
from arekit.contrib.experiment_rusentrel.entities.factory import create_entity_formatter
from arekit.contrib.experiment_rusentrel.entities.types import EntityFormattersService
from arekit.contrib.experiment_rusentrel.labels.scalers.three import ThreeLabelScaler
from arekit.contrib.experiment_rusentrel.labels.types import ExperimentPositiveLabel, ExperimentNegativeLabel
from arekit.contrib.networks.core.callback.stat import TrainingStatProviderCallback
from arekit.contrib.networks.core.callback.train_limiter import TrainingLimiterCallback
from arekit.contrib.networks.core.predict.tsv_writer import TsvPredictWriter
from arekit.contrib.networks.enum_name_types import ModelNames
from arekit.processing.lemmatization.mystem import MystemWrapper
from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper

from arelight.demo.utils import read_synonyms_collection
from arelight.network.nn.common import create_network_model_io, create_bags_collection_type, create_full_model_name
Expand All @@ -29,7 +25,6 @@ def demo_infer_texts_tensorflow_nn_pipeline(texts_count,
frames_collection,
output_dir,
synonyms_filepath,
embedding_filepath,
embedding_matrix_filepath=None,
vocab_filepath=None,
bags_per_minibatch=2,
Expand All @@ -56,14 +51,13 @@ def demo_infer_texts_tensorflow_nn_pipeline(texts_count,
frames_collection=frames_collection,
synonyms=read_synonyms_collection(synonyms_filepath=synonyms_filepath, stemmer=stemmer),
terms_per_context=terms_per_context,
embedding_path=embedding_filepath,
entities_parser=BertOntonotesNERPipelineItem(
lambda s_obj: s_obj.ObjectType in ["ORG", "PERSON", "LOC", "GPE"]),
entity_fmt=create_entity_formatter(entity_fmt_type),
stemmer=stemmer,
name_provider=exp_name_provider,
opin_annot=DefaultAnnotator(
PairBasedAnnotationAlgorithm(
opin_annot=BaseOpinionAnnotator(
PairBasedOpinionAnnotationAlgorithm(
dist_in_terms_bound=None,
label_provider=ConstantLabelProvider(label_instance=NoLabel()))),
output_dir=output_dir,
Expand Down
4 changes: 2 additions & 2 deletions arelight/demo/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from arekit.common.text.stemmer import Stemmer
from arekit.contrib.experiment_rusentrel.synonyms.collection import StemmerBasedSynonymCollection
from arekit.contrib.source.rusentrel.utils import iter_synonym_groups
from arekit.contrib.source.synonyms.utils import iter_synonym_groups
from arekit.contrib.utils.synonyms.stemmer_based import StemmerBasedSynonymCollection


def iter_groups(filepath):
Expand Down
6 changes: 0 additions & 6 deletions arelight/exp/doc_ops.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations


class CustomDocOperations(DocumentOperations):

def iter_tagget_doc_ids(self, tag):
assert(isinstance(tag, BaseDocumentTag))
assert(tag == BaseDocumentTag.Annotate)
return self.__doc_ids

def __init__(self, exp_ctx, text_parser):
super(CustomDocOperations, self).__init__(exp_ctx, text_parser)
self.__docs = None
Expand Down
16 changes: 1 addition & 15 deletions arelight/exp/exp_io.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,9 @@
import os
from arekit.contrib.experiment_rusentrel.model_io.tf_networks import RuSentRelExperimentNetworkIOUtils


class InferIOUtils(RuSentRelExperimentNetworkIOUtils):
class InferIOUtils(object):

def __init__(self, output_dir, exp_ctx):
assert(isinstance(output_dir, str))
super(InferIOUtils, self).__init__(exp_ctx=exp_ctx)
self.__output_dir = output_dir

def __create_annot_input_target(self, doc_id, data_type):
filename = "annot_input_d{doc_id}_{data_type}.txt".format(doc_id=doc_id, data_type=data_type.name)
return os.path.join(self._get_target_dir(), filename)

def _get_experiment_sources_dir(self):
return self.__output_dir

def create_opinion_collection_target(self, doc_id, data_type, check_existance=False):
return self.__create_annot_input_target(doc_id=doc_id, data_type=data_type)

def create_result_opinion_collection_target(self, doc_id, data_type, epoch_index):
return self.__create_annot_input_target(doc_id=doc_id, data_type=data_type)
24 changes: 1 addition & 23 deletions arelight/exp/opin_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.opinions.collection import OpinionCollection


class CustomOpinionOperations(OpinionOperations):
class CustomOpinionOperations(object):

def __init__(self, labels_formatter, exp_io, synonyms, neutral_labels_fmt):
super(CustomOpinionOperations, self).__init__()
Expand All @@ -11,27 +10,6 @@ def __init__(self, labels_formatter, exp_io, synonyms, neutral_labels_fmt):
self.__synonyms = synonyms
self.__neutral_labels_fmt = neutral_labels_fmt

@property
def LabelsFormatter(self):
return self.__labels_formatter

def iter_opinions_for_extraction(self, doc_id, data_type):
# Reading automatically annotated collection of neutral opinions.
# TODO. #250, #251 provide opinion annotation here for the particular document.
return self.__exp_io.read_opinion_collection(
target=self.__exp_io.create_result_opinion_collection_target(
doc_id=doc_id,
data_type=data_type,
epoch_index=0),
labels_formatter=self.__neutral_labels_fmt,
create_collection_func=self.create_opinion_collection)

def get_etalon_opinion_collection(self, doc_id):
return self.create_opinion_collection(None)

def get_result_opinion_collection(self, doc_id, data_type, epoch_index):
raise Exception("Not Supported")

def create_opinion_collection(self, opinions=None):
return OpinionCollection(opinions=[] if opinions is None else opinions,
synonyms=self.__synonyms,
Expand Down
21 changes: 21 additions & 0 deletions arelight/ner_obj_desc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
class NerObjectDescriptor:

def __init__(self, pos, length, obj_type):
self.__pos = pos
self.__len = length
self.__obj_type = obj_type

@property
def Position(self):
return self.__pos

@property
def Length(self):
return self.__len

@property
def ObjectType(self):
return self.__obj_type

def get_range(self):
return self.__pos, self.__pos + self.__len
17 changes: 2 additions & 15 deletions arelight/network/bert/ctx.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.experiment.api.ctx_serialization import ExperimentSerializationContext


class BertSerializationContext(ExperimentSerializationContext):

def __init__(self, label_scaler, terms_per_context, str_entity_formatter,
annotator, name_provider, data_folding):
assert(isinstance(str_entity_formatter, StringEntitiesFormatter))
def __init__(self, label_scaler, terms_per_context, name_provider):
assert(isinstance(terms_per_context, int))

super(BertSerializationContext, self).__init__(annot=annotator,
name_provider=name_provider,
label_scaler=label_scaler,
data_folding=data_folding)

super(BertSerializationContext, self).__init__(name_provider=name_provider, label_scaler=label_scaler)
self.__terms_per_context = terms_per_context
self.__str_entity_formatter = str_entity_formatter

@property
def StringEntityFormatter(self):
return self.__str_entity_formatter

@property
def TermsPerContext(self):
Expand Down
34 changes: 7 additions & 27 deletions arelight/network/nn/ctx.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.contrib.experiment_rusentrel.connotations.provider import RuSentiFramesConnotationProvider
from arekit.contrib.experiment_rusentrel.labels.scalers.three import ThreeLabelScaler
from arekit.contrib.networks.core.input.ctx_serialization import NetworkSerializationContext
from arekit.contrib.networks.embeddings.base import Embedding
from arekit.contrib.networks.embedding import Embedding
from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection
from arekit.processing.pos.base import POSTagger
from arekit.contrib.utils.connotations.rusentiframes_sentiment import RuSentiFramesConnotationProvider
from arekit.contrib.utils.processing.pos.base import POSTagger


class NetworkSerializationContext(NetworkSerializationContext):
class CustomNeuralNetworkSerializationContext(NetworkSerializationContext):

def __init__(self, labels_scaler, pos_tagger, embedding,
terms_per_context, str_entity_formatter, annotator,
frames_collection, frame_variant_collection, name_provider, data_folding):
def __init__(self, labels_scaler, pos_tagger, embedding, terms_per_context, str_entity_formatter,
frames_collection, frame_variant_collection, name_provider):
assert(isinstance(embedding, Embedding))
assert(isinstance(pos_tagger, POSTagger))
assert(isinstance(frames_collection, RuSentiFramesCollection))
assert(isinstance(str_entity_formatter, StringEntitiesFormatter))
assert(isinstance(terms_per_context, int))

super(NetworkSerializationContext, self).__init__(
labels_scaler=labels_scaler, annot=annotator,
name_provider=name_provider, data_folding=data_folding)
super(NetworkSerializationContext, self).__init__(label_scaler=labels_scaler, name_provider=name_provider)

self.__pos_tagger = pos_tagger
self.__terms_per_context = terms_per_context
Expand All @@ -35,22 +31,6 @@ def __init__(self, labels_scaler, pos_tagger, embedding,
def PosTagger(self):
return self.__pos_tagger

@property
def StringEntityFormatter(self):
return self.__str_entity_formatter

@property
def StringEntityEmbeddingFormatter(self):
return self.__str_entity_formatter

@property
def FrameVariantCollection(self):
return self.__frame_variant_collection

@property
def WordEmbedding(self):
return self.__word_embedding

@property
def TermsPerContext(self):
return self.__terms_per_context
Expand Down
69 changes: 0 additions & 69 deletions arelight/network/nn/embedding.py

This file was deleted.

4 changes: 2 additions & 2 deletions arelight/pipelines/inference_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from arekit.contrib.networks.core.predict.base_writer import BasePredictWriter
from arekit.contrib.networks.factory import create_network_and_network_config_funcs
from arekit.contrib.networks.shapes import NetworkInputShapes
from arekit.processing.languages.ru.pos_service import PartOfSpeechTypesService
from arekit.contrib.utils.processing.languages.ru.pos_service import PartOfSpeechTypesService

from arelight.exp.exp_io import InferIOUtils

Expand Down Expand Up @@ -76,7 +76,7 @@ def apply_core(self, input_data, pipeline_ctx):
# Update for further pipeline items.
pipeline_ctx.update("predict_fp", tgt)

# Fetch other required in furter information from input_data.
# Fetch other required in further information from input_data.
samples_filepath = input_data.create_samples_writer_target(self.__data_type)
embedding = input_data.load_embedding()
vocab = input_data.load_vocab()
Expand Down
Loading

0 comments on commit 37c0b49

Please sign in to comment.