Skip to content

Commit

Permalink
feat: add support for classification data in csv/json formats (#115)
Browse files Browse the repository at this point in the history
* feat: add support for csv/json classification datasets

* feat: add tests for snips and samples
  • Loading branch information
seliverstov authored Mar 16, 2018
1 parent c7a131c commit 8cc53c6
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 33 deletions.
2 changes: 1 addition & 1 deletion deeppavlov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import deeppavlov.dataset_readers.dstc2_reader
import deeppavlov.dataset_readers.conll2003_reader
import deeppavlov.dataset_readers.typos_reader
import deeppavlov.dataset_readers.csv_classification_reader
import deeppavlov.dataset_readers.basic_classification_reader
import deeppavlov.dataset_iterators.dialog_iterator
import deeppavlov.dataset_iterators.dstc2_ner_iterator
import deeppavlov.dataset_iterators.dstc2_intents_iterator
Expand Down
98 changes: 98 additions & 0 deletions deeppavlov/configs/intents/intents_sample_csv.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
{
"dataset": {
"type": "classification",
"format": "csv",
"sep": ",",
"header": 0,
"names": ["text", "classes"],
"class_sep": ",",
"train": "sample.csv",
"data_path": "sample",
"x": "text",
"y": "classes",
"url": "http://lnsigo.mipt.ru/export/datasets/snips_intents/train.csv",
"seed": 42,
"field_to_split": "train",
"split_fields": [
"train",
"valid"
],
"split_proportions": [
0.9,
0.1
]
},
"chainer": {
"in": ["x"],
"in_y": ["y"],
"pipe": [
{
"id": "classes_vocab",
"name": "default_vocab",
"fit_on": ["y"],
"level": "token",
"save_path": "vocabs/snips_classes.dict",
"load_path": "vocabs/snips_classes.dict"
},
{
"in": ["x"],
"in_y": ["y"],
"out": ["y_predicted"],
"main": true,
"name": "intent_model",
"save_path": "intents/intent_cnn_snips_v2",
"load_path": "intents/intent_cnn_snips_v2",
"classes": "#classes_vocab.keys()",
"opt": {
"kernel_sizes_cnn": [
1,
2,
3
],
"filters_cnn": 256,
"lear_metrics": [
"binary_accuracy",
"fmeasure"
],
"confident_threshold": 0.5,
"optimizer": "Adam",
"lear_rate": 0.01,
"lear_rate_decay": 0.1,
"loss": "binary_crossentropy",
"text_size": 15,
"coef_reg_cnn": 1e-4,
"coef_reg_den": 1e-4,
"dropout_rate": 0.5,
"epochs": 1000,
"dense_size": 100,
"model_name": "cnn_model"
},
"embedder": {
"name": "fasttext",
"save_path": "embeddings/dstc2_fastText_model.bin",
"load_path": "embeddings/dstc2_fastText_model.bin",
"emb_module": "fasttext",
"dim": 100
},
"tokenizer": {
"name": "nltk_tokenizer",
"tokenizer": "wordpunct_tokenize"
}
}
],
"out": ["y_predicted"]
},
"train": {
"epochs": 100,
"batch_size": 64,
"metrics": [
"sets_accuracy"
],
"validation_patience": 5,
"val_every_n_epochs": 1,
"log_every_n_epochs": 1,
"show_examples": false,
"validate_best": true,
"test_best": false
}
}
96 changes: 96 additions & 0 deletions deeppavlov/configs/intents/intents_sample_json.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
{
"dataset": {
"type": "classification",
"format": "json",
"orient": "records",
"lines": true,
"data_path": "sample",
"train": "sample.json",
"x": "text",
"y": "intents",
"url": "http://lnsigo.mipt.ru/export/datasets/snips_intents/train.json",
"seed": 42,
"field_to_split": "train",
"split_fields": [
"train",
"valid"
],
"split_proportions": [
0.9,
0.1
]
},
"chainer": {
"in": ["x"],
"in_y": ["y"],
"pipe": [
{
"id": "classes_vocab",
"name": "default_vocab",
"fit_on": ["y"],
"level": "token",
"save_path": "vocabs/snips_classes.dict",
"load_path": "vocabs/snips_classes.dict"
},
{
"in": ["x"],
"in_y": ["y"],
"out": ["y_predicted"],
"main": true,
"name": "intent_model",
"save_path": "intents/intent_cnn_snips_v2",
"load_path": "intents/intent_cnn_snips_v2",
"classes": "#classes_vocab.keys()",
"opt": {
"kernel_sizes_cnn": [
1,
2,
3
],
"filters_cnn": 256,
"lear_metrics": [
"binary_accuracy",
"fmeasure"
],
"confident_threshold": 0.5,
"optimizer": "Adam",
"lear_rate": 0.01,
"lear_rate_decay": 0.1,
"loss": "binary_crossentropy",
"text_size": 15,
"coef_reg_cnn": 1e-4,
"coef_reg_den": 1e-4,
"dropout_rate": 0.5,
"epochs": 1000,
"dense_size": 100,
"model_name": "cnn_model"
},
"embedder": {
"name": "fasttext",
"save_path": "embeddings/dstc2_fastText_model.bin",
"load_path": "embeddings/dstc2_fastText_model.bin",
"emb_module": "fasttext",
"dim": 100
},
"tokenizer": {
"name": "nltk_tokenizer",
"tokenizer": "wordpunct_tokenize"
}
}
],
"out": ["y_predicted"]
},
"train": {
"epochs": 100,
"batch_size": 64,
"metrics": [
"sets_accuracy"
],
"validation_patience": 5,
"val_every_n_epochs": 1,
"log_every_n_epochs": 1,
"show_examples": false,
"validate_best": true,
"test_best": false
}
}
4 changes: 3 additions & 1 deletion deeppavlov/configs/intents/intents_snips.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{
"dataset_reader": {
"name": "csv_classification_reader",
"name": "basic_classification_reader",
"x": "text",
"y": "intents",
"data_path": "snips",
"url": "http://lnsigo.mipt.ru/export/datasets/snips_intents/train.csv"
},
Expand Down
17 changes: 15 additions & 2 deletions deeppavlov/core/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ def train_model_from_config(config_path: str):
config = read_json(config_path)
set_deeppavlov_root(config)

dataset_config = config.get('dataset', None)

if dataset_config is not None:
del config['dataset']
ds_type = dataset_config['type']
if ds_type == 'classification':
reader = {'name': 'basic_classification_reader'}
iterator = {'name': 'basic_classification_iterator'}
config['dataset_reader'] = {**dataset_config, **reader}
config['dataset_iterator'] = {**dataset_config, **iterator}
else:
raise Exception("Unsupported dataset type: {}".format(ds_type))

reader_config = config['dataset_reader']
reader = get_model(reader_config['name'])()
data_path = expand_path(reader_config.get('data_path', ''))
Expand All @@ -81,8 +94,8 @@ def train_model_from_config(config_path: str):
if "data_path" in kwargs: del kwargs["data_path"]
data = reader.read(data_path, **kwargs)

dataset_config = config['dataset_iterator']
dataset: BasicDatasetIterator = from_params(dataset_config, data=data)
iterator_config = config['dataset_iterator']
dataset: BasicDatasetIterator = from_params(iterator_config, data=data)

if 'chainer' in config:
model = fit_chainer(config, dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
log = get_logger(__name__)


@register('csv_classification_reader')
class CsvClassificationDatasetReader(DatasetReader):
@register('basic_classification_reader')
class BasicClassificationDatasetReader(DatasetReader):
"""
Class provides reading dataset in .csv format
"""
Expand All @@ -52,20 +52,38 @@ def read(self, data_path, url=None, *args, **kwargs):
"""
data_types = ["train", "valid", "test"]

if not Path(data_path, "train.csv").exists():
train_file = format(kwargs.get('train', 'train.csv'))

if not Path(data_path, train_file).exists():
if url is None:
raise Exception("data path {} is not exists or empty and download url parameter not specified!".format(data_path))
log.info("Loading train data from {} to {}".format(url, data_path))
download(source_url=url, dest_file_path=Path(data_path, "train.csv"))
download(source_url=url, dest_file_path=Path(data_path, train_file))

data = {"train": [],
"valid": [],
"test": []}
for data_type in data_types:
try:
df = pd.read_csv(Path(data_path).joinpath(data_type + ".csv"))
data[data_type] = [(row['text'], row['intents'].split(',')) for _, row in df.iterrows()]
except FileNotFoundError:
log.warning("Cannot find {}.csv data file".format(data_type))
file_format = kwargs.get('format', 'csv')
file_name = kwargs.get(data_type, '{}.{}'.format(data_type, file_format))
file = Path(data_path).joinpath(file_name)
if file.exists():
if file_format == 'csv':
keys = ('sep', 'header', 'names')
options = {k: kwargs[k] for k in keys if k in kwargs}
df = pd.read_csv(file, **options)
elif file_format == 'json':
keys = ('orient', 'lines')
options = {k: kwargs[k] for k in keys if k in kwargs}
df = pd.read_json(file, **options)
else:
raise Exception('Unsupported file format: {}'.format(file_format))

x = kwargs.get("x", "text")
y = kwargs.get('y', 'labels')
class_sep = kwargs.get('class_sep', ',')
data[data_type] = [(row[x], row[y].split(class_sep)) for _, row in df.iterrows()]
else:
log.warning("Cannot find {} file".format(file))

return data
Loading

0 comments on commit 8cc53c6

Please sign in to comment.