diff --git a/deeppavlov/configs/classifiers/sentiment_imdb_bert.json b/deeppavlov/configs/classifiers/sentiment_imdb_bert.json new file mode 100644 index 0000000000..466ad58690 --- /dev/null +++ b/deeppavlov/configs/classifiers/sentiment_imdb_bert.json @@ -0,0 +1,147 @@ +{ + "dataset_reader": { + "class_name": "imdb_reader", + "data_path": "{DOWNLOADS_PATH}/aclImdb" + }, + "dataset_iterator": { + "class_name": "basic_classification_iterator", + "seed": 42, + "split_seed": 23, + "field_to_split": "train", + "stratify": true, + "split_fields": [ + "train", + "valid" + ], + "split_proportions": [ + 0.9, + 0.1 + ] + }, + "chainer": { + "in": [ + "x" + ], + "in_y": [ + "y" + ], + "pipe": [ + { + "class_name": "bert_preprocessor", + "vocab_file": "{DOWNLOADS_PATH}/bert_models/cased_L-12_H-768_A-12/vocab.txt", + "do_lower_case": false, + "max_seq_length": 450, + "in": [ + "x" + ], + "out": [ + "bert_features" + ] + }, + { + "id": "classes_vocab", + "class_name": "simple_vocab", + "fit_on": [ + "y" + ], + "save_path": "{MODEL_PATH}/classes.dict", + "load_path": "{MODEL_PATH}/classes.dict", + "in": "y", + "out": "y_ids" + }, + { + "in": "y_ids", + "out": "y_onehot", + "class_name": "one_hotter", + "depth": "#classes_vocab.len", + "single_vector": true + }, + { + "class_name": "bert_classifier", + "n_classes": "#classes_vocab.len", + "return_probas": true, + "one_hot_labels": true, + "bert_config_file": "{DOWNLOADS_PATH}/bert_models/cased_L-12_H-768_A-12/bert_config.json", + "pretrained_bert": "{DOWNLOADS_PATH}/bert_models/cased_L-12_H-768_A-12/bert_model.ckpt", + "save_path": "{MODEL_PATH}/model", + "load_path": "{MODEL_PATH}/model", + "keep_prob": 0.5, + "learning_rate": 1e-05, + "learning_rate_drop_patience": 5, + "learning_rate_drop_div": 2.0, + "in": [ + "bert_features" + ], + "in_y": [ + "y_onehot" + ], + "out": [ + "y_pred_probas" + ] + }, + { + "in": "y_pred_probas", + "out": "y_pred_ids", + "class_name": "proba2labels", + "max_proba": true + }, + { + "in": "y_pred_ids", + "out": "y_pred_labels", + "ref": "classes_vocab" + } + ], + "out": [ + "y_pred_labels" + ] + }, + "train": { + "batch_size": 8, + "epochs": 100, + "metrics": [ + "f1_weighted", + "f1_macro", + "sets_accuracy", + { + "name": "roc_auc", + "inputs": [ + "y_onehot", + "y_pred_probas" + ] + } + ], + "show_examples": false, + "pytest_max_batches": 2, + "validation_patience": 5, + "val_every_n_epochs": 1, + "log_every_n_epochs": 1, + "evaluation_targets": [ + "train", + "valid", + "test" + ], + "tensorboard_log_dir": "{MODEL_PATH}/" + }, + "metadata": { + "variables": { + "ROOT_PATH": "~/.deeppavlov", + "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", + "MODEL_PATH": "{MODELS_PATH}/classifiers/sentiment_imdb_bert_v0/" + }, + "requirements": [ + "{DEEPPAVLOV_PATH}/requirements/tf.txt", + "{DEEPPAVLOV_PATH}/requirements/bert_dp.txt" + ], + "labels": { + "telegram_utils": "IntentModel", + "server_utils": "KerasIntentModel" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/bert/cased_L-12_H-768_A-12.zip", + "subdir": "{DOWNLOADS_PATH}/bert_models" + } + ] + } +} diff --git a/deeppavlov/configs/classifiers/sentiment_imdb_conv_bert.json b/deeppavlov/configs/classifiers/sentiment_imdb_conv_bert.json new file mode 100644 index 0000000000..b8d71d5d87 --- /dev/null +++ b/deeppavlov/configs/classifiers/sentiment_imdb_conv_bert.json @@ -0,0 +1,147 @@ +{ + "dataset_reader": { + "class_name": "imdb_reader", + "data_path": "{DOWNLOADS_PATH}/aclImdb" + }, + "dataset_iterator": { + "class_name": "basic_classification_iterator", + "seed": 42, + "split_seed": 23, + "field_to_split": "train", + "stratify": true, + "split_fields": [ + "train", + "valid" + ], + "split_proportions": [ + 0.9, + 0.1 + ] + }, + "chainer": { + "in": [ + "x" + ], + "in_y": [ + "y" + ], + "pipe": [ + { + "class_name": "bert_preprocessor", + "vocab_file": "{DOWNLOADS_PATH}/bert_models/conversational_cased_L-12_H-768_A-12/vocab.txt", + "do_lower_case": false, + "max_seq_length": 450, + "in": [ + "x" + ], + "out": [ + "bert_features" + ] + }, + { + "id": "classes_vocab", + "class_name": "simple_vocab", + "fit_on": [ + "y" + ], + "save_path": "{MODEL_PATH}/classes.dict", + "load_path": "{MODEL_PATH}/classes.dict", + "in": "y", + "out": "y_ids" + }, + { + "in": "y_ids", + "out": "y_onehot", + "class_name": "one_hotter", + "depth": "#classes_vocab.len", + "single_vector": true + }, + { + "class_name": "bert_classifier", + "n_classes": "#classes_vocab.len", + "return_probas": true, + "one_hot_labels": true, + "bert_config_file": "{DOWNLOADS_PATH}/bert_models/conversational_cased_L-12_H-768_A-12/bert_config.json", + "pretrained_bert": "{DOWNLOADS_PATH}/bert_models/conversational_cased_L-12_H-768_A-12/bert_model.ckpt", + "save_path": "{MODEL_PATH}/model", + "load_path": "{MODEL_PATH}/model", + "keep_prob": 0.5, + "learning_rate": 1e-05, + "learning_rate_drop_patience": 5, + "learning_rate_drop_div": 2.0, + "in": [ + "bert_features" + ], + "in_y": [ + "y_onehot" + ], + "out": [ + "y_pred_probas" + ] + }, + { + "in": "y_pred_probas", + "out": "y_pred_ids", + "class_name": "proba2labels", + "max_proba": true + }, + { + "in": "y_pred_ids", + "out": "y_pred_labels", + "ref": "classes_vocab" + } + ], + "out": [ + "y_pred_labels" + ] + }, + "train": { + "batch_size": 8, + "epochs": 100, + "metrics": [ + "f1_weighted", + "f1_macro", + "sets_accuracy", + { + "name": "roc_auc", + "inputs": [ + "y_onehot", + "y_pred_probas" + ] + } + ], + "show_examples": false, + "pytest_max_batches": 2, + "validation_patience": 5, + "val_every_n_epochs": 1, + "log_every_n_epochs": 1, + "evaluation_targets": [ + "train", + "valid", + "test" + ], + "tensorboard_log_dir": "{MODEL_PATH}/" + }, + "metadata": { + "variables": { + "ROOT_PATH": "~/.deeppavlov", + "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", + "MODEL_PATH": "{MODELS_PATH}/classifiers/sentiment_imdb_conv_bert_v0/" + }, + "requirements": [ + "{DEEPPAVLOV_PATH}/requirements/tf.txt", + "{DEEPPAVLOV_PATH}/requirements/bert_dp.txt" + ], + "labels": { + "telegram_utils": "IntentModel", + "server_utils": "KerasIntentModel" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/bert/conversational_cased_L-12_H-768_A-12.tar.gz", + "subdir": "{DOWNLOADS_PATH}/bert_models" + } + ] + } +} diff --git a/deeppavlov/core/common/registry.json b/deeppavlov/core/common/registry.json index 079010b520..3e06783e0f 100644 --- a/deeppavlov/core/common/registry.json +++ b/deeppavlov/core/common/registry.json @@ -56,6 +56,7 @@ "glove": "deeppavlov.models.embedders.glove_embedder:GloVeEmbedder", "go_bot": "deeppavlov.models.go_bot.network:GoalOrientedBot", "hashing_tfidf_vectorizer": "deeppavlov.models.vectorizers.hashing_tfidf_vectorizer:HashingTfIdfVectorizer", + "imdb_reader": "deeppavlov.dataset_readers.imdb_reader:ImdbReader", "insurance_reader": "deeppavlov.dataset_readers.insurance_reader:InsuranceReader", "kb_answer_parser_wikidata": "deeppavlov.models.kbqa.kb_answer_parser_wikidata:KBAnswerParserWikidata", "kbqa_reader": "deeppavlov.dataset_readers.kbqa_reader:KBQAReader", diff --git a/deeppavlov/core/data/utils.py b/deeppavlov/core/data/utils.py index 10e5ef3592..cf65bfe532 100644 --- a/deeppavlov/core/data/utils.py +++ b/deeppavlov/core/data/utils.py @@ -192,6 +192,11 @@ def download_decompress(url: str, download_path: [Path, str], extract_paths=None extracted = extracted_path.exists() if not extracted and not arch_file_path.exists(): simple_download(url, arch_file_path) + else: + if extracted: + log.info(f'Found cached and extracted {url} in {extracted_path}') + else: + log.info(f'Found cached {url} in {arch_file_path}') else: arch_file_path = download_path / file_name simple_download(url, arch_file_path) diff --git a/deeppavlov/dataset_readers/imdb_reader.py b/deeppavlov/dataset_readers/imdb_reader.py new file mode 100644 index 0000000000..d9af6fa72e --- /dev/null +++ b/deeppavlov/dataset_readers/imdb_reader.py @@ -0,0 +1,76 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from logging import getLogger +from typing import List, Dict, Any, Optional, Tuple +from pathlib import Path + +from overrides import overrides + +from deeppavlov.core.common.registry import register +from deeppavlov.core.data.dataset_reader import DatasetReader +from deeppavlov.core.data.utils import download_decompress, mark_done, is_done + +log = getLogger(__name__) + + +@register('imdb_reader') +class ImdbReader(DatasetReader): + """This class downloads and reads the IMDb sentiment classification dataset. + + https://ai.stanford.edu/~amaas/data/sentiment/ + + Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. + (2011). Learning Word Vectors for Sentiment Analysis. The 49th Annual Meeting of the Association + for Computational Linguistics (ACL 2011). + """ + + @overrides + def read(self, data_path: str, url: Optional[str] = None, + *args, **kwargs) -> Dict[str, List[Tuple[Any, Any]]]: + """ + Args: + data_path: A path to a folder with dataset files. + url: A url to the archive with the dataset to download if the data folder is empty. + """ + data_path = Path(data_path) + + if url is None: + url = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" + + if not is_done(data_path): + log.info('[downloading data from {} to {}]'.format(url, data_path)) + download_decompress(url, data_path) + mark_done(data_path) + + alternative_data_path = data_path / "aclImdb" + if alternative_data_path.exists(): + data_path = alternative_data_path + + data = {"train": [], + "test": []} + for data_type in data.keys(): + for label in ["neg", "pos"]: + labelpath = data_path / data_type / label + if not labelpath.exists(): + raise RuntimeError(f"Cannot load data: {labelpath} does not exist") + for filename in labelpath.glob("*.txt"): + with filename.open(encoding='utf-8') as f: + text = f.read() + data[data_type].append((text, [label])) + + if not data[data_type]: + raise RuntimeError(f"Could not load the '{data_type}' dataset, " + "probably data dirs are empty") + + return data