Skip to content

Commit

Permalink
#170 refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 29, 2024
1 parent 235ca62 commit 6d43710
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 64 deletions.
14 changes: 14 additions & 0 deletions arelight/arekit/indexed_entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from arekit.common.entities.base import Entity


class IndexedEntity(Entity):
""" Same as the base Entity but supports indexing.
"""

def __init__(self, value, e_type, entity_id):
super(IndexedEntity, self).__init__(value=value, e_type=e_type)
self.__id = entity_id

@property
def ID(self):
return self.__id
10 changes: 5 additions & 5 deletions arelight/run/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
from arelight.predict.writer_sqlite3 import SQLite3PredictWriter
from arelight.readers.csv_pd import PandasCsvReader
from arelight.readers.sqlite import SQliteReader
from arelight.run.utils import merge_dictionaries, iter_group_values, create_sentence_parser, \
create_translate_model, iter_content, OPENNRE_CHECKPOINTS, NER_TYPES
from arelight.run.utils import merge_dictionaries, iter_group_values, create_sentence_parser,\
iter_content, OPENNRE_CHECKPOINTS, NER_TYPES
from arelight.run.utils_logger import setup_custom_logger, TqdmToLogger
from arelight.samplers.bert import create_bert_sample_provider
from arelight.samplers.types import SampleFormattersService
from arelight.stemmers.ru_mystem import MystemWrapper
from arelight.third_party.dp_130 import DeepPavlovNER
from arelight.third_party.gt_310a import GoogleTranslateModel
from arelight.utils import flatten

from bulk_translate.src.pipeline.translator import MLTextTranslatorPipelineItem
Expand Down Expand Up @@ -148,7 +149,7 @@ def setup_collection_name(value):

translate_model = {
None: lambda: None,
"googletrans": lambda: create_translate_model("googletrans")
"googletrans": lambda: GoogleTranslateModel()
}

translator = translate_model[args.translate_framework]()
Expand Down Expand Up @@ -244,8 +245,7 @@ def setup_collection_name(value):
None: lambda: None,
"ml-based": lambda: [
MLTextTranslatorPipelineItem(
batch_translate_model=lambda content: translator(
str_list=content,
batch_translate_model=translator.get_func(
src=args.translate_text.split(':')[0],
dest=args.translate_text.split(':')[1]),
do_translate_entity=False,
Expand Down
10 changes: 0 additions & 10 deletions arelight/run/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,6 @@ def iter_content(filepath, csv_column, csv_delimiter, open_func=None):
yield f.read().rstrip()


def create_translate_model(arg):

if arg == "googletrans":
# We do auto-import so we not depend on the actually installed library.
translate_value = auto_import("arelight.third_party.googletrans.translate_value")
# Translation of the list of data.
# Returns the list of strings.
return lambda str_list, src, dest: [translate_value(s, dest=dest, src=src) for s in str_list]


def iter_group_values(filepath):

if filepath is None:
Expand Down
36 changes: 0 additions & 36 deletions arelight/third_party/googletrans.py

This file was deleted.

38 changes: 38 additions & 0 deletions arelight/third_party/gt_310a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# This implementation has been tested for
# googletrans==3.1.0a0


import time

from googletrans import Translator


class GoogleTranslateModel(object):

def __init__(self, **kwargs):
self._instance = Translator()

@staticmethod
def translate_value(translator, value, src, dest, sec_delay=1, attempts=10):

import logging
logger = logging.getLogger() # get the default logger
logger.setLevel(50)

for i in range(attempts):
try:
translated = translator.translate(value, dest=dest, src=src)
return translated.text
except:
logger.info("Unable to perform translation. Try {} out of {}.".format(i, attempts))
time.sleep(sec_delay)

raise Exception("Can't translate")

def get_func(self, src, dest, **kwargs):
# We do auto-import so we not depend on the actually installed library.
# Translation of the list of data.
# Returns the list of strings.
return lambda str_list: [
GoogleTranslateModel.translate_value(translator=self._instance, value=s, dest=dest, src=src)
for s in str_list]
6 changes: 3 additions & 3 deletions test/test_document_parsing_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from bulk_translate.src.pipeline.translator import MLTextTranslatorPipelineItem

from arelight.arekit.indexed_entity import IndexedEntity
from arelight.run.utils import create_translate_model
from arelight.third_party.dp_130 import DeepPavlovNER
from arelight.third_party.gt_310a import GoogleTranslateModel


class DocumentParsingBenchmark(unittest.TestCase):
Expand Down Expand Up @@ -62,14 +62,14 @@ def test_ner_deeppavlov(self):

def test_translator(self):

translator = create_translate_model("googletrans")
translator = GoogleTranslateModel()

# Declare text parser.
text_parser_pipeline = [
BasePipelineItem(src_func=lambda s: s.Text),
MLTextTranslatorPipelineItem(
src_func=lambda text: split_by_whitespaces(text),
batch_translate_model=lambda content: translator(str_list=content, src="ru", dest="en"),
batch_translate_model=translator.get_func(src="ru", dest="en"),
is_span_func=lambda term: isinstance(term, IndexedEntity),
do_translate_entity=False)
]
Expand Down
10 changes: 0 additions & 10 deletions test/test_translation.py

This file was deleted.

0 comments on commit 6d43710

Please sign in to comment.