Skip to content

Commit

Permalink
Push the 0.25.0 version
Browse files Browse the repository at this point in the history
nicolay-r committed Dec 21, 2024
2 parents 9a0231a + 1993d47 commit 8b597a6
Showing 88 changed files with 4,887 additions and 1,416 deletions.
36 changes: 7 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ARElight 0.24.0
# ARElight 0.25.0

![](https://img.shields.io/badge/Python-3.9-brightgreen.svg)
![](https://img.shields.io/badge/AREkit-0.24.0-orange.svg)
![](https://img.shields.io/badge/AREkit-0.25.1-orange.svg)
[![](https://img.shields.io/badge/demo-0.24.0-purple.svg)](https://guardeec.github.io/arelight_demo/template.html)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nicolay-r/ARElight/blob/v0.24.0/ARElight.ipynb)

@@ -12,13 +12,15 @@
</p>

ARElight is an application for a granular view onto sentiments between mentioned named entities
in texts.
in texts.
This repository is a part of the **ECIR-2024** demo paper:
[ARElight: Context Sampling of Large Texts for Deep Learning Relation Extraction](https://link.springer.com/chapter/10.1007/978-3-031-56069-9_23).


# Installation

```bash
pip install git+https://github.com/nicolay-r/arelight@v0.24.0
pip install git+https://github.com/nicolay-r/arelight@v0.24.1
```

## Usage: Inference
@@ -86,7 +88,7 @@ Parameters:
* `batch-size` -- amount of samples per single inference iteration.
* `tokens-per-context` -- size of input.
* `bert-torch-checkpoint` -- fine-tuned state.
* `device-type` -- `cpu` or `gpu`.
* `device-type` -- `cpu` or `cuda`.
* `labels-fmt` -- list of the mappings from `label` to integer value; is a `p:1,n:2,u:0` by default, where:
* `p` -- positive label, which is mapped to `1`.
* `n` -- negative label, which is mapped to `2`.
@@ -99,31 +101,7 @@ Parameters:
Framework parameters mentioned above as well as their related setups might be ommited.

</details>

To Launch Graph Builder for D3JS and (optional) start DEMO server for collections in `output` dir:

```bash
cd output && python -m http.server 8000
```

Finally, you may follow the demo page at `http://0.0.0.0:8000/`

[![](https://img.shields.io/badge/demo-0.24.0-purple.svg)](https://guardeec.github.io/arelight_demo/template.html)

![image](https://github.com/nicolay-r/ARElight/assets/14871187/341f3b51-d639-46b6-83fe-99b542b1751b)

## Layout of the files in output
```
output/
├── description/
└── ... // graph descriptions in JSON.
├── force/
└── ... // force graphs in JSON.
├── radial/
└── ... // radial graphs in JSON.
└── index.html // main HTML demo page.
```

## Usage: Graph Operations

For graph analysis you can perform several graph operations by this script:
31 changes: 31 additions & 0 deletions arelight/arekit/chunk_it.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
class ChunkIterator:

def __init__(self, data_iter, batch_size, chunk_limit):
assert(isinstance(batch_size, int) and batch_size > 0)
self.__data_iter = data_iter
self.__index = -1
self.__batch_size = batch_size
self.__chunk_limit = chunk_limit
self.__buffer = []

def __iter__(self):
return self

def __next__(self):
while True:
if len(self.__buffer) > 0:
break
try:
data = next(self.__data_iter)
self.__index += 1
except StopIteration:
break
for chunk_start in range(0, len(data), self.__chunk_limit):
chunk = data[chunk_start:chunk_start + self.__chunk_limit]
self.__buffer.append([self.__index, chunk])

if len(self.__buffer) > 0:
return self.__buffer.pop(0)

raise StopIteration

14 changes: 14 additions & 0 deletions arelight/arekit/custom_sqlite_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from arelight.readers.base import BaseReader


class CustomSQliteReader(BaseReader):

def __init__(self, storage_type, **storage_kwargs):
self._storage_kwargs = storage_kwargs
self._storage_type = storage_type

def extension(self):
return ".sqlite"

def read(self, target):
return self._storage_type(path=target, **self._storage_kwargs)
22 changes: 22 additions & 0 deletions arelight/arekit/joined_sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sqlite3

from arekit.common.data.const import ID
from arekit.common.data.storages.base import BaseRowsStorage


class JoinedSQliteBasedRowsStorage(BaseRowsStorage):

def __init__(self, path, table_name_a, table_name_b, **kwargs):
super(JoinedSQliteBasedRowsStorage, self).__init__(**kwargs)
self.__path = path
self.__table_name_a = table_name_a
self.__table_name_b = table_name_b
self.__conn = None

def _iter_rows(self):
with sqlite3.connect(self.__path) as conn:
cursor = conn.execute(f"select * from {self.__table_name_a} inner join {self.__table_name_b}"
f" on {self.__table_name_a}.{ID}={self.__table_name_b}.{ID}")
for row_index, row in enumerate(cursor.fetchall()):
row_dict = {cursor.description[i][0]: value for i, value in enumerate(row)}
yield row_index, row_dict
19 changes: 0 additions & 19 deletions arelight/arekit/parse_predict.py

This file was deleted.

8 changes: 8 additions & 0 deletions arelight/arekit/parsed_row_service.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,14 @@ class ParsedSampleRowExtraService(object):
obj_id=parsed_row[const.T_IND],
obj_ids=parsed_row[const.ENTITIES],
obj_values=parsed_row[const.ENTITY_VALUES]),
"SourceType": lambda parsed_row: ParsedSampleRowExtraService.calc_obj_value(
obj_id=parsed_row[const.S_IND],
obj_ids=parsed_row[const.ENTITIES],
obj_values=parsed_row[const.ENTITY_TYPES]),
"TargetType": lambda parsed_row: ParsedSampleRowExtraService.calc_obj_value(
obj_id=parsed_row[const.T_IND],
obj_ids=parsed_row[const.ENTITIES],
obj_values=parsed_row[const.ENTITY_TYPES]),
}

@staticmethod
68 changes: 68 additions & 0 deletions arelight/arekit/sample_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from arekit.common.data.rows_fmt import create_base_column_fmt
from arekit.common.data.rows_parser import ParsedSampleRow
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.experiment.data_type import DataType
from arekit.common.labels.base import NoLabel

from arelight.arekit.custom_sqlite_reader import CustomSQliteReader
from arelight.arekit.joined_sqlite import JoinedSQliteBasedRowsStorage
from arelight.arekit.parsed_row_service import ParsedSampleRowExtraService
from arelight.arekit.samples_io import CustomSamplesIO
from arelight.pipelines.demo.labels.base import PositiveLabel, NegativeLabel
from arelight.pipelines.demo.labels.formatter import CustomLabelsFormatter


class AREkitSamplesService(object):

@staticmethod
def _extract_label_from_row(parsed_row):
if parsed_row["col_0"] == 1:
return NoLabel()
elif parsed_row["col_1"] == 1:
return PositiveLabel()
elif parsed_row["col_2"] == 1:
return NegativeLabel()

@staticmethod
def iter_samples_and_predict_sqlite3(sqlite_filepath, samples_table_name, predict_table_name,
filter_record_func=None):
assert(callable(filter_record_func) or filter_record_func is None)

samples_io = CustomSamplesIO(
create_target_func=lambda _: sqlite_filepath,
reader=CustomSQliteReader(table_name_a=samples_table_name, table_name_b=predict_table_name,
storage_type=JoinedSQliteBasedRowsStorage))

samples_filepath = samples_io.create_target(data_type=DataType.Test)
samples = samples_io.Reader.read(samples_filepath)
assert (isinstance(samples, BaseRowsStorage))

column_fmts = [
# default parameters.
create_base_column_fmt(fmt_type="parser"),
# sentiment score.
{"col_0": lambda value: int(value), "col_1": lambda value: int(value), "col_2": lambda value: int(value)}
]

labels_fmt = CustomLabelsFormatter()

for ind, sample_row in samples:

parsed_row = ParsedSampleRow(sample_row, columns_fmts=column_fmts, no_value_func=lambda: None)

# reading label.

record = {
"filename": sample_row["doc_id"].split(':')[0],
"text": sample_row["text_a"].split(),
"s_val": ParsedSampleRowExtraService.calc("SourceValue", parsed_row=parsed_row),
"t_val": ParsedSampleRowExtraService.calc("TargetValue", parsed_row=parsed_row),
"s_type": ParsedSampleRowExtraService.calc("SourceType", parsed_row=parsed_row),
"t_type": ParsedSampleRowExtraService.calc("TargetType", parsed_row=parsed_row),
"label": labels_fmt.label_to_str(AREkitSamplesService._extract_label_from_row(parsed_row))
}

if filter_record_func is None:
yield record
elif filter_record_func(record):
yield record
38 changes: 38 additions & 0 deletions arelight/arekit/samples_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.contrib.utils.data.writers.base import BaseWriter

from arelight.readers.base import BaseReader


class CustomSamplesIO(BaseSamplesIO):
""" Samples default IO utils for samples.
Sample is a text part which include pair of attitude participants.
This class allows to provide saver and loader for such entries, bubbed as samples.
Samples required for machine learning training/inferring.
"""

def __init__(self, create_target_func, writer=None, reader=None):
assert(isinstance(writer, BaseWriter) or writer is None)
assert(isinstance(reader, BaseReader) or reader is None)
assert(callable(create_target_func))

self.__writer = writer
self.__reader = reader
self.__create_target_func = create_target_func

self.__target_extension = None
if writer is not None:
self.__target_extension = writer.extension()
elif reader is not None:
self.__target_extension = reader.extension()

@property
def Reader(self):
return self.__reader

@property
def Writer(self):
return self.__writer

def create_target(self, data_type):
return self.__create_target_func(data_type) + self.__target_extension
9 changes: 9 additions & 0 deletions arelight/arekit/utils_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def string_terms_to_list(terms):
r = []
for t in terms:
if isinstance(t, str):
for i in t.split(' '):
r.append(i)
else:
r.append(t)
return r
4 changes: 2 additions & 2 deletions arelight/backend/d3js/relations_graph_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import Counter


def make_graph_from_relations_array(relations, entity_values, entity_types, min_links, weights=True):
def make_graph_from_relations_array(graph_name, relations, entity_values, entity_types, min_links, weights=True):
""" This is a method composes a dictionary from the relations data between entities.
(C) Maxim Kolomeets (Originally)
@@ -88,4 +88,4 @@ def __get_type(v):
node_max = max(used_nodes.values()) if used_nodes else 0
nodes = [{"id": id, "c": used_nodes[id]/node_max if weights else 1} for id in used_nodes]

return {"nodes": nodes, "links": links}
return {"basis": [graph_name], "equation": "["+graph_name+"]", "nodes": nodes, "links": links}
19 changes: 16 additions & 3 deletions arelight/backend/d3js/relations_graph_operations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import logging
import warnings


OP_UNION = "UNION"
OP_INTERSECTION = "INTERSECTION"
OP_DIFFERENCE = "DIFFERENCE"

OPERATION_MAP = {}
OPERATION_MAP[OP_UNION] = "+"
OPERATION_MAP[OP_INTERSECTION] = "∩"
OPERATION_MAP[OP_DIFFERENCE] = "-"

logger = logging.getLogger(__name__)


def graphs_operations(graph_A, graph_B, operation=OP_UNION, weights=True):
"""
@@ -22,7 +30,7 @@ def graphs_operations(graph_A, graph_B, operation=OP_UNION, weights=True):
dict: The resulting graph after performing the operation.
"""

print(f"\nPerforming {operation} on graphs...")
logger.info(f"\nPerforming {operation} on graphs...")

def link_key(link):
"""Generate a key for a link."""
@@ -59,7 +67,7 @@ def link_key(link):
if operation == OP_DIFFERENCE:
for l, c in links_A.items():
if l in links_B and c - links_B[l] > 0:
print(" ", l, c, "=>", l, links_B[l])
logger.info(" ", l, c, "=>", l, links_B[l])
links_[l] = c - links_B[l]
if l not in links_B:
links_[l] = c
@@ -85,7 +93,12 @@ def link_key(link):
used_nodes[t] = used_nodes.get(t, 0) + c

nodes = [{"id": id, "c": c} for id, c in used_nodes.items()]
result_graph = {"nodes": nodes, "links": links}
if operation == OP_DIFFERENCE:
basis = list(set(graph_A["basis"]).difference(graph_B["basis"]))
else:
basis = list(set(graph_A["basis"]).union(graph_B["basis"]))
equation = "(" + graph_A["equation"] + ")" + OPERATION_MAP[operation] + "(" + graph_B["equation"] + ")"
result_graph = {"basis": basis, "equation": equation, "nodes": nodes, "links": links}

# Assign weights if not used.
if not weights:
561 changes: 0 additions & 561 deletions arelight/backend/d3js/ui_web.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions arelight/backend/d3js/utils_graph.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,9 @@ def save_graph(graph, out_dir, out_filename, convert_to_radial=True):
if not exists(out_dir):
makedirs(out_dir)

data_filepath = join(out_dir, "{}.json".format(out_filename))
# Make sure that we have no extention related to the expected format.
no_ext_basename = out_filename.replace(".json", "")
target_filepath = join(out_dir, f"{no_ext_basename}.json")
# Convert to radial graph.
radial_graph = graph_to_radial(graph) if convert_to_radial else graph
save_json(data=radial_graph, file_path=data_filepath)
save_json(data=radial_graph, file_path=target_filepath)
Empty file added arelight/data/__init__.py
Empty file.
Empty file.
68 changes: 68 additions & 0 deletions arelight/data/repositories/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from arekit.common.data.input.providers.columns.base import BaseColumnsProvider
from arekit.common.data.input.providers.contents import ContentsProvider
from arekit.common.data.input.providers.rows.base import BaseRowProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.base import BaseWriter


class BaseInputRepository(object):

def __init__(self, columns_provider, rows_provider, storage):
assert(isinstance(columns_provider, BaseColumnsProvider))
assert(isinstance(rows_provider, BaseRowProvider))
assert(isinstance(storage, BaseRowsStorage))

self._columns_provider = columns_provider
self._rows_provider = rows_provider
self._storage = storage

# Do setup operations.
self._setup_columns_provider()
self._setup_rows_provider()

# region protected methods

def _setup_columns_provider(self):
pass

def _setup_rows_provider(self):
pass

# endregion

def populate(self, contents_provider, doc_ids, desc="", writer=None, target=None):
assert(isinstance(contents_provider, ContentsProvider))
assert(isinstance(self._storage, BaseRowsStorage))
assert(isinstance(doc_ids, list))
assert(isinstance(writer, BaseWriter) or writer is None)
assert(isinstance(target, str) or target is None)

def iter_rows(idle_mode):
return self._rows_provider.iter_by_rows(
contents_provider=contents_provider,
doc_ids_iter=doc_ids,
idle_mode=idle_mode)

self._storage.init_empty(columns_provider=self._columns_provider)

is_async_write_mode_on = writer is not None and target is not None

if is_async_write_mode_on:
writer.open_target(target)

self._storage.fill(lambda idle_mode: iter_rows(idle_mode),
columns_provider=self._columns_provider,
row_handler=lambda: writer.commit_line(self._storage) if is_async_write_mode_on else None,
desc=desc)

if is_async_write_mode_on:
writer.close_target()

def push(self, writer, target, free_storage=True):
if not isinstance(self._storage, RowCacheStorage):
writer.write_all(self._storage, target)

# After writing we free the contents of the storage.
if free_storage:
self._storage.free()
23 changes: 23 additions & 0 deletions arelight/data/repositories/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import logging

from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider

from arelight.data.repositories.base import BaseInputRepository

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class BaseInputSamplesRepository(BaseInputRepository):

def _setup_rows_provider(self):
""" Setup store labels.
"""
assert(isinstance(self._rows_provider, BaseSampleRowProvider))
self._rows_provider.set_store_labels(self._columns_provider.StoreLabels)

def _setup_columns_provider(self):
""" Setup text column names.
"""
text_column_names = list(self._rows_provider.TextProvider.iter_columns())
self._columns_provider.set_text_column_names(text_column_names)
95 changes: 95 additions & 0 deletions arelight/data/serializer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.common.experiment.data_type import DataType
from arekit.common.pipeline.items.base import BasePipelineItem

from arelight.data.serializer_helper import InputDataSerializationHelper


class BaseSerializerPipelineItem(BasePipelineItem):

def __init__(self, rows_provider, samples_io, save_labels_func, storage, **kwargs):
""" sample_rows_formatter:
how we format input texts for a BERT model, for example:
- single text
- two sequences, separated by [SEP] token
save_labels_func: function
data_type -> bool
"""
assert(isinstance(rows_provider, BaseSampleRowProvider))
assert(isinstance(samples_io, BaseSamplesIO))
assert(callable(save_labels_func))
assert(isinstance(storage, BaseRowsStorage))
super(BaseSerializerPipelineItem, self).__init__(**kwargs)

self._rows_provider = rows_provider
self._samples_io = samples_io
self._save_labels_func = save_labels_func
self._storage = storage

def _serialize_iteration(self, data_type, pipeline, data_folding, doc_ids):
assert(isinstance(data_type, DataType))
assert(isinstance(pipeline, list))
assert(isinstance(data_folding, dict) or data_folding is None)
assert(isinstance(doc_ids, list) or doc_ids is None)
assert(doc_ids is not None or data_folding is not None)

repos = {
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self._save_labels_func(data_type),
rows_provider=self._rows_provider,
storage=self._storage),
}

writer_and_targets = {
"sample": (self._samples_io.Writer,
self._samples_io.create_target(data_type=data_type)),
}

for description, repo in repos.items():

if data_folding is None:
# Consider only the predefined doc_ids.
doc_ids_iter = doc_ids
else:
# Take particular data_type.
doc_ids_iter = data_folding[data_type]
# Consider only predefined doc_ids.
if doc_ids is not None:
doc_ids_iter = set(doc_ids_iter).intersection(doc_ids)

InputDataSerializationHelper.fill_and_write(
repo=repo,
pipeline=pipeline,
doc_ids_iter=doc_ids_iter,
desc="{desc} [{data_type}]".format(desc=description, data_type=data_type),
writer=writer_and_targets[description][0],
target=writer_and_targets[description][1])

def _handle_iteration(self, data_type_pipelines, data_folding, doc_ids):
""" Performing data serialization for a particular iteration
"""
assert(isinstance(data_type_pipelines, dict))
for data_type, pipeline in data_type_pipelines.items():
self._serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding,
doc_ids=doc_ids)

def apply_core(self, input_data, pipeline_ctx):
"""
data_type_pipelines: dict of, for example:
{
DataType.Train: BasePipeline,
DataType.Test: BasePipeline
}
data_type_pipelines: doc_id -> parsed_doc -> annot -> opinion linkages
for example, function: sentiment_attitude_extraction_default_pipeline
doc_ids: optional
this parameter allows to limit amount of documents considered for sampling
"""
assert("data_type_pipelines" in pipeline_ctx)
self._handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"),
doc_ids=pipeline_ctx.provide_or_none("doc_ids"),
data_folding=pipeline_ctx.provide_or_none("data_folding"))
43 changes: 43 additions & 0 deletions arelight/data/serializer_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging

from collections.abc import Iterable

from arekit.common.data.input.providers.columns.sample import SampleColumnsProvider
from arekit.common.data.input.providers.rows.base import BaseRowProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.contrib.utils.data.contents.opinions import InputTextOpinionProvider

from arelight.data.repositories.base import BaseInputRepository
from arelight.data.repositories.sample import BaseInputSamplesRepository

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class InputDataSerializationHelper(object):

@staticmethod
def create_samples_repo(keep_labels, rows_provider, storage):
assert(isinstance(rows_provider, BaseRowProvider))
assert(isinstance(keep_labels, bool))
assert(isinstance(storage, BaseRowsStorage))
return BaseInputSamplesRepository(
columns_provider=SampleColumnsProvider(store_labels=keep_labels),
rows_provider=rows_provider,
storage=storage)

@staticmethod
def fill_and_write(pipeline, repo, target, writer, doc_ids_iter, desc=""):
assert(isinstance(pipeline, list))
assert(isinstance(doc_ids_iter, Iterable))
assert(isinstance(repo, BaseInputRepository))

doc_ids = list(doc_ids_iter)

repo.populate(contents_provider=InputTextOpinionProvider(pipeline),
doc_ids=doc_ids,
desc=desc,
writer=writer,
target=target)

repo.push(writer=writer, target=target)
Empty file.
63 changes: 63 additions & 0 deletions arelight/data/writers/csv_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import csv
import os
from os.path import dirname

from arekit.common.data.storages.base import BaseRowsStorage
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.base import BaseWriter


class NativeCsvWriter(BaseWriter):

def __init__(self, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL, header=True):
self.__target_f = None
self.__writer = None
self.__create_writer_func = lambda f: csv.writer(
f, delimiter=delimiter, quotechar=quotechar, quoting=quoting)
self.__header = header
self.__header_written = None

def extension(self):
return ".csv"

@staticmethod
def __iter_storage_column_names(storage):
""" Iter only those columns that existed in storage.
"""
for col_name in storage.iter_column_names():
if col_name in storage.RowCache:
yield col_name

def open_target(self, target):
os.makedirs(dirname(target), exist_ok=True)
self.__target_f = open(target, "w")
self.__writer = self.__create_writer_func(self.__target_f)
self.__header_written = not self.__header

def close_target(self):
self.__target_f.close()

def commit_line(self, storage):
assert(isinstance(storage, RowCacheStorage))
assert(self.__writer is not None)

if not self.__header_written:
self.__writer.writerow(list(self.__iter_storage_column_names(storage)))
self.__header_written = True

line_data = list(map(lambda col_name: storage.RowCache[col_name],
self.__iter_storage_column_names(storage)))
self.__writer.writerow(line_data)

def write_all(self, storage, target):
""" Writes all the `storage` rows
into the `target` filepath, formatted as CSV.
"""
assert(isinstance(storage, BaseRowsStorage))

with open(target, "w") as f:
writer = self.__create_writer_func(f)
for _, row in storage:
#content = [row[col_name] for col_name in storage.iter_column_names()]
content = [v for v in row]
writer.writerow(content)
40 changes: 40 additions & 0 deletions arelight/data/writers/csv_pd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging

from arekit.common.data.input.providers.columns.base import BaseColumnsProvider
from arekit.common.utils import create_dir_if_not_exists
from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage
from arekit.contrib.utils.data.writers.base import BaseWriter

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class PandasCsvWriter(BaseWriter):

def __init__(self, write_header):
super(PandasCsvWriter, self).__init__()
self.__write_header = write_header

def extension(self):
return ".tsv.gz"

def write_all(self, storage, target):
assert(isinstance(storage, PandasBasedRowsStorage))
assert(isinstance(target, str))

create_dir_if_not_exists(target)

# Temporary hack, remove it in future.
df = storage.DataFrame

logger.info("Saving... {length}: {filepath}".format(length=len(storage), filepath=target))
df.to_csv(target,
sep='\t',
encoding='utf-8',
columns=[c for c in df.columns if c != BaseColumnsProvider.ROW_ID],
index=False,
float_format="%.0f",
compression='gzip',
header=self.__write_header)

logger.info("Saving completed!")
132 changes: 132 additions & 0 deletions arelight/data/writers/json_opennre.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import json
import logging
import os
from os.path import dirname

from arekit.common.data import const
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.base import BaseWriter

logger = logging.getLogger(__name__)


class OpenNREJsonWriter(BaseWriter):
""" This is a bag-based writer for the samples.
Project page: https://github.com/thunlp/OpenNRE
Every bag presented as follows:
{
'text' or 'token': ...,
'h': {'pos': [start, end], 'id': ... },
't': {'pos': [start, end], 'id': ... }
'id': "id_of_the_text_opinion"
}
In terms of the linked opinions (i0, i1, etc.) we consider id of the first opinion in linkage.
During the dataset reading stage via OpenNRE, these linkages automaticaly groups into bags.
"""

def __init__(self, text_columns, encoding="utf-8", na_value="NA", keep_extra_columns=True,
skip_extra_existed=True):
""" text_columns: list
column names that expected to be joined into a single (token) column.
"""
assert(isinstance(text_columns, list))
assert(isinstance(encoding, str))
self.__text_columns = text_columns
self.__encoding = encoding
self.__target_f = None
self.__keep_extra_columns = keep_extra_columns
self.__na_value = na_value
self.__skip_extra_existed = skip_extra_existed

def extension(self):
return ".jsonl"

@staticmethod
def __format_row(row, na_value, text_columns, keep_extra_columns, skip_extra_existed):
""" Formatting that is compatible with the OpenNRE.
"""
assert(isinstance(na_value, str))

sample_id = row[const.ID]
s_ind = int(row[const.S_IND])
t_ind = int(row[const.T_IND])
bag_id = str(row[const.OPINION_ID])

# Gather tokens.
tokens = []
for text_col in text_columns:
if text_col in row:
tokens.extend(row[text_col].split())

# Filtering JSON row.
formatted_data = {
"id": bag_id,
"id_orig": sample_id,
"token": tokens,
"h": {"pos": [s_ind, s_ind + 1], "id": str(bag_id + "s")},
"t": {"pos": [t_ind, t_ind + 1], "id": str(bag_id + "t")},
"relation": str(int(row[const.LABEL_UINT])) if const.LABEL_UINT in row else na_value
}

# Register extra fields (optionally).
if keep_extra_columns:
for key, value in row.items():
if key not in formatted_data and key not in text_columns:
formatted_data[key] = value
else:
if not skip_extra_existed:
raise Exception(f"key `{key}` is already exist in formatted data "
f"or a part of the text columns list: {text_columns}")

return formatted_data

def open_target(self, target):
os.makedirs(dirname(target), exist_ok=True)
self.__target_f = open(target, "w")
pass

def close_target(self):
self.__target_f.close()

def commit_line(self, storage):
assert(isinstance(storage, RowCacheStorage))

# Collect existed columns.
row_data = {}
for col_name in storage.iter_column_names():
if col_name not in storage.RowCache:
continue
row_data[col_name] = storage.RowCache[col_name]

bag = self.__format_row(row_data, text_columns=self.__text_columns,
keep_extra_columns=self.__keep_extra_columns,
na_value=self.__na_value,
skip_extra_existed=self.__skip_extra_existed)

self.__write_bag(bag=bag, json_file=self.__target_f)

@staticmethod
def __write_bag(bag, json_file):
assert(isinstance(bag, dict))
json.dump(bag, json_file, separators=(",", ":"), ensure_ascii=False)
json_file.write("\n")

def write_all(self, storage, target):
assert(isinstance(storage, BaseRowsStorage))
assert(isinstance(target, str))

logger.info("Saving... {rows}: {filepath}".format(rows=(len(storage)), filepath=target))

os.makedirs(os.path.dirname(target), exist_ok=True)
with open(target, "w", encoding=self.__encoding) as json_file:
for row_index, row in storage:
self.__write_bag(bag=self.__format_row(row, text_columns=self.__text_columns,
keep_extra_columns=self.__keep_extra_columns,
na_value=self.__na_value,
skip_extra_existed=self.__skip_extra_existed),
json_file=json_file)

logger.info("Saving completed!")
114 changes: 114 additions & 0 deletions arelight/data/writers/sqlite_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
import sqlite3
from os.path import dirname

from arekit.common.data import const
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.base import BaseWriter


class SQliteWriter(BaseWriter):
""" TODO. This implementation is dedicated for the writing concepts of the data
serialization pipeline. However we add the SQLite3 service, it would be
right to refactor and utlize some core functionality from the core/service/sqlite.py
"""

def __init__(self, table_name="contents", index_column_names=None, skip_existed=False, clear_table=True):
""" index_column_names: list or None
column names should be considered to build a unique index;
if None, the default 'const.ID' will be considered for row indexation.
"""
assert (isinstance(index_column_names, list) or index_column_names is None)
self.__index_column_names = index_column_names if index_column_names is not None else [const.ID]
self.__table_name = table_name
self.__conn = None
self.__cur = None
self.__need_init_table = True
self.__origin_column_names = None
self.__skip_existed = skip_existed
self.__clear_table = clear_table

def extension(self):
return ".sqlite"

@staticmethod
def __iter_storage_column_names(storage):
""" Iter only those columns that existed in storage.
"""
assert (isinstance(storage, RowCacheStorage))
for col_name, col_type in zip(storage.iter_column_names(), storage.iter_column_types()):
if col_name in storage.RowCache:
yield col_name, col_type

def __init_table(self, column_data):
# Compose column name with the related SQLITE type.
column_types = ",".join([" ".join([col_name, self.type_to_sqlite(col_type)])
for col_name, col_type in column_data])
# Create table if not exists.
self.__cur.execute(f"CREATE TABLE IF NOT EXISTS {self.__table_name}({column_types})")
# Table exists, however we may optionally remove the content from it.
if self.__clear_table:
self.__cur.execute(f"DELETE FROM {self.__table_name};")
# Create index.
index_name = f"i_{self.__table_name}_id"
self.__cur.execute(f"DROP INDEX IF EXISTS {index_name};")
self.__cur.execute("CREATE INDEX IF NOT EXISTS {index} ON {table}({columns})".format(
index=index_name,
table=self.__table_name,
columns=", ".join(self.__index_column_names)
))
self.__origin_column_names = [col_name for col_name, _ in column_data]

@staticmethod
def type_to_sqlite(col_type):
""" This is a simple function that provides conversion from the
base numpy types to SQLITE.
NOTE: this method represent a quick implementation for supporting
types, however it is far away from the generalized implementation.
"""
if isinstance(col_type, str):
if 'int' in col_type:
return 'INTEGER'

return "TEXT"

def open_target(self, target):
os.makedirs(dirname(target), exist_ok=True)
self.__conn = sqlite3.connect(target)
self.__cur = self.__conn.cursor()

def commit_line(self, storage):
assert (isinstance(storage, RowCacheStorage))

column_data = list(self.__iter_storage_column_names(storage))

if self.__need_init_table:
self.__init_table(column_data)
self.__need_init_table = False

# Check whether the related row is already exist in SQLITE database.
row_id = storage.RowCache[const.ID]
top_row = self.__cur.execute(f"SELECT EXISTS(SELECT 1 FROM {self.__table_name} WHERE id='{row_id}');")
is_exists = top_row.fetchone()[0]
if is_exists == 1 and self.__skip_existed:
return

line_data = [storage.RowCache[col_name] for col_name, _ in column_data]
parameters = ",".join(["?"] * len(line_data))

assert (len(self.__origin_column_names) == len(line_data))

self.__cur.execute(
f"INSERT OR REPLACE INTO {self.__table_name} VALUES ({parameters})",
tuple(line_data))

self.__conn.commit()

def close_target(self):
self.__cur = None
self.__origin_column_names = None
self.__need_init_table = True
self.__conn.close()

def write_all(self, storage, target):
pass
56 changes: 24 additions & 32 deletions arelight/ner/base.py
Original file line number Diff line number Diff line change
@@ -11,56 +11,48 @@ class BaseNER(object):
inner_tag = 'I'

def extract(self, sequences):
return self.__extract_objects_core(sequences=sequences)

def __extract_objects_core(self, sequences):
assert(isinstance(sequences, list))
seqs_tags = self._extract_tags(sequences)
assert(len(sequences) == len(seqs_tags))

extracted = []
for sequence_ind, sequence in enumerate(sequences):
seq_tags = seqs_tags[sequence_ind]
objs_len = [len(entry) for entry in self.__merge(sequence, seq_tags)]
objs_type = [self.__tag_type(tag) for tag in seq_tags if self.__tag_part(tag) == self.begin_tag]
objs_positions = [j for j, tag in enumerate(seq_tags) if self.__tag_part(tag) == self.begin_tag]

assert(len(objs_len) == len(objs_type) == len(objs_positions))
terms, labels = self._forward(sequences)
return self.iter_descriptors(terms=terms, labels=labels)

seq_obj_descriptors = [NerObjectDescriptor(pos=objs_positions[i],
length=objs_len[i],
obj_type=objs_type[i])
for i in range(len(objs_len))]
def iter_descriptors(self, terms, labels):
assert(len(terms) == len(labels))
for seq, tags in zip(terms, labels):
objs_len = [len(entry) for entry in self.__iter_merged(seq, tags)]
objs_type = [self.__tag_type(tag) for tag in tags if self.__tag_part(tag) == self.begin_tag]
objs_positions = [j for j, tag in enumerate(tags) if self.__tag_part(tag) == self.begin_tag]

extracted.append(seq_obj_descriptors)
descriptors = [NerObjectDescriptor(pos=objs_positions[i], length=objs_len[i], obj_type=objs_type[i])
for i in range(len(objs_len))]
yield seq, descriptors

return extracted

def _extract_tags(self, seqences):
def _forward(self, seqences):
raise NotImplementedError()

# region private methods

def __merge(self, terms, tags):
merged = []
def __iter_merged(self, terms, tags):
buffer = None
for i, tag in enumerate(tags):
current_tag = self.__tag_part(tag)
if current_tag == self.begin_tag:
merged.append([terms[i]])
elif current_tag == self.inner_tag and len(merged) > 0:
merged[-1].append(terms[i])
return merged
if buffer is not None:
yield buffer
buffer = [terms[i]]
elif current_tag == self.inner_tag and buffer is not None:
buffer.append(terms[i])

if buffer is not None:
yield buffer

@staticmethod
def __tag_part(tag):
assert(isinstance(tag, str))
return tag if BaseNER.separator not in tag \
else tag[:tag.index(BaseNER.separator)]
return tag if BaseNER.separator not in tag else tag[:tag.index(BaseNER.separator)]

@staticmethod
def __tag_type(tag):
assert(isinstance(tag, str))
return "" if BaseNER.separator not in tag \
else tag[tag.index(BaseNER.separator) + 1:]
return "" if BaseNER.separator not in tag else tag[tag.index(BaseNER.separator) + 1:]

# endregion
39 changes: 2 additions & 37 deletions arelight/ner/deep_pavlov.py
Original file line number Diff line number Diff line change
@@ -13,40 +13,5 @@ def __init__(self, model_name):

# region Properties

def _extract_tags(self, sequences):
tokens, labels = self.__ner_model(sequences)
gathered_labels_seq = []
for i, sequence in enumerate(sequences):
_, labels = self.__tokens_to_terms(terms=sequence, tokens=tokens[i], labels=labels[i])
gathered_labels_seq.append(self.__gather(labels))
return gathered_labels_seq

@staticmethod
def __tokens_to_terms(terms, tokens, labels):
def __cur_term():
return len(joined_tokens) - 1

assert (len(labels) == len(tokens))

terms_lengths = [len(term) for term in terms]
current_lengths = [0] * len(terms)
joined_tokens = [[]]
joined_labels = [[]]
for i, token in enumerate(tokens):
if current_lengths[__cur_term()] == terms_lengths[__cur_term()]:
joined_tokens.append([])
joined_labels.append([])
joined_tokens[-1].append(token)
joined_labels[-1].append(labels[i])
current_lengths[__cur_term()] += len(token)

return joined_tokens, joined_labels

@staticmethod
def __gather(labels_in_lists):
return [labels[0] if len(labels) == 1 else DeepPavlovNER.__gather_many(labels)
for labels in labels_in_lists]

@staticmethod
def __gather_many(labels):
return 'O'
def _forward(self, sequences):
return self.__ner_model(sequences)
9 changes: 5 additions & 4 deletions arelight/pipelines/data/annot_pairs_nolabel.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter


def create_neutral_annotation_pipeline(synonyms, dist_in_terms_bound, terms_per_context,
doc_provider, text_parser, dist_in_sentences=0):
def create_neutral_annotation_pipeline(synonyms, dist_in_terms_bound, terms_per_context, batch_size,
doc_provider, text_pipeline, dist_in_sentences=0):

nolabel_annotator = AlgorithmBasedTextOpinionAnnotator(
value_to_group_id_func=lambda value:
@@ -28,13 +28,14 @@ def create_neutral_annotation_pipeline(synonyms, dist_in_terms_bound, terms_per_

annotation_pipeline = text_opinion_extraction_pipeline(
entity_index_func=lambda indexed_entity: indexed_entity.ID,
text_parser=text_parser,
pipeline_items=text_pipeline,
get_doc_by_id_func=doc_provider.by_id,
annotators=[
nolabel_annotator
],
text_opinion_filters=[
DistanceLimitedTextOpinionFilter(terms_per_context)
])
],
batch_size=batch_size)

return annotation_pipeline
8 changes: 2 additions & 6 deletions arelight/pipelines/demo/infer_bert.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from arelight.pipelines.items.backend_d3js_operations import D3jsGraphOperationsBackendPipelineItem
from arelight.pipelines.items.inference_writer import InferenceWriterPipelineItem
from arelight.predict_writer_csv import TsvPredictWriter


def demo_infer_texts_bert_pipeline(sampling_engines=None, infer_engines=None, backend_engines=None):
def demo_infer_texts_bert_pipeline(sampling_engines=None, infer_engines=None, backend_engines=None,
inference_writer=None):
assert(isinstance(sampling_engines, dict) or sampling_engines is None)
assert(isinstance(infer_engines, dict) or infer_engines is None)
assert(isinstance(backend_engines, dict) or backend_engines is None)
@@ -16,17 +16,13 @@ def demo_infer_texts_bert_pipeline(sampling_engines=None, infer_engines=None, ba
#####################################################################
# Serialization Items
#####################################################################

if "arekit" in sampling_engines:
from arelight.pipelines.items.serializer_arekit import AREkitSerializerPipelineItem
pipeline += [AREkitSerializerPipelineItem(**sampling_engines["arekit"])]

#####################################################################
# Inference Items
#####################################################################

inference_writer = TsvPredictWriter()

if "opennre" in infer_engines:
from arelight.pipelines.items.inference_bert_opennre import BertOpenNREInferencePipelineItem
pipeline += [BertOpenNREInferencePipelineItem(**infer_engines["opennre"]),
3 changes: 0 additions & 3 deletions arelight/pipelines/demo/result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from arekit.common.pipeline.context import PipelineContext
from arekit.contrib.utils.data.readers.csv_pd import PandasCsvReader

from arelight.run.utils import merge_dictionaries

@@ -16,8 +15,6 @@ def __init__(self, extra_params=None):
# Inference stage -------------------------
"iter_infer": None,
"iter_total": None,
# Inference stage -------------------------
"predict_reader": PandasCsvReader(compression='infer'), # The way we can read the predicted results.
}

super(PipelineResult, self).__init__(
32 changes: 17 additions & 15 deletions arelight/pipelines/items/backend_d3js_graphs.py
Original file line number Diff line number Diff line change
@@ -7,20 +7,19 @@
from arekit.common.experiment.data_type import DataType
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem

from arelight.arekit.parse_predict import iter_predicted_labels
from arelight.arekit.parsed_row_service import ParsedSampleRowExtraService
from arelight.backend.d3js.relations_graph_builder import make_graph_from_relations_array

from arelight.predict.provider import BasePredictProvider

logger = logging.getLogger(__name__)


class D3jsGraphsBackendPipelineItem(BasePipelineItem):

def __init__(self, graph_min_links=0.01, graph_a_labels=None, weights=True):
def __init__(self, graph_min_links=0.01, graph_a_labels=None, weights=True, **kwargs):
super(D3jsGraphsBackendPipelineItem, self).__init__(**kwargs)
self.__graph_min_links = graph_min_links

# Setup filters for the A and B graphs for further operations application.
@@ -58,28 +57,32 @@ def iter_column_value(self, samples, column_value):
yield parsed_row[column_value]

def apply_core(self, input_data, pipeline_ctx):
assert(isinstance(input_data, PipelineContext))

predict_filepath = input_data.provide("predict_filepath")
result_reader = input_data.provide("predict_reader")
labels_fmt = input_data.provide("labels_formatter")
collection_name = pipeline_ctx.provide("d3js_collection_name")
predict_filepath = pipeline_ctx.provide("predict_filepath")
result_reader = pipeline_ctx.provide("predict_reader")
labels_fmt = pipeline_ctx.provide("labels_formatter")
assert(isinstance(labels_fmt, StringLabelsFormatter))
labels_scaler = input_data.provide("labels_scaler")
labels_scaler = pipeline_ctx.provide("labels_scaler")
assert(isinstance(labels_scaler, BaseLabelScaler))
predict_storage = result_reader.read(predict_filepath)
assert(isinstance(predict_storage, BaseRowsStorage))

# Reading samples.
samples_io = input_data.provide("samples_io")
samples_io = pipeline_ctx.provide("samples_io")
samples_filepath = samples_io.create_target(data_type=DataType.Test)
samples = samples_io.Reader.read(samples_filepath)

# Reading labels.
labels_to_str = {str(labels_scaler.label_to_uint(label)): labels_fmt.label_to_str(label)
for label in labels_scaler.ordered_suppoted_labels()}
labels = list(iter_predicted_labels(predict_data=predict_storage, label_to_str=labels_to_str, keep_ind=False))
uint_labels_iter = BasePredictProvider.iter_from_storage(
predict_data=predict_storage,
uint_labels=[labels_scaler.label_to_uint(label) for label in labels_scaler.ordered_suppoted_labels()],
keep_ind=False)

labels = list(map(lambda item: labels_fmt.label_to_str(labels_scaler.uint_to_label(item)), uint_labels_iter))

graph = make_graph_from_relations_array(
graph_name=collection_name,
relations=self.__iter_relations(samples=samples,
labels=labels,
labels_filter_func=self.__graph_label_filter,
@@ -90,5 +93,4 @@ def apply_core(self, input_data, pipeline_ctx):
weights=self.__graph_weights)

# Saving graph as the collection name for it.
input_data.update("d3js_graph_a", value=graph)
input_data.update("d3js_collection_name", value=samples_io.Prefix)
pipeline_ctx.update("d3js_graph_a", value=graph)
42 changes: 14 additions & 28 deletions arelight/pipelines/items/backend_d3js_operations.py
Original file line number Diff line number Diff line change
@@ -4,11 +4,10 @@

from arekit.common.data.rows_fmt import create_base_column_fmt
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem

from arelight.backend.d3js.relations_graph_operations import graphs_operations
from arelight.backend.d3js.ui_web import save_demo_page, iter_ui_backend_folders, GRAPH_TYPE_RADIAL
from arelight.backend.d3js.ui_web import iter_ui_backend_folders, GRAPH_TYPE_RADIAL
from arelight.backend.d3js.utils_graph import save_graph


@@ -17,21 +16,20 @@

class D3jsGraphOperationsBackendPipelineItem(BasePipelineItem):

def __init__(self):
def __init__(self, **kwargs):
# Parameters for sampler.
super(D3jsGraphOperationsBackendPipelineItem, self).__init__(**kwargs)
self.__column_fmts = [create_base_column_fmt(fmt_type="parser")]

def apply_core(self, input_data, pipeline_ctx):
assert(isinstance(input_data, PipelineContext))

graph_a = input_data.provide_or_none("d3js_graph_a")
graph_b = input_data.provide_or_none("d3js_graph_b")
op = input_data.provide_or_none("d3js_graph_operations")
weights = input_data.provide_or_none("d3js_graph_weights")
target_dir = input_data.provide("d3js_graph_output_dir")
collection_name = input_data.provide("d3js_collection_name")
labels_fmt = input_data.provide("labels_formatter")
host_port = input_data.provide_or_none("d3js_host")

graph_a = pipeline_ctx.provide_or_none("d3js_graph_a")
graph_b = pipeline_ctx.provide_or_none("d3js_graph_b")
op = pipeline_ctx.provide_or_none("d3js_graph_operations")
weights = pipeline_ctx.provide_or_none("d3js_graph_weights")
target_dir = pipeline_ctx.provide("d3js_graph_output_dir")
collection_name = pipeline_ctx.provide("d3js_collection_name")
labels_fmt = pipeline_ctx.provide("labels_formatter")
assert(isinstance(labels_fmt, StringLabelsFormatter))

graph = graphs_operations(graph_A=graph_a, graph_B=graph_b, operation=op, weights=weights) \
@@ -45,19 +43,7 @@ def apply_core(self, input_data, pipeline_ctx):
out_filename=f"{collection_name}",
convert_to_radial=True if graph_type == GRAPH_TYPE_RADIAL else False)

# Save Graph description.
save_demo_page(target_dir=target_dir,
collection_name=collection_name,
host_root_path=f"http://localhost:{host_port}/" if host_port is not None else "./",
desc_name=input_data.provide_or_none("d3js_collection_description"),
desc_labels={label_type.__name__: labels_fmt.label_to_str(label_type())
for label_type in labels_fmt._stol.values()})

print(f"\nDataset is completed and saved in the following locations:")
logger.info(f"\n")
logger.info(f"Dataset is completed and saved in the following locations:")
for subfolder in iter_ui_backend_folders(keep_desc=True, keep_graph=True):
print(f"- {os.path.join(target_dir, subfolder, collection_name)}")

# Print system info.
if host_port is not None:
cmd = f"cd {target_dir} && python -m http.server {host_port}"
print(f"To host, launch manually: {cmd}")
logger.info(f"- {os.path.join(target_dir, subfolder, collection_name)}")
4 changes: 2 additions & 2 deletions arelight/pipelines/items/entities_default.py
Original file line number Diff line number Diff line change
@@ -5,10 +5,10 @@

class TextEntitiesParser(BasePipelineItem):

def __init__(self, id_assigner, display_value_func=None):
def __init__(self, id_assigner, display_value_func=None, **kwargs):
assert(isinstance(id_assigner, IdAssigner))
assert(callable(display_value_func) or display_value_func is None)
super(TextEntitiesParser, self).__init__()
super(TextEntitiesParser, self).__init__(**kwargs)
self.__id_assigner = id_assigner
self.__disp_value_func = display_value_func

99 changes: 63 additions & 36 deletions arelight/pipelines/items/entities_ner_dp.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,97 @@
from arekit.common.bound import Bound
from arekit.common.docs.objects_parser import SentenceObjectsParserPipelineItem
from arekit.common.text.partitioning.terms import TermsPartitioning
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.common.pipeline.utils import BatchIterator
from arekit.common.text.partitioning import Partitioning

from arelight.arekit.chunk_it import ChunkIterator
from arelight.ner.deep_pavlov import DeepPavlovNER
from arelight.ner.obj_desc import NerObjectDescriptor
from arelight.pipelines.items.entity import IndexedEntity
from arelight.utils import IdAssigner


class DeepPavlovNERPipelineItem(SentenceObjectsParserPipelineItem):
class DeepPavlovNERPipelineItem(BasePipelineItem):

def __init__(self, id_assigner, ner_model_name, obj_filter=None, chunk_limit=128, display_value_func=None):
def __init__(self, id_assigner, ner_model_name, obj_filter=None,
chunk_limit=128, display_value_func=None, **kwargs):
""" chunk_limit: int
length of text part in words that is going to be provided in input.
"""
assert(callable(obj_filter) or obj_filter is None)
assert(isinstance(chunk_limit, int) and chunk_limit > 0)
assert(isinstance(id_assigner, IdAssigner))
assert(callable(display_value_func) or display_value_func is None)
super(DeepPavlovNERPipelineItem, self).__init__(**kwargs)

# Initialize bert-based model instance.
self.__dp_ner = DeepPavlovNER(ner_model_name)
self.__obj_filter = obj_filter
self.__chunk_limit = chunk_limit
self.__id_assigner = id_assigner
self.__disp_value_func = display_value_func
super(DeepPavlovNERPipelineItem, self).__init__(TermsPartitioning())
self.__partitioning = Partitioning(text_fmt="list")

def _get_parts_provider_func(self, input_data, pipeline_ctx):
return self.__iter_subs_values_with_bounds(input_data)
@property
def SupportBatching(self):
return True

def __iter_subs_values_with_bounds(self, terms_list):
assert(isinstance(terms_list, list))
def __iter_subs_values_with_bounds(self, batch_it):
chunk_offset = 0
handled_text_index = -1
for batch in batch_it:
text_indices, texts = zip(*batch)

for chunk_start in range(0, len(terms_list), self.__chunk_limit):
single_sentence_chunk = [terms_list[chunk_start:chunk_start+self.__chunk_limit]]

# NOTE: in some cases, for example URL links or other long input words,
# the overall behavior might result in exceeding the assumed threshold.
# In order to completely prevent it, we consider to wrap the call
# of NER module into try-catch block.
try:
processed_sequences = self.__dp_ner.extract(sequences=single_sentence_chunk)
data = self.__dp_ner.extract(sequences=list(texts))
except RuntimeError:
processed_sequences = []

entities_it = self.__iter_parsed_entities(processed_sequences,
chunk_terms_list=single_sentence_chunk[0],
chunk_offset=chunk_start)
data = None

for entity, bound in entities_it:
yield entity, bound
if data is not None:
for i, d in enumerate(data):
terms, descriptors = d
if text_indices[i] != handled_text_index:
chunk_offset = 0
entities_it = self.__iter_parsed_entities(
descriptors=descriptors, terms_list=terms, chunk_offset=chunk_offset)
handled_text_index = text_indices[i]
chunk_offset += len(terms)
yield text_indices[i], terms, list(entities_it)
else:
for i in range(len(batch)):
yield text_indices[i], texts[i], []

def __iter_parsed_entities(self, processed_sequences, chunk_terms_list, chunk_offset):
for p_sequence in processed_sequences:
for s_obj in p_sequence:
assert (isinstance(s_obj, NerObjectDescriptor))
def __iter_parsed_entities(self, descriptors, terms_list, chunk_offset):
for s_obj in descriptors:
assert (isinstance(s_obj, NerObjectDescriptor))

if self.__obj_filter is not None and not self.__obj_filter(s_obj):
continue
if self.__obj_filter is not None and not self.__obj_filter(s_obj):
continue

value = " ".join(chunk_terms_list[s_obj.Position:s_obj.Position + s_obj.Length])
entity = IndexedEntity(
value=value, e_type=s_obj.ObjectType, entity_id=self.__id_assigner.get_id(),
display_value=self.__disp_value_func(value) if self.__disp_value_func is not None else None)
yield entity, Bound(pos=chunk_offset + s_obj.Position, length=s_obj.Length)
value = " ".join(terms_list[s_obj.Position:s_obj.Position + s_obj.Length])
entity = IndexedEntity(
value=value, e_type=s_obj.ObjectType, entity_id=self.__id_assigner.get_id(),
display_value=self.__disp_value_func(value) if self.__disp_value_func is not None else None)
yield entity, Bound(pos=chunk_offset + s_obj.Position, length=s_obj.Length)

def apply_core(self, input_data, pipeline_ctx):
return super(DeepPavlovNERPipelineItem, self).apply_core(input_data=input_data, pipeline_ctx=pipeline_ctx)
assert(isinstance(input_data, list))

batch_size = len(input_data)

c_it = ChunkIterator(iter(input_data), batch_size=batch_size, chunk_limit=self.__chunk_limit)
b_it = BatchIterator(c_it, batch_size=batch_size)

parts_it = self.__iter_subs_values_with_bounds(b_it)

terms = [[] for _ in range(batch_size)]
bounds = [[] for _ in range(batch_size)]
for i, t, e in parts_it:
terms[i].extend(t)
bounds[i].extend(e)

# Compose batch data.
b_data = []
for i in range(batch_size):
b_data.append(self.__partitioning.provide(text=terms[i], parts_it=bounds[i]))

return b_data
31 changes: 23 additions & 8 deletions arelight/pipelines/items/entities_ner_transformers.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
from arekit.common.bound import Bound
from arekit.common.docs.objects_parser import SentenceObjectsParserPipelineItem
from arekit.common.text.partitioning.str import StringPartitioning
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.common.text.partitioning import Partitioning
from arekit.common.utils import split_by_whitespaces

from arelight.pipelines.items.entity import IndexedEntity
from arelight.utils import IdAssigner, auto_import


class TransformersNERPipelineItem(SentenceObjectsParserPipelineItem):
class TransformersNERPipelineItem(BasePipelineItem):

def __init__(self, id_assigner, ner_model_name, device, obj_filter=None, display_value_func=None):
def __init__(self, id_assigner, ner_model_name, device, obj_filter=None, display_value_func=None, **kwargs):
""" chunk_limit: int
length of text part in words that is going to be provided in input.
"""
assert(callable(obj_filter) or obj_filter is None)
assert(isinstance(id_assigner, IdAssigner))
assert(callable(display_value_func) or display_value_func is None)
super(TransformersNERPipelineItem, self).__init__(**kwargs)

# Setup third-party modules.
model_init = auto_import("arelight.third_party.transformers.init_token_classification_model")
self.annotate_ner = auto_import("arelight.third_party.transformers.annotate_ner")

# Transformers-related parameters.

self.__device = device
self.__model, self.__tokenizer = model_init(model_path=ner_model_name, device=self.__device)

# Initialize bert-based model instance.
self.__obj_filter = obj_filter
self.__id_assigner = id_assigner
self.__disp_value_func = display_value_func
self.__partitioning = Partitioning(text_fmt="str")

super(TransformersNERPipelineItem, self).__init__(StringPartitioning())
# region Private methods

def _get_parts_provider_func(self, input_data, pipeline_ctx):
def __get_parts_provider_func(self, input_data):
assert(isinstance(input_data, str))
parts = self.annotate_ner(model=self.__model, tokenizer=self.__tokenizer, text=input_data,
device=self.__device)
@@ -56,5 +58,18 @@ def __iter_parsed_entities(self, parts):

yield entity, Bound(pos=p["start"], length=p["end"] - p["start"])

@staticmethod
def __iter_fixed_terms(terms):
for e in terms:
if isinstance(e, str):
for term in split_by_whitespaces(e):
yield term
else:
yield e

# endregion

def apply_core(self, input_data, pipeline_ctx):
return super(TransformersNERPipelineItem, self).apply_core(input_data=input_data, pipeline_ctx=pipeline_ctx)
parts_it = self.__get_parts_provider_func(input_data)
handled = self.__partitioning.provide(text=input_data, parts_it=parts_it)
return list(self.__iter_fixed_terms(handled))
123 changes: 67 additions & 56 deletions arelight/pipelines/items/inference_bert_opennre.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import json
import logging
import os
from os.path import exists, join

import logging
import torch

from arekit.common.experiment.data_type import DataType
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.common.utils import download

from opennre.encoder import BERTEntityEncoder, BERTEncoder
from opennre.framework import SentenceRELoader
from opennre.model import SoftmaxNN

from arelight.utils import get_default_download_dir

from arelight.third_party.torch import sentence_re_loader
from arelight.utils import get_default_download_dir, download

logger = logging.getLogger(__name__)

@@ -24,8 +20,15 @@ class BertOpenNREInferencePipelineItem(BasePipelineItem):

def __init__(self, pretrained_bert=None, checkpoint_path=None, device_type='cpu',
max_seq_length=128, pooler='cls', batch_size=10, tokenizers_parallelism=True,
predefined_ckpts=None):
table_name="contents", task_kwargs=None, predefined_ckpts=None, logger=None,
data_loader_num_workers=0, **kwargs):
"""
NOTE: data_loader_num_workers has set to 0 to cope with the following issue #147:
https://github.com/nicolay-r/ARElight/issues/147
where the most similar
"""
assert(isinstance(tokenizers_parallelism, bool))
super(BertOpenNREInferencePipelineItem, self).__init__(**kwargs)

self.__model = None
self.__pretrained_bert = pretrained_bert
@@ -35,6 +38,10 @@ def __init__(self, pretrained_bert=None, checkpoint_path=None, device_type='cpu'
self.__pooler = pooler
self.__batch_size = batch_size
self.__predefined_ckpts = {} if predefined_ckpts is None else predefined_ckpts
self.__task_kwargs = task_kwargs
self.__table_name = table_name
self.__logger = logger
self.__data_loader_num_workers = data_loader_num_workers

# Huggingface/Tokenizers compatibility.
os.environ['TOKENIZERS_PARALLELISM'] = str(tokenizers_parallelism).lower()
@@ -67,7 +74,7 @@ def scaler_to_rel2id(labels_scaler):
return rel2id

@staticmethod
def try_download_predefined_checkpoint(checkpoint, predefined, dir_to_download):
def try_download_predefined_checkpoint(checkpoint, predefined, dir_to_download, logger=None):
""" This is for the simplicity of using the framework straightaway.
"""
assert (isinstance(checkpoint, str))
@@ -81,21 +88,23 @@ def try_download_predefined_checkpoint(checkpoint, predefined, dir_to_download):
# No need to do anything, file has been already downloaded.
if not exists(target_checkpoint_path):
logger.info("Downloading checkpoint to: {}".format(target_checkpoint_path))
download(dest_file_path=target_checkpoint_path, source_url=data["checkpoint"])
download(dest_file_path=target_checkpoint_path,
source_url=data["checkpoint"],
logger=logger)

return data["state"], target_checkpoint_path, data["label_scaler"]

return None, None, None

@staticmethod
def init_bert_model(pretrain_path, labels_scaler, ckpt_path, device_type, predefined, dir_to_donwload=None,
pooler='cls', max_length=128, mask_entity=True):
def init_bert_model(pretrain_path, labels_scaler, ckpt_path, device_type, predefined, logger=None,
dir_to_donwload=None, pooler='cls', max_length=128, mask_entity=True):
""" This is a main and core method for inference based on OpenNRE framework.
"""
# Check predefined checkpoints for local downloading.
predefined_pretrain_path, predefined_ckpt_path, ckpt_label_scaler = \
BertOpenNREInferencePipelineItem.try_download_predefined_checkpoint(
checkpoint=ckpt_path, dir_to_download=dir_to_donwload, predefined=predefined)
checkpoint=ckpt_path, dir_to_download=dir_to_donwload, predefined=predefined, logger=logger)

# Update checkpoint and pretrain paths with the predefined.
ckpt_path = predefined_ckpt_path if predefined_ckpt_path is not None else ckpt_path
@@ -111,68 +120,69 @@ def init_bert_model(pretrain_path, labels_scaler, ckpt_path, device_type, predef
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device(device_type))['state_dict'])
return model

@staticmethod
def extract_ids(data_file):
with open(data_file) as input_file:
for line_str in input_file.readlines():
data = json.loads(line_str)
yield data["id_orig"]

@staticmethod
def iter_results(parallel_model, eval_loader, data_ids):
l_ind = 0
with torch.no_grad():
for iter, data in enumerate(eval_loader):
if torch.cuda.is_available():
for i in range(len(data)):
try:
data[i] = data[i].cuda()
except:
pass

args = data[1:]
logits = parallel_model(*args)
score, pred = logits.max(-1) # (B)

# Save result
batch_size = pred.size(0)
for i in range(batch_size):
yield data_ids[l_ind], pred[i].item()
l_ind += 1

# It is important we should open database.
with eval_loader.dataset:

l_ind = 0
with torch.no_grad():
for iter, data in enumerate(eval_loader):
if torch.cuda.is_available():
for i in range(len(data)):
try:
data[i] = data[i].cuda()
except:
pass

args = data[1:]
logits = parallel_model(*args)
score, pred = logits.max(-1) # (B)

# Save result
batch_size = pred.size(0)
for i in range(batch_size):
yield data_ids[l_ind], int(pred[i].item())
l_ind += 1

def __iter_predict_result(self, samples_filepath, batch_size):
# Compose evaluator.
sentence_eval = SentenceRELoader(path=samples_filepath,
rel2id=self.__model.rel2id,
tokenizer=self.__model.sentence_encoder.tokenize,
batch_size=batch_size,
shuffle=False)
sentence_eval = sentence_re_loader(path=samples_filepath,
rel2id=self.__model.rel2id,
tokenizer=self.__model.sentence_encoder.tokenize,
batch_size=batch_size,
table_name=self.__table_name,
task_kwargs=self.__task_kwargs,
num_workers=self.__data_loader_num_workers,
shuffle=False)

with sentence_eval.dataset as dataset:

# Iter output results.
results_it = self.iter_results(parallel_model=torch.nn.DataParallel(self.__model),
data_ids=list(self.extract_ids(samples_filepath)),
eval_loader=sentence_eval)
# Iter output results.
results_it = self.iter_results(parallel_model=torch.nn.DataParallel(self.__model),
data_ids=list(dataset.iter_ids()),
eval_loader=sentence_eval)

total = len(sentence_eval.dataset)
total = len(sentence_eval.dataset)

return results_it, total

def apply_core(self, input_data, pipeline_ctx):
assert(isinstance(input_data, PipelineContext))

# Fetching the input data.
labels_scaler = input_data.provide("labels_scaler")
labels_scaler = pipeline_ctx.provide("labels_scaler")

# Try to obrain from the specific input variable.
samples_filepath = input_data.provide_or_none("opennre_samples_filepath")
samples_filepath = pipeline_ctx.provide_or_none("opennre_samples_filepath")
if samples_filepath is None:
samples_io = input_data.provide("samples_io")
samples_io = pipeline_ctx.provide("samples_io")
samples_filepath = samples_io.create_target(data_type=DataType.Test)

# Initialize model if the latter has not been yet.
if self.__model is None:

ckpt_dir = input_data.provide_or_none("opennre_ckpt_cache_dir")
ckpt_dir = pipeline_ctx.provide_or_none("opennre_ckpt_cache_dir")

self.__model = self.init_bert_model(
pretrain_path=self.__pretrained_bert,
@@ -183,8 +193,9 @@ def apply_core(self, input_data, pipeline_ctx):
labels_scaler=labels_scaler,
mask_entity=True,
predefined=self.__predefined_ckpts,
logger=self.__logger,
dir_to_donwload=get_default_download_dir() if ckpt_dir is None else ckpt_dir)

iter_infer, total = self.__iter_predict_result(samples_filepath=samples_filepath, batch_size=self.__batch_size)
input_data.update("iter_infer", iter_infer)
input_data.update("iter_total", total)
pipeline_ctx.update("iter_infer", iter_infer)
pipeline_ctx.update("iter_total", total)
26 changes: 14 additions & 12 deletions arelight/pipelines/items/inference_writer.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem

from arelight.predict_provider import BasePredictProvider
from arelight.predict_writer import BasePredictWriter
from arelight.predict.provider import BasePredictProvider
from arelight.predict.writer import BasePredictWriter


class InferenceWriterPipelineItem(BasePipelineItem):

def __init__(self, writer):
def __init__(self, writer, **kwargs):
assert(isinstance(writer, BasePredictWriter))
super(InferenceWriterPipelineItem, self).__init__(**kwargs)
self.__writer = writer

def apply_core(self, input_data, pipeline_ctx):
assert(isinstance(input_data, PipelineContext))

# Setup predicted result writer.
target = input_data.provide("predict_filepath")
print(target)
target = pipeline_ctx.provide("predict_filepath")

self.__writer.set_target(target)

# Extracting list of the uint labels.
labels_scaler = pipeline_ctx.provide("labels_scaler")
uint_labels = [labels_scaler.label_to_uint(label) for label in labels_scaler.ordered_suppoted_labels()]

# Gathering the content
title, contents_it = BasePredictProvider().provide(
sample_id_with_uint_labels_iter=input_data.provide("iter_infer"),
labels_count=input_data.provide("labels_scaler").LabelsCount)
header, contents_it = BasePredictProvider.provide_to_storage(
sample_id_with_uint_labels_iter=pipeline_ctx.provide("iter_infer"),
uint_labels=uint_labels)

with self.__writer:
self.__writer.write(title=title, contents_it=contents_it,
total=input_data.provide_or_none("iter_total"))
self.__writer.write(header=header, contents_it=contents_it,
total=pipeline_ctx.provide_or_none("iter_total"))
6 changes: 2 additions & 4 deletions arelight/pipelines/items/serializer_arekit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from arekit.common.pipeline.context import PipelineContext
from arekit.contrib.utils.pipelines.items.sampling.base import BaseSerializerPipelineItem
from arelight.data.serializer_base import BaseSerializerPipelineItem


class AREkitSerializerPipelineItem(BaseSerializerPipelineItem):
@@ -8,9 +7,8 @@ class AREkitSerializerPipelineItem(BaseSerializerPipelineItem):
"""

def apply_core(self, input_data, pipeline_ctx):
assert(isinstance(input_data, PipelineContext))
super(AREkitSerializerPipelineItem, self).apply_core(input_data=input_data,
pipeline_ctx=pipeline_ctx)

# Host samples into the result for further pipeline items.
input_data.update("samples_io", self._samples_io)
pipeline_ctx.update("samples_io", self._samples_io)
16 changes: 0 additions & 16 deletions arelight/pipelines/items/terms_splitter.py

This file was deleted.

Empty file added arelight/predict/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions arelight/predict/header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from arekit.common.data import const


class PredictHeader:

@staticmethod
def create_header(uint_labels, uint_label_to_str, create_id=True):
assert(callable(uint_label_to_str))

header = []

if create_id:
header.append(const.ID)

header.extend([uint_label_to_str(uint_label) for uint_label in uint_labels])

return header
68 changes: 68 additions & 0 deletions arelight/predict/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from collections.abc import Iterable

from arekit.common.data.storages.base import BaseRowsStorage

from arelight.predict.header import PredictHeader


class BasePredictProvider(object):

UINT_TO_STR = lambda uint_label: f"col_{uint_label}"

@staticmethod
def __iter_contents(sample_id_with_uint_labels_iter, labels_count, column_extra_funcs):
assert(isinstance(labels_count, int))

for sample_id, uint_label in sample_id_with_uint_labels_iter:
assert(isinstance(uint_label, int))

labels = ['0'] * labels_count
labels[uint_label] = '1'

# Composing row contents.
contents = [sample_id]

# Optionally provide additional values.
if column_extra_funcs is not None:
for _, value_func in column_extra_funcs:
contents.append(str(value_func(sample_id)))

# Providing row labels.
contents.extend(labels)
yield contents

@staticmethod
def provide_to_storage(sample_id_with_uint_labels_iter, uint_labels):
assert(isinstance(sample_id_with_uint_labels_iter, Iterable))

# Provide contents.
contents_it = BasePredictProvider.__iter_contents(
sample_id_with_uint_labels_iter=sample_id_with_uint_labels_iter,
labels_count=len(uint_labels),
column_extra_funcs=None)

header = PredictHeader.create_header(uint_labels=uint_labels,
uint_label_to_str=BasePredictProvider.UINT_TO_STR)

return header, contents_it

@staticmethod
def iter_from_storage(predict_data, uint_labels, keep_ind=True):
assert (isinstance(predict_data, BaseRowsStorage))

header = PredictHeader.create_header(uint_labels=uint_labels,
uint_label_to_str=BasePredictProvider.UINT_TO_STR,
create_id=False)

for res_ind, row in predict_data:

uint_label = None
for label, field_name in zip(uint_labels, header):
if row[field_name] > 0:
uint_label = label
break

if keep_ind:
yield res_ind, uint_label
else:
yield uint_label
2 changes: 1 addition & 1 deletion arelight/predict_writer.py → arelight/predict/writer.py
Original file line number Diff line number Diff line change
@@ -6,5 +6,5 @@ def __init__(self):
def set_target(self, target):
self._target = target

def write(self, title, contents_it, total=None):
def write(self, header, contents_it, total=None):
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -3,27 +3,28 @@

from arekit.common.utils import progress_bar_defined, create_dir_if_not_exists

from arelight.predict_writer import BasePredictWriter
from arelight.predict.writer import BasePredictWriter

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class TsvPredictWriter(BasePredictWriter):

def __init__(self):
def __init__(self, log_out=None):
super(TsvPredictWriter, self).__init__()
self.__col_separator = '\t'
self.__f = None
self.__log_out = log_out

def __write(self, params):
line = "{}\n".format(self.__col_separator.join([str(p) for p in params]))
self.__f.write(line.encode())

def write(self, title, contents_it, total=None):
self.__write(title)
def write(self, header, contents_it, total=None):
self.__write(header)

wrapped_it = progress_bar_defined(iterable=contents_it, desc='Writing output', unit='rows', total=total)
wrapped_it = progress_bar_defined(iterable=contents_it, desc='Writing output (tsv)', unit='rows',
total=total, file=self.__log_out)

for contents in wrapped_it:
self.__write(contents)
37 changes: 37 additions & 0 deletions arelight/predict/writer_sqlite3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

from arekit.common.utils import progress_bar_defined
from source_iter.service_sqlite import SQLite3Service

from arelight.predict.writer import BasePredictWriter

logger = logging.getLogger(__name__)


class SQLite3PredictWriter(BasePredictWriter):

def __init__(self, table_name, log_out=None):
super(SQLite3PredictWriter, self).__init__()
self.__table_name = table_name
self.__log_out = log_out

def write(self, header, contents_it, total=None):

content_header = header[1:]
SQLite3Service.write_missed(
columns=content_header,
target=self._target,
table_name=self.__table_name,
it_type=None,
data_it=progress_bar_defined(iterable=map(lambda item: [item[0], item[1:]], contents_it),
desc=f'Writing output (sqlite:{self.__table_name})',
unit='rows', total=total, file=self.__log_out),
sqlite3_column_types=["INTEGER"] * len(content_header),
id_column_name=header[0],
id_column_type="INTEGER")

def __enter__(self):
pass

def __exit__(self, exc_type, exc_val, exc_tb):
pass
47 changes: 0 additions & 47 deletions arelight/predict_provider.py

This file was deleted.

Empty file added arelight/readers/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions arelight/readers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class BaseReader(object):

def extension(self):
raise NotImplementedError()

def read(self, target):
raise NotImplementedError()
39 changes: 39 additions & 0 deletions arelight/readers/csv_pd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import importlib

from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage

from arelight.readers.base import BaseReader


class PandasCsvReader(BaseReader):
""" Represents a CSV-based reader, implmented via pandas API.
"""

def __init__(self, sep='\t', header='infer', compression='infer', encoding='utf-8', col_types=None,
custom_extension=None):
self.__sep = sep
self.__compression = compression
self.__encoding = encoding
self.__header = header
self.__custom_extension = custom_extension

# Special assignation of types for certain columns.
self.__col_types = col_types
if self.__col_types is None:
self.__col_types = dict()

def extension(self):
return ".tsv.gz" if self.__custom_extension is None else self.__custom_extension

def __from_csv(self, filepath):
pd = importlib.import_module("pandas")
return pd.read_csv(filepath,
sep=self.__sep,
encoding=self.__encoding,
compression=self.__compression,
dtype=self.__col_types,
header=self.__header)

def read(self, target):
df = self.__from_csv(filepath=target)
return PandasBasedRowsStorage(df)
16 changes: 16 additions & 0 deletions arelight/readers/jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from arekit.contrib.utils.data.storages.jsonl_based import JsonlBasedRowsStorage

from arelight.readers.base import BaseReader


class JsonlReader(BaseReader):

def extension(self):
return ".jsonl"

def read(self, target):
rows = []
with open(target, "r") as f:
for line in f.readlines():
rows.append(line)
return JsonlBasedRowsStorage(rows)
15 changes: 15 additions & 0 deletions arelight/readers/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from arekit.contrib.utils.data.storages.sqlite_based import SQliteBasedRowsStorage

from arelight.readers.base import BaseReader


class SQliteReader(BaseReader):

def __init__(self, table_name):
self.__table_name = table_name

def extension(self):
return ".sqlite"

def read(self, target):
return SQliteBasedRowsStorage(path=target, table_name=self.__table_name)
158 changes: 101 additions & 57 deletions arelight/run/infer.py

Large diffs are not rendered by default.

54 changes: 33 additions & 21 deletions arelight/run/operations.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import os
from datetime import datetime

from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.base import BasePipelineLauncher

from arelight.backend.d3js.relations_graph_operations import OP_UNION, OP_DIFFERENCE, OP_INTERSECTION
from arelight.backend.d3js.ui_web import GRAPH_TYPE_FORCE
@@ -11,14 +11,15 @@
from arelight.pipelines.demo.result import PipelineResult
from arelight.pipelines.items.backend_d3js_operations import D3jsGraphOperationsBackendPipelineItem
from arelight.run.utils import get_binary_choice, get_list_choice, get_int_choice, is_port_number
from arelight.run.utils_logger import setup_custom_logger


def get_input_with_default(prompt, default_value):
user_input = input(prompt)
return user_input.strip() or default_value


def get_graph_path(text):
def get_graph_path_interactive(text):
while True:
folder_path = input(text)

@@ -57,12 +58,11 @@ def get_graph_path(text):
print("Invalid input. Please enter a number.")


if __name__ == '__main__':
def create_operations_parser(op_list):

op_list = [OP_UNION, OP_INTERSECTION, OP_DIFFERENCE]
parser = argparse.ArgumentParser(description="Graph Operations")

# Providing arguments.
parser = argparse.ArgumentParser(description="Graph Operations")
parser.add_argument("--operation", required=False, choices=op_list,
help="Select operation: {ops}".format(ops=",".join(op_list)))
parser.add_argument("--graph_a_file", required=False,
@@ -76,15 +76,29 @@ def get_graph_path(text):
parser.add_argument("--name", required=False, help="Specify name of new graph")
parser.add_argument("--label-names", dest="d3js_label_names", type=str, default="p:pos,n:neg,u:neu")
parser.add_argument("--description", required=False, help="Specify description of new graph")
parser.add_argument('--log-file', dest="log_file", default=None, type=str)
parser.add_argument("--host", required=False, default=None, help="Server port for launching hosting (optional)")

return parser


if __name__ == '__main__':

op_list = [OP_UNION, OP_INTERSECTION, OP_DIFFERENCE]

# Completing list of arguments.
parser = create_operations_parser(op_list)

# Parsing arguments.
args = parser.parse_args()

# Setup logger
logger = setup_custom_logger(name="arelight", filepath=args.log_file)

operation = args.operation if args.operation else get_list_choice(op_list)
graph_A_file_path = args.graph_a_file if args.graph_a_file else get_graph_path(
graph_A_file_path = args.graph_a_file if args.graph_a_file else get_graph_path_interactive(
"Enter the path to the folder for graph_A: ")
graph_B_file_path = args.graph_b_file if args.graph_b_file else get_graph_path(
graph_B_file_path = args.graph_b_file if args.graph_b_file else get_graph_path_interactive(
"Enter the path to the folder for graph_B: ")
weights = args.weights.lower() == 'y' if args.weights else get_binary_choice("Use weights? (y/n)\n")
do_host = args.host if is_port_number(args.host) \
@@ -117,21 +131,19 @@ def get_graph_path(text):
description = args.description if args.description else \
get_input_with_default("Specify description of new graph (enter to skip)\n", default_description)

pipeline = BasePipeline([
D3jsGraphOperationsBackendPipelineItem()
])

labels_fmt = {a: v for a, v in map(lambda item: item.split(":"), args.d3js_label_names.split(','))}

# Launch application.
pipeline.run(input_data=PipelineResult({
# We provide this settings for inference.
"labels_formatter": CustomLabelsFormatter(**labels_fmt),
"d3js_graph_output_dir": output_dir,
"d3js_collection_description": description,
"d3js_host": str(8000) if do_host else None,
"d3js_graph_a": load_graph(graph_A_file_path),
"d3js_graph_b": load_graph(graph_B_file_path),
"d3js_graph_operations": operation,
"d3js_collection_name": collection_name
BasePipelineLauncher.run(
pipeline=[D3jsGraphOperationsBackendPipelineItem()],
pipeline_ctx=PipelineResult({
# We provide this settings for inference.
"labels_formatter": CustomLabelsFormatter(**labels_fmt),
"d3js_graph_output_dir": output_dir,
"d3js_collection_description": description,
"d3js_graph_a": load_graph(graph_A_file_path),
"d3js_graph_b": load_graph(graph_B_file_path),
"d3js_graph_operations": operation,
"d3js_collection_name": collection_name,
"result": None
}))
17 changes: 8 additions & 9 deletions arelight/run/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import logging
from enum import Enum
from io import TextIOWrapper
from zipfile import ZipFile

from arekit.common.docs.base import Document
from arekit.common.docs.sentence import BaseDocumentSentence
from arekit.contrib.source.synonyms.utils import iter_synonym_groups

from arelight.pipelines.demo.labels.scalers import CustomLabelScaler
from arelight.synonyms import iter_synonym_groups
from arelight.utils import auto_import, iter_csv_lines

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

NER_TYPES = ["ORG", "PERSON", "LOC", "GPE"]


def create_sentence_parser(framework, language):
@@ -182,14 +181,14 @@ def is_port_number(number, is_optional=True):


OPENNRE_CHECKPOINTS = {
"ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar": {
"state": "DeepPavlov/rubert-base-cased",
"checkpoint": "https://www.dropbox.com/scl/fi/rwjf7ag3w3z90pifeywrd/ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar?rlkey=p0mmu81o6c2u6iboe9m20uzqk&dl=1",
"label_scaler": CustomLabelScaler(p=1, n=2, u=0)
},
"ra4-rsr1_bert-base-cased_cls.pth.tar": {
"state": "bert-base-cased",
"checkpoint": "https://www.dropbox.com/scl/fi/k5arragv1g4wwftgw5xxd/ra-rsr_bert-base-cased_cls.pth.tar?rlkey=8hzavrxunekf0woesxrr0zqys&dl=1",
"label_scaler": CustomLabelScaler(p=1, n=2, u=0)
},
"ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar": {
"state": "DeepPavlov/rubert-base-cased",
"checkpoint": "https://www.dropbox.com/scl/fi/rwjf7ag3w3z90pifeywrd/ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar?rlkey=p0mmu81o6c2u6iboe9m20uzqk&dl=1",
"label_scaler": CustomLabelScaler(p=1, n=2, u=0)
}
}
61 changes: 61 additions & 0 deletions arelight/run/utils_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import io
import logging
import sys
from tqdm import tqdm

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.NOTSET):
super().__init__(level)

def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)


class TqdmToLogger(io.StringIO):
"""
Output stream for TQDM which will output to logger module instead of
the StdOut.
"""
logger = None
level = None
buf = ''

def __init__(self, logger, level=None):
super(TqdmToLogger, self).__init__()
self.logger = logger
self.level = level or logging.INFO

def write(self, buf):
self.buf = buf.strip('\r\n\t ')

def flush(self):
self.logger.log(self.level, self.buf)


def setup_custom_logger(name, add_screen_handler=False, filepath=None):
formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.addHandler(TqdmLoggingHandler())

if add_screen_handler:
screen_handler = logging.StreamHandler(stream=sys.stdout)
screen_handler.setFormatter(formatter)
logger.addHandler(screen_handler)

if filepath is not None:
handler = logging.FileHandler(filepath, mode='w')
handler.setFormatter(formatter)
logger.addHandler(handler)

return logger
Empty file added arelight/stemmers/__init__.py
Empty file.
51 changes: 51 additions & 0 deletions arelight/stemmers/ru_mystem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from arekit.common.text.stemmer import Stemmer
from arekit.common.utils import filter_whitespaces
from pymystem3 import Mystem


class MystemWrapper(Stemmer):
""" Yandex MyStem wrapper
part of speech description:
https://tech.yandex.ru/mystem/doc/grammemes-values-docpage/
"""

def __init__(self, entire_input=False):
"""
entire_input: bool
Mystem parameter that allows to keep all information from input (true) or
remove garbage characters
"""
self.__mystem = Mystem(entire_input=entire_input)

# region properties

@property
def MystemInstance(self):
return self.__mystem

# endregion

# region public methods

def lemmatize_to_list(self, text):
return self.__lemmatize_core(text)

def lemmatize_to_str(self, text):
result = " ".join(self.__lemmatize_core(text))
return result if len(result) != 0 else self.__process_original_text(text)

# endregion

# region private methods

def __lemmatize_core(self, text):
assert(isinstance(text, str))
result_list = self.__mystem.lemmatize(self.__process_original_text(text))
return filter_whitespaces(result_list)

@staticmethod
def __process_original_text(text):
return text.lower()

# endregion
14 changes: 14 additions & 0 deletions arelight/synonyms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from arekit.common.utils import progress_bar_defined


def iter_synonym_groups(input_file, sep=",", desc=""):
""" All the synonyms groups organized in lines, separated by `sep`
"""
lines = input_file.readlines()

for line in progress_bar_defined(lines, total=len(lines), desc=desc, unit="opins"):

if isinstance(line, bytes):
line = line.decode()

yield line.split(sep)
42 changes: 42 additions & 0 deletions arelight/third_party/sqlite3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import sqlite3


class SQLite3Service(object):

def __init__(self):
self.conn = None

def connect(self, sqlite_path):
self.conn = sqlite3.connect(sqlite_path)

def disconnect(self):
assert(self.conn is not None)
self.conn.close()

def table_rows_count(self, table_name):
count_response = self.conn.execute(f"select count(*) from {table_name}").fetchone()
return count_response[0]

def get_column_names(self, table_name, filter_name=None):
cursor = self.conn.execute(f"select * from {table_name}")
column_names = list(map(lambda x: x[0], cursor.description))
return [col_name for col_name in column_names
if filter_name is None or (filter_name is not None and filter_name(col_name))]

def iter_rows(self, table_name, select_columns="*", column_value=None, value=None, return_dict=False):
""" Returns array of the values in the same order as the one provided in `select_columns` parameter.
"""
if column_value is not None and value is not None:
cursor = self.conn.execute(
f"select {select_columns} from {table_name} where ({column_value} = ?)", (value,))
else:
cursor = self.conn.execute(f"select {select_columns} from {table_name}")

if not return_dict:
for row in cursor.fetchall():
yield row

# Return as dictionary
column_names = list(map(lambda x: x[0], cursor.description))
for row in cursor.fetchall():
yield {col_name: row[i] for i, col_name in enumerate(column_names)}
98 changes: 98 additions & 0 deletions arelight/third_party/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging

import torch
from torch.utils import data

from arelight.third_party.sqlite3 import SQLite3Service


class SQLiteSentenceREDataset(data.Dataset):
""" Sentence-level relation extraction dataset
This is a original OpenNRE implementation, adapted for SQLite.
"""

def __init__(self, path, table_name, rel2id, tokenizer, kwargs, task_kwargs):
"""
Args:
path: path of the input file sqlite file
rel2id: dictionary of relation->id mapping
tokenizer: function of tokenizing
"""
assert(isinstance(task_kwargs, dict))
assert("no_label" in task_kwargs)
assert("default_id_column" in task_kwargs)
assert("index_columns" in task_kwargs)
assert("text_columns" in task_kwargs)
super().__init__()
self.path = path
self.tokenizer = tokenizer
self.rel2id = rel2id
self.kwargs = kwargs
self.task_kwargs = task_kwargs
self.table_name = table_name
self.sqlite_service = SQLite3Service()

def iter_ids(self, id_column=None):
col_name = self.task_kwargs["default_id_column"] if id_column is None else id_column
for row in self.sqlite_service.iter_rows(select_columns=col_name, table_name=self.table_name):
yield row[0]

def __len__(self):
return self.sqlite_service.table_rows_count(self.table_name)

def __getitem__(self, index):
found_text_columns = self.sqlite_service.get_column_names(
table_name=self.table_name, filter_name=lambda col_name: col_name in self.task_kwargs["text_columns"])

iter_rows = self.sqlite_service.iter_rows(
select_columns=",".join(self.task_kwargs["index_columns"] + found_text_columns),
value=index,
column_value=self.task_kwargs["default_id_column"],
table_name=self.table_name)

fetched_row = next(iter_rows)

opennre_item = {
"text": " ".join(fetched_row[-len(found_text_columns):]),
"h": {"pos": [fetched_row[0], fetched_row[0] + 1]},
"t": {"pos": [fetched_row[1], fetched_row[1] + 1]},
}

seq = list(self.tokenizer(opennre_item, **self.kwargs))

return [self.rel2id[self.task_kwargs["no_label"]]] + seq # label, seq1, seq2, ...

def collate_fn(data):
data = list(zip(*data))
labels = data[0]
seqs = data[1:]
batch_labels = torch.tensor(labels).long() # (B)
batch_seqs = []
for seq in seqs:
batch_seqs.append(torch.cat(seq, 0)) # (B, L)
return [batch_labels] + batch_seqs

def eval(self, pred_result, use_name=False):
raise NotImplementedError()

def __enter__(self):
self.sqlite_service.connect(self.path)
logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format(
self.path, len(self), len(self.rel2id)))
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.sqlite_service.disconnect()


def sentence_re_loader(path, table_name, rel2id, tokenizer, batch_size, shuffle,
task_kwargs, num_workers, collate_fn=SQLiteSentenceREDataset.collate_fn, **kwargs):
dataset = SQLiteSentenceREDataset(path=path, table_name=table_name, rel2id=rel2id,
tokenizer=tokenizer, kwargs=kwargs, task_kwargs=task_kwargs)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
collate_fn=collate_fn)
return data_loader
4 changes: 2 additions & 2 deletions arelight/third_party/transformers.py
Original file line number Diff line number Diff line change
@@ -12,12 +12,12 @@ def init_token_classification_model(model_path, device):
return model, tokenizer


def annotate_ner_ppl(model, tokenizer, device, batch_size=4):
def annotate_ner_ppl(model, tokenizer, device="cpu", batch_size=4):
return pipeline("ner", model=model, aggregation_strategy='simple', tokenizer=tokenizer,
grouped_entities=True, batch_size=batch_size, device=device)


def annotate_ner(model, tokenizer, text, device):
def annotate_ner(model, tokenizer, text, device="cpu"):
""" This code is related to collection of the annotated objects from texts.
return: list of dict
41 changes: 41 additions & 0 deletions arelight/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import collections
import csv
import importlib
import os
import sys

import requests
from tqdm import tqdm


def auto_import(name):
""" Import from the external python packages.
@@ -14,6 +18,17 @@ def __get_module(comps_list):
return getattr(__get_module(components[:-1]), components[-1])


def flatten(xss):
l = []
for xs in xss:
if isinstance(xs, collections.abc.Iterable):
for x in xs:
l.append(x)
else:
l.append(xs)
return l


def get_default_download_dir():
""" Refered to NLTK toolkit approach
https://github.com/nltk/nltk/blob/8e771679cee1b4a9540633cc3ea17f4421ffd6c0/nltk/downloader.py#L1051
@@ -53,3 +68,29 @@ def iter_csv_lines(csv_file, column_name, delimiter=","):

for row in csv_reader:
yield row[column_name]


def download(dest_file_path, source_url, logger):
""" Refered to https://github.com/nicolay-r/ner-bilstm-crf-tensorflow/blob/master/ner/utils.py
Simple http file downloader
"""
if logger is not None:
logger.info(('Downloading from {src} to {dest}'.format(src=source_url, dest=dest_file_path)))

sys.stdout.flush()
datapath = os.path.dirname(dest_file_path)

if not os.path.exists(datapath):
os.makedirs(datapath, mode=0o755)

dest_file_path = os.path.abspath(dest_file_path)

r = requests.get(source_url, stream=True)
total_length = int(r.headers.get('content-length', 0))

with open(dest_file_path, 'wb') as f:
pbar = tqdm(total=total_length, unit='B', unit_scale=True)
for chunk in r.iter_content(chunk_size=32 * 1024):
if chunk: # filter out keep-alive new chunks
pbar.update(len(chunk))
f.write(chunk)
10 changes: 8 additions & 2 deletions dependencies.txt
Original file line number Diff line number Diff line change
@@ -2,7 +2,13 @@ deeppavlov==1.3.0
transformers==4.24.0
torch==2.0.1
pytorch-crf==0.7.2
arekit==0.24.0
source-iter==0.24.2
arekit @ git+https://github.com/nicolay-r/AREkit@0.25.1-rc
bulk-translate @ git+https://github.com/nicolay-r/bulk-translate@master
open-nre==0.1.1
nltk==3.8.1
googletrans==3.1.0a0
googletrans==3.1.0a0
httpcore==0.9.1
argparse_to_json==0.0.1
httpx==0.13.3
requests
Binary file modified logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ def get_requirements(filenames):

setup(
name='arelight',
version='0.24.0',
version='0.25.0',
description='About Mass-media text processing application for your '
'Relation Extraction task, powered by AREkit.',
url='https://github.com/nicolay-r/ARElight',
File renamed without changes.
Loading

0 comments on commit 8b597a6

Please sign in to comment.