Skip to content

Commit

Permalink
#152 related refactoring / update
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 31, 2024
1 parent 1ba2f6a commit 3d15ee0
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 25 deletions.
7 changes: 6 additions & 1 deletion arelight/pipelines/items/inference_bert_opennre.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from opennre.encoder import BERTEntityEncoder, BERTEncoder
from opennre.model import SoftmaxNN

from arelight.third_party.torch import sentence_re_loader
from arelight.third_party.legacy.torch import sentence_re_loader
from arelight.utils import get_default_download_dir, download

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -120,8 +120,13 @@ def init_bert_model(pretrain_path, labels_scaler, ckpt_path, device_type, predef
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device(device_type))['state_dict'])
return model

# TODO. This is legacy (see #173).
# TODO. This is legacy (see #173).
# TODO. This is legacy (see #173).
@staticmethod
def iter_results(parallel_model, eval_loader, data_ids):
""" NOTE. This method by-passess the predefined inference @ model.infer().
"""

# It is important we should open database.
with eval_loader.dataset:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from arelight.third_party.sqlite3 import SQLite3Service


# TODO. This component is legacy (#173)
class SQLiteSentenceREDataset(data.Dataset):
""" Sentence-level relation extraction dataset
This is a original OpenNRE implementation, adapted for SQLite.
Expand Down Expand Up @@ -85,6 +86,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.sqlite_service.disconnect()


# TODO. This component is legacy (#173)
def sentence_re_loader(path, table_name, rel2id, tokenizer, batch_size, shuffle,
task_kwargs, num_workers, collate_fn=SQLiteSentenceREDataset.collate_fn, **kwargs):
dataset = SQLiteSentenceREDataset(path=path, table_name=table_name, rel2id=rel2id,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,49 +1,41 @@
import logging

import utils
import torch
import sys
import unittest
from os.path import join, dirname, realpath

from os.path import join

from opennre.encoder import BERTEncoder
from opennre.model import SoftmaxNN
import torch

from arelight.pipelines.items.inference_bert_opennre import BertOpenNREInferencePipelineItem
from arelight.predict.provider import BasePredictProvider
from arelight.predict.writer_csv import TsvPredictWriter
from arelight.run.utils import OPENNRE_CHECKPOINTS
from arelight.third_party.torch import sentence_re_loader
from arelight.third_party.legacy.torch import sentence_re_loader
from arelight.utils import get_default_download_dir

logger = logging.getLogger(__name__)
sys.path.append("..")

current_dir = dirname(realpath(__file__))
TEST_DATA_DIR = join(current_dir, "..", "data")
TEST_OUT_DIR = join("..", "_out")


class TestLoadModel(unittest.TestCase):

CKPT = "ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar"

def test_launch_model(self):
pretrain_path, ckpt_path, label_scaler = BertOpenNREInferencePipelineItem.try_download_predefined_checkpoint(
checkpoint=self.CKPT, predefined=OPENNRE_CHECKPOINTS, dir_to_download=utils.TEST_OUT_DIR, logger=logger)
model = BERTEncoder(pretrain_path=pretrain_path, mask_entity=True, max_length=512)
rel2id = BertOpenNREInferencePipelineItem.scaler_to_rel2id(label_scaler)
model = SoftmaxNN(model, len(rel2id), rel2id)
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))["state_dict"])

def test_infer(self):
self.infer_bert(pretrain_path=None,
ckpt_path=self.CKPT,
labels_scaler=None,
predefined=OPENNRE_CHECKPOINTS,
output_file_gzip=join(utils.TEST_OUT_DIR, "opennre-data-test.tsv.gz"),
output_file_gzip=join(TEST_OUT_DIR, "opennre-data-test.tsv.gz"),
logger=logging.getLogger(__name__))

@staticmethod
def infer_bert(pretrain_path, labels_scaler, output_file_gzip, predefined, logger, ckpt_path=None,
pooler='cls', batch_size=6, max_length=128, mask_entity=True):

test_data_file = join(utils.TEST_DATA_DIR, "opennre-data-test-predict.sqlite")
test_data_file = join(TEST_DATA_DIR, "opennre-data-test-predict.sqlite")

model = BertOpenNREInferencePipelineItem.init_bert_model(
pretrain_path=pretrain_path, labels_scaler=labels_scaler, ckpt_path=ckpt_path,
Expand All @@ -56,17 +48,16 @@ def infer_bert(pretrain_path, labels_scaler, output_file_gzip, predefined, logge
tokenizer=model.sentence_encoder.tokenize,
batch_size=batch_size,
task_kwargs={
"no_label": "0",
"default_id_column": "id",
"index_columns": ["s_ind", "t_ind"],
"text_columns": ["text_a", "text_b"]
"no_label": "0",
"default_id_column": "id",
"index_columns": ["s_ind", "t_ind"],
"text_columns": ["text_a", "text_b"]
},
shuffle=False,
num_workers=0)

# Open database.
with eval_loader.dataset as dataset:

it_results = BertOpenNREInferencePipelineItem.iter_results(
parallel_model=torch.nn.DataParallel(model),
data_ids=list(dataset.iter_ids()),
Expand Down
54 changes: 54 additions & 0 deletions test/test_opennre_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import utils
import torch
import unittest

from tqdm import tqdm

from opennre.encoder import BERTEncoder
from opennre.model import SoftmaxNN

from arelight.pipelines.items.inference_bert_opennre import BertOpenNREInferencePipelineItem
from arelight.run.utils import OPENNRE_CHECKPOINTS
from arelight.utils import get_default_download_dir

logger = logging.getLogger(__name__)


class TestLoadModel(unittest.TestCase):

CKPT = "ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar"

def test_launch_model(self):
pretrain_path, ckpt_path, label_scaler = BertOpenNREInferencePipelineItem.try_download_predefined_checkpoint(
checkpoint=self.CKPT, predefined=OPENNRE_CHECKPOINTS, dir_to_download=utils.TEST_OUT_DIR, logger=logger)
model = BERTEncoder(pretrain_path=pretrain_path, mask_entity=True, max_length=512)
rel2id = BertOpenNREInferencePipelineItem.scaler_to_rel2id(label_scaler)
model = SoftmaxNN(model, len(rel2id), rel2id)
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))["state_dict"])

def test_infer(self):
self.infer_no_batch(pretrain_path=None,
ckpt_path=self.CKPT,
labels_scaler=None,
predefined=OPENNRE_CHECKPOINTS,
logger=logging.getLogger(__name__))

@staticmethod
def infer_no_batch(pretrain_path, labels_scaler, predefined, logger, ckpt_path=None,
pooler='cls', max_length=128, mask_entity=True):
model = BertOpenNREInferencePipelineItem.init_bert_model(
pretrain_path=pretrain_path, labels_scaler=labels_scaler, ckpt_path=ckpt_path,
device_type="cpu", max_length=max_length, mask_entity=mask_entity, logger=logger,
dir_to_donwload=get_default_download_dir(), pooler=pooler, predefined=predefined)

texts = [
{"token": "сша ввела санкции против россии".split(), "h": {"pos": [0, 1]}, "t": {"pos": [3, 4]}},
{"token": "сша поддержала россии".split(), "h": {"pos": [0, 1]}, "t": {"pos": [3, 4]}},
{"token": "сша и россия".split(), "h": {"pos": [0, 1]}, "t": {"pos": [3, 4]}},
]

texts = texts * 100

for input in tqdm(texts):
_ = model.infer(input)

0 comments on commit 3d15ee0

Please sign in to comment.