From 3d15ee0d27332307292a2017d2b27a195bbc4ac6 Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Tue, 31 Dec 2024 13:14:59 +0000 Subject: [PATCH] #152 related refactoring / update --- .../pipelines/items/inference_bert_opennre.py | 7 ++- arelight/third_party/legacy/__init__.py | 0 arelight/third_party/{ => legacy}/torch.py | 2 + .../test_opennre_infer_batching.py} | 39 ++++++-------- test/test_opennre_infer.py | 54 +++++++++++++++++++ 5 files changed, 77 insertions(+), 25 deletions(-) create mode 100644 arelight/third_party/legacy/__init__.py rename arelight/third_party/{ => legacy}/torch.py (97%) rename test/{test_opennre_bert_infer.py => legacy/test_opennre_infer_batching.py} (66%) create mode 100644 test/test_opennre_infer.py diff --git a/arelight/pipelines/items/inference_bert_opennre.py b/arelight/pipelines/items/inference_bert_opennre.py index 8feab11..abbe4ff 100644 --- a/arelight/pipelines/items/inference_bert_opennre.py +++ b/arelight/pipelines/items/inference_bert_opennre.py @@ -10,7 +10,7 @@ from opennre.encoder import BERTEntityEncoder, BERTEncoder from opennre.model import SoftmaxNN -from arelight.third_party.torch import sentence_re_loader +from arelight.third_party.legacy.torch import sentence_re_loader from arelight.utils import get_default_download_dir, download logger = logging.getLogger(__name__) @@ -120,8 +120,13 @@ def init_bert_model(pretrain_path, labels_scaler, ckpt_path, device_type, predef model.load_state_dict(torch.load(ckpt_path, map_location=torch.device(device_type))['state_dict']) return model + # TODO. This is legacy (see #173). + # TODO. This is legacy (see #173). + # TODO. This is legacy (see #173). @staticmethod def iter_results(parallel_model, eval_loader, data_ids): + """ NOTE. This method by-passess the predefined inference @ model.infer(). + """ # It is important we should open database. with eval_loader.dataset: diff --git a/arelight/third_party/legacy/__init__.py b/arelight/third_party/legacy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/arelight/third_party/torch.py b/arelight/third_party/legacy/torch.py similarity index 97% rename from arelight/third_party/torch.py rename to arelight/third_party/legacy/torch.py index f0b03a3..47f83b2 100644 --- a/arelight/third_party/torch.py +++ b/arelight/third_party/legacy/torch.py @@ -6,6 +6,7 @@ from arelight.third_party.sqlite3 import SQLite3Service +# TODO. This component is legacy (#173) class SQLiteSentenceREDataset(data.Dataset): """ Sentence-level relation extraction dataset This is a original OpenNRE implementation, adapted for SQLite. @@ -85,6 +86,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.sqlite_service.disconnect() +# TODO. This component is legacy (#173) def sentence_re_loader(path, table_name, rel2id, tokenizer, batch_size, shuffle, task_kwargs, num_workers, collate_fn=SQLiteSentenceREDataset.collate_fn, **kwargs): dataset = SQLiteSentenceREDataset(path=path, table_name=table_name, rel2id=rel2id, diff --git a/test/test_opennre_bert_infer.py b/test/legacy/test_opennre_infer_batching.py similarity index 66% rename from test/test_opennre_bert_infer.py rename to test/legacy/test_opennre_infer_batching.py index 57785cf..f548a14 100644 --- a/test/test_opennre_bert_infer.py +++ b/test/legacy/test_opennre_infer_batching.py @@ -1,49 +1,41 @@ import logging - -import utils -import torch +import sys import unittest +from os.path import join, dirname, realpath -from os.path import join - -from opennre.encoder import BERTEncoder -from opennre.model import SoftmaxNN +import torch from arelight.pipelines.items.inference_bert_opennre import BertOpenNREInferencePipelineItem from arelight.predict.provider import BasePredictProvider from arelight.predict.writer_csv import TsvPredictWriter from arelight.run.utils import OPENNRE_CHECKPOINTS -from arelight.third_party.torch import sentence_re_loader +from arelight.third_party.legacy.torch import sentence_re_loader from arelight.utils import get_default_download_dir -logger = logging.getLogger(__name__) +sys.path.append("..") + +current_dir = dirname(realpath(__file__)) +TEST_DATA_DIR = join(current_dir, "..", "data") +TEST_OUT_DIR = join("..", "_out") class TestLoadModel(unittest.TestCase): CKPT = "ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar" - def test_launch_model(self): - pretrain_path, ckpt_path, label_scaler = BertOpenNREInferencePipelineItem.try_download_predefined_checkpoint( - checkpoint=self.CKPT, predefined=OPENNRE_CHECKPOINTS, dir_to_download=utils.TEST_OUT_DIR, logger=logger) - model = BERTEncoder(pretrain_path=pretrain_path, mask_entity=True, max_length=512) - rel2id = BertOpenNREInferencePipelineItem.scaler_to_rel2id(label_scaler) - model = SoftmaxNN(model, len(rel2id), rel2id) - model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))["state_dict"]) - def test_infer(self): self.infer_bert(pretrain_path=None, ckpt_path=self.CKPT, labels_scaler=None, predefined=OPENNRE_CHECKPOINTS, - output_file_gzip=join(utils.TEST_OUT_DIR, "opennre-data-test.tsv.gz"), + output_file_gzip=join(TEST_OUT_DIR, "opennre-data-test.tsv.gz"), logger=logging.getLogger(__name__)) @staticmethod def infer_bert(pretrain_path, labels_scaler, output_file_gzip, predefined, logger, ckpt_path=None, pooler='cls', batch_size=6, max_length=128, mask_entity=True): - test_data_file = join(utils.TEST_DATA_DIR, "opennre-data-test-predict.sqlite") + test_data_file = join(TEST_DATA_DIR, "opennre-data-test-predict.sqlite") model = BertOpenNREInferencePipelineItem.init_bert_model( pretrain_path=pretrain_path, labels_scaler=labels_scaler, ckpt_path=ckpt_path, @@ -56,17 +48,16 @@ def infer_bert(pretrain_path, labels_scaler, output_file_gzip, predefined, logge tokenizer=model.sentence_encoder.tokenize, batch_size=batch_size, task_kwargs={ - "no_label": "0", - "default_id_column": "id", - "index_columns": ["s_ind", "t_ind"], - "text_columns": ["text_a", "text_b"] + "no_label": "0", + "default_id_column": "id", + "index_columns": ["s_ind", "t_ind"], + "text_columns": ["text_a", "text_b"] }, shuffle=False, num_workers=0) # Open database. with eval_loader.dataset as dataset: - it_results = BertOpenNREInferencePipelineItem.iter_results( parallel_model=torch.nn.DataParallel(model), data_ids=list(dataset.iter_ids()), diff --git a/test/test_opennre_infer.py b/test/test_opennre_infer.py new file mode 100644 index 0000000..8b5833f --- /dev/null +++ b/test/test_opennre_infer.py @@ -0,0 +1,54 @@ +import logging +import utils +import torch +import unittest + +from tqdm import tqdm + +from opennre.encoder import BERTEncoder +from opennre.model import SoftmaxNN + +from arelight.pipelines.items.inference_bert_opennre import BertOpenNREInferencePipelineItem +from arelight.run.utils import OPENNRE_CHECKPOINTS +from arelight.utils import get_default_download_dir + +logger = logging.getLogger(__name__) + + +class TestLoadModel(unittest.TestCase): + + CKPT = "ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar" + + def test_launch_model(self): + pretrain_path, ckpt_path, label_scaler = BertOpenNREInferencePipelineItem.try_download_predefined_checkpoint( + checkpoint=self.CKPT, predefined=OPENNRE_CHECKPOINTS, dir_to_download=utils.TEST_OUT_DIR, logger=logger) + model = BERTEncoder(pretrain_path=pretrain_path, mask_entity=True, max_length=512) + rel2id = BertOpenNREInferencePipelineItem.scaler_to_rel2id(label_scaler) + model = SoftmaxNN(model, len(rel2id), rel2id) + model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))["state_dict"]) + + def test_infer(self): + self.infer_no_batch(pretrain_path=None, + ckpt_path=self.CKPT, + labels_scaler=None, + predefined=OPENNRE_CHECKPOINTS, + logger=logging.getLogger(__name__)) + + @staticmethod + def infer_no_batch(pretrain_path, labels_scaler, predefined, logger, ckpt_path=None, + pooler='cls', max_length=128, mask_entity=True): + model = BertOpenNREInferencePipelineItem.init_bert_model( + pretrain_path=pretrain_path, labels_scaler=labels_scaler, ckpt_path=ckpt_path, + device_type="cpu", max_length=max_length, mask_entity=mask_entity, logger=logger, + dir_to_donwload=get_default_download_dir(), pooler=pooler, predefined=predefined) + + texts = [ + {"token": "сша ввела санкции против россии".split(), "h": {"pos": [0, 1]}, "t": {"pos": [3, 4]}}, + {"token": "сша поддержала россии".split(), "h": {"pos": [0, 1]}, "t": {"pos": [3, 4]}}, + {"token": "сша и россия".split(), "h": {"pos": [0, 1]}, "t": {"pos": [3, 4]}}, + ] + + texts = texts * 100 + + for input in tqdm(texts): + _ = model.infer(input)