-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
88 changed files
with
4,887 additions
and
1,416 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
class ChunkIterator: | ||
|
||
def __init__(self, data_iter, batch_size, chunk_limit): | ||
assert(isinstance(batch_size, int) and batch_size > 0) | ||
self.__data_iter = data_iter | ||
self.__index = -1 | ||
self.__batch_size = batch_size | ||
self.__chunk_limit = chunk_limit | ||
self.__buffer = [] | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
while True: | ||
if len(self.__buffer) > 0: | ||
break | ||
try: | ||
data = next(self.__data_iter) | ||
self.__index += 1 | ||
except StopIteration: | ||
break | ||
for chunk_start in range(0, len(data), self.__chunk_limit): | ||
chunk = data[chunk_start:chunk_start + self.__chunk_limit] | ||
self.__buffer.append([self.__index, chunk]) | ||
|
||
if len(self.__buffer) > 0: | ||
return self.__buffer.pop(0) | ||
|
||
raise StopIteration | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from arelight.readers.base import BaseReader | ||
|
||
|
||
class CustomSQliteReader(BaseReader): | ||
|
||
def __init__(self, storage_type, **storage_kwargs): | ||
self._storage_kwargs = storage_kwargs | ||
self._storage_type = storage_type | ||
|
||
def extension(self): | ||
return ".sqlite" | ||
|
||
def read(self, target): | ||
return self._storage_type(path=target, **self._storage_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import sqlite3 | ||
|
||
from arekit.common.data.const import ID | ||
from arekit.common.data.storages.base import BaseRowsStorage | ||
|
||
|
||
class JoinedSQliteBasedRowsStorage(BaseRowsStorage): | ||
|
||
def __init__(self, path, table_name_a, table_name_b, **kwargs): | ||
super(JoinedSQliteBasedRowsStorage, self).__init__(**kwargs) | ||
self.__path = path | ||
self.__table_name_a = table_name_a | ||
self.__table_name_b = table_name_b | ||
self.__conn = None | ||
|
||
def _iter_rows(self): | ||
with sqlite3.connect(self.__path) as conn: | ||
cursor = conn.execute(f"select * from {self.__table_name_a} inner join {self.__table_name_b}" | ||
f" on {self.__table_name_a}.{ID}={self.__table_name_b}.{ID}") | ||
for row_index, row in enumerate(cursor.fetchall()): | ||
row_dict = {cursor.description[i][0]: value for i, value in enumerate(row)} | ||
yield row_index, row_dict |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from arekit.common.data.rows_fmt import create_base_column_fmt | ||
from arekit.common.data.rows_parser import ParsedSampleRow | ||
from arekit.common.data.storages.base import BaseRowsStorage | ||
from arekit.common.experiment.data_type import DataType | ||
from arekit.common.labels.base import NoLabel | ||
|
||
from arelight.arekit.custom_sqlite_reader import CustomSQliteReader | ||
from arelight.arekit.joined_sqlite import JoinedSQliteBasedRowsStorage | ||
from arelight.arekit.parsed_row_service import ParsedSampleRowExtraService | ||
from arelight.arekit.samples_io import CustomSamplesIO | ||
from arelight.pipelines.demo.labels.base import PositiveLabel, NegativeLabel | ||
from arelight.pipelines.demo.labels.formatter import CustomLabelsFormatter | ||
|
||
|
||
class AREkitSamplesService(object): | ||
|
||
@staticmethod | ||
def _extract_label_from_row(parsed_row): | ||
if parsed_row["col_0"] == 1: | ||
return NoLabel() | ||
elif parsed_row["col_1"] == 1: | ||
return PositiveLabel() | ||
elif parsed_row["col_2"] == 1: | ||
return NegativeLabel() | ||
|
||
@staticmethod | ||
def iter_samples_and_predict_sqlite3(sqlite_filepath, samples_table_name, predict_table_name, | ||
filter_record_func=None): | ||
assert(callable(filter_record_func) or filter_record_func is None) | ||
|
||
samples_io = CustomSamplesIO( | ||
create_target_func=lambda _: sqlite_filepath, | ||
reader=CustomSQliteReader(table_name_a=samples_table_name, table_name_b=predict_table_name, | ||
storage_type=JoinedSQliteBasedRowsStorage)) | ||
|
||
samples_filepath = samples_io.create_target(data_type=DataType.Test) | ||
samples = samples_io.Reader.read(samples_filepath) | ||
assert (isinstance(samples, BaseRowsStorage)) | ||
|
||
column_fmts = [ | ||
# default parameters. | ||
create_base_column_fmt(fmt_type="parser"), | ||
# sentiment score. | ||
{"col_0": lambda value: int(value), "col_1": lambda value: int(value), "col_2": lambda value: int(value)} | ||
] | ||
|
||
labels_fmt = CustomLabelsFormatter() | ||
|
||
for ind, sample_row in samples: | ||
|
||
parsed_row = ParsedSampleRow(sample_row, columns_fmts=column_fmts, no_value_func=lambda: None) | ||
|
||
# reading label. | ||
|
||
record = { | ||
"filename": sample_row["doc_id"].split(':')[0], | ||
"text": sample_row["text_a"].split(), | ||
"s_val": ParsedSampleRowExtraService.calc("SourceValue", parsed_row=parsed_row), | ||
"t_val": ParsedSampleRowExtraService.calc("TargetValue", parsed_row=parsed_row), | ||
"s_type": ParsedSampleRowExtraService.calc("SourceType", parsed_row=parsed_row), | ||
"t_type": ParsedSampleRowExtraService.calc("TargetType", parsed_row=parsed_row), | ||
"label": labels_fmt.label_to_str(AREkitSamplesService._extract_label_from_row(parsed_row)) | ||
} | ||
|
||
if filter_record_func is None: | ||
yield record | ||
elif filter_record_func(record): | ||
yield record |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO | ||
from arekit.contrib.utils.data.writers.base import BaseWriter | ||
|
||
from arelight.readers.base import BaseReader | ||
|
||
|
||
class CustomSamplesIO(BaseSamplesIO): | ||
""" Samples default IO utils for samples. | ||
Sample is a text part which include pair of attitude participants. | ||
This class allows to provide saver and loader for such entries, bubbed as samples. | ||
Samples required for machine learning training/inferring. | ||
""" | ||
|
||
def __init__(self, create_target_func, writer=None, reader=None): | ||
assert(isinstance(writer, BaseWriter) or writer is None) | ||
assert(isinstance(reader, BaseReader) or reader is None) | ||
assert(callable(create_target_func)) | ||
|
||
self.__writer = writer | ||
self.__reader = reader | ||
self.__create_target_func = create_target_func | ||
|
||
self.__target_extension = None | ||
if writer is not None: | ||
self.__target_extension = writer.extension() | ||
elif reader is not None: | ||
self.__target_extension = reader.extension() | ||
|
||
@property | ||
def Reader(self): | ||
return self.__reader | ||
|
||
@property | ||
def Writer(self): | ||
return self.__writer | ||
|
||
def create_target(self, data_type): | ||
return self.__create_target_func(data_type) + self.__target_extension |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
def string_terms_to_list(terms): | ||
r = [] | ||
for t in terms: | ||
if isinstance(t, str): | ||
for i in t.split(' '): | ||
r.append(i) | ||
else: | ||
r.append(t) | ||
return r |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from arekit.common.data.input.providers.columns.base import BaseColumnsProvider | ||
from arekit.common.data.input.providers.contents import ContentsProvider | ||
from arekit.common.data.input.providers.rows.base import BaseRowProvider | ||
from arekit.common.data.storages.base import BaseRowsStorage | ||
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage | ||
from arekit.contrib.utils.data.writers.base import BaseWriter | ||
|
||
|
||
class BaseInputRepository(object): | ||
|
||
def __init__(self, columns_provider, rows_provider, storage): | ||
assert(isinstance(columns_provider, BaseColumnsProvider)) | ||
assert(isinstance(rows_provider, BaseRowProvider)) | ||
assert(isinstance(storage, BaseRowsStorage)) | ||
|
||
self._columns_provider = columns_provider | ||
self._rows_provider = rows_provider | ||
self._storage = storage | ||
|
||
# Do setup operations. | ||
self._setup_columns_provider() | ||
self._setup_rows_provider() | ||
|
||
# region protected methods | ||
|
||
def _setup_columns_provider(self): | ||
pass | ||
|
||
def _setup_rows_provider(self): | ||
pass | ||
|
||
# endregion | ||
|
||
def populate(self, contents_provider, doc_ids, desc="", writer=None, target=None): | ||
assert(isinstance(contents_provider, ContentsProvider)) | ||
assert(isinstance(self._storage, BaseRowsStorage)) | ||
assert(isinstance(doc_ids, list)) | ||
assert(isinstance(writer, BaseWriter) or writer is None) | ||
assert(isinstance(target, str) or target is None) | ||
|
||
def iter_rows(idle_mode): | ||
return self._rows_provider.iter_by_rows( | ||
contents_provider=contents_provider, | ||
doc_ids_iter=doc_ids, | ||
idle_mode=idle_mode) | ||
|
||
self._storage.init_empty(columns_provider=self._columns_provider) | ||
|
||
is_async_write_mode_on = writer is not None and target is not None | ||
|
||
if is_async_write_mode_on: | ||
writer.open_target(target) | ||
|
||
self._storage.fill(lambda idle_mode: iter_rows(idle_mode), | ||
columns_provider=self._columns_provider, | ||
row_handler=lambda: writer.commit_line(self._storage) if is_async_write_mode_on else None, | ||
desc=desc) | ||
|
||
if is_async_write_mode_on: | ||
writer.close_target() | ||
|
||
def push(self, writer, target, free_storage=True): | ||
if not isinstance(self._storage, RowCacheStorage): | ||
writer.write_all(self._storage, target) | ||
|
||
# After writing we free the contents of the storage. | ||
if free_storage: | ||
self._storage.free() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import logging | ||
|
||
from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider | ||
|
||
from arelight.data.repositories.base import BaseInputRepository | ||
|
||
logger = logging.getLogger(__name__) | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
class BaseInputSamplesRepository(BaseInputRepository): | ||
|
||
def _setup_rows_provider(self): | ||
""" Setup store labels. | ||
""" | ||
assert(isinstance(self._rows_provider, BaseSampleRowProvider)) | ||
self._rows_provider.set_store_labels(self._columns_provider.StoreLabels) | ||
|
||
def _setup_columns_provider(self): | ||
""" Setup text column names. | ||
""" | ||
text_column_names = list(self._rows_provider.TextProvider.iter_columns()) | ||
self._columns_provider.set_text_column_names(text_column_names) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider | ||
from arekit.common.data.storages.base import BaseRowsStorage | ||
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO | ||
from arekit.common.experiment.data_type import DataType | ||
from arekit.common.pipeline.items.base import BasePipelineItem | ||
|
||
from arelight.data.serializer_helper import InputDataSerializationHelper | ||
|
||
|
||
class BaseSerializerPipelineItem(BasePipelineItem): | ||
|
||
def __init__(self, rows_provider, samples_io, save_labels_func, storage, **kwargs): | ||
""" sample_rows_formatter: | ||
how we format input texts for a BERT model, for example: | ||
- single text | ||
- two sequences, separated by [SEP] token | ||
save_labels_func: function | ||
data_type -> bool | ||
""" | ||
assert(isinstance(rows_provider, BaseSampleRowProvider)) | ||
assert(isinstance(samples_io, BaseSamplesIO)) | ||
assert(callable(save_labels_func)) | ||
assert(isinstance(storage, BaseRowsStorage)) | ||
super(BaseSerializerPipelineItem, self).__init__(**kwargs) | ||
|
||
self._rows_provider = rows_provider | ||
self._samples_io = samples_io | ||
self._save_labels_func = save_labels_func | ||
self._storage = storage | ||
|
||
def _serialize_iteration(self, data_type, pipeline, data_folding, doc_ids): | ||
assert(isinstance(data_type, DataType)) | ||
assert(isinstance(pipeline, list)) | ||
assert(isinstance(data_folding, dict) or data_folding is None) | ||
assert(isinstance(doc_ids, list) or doc_ids is None) | ||
assert(doc_ids is not None or data_folding is not None) | ||
|
||
repos = { | ||
"sample": InputDataSerializationHelper.create_samples_repo( | ||
keep_labels=self._save_labels_func(data_type), | ||
rows_provider=self._rows_provider, | ||
storage=self._storage), | ||
} | ||
|
||
writer_and_targets = { | ||
"sample": (self._samples_io.Writer, | ||
self._samples_io.create_target(data_type=data_type)), | ||
} | ||
|
||
for description, repo in repos.items(): | ||
|
||
if data_folding is None: | ||
# Consider only the predefined doc_ids. | ||
doc_ids_iter = doc_ids | ||
else: | ||
# Take particular data_type. | ||
doc_ids_iter = data_folding[data_type] | ||
# Consider only predefined doc_ids. | ||
if doc_ids is not None: | ||
doc_ids_iter = set(doc_ids_iter).intersection(doc_ids) | ||
|
||
InputDataSerializationHelper.fill_and_write( | ||
repo=repo, | ||
pipeline=pipeline, | ||
doc_ids_iter=doc_ids_iter, | ||
desc="{desc} [{data_type}]".format(desc=description, data_type=data_type), | ||
writer=writer_and_targets[description][0], | ||
target=writer_and_targets[description][1]) | ||
|
||
def _handle_iteration(self, data_type_pipelines, data_folding, doc_ids): | ||
""" Performing data serialization for a particular iteration | ||
""" | ||
assert(isinstance(data_type_pipelines, dict)) | ||
for data_type, pipeline in data_type_pipelines.items(): | ||
self._serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding, | ||
doc_ids=doc_ids) | ||
|
||
def apply_core(self, input_data, pipeline_ctx): | ||
""" | ||
data_type_pipelines: dict of, for example: | ||
{ | ||
DataType.Train: BasePipeline, | ||
DataType.Test: BasePipeline | ||
} | ||
data_type_pipelines: doc_id -> parsed_doc -> annot -> opinion linkages | ||
for example, function: sentiment_attitude_extraction_default_pipeline | ||
doc_ids: optional | ||
this parameter allows to limit amount of documents considered for sampling | ||
""" | ||
assert("data_type_pipelines" in pipeline_ctx) | ||
self._handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"), | ||
doc_ids=pipeline_ctx.provide_or_none("doc_ids"), | ||
data_folding=pipeline_ctx.provide_or_none("data_folding")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import logging | ||
|
||
from collections.abc import Iterable | ||
|
||
from arekit.common.data.input.providers.columns.sample import SampleColumnsProvider | ||
from arekit.common.data.input.providers.rows.base import BaseRowProvider | ||
from arekit.common.data.storages.base import BaseRowsStorage | ||
from arekit.contrib.utils.data.contents.opinions import InputTextOpinionProvider | ||
|
||
from arelight.data.repositories.base import BaseInputRepository | ||
from arelight.data.repositories.sample import BaseInputSamplesRepository | ||
|
||
logger = logging.getLogger(__name__) | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
class InputDataSerializationHelper(object): | ||
|
||
@staticmethod | ||
def create_samples_repo(keep_labels, rows_provider, storage): | ||
assert(isinstance(rows_provider, BaseRowProvider)) | ||
assert(isinstance(keep_labels, bool)) | ||
assert(isinstance(storage, BaseRowsStorage)) | ||
return BaseInputSamplesRepository( | ||
columns_provider=SampleColumnsProvider(store_labels=keep_labels), | ||
rows_provider=rows_provider, | ||
storage=storage) | ||
|
||
@staticmethod | ||
def fill_and_write(pipeline, repo, target, writer, doc_ids_iter, desc=""): | ||
assert(isinstance(pipeline, list)) | ||
assert(isinstance(doc_ids_iter, Iterable)) | ||
assert(isinstance(repo, BaseInputRepository)) | ||
|
||
doc_ids = list(doc_ids_iter) | ||
|
||
repo.populate(contents_provider=InputTextOpinionProvider(pipeline), | ||
doc_ids=doc_ids, | ||
desc=desc, | ||
writer=writer, | ||
target=target) | ||
|
||
repo.push(writer=writer, target=target) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import csv | ||
import os | ||
from os.path import dirname | ||
|
||
from arekit.common.data.storages.base import BaseRowsStorage | ||
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage | ||
from arekit.contrib.utils.data.writers.base import BaseWriter | ||
|
||
|
||
class NativeCsvWriter(BaseWriter): | ||
|
||
def __init__(self, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL, header=True): | ||
self.__target_f = None | ||
self.__writer = None | ||
self.__create_writer_func = lambda f: csv.writer( | ||
f, delimiter=delimiter, quotechar=quotechar, quoting=quoting) | ||
self.__header = header | ||
self.__header_written = None | ||
|
||
def extension(self): | ||
return ".csv" | ||
|
||
@staticmethod | ||
def __iter_storage_column_names(storage): | ||
""" Iter only those columns that existed in storage. | ||
""" | ||
for col_name in storage.iter_column_names(): | ||
if col_name in storage.RowCache: | ||
yield col_name | ||
|
||
def open_target(self, target): | ||
os.makedirs(dirname(target), exist_ok=True) | ||
self.__target_f = open(target, "w") | ||
self.__writer = self.__create_writer_func(self.__target_f) | ||
self.__header_written = not self.__header | ||
|
||
def close_target(self): | ||
self.__target_f.close() | ||
|
||
def commit_line(self, storage): | ||
assert(isinstance(storage, RowCacheStorage)) | ||
assert(self.__writer is not None) | ||
|
||
if not self.__header_written: | ||
self.__writer.writerow(list(self.__iter_storage_column_names(storage))) | ||
self.__header_written = True | ||
|
||
line_data = list(map(lambda col_name: storage.RowCache[col_name], | ||
self.__iter_storage_column_names(storage))) | ||
self.__writer.writerow(line_data) | ||
|
||
def write_all(self, storage, target): | ||
""" Writes all the `storage` rows | ||
into the `target` filepath, formatted as CSV. | ||
""" | ||
assert(isinstance(storage, BaseRowsStorage)) | ||
|
||
with open(target, "w") as f: | ||
writer = self.__create_writer_func(f) | ||
for _, row in storage: | ||
#content = [row[col_name] for col_name in storage.iter_column_names()] | ||
content = [v for v in row] | ||
writer.writerow(content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import logging | ||
|
||
from arekit.common.data.input.providers.columns.base import BaseColumnsProvider | ||
from arekit.common.utils import create_dir_if_not_exists | ||
from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage | ||
from arekit.contrib.utils.data.writers.base import BaseWriter | ||
|
||
logger = logging.getLogger(__name__) | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
class PandasCsvWriter(BaseWriter): | ||
|
||
def __init__(self, write_header): | ||
super(PandasCsvWriter, self).__init__() | ||
self.__write_header = write_header | ||
|
||
def extension(self): | ||
return ".tsv.gz" | ||
|
||
def write_all(self, storage, target): | ||
assert(isinstance(storage, PandasBasedRowsStorage)) | ||
assert(isinstance(target, str)) | ||
|
||
create_dir_if_not_exists(target) | ||
|
||
# Temporary hack, remove it in future. | ||
df = storage.DataFrame | ||
|
||
logger.info("Saving... {length}: {filepath}".format(length=len(storage), filepath=target)) | ||
df.to_csv(target, | ||
sep='\t', | ||
encoding='utf-8', | ||
columns=[c for c in df.columns if c != BaseColumnsProvider.ROW_ID], | ||
index=False, | ||
float_format="%.0f", | ||
compression='gzip', | ||
header=self.__write_header) | ||
|
||
logger.info("Saving completed!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import json | ||
import logging | ||
import os | ||
from os.path import dirname | ||
|
||
from arekit.common.data import const | ||
from arekit.common.data.storages.base import BaseRowsStorage | ||
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage | ||
from arekit.contrib.utils.data.writers.base import BaseWriter | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class OpenNREJsonWriter(BaseWriter): | ||
""" This is a bag-based writer for the samples. | ||
Project page: https://github.com/thunlp/OpenNRE | ||
Every bag presented as follows: | ||
{ | ||
'text' or 'token': ..., | ||
'h': {'pos': [start, end], 'id': ... }, | ||
't': {'pos': [start, end], 'id': ... } | ||
'id': "id_of_the_text_opinion" | ||
} | ||
In terms of the linked opinions (i0, i1, etc.) we consider id of the first opinion in linkage. | ||
During the dataset reading stage via OpenNRE, these linkages automaticaly groups into bags. | ||
""" | ||
|
||
def __init__(self, text_columns, encoding="utf-8", na_value="NA", keep_extra_columns=True, | ||
skip_extra_existed=True): | ||
""" text_columns: list | ||
column names that expected to be joined into a single (token) column. | ||
""" | ||
assert(isinstance(text_columns, list)) | ||
assert(isinstance(encoding, str)) | ||
self.__text_columns = text_columns | ||
self.__encoding = encoding | ||
self.__target_f = None | ||
self.__keep_extra_columns = keep_extra_columns | ||
self.__na_value = na_value | ||
self.__skip_extra_existed = skip_extra_existed | ||
|
||
def extension(self): | ||
return ".jsonl" | ||
|
||
@staticmethod | ||
def __format_row(row, na_value, text_columns, keep_extra_columns, skip_extra_existed): | ||
""" Formatting that is compatible with the OpenNRE. | ||
""" | ||
assert(isinstance(na_value, str)) | ||
|
||
sample_id = row[const.ID] | ||
s_ind = int(row[const.S_IND]) | ||
t_ind = int(row[const.T_IND]) | ||
bag_id = str(row[const.OPINION_ID]) | ||
|
||
# Gather tokens. | ||
tokens = [] | ||
for text_col in text_columns: | ||
if text_col in row: | ||
tokens.extend(row[text_col].split()) | ||
|
||
# Filtering JSON row. | ||
formatted_data = { | ||
"id": bag_id, | ||
"id_orig": sample_id, | ||
"token": tokens, | ||
"h": {"pos": [s_ind, s_ind + 1], "id": str(bag_id + "s")}, | ||
"t": {"pos": [t_ind, t_ind + 1], "id": str(bag_id + "t")}, | ||
"relation": str(int(row[const.LABEL_UINT])) if const.LABEL_UINT in row else na_value | ||
} | ||
|
||
# Register extra fields (optionally). | ||
if keep_extra_columns: | ||
for key, value in row.items(): | ||
if key not in formatted_data and key not in text_columns: | ||
formatted_data[key] = value | ||
else: | ||
if not skip_extra_existed: | ||
raise Exception(f"key `{key}` is already exist in formatted data " | ||
f"or a part of the text columns list: {text_columns}") | ||
|
||
return formatted_data | ||
|
||
def open_target(self, target): | ||
os.makedirs(dirname(target), exist_ok=True) | ||
self.__target_f = open(target, "w") | ||
pass | ||
|
||
def close_target(self): | ||
self.__target_f.close() | ||
|
||
def commit_line(self, storage): | ||
assert(isinstance(storage, RowCacheStorage)) | ||
|
||
# Collect existed columns. | ||
row_data = {} | ||
for col_name in storage.iter_column_names(): | ||
if col_name not in storage.RowCache: | ||
continue | ||
row_data[col_name] = storage.RowCache[col_name] | ||
|
||
bag = self.__format_row(row_data, text_columns=self.__text_columns, | ||
keep_extra_columns=self.__keep_extra_columns, | ||
na_value=self.__na_value, | ||
skip_extra_existed=self.__skip_extra_existed) | ||
|
||
self.__write_bag(bag=bag, json_file=self.__target_f) | ||
|
||
@staticmethod | ||
def __write_bag(bag, json_file): | ||
assert(isinstance(bag, dict)) | ||
json.dump(bag, json_file, separators=(",", ":"), ensure_ascii=False) | ||
json_file.write("\n") | ||
|
||
def write_all(self, storage, target): | ||
assert(isinstance(storage, BaseRowsStorage)) | ||
assert(isinstance(target, str)) | ||
|
||
logger.info("Saving... {rows}: {filepath}".format(rows=(len(storage)), filepath=target)) | ||
|
||
os.makedirs(os.path.dirname(target), exist_ok=True) | ||
with open(target, "w", encoding=self.__encoding) as json_file: | ||
for row_index, row in storage: | ||
self.__write_bag(bag=self.__format_row(row, text_columns=self.__text_columns, | ||
keep_extra_columns=self.__keep_extra_columns, | ||
na_value=self.__na_value, | ||
skip_extra_existed=self.__skip_extra_existed), | ||
json_file=json_file) | ||
|
||
logger.info("Saving completed!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import os | ||
import sqlite3 | ||
from os.path import dirname | ||
|
||
from arekit.common.data import const | ||
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage | ||
from arekit.contrib.utils.data.writers.base import BaseWriter | ||
|
||
|
||
class SQliteWriter(BaseWriter): | ||
""" TODO. This implementation is dedicated for the writing concepts of the data | ||
serialization pipeline. However we add the SQLite3 service, it would be | ||
right to refactor and utlize some core functionality from the core/service/sqlite.py | ||
""" | ||
|
||
def __init__(self, table_name="contents", index_column_names=None, skip_existed=False, clear_table=True): | ||
""" index_column_names: list or None | ||
column names should be considered to build a unique index; | ||
if None, the default 'const.ID' will be considered for row indexation. | ||
""" | ||
assert (isinstance(index_column_names, list) or index_column_names is None) | ||
self.__index_column_names = index_column_names if index_column_names is not None else [const.ID] | ||
self.__table_name = table_name | ||
self.__conn = None | ||
self.__cur = None | ||
self.__need_init_table = True | ||
self.__origin_column_names = None | ||
self.__skip_existed = skip_existed | ||
self.__clear_table = clear_table | ||
|
||
def extension(self): | ||
return ".sqlite" | ||
|
||
@staticmethod | ||
def __iter_storage_column_names(storage): | ||
""" Iter only those columns that existed in storage. | ||
""" | ||
assert (isinstance(storage, RowCacheStorage)) | ||
for col_name, col_type in zip(storage.iter_column_names(), storage.iter_column_types()): | ||
if col_name in storage.RowCache: | ||
yield col_name, col_type | ||
|
||
def __init_table(self, column_data): | ||
# Compose column name with the related SQLITE type. | ||
column_types = ",".join([" ".join([col_name, self.type_to_sqlite(col_type)]) | ||
for col_name, col_type in column_data]) | ||
# Create table if not exists. | ||
self.__cur.execute(f"CREATE TABLE IF NOT EXISTS {self.__table_name}({column_types})") | ||
# Table exists, however we may optionally remove the content from it. | ||
if self.__clear_table: | ||
self.__cur.execute(f"DELETE FROM {self.__table_name};") | ||
# Create index. | ||
index_name = f"i_{self.__table_name}_id" | ||
self.__cur.execute(f"DROP INDEX IF EXISTS {index_name};") | ||
self.__cur.execute("CREATE INDEX IF NOT EXISTS {index} ON {table}({columns})".format( | ||
index=index_name, | ||
table=self.__table_name, | ||
columns=", ".join(self.__index_column_names) | ||
)) | ||
self.__origin_column_names = [col_name for col_name, _ in column_data] | ||
|
||
@staticmethod | ||
def type_to_sqlite(col_type): | ||
""" This is a simple function that provides conversion from the | ||
base numpy types to SQLITE. | ||
NOTE: this method represent a quick implementation for supporting | ||
types, however it is far away from the generalized implementation. | ||
""" | ||
if isinstance(col_type, str): | ||
if 'int' in col_type: | ||
return 'INTEGER' | ||
|
||
return "TEXT" | ||
|
||
def open_target(self, target): | ||
os.makedirs(dirname(target), exist_ok=True) | ||
self.__conn = sqlite3.connect(target) | ||
self.__cur = self.__conn.cursor() | ||
|
||
def commit_line(self, storage): | ||
assert (isinstance(storage, RowCacheStorage)) | ||
|
||
column_data = list(self.__iter_storage_column_names(storage)) | ||
|
||
if self.__need_init_table: | ||
self.__init_table(column_data) | ||
self.__need_init_table = False | ||
|
||
# Check whether the related row is already exist in SQLITE database. | ||
row_id = storage.RowCache[const.ID] | ||
top_row = self.__cur.execute(f"SELECT EXISTS(SELECT 1 FROM {self.__table_name} WHERE id='{row_id}');") | ||
is_exists = top_row.fetchone()[0] | ||
if is_exists == 1 and self.__skip_existed: | ||
return | ||
|
||
line_data = [storage.RowCache[col_name] for col_name, _ in column_data] | ||
parameters = ",".join(["?"] * len(line_data)) | ||
|
||
assert (len(self.__origin_column_names) == len(line_data)) | ||
|
||
self.__cur.execute( | ||
f"INSERT OR REPLACE INTO {self.__table_name} VALUES ({parameters})", | ||
tuple(line_data)) | ||
|
||
self.__conn.commit() | ||
|
||
def close_target(self): | ||
self.__cur = None | ||
self.__origin_column_names = None | ||
self.__need_init_table = True | ||
self.__conn.close() | ||
|
||
def write_all(self, storage, target): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,97 @@ | ||
from arekit.common.bound import Bound | ||
from arekit.common.docs.objects_parser import SentenceObjectsParserPipelineItem | ||
from arekit.common.text.partitioning.terms import TermsPartitioning | ||
from arekit.common.pipeline.items.base import BasePipelineItem | ||
from arekit.common.pipeline.utils import BatchIterator | ||
from arekit.common.text.partitioning import Partitioning | ||
|
||
from arelight.arekit.chunk_it import ChunkIterator | ||
from arelight.ner.deep_pavlov import DeepPavlovNER | ||
from arelight.ner.obj_desc import NerObjectDescriptor | ||
from arelight.pipelines.items.entity import IndexedEntity | ||
from arelight.utils import IdAssigner | ||
|
||
|
||
class DeepPavlovNERPipelineItem(SentenceObjectsParserPipelineItem): | ||
class DeepPavlovNERPipelineItem(BasePipelineItem): | ||
|
||
def __init__(self, id_assigner, ner_model_name, obj_filter=None, chunk_limit=128, display_value_func=None): | ||
def __init__(self, id_assigner, ner_model_name, obj_filter=None, | ||
chunk_limit=128, display_value_func=None, **kwargs): | ||
""" chunk_limit: int | ||
length of text part in words that is going to be provided in input. | ||
""" | ||
assert(callable(obj_filter) or obj_filter is None) | ||
assert(isinstance(chunk_limit, int) and chunk_limit > 0) | ||
assert(isinstance(id_assigner, IdAssigner)) | ||
assert(callable(display_value_func) or display_value_func is None) | ||
super(DeepPavlovNERPipelineItem, self).__init__(**kwargs) | ||
|
||
# Initialize bert-based model instance. | ||
self.__dp_ner = DeepPavlovNER(ner_model_name) | ||
self.__obj_filter = obj_filter | ||
self.__chunk_limit = chunk_limit | ||
self.__id_assigner = id_assigner | ||
self.__disp_value_func = display_value_func | ||
super(DeepPavlovNERPipelineItem, self).__init__(TermsPartitioning()) | ||
self.__partitioning = Partitioning(text_fmt="list") | ||
|
||
def _get_parts_provider_func(self, input_data, pipeline_ctx): | ||
return self.__iter_subs_values_with_bounds(input_data) | ||
@property | ||
def SupportBatching(self): | ||
return True | ||
|
||
def __iter_subs_values_with_bounds(self, terms_list): | ||
assert(isinstance(terms_list, list)) | ||
def __iter_subs_values_with_bounds(self, batch_it): | ||
chunk_offset = 0 | ||
handled_text_index = -1 | ||
for batch in batch_it: | ||
text_indices, texts = zip(*batch) | ||
|
||
for chunk_start in range(0, len(terms_list), self.__chunk_limit): | ||
single_sentence_chunk = [terms_list[chunk_start:chunk_start+self.__chunk_limit]] | ||
|
||
# NOTE: in some cases, for example URL links or other long input words, | ||
# the overall behavior might result in exceeding the assumed threshold. | ||
# In order to completely prevent it, we consider to wrap the call | ||
# of NER module into try-catch block. | ||
try: | ||
processed_sequences = self.__dp_ner.extract(sequences=single_sentence_chunk) | ||
data = self.__dp_ner.extract(sequences=list(texts)) | ||
except RuntimeError: | ||
processed_sequences = [] | ||
|
||
entities_it = self.__iter_parsed_entities(processed_sequences, | ||
chunk_terms_list=single_sentence_chunk[0], | ||
chunk_offset=chunk_start) | ||
data = None | ||
|
||
for entity, bound in entities_it: | ||
yield entity, bound | ||
if data is not None: | ||
for i, d in enumerate(data): | ||
terms, descriptors = d | ||
if text_indices[i] != handled_text_index: | ||
chunk_offset = 0 | ||
entities_it = self.__iter_parsed_entities( | ||
descriptors=descriptors, terms_list=terms, chunk_offset=chunk_offset) | ||
handled_text_index = text_indices[i] | ||
chunk_offset += len(terms) | ||
yield text_indices[i], terms, list(entities_it) | ||
else: | ||
for i in range(len(batch)): | ||
yield text_indices[i], texts[i], [] | ||
|
||
def __iter_parsed_entities(self, processed_sequences, chunk_terms_list, chunk_offset): | ||
for p_sequence in processed_sequences: | ||
for s_obj in p_sequence: | ||
assert (isinstance(s_obj, NerObjectDescriptor)) | ||
def __iter_parsed_entities(self, descriptors, terms_list, chunk_offset): | ||
for s_obj in descriptors: | ||
assert (isinstance(s_obj, NerObjectDescriptor)) | ||
|
||
if self.__obj_filter is not None and not self.__obj_filter(s_obj): | ||
continue | ||
if self.__obj_filter is not None and not self.__obj_filter(s_obj): | ||
continue | ||
|
||
value = " ".join(chunk_terms_list[s_obj.Position:s_obj.Position + s_obj.Length]) | ||
entity = IndexedEntity( | ||
value=value, e_type=s_obj.ObjectType, entity_id=self.__id_assigner.get_id(), | ||
display_value=self.__disp_value_func(value) if self.__disp_value_func is not None else None) | ||
yield entity, Bound(pos=chunk_offset + s_obj.Position, length=s_obj.Length) | ||
value = " ".join(terms_list[s_obj.Position:s_obj.Position + s_obj.Length]) | ||
entity = IndexedEntity( | ||
value=value, e_type=s_obj.ObjectType, entity_id=self.__id_assigner.get_id(), | ||
display_value=self.__disp_value_func(value) if self.__disp_value_func is not None else None) | ||
yield entity, Bound(pos=chunk_offset + s_obj.Position, length=s_obj.Length) | ||
|
||
def apply_core(self, input_data, pipeline_ctx): | ||
return super(DeepPavlovNERPipelineItem, self).apply_core(input_data=input_data, pipeline_ctx=pipeline_ctx) | ||
assert(isinstance(input_data, list)) | ||
|
||
batch_size = len(input_data) | ||
|
||
c_it = ChunkIterator(iter(input_data), batch_size=batch_size, chunk_limit=self.__chunk_limit) | ||
b_it = BatchIterator(c_it, batch_size=batch_size) | ||
|
||
parts_it = self.__iter_subs_values_with_bounds(b_it) | ||
|
||
terms = [[] for _ in range(batch_size)] | ||
bounds = [[] for _ in range(batch_size)] | ||
for i, t, e in parts_it: | ||
terms[i].extend(t) | ||
bounds[i].extend(e) | ||
|
||
# Compose batch data. | ||
b_data = [] | ||
for i in range(batch_size): | ||
b_data.append(self.__partitioning.provide(text=terms[i], parts_it=bounds[i])) | ||
|
||
return b_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,32 @@ | ||
from arekit.common.pipeline.context import PipelineContext | ||
from arekit.common.pipeline.items.base import BasePipelineItem | ||
|
||
from arelight.predict_provider import BasePredictProvider | ||
from arelight.predict_writer import BasePredictWriter | ||
from arelight.predict.provider import BasePredictProvider | ||
from arelight.predict.writer import BasePredictWriter | ||
|
||
|
||
class InferenceWriterPipelineItem(BasePipelineItem): | ||
|
||
def __init__(self, writer): | ||
def __init__(self, writer, **kwargs): | ||
assert(isinstance(writer, BasePredictWriter)) | ||
super(InferenceWriterPipelineItem, self).__init__(**kwargs) | ||
self.__writer = writer | ||
|
||
def apply_core(self, input_data, pipeline_ctx): | ||
assert(isinstance(input_data, PipelineContext)) | ||
|
||
# Setup predicted result writer. | ||
target = input_data.provide("predict_filepath") | ||
print(target) | ||
target = pipeline_ctx.provide("predict_filepath") | ||
|
||
self.__writer.set_target(target) | ||
|
||
# Extracting list of the uint labels. | ||
labels_scaler = pipeline_ctx.provide("labels_scaler") | ||
uint_labels = [labels_scaler.label_to_uint(label) for label in labels_scaler.ordered_suppoted_labels()] | ||
|
||
# Gathering the content | ||
title, contents_it = BasePredictProvider().provide( | ||
sample_id_with_uint_labels_iter=input_data.provide("iter_infer"), | ||
labels_count=input_data.provide("labels_scaler").LabelsCount) | ||
header, contents_it = BasePredictProvider.provide_to_storage( | ||
sample_id_with_uint_labels_iter=pipeline_ctx.provide("iter_infer"), | ||
uint_labels=uint_labels) | ||
|
||
with self.__writer: | ||
self.__writer.write(title=title, contents_it=contents_it, | ||
total=input_data.provide_or_none("iter_total")) | ||
self.__writer.write(header=header, contents_it=contents_it, | ||
total=pipeline_ctx.provide_or_none("iter_total")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from arekit.common.data import const | ||
|
||
|
||
class PredictHeader: | ||
|
||
@staticmethod | ||
def create_header(uint_labels, uint_label_to_str, create_id=True): | ||
assert(callable(uint_label_to_str)) | ||
|
||
header = [] | ||
|
||
if create_id: | ||
header.append(const.ID) | ||
|
||
header.extend([uint_label_to_str(uint_label) for uint_label in uint_labels]) | ||
|
||
return header |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from collections.abc import Iterable | ||
|
||
from arekit.common.data.storages.base import BaseRowsStorage | ||
|
||
from arelight.predict.header import PredictHeader | ||
|
||
|
||
class BasePredictProvider(object): | ||
|
||
UINT_TO_STR = lambda uint_label: f"col_{uint_label}" | ||
|
||
@staticmethod | ||
def __iter_contents(sample_id_with_uint_labels_iter, labels_count, column_extra_funcs): | ||
assert(isinstance(labels_count, int)) | ||
|
||
for sample_id, uint_label in sample_id_with_uint_labels_iter: | ||
assert(isinstance(uint_label, int)) | ||
|
||
labels = ['0'] * labels_count | ||
labels[uint_label] = '1' | ||
|
||
# Composing row contents. | ||
contents = [sample_id] | ||
|
||
# Optionally provide additional values. | ||
if column_extra_funcs is not None: | ||
for _, value_func in column_extra_funcs: | ||
contents.append(str(value_func(sample_id))) | ||
|
||
# Providing row labels. | ||
contents.extend(labels) | ||
yield contents | ||
|
||
@staticmethod | ||
def provide_to_storage(sample_id_with_uint_labels_iter, uint_labels): | ||
assert(isinstance(sample_id_with_uint_labels_iter, Iterable)) | ||
|
||
# Provide contents. | ||
contents_it = BasePredictProvider.__iter_contents( | ||
sample_id_with_uint_labels_iter=sample_id_with_uint_labels_iter, | ||
labels_count=len(uint_labels), | ||
column_extra_funcs=None) | ||
|
||
header = PredictHeader.create_header(uint_labels=uint_labels, | ||
uint_label_to_str=BasePredictProvider.UINT_TO_STR) | ||
|
||
return header, contents_it | ||
|
||
@staticmethod | ||
def iter_from_storage(predict_data, uint_labels, keep_ind=True): | ||
assert (isinstance(predict_data, BaseRowsStorage)) | ||
|
||
header = PredictHeader.create_header(uint_labels=uint_labels, | ||
uint_label_to_str=BasePredictProvider.UINT_TO_STR, | ||
create_id=False) | ||
|
||
for res_ind, row in predict_data: | ||
|
||
uint_label = None | ||
for label, field_name in zip(uint_labels, header): | ||
if row[field_name] > 0: | ||
uint_label = label | ||
break | ||
|
||
if keep_ind: | ||
yield res_ind, uint_label | ||
else: | ||
yield uint_label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import logging | ||
|
||
from arekit.common.utils import progress_bar_defined | ||
from source_iter.service_sqlite import SQLite3Service | ||
|
||
from arelight.predict.writer import BasePredictWriter | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SQLite3PredictWriter(BasePredictWriter): | ||
|
||
def __init__(self, table_name, log_out=None): | ||
super(SQLite3PredictWriter, self).__init__() | ||
self.__table_name = table_name | ||
self.__log_out = log_out | ||
|
||
def write(self, header, contents_it, total=None): | ||
|
||
content_header = header[1:] | ||
SQLite3Service.write_missed( | ||
columns=content_header, | ||
target=self._target, | ||
table_name=self.__table_name, | ||
it_type=None, | ||
data_it=progress_bar_defined(iterable=map(lambda item: [item[0], item[1:]], contents_it), | ||
desc=f'Writing output (sqlite:{self.__table_name})', | ||
unit='rows', total=total, file=self.__log_out), | ||
sqlite3_column_types=["INTEGER"] * len(content_header), | ||
id_column_name=header[0], | ||
id_column_type="INTEGER") | ||
|
||
def __enter__(self): | ||
pass | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
pass |
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
class BaseReader(object): | ||
|
||
def extension(self): | ||
raise NotImplementedError() | ||
|
||
def read(self, target): | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import importlib | ||
|
||
from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage | ||
|
||
from arelight.readers.base import BaseReader | ||
|
||
|
||
class PandasCsvReader(BaseReader): | ||
""" Represents a CSV-based reader, implmented via pandas API. | ||
""" | ||
|
||
def __init__(self, sep='\t', header='infer', compression='infer', encoding='utf-8', col_types=None, | ||
custom_extension=None): | ||
self.__sep = sep | ||
self.__compression = compression | ||
self.__encoding = encoding | ||
self.__header = header | ||
self.__custom_extension = custom_extension | ||
|
||
# Special assignation of types for certain columns. | ||
self.__col_types = col_types | ||
if self.__col_types is None: | ||
self.__col_types = dict() | ||
|
||
def extension(self): | ||
return ".tsv.gz" if self.__custom_extension is None else self.__custom_extension | ||
|
||
def __from_csv(self, filepath): | ||
pd = importlib.import_module("pandas") | ||
return pd.read_csv(filepath, | ||
sep=self.__sep, | ||
encoding=self.__encoding, | ||
compression=self.__compression, | ||
dtype=self.__col_types, | ||
header=self.__header) | ||
|
||
def read(self, target): | ||
df = self.__from_csv(filepath=target) | ||
return PandasBasedRowsStorage(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from arekit.contrib.utils.data.storages.jsonl_based import JsonlBasedRowsStorage | ||
|
||
from arelight.readers.base import BaseReader | ||
|
||
|
||
class JsonlReader(BaseReader): | ||
|
||
def extension(self): | ||
return ".jsonl" | ||
|
||
def read(self, target): | ||
rows = [] | ||
with open(target, "r") as f: | ||
for line in f.readlines(): | ||
rows.append(line) | ||
return JsonlBasedRowsStorage(rows) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from arekit.contrib.utils.data.storages.sqlite_based import SQliteBasedRowsStorage | ||
|
||
from arelight.readers.base import BaseReader | ||
|
||
|
||
class SQliteReader(BaseReader): | ||
|
||
def __init__(self, table_name): | ||
self.__table_name = table_name | ||
|
||
def extension(self): | ||
return ".sqlite" | ||
|
||
def read(self, target): | ||
return SQliteBasedRowsStorage(path=target, table_name=self.__table_name) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import io | ||
import logging | ||
import sys | ||
from tqdm import tqdm | ||
|
||
logger = logging.getLogger(__name__) | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
class TqdmLoggingHandler(logging.Handler): | ||
def __init__(self, level=logging.NOTSET): | ||
super().__init__(level) | ||
|
||
def emit(self, record): | ||
try: | ||
msg = self.format(record) | ||
tqdm.write(msg) | ||
self.flush() | ||
except Exception: | ||
self.handleError(record) | ||
|
||
|
||
class TqdmToLogger(io.StringIO): | ||
""" | ||
Output stream for TQDM which will output to logger module instead of | ||
the StdOut. | ||
""" | ||
logger = None | ||
level = None | ||
buf = '' | ||
|
||
def __init__(self, logger, level=None): | ||
super(TqdmToLogger, self).__init__() | ||
self.logger = logger | ||
self.level = level or logging.INFO | ||
|
||
def write(self, buf): | ||
self.buf = buf.strip('\r\n\t ') | ||
|
||
def flush(self): | ||
self.logger.log(self.level, self.buf) | ||
|
||
|
||
def setup_custom_logger(name, add_screen_handler=False, filepath=None): | ||
formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S') | ||
logger = logging.getLogger(name) | ||
logger.setLevel(logging.DEBUG) | ||
logger.addHandler(TqdmLoggingHandler()) | ||
|
||
if add_screen_handler: | ||
screen_handler = logging.StreamHandler(stream=sys.stdout) | ||
screen_handler.setFormatter(formatter) | ||
logger.addHandler(screen_handler) | ||
|
||
if filepath is not None: | ||
handler = logging.FileHandler(filepath, mode='w') | ||
handler.setFormatter(formatter) | ||
logger.addHandler(handler) | ||
|
||
return logger |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from arekit.common.text.stemmer import Stemmer | ||
from arekit.common.utils import filter_whitespaces | ||
from pymystem3 import Mystem | ||
|
||
|
||
class MystemWrapper(Stemmer): | ||
""" Yandex MyStem wrapper | ||
part of speech description: | ||
https://tech.yandex.ru/mystem/doc/grammemes-values-docpage/ | ||
""" | ||
|
||
def __init__(self, entire_input=False): | ||
""" | ||
entire_input: bool | ||
Mystem parameter that allows to keep all information from input (true) or | ||
remove garbage characters | ||
""" | ||
self.__mystem = Mystem(entire_input=entire_input) | ||
|
||
# region properties | ||
|
||
@property | ||
def MystemInstance(self): | ||
return self.__mystem | ||
|
||
# endregion | ||
|
||
# region public methods | ||
|
||
def lemmatize_to_list(self, text): | ||
return self.__lemmatize_core(text) | ||
|
||
def lemmatize_to_str(self, text): | ||
result = " ".join(self.__lemmatize_core(text)) | ||
return result if len(result) != 0 else self.__process_original_text(text) | ||
|
||
# endregion | ||
|
||
# region private methods | ||
|
||
def __lemmatize_core(self, text): | ||
assert(isinstance(text, str)) | ||
result_list = self.__mystem.lemmatize(self.__process_original_text(text)) | ||
return filter_whitespaces(result_list) | ||
|
||
@staticmethod | ||
def __process_original_text(text): | ||
return text.lower() | ||
|
||
# endregion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from arekit.common.utils import progress_bar_defined | ||
|
||
|
||
def iter_synonym_groups(input_file, sep=",", desc=""): | ||
""" All the synonyms groups organized in lines, separated by `sep` | ||
""" | ||
lines = input_file.readlines() | ||
|
||
for line in progress_bar_defined(lines, total=len(lines), desc=desc, unit="opins"): | ||
|
||
if isinstance(line, bytes): | ||
line = line.decode() | ||
|
||
yield line.split(sep) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import sqlite3 | ||
|
||
|
||
class SQLite3Service(object): | ||
|
||
def __init__(self): | ||
self.conn = None | ||
|
||
def connect(self, sqlite_path): | ||
self.conn = sqlite3.connect(sqlite_path) | ||
|
||
def disconnect(self): | ||
assert(self.conn is not None) | ||
self.conn.close() | ||
|
||
def table_rows_count(self, table_name): | ||
count_response = self.conn.execute(f"select count(*) from {table_name}").fetchone() | ||
return count_response[0] | ||
|
||
def get_column_names(self, table_name, filter_name=None): | ||
cursor = self.conn.execute(f"select * from {table_name}") | ||
column_names = list(map(lambda x: x[0], cursor.description)) | ||
return [col_name for col_name in column_names | ||
if filter_name is None or (filter_name is not None and filter_name(col_name))] | ||
|
||
def iter_rows(self, table_name, select_columns="*", column_value=None, value=None, return_dict=False): | ||
""" Returns array of the values in the same order as the one provided in `select_columns` parameter. | ||
""" | ||
if column_value is not None and value is not None: | ||
cursor = self.conn.execute( | ||
f"select {select_columns} from {table_name} where ({column_value} = ?)", (value,)) | ||
else: | ||
cursor = self.conn.execute(f"select {select_columns} from {table_name}") | ||
|
||
if not return_dict: | ||
for row in cursor.fetchall(): | ||
yield row | ||
|
||
# Return as dictionary | ||
column_names = list(map(lambda x: x[0], cursor.description)) | ||
for row in cursor.fetchall(): | ||
yield {col_name: row[i] for i, col_name in enumerate(column_names)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import logging | ||
|
||
import torch | ||
from torch.utils import data | ||
|
||
from arelight.third_party.sqlite3 import SQLite3Service | ||
|
||
|
||
class SQLiteSentenceREDataset(data.Dataset): | ||
""" Sentence-level relation extraction dataset | ||
This is a original OpenNRE implementation, adapted for SQLite. | ||
""" | ||
|
||
def __init__(self, path, table_name, rel2id, tokenizer, kwargs, task_kwargs): | ||
""" | ||
Args: | ||
path: path of the input file sqlite file | ||
rel2id: dictionary of relation->id mapping | ||
tokenizer: function of tokenizing | ||
""" | ||
assert(isinstance(task_kwargs, dict)) | ||
assert("no_label" in task_kwargs) | ||
assert("default_id_column" in task_kwargs) | ||
assert("index_columns" in task_kwargs) | ||
assert("text_columns" in task_kwargs) | ||
super().__init__() | ||
self.path = path | ||
self.tokenizer = tokenizer | ||
self.rel2id = rel2id | ||
self.kwargs = kwargs | ||
self.task_kwargs = task_kwargs | ||
self.table_name = table_name | ||
self.sqlite_service = SQLite3Service() | ||
|
||
def iter_ids(self, id_column=None): | ||
col_name = self.task_kwargs["default_id_column"] if id_column is None else id_column | ||
for row in self.sqlite_service.iter_rows(select_columns=col_name, table_name=self.table_name): | ||
yield row[0] | ||
|
||
def __len__(self): | ||
return self.sqlite_service.table_rows_count(self.table_name) | ||
|
||
def __getitem__(self, index): | ||
found_text_columns = self.sqlite_service.get_column_names( | ||
table_name=self.table_name, filter_name=lambda col_name: col_name in self.task_kwargs["text_columns"]) | ||
|
||
iter_rows = self.sqlite_service.iter_rows( | ||
select_columns=",".join(self.task_kwargs["index_columns"] + found_text_columns), | ||
value=index, | ||
column_value=self.task_kwargs["default_id_column"], | ||
table_name=self.table_name) | ||
|
||
fetched_row = next(iter_rows) | ||
|
||
opennre_item = { | ||
"text": " ".join(fetched_row[-len(found_text_columns):]), | ||
"h": {"pos": [fetched_row[0], fetched_row[0] + 1]}, | ||
"t": {"pos": [fetched_row[1], fetched_row[1] + 1]}, | ||
} | ||
|
||
seq = list(self.tokenizer(opennre_item, **self.kwargs)) | ||
|
||
return [self.rel2id[self.task_kwargs["no_label"]]] + seq # label, seq1, seq2, ... | ||
|
||
def collate_fn(data): | ||
data = list(zip(*data)) | ||
labels = data[0] | ||
seqs = data[1:] | ||
batch_labels = torch.tensor(labels).long() # (B) | ||
batch_seqs = [] | ||
for seq in seqs: | ||
batch_seqs.append(torch.cat(seq, 0)) # (B, L) | ||
return [batch_labels] + batch_seqs | ||
|
||
def eval(self, pred_result, use_name=False): | ||
raise NotImplementedError() | ||
|
||
def __enter__(self): | ||
self.sqlite_service.connect(self.path) | ||
logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format( | ||
self.path, len(self), len(self.rel2id))) | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self.sqlite_service.disconnect() | ||
|
||
|
||
def sentence_re_loader(path, table_name, rel2id, tokenizer, batch_size, shuffle, | ||
task_kwargs, num_workers, collate_fn=SQLiteSentenceREDataset.collate_fn, **kwargs): | ||
dataset = SQLiteSentenceREDataset(path=path, table_name=table_name, rel2id=rel2id, | ||
tokenizer=tokenizer, kwargs=kwargs, task_kwargs=task_kwargs) | ||
data_loader = data.DataLoader(dataset=dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
pin_memory=True, | ||
num_workers=num_workers, | ||
collate_fn=collate_fn) | ||
return data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.