diff --git a/deeppavlov/_meta.py b/deeppavlov/_meta.py index 794f709ee1..8cd67ba2dd 100644 --- a/deeppavlov/_meta.py +++ b/deeppavlov/_meta.py @@ -1,4 +1,4 @@ -__version__ = '1.1.1' +__version__ = '1.2.0' __author__ = 'Neural Networks and Deep Learning lab, MIPT' __description__ = 'An open source library for building end-to-end dialog systems and training chatbots.' __keywords__ = ['NLP', 'NER', 'SQUAD', 'Intents', 'Chatbot'] diff --git a/deeppavlov/configs/classifiers/glue/glue_cola_cased_bert_torch.json b/deeppavlov/configs/classifiers/glue/glue_cola_roberta.json similarity index 88% rename from deeppavlov/configs/classifiers/glue/glue_cola_cased_bert_torch.json rename to deeppavlov/configs/classifiers/glue/glue_cola_roberta.json index e326765368..b93c8b2ec3 100644 --- a/deeppavlov/configs/classifiers/glue/glue_cola_cased_bert_torch.json +++ b/deeppavlov/configs/classifiers/glue/glue_cola_roberta.json @@ -1,8 +1,8 @@ { "dataset_reader": { "class_name": "huggingface_dataset_reader", - "path": "glue", - "name": "cola", + "path": "{COMPETITION}", + "name": "{TASK}", "train": "train", "valid": "validation", "test": "test" @@ -120,11 +120,19 @@ }, "metadata": { "variables": { + "BASE_MODEL": "roberta-large", "ROOT_PATH": "~/.deeppavlov", "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", "MODELS_PATH": "{ROOT_PATH}/models", - "MODEL_PATH": "{MODELS_PATH}/classifiers/glue_cola_torch_cased_bert", - "BASE_MODEL": "bert-base-cased" - } + "COMPETITION": "glue", + "TASK": "cola", + "MODEL_PATH": "{MODELS_PATH}/{COMPETITION}/{TASK}/{BASE_MODEL}" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/v1/glue/glue_cola_roberta.tar.gz", + "subdir": "{MODEL_PATH}" + } + ] } } diff --git a/deeppavlov/configs/classifiers/glue/glue_mrpc_cased_bert_torch.json b/deeppavlov/configs/classifiers/glue/glue_mrpc_roberta.json similarity index 63% rename from deeppavlov/configs/classifiers/glue/glue_mrpc_cased_bert_torch.json rename to deeppavlov/configs/classifiers/glue/glue_mrpc_roberta.json index 990ca20dbc..3561695076 100644 --- a/deeppavlov/configs/classifiers/glue/glue_mrpc_cased_bert_torch.json +++ b/deeppavlov/configs/classifiers/glue/glue_mrpc_roberta.json @@ -1,8 +1,8 @@ { "dataset_reader": { "class_name": "huggingface_dataset_reader", - "path": "glue", - "name": "mrpc", + "path": "{COMPETITION}", + "name": "{TASK}", "train": "train", "valid": "validation", "test": "test" @@ -11,38 +11,46 @@ "class_name": "huggingface_dataset_iterator", "features": ["sentence1", "sentence2"], "label": "label", - "use_label_name": false, "seed": 42 }, "chainer": { "in": ["sentence1", "sentence2"], - "in_y": ["y_ids"], + "in_y": ["y"], "pipe": [ { "class_name": "torch_transformers_preprocessor", "vocab_file": "{BASE_MODEL}", "do_lower_case": false, - "max_seq_length": 100, + "max_seq_length": 256, "in": ["sentence1", "sentence2"], "out": ["bert_features"] }, + { + "id": "classes_vocab", + "class_name": "simple_vocab", + "fit_on": ["y"], + "save_path": "{MODEL_PATH}/classes.dict", + "load_path": "{MODEL_PATH}/classes.dict", + "in": ["y"], + "out": ["y_ids"] + }, { "in": ["y_ids"], "out": ["y_onehot"], "class_name": "one_hotter", - "depth": 2, + "depth": "#classes_vocab.len", "single_vector": true }, { "class_name": "torch_transformers_classifier", - "n_classes": 2, + "n_classes": "#classes_vocab.len", "return_probas": true, "pretrained_bert": "{BASE_MODEL}", "save_path": "{MODEL_PATH}/model", "load_path": "{MODEL_PATH}/model", "optimizer": "AdamW", "optimizer_parameters": { - "lr": 2e-05 + "lr": 1e-06 }, "learning_rate_drop_patience": 3, "learning_rate_drop_div": 2.0, @@ -55,32 +63,42 @@ "out": ["y_pred_ids"], "class_name": "proba2labels", "max_proba": true + }, + { + "in": ["y_pred_ids"], + "out": ["y_pred_labels"], + "ref": "classes_vocab" } ], - "out": ["y_pred_ids"] + "out": ["y_pred_labels"] }, "train": { - "batch_size": 100, - "metrics": [ - "f1", - "accuracy" - ], - "validation_patience": 10, + "batch_size": 4, + "metrics": ["accuracy"], + "epochs": 2, "val_every_n_epochs": 1, "log_every_n_epochs": 1, "show_examples": false, - "evaluation_targets": ["train", "valid"], + "evaluation_targets": ["valid"], "class_name": "torch_trainer", "tensorboard_log_dir": "{MODEL_PATH}/", "pytest_max_batches": 2 }, "metadata": { "variables": { + "BASE_MODEL": "roberta-large", "ROOT_PATH": "~/.deeppavlov", "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", "MODELS_PATH": "{ROOT_PATH}/models", - "MODEL_PATH": "{MODELS_PATH}/classifiers/glue_mrpc_torch_cased_bert", - "BASE_MODEL": "bert-base-cased" - } + "COMPETITION": "glue", + "TASK": "mrpc", + "MODEL_PATH": "{MODELS_PATH}/{COMPETITION}/{TASK}/{BASE_MODEL}" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/v1/glue/glue_mrpc_roberta.tar.gz", + "subdir": "{MODEL_PATH}" + } + ] } } diff --git a/deeppavlov/configs/classifiers/glue/glue_qnli_cased_bert_torch.json b/deeppavlov/configs/classifiers/glue/glue_qnli_roberta.json similarity index 86% rename from deeppavlov/configs/classifiers/glue/glue_qnli_cased_bert_torch.json rename to deeppavlov/configs/classifiers/glue/glue_qnli_roberta.json index f4dcc1cd5c..deba445a2b 100644 --- a/deeppavlov/configs/classifiers/glue/glue_qnli_cased_bert_torch.json +++ b/deeppavlov/configs/classifiers/glue/glue_qnli_roberta.json @@ -1,8 +1,8 @@ { "dataset_reader": { "class_name": "huggingface_dataset_reader", - "path": "glue", - "name": "qnli", + "path": "{COMPETITION}", + "name": "{TASK}", "train": "train", "valid": "validation", "test": "test" @@ -86,11 +86,19 @@ }, "metadata": { "variables": { + "BASE_MODEL": "roberta-large", "ROOT_PATH": "~/.deeppavlov", "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", "MODELS_PATH": "{ROOT_PATH}/models", - "MODEL_PATH": "{MODELS_PATH}/classifiers/glue_qnli_torch_cased_bert", - "BASE_MODEL": "bert-base-cased" - } + "COMPETITION": "glue", + "TASK": "qnli", + "MODEL_PATH": "{MODELS_PATH}/{COMPETITION}/{TASK}/{BASE_MODEL}" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/v1/glue/glue_qnli_roberta.tar.gz", + "subdir": "{MODEL_PATH}" + } + ] } } diff --git a/deeppavlov/configs/classifiers/glue/glue_qqp_cased_bert_torch.json b/deeppavlov/configs/classifiers/glue/glue_qqp_roberta.json similarity index 85% rename from deeppavlov/configs/classifiers/glue/glue_qqp_cased_bert_torch.json rename to deeppavlov/configs/classifiers/glue/glue_qqp_roberta.json index 4a2117c9fc..d7b8dbe6f5 100644 --- a/deeppavlov/configs/classifiers/glue/glue_qqp_cased_bert_torch.json +++ b/deeppavlov/configs/classifiers/glue/glue_qqp_roberta.json @@ -1,8 +1,8 @@ { "dataset_reader": { "class_name": "huggingface_dataset_reader", - "path": "glue", - "name": "qqp", + "path": "{COMPETITION}", + "name": "{TASK}", "train": "train", "valid": "validation", "test": "test" @@ -76,11 +76,19 @@ }, "metadata": { "variables": { + "BASE_MODEL": "roberta-large", "ROOT_PATH": "~/.deeppavlov", "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", "MODELS_PATH": "{ROOT_PATH}/models", - "MODEL_PATH": "{MODELS_PATH}/classifiers/glue_qqp_torch_cased_bert", - "BASE_MODEL" : "bert-base-cased" - } + "COMPETITION": "glue", + "TASK": "qqp", + "MODEL_PATH": "{MODELS_PATH}/{COMPETITION}/{TASK}/{BASE_MODEL}" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/v1/glue/glue_qqp_roberta.tar.gz", + "subdir": "{MODEL_PATH}" + } + ] } } diff --git a/deeppavlov/configs/classifiers/glue/glue_sst2_cased_bert_torch.json b/deeppavlov/configs/classifiers/glue/glue_sst2_roberta.json similarity index 88% rename from deeppavlov/configs/classifiers/glue/glue_sst2_cased_bert_torch.json rename to deeppavlov/configs/classifiers/glue/glue_sst2_roberta.json index 72feec4ee2..a9e4264fa6 100644 --- a/deeppavlov/configs/classifiers/glue/glue_sst2_cased_bert_torch.json +++ b/deeppavlov/configs/classifiers/glue/glue_sst2_roberta.json @@ -1,8 +1,8 @@ { "dataset_reader": { "class_name": "huggingface_dataset_reader", - "path": "glue", - "name": "sst2", + "path": "{COMPETITION}", + "name": "{TASK}", "train": "train", "valid": "validation", "test": "test" @@ -120,11 +120,19 @@ }, "metadata": { "variables": { + "BASE_MODEL": "roberta-large", "ROOT_PATH": "~/.deeppavlov", "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", "MODELS_PATH": "{ROOT_PATH}/models", - "MODEL_PATH": "{MODELS_PATH}/classifiers/glue_sst2_torch_cased_bert", - "BASE_MODEL": "bert-base-cased" - } + "COMPETITION": "glue", + "TASK": "sst2", + "MODEL_PATH": "{MODELS_PATH}/{COMPETITION}/{TASK}/{BASE_MODEL}" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/v1/glue/glue_sst2_roberta.tar.gz", + "subdir": "{MODEL_PATH}" + } + ] } } diff --git a/deeppavlov/configs/classifiers/glue/glue_stsb_cased_bert_torch.json b/deeppavlov/configs/classifiers/glue/glue_stsb_roberta.json similarity index 81% rename from deeppavlov/configs/classifiers/glue/glue_stsb_cased_bert_torch.json rename to deeppavlov/configs/classifiers/glue/glue_stsb_roberta.json index 701c70b30d..a1ead5a93a 100644 --- a/deeppavlov/configs/classifiers/glue/glue_stsb_cased_bert_torch.json +++ b/deeppavlov/configs/classifiers/glue/glue_stsb_roberta.json @@ -1,8 +1,8 @@ { "dataset_reader": { "class_name": "huggingface_dataset_reader", - "path": "glue", - "name": "stsb", + "path": "{COMPETITION}", + "name": "{TASK}", "train": "train", "valid": "validation", "test": "test" @@ -47,7 +47,7 @@ "out": ["y_pred"] }, "train": { - "batch_size": 128, + "batch_size": 32, "metrics": [ "pearson_correlation", "spearman_correlation" @@ -63,11 +63,19 @@ }, "metadata": { "variables": { + "BASE_MODEL": "roberta-large", "ROOT_PATH": "~/.deeppavlov", "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", "MODELS_PATH": "{ROOT_PATH}/models", - "MODEL_PATH": "{MODELS_PATH}/classifiers/glue_stsb_torch_cased_bert", - "BASE_MODEL": "bert-base-cased" - } + "COMPETITION": "glue", + "TASK": "stsb", + "MODEL_PATH": "{MODELS_PATH}/{COMPETITION}/{TASK}/{BASE_MODEL}" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/v1/glue/glue_stsb_roberta.tar.gz", + "subdir": "{MODEL_PATH}" + } + ] } } diff --git a/deeppavlov/configs/entity_extraction/entity_detection_en.json b/deeppavlov/configs/entity_extraction/entity_detection_en.json index 2c45fc8703..d4e936ec65 100644 --- a/deeppavlov/configs/entity_extraction/entity_detection_en.json +++ b/deeppavlov/configs/entity_extraction/entity_detection_en.json @@ -5,17 +5,15 @@ { "class_name": "ner_chunker", "batch_size": 16, - "max_chunk_len" : 180, "max_seq_len" : 300, "vocab_file": "{TRANSFORMER}", "in": ["x"], "out": ["x_chunk", "chunk_nums", "chunk_sentences_offsets", "chunk_sentences"] }, { - "thres_proba": 0.05, + "thres_proba": 0.6, "o_tag": "O", "tags_file": "{NER_PATH}/tag.dict", - "return_entities_with_tags": true, "class_name": "entity_detection_parser", "id": "edp" }, diff --git a/deeppavlov/configs/entity_extraction/entity_detection_ru.json b/deeppavlov/configs/entity_extraction/entity_detection_ru.json index 5ff48e3fd7..3c0f08a417 100644 --- a/deeppavlov/configs/entity_extraction/entity_detection_ru.json +++ b/deeppavlov/configs/entity_extraction/entity_detection_ru.json @@ -5,7 +5,6 @@ { "class_name": "ner_chunker", "batch_size": 16, - "max_chunk_len" : 180, "max_seq_len" : 300, "vocab_file": "{TRANSFORMER}", "in": ["x"], @@ -15,7 +14,6 @@ "thres_proba": 0.05, "o_tag": "O", "tags_file": "{NER_PATH}/tag.dict", - "return_entities_with_tags": true, "class_name": "entity_detection_parser", "id": "edp" }, diff --git a/deeppavlov/configs/entity_extraction/entity_extraction_en.json b/deeppavlov/configs/entity_extraction/entity_extraction_en.json index 188568dcb6..dc0604c614 100644 --- a/deeppavlov/configs/entity_extraction/entity_extraction_en.json +++ b/deeppavlov/configs/entity_extraction/entity_extraction_en.json @@ -9,11 +9,11 @@ }, { "config_path": "{CONFIGS_PATH}/entity_extraction/entity_linking_en.json", - "in": ["entity_substr", "tags", "sentences", "entity_offsets", "sentences_offsets"], - "out": ["entity_ids", "entity_conf", "entity_pages"] + "in": ["entity_substr", "tags", "probas", "sentences", "entity_offsets", "sentences_offsets"], + "out": ["entity_ids", "entity_conf", "entity_pages", "entity_labels"] } ], - "out": ["entity_substr", "tags", "entity_offsets", "entity_ids", "entity_conf", "entity_pages"] + "out": ["entity_substr", "tags", "entity_offsets", "entity_ids", "entity_conf", "entity_pages", "entity_labels"] }, "metadata": { "variables": { diff --git a/deeppavlov/configs/entity_extraction/entity_extraction_ru.json b/deeppavlov/configs/entity_extraction/entity_extraction_ru.json index 941a59a65f..cf0b8fb69a 100644 --- a/deeppavlov/configs/entity_extraction/entity_extraction_ru.json +++ b/deeppavlov/configs/entity_extraction/entity_extraction_ru.json @@ -9,11 +9,11 @@ }, { "config_path": "{CONFIGS_PATH}/entity_extraction/entity_linking_ru.json", - "in": ["entity_substr", "tags", "sentences", "entity_offsets", "sentences_offsets"], - "out": ["entity_ids", "entity_conf", "entity_pages"] + "in": ["entity_substr", "tags", "probas", "sentences", "entity_offsets", "sentences_offsets"], + "out": ["entity_ids", "entity_conf", "entity_pages", "entity_labels"] } ], - "out": ["entity_substr", "tags", "entity_offsets", "entity_ids", "entity_conf", "entity_pages"] + "out": ["entity_substr", "tags", "entity_offsets", "entity_ids", "entity_conf", "entity_pages", "entity_labels"] }, "metadata": { "variables": { diff --git a/deeppavlov/configs/entity_extraction/entity_linking_en.json b/deeppavlov/configs/entity_extraction/entity_linking_en.json index 9faeac0d8b..b0a7cec606 100644 --- a/deeppavlov/configs/entity_extraction/entity_linking_en.json +++ b/deeppavlov/configs/entity_extraction/entity_linking_en.json @@ -1,6 +1,6 @@ { "chainer": { - "in": ["entity_substr", "tags", "sentences", "entity_offsets", "sentences_offsets"], + "in": ["entity_substr", "tags", "probas", "sentences", "entity_offsets", "sentences_offsets"], "pipe": [ { "class_name": "torch_transformers_entity_ranker_infer", @@ -14,10 +14,10 @@ }, { "class_name": "entity_linker", - "in": ["entity_substr", "tags", "sentences", "entity_offsets", "sentences_offsets"], - "out": ["entity_ids", "entity_conf", "entity_pages"], + "in": ["entity_substr", "tags", "probas", "sentences", "entity_offsets", "sentences_offsets"], + "out": ["entity_ids", "entity_conf", "entity_pages", "entity_labels"], "load_path": "{DOWNLOADS_PATH}/entity_linking_eng", - "entities_database_filename": "el_eng.db", + "entities_database_filename": "el_eng_v2.db", "entity_ranker": "#entity_descr_ranking", "rank_in_runtime": true, "num_entities_for_bert_ranking": 20, @@ -26,15 +26,14 @@ "num_entities_to_return": 3, "lemmatize": true, "use_descriptions": true, - "wikidata_file": "{DOWNLOADS_PATH}/wikidata/wikidata_lite.hdt", "use_connections": true, "use_tags": true, "full_paragraph": true, "return_confidences": true, - "lang": "ru" + "lang": "en" } ], - "out": ["entity_ids", "entity_conf", "entity_pages"] + "out": ["entity_ids", "entity_conf", "entity_pages", "entity_labels"] }, "metadata": { "variables": { @@ -45,16 +44,12 @@ }, "download": [ { - "url": "http://files.deeppavlov.ai/deeppavlov_data/entity_linking/el_db_eng.tar.gz", + "url": "http://files.deeppavlov.ai/kbqa/downloads/el_db_eng_v2.tar.gz", "subdir": "{DOWNLOADS_PATH}/entity_linking_eng" }, { "url": "http://files.deeppavlov.ai/deeppavlov_data/entity_linking/el_ranker_eng.tar.gz", "subdir": "{MODELS_PATH}/entity_linking_eng" - }, - { - "url": "http://files.deeppavlov.ai/kbqa/wikidata/wikidata_lite.tar.gz", - "subdir": "{DOWNLOADS_PATH}/wikidata" } ] } diff --git a/deeppavlov/configs/entity_extraction/entity_linking_ru.json b/deeppavlov/configs/entity_extraction/entity_linking_ru.json index 4b8589710c..513edab1d3 100644 --- a/deeppavlov/configs/entity_extraction/entity_linking_ru.json +++ b/deeppavlov/configs/entity_extraction/entity_linking_ru.json @@ -1,6 +1,6 @@ { "chainer": { - "in": ["entity_substr", "tags", "sentences", "entity_offsets", "sentences_offsets"], + "in": ["entity_substr", "tags", "probas", "sentences", "entity_offsets", "sentences_offsets"], "pipe": [ { "class_name": "torch_transformers_entity_ranker_infer", @@ -14,13 +14,15 @@ }, { "class_name": "entity_linker", - "in": ["entity_substr", "tags", "sentences", "entity_offsets", "sentences_offsets"], - "out": ["entity_ids", "entity_conf", "entity_pages"], + "in": ["entity_substr", "tags", "probas", "sentences", "entity_offsets", "sentences_offsets"], + "out": ["entity_ids", "entity_conf", "entity_pages", "entity_labels"], "load_path": "{DOWNLOADS_PATH}/entity_linking_rus", - "entities_database_filename": "el_rus.db", + "entities_database_filename": "el_rus_v2.db", + "words_dict_filename": "{DOWNLOADS_PATH}/entity_linking_rus/words_dict.pickle", + "ngrams_matrix_filename": "{DOWNLOADS_PATH}/entity_linking_rus/ngrams_matrix.npz", "entity_ranker": "#entity_descr_ranking", "rank_in_runtime": true, - "num_entities_for_bert_ranking": 20, + "num_entities_for_bert_ranking": 30, "use_gpu": false, "include_mention": false, "num_entities_to_return": 3, @@ -28,13 +30,20 @@ "use_descriptions": true, "use_connections": true, "use_tags": true, - "wikidata_file": "{DOWNLOADS_PATH}/wikidata/wikidata_lite.hdt", + "kb_filename": "{DOWNLOADS_PATH}/wikidata/wikidata_lite.hdt", + "prefixes": {"entity": ["http://we"], + "rels": {"direct": "http://wpd", + "no_type": "http://wp", + "statement": "http://wps", + "qualifier": "http://wpq" + } + }, "full_paragraph": true, "return_confidences": true, "lang": "ru" } ], - "out": ["entity_ids", "entity_conf", "entity_pages"] + "out": ["entity_ids", "entity_conf", "entity_pages", "entity_labels"] }, "metadata": { "variables": { @@ -45,7 +54,7 @@ }, "download": [ { - "url": "http://files.deeppavlov.ai/deeppavlov_data/entity_linking/el_db_rus.tar.gz", + "url": "http://files.deeppavlov.ai/kbqa/downloads/el_files_rus_v2.tar.gz", "subdir": "{DOWNLOADS_PATH}/entity_linking_rus" }, { diff --git a/deeppavlov/configs/kbqa/kbqa_cq_en.json b/deeppavlov/configs/kbqa/kbqa_cq_en.json index 3c18230646..3db7403cd6 100644 --- a/deeppavlov/configs/kbqa/kbqa_cq_en.json +++ b/deeppavlov/configs/kbqa/kbqa_cq_en.json @@ -1,7 +1,17 @@ { + "dataset_reader": { + "class_name": "lcquad_reader", + "question_types": ["statement_property", "right-subgraph", "simple question left", + "simple question right", "left-subgraph", "rank"], + "num_samples": 100, + "data_path": "{DOWNLOADS_PATH}/lcquad/lcquad2.json" + }, + "dataset_iterator": { + "class_name": "data_learning_iterator" + }, "chainer": { "in": ["x"], - "in_y": ["y"], + "in_y": ["gold_answer_ids", "gold_answer_labels", "gold_query"], "pipe": [ { "class_name": "question_sign_checker", @@ -9,9 +19,38 @@ "out": ["x_punct"] }, { - "config_path": "{CONFIGS_PATH}/entity_extraction/entity_detection_en.json", + "config_path": "{CONFIGS_PATH}/classifiers/query_pr.json", "in": ["x_punct"], - "out": ["entity_substr", "entity_offsets", "entity_positions", "tags", "sentences_offsets", "sentences", "probas"] + "out": ["template_type"] + }, + { + "class_name": "query_formatter", + "query_info": {"unk_var": "?answer", "mid_var": "?ent"}, + "in": ["gold_query"], + "out": ["f_gold_query"] + }, + { + "config_path": "{CONFIGS_PATH}/entity_extraction/entity_detection_en.json", + "overwrite": { + "chainer.pipe.1.make_tags_from_probas": true, + "chainer.pipe.2.ner": { + "config_path": "{CONFIGS_PATH}/ner/ner_ontonotes_bert.json", + "overwrite": { + "chainer.out": ["x_tokens", "tokens_offsets", "y_pred", "probas"], + "chainer.pipe.2.use_crf": false, + "metadata.variables.TRANSFORMER": "distilbert-base-cased", + "metadata.variables.MODEL_PATH": "{MODELS_PATH}/entity_type_detection_distilbert_lcquad2.0" + } + }, + "metadata.variables.NER_PATH": "{MODELS_PATH}/entity_type_detection_distilbert_lcquad2.0" + }, + "in": ["x_punct", "template_type"], + "out": ["entity_type_substr", "entity_offsets", "entity_positions", "tags", "sentences_offsets", "sentences", "probas"] + }, + { + "class_name": "entity_type_split", + "in": ["entity_type_substr", "tags"], + "out": ["entity_substr", "entity_tags", "type_substr"] }, { "class_name": "answer_types_extractor", @@ -22,13 +61,30 @@ "out": ["answer_types", "f_entity_substr", "f_tags"] }, { - "config_path": "{CONFIGS_PATH}/entity_extraction/entity_linking_en.json", + "class_name": "entity_linker", + "load_path": "{DOWNLOADS_PATH}/entity_linking_eng", + "entities_database_filename": "el_db_lcquad2.db", + "num_entities_to_return": 7, + "lemmatize": true, + "use_descriptions": false, + "use_connections": false, + "use_tags": true, + "alias_coef": 1.0, + "prefixes": {"entity": ["http://we"], + "rels": {"direct": "http://wpd", + "no_type": "http://wp", + "statement": "http://wps", + "qualifier": "http://wpq" + } + }, + "return_confidences": true, + "lang": "en", "id": "entity_linker" }, { "class_name": "wiki_parser", "id": "wiki_p", - "wiki_filename": "{DOWNLOADS_PATH}/wikidata/wikidata_lite.hdt", + "wiki_filename": "{DOWNLOADS_PATH}/wikidata/wikidata_full.hdt", "lang": "@en" }, { @@ -38,19 +94,14 @@ "load_path": "{DOWNLOADS_PATH}/wikidata_eng", "templates_filename": "templates_eng.json" }, - { - "config_path": "{CONFIGS_PATH}/classifiers/query_pr.json", - "in": ["x_punct"], - "out": ["template_type"] - }, { "class_name": "rel_ranking_infer", "id": "rel_r_inf", - "ranker": {"config_path": "{CONFIGS_PATH}/ranking/rel_ranking_bert_en.json"}, + "ranker": {"config_path": "{CONFIGS_PATH}/ranking/rel_ranking_roberta_en.json", + "overwrite": {"chainer.out": ["y_pred_probas"]} + }, "wiki_parser": "#wiki_p", "batch_size": 32, - "return_all_possible_answers": true, - "return_answer_ids": false, "rank_answers": true, "load_path": "{DOWNLOADS_PATH}/wikidata_eng", "rel_q2name_filename": "wiki_dict_properties_eng.pickle" @@ -63,26 +114,71 @@ "rel_ranker": "#rel_r_inf", "wiki_parser": "#wiki_p", "load_path": "{DOWNLOADS_PATH}/wikidata", - "rank_rels_filename_1": "rels_0.txt", - "rank_rels_filename_2": "rels_1.txt", - "sparql_queries_filename": "{DOWNLOADS_PATH}/wikidata/sparql_queries.json", + "rels_in_ranking_queries_fname": "rels_in_ranking_queries.json", + "sparql_queries_filename": "{DOWNLOADS_PATH}/wikidata/sparql_queries_eng.json", "entities_to_leave": 5, "rels_to_leave": 10, - "in": ["x_punct", "x_punct", "template_type", "f_entity_substr", "f_tags", "answer_types"], - "out": ["answers"] + "return_answers": false, + "map_query_str_to_kb": [["P0", "http://wd"], ["P00", "http://wl"], ["wd:", "http://we/"], ["wdt:", "http://wpd/"], + [" p:", " http://wp/"], ["ps:", "http://wps/"], ["pq:", "http://wpq/"]], + "kb_prefixes": {"entity": "wd:E", "rel": "wdt:R", "type": "wd:T", "type_rel": "wdt:P", "type_rels": ["P31", "P279"]}, + "gold_query_info": {"unk_var": "?answer", "mid_var": "?ent"}, + "in": ["x_punct", "x_punct", "template_type", "entity_substr", "type_substr", "entity_tags", "probas", "answer_types"], + "out": ["cand_answers", "template_answers"] + }, + { + "class_name": "rel_ranking_infer", + "ranker": {"config_path": "{CONFIGS_PATH}/ranking/path_ranking_nll_roberta_en.json"}, + "wiki_parser": "#wiki_p", + "batch_size": 32, + "nll_path_ranking": true, + "return_elements": ["answer_ids", "queries"], + "rank_answers": true, + "load_path": "{DOWNLOADS_PATH}/wikidata_eng", + "rel_q2name_filename": "wiki_dict_properties_eng.pickle", + "in": ["x_punct", "template_type", "cand_answers", "entity_substr", "template_answers"], + "out": ["answers", "answer_ids", "query"] + } + ], + "out": ["answers", "answer_ids", "query"] + }, + "train": { + "evaluation_targets": ["test"], + "batch_size": 1, + "metrics": [ + { + "name": "kbqa_accuracy", + "inputs": ["x", "answers", "answer_ids", "query", "gold_answer_labels", "gold_answer_ids", "f_gold_query"] } ], - "out": ["answers"] + "class_name": "nn_trainer" }, "metadata": { "variables": { "ROOT_PATH": "~/.deeppavlov", "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", "CONFIGS_PATH": "{DEEPPAVLOV_PATH}/configs" }, "download": [ { - "url": "http://files.deeppavlov.ai/kbqa/wikidata/queries_and_rels.tar.gz", + "url": "http://files.deeppavlov.ai/kbqa/datasets/lcquad2.tar.gz", + "subdir": "{DOWNLOADS_PATH}/lcquad" + }, + { + "url": "http://files.deeppavlov.ai/kbqa/models/entity_type_detection_distilbert_lcquad2.0.tar.gz", + "subdir": "{MODELS_PATH}/entity_type_detection_distilbert_lcquad2.0" + }, + { + "url": "http://files.deeppavlov.ai/kbqa/wikidata/queries_and_rels_lcquad2_v2.tar.gz", + "subdir": "{DOWNLOADS_PATH}/wikidata" + }, + { + "url": "http://files.deeppavlov.ai/kbqa/downloads/el_db_lcquad2.tar.gz", + "subdir": "{DOWNLOADS_PATH}/entity_linking_eng" + }, + { + "url": "http://files.deeppavlov.ai/kbqa/wikidata/wikidata_full.tar.gz", "subdir": "{DOWNLOADS_PATH}/wikidata" }, { diff --git a/deeppavlov/configs/kbqa/kbqa_cq_ru.json b/deeppavlov/configs/kbqa/kbqa_cq_ru.json index 85f14eed91..0e899baec2 100644 --- a/deeppavlov/configs/kbqa/kbqa_cq_ru.json +++ b/deeppavlov/configs/kbqa/kbqa_cq_ru.json @@ -1,16 +1,79 @@ { + "dataset_reader": { + "class_name": "rubq_reader", + "version": "2.0", + "question_types": ["all"], + "num_samples": 100, + "data_path": "{DOWNLOADS_PATH}/rubq/rubq2.0.json" + }, + "dataset_iterator": { + "class_name": "data_learning_iterator" + }, "chainer": { "in": ["x"], - "in_y": ["y"], + "in_y": ["gold_answer_ids", "gold_answer_labels", "gold_query"], "pipe": [ { "class_name": "question_sign_checker", + "delete_brackets": true, "in": ["x"], "out": ["x_punct"] }, { - "config_path": "{CONFIGS_PATH}/entity_extraction/entity_detection_ru.json", + "class_name": "query_formatter", + "query_info": {"unk_var": "?answer", "mid_var": "?ent"}, + "in": ["gold_query"], + "out": ["f_gold_query"] + }, + { + "class_name": "ner_chunker", + "batch_size": 16, + "max_seq_len" : 300, + "vocab_file": "distilbert-base-multilingual-cased", "in": ["x_punct"], + "out": ["x_chunk", "chunk_nums", "chunk_sentences_offsets", "chunk_sentences"] + }, + { + "thres_proba": 0.05, + "o_tag": "O", + "tags_file": "{NER_PATH}/tag.dict", + "class_name": "entity_detection_parser", + "ignored_tags": ["DATE", "CARDINAL", "ORDINAL", "QUANTITY", "PERCENT", "NORP"], + "lang": "ru", + "id": "edp" + }, + { + "thres_proba": 0.05, + "o_tag": "O", + "tags_file": "{NER_PATH2}/tag.dict", + "class_name": "entity_detection_parser", + "ignored_tags": ["T"], + "lang": "ru", + "id": "edp2" + }, + { + "class_name": "ner_chunk_model", + "ner": { + "config_path": "{CONFIGS_PATH}/ner/ner_ontonotes_bert_mult.json", + "overwrite": { + "chainer.pipe.2.use_crf": false, + "metadata.variables.TRANSFORMER": "distilbert-base-multilingual-cased", + "chainer.out": ["x_tokens", "tokens_offsets", "y_pred", "probas"], + "metadata.variables.MODEL_PATH": "{MODELS_PATH}/ner_ontonotes_torch_distilbert_mult" + } + }, + "ner_parser": "#edp", + "ner2": { + "config_path": "{CONFIGS_PATH}/ner/ner_ontonotes_bert_mult.json", + "overwrite": { + "chainer.pipe.2.use_crf": false, + "metadata.variables.TRANSFORMER": "DeepPavlov/distilrubert-small-cased-conversational", + "chainer.out": ["x_tokens", "tokens_offsets", "y_pred", "probas"], + "metadata.variables.MODEL_PATH": "{MODELS_PATH}/entity_detection_rubq" + } + }, + "ner_parser2": "#edp2", + "in": ["x_chunk", "chunk_nums", "chunk_sentences_offsets", "chunk_sentences"], "out": ["entity_substr", "entity_offsets", "entity_positions", "tags", "sentences_offsets", "sentences", "probas"] }, { @@ -22,13 +85,34 @@ "out": ["answer_types", "f_entity_substr", "f_tags"] }, { - "config_path": "{CONFIGS_PATH}/entity_extraction/entity_linking_ru.json", + "class_name": "entity_linker", + "load_path": "{DOWNLOADS_PATH}/entity_linking_rus", + "entities_database_filename": "el_db_rus.db", + "words_dict_filename": "{DOWNLOADS_PATH}/entity_linking_rus/words_dict.pickle", + "ngrams_matrix_filename": "{DOWNLOADS_PATH}/entity_linking_rus/ngrams_matrix.npz", + "include_mention": false, + "num_entities_to_return": 7, + "lemmatize": true, + "use_descriptions": false, + "use_connections": true, + "use_tags": true, + "kb_filename": "{DOWNLOADS_PATH}/wikidata/wikidata_full.hdt", + "prefixes": {"entity": ["http://we"], + "rels": {"direct": "http://wpd", + "no_type": "http://wp", + "statement": "http://wps", + "qualifier": "http://wpq" + } + }, + "return_confidences": true, + "lang": "ru", "id": "entity_linker" }, { "class_name": "wiki_parser", "id": "wiki_p", - "wiki_filename": "{DOWNLOADS_PATH}/wikidata/wikidata_lite.hdt", + "wiki_filename": "{DOWNLOADS_PATH}/wikidata/wikidata_full.hdt", + "max_comb_num": 40000, "lang": "@ru" }, { @@ -36,8 +120,8 @@ "load_path": "{MODELS_PATH}/slovnet_syntax_parser", "navec_filename": "{MODELS_PATH}/slovnet_syntax_parser/navec_news_v1_1B_250K_300d_100q.tar", "syntax_parser_filename": "{MODELS_PATH}/slovnet_syntax_parser/slovnet_syntax_news_v1.tar", - "in": ["x_punct", "entity_offsets"], - "out": ["syntax_info"] + "tree_patterns_filename": "{MODELS_PATH}/slovnet_syntax_parser/tree_patterns.json", + "id": "slovnet_parser" }, { "class_name": "ru_adj_to_noun", @@ -46,11 +130,12 @@ }, { "class_name": "tree_to_sparql", - "sparql_queries_filename": "{DOWNLOADS_PATH}/wikidata/sparql_queries.json", + "sparql_queries_filename": "{DOWNLOADS_PATH}/wikidata/sparql_queries_rus.json", "adj_to_noun": "#adj2noun", - "lang": "rus", - "in": ["syntax_info", "entity_positions"], - "out": ["x_sanitized", "query_nums", "entities_dict", "types_dict"] + "syntax_parser": "#slovnet_parser", + "kb_prefixes": {"entity": "wd:E", "rel": "wdt:R", "type": "wd:T", "type_rel": "wdt:P", "type_rels": ["P31", "P279"]}, + "in": ["x_punct", "entity_substr", "tags", "entity_offsets", "entity_positions", "probas"], + "out": ["x_sanitized", "query_nums", "s_entity_substr", "s_tags", "s_probas", "entities_to_link", "s_types_substr"] }, { "class_name": "template_matcher", @@ -62,13 +147,16 @@ { "class_name": "rel_ranking_infer", "id": "rel_r_inf", - "ranker": {"config_path": "{CONFIGS_PATH}/ranking/rel_ranking_bert_ru.json"}, + "ranker": {"config_path": "{CONFIGS_PATH}/ranking/rel_ranking_nll_bert_ru.json"}, "wiki_parser": "#wiki_p", "batch_size": 32, - "return_all_possible_answers": true, - "return_answer_ids": false, + "nll_rel_ranking": true, + "return_elements": ["answer_ids", "queries"], "load_path": "{DOWNLOADS_PATH}/wikidata_rus", - "rel_q2name_filename": "wiki_dict_properties_rus.pickle" + "rank": false, + "rel_thres": -4.0, + "type_rels": ["P31", "P279"], + "rel_q2name_filename": "wiki_dict_properties_full_rus.pickle" }, { "class_name": "query_generator", @@ -78,29 +166,61 @@ "rel_ranker": "#rel_r_inf", "wiki_parser": "#wiki_p", "load_path": "{DOWNLOADS_PATH}/wikidata", - "rank_rels_filename_1": "rels_0.txt", - "rank_rels_filename_2": "rels_1.txt", - "sparql_queries_filename": "{DOWNLOADS_PATH}/wikidata/sparql_queries.json", + "rels_in_ranking_queries_fname": "rels_in_ranking_queries.json", + "sparql_queries_filename": "{DOWNLOADS_PATH}/wikidata/sparql_queries_rus.json", "entities_to_leave": 9, "rels_to_leave": 10, - "return_all_possible_answers": false, + "max_comb_num": 1000, + "map_query_str_to_kb": [["P0", "http://wd"], ["P00", "http://wl"], ["wd:", "http://we/"], ["wdt:", "http://wpd/"], + [" p:", " http://wp/"], ["ps:", "http://wps/"], ["pq:", "http://wpq/"]], + "kb_prefixes": {"entity": "wd:E", "rel": "wdt:R", "type": "wd:T", "type_rel": "wdt:P", "type_rels": ["P31", "P279"]}, + "gold_query_info": {"unk_var": "?answer", "mid_var": "?ent"}, "syntax_structure_known": true, - "in": ["x_punct", "x_sanitized", "query_nums", "f_entity_substr", "f_tags", "answer_types"], - "out": ["answers"] + "in": ["x_punct", "x_sanitized", "query_nums", "s_entity_substr", "s_types_substr", "s_tags", "s_probas", "answer_types", "entities_to_link"], + "out": ["answers", "answer_ids", "query"] } ], - "out": ["answers"] + "out": ["answers", "answer_ids", "query"] + }, + "train": { + "evaluation_targets": ["test"], + "batch_size": 1, + "metrics": [ + { + "name": "kbqa_accuracy", + "inputs": ["x", "answers", "answer_ids", "query", "gold_answer_labels", "gold_answer_ids", "f_gold_query"] + } + ], + "class_name": "nn_trainer" }, "metadata": { "variables": { "ROOT_PATH": "~/.deeppavlov", "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", "MODELS_PATH": "{ROOT_PATH}/models", - "CONFIGS_PATH": "{DEEPPAVLOV_PATH}/configs" + "CONFIGS_PATH": "{DEEPPAVLOV_PATH}/configs", + "NER_PATH": "{MODELS_PATH}/ner_ontonotes_torch_distilbert_mult", + "NER_PATH2": "{MODELS_PATH}/entity_detection_rubq" }, "download": [ { - "url": "http://files.deeppavlov.ai/kbqa/wikidata/queries_and_rels.tar.gz", + "url": "http://files.deeppavlov.ai/datasets/rubq2.0.tar.gz", + "subdir": "{DOWNLOADS_PATH}/rubq" + }, + { + "url": "http://files.deeppavlov.ai/kbqa/downloads/el_files_rus.tar.gz", + "subdir": "{DOWNLOADS_PATH}/entity_linking_rus" + }, + { + "url": "http://files.deeppavlov.ai/kbqa/models/ner_ontonotes_torch_distilbert_mult.tar.gz", + "subdir": "{MODELS_PATH}/ner_ontonotes_torch_distilbert_mult" + }, + { + "url": "http://files.deeppavlov.ai/kbqa/models/entity_detection_rubq.tar.gz", + "subdir": "{MODELS_PATH}/entity_detection_rubq" + }, + { + "url": "http://files.deeppavlov.ai/kbqa/wikidata/queries_and_rels_rus_v2.tar.gz", "subdir": "{DOWNLOADS_PATH}/wikidata" }, { @@ -108,8 +228,12 @@ "subdir": "{DOWNLOADS_PATH}/wikidata_rus" }, { - "url": "http://files.deeppavlov.ai/deeppavlov_data/slovnet_syntax_parser.tar.gz", + "url": "http://files.deeppavlov.ai/deeppavlov_data/syntax_parser/slovnet_syntax_parser_v2.tar.gz", "subdir": "{MODELS_PATH}/slovnet_syntax_parser" + }, + { + "url": "http://files.deeppavlov.ai/kbqa/wikidata/wikidata_full.tar.gz", + "subdir": "{DOWNLOADS_PATH}/wikidata" } ] } diff --git a/deeppavlov/configs/multitask/mt_glue.json b/deeppavlov/configs/multitask/mt_glue.json index ca6e529213..e19a8a06d4 100644 --- a/deeppavlov/configs/multitask/mt_glue.json +++ b/deeppavlov/configs/multitask/mt_glue.json @@ -267,7 +267,8 @@ "log_every_n_epochs": 1, "show_examples": false, "evaluation_targets": ["valid"], - "class_name": "torch_trainer" + "class_name": "torch_trainer", + "pytest_max_batches": 2 }, "metadata": { "variables": { diff --git a/deeppavlov/configs/multitask/multitask_example.json b/deeppavlov/configs/multitask/multitask_example.json index cf8cf7ad47..04abf0ebf2 100644 --- a/deeppavlov/configs/multitask/multitask_example.json +++ b/deeppavlov/configs/multitask/multitask_example.json @@ -219,7 +219,8 @@ "log_every_n_epochs": 1, "show_examples": false, "evaluation_targets": ["valid"], - "class_name": "torch_trainer" + "class_name": "torch_trainer", + "pytest_max_batches": 2 }, "metadata": { "variables": { diff --git a/deeppavlov/configs/ranking/path_ranking_nll_roberta_en.json b/deeppavlov/configs/ranking/path_ranking_nll_roberta_en.json new file mode 100644 index 0000000000..0022e88b76 --- /dev/null +++ b/deeppavlov/configs/ranking/path_ranking_nll_roberta_en.json @@ -0,0 +1,43 @@ +{ + "chainer": { + "in": ["question", "rels"], + "pipe": [ + { + "class_name": "path_ranking_preprocessor", + "vocab_file": "{TRANSFORMER}", + "do_lower_case": false, + "additional_special_tokens": ["", "", "", "", "", "", ""], + "max_seq_length": 96, + "in": ["question", "rels"], + "out": ["bert_features"] + }, + { + "class_name": "torch_transformers_nll_ranker", + "in": ["bert_features"], + "out": ["model_output"], + "return_probas": true, + "save_path": "{MODEL_PATH}/model", + "load_path": "{MODEL_PATH}/model", + "encoder_save_path": "{MODEL_PATH}/encoder", + "linear_save_path": "{MODEL_PATH}/linear", + "model_name": "in_batch_ranking_model", + "pretrained_bert": "{TRANSFORMER}", + "learning_rate_drop_patience": 5, + "learning_rate_drop_div": 1.5 + } + ], + "out": ["model_output"] + }, + "metadata": { + "variables": { + "TRANSFORMER": "haisongzhang/roberta-tiny-cased", + "MODEL_PATH": "~/.deeppavlov/models/classifiers/path_ranking_nll_roberta_lcquad2" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/kbqa/models/path_ranking_nll_roberta_lcquad2.tar.gz", + "subdir": "{MODEL_PATH}" + } + ] + } +} diff --git a/deeppavlov/configs/ranking/rel_ranking_bert_ru.json b/deeppavlov/configs/ranking/rel_ranking_bert_ru.json deleted file mode 100644 index 8bc7209c03..0000000000 --- a/deeppavlov/configs/ranking/rel_ranking_bert_ru.json +++ /dev/null @@ -1,106 +0,0 @@ -{ - "dataset_reader": { - "class_name": "sq_reader", - "data_path": "{DOWNLOADS_PATH}/rel_ranking_rus/rubq_rel_ranking.pickle" - }, - "dataset_iterator": { - "class_name": "basic_classification_iterator", - "seed": 42 - }, - "chainer": { - "in": ["question", "rel_list"], - "in_y": ["y"], - "pipe": [ - { - "class_name": "rel_ranking_preprocessor", - "vocab_file": "{TRANSFORMER}", - "do_lower_case": true, - "max_seq_length": 64, - "add_special_tokens": ["", "", ""], - "in": ["question", "rel_list"], - "out": ["bert_features"] - }, - { - "id": "classes_vocab", - "class_name": "simple_vocab", - "fit_on": ["y"], - "save_path": "{MODEL_PATH}/classes.dict", - "load_path": "{MODEL_PATH}/classes.dict", - "in": ["y"], - "out": ["y_ids"] - }, - { - "in": ["y_ids"], - "out": ["y_onehot"], - "class_name": "one_hotter", - "depth": "#classes_vocab.len", - "single_vector": true - }, - { - "class_name": "torch_transformers_classifier", - "n_classes": "#classes_vocab.len", - "return_probas": "true", - "num_special_tokens": 3, - "pretrained_bert": "{TRANSFORMER}", - "save_path": "{MODEL_PATH}/model", - "load_path": "{MODEL_PATH}/model", - "optimizer": "AdamW", - "optimizer_parameters": {"lr": 1e-05}, - "learning_rate_drop_patience": 5, - "learning_rate_drop_div": 2.0, - "in": ["bert_features"], - "in_y": ["y_ids"], - "out": ["y_pred_probas"] - }, - { - "in": ["y_pred_probas"], - "out": ["y_pred_ids"], - "class_name": "proba2labels", - "max_proba": true - }, - { - "in": ["y_pred_ids"], - "out": ["y_pred_labels"], - "ref": "classes_vocab" - } - ], - "out": ["y_pred_probas"] - }, - "train": { - "epochs": 3, - "batch_size": 30, - "metrics": [ - { - "name": "roc_auc", - "inputs": ["y_onehot", "y_pred_probas"] - }, - "accuracy", - "f1_macro" - ], - "validation_patience": 5, - "val_every_n_batches": 100, - "log_every_n_batches": 100, - "show_examples": false, - "evaluation_targets": ["train", "valid", "test"], - "class_name": "torch_trainer" - }, - "metadata": { - "variables": { - "ROOT_PATH": "~/.deeppavlov", - "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", - "MODELS_PATH": "{ROOT_PATH}/models", - "TRANSFORMER": "DeepPavlov/distilrubert-tiny-cased-conversational", - "MODEL_PATH": "{MODELS_PATH}/classifiers/rel_ranking_bert_rus_torch" - }, - "download": [ - { - "url": "http://files.deeppavlov.ai/kbqa/wikidata/rel_ranking_bert_rus_torch.tar.gz", - "subdir": "{MODEL_PATH}" - }, - { - "url": "http://files.deeppavlov.ai/kbqa/wikidata/rubq_rel_ranking.pickle", - "subdir": "{DOWNLOADS_PATH}/rel_ranking_rus" - } - ] - } -} diff --git a/deeppavlov/configs/ranking/rel_ranking_nll_bert_ru.json b/deeppavlov/configs/ranking/rel_ranking_nll_bert_ru.json new file mode 100644 index 0000000000..10390e23d7 --- /dev/null +++ b/deeppavlov/configs/ranking/rel_ranking_nll_bert_ru.json @@ -0,0 +1,45 @@ +{ + "chainer": { + "in": ["question", "rels"], + "pipe": [ + { + "class_name": "path_ranking_preprocessor", + "vocab_file": "{TRANSFORMER}", + "do_lower_case": false, + "max_seq_length": 96, + "in": ["question", "rels"], + "out": ["bert_features"] + }, + { + "class_name": "torch_transformers_nll_ranker", + "in": ["bert_features"], + "out": ["model_output"], + "return_probas": true, + "save_path": "{MODEL_PATH}/model", + "load_path": "{MODEL_PATH}/model", + "encoder_save_path": "{MODEL_PATH}/encoder", + "linear_save_path": "{MODEL_PATH}/linear", + "model_name": "in_batch_ranking_model", + "pretrained_bert": "{TRANSFORMER}", + "learning_rate_drop_patience": 4, + "learning_rate_drop_div": 1.5 + } + ], + "out": ["model_output"] + }, + "metadata": { + "variables": { + "ROOT_PATH": "~/.deeppavlov", + "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", + "TRANSFORMER": "DeepPavlov/rubert-base-cased", + "MODEL_PATH": "{MODELS_PATH}/classifiers/rel_ranking_nll_bert_ru" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/kbqa/models/rel_ranking_nll_bert_ru.tar.gz", + "subdir": "{MODEL_PATH}" + } + ] + } +} diff --git a/deeppavlov/configs/ranking/rel_ranking_bert_en.json b/deeppavlov/configs/ranking/rel_ranking_roberta_en.json similarity index 79% rename from deeppavlov/configs/ranking/rel_ranking_bert_en.json rename to deeppavlov/configs/ranking/rel_ranking_roberta_en.json index ae836ebcc9..95e394a480 100644 --- a/deeppavlov/configs/ranking/rel_ranking_bert_en.json +++ b/deeppavlov/configs/ranking/rel_ranking_roberta_en.json @@ -1,7 +1,7 @@ { "dataset_reader": { "class_name": "sq_reader", - "data_path": "{DOWNLOADS_PATH}/rel_ranking_eng/lcquad_rel_ranking.pickle" + "data_path": "{DOWNLOADS_PATH}/rel_ranking_eng/lcquad_one_rel_ranking.json" }, "dataset_iterator": { "class_name": "basic_classification_iterator", @@ -14,9 +14,8 @@ { "class_name": "rel_ranking_preprocessor", "vocab_file": "{TRANSFORMER}", - "do_lower_case": true, + "do_lower_case": false, "max_seq_length": 64, - "add_special_tokens": ["", "", ""], "in": ["question", "rel_list"], "out": ["bert_features"] }, @@ -39,8 +38,7 @@ { "class_name": "torch_transformers_classifier", "n_classes": "#classes_vocab.len", - "return_probas": "true", - "num_special_tokens": 3, + "return_probas": true, "pretrained_bert": "{TRANSFORMER}", "save_path": "{MODEL_PATH}/model", "load_path": "{MODEL_PATH}/model", @@ -64,24 +62,20 @@ "ref": "classes_vocab" } ], - "out": ["y_pred_probas"] + "out": ["y_pred_labels"] }, "train": { "epochs": 3, "batch_size": 30, "metrics": [ - { - "name": "roc_auc", - "inputs": ["y_onehot", "y_pred_probas"] - }, "accuracy", "f1_macro" ], - "validation_patience": 5, + "validation_patience": 10, "val_every_n_batches": 100, "log_every_n_batches": 100, "show_examples": false, - "evaluation_targets": ["train", "valid", "test"], + "evaluation_targets": ["valid", "test"], "class_name": "torch_trainer" }, "metadata": { @@ -90,11 +84,11 @@ "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", "MODELS_PATH": "{ROOT_PATH}/models", "TRANSFORMER": "haisongzhang/roberta-tiny-cased", - "MODEL_PATH": "{MODELS_PATH}/classifiers/rel_ranking_bert_eng_torch" + "MODEL_PATH": "{MODELS_PATH}/classifiers/rel_ranking_roberta_en" }, "download": [ { - "url": "http://files.deeppavlov.ai/kbqa/wikidata/rel_ranking_bert_eng_torch.tar.gz", + "url": "http://files.deeppavlov.ai/kbqa/models/rel_ranking_roberta_en.tar.gz", "subdir": "{MODEL_PATH}" }, { diff --git a/deeppavlov/configs/russian_super_glue/russian_superglue_danetqa_rubert.json b/deeppavlov/configs/russian_super_glue/russian_superglue_danetqa_rubert.json index bae5cc4bba..123e2bcf37 100644 --- a/deeppavlov/configs/russian_super_glue/russian_superglue_danetqa_rubert.json +++ b/deeppavlov/configs/russian_super_glue/russian_superglue_danetqa_rubert.json @@ -5,7 +5,9 @@ "name": "{TASK}", "train": "train", "valid": "validation", - "test": "test" + "test": "test", + "data_url": "http://files.deeppavlov.ai/datasets/russian_super_glue/DaNetQA", + "ignore_verifications": true }, "dataset_iterator": { "class_name": "huggingface_dataset_iterator", diff --git a/deeppavlov/configs/russian_super_glue/russian_superglue_lidirus_rubert.json b/deeppavlov/configs/russian_super_glue/russian_superglue_lidirus_rubert.json index ae0e6d2446..4b29607c06 100644 --- a/deeppavlov/configs/russian_super_glue/russian_superglue_lidirus_rubert.json +++ b/deeppavlov/configs/russian_super_glue/russian_superglue_lidirus_rubert.json @@ -3,7 +3,9 @@ "class_name": "huggingface_dataset_reader", "path": "{COMPETITION}", "name": "{TASK}", - "test": "test" + "test": "test", + "data_url": "http://files.deeppavlov.ai/datasets/russian_super_glue/LiDiRus", + "ignore_verifications": true }, "dataset_iterator": { "class_name": "huggingface_dataset_iterator", diff --git a/deeppavlov/configs/russian_super_glue/russian_superglue_muserc_rubert.json b/deeppavlov/configs/russian_super_glue/russian_superglue_muserc_rubert.json index d55f0372c2..5d8dac7478 100644 --- a/deeppavlov/configs/russian_super_glue/russian_superglue_muserc_rubert.json +++ b/deeppavlov/configs/russian_super_glue/russian_superglue_muserc_rubert.json @@ -5,7 +5,9 @@ "name": "{TASK}", "train": "train", "valid": "validation", - "test": "test" + "test": "test", + "data_url": "http://files.deeppavlov.ai/datasets/russian_super_glue/MuSeRC", + "ignore_verifications": true }, "dataset_iterator": { "class_name": "huggingface_dataset_iterator", diff --git a/deeppavlov/configs/russian_super_glue/russian_superglue_parus_rubert.json b/deeppavlov/configs/russian_super_glue/russian_superglue_parus_rubert.json index c2c6f7c233..19dfecf93a 100644 --- a/deeppavlov/configs/russian_super_glue/russian_superglue_parus_rubert.json +++ b/deeppavlov/configs/russian_super_glue/russian_superglue_parus_rubert.json @@ -5,7 +5,9 @@ "name": "{TASK}", "train": "train", "valid": "validation", - "test": "test" + "test": "test", + "data_url": "http://files.deeppavlov.ai/datasets/russian_super_glue/PARus", + "ignore_verifications": true }, "dataset_iterator": { "class_name": "huggingface_dataset_iterator", diff --git a/deeppavlov/configs/russian_super_glue/russian_superglue_rcb_rubert.json b/deeppavlov/configs/russian_super_glue/russian_superglue_rcb_rubert.json index a26894fdba..ae5689a243 100644 --- a/deeppavlov/configs/russian_super_glue/russian_superglue_rcb_rubert.json +++ b/deeppavlov/configs/russian_super_glue/russian_superglue_rcb_rubert.json @@ -5,7 +5,9 @@ "name": "{TASK}", "train": "train", "valid": "validation", - "test": "test" + "test": "test", + "data_url": "http://files.deeppavlov.ai/datasets/russian_super_glue/RCB", + "ignore_verifications": true }, "dataset_iterator": { "class_name": "huggingface_dataset_iterator", diff --git a/deeppavlov/configs/russian_super_glue/russian_superglue_rucos_rubert.json b/deeppavlov/configs/russian_super_glue/russian_superglue_rucos_rubert.json index 63d01f3fc6..48ea406237 100644 --- a/deeppavlov/configs/russian_super_glue/russian_superglue_rucos_rubert.json +++ b/deeppavlov/configs/russian_super_glue/russian_superglue_rucos_rubert.json @@ -6,6 +6,8 @@ "train": "train", "valid": "validation", "test": "test", + "data_url": "http://files.deeppavlov.ai/datasets/russian_super_glue/RuCoS", + "ignore_verifications": true, "downsample_ratio": [1.8, 1.8, 1], "do_index_correction": false }, diff --git a/deeppavlov/configs/russian_super_glue/russian_superglue_russe_rubert.json b/deeppavlov/configs/russian_super_glue/russian_superglue_russe_rubert.json index 1c6115f365..35a2e39f2b 100644 --- a/deeppavlov/configs/russian_super_glue/russian_superglue_russe_rubert.json +++ b/deeppavlov/configs/russian_super_glue/russian_superglue_russe_rubert.json @@ -5,7 +5,9 @@ "name": "{TASK}", "train": "train", "valid": "validation", - "test": "test" + "test": "test", + "data_url": "http://files.deeppavlov.ai/datasets/russian_super_glue/RUSSE", + "ignore_verifications": true }, "dataset_iterator": { "class_name": "huggingface_dataset_iterator", diff --git a/deeppavlov/configs/russian_super_glue/russian_superglue_rwsd_rubert.json b/deeppavlov/configs/russian_super_glue/russian_superglue_rwsd_rubert.json index b6eb696cc9..ee577d998a 100644 --- a/deeppavlov/configs/russian_super_glue/russian_superglue_rwsd_rubert.json +++ b/deeppavlov/configs/russian_super_glue/russian_superglue_rwsd_rubert.json @@ -5,7 +5,9 @@ "name": "{TASK}", "train": "train", "valid": "validation", - "test": "test" + "test": "test", + "data_url": "http://files.deeppavlov.ai/datasets/russian_super_glue/RWSD", + "ignore_verifications": true }, "dataset_iterator": { "class_name": "huggingface_dataset_iterator", diff --git a/deeppavlov/configs/russian_super_glue/russian_superglue_terra_rubert.json b/deeppavlov/configs/russian_super_glue/russian_superglue_terra_rubert.json index 5eaebd3124..41d29329bc 100644 --- a/deeppavlov/configs/russian_super_glue/russian_superglue_terra_rubert.json +++ b/deeppavlov/configs/russian_super_glue/russian_superglue_terra_rubert.json @@ -5,7 +5,9 @@ "name": "{TASK}", "train": "train", "valid": "validation", - "test": "test" + "test": "test", + "data_url": "http://files.deeppavlov.ai/datasets/russian_super_glue/TERRa", + "ignore_verifications": true }, "dataset_iterator": { "class_name": "huggingface_dataset_iterator", diff --git a/deeppavlov/core/common/registry.json b/deeppavlov/core/common/registry.json index 6969143433..654361889e 100644 --- a/deeppavlov/core/common/registry.json +++ b/deeppavlov/core/common/registry.json @@ -14,6 +14,7 @@ "document_chunker": "deeppavlov.models.preprocessors.odqa_preprocessors:DocumentChunker", "entity_detection_parser": "deeppavlov.models.entity_extraction.entity_detection_parser:EntityDetectionParser", "entity_linker": "deeppavlov.models.entity_extraction.entity_linking:EntityLinker", + "entity_type_split": "deeppavlov.models.entity_extraction.entity_detection_parser:entity_type_split", "faq_reader": "deeppavlov.dataset_readers.faq_reader:FaqDatasetReader", "fasttext": "deeppavlov.models.embedders.fasttext_embedder:FasttextEmbedder", "fit_trainer": "deeppavlov.core.trainers.fit_trainer:FitTrainer", @@ -22,6 +23,7 @@ "huggingface_dataset_reader": "deeppavlov.dataset_readers.huggingface_dataset_reader:HuggingFaceDatasetReader", "imdb_reader": "deeppavlov.dataset_readers.imdb_reader:ImdbReader", "kenlm_elector": "deeppavlov.models.spelling_correction.electors.kenlm_elector:KenlmElector", + "lcquad_reader": "deeppavlov.dataset_readers.sq_reader:LCQuADReader", "line_reader": "deeppavlov.dataset_readers.line_reader:LineReader", "logit_ranker": "deeppavlov.models.doc_retrieval.logit_ranker:LogitRanker", "mask": "deeppavlov.models.preprocessors.mask:Mask", @@ -42,10 +44,12 @@ "one_hotter": "deeppavlov.models.preprocessors.one_hotter:OneHotter", "params_search": "deeppavlov.core.common.params_search:ParamsSearch", "paraphraser_reader": "deeppavlov.dataset_readers.paraphraser_reader:ParaphraserReader", + "path_ranking_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:PathRankingPreprocessor", "pop_ranker": "deeppavlov.models.doc_retrieval.pop_ranker:PopRanker", "proba2labels": "deeppavlov.models.classifiers.proba2labels:Proba2Labels", + "query_formatter": "deeppavlov.models.kbqa.query_generator:QueryFormatter", "query_generator": "deeppavlov.models.kbqa.query_generator:QueryGenerator", - "question_sign_checker": "deeppavlov.models.entity_extraction.entity_detection_parser:question_sign_checker", + "question_sign_checker": "deeppavlov.models.entity_extraction.entity_detection_parser:QuestionSignChecker", "re_classifier": "deeppavlov.models.relation_extraction.relation_extraction_bert:REBertModel", "re_postprocessor": "deeppavlov.models.preprocessors.re_preprocessor:REPostprocessor", "re_preprocessor": "deeppavlov.models.preprocessors.re_preprocessor:REPreprocessor", @@ -53,7 +57,8 @@ "rel_ranking_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:RelRankingPreprocessor", "rel_ranking_reader": "deeppavlov.dataset_readers.rel_ranking_reader:ParaphraserReader", "response_base_loader": "deeppavlov.models.preprocessors.response_base_loader:ResponseBaseLoader", - "ru_adj_to_noun": "deeppavlov.models.kbqa.tree_to_sparql:RuAdjToNoun", + "ru_adj_to_noun": "deeppavlov.models.kbqa.ru_adj_to_noun:RuAdjToNoun", + "rubq_reader": "deeppavlov.dataset_readers.sq_reader:RuBQReader", "rured_reader": "deeppavlov.dataset_readers.rured_reader:RuREDDatasetReader", "russian_words_vocab": "deeppavlov.vocabs.typos:RussianWordsVocab", "sanitizer": "deeppavlov.models.preprocessors.sanitizer:Sanitizer", @@ -65,7 +70,7 @@ "spelling_error_model": "deeppavlov.models.spelling_correction.brillmoore.error_model:ErrorModel", "spelling_levenshtein": "deeppavlov.models.spelling_correction.levenshtein.searcher_component:LevenshteinSearcherComponent", "split_tokenizer": "deeppavlov.models.tokenizers.split_tokenizer:SplitTokenizer", - "sq_reader": "deeppavlov.dataset_readers.sq_reader:OntonotesReader", + "sq_reader": "deeppavlov.dataset_readers.sq_reader:SQReader", "sqlite_iterator": "deeppavlov.dataset_iterators.sqlite_iterator:SQLiteDataIterator", "squad_bert_ans_postprocessor": "deeppavlov.models.preprocessors.squad_preprocessor:SquadBertAnsPostprocessor", "squad_bert_ans_preprocessor": "deeppavlov.models.preprocessors.squad_preprocessor:SquadBertAnsPreprocessor", @@ -95,6 +100,7 @@ "torch_transformers_multiplechoice": "deeppavlov.models.torch_bert.torch_transformers_multiplechoice:TorchTransformersMultiplechoiceModel", "torch_transformers_multiplechoice_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersMultiplechoicePreprocessor", "torch_transformers_ner_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersNerPreprocessor", + "torch_transformers_nll_ranker": "deeppavlov.models.torch_bert.torch_transformers_nll_ranking:TorchTransformersNLLRanker", "torch_transformers_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchTransformersPreprocessor", "torch_transformers_sequence_tagger": "deeppavlov.models.torch_bert.torch_transformers_sequence_tagger:TorchTransformersSequenceTagger", "torch_transformers_squad": "deeppavlov.models.torch_bert.torch_transformers_squad:TorchTransformersSquad", diff --git a/deeppavlov/core/common/requirements_registry.json b/deeppavlov/core/common/requirements_registry.json index f47311c0c7..420f136844 100644 --- a/deeppavlov/core/common/requirements_registry.json +++ b/deeppavlov/core/common/requirements_registry.json @@ -32,6 +32,10 @@ "nltk_moses_tokenizer": [ "{DEEPPAVLOV_PATH}/requirements/sacremoses.txt" ], + "path_ranking_preprocessor": [ + "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", + "{DEEPPAVLOV_PATH}/requirements/transformers.txt" + ], "query_generator": [ "{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt", "{DEEPPAVLOV_PATH}/requirements/hdt.txt", @@ -60,14 +64,15 @@ "{DEEPPAVLOV_PATH}/requirements/transformers.txt" ], "ru_adj_to_noun": [ - "{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt", - "{DEEPPAVLOV_PATH}/requirements/udapi.txt" + "{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt" ], "russian_words_vocab": [ "{DEEPPAVLOV_PATH}/requirements/lxml.txt" ], "slovnet_syntax_parser": [ - "{DEEPPAVLOV_PATH}/requirements/slovnet.txt" + "{DEEPPAVLOV_PATH}/requirements/slovnet.txt", + "{DEEPPAVLOV_PATH}/requirements/razdel.txt", + "{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt" ], "spelling_error_model": [ "{DEEPPAVLOV_PATH}/requirements/lxml.txt" @@ -133,6 +138,10 @@ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" ], + "torch_transformers_nll_ranker": [ + "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", + "{DEEPPAVLOV_PATH}/requirements/transformers.txt" + ], "torch_transformers_preprocessor": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" @@ -161,7 +170,7 @@ ], "tree_to_sparql": [ "{DEEPPAVLOV_PATH}/requirements/udapi.txt", - "{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt", + "{DEEPPAVLOV_PATH}/requirements/razdel.txt", "{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt" ], "typos_custom_reader": [ diff --git a/deeppavlov/core/models/torch_model.py b/deeppavlov/core/models/torch_model.py index af6edb409e..f43b862df4 100644 --- a/deeppavlov/core/models/torch_model.py +++ b/deeppavlov/core/models/torch_model.py @@ -19,7 +19,6 @@ from typing import Optional import torch -from overrides import overrides from deeppavlov.core.common.errors import ConfigError from deeppavlov.core.models.nn_model import NNModel @@ -130,7 +129,6 @@ def init_from_opt(self, model_func: str) -> None: def is_data_parallel(self) -> bool: return isinstance(self.model, torch.nn.DataParallel) - @overrides def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: """Load model from `fname` (if `fname` is not given, use `self.load_path`) to `self.model` along with the optimizer `self.optimizer`, optionally `self.lr_scheduler`. @@ -187,7 +185,6 @@ def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: log.debug(f"Init from scratch. Load path {self.load_path} is not provided.") self.init_from_opt(model_func) - @overrides def save(self, fname: Optional[str] = None, *args, **kwargs) -> None: """Save torch model to `fname` (if `fname` is not given, use `self.save_path`). Checkpoint includes `model_state_dict`, `optimizer_state_dict`, and `epochs_done` (number of training epochs). @@ -224,7 +221,6 @@ def save(self, fname: Optional[str] = None, *args, **kwargs) -> None: # return it back to device (necessary if it was on `cuda`) self.model.to(self.device) - @overrides def process_event(self, event_name: str, data: dict) -> None: """Process event. After epoch, increase `self.epochs_done`. After validation, decrease learning rate in `self.learning_rate_drop_div` times (not lower than `self.min_learning_rate`) diff --git a/deeppavlov/dataset_iterators/sqlite_iterator.py b/deeppavlov/dataset_iterators/sqlite_iterator.py index e18622ecfc..6d89ce851e 100644 --- a/deeppavlov/dataset_iterators/sqlite_iterator.py +++ b/deeppavlov/dataset_iterators/sqlite_iterator.py @@ -18,8 +18,6 @@ from random import Random from typing import List, Any, Dict, Optional, Union, Generator, Tuple -from overrides import overrides - from deeppavlov.core.commands.utils import expand_path from deeppavlov.core.common.registry import register from deeppavlov.core.data.data_fitting_iterator import DataFittingIterator @@ -73,7 +71,6 @@ def __init__(self, load_path: Union[str, Path], batch_size: Optional[int] = None self.shuffle = shuffle self.random = Random(seed) - @overrides def get_doc_ids(self) -> List[Any]: """Get document ids. @@ -112,7 +109,6 @@ def map_doc2idx(self) -> Dict[int, Any]: "SQLite iterator: The size of the database is {} documents".format(len(doc2idx))) return doc2idx - @overrides def get_doc_content(self, doc_id: Any) -> Optional[str]: """Get document content by id. @@ -132,7 +128,6 @@ def get_doc_content(self, doc_id: Any) -> Optional[str]: cursor.close() return result if result is None else result[0] - @overrides def gen_batches(self, batch_size: int, shuffle: bool = None) \ -> Generator[Tuple[List[str], List[int]], Any, None]: """Gen batches of documents. diff --git a/deeppavlov/dataset_readers/basic_classification_reader.py b/deeppavlov/dataset_readers/basic_classification_reader.py index 81a6738492..8ef767b368 100644 --- a/deeppavlov/dataset_readers/basic_classification_reader.py +++ b/deeppavlov/dataset_readers/basic_classification_reader.py @@ -17,7 +17,6 @@ from pathlib import Path import pandas as pd -from overrides import overrides from deeppavlov.core.common.registry import register from deeppavlov.core.data.dataset_reader import DatasetReader @@ -32,7 +31,6 @@ class BasicClassificationDatasetReader(DatasetReader): Class provides reading dataset in .csv format """ - @overrides def read(self, data_path: str, url: str = None, format: str = "csv", class_sep: str = None, *args, **kwargs) -> dict: diff --git a/deeppavlov/dataset_readers/docred_reader.py b/deeppavlov/dataset_readers/docred_reader.py index d84b528c06..479854d041 100644 --- a/deeppavlov/dataset_readers/docred_reader.py +++ b/deeppavlov/dataset_readers/docred_reader.py @@ -22,7 +22,6 @@ import numpy as np import pandas as pd -from overrides import overrides from deeppavlov.core.commands.utils import expand_path from deeppavlov.core.common.registry import register @@ -35,7 +34,6 @@ class DocREDDatasetReader(DatasetReader): """ Class to read the datasets in DocRED format""" - @overrides def read( self, data_path: str, diff --git a/deeppavlov/dataset_readers/huggingface_dataset_reader.py b/deeppavlov/dataset_readers/huggingface_dataset_reader.py index 2e62bf966c..d4300dac66 100644 --- a/deeppavlov/dataset_readers/huggingface_dataset_reader.py +++ b/deeppavlov/dataset_readers/huggingface_dataset_reader.py @@ -19,7 +19,6 @@ from typing import Dict, Optional, List, Union from datasets import load_dataset, Dataset, Features, ClassLabel, concatenate_datasets -from overrides import overrides from deeppavlov.core.common.registry import register from deeppavlov.core.data.dataset_reader import DatasetReader @@ -30,7 +29,6 @@ class HuggingFaceDatasetReader(DatasetReader): """Adds HuggingFace Datasets https://huggingface.co/datasets/ to DeepPavlov """ - @overrides def read(self, path: str, name: Optional[str] = None, diff --git a/deeppavlov/dataset_readers/imdb_reader.py b/deeppavlov/dataset_readers/imdb_reader.py index d9af6fa72e..32f8134729 100644 --- a/deeppavlov/dataset_readers/imdb_reader.py +++ b/deeppavlov/dataset_readers/imdb_reader.py @@ -15,8 +15,6 @@ from typing import List, Dict, Any, Optional, Tuple from pathlib import Path -from overrides import overrides - from deeppavlov.core.common.registry import register from deeppavlov.core.data.dataset_reader import DatasetReader from deeppavlov.core.data.utils import download_decompress, mark_done, is_done @@ -35,7 +33,6 @@ class ImdbReader(DatasetReader): for Computational Linguistics (ACL 2011). """ - @overrides def read(self, data_path: str, url: Optional[str] = None, *args, **kwargs) -> Dict[str, List[Tuple[Any, Any]]]: """ diff --git a/deeppavlov/dataset_readers/rured_reader.py b/deeppavlov/dataset_readers/rured_reader.py index 8c717a9579..cb950a3ed2 100644 --- a/deeppavlov/dataset_readers/rured_reader.py +++ b/deeppavlov/dataset_readers/rured_reader.py @@ -4,7 +4,6 @@ from typing import Dict, List, Tuple from pathlib import Path from logging import getLogger -from overrides import overrides from deeppavlov.core.common.registry import register from deeppavlov.core.data.dataset_reader import DatasetReader @@ -16,7 +15,6 @@ class RuREDDatasetReader(DatasetReader): """ Class to read the datasets in RuRED format""" - @overrides def read(self, data_path: str, rel2id: Dict = None) -> Dict[str, List[Tuple]]: """ This class processes the RuRED relation extraction dataset diff --git a/deeppavlov/dataset_readers/sq_reader.py b/deeppavlov/dataset_readers/sq_reader.py index 00949a6cb5..2530fcae77 100644 --- a/deeppavlov/dataset_readers/sq_reader.py +++ b/deeppavlov/dataset_readers/sq_reader.py @@ -12,18 +12,89 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import pickle +from typing import List from deeppavlov.core.common.registry import register from deeppavlov.core.data.dataset_reader import DatasetReader +from deeppavlov.core.common.file import load_pickle +from deeppavlov.core.common.file import read_json @register('sq_reader') -class OntonotesReader(DatasetReader): - """Class to read training datasets in OntoNotes format""" +class SQReader(DatasetReader): + """Class to read training datasets""" - def read(self, data_path: str): - with open(data_path, 'rb') as f: - dataset = pickle.load(f) + def read(self, data_path: str, valid_size: int = None): + if str(data_path).endswith(".pickle"): + dataset = load_pickle(data_path) + elif str(data_path).endswith(".json"): + dataset = read_json(data_path) + else: + raise TypeError(f'Unsupported file type: {data_path}') + if valid_size: + dataset["valid"] = dataset["valid"][:valid_size] return dataset + + +@register('rubq_reader') +class RuBQReader(SQReader): + """Class to read RuBQ datasets""" + + def read(self, data_path: str, version: str = "2.0", question_types: List[str] = ["all"], + not_include_question_types: List[str] = None, num_samples: int = -1): + dataset = super().read(data_path) + for data_type in ["valid", "test"]: + samples = dataset[data_type] + samples = [sample for sample in samples if float(sample["RuBQ_version"]) <= float(version) and + (any(tp in sample["tags"] for tp in question_types) or question_types == ["all"])] + if not_include_question_types: + samples = [sample for sample in samples if all([tp not in sample["tags"] + for tp in not_include_question_types])] + samples = [self.preprocess(sample) for sample in samples] + if num_samples > 0: + samples = samples[:num_samples] + dataset[data_type] = samples + return dataset + + def preprocess(self, sample): + question = sample.get("question_text", "") + answers = sample.get("answers", []) + answer_ids = [elem.get("value", "").split("/")[-1] for elem in answers] + answer_labels = [elem.get("label", "").split("/")[-1] for elem in answers] + query = sample.get("query", "") + if query is None: + query = "" + else: + query = query.replace("\n", " ").replace(" ", " ") + return [question, [answer_ids, answer_labels, query]] + + +@register('lcquad_reader') +class LCQuADReader(SQReader): + """Class to read LCQuAD dataset""" + + def read(self, data_path: str, question_types: List[str] = "all", + not_include_question_types: List[str] = None, num_samples: int = -1): + dataset = super().read(data_path) + for data_type in ["valid", "test"]: + samples = dataset[data_type] + samples = [sample for sample in samples if (any(tp == sample["subgraph"] for tp in question_types) \ + or question_types == ["all"])] + if not_include_question_types: + samples = [sample for sample in samples + if sample["subgraph"] not in not_include_question_types] + samples = [self.preprocess(sample) for sample in samples] + if num_samples > 0: + samples = samples[:num_samples] + dataset[data_type] = samples + return dataset + + def preprocess(self, sample): + question = sample.get("question", "") + answers = sample.get("answer", []) + answer_labels = sample.get("answer_label", []) + query = sample.get("sparql_wikidata", "") + return [question, [answers, answer_labels, query]] diff --git a/deeppavlov/deep.py b/deeppavlov/deep.py index 9cee573c6e..d13265321c 100644 --- a/deeppavlov/deep.py +++ b/deeppavlov/deep.py @@ -20,7 +20,6 @@ from deeppavlov.core.common.cross_validation import calc_cv_score from deeppavlov.core.common.file import find_config from deeppavlov.download import deep_download -from deeppavlov.utils.agent import start_rabbit_service from deeppavlov.utils.pip_wrapper import install_from_config from deeppavlov.utils.server import start_model_server from deeppavlov.utils.socket import start_socket_server @@ -30,8 +29,8 @@ parser = argparse.ArgumentParser() parser.add_argument("mode", help="select a mode, train or interact", type=str, - choices={'train', 'evaluate', 'interact', 'predict', 'riseapi', 'risesocket', 'agent-rabbit', - 'download', 'install', 'crossval'}) + choices={'train', 'evaluate', 'interact', 'predict', 'riseapi', 'risesocket', 'download', 'install', + 'crossval'}) parser.add_argument("config_path", help="path to a pipeline json config", type=str) parser.add_argument("-e", "--start-epoch-num", dest="start_epoch_num", default=None, @@ -54,15 +53,6 @@ parser.add_argument("--socket-type", default="TCP", type=str, choices={"TCP", "UNIX"}) parser.add_argument("--socket-file", default="/tmp/deeppavlov_socket.s", type=str) -parser.add_argument("-sn", "--service-name", default=None, help="service name for agent-rabbit mode", type=str) -parser.add_argument("-an", "--agent-namespace", default=None, help="dp-agent namespace name", type=str) -parser.add_argument("-ul", "--utterance-lifetime", default=None, help="message expiration in seconds", type=int) -parser.add_argument("-rh", "--rabbit-host", default=None, help="RabbitMQ server host", type=str) -parser.add_argument("-rp", "--rabbit-port", default=None, help="RabbitMQ server port", type=int) -parser.add_argument("-rl", "--rabbit-login", default=None, help="RabbitMQ server login", type=str) -parser.add_argument("-rpwd", "--rabbit-password", default=None, help="RabbitMQ server password", type=str) -parser.add_argument("-rvh", "--rabbit-virtualhost", default=None, help="RabbitMQ server virtualhost", type=str) - def main(): args = parser.parse_args() @@ -85,17 +75,6 @@ def main(): start_model_server(pipeline_config_path, args.https, args.key, args.cert, port=args.port) elif args.mode == 'risesocket': start_socket_server(pipeline_config_path, args.socket_type, port=args.port, socket_file=args.socket_file) - elif args.mode == 'agent-rabbit': - start_rabbit_service(model_config=pipeline_config_path, - service_name=args.service_name, - agent_namespace=args.agent_namespace, - batch_size=args.batch_size, - utterance_lifetime_sec=args.utterance_lifetime, - rabbit_host=args.rabbit_host, - rabbit_port=args.rabbit_port, - rabbit_login=args.rabbit_login, - rabbit_password=args.rabbit_password, - rabbit_virtualhost=args.rabbit_virtualhost) elif args.mode == 'predict': predict_on_stream(pipeline_config_path, args.batch_size, args.file_path) elif args.mode == 'crossval': diff --git a/deeppavlov/metrics/accuracy.py b/deeppavlov/metrics/accuracy.py index 86602ed508..dd52ce5743 100644 --- a/deeppavlov/metrics/accuracy.py +++ b/deeppavlov/metrics/accuracy.py @@ -14,12 +14,16 @@ import itertools -from typing import List, Iterable +import re +from logging import getLogger +from typing import List import numpy as np from deeppavlov.core.common.metrics_registry import register_metric +log = getLogger(__name__) + @register_metric('accuracy') def accuracy(y_true: [list, np.ndarray], y_predicted: [list, np.ndarray]) -> float: @@ -47,6 +51,31 @@ def _are_equal(y1, y2): return correct / examples_len if examples_len else 0 +@register_metric('kbqa_accuracy') +def kbqa_accuracy(questions_batch, pred_answer_labels_batch, pred_answer_ids_batch, pred_query_batch, + gold_answer_labels_batch, gold_answer_ids_batch, gold_query_batch) -> float: + num_samples = len(pred_answer_ids_batch) + correct = 0 + for question, pred_answer_label, pred_answer_ids, pred_query, gold_answer_labels, gold_answer_ids, gold_query in \ + zip(questions_batch, pred_answer_labels_batch, pred_answer_ids_batch, pred_query_batch, + gold_answer_labels_batch, gold_answer_ids_batch, gold_query_batch): + found_date = False + if pred_answer_ids and gold_answer_ids and re.findall(r"[\d]{3,4}", pred_answer_ids[0]) and \ + re.findall(r"[\d]{3,4}", pred_answer_ids[0]) == re.findall(r"[\d]{3,4}", gold_answer_ids[0]): + found_date = True + found_label = False + if len(gold_answer_labels) == 1 and len(pred_answer_label) > 1 and pred_answer_label == gold_answer_labels[0]: + found_label = True + no_answer = False + if pred_answer_label == "Not Found" and not gold_answer_ids: + no_answer = True + if set(pred_answer_ids) == set(gold_answer_ids) or gold_query in pred_query or found_date or found_label \ + or no_answer: + correct += 1 + log.debug(f"question: {question} -- gold_answer_ids: {gold_answer_ids} -- pred_answer_ids: {pred_answer_ids}") + return correct / num_samples if num_samples else 0 + + @register_metric('multitask_accuracy') def multitask_accuracy(*args) -> float: """ @@ -178,13 +207,3 @@ def round_accuracy(y_true, y_predicted): examples_len = len(y_true) correct = sum([y1 == y2 for y1, y2 in zip(y_true, predictions)]) return correct / examples_len if examples_len else 0 - - -@register_metric('kbqa_accuracy') -def kbqa_accuracy(y_true, y_predicted): - total_correct = 0 - for answer_true, answer_predicted in zip(y_true, y_predicted): - if answer_predicted in answer_true: - total_correct += 1 - - return total_correct / len(y_true) if len(y_true) else 0 diff --git a/deeppavlov/models/classifiers/torch_classification_model.py b/deeppavlov/models/classifiers/torch_classification_model.py index ceb209810a..11581a29f4 100644 --- a/deeppavlov/models/classifiers/torch_classification_model.py +++ b/deeppavlov/models/classifiers/torch_classification_model.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -from overrides import overrides from typing import List, Union, Optional import numpy as np @@ -133,7 +132,6 @@ def __call__(self, texts: List[np.ndarray], *args) -> Union[List[List[float]], L else: return np.argmax(outputs, axis=-1).tolist() - @overrides def process_event(self, event_name: str, data: dict): """Process event after epoch diff --git a/deeppavlov/models/embedders/abstract_embedder.py b/deeppavlov/models/embedders/abstract_embedder.py index c9a52c2b70..34fd08200b 100644 --- a/deeppavlov/models/embedders/abstract_embedder.py +++ b/deeppavlov/models/embedders/abstract_embedder.py @@ -18,7 +18,6 @@ from typing import List, Union, Iterator import numpy as np -from overrides import overrides from deeppavlov.core.data.utils import zero_pad from deeppavlov.core.models.component import Component @@ -56,14 +55,12 @@ def __init__(self, load_path: Union[str, Path], pad_zero: bool = False, mean: bo self.model = None self.load() - @overrides def save(self) -> None: """ Class does not save loaded model again as it is not trained during usage """ raise NotImplementedError - @overrides def __call__(self, batch: List[List[str]], mean: bool = None) -> List[Union[list, np.ndarray]]: """ Embed sentences from batch diff --git a/deeppavlov/models/embedders/fasttext_embedder.py b/deeppavlov/models/embedders/fasttext_embedder.py index bb0f5cfb79..df694f7652 100644 --- a/deeppavlov/models/embedders/fasttext_embedder.py +++ b/deeppavlov/models/embedders/fasttext_embedder.py @@ -18,7 +18,6 @@ import fasttext import numpy as np -from overrides import overrides from deeppavlov.core.common.registry import register from deeppavlov.models.embedders.abstract_embedder import Embedder @@ -54,7 +53,6 @@ def load(self) -> None: self.model = fasttext.load_model(str(self.load_path)) self.dim = self.model.get_dimension() - @overrides def __iter__(self) -> Iterator[str]: """ Iterate over all words from fastText model vocabulary diff --git a/deeppavlov/models/embedders/tfidf_weighted_embedder.py b/deeppavlov/models/embedders/tfidf_weighted_embedder.py index 880138de33..ab93ec05a6 100644 --- a/deeppavlov/models/embedders/tfidf_weighted_embedder.py +++ b/deeppavlov/models/embedders/tfidf_weighted_embedder.py @@ -16,7 +16,6 @@ from typing import List, Union, Optional, Tuple import numpy as np -from overrides import overrides from deeppavlov.core.commands.utils import expand_path from deeppavlov.core.common.errors import ConfigError @@ -177,7 +176,6 @@ def space_detokenizer(batch: List[List[str]]) -> List[str]: """ return [" ".join(tokens) for tokens in batch] - @overrides def __call__(self, batch: List[List[str]], tags_batch: Optional[List[List[str]]] = None, mean: bool = None, *args, **kwargs) -> List[Union[list, np.ndarray]]: """ diff --git a/deeppavlov/models/entity_extraction/entity_detection_parser.py b/deeppavlov/models/entity_extraction/entity_detection_parser.py index 7719122c3d..7fb2073823 100644 --- a/deeppavlov/models/entity_extraction/entity_detection_parser.py +++ b/deeppavlov/models/entity_extraction/entity_detection_parser.py @@ -12,20 +12,60 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from collections import defaultdict -from typing import List, Tuple, Union, Dict +from logging import getLogger +from string import punctuation +from typing import List, Tuple, Union, Any import numpy as np +from nltk.corpus import stopwords from deeppavlov.core.commands.utils import expand_path from deeppavlov.core.common.registry import register from deeppavlov.core.models.component import Component +log = getLogger(__name__) +punctuation = punctuation.replace('+', '') + @register('question_sign_checker') -def question_sign_checker(questions: List[str]) -> List[str]: - """Adds question sign if it is absent or replaces dots in the end with question sign.""" - return [question if question.endswith('?') else f'{question.rstrip(".")}?' for question in questions] +class QuestionSignChecker: + def __init__(self, delete_brackets: bool = False, **kwargs): + self.delete_brackets = delete_brackets + self.replace_tokens = [(" '", ' "'), ("' ", '" '), (" ?", "?"), (" ", " ")] + + def __call__(self, questions: List[str]) -> List[str]: + """Adds question sign if it is absent or replaces dots in the end with question sign.""" + questions_clean = [] + for question in questions: + question = question if question.endswith('?') else f'{question.rstrip(".")}?' + if self.delete_brackets: + brackets_text = re.findall(r"(\(.*?\))", question) + for elem in brackets_text: + question = question.replace(elem, " ") + for old_tok, new_tok in self.replace_tokens: + question = question.replace(old_tok, new_tok) + questions_clean.append(question) + return questions_clean + + +@register('entity_type_split') +def entity_type_split(entities_batch: List[List[str]], tags_batch: List[List[str]]) -> Tuple[ + List[List[str]], List[List[str]], List[List[str]]]: + f_entities_batch, f_types_batch, f_tags_batch = [], [], [] + for entities_list, tags_list in zip(entities_batch, tags_batch): + f_entities_list, f_types_list, f_tags_list = [], [], [] + for entity, tag in zip(entities_list, tags_list): + if tag != "T": + f_entities_list.append(entity) + f_tags_list.append(tag.lower()) + else: + f_types_list.append(entity) + f_entities_batch.append(f_entities_list) + f_tags_batch.append(f_tags_list) + f_types_batch.append(f_types_list) + return f_entities_batch, f_tags_batch, f_types_batch @register('entity_detection_parser') @@ -33,25 +73,27 @@ class EntityDetectionParser(Component): """This class parses probabilities of tokens to be a token from the entity substring.""" def __init__(self, o_tag: str, tags_file: str, entity_tags: List[str] = None, ignore_points: bool = False, - return_entities_with_tags: bool = False, thres_proba: float = 0.8, **kwargs): + thres_proba: float = 0.8, make_tags_from_probas: bool = False, lang: str = "en", + ignored_tags: List[str] = None, **kwargs): """ Args: o_tag: tag for tokens which are neither entities nor types tags_file: filename with NER tags entity_tags: tags for entities ignore_points: whether to consider points as separate symbols - return_entities_with_tags: whether to return a dict of tags (keys) and list of entity substrings (values) - or simply a list of entity substrings thres_proba: if the probability of the tag is less than thres_proba, we assign the tag as 'O' + make_tags_from_probas: whether to define token tags from confidences from sequence tagging model + lang: language of texts + ignored_tags: not used tags of entities """ self.entity_tags = entity_tags self.o_tag = o_tag self.ignore_points = ignore_points - self.return_entities_with_tags = return_entities_with_tags self.thres_proba = thres_proba self.tag_ind_dict = {} with open(str(expand_path(tags_file))) as fl: tags = [line.split('\t')[0] for line in fl.readlines()] + self.tags = tags if self.entity_tags is None: self.entity_tags = list( {tag.split('-')[1] for tag in tags if len(tag.split('-')) > 1}.difference({self.o_tag})) @@ -64,15 +106,23 @@ def __init__(self, o_tag: str, tags_file: str, entity_tags: List[str] = None, ig for ind in tag_ind: self.tag_ind_dict[ind] = entity_tag self.tag_ind_dict[0] = self.o_tag + self.make_tags_from_probas = make_tags_from_probas + if lang == "en": + self.stopwords = set(stopwords.words("english")) + elif lang == "ru": + self.stopwords = set(stopwords.words("russian")) + else: + raise ValueError(f'Unsupported lang value: "{lang}". Only "en" and "ru" are allowed.') + self.ignored_tags = ignored_tags or [] def __call__(self, question_tokens_batch: List[List[str]], tokens_info_batch: List[List[List[float]]], tokens_probas_batch: np.ndarray) -> \ - Tuple[List[Union[List[str], Dict[str, List[str]]]], List[List[str]], - List[Union[List[int], Dict[str, List[List[int]]]]]]: + Tuple[List[dict], List[dict], List[dict]]: """ Args: - question_tokens: tokenized questions - token_probas: list of probabilities of question tokens + question_tokens_batch: tokenized questions + tokens_info_batch: list of tags of question tokens + tokens_probas_probas: list of probabilities of question tokens Returns: Batch of dicts where keys are tags and values are substrings corresponding to tags Batch of substrings which correspond to entity types @@ -81,14 +131,19 @@ def __call__(self, question_tokens_batch: List[List[str]], tokens_info_batch: Li entities_batch = [] positions_batch = [] probas_batch = [] - for tokens, tokens_info, probas in zip(question_tokens_batch, tokens_info_batch, tokens_probas_batch): - entities, positions, entities_probas = self.entities_from_tags(tokens, tokens_info, probas) + for tokens, tags, probas in \ + zip(question_tokens_batch, tokens_info_batch, tokens_probas_batch): + if self.make_tags_from_probas: + tags, _ = self.tags_from_probas(tokens, probas) + tags = self.correct_quotes(tokens, tags, probas) + tags = self.correct_tags(tokens, tags) + entities, positions, entities_probas = self.entities_from_tags(tokens, tags, probas) entities_batch.append(entities) positions_batch.append(positions) probas_batch.append(entities_probas) return entities_batch, positions_batch, probas_batch - def tags_from_probas(self, tokens, probas): + def tags_from_probas(self, tokens: List[str], probas: np.array) -> Tuple[List[Union[str, List[str]]], List[Any]]: """ This method makes a list of tags from a list of probas for tags Args: @@ -101,18 +156,95 @@ def tags_from_probas(self, tokens, probas): tags = [] tag_probas = [] for token, proba in zip(tokens, probas): - tag_num = np.argmax(proba) - if tag_num in self.et_prob_ind: - if proba[tag_num] < self.thres_proba: - tag_num = 0 + if proba[0] < self.thres_proba: + tag_num = np.argmax(proba[1:]) + 1 else: tag_num = 0 - tags.append(self.tag_ind_dict[tag_num]) + tags.append(self.tags[tag_num]) tag_probas.append(proba[tag_num]) return tags, tag_probas - def entities_from_tags(self, tokens, tags, tag_probas): + def correct_tags(self, tokens: List[str], tags: List[str]) -> List[str]: + for i in range(len(tags) - 2): + if len(tags[i]) > 1 and tags[i].startswith("B-"): + tag = tags[i].split("-")[1] + if tags[i + 2] == f"I-{tag}" and tags[i + 1] != f"I-{tag}": + tags[i + 1] = f"I-{tag}" + if tokens[i + 1] in '«' and tags[i] != "O": + tags[i] = "O" + tags[i + 1] = "O" + if len(tags[i]) > 1 and tags[i].split("-")[1] == "EVENT": + found_n = -1 + for j in range(i + 1, i + 3): + if re.findall(r"[\d]{3,4}", tokens[j]): + found_n = j + break + if found_n > 0: + for j in range(i + 1, found_n + 1): + tags[j] = "I-EVENT" + if i < len(tokens) - 3 and len(tokens[i]) == 1 and tokens[i + 1] == "." and len(tokens[i + 2]) == 1 \ + and tokens[i + 3] == "." and tags[i + 2].startswith("B-"): + tag = tags[i + 2].split("-")[1] + tags[i] = f"B-{tag}" + tags[i + 1] = f"I-{tag}" + tags[i + 2] = f"I-{tag}" + return tags + + def correct_quotes(self, tokens: List[str], tags: List[str], probas: np.array) -> List[str]: + quotes = {"«": "»", '"': '"'} + for i in range(len(tokens)): + if tokens[i] in {"«", '"'}: + quote_start = tokens[i] + end_pos = 0 + for j in range(i + 1, len(tokens)): + if tokens[j] == quotes[quote_start]: + end_pos = j + break + if end_pos and end_pos != i + 1: + probas_sum = np.sum(probas[i + 1:end_pos], axis=0) + tags_probas = {} + for tag in self.entity_prob_ind: + for ind in self.entity_prob_ind[tag]: + if tag not in tags_probas: + tags_probas[tag] = probas_sum[ind] + else: + tags_probas[tag] += probas_sum[ind] + tags_probas = list(tags_probas.items()) + tags_probas = sorted(tags_probas, key=lambda x: x[1], reverse=True) + found_tag = "" + for tag, _ in tags_probas: + if tag != "PERSON": + found_tag = tag + break + if found_tag: + tags[i + 1] = f"B-{found_tag}" + for j in range(i + 2, end_pos): + tags[j] = f"I-{found_tag}" + return tags + + def add_entity(self, entity: str, c_tag: str) -> None: + replace_tokens = [(' - ', '-'), ("'s", ''), (' .', '.'), ('{', ''), ('}', ''), + (' ', ' '), ('"', "'"), ('(', ''), (')', ''), (' +', '+')] + if entity and (entity[-1] in punctuation or entity[-1] == "»"): + entity = entity[:-1] + self.ent_pos_dict[c_tag] = self.ent_pos_dict[c_tag][:-1] + if entity and (entity[0] in punctuation or entity[0] == "«"): + entity = entity[1:] + self.ent_pos_dict[c_tag] = self.ent_pos_dict[c_tag][1:] + entity = ' '.join(entity) + for old, new in replace_tokens: + entity = entity.replace(old, new) + if entity and entity.lower() not in self.stopwords: + cur_probas = self.ent_probas_dict[c_tag] + self.ents_pos_probas_dict[c_tag].append((entity, self.ent_pos_dict[c_tag], + round(sum(cur_probas) / len(cur_probas), 4))) + self.ent_dict[c_tag] = [] + self.ent_pos_dict[c_tag] = [] + self.ent_probas_dict[c_tag] = [] + + def entities_from_tags(self, tokens: List[str], tags: List[str], + tag_probas: List[List[float]]) -> Tuple[dict, dict, dict]: """ This method makes lists of substrings corresponding to entities and entity types and a list of indices of tokens which correspond to entities @@ -126,64 +258,40 @@ def entities_from_tags(self, tokens, tags, tag_probas): list of indices of tokens which correspond to entities (or a dict of tags (keys) and list of indices of entity tokens) """ - entities_dict = defaultdict(list) - entity_dict = defaultdict(list) - entity_positions_dict = defaultdict(list) - entities_positions_dict = defaultdict(list) - entities_probas_dict = defaultdict(list) - entity_probas_dict = defaultdict(list) - replace_tokens = [(' - ', '-'), ("'s", ''), (' .', ''), ('{', ''), ('}', ''), - (' ', ' '), ('"', "'"), ('(', ''), (')', '')] - + self.ent_dict = defaultdict(list) + self.ent_pos_dict = defaultdict(list) + self.ent_probas_dict = defaultdict(list) + self.ents_pos_probas_dict = defaultdict(list) cnt = 0 for n, (tok, tag, probas) in enumerate(zip(tokens, tags, tag_probas)): if tag.split('-')[-1] in self.entity_tags: f_tag = tag.split("-")[-1] - if tag.startswith("B-") and any(entity_dict.values()): - for c_tag, entity in entity_dict.items(): - entity = ' '.join(entity) - for old, new in replace_tokens: - entity = entity.replace(old, new) - if entity: - entities_dict[c_tag].append(entity) - entities_positions_dict[c_tag].append(entity_positions_dict[c_tag]) - cur_probas = entity_probas_dict[c_tag] - entities_probas_dict[c_tag].append(round(sum(cur_probas) / len(cur_probas), 4)) - entity_dict[c_tag] = [] - entity_positions_dict[c_tag] = [] - entity_probas_dict[c_tag] = [] - - entity_dict[f_tag].append(tok) - entity_positions_dict[f_tag].append(cnt) - entity_probas_dict[f_tag].append(probas[self.tags_ind[tag]]) - - elif any(entity_dict.values()): - for tag, entity in entity_dict.items(): + if tag.startswith("B-") and any(self.ent_dict.values()): + for c_tag, entity in self.ent_dict.items(): + self.add_entity(entity, c_tag) + self.ent_dict[f_tag].append(tok) + self.ent_pos_dict[f_tag].append(cnt) + self.ent_probas_dict[f_tag].append(probas[self.tags_ind[tag]]) + + elif any(self.ent_dict.values()): + for tag, entity in self.ent_dict.items(): c_tag = tag.split("-")[-1] - entity = ' '.join(entity) - for old, new in replace_tokens: - entity = entity.replace(old, new) - if entity: - entities_dict[c_tag].append(entity) - entities_positions_dict[c_tag].append(entity_positions_dict[c_tag]) - cur_probas = entity_probas_dict[c_tag] - entities_probas_dict[c_tag].append(round(sum(cur_probas) / len(cur_probas), 4)) - - entity_dict[c_tag] = [] - entity_positions_dict[c_tag] = [] - entity_probas_dict[c_tag] = [] + self.add_entity(entity, c_tag) cnt += 1 + if any(self.ent_dict.values()): + for tag, entity in self.ent_dict.items(): + c_tag = tag.split("-")[-1] + self.add_entity(entity, c_tag) - entities_list = [entity for tag, entities in entities_dict.items() for entity in entities] - entities_positions_list = [position for tag, positions in entities_positions_dict.items() - for position in positions] - entities_probas_list = [proba for tag, probas in entities_probas_dict.items() for proba in probas] + self.ents_pos_probas_dict = {tag: elements for tag, elements in self.ents_pos_probas_dict.items() + if tag not in self.ignored_tags} - entities_dict = dict(entities_dict) - entities_positions_dict = dict(entities_positions_dict) - entities_probas_dict = dict(entities_probas_dict) + for tag in self.ents_pos_probas_dict: + ents_pos_proba = self.ents_pos_probas_dict[tag] - if self.return_entities_with_tags: - return entities_dict, entities_positions_dict, entities_probas_dict - else: - return entities_list, entities_positions_list, entities_probas_list + entities_dict = {tag: [ent[0] for ent in ents] for tag, ents in self.ents_pos_probas_dict.items()} + entities_positions_dict = {tag: [ent[1] for ent in ents] for tag, ents in self.ents_pos_probas_dict.items()} + entities_probas_dict = {tag: [ent[2] for ent in ents] for tag, ents in self.ents_pos_probas_dict.items()} + log.debug(f"entities_dict {entities_dict}") + + return entities_dict, entities_positions_dict, entities_probas_dict diff --git a/deeppavlov/models/entity_extraction/entity_linking.py b/deeppavlov/models/entity_extraction/entity_linking.py index b91e1ea412..a0666d3741 100644 --- a/deeppavlov/models/entity_extraction/entity_linking.py +++ b/deeppavlov/models/entity_extraction/entity_linking.py @@ -14,10 +14,11 @@ import re import sqlite3 -from collections import defaultdict from logging import getLogger -from typing import List, Dict, Tuple, Union, Any +from typing import List, Dict, Tuple, Any, Union +from collections import defaultdict +import nltk import spacy from hdt import HDTDocument from nltk.corpus import stopwords @@ -27,8 +28,10 @@ from deeppavlov.core.common.registry import register from deeppavlov.core.models.component import Component from deeppavlov.core.models.serializable import Serializable +from deeppavlov.models.entity_extraction.find_word import WordSearcher log = getLogger(__name__) +nltk.download("stopwords") @register("entity_linker") @@ -40,48 +43,61 @@ class EntityLinker(Component, Serializable): def __init__( self, load_path: str, - entities_database_filename: str, entity_ranker=None, + entities_database_filename: str = None, + words_dict_filename: str = None, + ngrams_matrix_filename: str = None, num_entities_for_bert_ranking: int = 50, - wikidata_file: str = None, + num_entities_for_conn_ranking: int = 5, num_entities_to_return: int = 10, max_text_len: int = 300, - lang: str = "en", + max_paragraph_len: int = 150, + lang: str = "ru", use_descriptions: bool = True, + alias_coef: float = 1.1, use_tags: bool = False, lemmatize: bool = False, full_paragraph: bool = False, use_connections: bool = False, - max_paragraph_len: int = 250, + kb_filename: str = None, + prefixes: Dict[str, Any] = None, **kwargs, ) -> None: """ Args: load_path: path to folder with inverted index files - entities_database_filename: file with sqlite database with Wikidata entities index - entity_ranker: deeppavlov.models.torch_bert.torch_transformers_el_ranker.TorchTransformersEntityRankerInfer + entity_ranker: component deeppavlov.models.kbqa.rel_ranking_bert + entities_database_filename: filename with database with entities index + words_dict_filename: filename with words and corresponding tags + ngrams_matrix_filename: filename with char tfidf matrix num_entities_for_bert_ranking: number of candidate entities for BERT ranking using description and context - wikidata_file: .hdt file with Wikidata graph + num_entities_for_conn_ranking: number of candidate entities for ranking using connections in the knowledge + graph num_entities_to_return: number of candidate entities for the substring which are returned - max_text_len: max length of context for entity ranking by description + max_text_len: maximal length of entity context + max_paragraph_len: maximal length of context paragraphs lang: russian or english use_description: whether to perform entity ranking by context and description - use_tags: whether to use ner tags for entity filtering + alias_coef: coefficient which is multiplied by the substring matching confidence if the substring is the + title of the entity + use_tags: whether to filter candidate entities by tags lemmatize: whether to lemmatize tokens - full_paragraph: whether to use full paragraph for entity ranking by context and description - use_connections: whether to ranking entities by number of connections in Wikidata - max_paragraph_len: maximum length of paragraph for ranking by context and description + full_paragraph: whether to use full paragraph for entity context + use_connections: whether to rank entities by connections in the knowledge graph + kb_filename: filename with the knowledge base in HDT format + prefixes: entity and title prefixes **kwargs: """ super().__init__(save_path=None, load_path=load_path) self.lemmatize = lemmatize - self.entities_database_filename = entities_database_filename self.num_entities_for_bert_ranking = num_entities_for_bert_ranking - self.wikidata_file = wikidata_file + self.num_entities_for_conn_ranking = num_entities_for_conn_ranking self.entity_ranker = entity_ranker + self.entities_database_filename = entities_database_filename self.num_entities_to_return = num_entities_to_return self.max_text_len = max_text_len + self.max_paragraph_len = max_paragraph_len self.lang = f"@{lang}" if self.lang == "@en": self.stopwords = set(stopwords.words("english")) @@ -89,37 +105,49 @@ def __init__( elif self.lang == "@ru": self.stopwords = set(stopwords.words("russian")) self.nlp = spacy.load("ru_core_news_sm") + self.alias_coef = alias_coef self.use_descriptions = use_descriptions self.use_connections = use_connections - self.max_paragraph_len = max_paragraph_len self.use_tags = use_tags self.full_paragraph = full_paragraph self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") - self.not_found_str = "not_in_wiki" - + self.related_tags = { + "loc": ["gpe", "country", "city", "us_state", "river"], + "gpe": ["loc", "country", "city", "us_state"], + "work_of_art": ["product", "law"], + "product": ["work_of_art"], + "law": ["work_of_art"], + "org": ["fac", "business"], + "business": ["org"] + } + self.word_searcher = None + if words_dict_filename: + self.word_searcher = WordSearcher(words_dict_filename, ngrams_matrix_filename, self.lang) + self.kb_filename = kb_filename + self.prefixes = prefixes self.load() def load(self) -> None: self.conn = sqlite3.connect(str(self.load_path / self.entities_database_filename)) self.cur = self.conn.cursor() - self.wikidata = None - if self.wikidata_file: - self.wikidata = HDTDocument(str(expand_path(self.wikidata_file))) + self.kb = None + if self.kb_filename: + self.kb = HDTDocument(str(expand_path(self.kb_filename))) def save(self) -> None: pass def __call__( self, - entity_substr_batch: List[List[str]], - entity_tags_batch: List[List[str]] = None, + substr_batch: List[List[str]], + tags_batch: List[List[str]] = None, + probas_batch: List[List[float]] = None, sentences_batch: List[List[str]] = None, - entity_offsets_batch: List[List[List[int]]] = None, + offsets_batch: List[List[List[int]]] = None, sentences_offsets_batch: List[List[Tuple[int, int]]] = None, - ) -> Tuple[Union[List[List[List[str]]], List[List[str]]], Union[List[List[List[Any]]], List[List[Any]]], - Union[List[List[List[str]]], List[List[str]]]]: - if (not sentences_offsets_batch or sentences_offsets_batch[0] is None) and sentences_batch is not None \ - or not isinstance(sentences_offsets_batch[0][0], (list, tuple)): + entities_to_link_batch: List[List[int]] = None + ): + if (not sentences_offsets_batch or sentences_offsets_batch[0] is None) and sentences_batch is not None: sentences_offsets_batch = [] for sentences_list in sentences_batch: sentences_offsets_list = [] @@ -130,219 +158,289 @@ def __call__( start = end + 1 sentences_offsets_batch.append(sentences_offsets_list) - if entity_tags_batch is None or not entity_tags_batch[0]: - entity_tags_batch = [["" for _ in entity_substr_list] for entity_substr_list in entity_substr_batch] - else: - entity_tags_batch = [[tag.upper() for tag in entity_tags] for entity_tags in entity_tags_batch] - if sentences_batch is None: - sentences_batch = [[] for _ in entity_substr_batch] - sentences_offsets_batch = [[] for _ in entity_substr_batch] - - log.debug(f"sentences_batch {sentences_batch}") - if (not entity_offsets_batch and sentences_batch) or not entity_offsets_batch[0] \ - or not isinstance(entity_offsets_batch[0][0], (list, tuple)): - entity_offsets_batch = [] - for entity_substr_list, sentences_list in zip(entity_substr_batch, sentences_batch): + sentences_batch = [[] for _ in substr_batch] + sentences_offsets_batch = [[] for _ in substr_batch] + + if not entities_to_link_batch or entities_to_link_batch[0] is None: + entities_to_link_batch = [[1 for _ in substr_list] for substr_list in substr_batch] + + log.debug(f"substr: {substr_batch} --- sentences_batch: {sentences_batch} --- offsets: {offsets_batch}") + if (not offsets_batch or offsets_batch[0] is None) and sentences_batch: + offsets_batch = [] + for substr_list, sentences_list in zip(substr_batch, sentences_batch): text = " ".join(sentences_list).lower() log.debug(f"text {text}") - entity_offsets_list = [] - for entity_substr in entity_substr_list: - st_offset = text.find(entity_substr.lower()) - end_offset = st_offset + len(entity_substr) - entity_offsets_list.append([st_offset, end_offset]) - entity_offsets_batch.append(entity_offsets_list) - - entity_ids_batch, entity_conf_batch, entity_pages_batch = [], [], [] - for (entity_substr_list, entity_offsets_list, entity_tags_list, sentences_list, sentences_offsets_list,) in zip( - entity_substr_batch, - entity_offsets_batch, - entity_tags_batch, - sentences_batch, - sentences_offsets_batch, - ): - entity_ids_list, entity_conf_list, entity_pages_list = self.link_entities( - entity_substr_list, - entity_offsets_list, - entity_tags_list, - sentences_list, - sentences_offsets_list, - ) - log.debug(f"entity_ids_list {entity_ids_list} entity_conf_list {entity_conf_list}") - entity_ids_batch.append(entity_ids_list) - entity_conf_batch.append(entity_conf_list) - entity_pages_batch.append(entity_pages_list) - return entity_ids_batch, entity_conf_batch, entity_pages_batch + offsets_list = [] + for substr in substr_list: + st_offset = text.find(substr.lower()) + end_offset = st_offset + len(substr) + offsets_list.append([st_offset, end_offset]) + offsets_batch.append(offsets_list) + ids_batch, conf_batch, pages_batch, labels_batch = [], [], [], [] + for substr_list, offsets_list, tags_list, probas_list, sentences_list, sentences_offsets_list, \ + entities_to_link in zip(substr_batch, offsets_batch, tags_batch, probas_batch, sentences_batch, + sentences_offsets_batch, entities_to_link_batch): + ids_list, conf_list, pages_list, labels_list = \ + self.link_entities(substr_list, offsets_list, tags_list, probas_list, sentences_list, + sentences_offsets_list, entities_to_link) + log.debug(f"ids_list {ids_list} conf_list {conf_list}") + if self.num_entities_to_return == 1: + pages_list = [pages[0] for pages in pages_list] + else: + pages_list = [pages[: len(ids)] for pages, ids in zip(pages_list, ids_list)] + ids_batch.append(ids_list) + conf_batch.append(conf_list) + pages_batch.append(pages_list) + labels_batch.append(labels_list) + return ids_batch, conf_batch, pages_batch, labels_batch def link_entities( self, - entity_substr_list: List[str], - entity_offsets_list: List[List[int]], - entity_tags_list: List[str], + substr_list: List[str], + offsets_list: List[List[int]], + tags_list: List[str], + probas_list: List[float], sentences_list: List[str], sentences_offsets_list: List[List[int]], - ) -> Tuple[Union[List[List[str]], List[str]], Union[List[List[Any]], List[Any]], Union[List[List[str]], List[str]]]: - log.debug( - f"entity_substr_list {entity_substr_list} entity_tags_list {entity_tags_list} " - f"entity_offsets_list {entity_offsets_list}" - ) - entity_ids_list, conf_list, pages_list = [], [], [] - if entity_substr_list: + entities_to_link: List[int] + ) -> Tuple[List[Any], List[Any], List[List[Union[str, Any]]], List[List[Union[str, Any]]]]: + log.debug(f"substr_list {substr_list} tags_list {tags_list} probas {probas_list} offsets_list {offsets_list}") + ids_list, conf_list, pages_list, label_list, descr_list = [], [], [], [], [] + if substr_list: entities_scores_list = [] cand_ent_scores_list = [] - entity_substr_split_list = [ - [word for word in entity_substr.split(" ") if word not in self.stopwords and len(word) > 0] - for entity_substr in entity_substr_list - ] - for entity_substr, entity_substr_split, tag in zip( - entity_substr_list, entity_substr_split_list, entity_tags_list - ): + for substr, tags, proba in zip(substr_list, tags_list, probas_list): + for old_symb, new_symb in [("'s", ""), ("@", ""), (" ", " "), (".", ""), (",", ""), ("-", " "), + ("'", " "), ("!", ""), (":", ""), ("&", ""), ("/", " "), ('"', ""), + (" ", " ")]: + substr = substr.replace(old_symb, new_symb) + substr = substr.strip() + cand_ent_init = defaultdict(set) + if len(substr) > 1: + if isinstance(tags, str): + tags = [tags] + tags = [tag.lower() for tag in tags] + if tags and not isinstance(tags[0], (list, tuple)): + tags = [(tag, 1.0) for tag in tags] + if tags and tags[0][0] == "e": + use_tags_flag = False + else: + use_tags_flag = True + cand_ent_init = self.find_exact_match(substr, tags, use_tags=use_tags_flag) + new_substr = re.sub(r"\b([a-z]{1}) ([a-z]{1})\b", r"\1\2", substr) + if substr != new_substr: + new_cand_ent_init = self.find_exact_match(new_substr, tags, use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, new_cand_ent_init) + + init_substr_split = substr.lower().split(" ") + if tags[0][0] in {"person", "work_of_art"}: + substr_split = [word for word in substr.lower().split(" ") if len(word) > 0] + else: + substr_split = [word for word in substr.lower().split(" ") + if word not in self.stopwords and len(word) > 0] + + substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in substr_split] + substr_lemm = " ".join(substr_split_lemm) + if substr_split != substr_split_lemm \ + or (tags[0][0] == "work_of_art" + and len(substr_split) != len(init_substr_split)): + new_cand_ent_init = self.find_fuzzy_match(substr_split, tags, use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, new_cand_ent_init) + if substr_split != substr_split_lemm: + new_cand_ent_init = self.find_exact_match(substr_lemm, tags, use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, new_cand_ent_init) + new_cand_ent_init = self.find_fuzzy_match(substr_split_lemm, tags, use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, new_cand_ent_init) + + all_low_conf = self.define_all_low_conf(cand_ent_init, 1.0) + clean_tags, corr_tags, corr_clean_tags = self.correct_tags(tags) + log.debug(f"substr: {substr} --- lemm: {substr_split_lemm} --- tags: {tags} --- corr_tags: " + f"{corr_tags} --- all_low_conf: {all_low_conf} --- cand_ent_init: {len(cand_ent_init)}") + + if (not cand_ent_init or all_low_conf) and corr_tags: + corr_cand_ent_init = self.find_exact_match(substr, corr_tags, use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, corr_cand_ent_init) + if substr_split != substr_split_lemm: + new_cand_ent_init = self.find_exact_match(substr_lemm, corr_tags, use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, new_cand_ent_init) + new_cand_ent_init = self.find_fuzzy_match(substr_split_lemm, corr_tags, + use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, new_cand_ent_init) + + if not cand_ent_init and len(substr_split) == 1 and self.word_searcher: + corr_words = self.word_searcher(substr_split[0], set(clean_tags + corr_clean_tags)) + if corr_words: + cand_ent_init = self.find_exact_match(corr_words[0], tags + corr_tags, + use_tags=use_tags_flag) + + if not cand_ent_init and len(substr_split) > 1: + cand_ent_init = self.find_fuzzy_match(substr_split, tags) + + all_low_conf = self.define_all_low_conf(cand_ent_init, 0.85) + if (not cand_ent_init or all_low_conf) and tags[0][0] != "t": + use_tags_flag = False + new_cand_ent_init = self.find_exact_match(substr, tags, use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, new_cand_ent_init) + if substr_split != substr_split_lemm and (tags[0][0] == "e" or not cand_ent_init): + new_cand_ent_init = self.find_fuzzy_match(substr_split, tags, use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, new_cand_ent_init) + new_cand_ent_init = self.find_fuzzy_match(substr_split_lemm, tags, use_tags=use_tags_flag) + cand_ent_init = self.unite_dicts(cand_ent_init, new_cand_ent_init) + cand_ent_scores = [] - if len(entity_substr) > 1: - entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split] - cand_ent_init = self.find_exact_match(entity_substr, tag) - if not cand_ent_init or entity_substr_split != entity_substr_split_lemm: - cand_ent_init = self.find_fuzzy_match(entity_substr_split, tag) - - for entity in cand_ent_init: - entities_scores = list(cand_ent_init[entity]) - entities_scores = sorted(entities_scores, key=lambda x: (x[0], x[1]), reverse=True) - cand_ent_scores.append((entity, entities_scores[0])) - cand_ent_scores = sorted(cand_ent_scores, key=lambda x: (x[1][0], x[1][1]), reverse=True) - - cand_ent_scores = cand_ent_scores[:self.num_entities_for_bert_ranking] + for entity in cand_ent_init: + entities_scores = list(cand_ent_init[entity]) + entities_scores = sorted(entities_scores, key=lambda x: (x[0], x[2], x[1]), reverse=True) + cand_ent_scores.append(([entity] + list(entities_scores[0]))) + + cand_ent_scores = sorted(cand_ent_scores, key=lambda x: (x[1], x[3], x[2]), reverse=True) + cand_ent_scores = cand_ent_scores[: self.num_entities_for_bert_ranking] cand_ent_scores_list.append(cand_ent_scores) entity_ids = [elem[0] for elem in cand_ent_scores] - entities_scores_list.append({ent: score for ent, score in cand_ent_scores}) - entity_ids_list.append(entity_ids) - - if self.use_connections: - entity_ids_list = [] - entities_with_conn_scores_list = self.rank_by_connections(cand_ent_scores_list) - for entities_with_conn_scores in entities_with_conn_scores_list: - entity_ids = [elem[0] for elem in entities_with_conn_scores] - entity_ids_list.append(entity_ids) - - entity_descr_list = [] - pages_dict = {} - for entity_ids in entity_ids_list: - entity_descrs = [] - for entity_id in entity_ids: - res = self.cur.execute("SELECT * FROM entity_labels WHERE entity='{}';".format(entity_id)) - entity_info = res.fetchall() - if entity_info: - ( - cur_entity_id, - cur_entity_label, - cur_entity_descr, - cur_entity_page, - ) = entity_info[0] - entity_descrs.append(cur_entity_descr) - pages_dict[cur_entity_id] = cur_entity_page - else: - entity_descrs.append("") - entity_descr_list.append(entity_descrs) - if self.use_descriptions: - substr_lens = [len(entity_substr.split()) for entity_substr in entity_substr_list] - entity_ids_list, conf_list = self.rank_by_description( - entity_substr_list, - entity_offsets_list, - entity_ids_list, - entity_descr_list, - entities_scores_list, - sentences_list, - sentences_offsets_list, - substr_lens, + scores = [elem[1:4] for elem in cand_ent_scores] + conf_list.append(scores) + entities_scores_list.append( + {entity_id: entity_scores for entity_id, entity_scores in zip(entity_ids, scores)} ) - if self.num_entities_to_return == 1: - pages_list = [pages_dict.get(entity_ids, "") for entity_ids in entity_ids_list] + ids_list.append(entity_ids) + pages = [elem[4] for elem in cand_ent_scores] + entity_labels = [elem[5] for elem in cand_ent_scores] + pages_list.append({entity_id: page for entity_id, page in zip(entity_ids, pages)}) + label_list.append( + {entity_id: entity_label for entity_id, entity_label in zip(entity_ids, entity_labels)}) + descr_list.append([elem[6] for elem in cand_ent_scores]) + + scores_dict = {} + if self.use_connections and self.kb: + scores_dict = self.rank_by_connections(ids_list) + + substr_lens = [len(entity_substr.split()) for entity_substr in substr_list] + ids_list, conf_list = self.rank_by_description(substr_list, tags_list, offsets_list, ids_list, + descr_list, entities_scores_list, sentences_list, + sentences_offsets_list, substr_lens, scores_dict) + label_list = [[label_dict.get(entity_id, "") for entity_id in entity_ids] + for entity_ids, label_dict in zip(ids_list, label_list)] + pages_list = [[pages_dict.get(entity_id, "") for entity_id in entity_ids] + for entity_ids, pages_dict in zip(ids_list, pages_list)] + + f_ids_list, f_conf_list, f_pages_list, f_label_list = [], [], [], [] + for ids, confs, pages, labels, add_flag in \ + zip(ids_list, conf_list, pages_list, label_list, entities_to_link): + if add_flag: + f_ids_list.append(ids) + f_conf_list.append(confs) + f_pages_list.append(pages) + f_label_list.append(labels) + return f_ids_list, f_conf_list, f_pages_list, f_label_list + + def define_all_low_conf(self, cand_ent_init, thres): + all_low_conf = True + for entity_id in cand_ent_init: + entity_info_set = cand_ent_init[entity_id] + for entity_info in entity_info_set: + if entity_info[0] >= thres: + all_low_conf = False + break + if not all_low_conf: + break + return all_low_conf + + def correct_tags(self, tags): + clean_tags = [tag for tag, conf in tags] + corr_tags, corr_clean_tags = [], [] + for tag, conf in tags: + if tag in self.related_tags: + corr_tag_list = self.related_tags[tag] + for corr_tag in corr_tag_list: + if corr_tag not in clean_tags and corr_tag not in corr_clean_tags: + corr_tags.append([corr_tag, conf]) + corr_clean_tags.append(corr_tag) + return clean_tags, corr_tags, corr_clean_tags + + def unite_dicts(self, cand_ent_init, new_cand_ent_init): + for entity_id in new_cand_ent_init: + if entity_id in cand_ent_init: + for entity_info in new_cand_ent_init[entity_id]: + cand_ent_init[entity_id].add(entity_info) else: - pages_list = [[pages_dict.get(entity_id, "") for entity_id in entity_ids] - for entity_ids in entity_ids_list] - - return entity_ids_list, conf_list, pages_list - - def process_cand_ent(self, cand_ent_init, entities_and_ids, entity_substr_split, tag): - if self.use_tags: - for cand_entity_title, cand_entity_id, cand_entity_rels, cand_tag, *_ in entities_and_ids: - if not tag or tag == cand_tag: - substr_score = self.calc_substr_score(cand_entity_title, entity_substr_split) - cand_ent_init[cand_entity_id].add((substr_score, cand_entity_rels)) - if not cand_ent_init: - for cand_entity_title, cand_entity_id, cand_entity_rels, cand_tag, *_ in entities_and_ids: - substr_score = self.calc_substr_score(cand_entity_title, entity_substr_split) - cand_ent_init[cand_entity_id].add((substr_score, cand_entity_rels)) - else: - for cand_entity_title, cand_entity_id, cand_entity_rels, *_ in entities_and_ids: - substr_score = self.calc_substr_score(cand_entity_title, entity_substr_split) - cand_ent_init[cand_entity_id].add((substr_score, cand_entity_rels)) + cand_ent_init[entity_id] = new_cand_ent_init[entity_id] + return cand_ent_init + + def process_cand_ent(self, cand_ent_init, entities_and_ids, substr_split, tag, tag_conf, use_tags): + for title, entity_id, rels, ent_tag, page, label, descr in entities_and_ids: + if (ent_tag == tag and use_tags) or not use_tags: + substr_score = self.calc_substr_score(title, substr_split, tag, ent_tag, label) + cand_ent_init[entity_id].add((substr_score, rels, tag_conf, page, label, descr)) return cand_ent_init - def find_title(self, entity_substr): - entities_and_ids = [] - try: - res = self.cur.execute("SELECT * FROM inverted_index WHERE title MATCH '{}';".format(entity_substr)) - entities_and_ids = res.fetchall() - except sqlite3.OperationalError as e: - log.debug(f"error in searching an entity {e}") - return entities_and_ids + def sanitize_substr(self, entity_substr, tag): + if tag == "person": + entity_substr_split = entity_substr.split() + if len(entity_substr_split) > 1 and len(entity_substr_split[-1]) > 1 and len(entity_substr_split[-2]) == 1: + entity_substr = entity_substr_split[-1] + return entity_substr - def find_exact_match(self, entity_substr, tag): + def find_exact_match(self, entity_substr, tags, use_tags=True): + entity_substr = entity_substr.lower() entity_substr_split = entity_substr.split() cand_ent_init = defaultdict(set) - entities_and_ids = self.find_title(entity_substr) - if entities_and_ids: - cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag) - if entity_substr.startswith("the "): - entity_substr = entity_substr.split("the ")[1] - entity_substr_split = entity_substr_split[1:] - entities_and_ids = self.find_title(entity_substr) - cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag) - - entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split] - entity_substr_lemm = " ".join(entity_substr_split_lemm) - if entity_substr_lemm != entity_substr: - entities_and_ids = self.find_title(entity_substr_lemm) + for tag, tag_conf in tags: + entity_substr = self.sanitize_substr(entity_substr, tag) + query = "SELECT * FROM inverted_index WHERE title MATCH ?;" + entities_and_ids = [] + try: + res = self.cur.execute(query, (entity_substr,)) + entities_and_ids = res.fetchall() + except: + log.info(f"error in query execute {query}") if entities_and_ids: - cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag) + cand_ent_init = self.process_cand_ent( + cand_ent_init, entities_and_ids, entity_substr_split, tag, tag_conf, use_tags) return cand_ent_init - def find_fuzzy_match(self, entity_substr_split, tag): - entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split] + def find_fuzzy_match(self, entity_substr_split, tags, use_tags=True): cand_ent_init = defaultdict(set) - for word in entity_substr_split: - part_entities_and_ids = self.find_title(word) - cand_ent_init = self.process_cand_ent(cand_ent_init, part_entities_and_ids, entity_substr_split, tag) - if self.lang == "@ru": - word_lemm = self.nlp(word)[0].lemma_ - if word != word_lemm: - part_entities_and_ids = self.find_title(word_lemm) - cand_ent_init = self.process_cand_ent( - cand_ent_init, - part_entities_and_ids, - entity_substr_split_lemm, - tag - ) + for tag, tag_conf in tags: + if len(entity_substr_split) > 3: + entity_substr_split = [" ".join(entity_substr_split[i:i + 2]) + for i in range(len(entity_substr_split) - 1)] + for word in entity_substr_split: + if len(word) > 1 and word not in self.stopwords: + query = "SELECT * FROM inverted_index WHERE title MATCH ?;" + part_entities_and_ids = [] + try: + res = self.cur.execute(query, (word,)) + part_entities_and_ids = res.fetchall() + except: + log.info(f"error in query execute {query}") + if part_entities_and_ids: + cand_ent_init = self.process_cand_ent( + cand_ent_init, part_entities_and_ids, entity_substr_split, tag, tag_conf, use_tags) return cand_ent_init - def calc_substr_score(self, cand_entity_title, entity_substr_split): - label_tokens = cand_entity_title.split() + def match_tokens(self, entity_substr_split, label_tokens): cnt = 0.0 - for ent_tok in entity_substr_split: - found = False - for label_tok in label_tokens: - if label_tok == ent_tok: - found = True - break - if found: - cnt += 1.0 - else: + if not (len(entity_substr_split) > 1 and len(label_tokens) > 1 + and set(entity_substr_split) != set(label_tokens) and label_tokens[0] != label_tokens[-1] + and ((entity_substr_split[0] == label_tokens[-1]) or (entity_substr_split[-1] == label_tokens[0]))): + for ent_tok in entity_substr_split: + found = False for label_tok in label_tokens: - if label_tok[:2] == ent_tok[:2]: - fuzz_score = fuzz.ratio(label_tok, ent_tok) - if fuzz_score >= 80.0 and not found: - cnt += fuzz_score * 0.01 - break + if label_tok == ent_tok: + found = True + break + if found: + cnt += 1.0 + else: + for label_tok in label_tokens: + if label_tok[:2] == ent_tok[:2]: + fuzz_score = fuzz.ratio(label_tok, ent_tok) + c_long_toks = len(label_tok) >= 8 and label_tok[:6] == ent_tok[:6] and fuzz_score > 70.0 + c_shrt_toks = len(label_tokens) > 2 and len(label_tok) > 3 and label_tok[:4] == ent_tok[:4] + if (fuzz_score >= 75.0 or c_long_toks or c_shrt_toks) and not found: + cnt += fuzz_score * 0.01 + break substr_score = round(cnt / max(len(label_tokens), len(entity_substr_split)), 3) if len(label_tokens) == 2 and len(entity_substr_split) == 1: if entity_substr_split[0] == label_tokens[1]: @@ -351,120 +449,99 @@ def calc_substr_score(self, cand_entity_title, entity_substr_split): substr_score = 0.3 return substr_score - def rank_by_connections(self, cand_ent_scores_list: List[List[Union[str, Tuple[str, str]]]]): - entities_for_ranking_list = [] - for entities_scores in cand_ent_scores_list: - entities_for_ranking = [] - if entities_scores: - max_score = entities_scores[0][1][0] - for entity, scores in entities_scores: - if scores[0] == max_score: - entities_for_ranking.append(entity) - entities_for_ranking_list.append(entities_for_ranking) - - entities_sets_list = [] - for entities_scores in cand_ent_scores_list: - entities_sets_list.append({entity for entity, scores in entities_scores}) - - entities_conn_scores_list = [] - for entities_scores in cand_ent_scores_list: - cur_entity_dict = {} - for entity, scores in entities_scores: - cur_entity_dict[entity] = 0 - entities_conn_scores_list.append(cur_entity_dict) - - entities_objects_list, entities_triplets_list = [], [] - for entities_scores in cand_ent_scores_list: - cur_objects_dict, cur_triplets_dict = {}, {} - for entity, scores in entities_scores: - objects, triplets = set(), set() - tr, cnt = self.wikidata.search_triples(f"http://we/{entity}", "", "") - for triplet in tr: - objects.add(triplet[2].split("/")[-1]) - triplets.add((triplet[1].split("/")[-1], triplet[2].split("/")[-1])) - cur_objects_dict[entity] = objects - cur_triplets_dict[entity] = triplets - entities_objects_list.append(cur_objects_dict) - entities_triplets_list.append(cur_triplets_dict) - - already_ranked = {i: False for i in range(len(entities_for_ranking_list))} - - for i in range(len(entities_for_ranking_list)): - for entity1 in entities_for_ranking_list[i]: - for j in range(len(entities_for_ranking_list)): - if i != j and not already_ranked[j]: - inters = entities_objects_list[i][entity1].intersection(entities_sets_list[j]) - if inters: - entities_conn_scores_list[i][entity1] += len(inters) - for entity2 in inters: - entities_conn_scores_list[j][entity2] += len(inters) - already_ranked[j] = True - else: - for entity2 in entities_triplets_list[j]: - inters = entities_triplets_list[i][entity1].intersection( - entities_triplets_list[j][entity2] - ) - inters = {elem for elem in inters if elem[0] != "P31"} - if inters: - prev_score1 = entities_conn_scores_list[i].get(entity1, 0) - prev_score2 = entities_conn_scores_list[j].get(entity2, 0) - entities_conn_scores_list[i][entity1] = max(len(inters), prev_score1) - entities_conn_scores_list[j][entity2] = max(len(inters), prev_score2) - - entities_with_conn_scores_list = [] - for i in range(len(entities_conn_scores_list)): - entities_with_conn_scores_list.append( - sorted( - list(entities_conn_scores_list[i].items()), - key=lambda x: x[1], - reverse=True, - ) - ) - return entities_with_conn_scores_list + def correct_substr_score(self, entity_substr_split, label_tokens, substr_score): + if sum([len(tok) == 1 for tok in entity_substr_split]) == 2 and len(label_tokens) >= 2 \ + and any([(len(tok) == 2 and re.findall(r"[a-z]{2}", tok)) for tok in label_tokens]): + new_label_tokens = [] + for tok in label_tokens: + if len(tok) == 2 and re.findall(r"[a-z]{2}", tok): + new_label_tokens.append(tok[0]) + new_label_tokens.append(tok[1]) + else: + new_label_tokens.append(tok) + label_tokens = new_label_tokens + if any([re.findall(r"[\d]{4}", tok) for tok in entity_substr_split]) \ + and any([re.findall(r"[\d]{4}–[\d]{2}", tok) for tok in label_tokens]): + new_label_tokens = [] + for tok in label_tokens: + if re.findall(r"[\d]{4}–[\d]{2}", tok): + new_label_tokens.append(tok[:4]) + new_label_tokens.append(tok[5:]) + else: + new_label_tokens.append(tok) + label_tokens = new_label_tokens + new_substr_score = self.match_tokens(entity_substr_split, label_tokens) + substr_score = max(substr_score, new_substr_score) + return substr_score + + def calc_substr_score(self, entity_title, entity_substr_split, tag, ent_tag, entity_label): + if self.lang == "@ru": + entity_title = entity_title.replace("ё", "е") + label_tokens = entity_title.split() + substr_score = self.match_tokens(entity_substr_split, label_tokens) + substr_score = self.correct_substr_score(entity_substr_split, label_tokens, substr_score) + if re.findall(r" \(.*\)", entity_label): + entity_label_split = entity_label.replace("(", "").replace(")", "").lower().split() + lbl_substr_score = self.match_tokens(entity_substr_split, entity_label_split) + substr_score = max(substr_score, lbl_substr_score) + if tag == ent_tag and tag.lower() == "person" and len(entity_substr_split) > 1 \ + and len(entity_substr_split[-1]) > 1 and len(entity_substr_split[-2]) == 1 \ + and len(label_tokens) == len(entity_substr_split): + cnt = 0.0 + for j in range(len(label_tokens) - 1): + if label_tokens[j][0] == entity_substr_split[j][0]: + cnt += 1.0 + if label_tokens[-1] == entity_substr_split[-1]: + cnt += 1.0 + new_substr_score = cnt / len(label_tokens) + substr_score = max(substr_score, new_substr_score) + + if entity_title.lower() == entity_label.lower() and substr_score == 1.0: + substr_score = substr_score * self.alias_coef + return substr_score def rank_by_description( self, entity_substr_list: List[str], + tags_list: List[str], entity_offsets_list: List[List[int]], cand_ent_list: List[List[str]], cand_ent_descr_list: List[List[str]], entities_scores_list: List[Dict[str, Tuple[int, float]]], sentences_list: List[str], - sentences_offsets_list: List[List[int]], + sentences_offsets_list: List[Tuple[int, int]], substr_lens: List[int], - ) -> Tuple[Union[List[List[str]], List[str]], Union[List[List[Any]], List[Any]]]: + scores_dict: Dict[str, int] = None + ) -> Tuple[List[Union[Union[float, List[Any], List[Union[float, Any]]], Any]], List[ + Union[Union[tuple, List[tuple], List[Any], List[Tuple[Union[float, Any], ...]]], Any]]]: entity_ids_list = [] conf_list = [] contexts = [] - for ( - entity_substr, - (entity_start_offset, entity_end_offset), - candidate_entities, - ) in zip(entity_substr_list, entity_offsets_list, cand_ent_list): - sentence = "" - rel_start_offset = 0 - rel_end_offset = 0 - found_sentence_num = 0 - for num, (sent, (sent_start_offset, sent_end_offset)) in enumerate( - zip(sentences_list, sentences_offsets_list) - ): - if entity_start_offset >= sent_start_offset and entity_end_offset <= sent_end_offset: - sentence = sent - found_sentence_num = num - rel_start_offset = entity_start_offset - sent_start_offset - rel_end_offset = entity_end_offset - sent_start_offset - break - context = "" + for entity_offset in entity_offsets_list: + context, sentence = "", "" + if len(entity_offset) == 2: + entity_start_offset, entity_end_offset = entity_offset + rel_start_offset = 0 + rel_end_offset = 0 + found_sentence_num = 0 + for num, (sent, (sent_start_offset, sent_end_offset)) in enumerate( + zip(sentences_list, sentences_offsets_list) + ): + if entity_start_offset >= sent_start_offset and entity_end_offset <= sent_end_offset: + sentence = sent + found_sentence_num = num + rel_start_offset = entity_start_offset - sent_start_offset + rel_end_offset = entity_end_offset - sent_start_offset + break if sentence: start_of_sentence = 0 end_of_sentence = len(sentence) if len(sentence) > self.max_text_len: start_of_sentence = max(rel_start_offset - self.max_text_len // 2, 0) end_of_sentence = min(rel_end_offset + self.max_text_len // 2, len(sentence)) - context = ( - sentence[start_of_sentence:rel_start_offset] + "[ENT]" + sentence[ - rel_end_offset:end_of_sentence] - ) + text_before = sentence[start_of_sentence:rel_start_offset] + text_after = sentence[rel_end_offset:end_of_sentence] + context = text_before + "[ENT]" + text_after if self.full_paragraph: cur_sent_len = len(re.findall(self.re_tokenizer, context)) first_sentence_num = found_sentence_num @@ -473,24 +550,16 @@ def rank_by_description( while True: added = False if last_sentence_num < len(sentences_list) - 1: - last_sentence_len = len( - re.findall( - self.re_tokenizer, - sentences_list[last_sentence_num + 1], - ) - ) + sentence_tokens = re.findall(self.re_tokenizer, sentences_list[last_sentence_num + 1]) + last_sentence_len = len(sentence_tokens) if cur_sent_len + last_sentence_len < self.max_paragraph_len: context.append(sentences_list[last_sentence_num + 1]) cur_sent_len += last_sentence_len last_sentence_num += 1 added = True if first_sentence_num > 0: - first_sentence_len = len( - re.findall( - self.re_tokenizer, - sentences_list[first_sentence_num - 1], - ) - ) + sentence_tokens = re.findall(self.re_tokenizer, sentences_list[first_sentence_num - 1]) + first_sentence_len = len(sentence_tokens) if cur_sent_len + first_sentence_len < self.max_paragraph_len: context = [sentences_list[first_sentence_num - 1]] + context cur_sent_len += first_sentence_len @@ -503,52 +572,56 @@ def rank_by_description( log.debug(f"rank, context: {context}") contexts.append(context) - scores_list = self.entity_ranker(contexts, cand_ent_list, cand_ent_descr_list) - for (entity_substr, candidate_entities, substr_len, entities_scores, scores,) in zip( - entity_substr_list, - cand_ent_list, - substr_lens, - entities_scores_list, - scores_list, + if self.use_descriptions: + scores_list = self.entity_ranker(contexts, cand_ent_list, cand_ent_descr_list) + else: + scores_list = [[(entity_id, 1.0) for entity_id in cand_ent] for cand_ent in cand_ent_list] + + for entity_substr, tag, context, candidate_entities, substr_len, entities_scores, scores in zip( + entity_substr_list, tags_list, contexts, cand_ent_list, substr_lens, entities_scores_list, scores_list ): - log.debug(f"len candidate entities {len(candidate_entities)}") - entities_with_scores = [ - ( - entity, - round(entities_scores.get(entity, (0.0, 0))[0], 2), - entities_scores.get(entity, (0.0, 0))[1], - round(float(score), 2), - ) - for entity, score in scores - ] - log.debug(f"len entities with scores {len(entities_with_scores)}") - entities_with_scores = sorted(entities_with_scores, key=lambda x: (x[1], x[3], x[2]), reverse=True) - log.debug(f"--- entities_with_scores {entities_with_scores}") + entities_with_scores = [] + max_conn_score = 0 + if scores_dict and scores: + max_conn_score = max([scores_dict.get(entity, 0) for entity, _ in scores]) + for entity, score in scores: + substr_score = round(entities_scores.get(entity, (0.0, 0))[0], 2) + num_rels = entities_scores.get(entity, (0.0, 0))[1] + if len(context.split()) < 4: + score = 0.95 + elif scores_dict and 0 < max_conn_score == scores_dict.get(entity, 0): + score = 1.0 + num_rels = 200 + entities_with_scores.append((entity, substr_score, num_rels, float(score))) + + if tag == "t": + entities_with_scores = sorted(entities_with_scores, key=lambda x: (x[1], x[2], x[3]), reverse=True) + else: + entities_with_scores = sorted(entities_with_scores, key=lambda x: (x[1], x[3], x[2]), reverse=True) + log.debug(f"{entity_substr} --- tag: {tag} --- entities_with_scores: {entities_with_scores}") if not entities_with_scores: - top_entities = [self.not_found_str] - top_conf = [(0.0, 0, 0.0)] + top_entities = [] + top_conf = [] elif entities_with_scores and substr_len == 1 and entities_with_scores[0][1] < 1.0: - top_entities = [self.not_found_str] - top_conf = [(0.0, 0, 0.0)] + top_entities = [] + top_conf = [] elif entities_with_scores and ( entities_with_scores[0][1] < 0.3 or (entities_with_scores[0][3] < 0.13 and entities_with_scores[0][2] < 20) or (entities_with_scores[0][3] < 0.3 and entities_with_scores[0][2] < 4) or entities_with_scores[0][1] < 0.6 ): - top_entities = [self.not_found_str] - top_conf = [(0.0, 0, 0.0)] + top_entities = [] + top_conf = [] else: top_entities = [score[0] for score in entities_with_scores] top_conf = [score[1:] for score in entities_with_scores] - log.debug(f"--- top_entities {top_entities} top_conf {top_conf}") - high_conf_entities = [] high_conf_nums = [] for elem_num, (entity, conf) in enumerate(zip(top_entities, top_conf)): - if len(conf) == 3 and conf[0] == 1.0 and conf[1] > 50 and conf[2] > 0.3: + if len(conf) == 3 and conf[0] >= 1.0 and conf[1] > 50 and conf[2] > 0.3: new_conf = list(conf) if new_conf[1] > 55: new_conf[2] = 1.0 @@ -557,6 +630,7 @@ def rank_by_description( high_conf_nums.append(elem_num) high_conf_entities = sorted(high_conf_entities, key=lambda x: (x[1], x[3], x[2]), reverse=True) + log.debug(f"high_conf_entities: {high_conf_entities}") for n, elem_num in enumerate(high_conf_nums): if 0 <= elem_num - n < len(top_entities): del top_entities[elem_num - n] @@ -565,12 +639,83 @@ def rank_by_description( top_entities = [elem[0] for elem in high_conf_entities] + top_entities top_conf = [elem[1:] for elem in high_conf_entities] + top_conf - log.debug(f"top entities {top_entities} top_conf {top_conf}") + if not top_entities: + entities_with_scores = sorted(entities_with_scores, key=lambda x: (x[1], x[2], x[3]), reverse=True) + top_entities = [score[0] for score in entities_with_scores] + top_conf = [score[1:] for score in entities_with_scores] if self.num_entities_to_return == 1 and top_entities: entity_ids_list.append(top_entities[0]) - conf_list.append(top_conf[0]) + conf_list.append([round(cnf, 2) for cnf in top_conf[0]]) + elif self.num_entities_to_return == "max": + if top_conf: + max_conf = top_conf[0][0] + max_rank_conf = top_conf[0][2] + entity_ids, confs = [], [] + for entity_id, conf in zip(top_entities, top_conf): + if (conf[0] >= max_conf * 0.9 and max_rank_conf <= 1.0) \ + or (max_rank_conf == 1.0 and conf[2] == 1.0): + entity_ids.append(entity_id) + confs.append([round(cnf, 2) for cnf in conf]) + entity_ids_list.append(entity_ids) + conf_list.append(confs) + else: + entity_ids_list.append([]) + conf_list.append([]) else: entity_ids_list.append(top_entities[: self.num_entities_to_return]) - conf_list.append(top_conf[: self.num_entities_to_return]) + conf_list.append([[round(cnf, 2) for cnf in conf] for conf in top_conf[: self.num_entities_to_return]]) + log.debug(f"{entity_substr} --- top entities {entity_ids_list[-1]} --- top_conf {conf_list[-1]}") return entity_ids_list, conf_list + + def sort_out_low_conf(self, entity_substr, top_entities, top_conf): + if len(entity_substr.split()) > 1 and top_conf: + f_top_entities, f_top_conf = [], [] + for top_conf_thres, conf_thres in [(1.0, 0.9), (0.9, 0.8)]: + if top_conf[0][0] >= top_conf_thres: + for ent, conf in zip(top_entities, top_conf): + if conf[0] > conf_thres: + f_top_entities.append(ent) + f_top_conf.append(conf) + return f_top_entities, f_top_conf + return top_entities, top_conf + + def rank_by_connections(self, ids_list): + objects_sets_dict, scores_dict, conn_dict = {}, {}, {} + for ids in ids_list: + for entity_id in ids: + scores_dict[entity_id] = 0 + conn_dict[entity_id] = set() + for ids in ids_list: + for entity_id in ids[:self.num_entities_for_conn_ranking]: + objects = set() + for prefix in self.prefixes["entity"]: + tr, _ = self.kb.search_triples(f"{prefix}/{entity_id}", "", "") + for subj, rel, obj in tr: + if rel.split("/")[-1] not in {"P31", "P279"}: + if any([obj.startswith(pr) for pr in self.prefixes["entity"]]): + objects.add(obj.split("/")[-1]) + if rel.startswith(self.prefixes["rels"]["no_type"]): + tr2, _ = self.kb.search_triples(obj, "", "") + for _, rel2, obj2 in tr2: + if rel2.startswith(self.prefixes["rels"]["statement"]) \ + or rel2.startswith(self.prefixes["rels"]["qualifier"]): + if any([obj2.startswith(pr) for pr in self.prefixes["entity"]]): + objects.add(obj2.split("/")[-1]) + objects_sets_dict[entity_id] = objects + for obj in objects: + if obj not in objects_sets_dict: + objects_sets_dict[obj] = set() + objects_sets_dict[obj].add(entity_id) + + for i in range(len(ids_list)): + for j in range(len(ids_list)): + if i != j: + for entity_id1 in ids_list[i][:self.num_entities_for_conn_ranking]: + for entity_id2 in ids_list[j][:self.num_entities_for_conn_ranking]: + if entity_id1 in objects_sets_dict[entity_id2]: + conn_dict[entity_id1].add(entity_id2) + conn_dict[entity_id2].add(entity_id1) + for entity_id in conn_dict: + scores_dict[entity_id] = len(conn_dict[entity_id]) + return scores_dict diff --git a/deeppavlov/models/entity_extraction/find_word.py b/deeppavlov/models/entity_extraction/find_word.py new file mode 100644 index 0000000000..2bda42cc9a --- /dev/null +++ b/deeppavlov/models/entity_extraction/find_word.py @@ -0,0 +1,102 @@ +# Copyright 2017 Neural Networks and Deep Learning lab, MIPT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import pickle +from collections import Counter + +import numpy as np +import scipy as sp + +from deeppavlov.core.commands.utils import expand_path + +Sparse = sp.sparse.csr_matrix + + +class WordSearcher: + def __init__(self, words_dict_filename: str, ngrams_matrix_filename: str, lang: str = "@en", thresh: int = 1000): + self.words_dict_filename = words_dict_filename + self.ngrams_matrix_filename = ngrams_matrix_filename + if lang == "@en": + self.letters = "abcdefghijklmnopqrstuvwxyz" + elif lang == "@ru": + self.letters = "абвгдеёжзийклмнопрстуфхцчшщъыьэюя" + else: + raise ValueError(f'Unexpected lang value: "{lang}"') + self.thresh = thresh + self.load() + self.make_ngrams_dicts() + + def load(self): + with open(str(expand_path(self.words_dict_filename)), "rb") as fl: + self.words_dict = pickle.load(fl) + words_list = list(self.words_dict.keys()) + self.words_list = sorted(words_list) + + loader = np.load(str(expand_path(self.ngrams_matrix_filename)), allow_pickle=True) + self.count_matrix = Sparse((loader["data"], loader["indices"], loader["indptr"]), shape=loader["shape"]) + + def make_ngrams_dicts(self): + self.bigrams_dict, self.trigrams_dict = {}, {} + bigram_combs = list(itertools.product(self.letters, self.letters)) + bigram_combs = ["".join(comb) for comb in bigram_combs] + trigram_combs = list(itertools.product(self.letters, self.letters, self.letters)) + trigram_combs = ["".join(comb) for comb in trigram_combs] + for cnt, bigram in enumerate(bigram_combs): + self.bigrams_dict[bigram] = cnt + for cnt, trigram in enumerate(trigram_combs): + self.trigrams_dict[trigram] = cnt + len(bigram_combs) + + def __call__(self, query, tags): + ngrams_list = [] + for i in range(len(query) - 1): + ngram = query[i : i + 2].lower() + if ngram in self.bigrams_dict: + ngram_id = self.bigrams_dict[ngram] + ngrams_list.append(ngram_id) + for i in range(len(query) - 2): + ngram = query[i : i + 3].lower() + if ngram in self.trigrams_dict: + ngram_id = self.trigrams_dict[ngram] + ngrams_list.append(ngram_id) + ngrams_with_cnts = Counter(ngrams_list).most_common() + ngram_ids = [elem[0] for elem in ngrams_with_cnts] + ngram_cnts = [1 for _ in ngrams_with_cnts] + + indptr = np.array([0, len(ngram_cnts)]) + query_matrix = Sparse( + (ngram_cnts, ngram_ids, indptr), shape=(1, len(self.bigrams_dict) + len(self.trigrams_dict)) + ) + + scores = query_matrix * self.count_matrix + scores = np.squeeze(scores.toarray()) + + if self.thresh >= len(scores): + o = np.argpartition(-scores, len(scores) - 1)[0:self.thresh] + else: + o = np.argpartition(-scores, self.thresh)[0:self.thresh] + o_sort = o[np.argsort(-scores[o])] + o_sort = o_sort.tolist() + + found_words = [self.words_list[n] for n in o_sort] + found_words = [ + word + for word in found_words + if ( + word.startswith(query[0]) + and abs(len(word) - len(query)) < 3 + and self.words_dict[word].intersection(tags) + ) + ] + return found_words diff --git a/deeppavlov/models/entity_extraction/ner_chunker.py b/deeppavlov/models/entity_extraction/ner_chunker.py index e72037311b..5ecbe57f61 100644 --- a/deeppavlov/models/entity_extraction/ner_chunker.py +++ b/deeppavlov/models/entity_extraction/ner_chunker.py @@ -15,7 +15,7 @@ import re from logging import getLogger from string import punctuation -from typing import List, Tuple +from typing import List, Tuple, Union, Any from nltk import sent_tokenize from transformers import AutoTokenizer @@ -31,19 +31,19 @@ @register('ner_chunker') class NerChunker(Component): """ - Class to split documents into chunks of max_chunk_len symbols so that the length will not exceed + Class to split documents into chunks of max_seq_len symbols so that the length will not exceed maximal sequence length to feed into BERT """ - def __init__(self, vocab_file: str, max_seq_len: int = 400, lowercase: bool = False, max_chunk_len: int = 180, - batch_size: int = 2, **kwargs): + def __init__(self, vocab_file: str, max_seq_len: int = 400, lowercase: bool = False, batch_size: int = 2, **kwargs): """ Args: - max_chunk_len: maximal length of chunks into which the document is split + vocab_file: vocab file of pretrained transformer model + max_seq_len: maximal length of chunks into which the document is split + lowercase: whether to lowercase text batch_size: how many chunks are in batch """ self.max_seq_len = max_seq_len - self.max_chunk_len = max_chunk_len self.batch_size = batch_size self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, @@ -52,10 +52,12 @@ def __init__(self, vocab_file: str, max_seq_len: int = 400, lowercase: bool = Fa self.russian_letters = "абвгдеёжзийклмнопрстуфхцчшщъыьэюя" self.lowercase = lowercase - def __call__(self, docs_batch: List[str]) -> Tuple[List[List[str]], List[List[int]], - List[List[List[Tuple[int, int]]]], List[List[List[str]]]]: + def __call__(self, docs_batch: List[str]) -> Tuple[List[List[str]], List[List[int]], List[List[Union[ + List[Union[Tuple[int, int], Tuple[Union[int, Any], Union[int, Any]]]], List[ + Tuple[Union[int, Any], Union[int, Any]]], List[Tuple[int, int]]]]], List[List[Union[List[Any], List[str]]]], + List[List[str]]]: """ - This method splits each document in the batch into chunks wuth the maximal length of max_chunk_len + This method splits each document in the batch into chunks wuth the maximal length of max_seq_len Args: docs_batch: batch of documents @@ -185,15 +187,22 @@ class NerChunkModel(Component): def __init__(self, ner: Chainer, ner_parser: EntityDetectionParser, + ner2: Chainer = None, + ner_parser2: EntityDetectionParser = None, **kwargs) -> None: """ Args: ner: config for entity detection ner_parser: component deeppavlov.models.entity_extraction.entity_detection_parser + ner2: config of additional entity detection model (ensemble of ner and ner2 models gives better + entity detection quality than single ner model) + ner_parser2: component deeppavlov.models.entity_extraction.entity_detection_parser **kwargs: """ self.ner = ner self.ner_parser = ner_parser + self.ner2 = ner2 + self.ner_parser2 = ner_parser2 def __call__(self, text_batch_list: List[List[str]], nums_batch_list: List[List[int]], @@ -222,6 +231,13 @@ def __call__(self, text_batch_list: List[List[str]], ner_tokens_batch, ner_tokens_offsets_batch, ner_probas_batch, probas_batch = self.ner(text_batch) entity_substr_batch, entity_positions_batch, entity_probas_batch = \ self.ner_parser(ner_tokens_batch, ner_probas_batch, probas_batch) + if self.ner2: + ner_tokens_batch2, ner_tokens_offsets_batch2, ner_probas_batch2, probas_batch2 = self.ner2(text_batch) + entity_substr_batch2, entity_positions_batch2, entity_probas_batch2 = \ + self.ner_parser2(ner_tokens_batch2, ner_probas_batch2, probas_batch2) + entity_substr_batch, entity_positions_batch, entity_probas_batch = \ + self.merge_annotations(entity_substr_batch, entity_positions_batch, entity_probas_batch, + entity_substr_batch2, entity_positions_batch2, entity_probas_batch2) entity_pos_tags_probas_batch = [[(entity_substr.lower(), entity_substr_positions, tag, entity_proba) for tag, entity_substr_list in entity_substr_dict.items() @@ -316,3 +332,56 @@ def __call__(self, text_batch_list: List[List[str]], return doc_entity_substr_batch, doc_entity_offsets_batch, doc_entity_positions_batch, doc_tags_batch, \ doc_sentences_offsets_batch, doc_sentences_batch, doc_probas_batch + + def merge_annotations(self, substr_batch, pos_batch, probas_batch, substr_batch2, pos_batch2, probas_batch2): + log.debug(f"ner_chunker, substr2: {substr_batch2} --- pos2: {pos_batch2} --- probas2: {probas_batch2} --- " + f"substr: {substr_batch} --- pos: {pos_batch} --- probas: {probas_batch}") + for i in range(len(substr_batch)): + for key2 in substr_batch2[i]: + substr_list2 = substr_batch2[i][key2] + pos_list2 = pos_batch2[i][key2] + probas_list2 = probas_batch2[i][key2] + for substr2, pos2, probas2 in zip(substr_list2, pos_list2, probas_list2): + found = False + for key in substr_batch[i]: + pos_list = pos_batch[i][key] + for pos in pos_list: + if pos[0] <= pos2[0] <= pos[-1] or pos[0] <= pos2[-1] <= pos[-1]: + found = True + if not found: + if key2 not in substr_batch[i]: + substr_batch[i][key2] = [] + pos_batch[i][key2] = [] + probas_batch[i][key2] = [] + substr_batch[i][key2].append(substr2) + pos_batch[i][key2].append(pos2) + probas_batch[i][key2].append(probas2) + for i in range(len(substr_batch)): + for key2 in substr_batch2[i]: + substr_list2 = substr_batch2[i][key2] + pos_list2 = pos_batch2[i][key2] + probas_list2 = probas_batch2[i][key2] + for substr2, pos2, probas2 in zip(substr_list2, pos_list2, probas_list2): + for key in substr_batch[i]: + inds = [] + substr_list = substr_batch[i][key] + pos_list = pos_batch[i][key] + probas_list = probas_batch[i][key] + for n, (substr, pos, probas) in enumerate(zip(substr_list, pos_list, probas_list)): + if (pos[0] == pos2[0] and pos[-1] < pos2[-1]) or (pos[0] > pos2[0] and pos[-1] == pos2[-1]): + inds.append(n) + elif key == "EVENT" and ((pos[0] >= pos2[0] and pos[-1] <= pos2[-1]) + or (len(substr.split()) == 1 and pos2[0] <= pos[0])): + inds.append(n) + + if (len(inds) > 1 or (len(inds) == 1 and key in {"WORK_OF_ART", "EVENT"})) \ + and not (key == "PERSON" and " и " in substr2): + inds = sorted(inds, reverse=True) + for ind in inds: + del substr_batch[i][key][ind] + del pos_batch[i][key][ind] + del probas_batch[i][key][ind] + substr_batch[i][key].append(substr2) + pos_batch[i][key].append(pos2) + probas_batch[i][key].append(probas2) + return substr_batch, pos_batch, probas_batch diff --git a/deeppavlov/models/kbqa/query_generator.py b/deeppavlov/models/kbqa/query_generator.py index 002a893326..501963b4eb 100644 --- a/deeppavlov/models/kbqa/query_generator.py +++ b/deeppavlov/models/kbqa/query_generator.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import itertools import re -from collections import namedtuple, OrderedDict +from collections import defaultdict from logging import getLogger from typing import Tuple, List, Optional, Union, Dict, Any, Set @@ -22,10 +23,11 @@ import numpy as np from deeppavlov.core.common.registry import register +from deeppavlov.core.models.component import Component from deeppavlov.models.kbqa.query_generator_base import QueryGeneratorBase from deeppavlov.models.kbqa.rel_ranking_infer import RelRankerInfer -from deeppavlov.models.kbqa.utils import \ - extract_year, extract_number, order_of_answers_sorting, make_combs, fill_query +from deeppavlov.models.kbqa.utils import extract_year, extract_number, make_combs, fill_query, find_query_features, \ + make_sparql_query, merge_sparql_query from deeppavlov.models.kbqa.wiki_parser import WikiParser log = getLogger(__name__) @@ -40,28 +42,36 @@ class QueryGenerator(QueryGeneratorBase): def __init__(self, wiki_parser: WikiParser, rel_ranker: RelRankerInfer, entities_to_leave: int = 5, + types_to_leave: int = 2, rels_to_leave: int = 7, max_comb_num: int = 10000, - return_all_possible_answers: bool = False, *args, **kwargs) -> None: + gold_query_info: Dict[str, str] = None, + map_query_str_to_kb: List[Tuple[str, str]] = None, + return_answers: bool = True, *args, **kwargs) -> None: """ Args: wiki_parser: component deeppavlov.models.kbqa.wiki_parser rel_ranker: component deeppavlov.models.kbqa.rel_ranking_infer entities_to_leave: how many entities to leave after entity linking + types_to_leave: how many types to leave after entity linking rels_to_leave: how many relations to leave after relation ranking max_comb_num: the maximum number of combinations of candidate entities and relations - return_all_possible_answers: whether to return all found answers + gold_query_info: dict of variable names used for formatting output sparql queries + map_query_str_to_kb: mapping of knowledge base prefixes to full https + return_answers: whether to return answers or candidate relations and answers for further ranking **kwargs: """ self.wiki_parser = wiki_parser self.rel_ranker = rel_ranker self.entities_to_leave = entities_to_leave + self.types_to_leave = types_to_leave self.rels_to_leave = rels_to_leave self.max_comb_num = max_comb_num - self.return_all_possible_answers = return_all_possible_answers - self.replace_tokens = [("wdt:p31", "wdt:P31"), ("pq:p580", "pq:P580"), - ("pq:p582", "pq:P582"), ("pq:p585", "pq:P585"), ("pq:p1545", "pq:P1545")] + self.gold_query_info = gold_query_info + self.map_query_str_to_kb = map_query_str_to_kb + self.return_answers = return_answers + self.replace_tokens = [("wdt:p", "wdt:P"), ("pq:p", "pq:P")] super().__init__(wiki_parser=self.wiki_parser, rel_ranker=self.rel_ranker, entities_to_leave=self.entities_to_leave, rels_to_leave=self.rels_to_leave, *args, **kwargs) @@ -70,164 +80,313 @@ def __call__(self, question_batch: List[str], question_san_batch: List[str], template_type_batch: Union[List[List[str]], List[str]], entities_from_ner_batch: List[List[str]], + types_from_ner_batch: List[List[str]], entity_tags_batch: List[List[str]], - answer_types_batch: List[Set[str]]) -> List[str]: - - candidate_outputs_batch = [] - template_answers_batch = [] - templates_nums_batch = [] - log.debug(f"kbqa inputs {question_batch} {entities_from_ner_batch} {template_type_batch} {entity_tags_batch}") - for question, question_sanitized, template_type, entities_from_ner, entity_tags_list, answer_types in \ - zip(question_batch, question_san_batch, template_type_batch, entities_from_ner_batch, - entity_tags_batch, answer_types_batch): + probas_batch: List[List[float]], + answer_types_batch: List[Set[str]] = None, + entities_to_link_batch: List[List[int]] = None) -> Tuple[List[Any], List[Any]]: + + candidate_outputs_batch, template_answers_batch = [], [] + if not answer_types_batch or answer_types_batch[0] is None: + answer_types_batch = [[] for _ in question_batch] + if not entities_to_link_batch or entities_to_link_batch[0] is None: + entities_to_link_batch = [[1 for _ in substr_list] for substr_list in entities_from_ner_batch] + log.debug(f"kbqa inputs {question_batch} {question_san_batch} template_type_batch: {template_type_batch} --- " + f"entities_from_ner: {entities_from_ner_batch} --- types_from_ner: {types_from_ner_batch} --- " + f"entity_tags_batch: {entity_tags_batch} --- answer_types_batch: " + f"{[list(elem)[:3] for elem in answer_types_batch]}") + for question, question_sanitized, template_type, entities_from_ner, types_from_ner, entity_tags_list, \ + probas, entities_to_link, answer_types in zip(question_batch, question_san_batch, template_type_batch, + entities_from_ner_batch, types_from_ner_batch, + entity_tags_batch, probas_batch, entities_to_link_batch, + answer_types_batch): if template_type == "-1": template_type = "7" - candidate_outputs, template_answer, templates_nums = \ + candidate_outputs, template_answer = \ self.find_candidate_answers(question, question_sanitized, template_type, entities_from_ner, - entity_tags_list, answer_types) + types_from_ner, entity_tags_list, probas, entities_to_link, answer_types) candidate_outputs_batch.append(candidate_outputs) template_answers_batch.append(template_answer) - templates_nums_batch.append(templates_nums) - answers = self.rel_ranker(question_batch, candidate_outputs_batch, entities_from_ner_batch, - template_answers_batch) - log.debug(f"(__call__)answers: {answers}") - if not answers: - answers = ["Not Found" for _ in question_batch] - return answers + if self.return_answers: + answers = self.rel_ranker(question_batch, template_type_batch, candidate_outputs_batch, + entities_from_ner_batch, template_answers_batch) + log.debug(f"(__call__)answers: {answers}") + if not answers: + answers = ["Not Found" for _ in question_batch] + return answers + else: + return candidate_outputs_batch, template_answers_batch + + def parse_queries_info(self, question, queries_info, entity_ids, type_ids, rels_from_template): + parsed_queries_info = [] + question_tokens = nltk.word_tokenize(question) + rels_scores_dict = {} + for query_info in queries_info: + query = query_info["query_template"].lower() + for old_tok, new_tok in self.replace_tokens: + query = query.replace(old_tok, new_tok) + log.debug(f"\n_______________________________\nquery: {query}\n_______________________________\n") + entities_and_types_select = query_info["entities_and_types_select"] + rels_for_search = query_info["rank_rels"] + rel_types = query_info["rel_types"] + n_hops = query_info["n_hops"] + unk_rels = query_info.get("unk_rels", []) + query_seq_num = query_info["query_sequence"] + return_if_found = query_info["return_if_found"] + log.debug(f"(query_parser)query: {query}, rels_for_search {rels_for_search}, rel_types {rel_types} " + f"n_hops {n_hops}, {query_seq_num}, {return_if_found}") + query_triplets = re.findall("{[ ]?(.*?)[ ]?}", query)[0].split(' . ') + log.debug(f"(query_parser)query_triplets: {query_triplets}") + query_triplets_split = [triplet.split(' ')[:3] for triplet in query_triplets] + property_types = {} + for rel_type, query_triplet in zip(rel_types, query_triplets_split): + if query_triplet[1].startswith("?") and rel_type == "qualifier": + property_types[query_triplet[1]] = rel_type + query_sequence_dict = {num + 1: triplet for num, triplet in enumerate(query_triplets_split)} + query_sequence = [] + for i in query_seq_num: + query_sequence.append(query_sequence_dict[i]) + triplet_info_list = [("forw" if triplet[2].startswith('?') else "backw", search_source, rel_type, n_hop) + for search_source, triplet, rel_type, n_hop in \ + zip(rels_for_search, query_sequence, rel_types, n_hops) + if search_source != "do_not_rank"] + log.debug(f"(query_parser)query_sequence_dict: {query_sequence_dict} --- rel_directions: " + f"{triplet_info_list} --- query_sequence: {query_sequence}") + entity_ids = [entity[:self.entities_to_leave] for entity in entity_ids] + rels, entities_rel_conn = [], set() + if rels_from_template is not None: + rels = [[(rel, 1.0) for rel in rel_list] for rel_list in rels_from_template] + elif not rels: + for triplet_info in triplet_info_list: + ex_rels, cur_rels_scores_dict, entity_rel_conn = self.find_top_rels(question, entity_ids, + triplet_info) + rels.append(ex_rels) + rels_scores_dict = {**rels_scores_dict, **cur_rels_scores_dict} + entities_rel_conn = entities_rel_conn.union(entity_rel_conn) + log.debug(f"(query_parser)rels: {rels}") + rels_from_query = [triplet[1] for triplet in query_triplets_split if triplet[1].startswith('?')] + qualifier_rels = [triplet[1] for triplet in query_triplets_split if triplet[1].startswith("pq:P")] + + answer_ent, order_info, filter_from_query = find_query_features(query, qualifier_rels, question) + log.debug(f"(query_parser) filter_from_query: {filter_from_query} --- order_info: {order_info}") - def query_parser(self, question: str, query_info: Dict[str, str], - entities_and_types_select: List[str], + year = extract_year(question_tokens, question) + number = extract_number(question_tokens, question) + log.debug(f"year {year}, number {number}") + if year: + filter_info = [(elem[0], elem[1].replace("n", year)) for elem in filter_from_query] + elif number: + filter_info = [(elem[0], elem[1].replace("n", number)) for elem in filter_from_query] + else: + filter_info = [elem for elem in filter_from_query if elem[1] != "n"] + for unk_prop, prop_type in property_types.items(): + filter_info.append((unk_prop, prop_type)) + log.debug(f"(query_parser)filter_from_query: {filter_from_query}") + rel_combs = make_combs(rels, permut=False) + + entity_positions, type_positions = [elem.split('_') for elem in entities_and_types_select.split(' ')] + log.debug(f"entity_positions {entity_positions}, type_positions {type_positions}") + selected_entity_ids, selected_type_ids = [], [] + if len(entity_ids) > 1 and len(entity_positions) == 1: + selected_entity_ids = [] + for j in range(max([len(elem) for elem in entity_ids])): + for elem in entity_ids: + if j < len(elem): + selected_entity_ids.append(elem[j]) + selected_entity_ids = [selected_entity_ids] + elif entity_ids: + selected_entity_ids = [entity_ids[int(pos) - 1] for pos in entity_positions if int(pos) > 0] + if type_ids: + selected_type_ids = [type_ids[int(pos) - 1][:self.types_to_leave] + for pos in type_positions if int(pos) > 0] + entity_combs = make_combs(selected_entity_ids, permut=True) + type_combs = make_combs(selected_type_ids, permut=False) + log.debug(f"(query_parser)entity_combs: {entity_combs[:3]}, type_combs: {type_combs[:3]}," + f" rel_combs: {rel_combs[:3]}") + + all_combs_list = list(itertools.product(entity_combs, type_combs, rel_combs)) + all_combs_list = sorted(all_combs_list, key=lambda x: (sum([elem[-1] for elem in x]), x[0][-1])) + parsed_queries_info.append({"query_triplets": query_triplets, + "query_sequence": query_sequence, + "rels_from_query": rels_from_query, + "answer_ent": answer_ent, + "filter_info": filter_info, + "order_info": order_info, + "rel_types": rel_types, + "unk_rels": unk_rels, + "return_if_found": return_if_found, + "selected_entity_ids": selected_entity_ids, + "selected_type_ids": selected_type_ids, + "rels": rels, + "entities_rel_conn": entities_rel_conn, + "entity_combs": entity_combs, + "type_combs": type_combs, + "rel_combs": rel_combs, + "all_combs_list": all_combs_list}) + return parsed_queries_info, rels_scores_dict + + def check_valid_query(self, entities_rel_conn, query_hdt_seq): + entity_rel_valid = True + if entities_rel_conn: + for query_hdt_elem in query_hdt_seq: + entity, rel = "", "" + if len(query_hdt_elem) == 3 and any([query_hdt_elem[i].startswith("?") for i in [0, 2]]): + if "statement" in self.kb_prefixes and query_hdt_elem[1].startswith(self.kb_prefixes["statement"]): + continue + else: + if not query_hdt_elem[0].startswith("?"): + entity = query_hdt_elem[0].split("/")[-1] + elif not query_hdt_elem[2].startswith("?"): + entity = query_hdt_elem[2].split("/")[-1] + if not query_hdt_elem[1].startswith("?"): + rel = query_hdt_elem[1].split("/")[-1] + if entity and rel and rel not in self.kb_prefixes["type_rels"] \ + and (entity, rel) not in entities_rel_conn: + entity_rel_valid = False + return entity_rel_valid + + def query_parser(self, question: str, + queries_info: Dict[str, str], entity_ids: List[List[str]], type_ids: List[List[str]], answer_types: Set[str], - rels_from_template: Optional[List[Tuple[str]]] = None) -> Union[ - List[Dict[str, Union[Union[Tuple[Any, ...], List[Any]], Any]]], List[Dict[str, Any]]]: - question_tokens = nltk.word_tokenize(question) - query = query_info["query_template"].lower() - for old_tok, new_tok in self.replace_tokens: - query = query.replace(old_tok, new_tok) - log.debug(f"\n_______________________________\nquery: {query}\n_______________________________\n") - rels_for_search = query_info["rank_rels"] - rel_types = query_info["rel_types"] - query_seq_num = query_info["query_sequence"] - return_if_found = query_info["return_if_found"] - define_sorting_order = query_info["define_sorting_order"] - property_types = query_info["property_types"] - log.debug(f"(query_parser)query: {query}, {rels_for_search}, {query_seq_num}, {return_if_found}") - query_triplets = re.findall("{[ ]?(.*?)[ ]?}", query)[0].split(' . ') - log.debug(f"(query_parser)query_triplets: {query_triplets}") - query_triplets = [triplet.split(' ')[:3] for triplet in query_triplets] - query_sequence_dict = {num: triplet for num, triplet in zip(query_seq_num, query_triplets)} - query_sequence = [] - for i in range(1, max(query_seq_num) + 1): - query_sequence.append(query_sequence_dict[i]) - triplet_info_list = [("forw" if triplet[2].startswith('?') else "backw", search_source, rel_type) - for search_source, triplet, rel_type in zip(rels_for_search, query_triplets, rel_types) if - search_source != "do_not_rank"] - log.debug(f"(query_parser)rel_directions: {triplet_info_list}") - entity_ids = [entity[:self.entities_to_leave] for entity in entity_ids] - if rels_from_template is not None: - rels = [[(rel, 1.0) for rel in rel_list] for rel_list in rels_from_template] + rels_from_template: Optional[List[Tuple[str]]] = None) -> Union[List[Dict[str, Any]], list]: + parsed_queries_info, rels_scores_dict = self.parse_queries_info(question, queries_info, entity_ids, type_ids, + rels_from_template) + queries_list, parser_info_list, entity_conf_list = [], [], [] + new_combs_list, query_info_list = [], [] + combs_num_list = [len(parsed_query_info["all_combs_list"]) for parsed_query_info in parsed_queries_info] + if combs_num_list: + max_comb_nums = max(combs_num_list) else: - rels = [self.find_top_rels(question, entity_ids, triplet_info) - for triplet_info in triplet_info_list] - rels = [[rel for rel in rel_list] for rel_list in rels] - log.debug(f"(query_parser)rels: {rels}") - rels_from_query = [triplet[1] for triplet in query_triplets if triplet[1].startswith('?')] - answer_ent = re.findall(r"select [\(]?([\S]+) ", query) - order_info_nt = namedtuple("order_info", ["variable", "sorting_order"]) - order_variable = re.findall("order by (asc|desc)\((.*)\)", query) - if order_variable: - if define_sorting_order: - answers_sorting_order = order_of_answers_sorting(question) - else: - answers_sorting_order = order_variable[0][0] - order_info = order_info_nt(order_variable[0][1], answers_sorting_order) - else: - order_info = order_info_nt(None, None) - log.debug(f"question, order_info: {question}, {order_info}") - filter_from_query = re.findall("contains\((\?\w), (.+?)\)", query) - log.debug(f"(query_parser)filter_from_query: {filter_from_query}") - - year = extract_year(question_tokens, question) - number = extract_number(question_tokens, question) - log.debug(f"year {year}, number {number}") - if year: - filter_info = [(elem[0], elem[1].replace("n", year)) for elem in filter_from_query] - elif number: - filter_info = [(elem[0], elem[1].replace("n", number)) for elem in filter_from_query] - else: - filter_info = [elem for elem in filter_from_query if elem[1] != "n"] - for unk_prop, prop_type in property_types.items(): - filter_info.append((unk_prop, prop_type)) - log.debug(f"(query_parser)filter_from_query: {filter_from_query}") - rel_combs = make_combs(rels, permut=False) - entity_positions, type_positions = [elem.split('_') for elem in entities_and_types_select.split(' ')] - log.debug(f"entity_positions {entity_positions}, type_positions {type_positions}") - selected_entity_ids, selected_type_ids = [], [] - if entity_ids: - selected_entity_ids = [entity_ids[int(pos) - 1] for pos in entity_positions if int(pos) > 0] - if type_ids: - selected_type_ids = [type_ids[int(pos) - 1] for pos in type_positions if int(pos) > 0] - entity_combs = make_combs(selected_entity_ids, permut=True) - type_combs = make_combs(selected_type_ids, permut=False) - log.debug(f"(query_parser)entity_combs: {entity_combs[:3]}, type_combs: {type_combs[:3]}," - f" rel_combs: {rel_combs[:3]}") - queries_list = [] - parser_info_list = [] - confidences_list = [] - all_combs_list = list(itertools.product(entity_combs, type_combs, rel_combs)) - for comb_num, combs in enumerate(all_combs_list): - confidence = np.prod([score for rel, score in combs[2][:-1]]) - confidences_list.append(confidence) - query_hdt_seq = [ - fill_query(query_hdt_elem, combs[0], combs[1], combs[2]) for query_hdt_elem in query_sequence] - if comb_num == 0: - log.debug(f"\n__________________________\nfilled query: {query_hdt_seq}\n__________________________\n") - if comb_num > 0: - answer_types = [] - queries_list.append( - (rels_from_query + answer_ent, query_hdt_seq, filter_info, order_info, answer_types, rel_types, - return_if_found)) - - parser_info_list.append("query_execute") - if comb_num == self.max_comb_num: - break - - candidate_outputs = [] - candidate_outputs_list = self.wiki_parser(parser_info_list, queries_list) - if self.use_wp_api_requester and isinstance(candidate_outputs_list, list) and candidate_outputs_list: - candidate_outputs_list = candidate_outputs_list[0] - - if isinstance(candidate_outputs_list, list) and candidate_outputs_list: - outputs_len = len(candidate_outputs_list) - all_combs_list = all_combs_list[:outputs_len] - confidences_list = confidences_list[:outputs_len] - for combs, confidence, candidate_output in zip(all_combs_list, confidences_list, candidate_outputs_list): - candidate_outputs += [[combs[0]] + [rel for rel, score in combs[2][:-1]] + output + [confidence] - for output in candidate_output] - - if self.return_all_possible_answers: - candidate_outputs_dict = OrderedDict() - for candidate_output in candidate_outputs: - candidate_output_key = (tuple(candidate_output[0]), tuple(candidate_output[1:-2])) - if candidate_output_key not in candidate_outputs_dict: - candidate_outputs_dict[candidate_output_key] = [] - candidate_outputs_dict[candidate_output_key].append(candidate_output[-2:]) - candidate_outputs = [] - for (candidate_entity_comb, candidate_rel_comb), candidate_output in candidate_outputs_dict.items(): - candidate_outputs.append({"entities": candidate_entity_comb, - "relations": list(candidate_rel_comb), - "answers": tuple([ans for ans, conf in candidate_output]), - "rel_conf": candidate_output[0][1] - }) - else: - candidate_outputs = [{"entities": f_entities, - "relations": f_relations, - "answers": f_answers, - "rel_conf": f_rel_conf - } for f_entities, *f_relations, f_answers, f_rel_conf in candidate_outputs] - log.debug(f"(query_parser)final outputs: {candidate_outputs[:3]}") - - return candidate_outputs + max_comb_nums = 0 + for comb_num in range(max_comb_nums): + for parsed_query_info in parsed_queries_info: + if comb_num < min(len(parsed_query_info["all_combs_list"]), self.max_comb_num): + query_triplets = parsed_query_info["query_triplets"] + query_sequence = parsed_query_info["query_sequence"] + rels_from_query = parsed_query_info["rels_from_query"] + answer_ent = parsed_query_info["answer_ent"] + filter_info = parsed_query_info["filter_info"] + order_info = parsed_query_info["order_info"] + rel_types = parsed_query_info["rel_types"] + unk_rels = parsed_query_info["unk_rels"] + return_if_found = parsed_query_info["return_if_found"] + entities_rel_conn = parsed_query_info["entities_rel_conn"] + combs = parsed_query_info["all_combs_list"][comb_num] + if combs[0][-1] == 0: + entity_conf_list.append(1.0) + else: + entity_conf_list.append(0.9) + query_hdt_seq = [fill_query(query_hdt_elem, combs[0], combs[1], combs[2], + self.map_query_str_to_kb) + for query_hdt_elem in query_sequence] + if comb_num == 0: + log.debug(f"\n______________________\nfilled query: {query_hdt_seq}\n______________________\n") + + entity_rel_valid = self.check_valid_query(entities_rel_conn, query_hdt_seq) + if entity_rel_valid: + new_combs_list.append(combs) + queries_list.append((answer_ent, rels_from_query, query_hdt_seq, filter_info, order_info, + answer_types, rel_types, return_if_found)) + query_info_list.append((query_triplets, query_hdt_seq, answer_ent, filter_info, order_info)) + parser_info_list.append("query_execute") + if comb_num < 3 and unk_rels: + unk_query_sequence = copy.deepcopy(query_sequence) + unk_rels_from_query = copy.deepcopy(rels_from_query) + for unk_rel, rel_var in zip(unk_rels, ["?p", "?p2"]): + unk_query_sequence[int(unk_rel) - 1][1] = rel_var + combs[-1][int(unk_rel) - 1] = (rel_var, 1.0) + if rel_var not in rels_from_query: + unk_rels_from_query.append(rel_var) + query_hdt_seq = [ + fill_query(query_hdt_elem, combs[0], combs[1], combs[2], self.map_query_str_to_kb) + for query_hdt_elem in unk_query_sequence] + new_combs_list.append(combs) + queries_list.append((answer_ent, unk_rels_from_query, query_hdt_seq, filter_info, order_info, + answer_types, rel_types, return_if_found)) + query_info_list.append((query_triplets, query_hdt_seq, answer_ent, filter_info, order_info)) + parser_info_list.append("query_execute") + + outputs_list = self.wiki_parser(parser_info_list, queries_list) + outputs = self.parse_outputs(outputs_list, new_combs_list, query_info_list, entity_conf_list, rels_scores_dict) + return outputs + + def parse_outputs(self, outputs_list, combs_list, query_info_list, entity_conf_list, rels_scores_dict): + outputs = [] + if isinstance(outputs_list, list) and outputs_list: + outputs_len = len(outputs_list) + combs_list = combs_list[:outputs_len] + entity_conf_list = entity_conf_list[:outputs_len] + for combs, query_info, entity_conf, (answers_list, found_rels_list, found_combs_list) in \ + zip(combs_list, query_info_list, entity_conf_list, outputs_list): + for answers, found_rels, found_comb in zip(answers_list, found_rels_list, found_combs_list): + found_rels = [found_rel.split("/")[-1] for found_rel in found_rels] + new_combs = list(copy.deepcopy(combs)) + found_unk_rel = False + for j, rel_var in enumerate(["?p", "?p2"]): + if isinstance(new_combs[2][j], tuple) and new_combs[2][j][0] == rel_var: + if found_rels: + new_combs[2][j] = (found_rels[j], rels_scores_dict.get(found_rels[j], 1.0)) + else: + new_combs[2][j] = (new_combs[2][j][0], 0.0) + found_unk_rel = True + if found_rels and not found_unk_rel: + new_combs[2] = new_combs[2][:-1] + [(found_rels[0], 1.0), new_combs[2][-1]] + confidence = np.prod([score for rel, score in new_combs[2][:-1]]) + if answers: + outputs.append([new_combs[0], new_combs[1]] + [rel for rel, score in new_combs[2][:-1]] + + answers + [(confidence, entity_conf), found_comb, query_info, new_combs[2]]) + outputs_dict = defaultdict(list) + types_dict = defaultdict(list) + for output in outputs: + key = (tuple(output[0]), tuple([rel.split("/")[-1] for rel in output[2:-5]])) + if key not in outputs_dict or output[-5:] not in outputs_dict[key]: + outputs_dict[key].append(output[-5:]) + types_dict[key].append(tuple(output[1])) + outputs = [] + for (entity_comb, rel_comb), output in outputs_dict.items(): + type_comb = types_dict[(entity_comb, rel_comb)] + output_conf = [elem[1] for elem in output] + output_conf = sorted(output_conf, key=lambda x: x[0] * x[1], reverse=True) + found_combs = [elem[2] for elem in output] + queries = [elem[3] for elem in output] + rel_combs = [elem[4] for elem in output] + cur_rel_comb = rel_combs[0] + cur_rel_comb = [rel for rel, score in cur_rel_comb[:-1]] + sparql_query = make_sparql_query(queries[0], entity_comb, rel_combs[0], type_comb[0], + self.gold_query_info) + parser_info_list = ["fill_triplets"] + parser_query_list = [(queries[0][1], queries[0][2], found_combs[0])] + filled_triplets = self.wiki_parser(parser_info_list, parser_query_list) + outputs.append({"entities": entity_comb, "types": type_comb, "relations": list(cur_rel_comb), + "answers": tuple([ans for ans, *_ in output]), "output_conf": output_conf[0], + "sparql_query": sparql_query, "triplets": filled_triplets[0]}) + return outputs + + +@register('query_formatter') +class QueryFormatter(Component): + def __init__(self, query_info: Dict[str, str], replace_prefixes: Dict[str, str] = None, **kwargs): + self.query_info = query_info + self.replace_prefixes = replace_prefixes + + def __call__(self, queries_batch): + parsed_queries_batch = [] + for query in queries_batch: + query_split = re.findall("{[ ]?(.*?)[ ]?}", query) + init_query_triplets, query_triplets = [], [] + if query_split: + init_query_triplets = query_split[0].split('. ') + for triplet in init_query_triplets: + triplet = " ".join([elem.strip("<>") for elem in triplet.strip().split()]) + if self.replace_prefixes: + for old_prefix, new_prefix in self.replace_prefixes.items(): + triplet = triplet.replace(old_prefix, new_prefix) + query_triplets.append(triplet) + answer_ent, order_info, filter_from_query = find_query_features(query, order_from_query=True) + query_info = (query_triplets, answer_ent, filter_from_query, order_info) + query = merge_sparql_query(query_info, self.query_info) + parsed_queries_batch.append(query) + return parsed_queries_batch diff --git a/deeppavlov/models/kbqa/query_generator_base.py b/deeppavlov/models/kbqa/query_generator_base.py index f0d80c2efe..5f093612c8 100644 --- a/deeppavlov/models/kbqa/query_generator_base.py +++ b/deeppavlov/models/kbqa/query_generator_base.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import json from logging import getLogger -from typing import Tuple, List, Optional, Union, Any, Set +from typing import Tuple, List, Dict, Optional, Union, Any, Set from bs4 import BeautifulSoup from whapi import search, get_html @@ -26,6 +27,7 @@ from deeppavlov.models.entity_extraction.entity_linking import EntityLinker from deeppavlov.models.kbqa.rel_ranking_infer import RelRankerInfer from deeppavlov.models.kbqa.template_matcher import TemplateMatcher +from deeppavlov.models.kbqa.utils import preprocess_template_queries log = getLogger(__name__) @@ -37,12 +39,11 @@ class QueryGeneratorBase(Component, Serializable): """ def __init__(self, template_matcher: TemplateMatcher, - entity_linker: EntityLinker, rel_ranker: RelRankerInfer, load_path: str, - rank_rels_filename_1: str, - rank_rels_filename_2: str, sparql_queries_filename: str, + entity_linker: EntityLinker, + rels_in_ranking_queries_fname: str = None, wiki_parser=None, entities_to_leave: int = 5, rels_to_leave: int = 7, @@ -50,17 +51,17 @@ def __init__(self, template_matcher: TemplateMatcher, use_wp_api_requester: bool = False, use_el_api_requester: bool = False, use_alt_templates: bool = True, - use_add_templates: bool = False, *args, **kwargs) -> None: + delete_rel_prefix: bool = True, + kb_prefixes: Dict[str, str] = None, *args, **kwargs) -> None: """ Args: template_matcher: component deeppavlov.models.kbqa.template_matcher - entity_linker: component deeppavlov.models.entity_extraction.entity_linking for linking of entities rel_ranker: component deeppavlov.models.kbqa.rel_ranking_infer load_path: path to folder with wikidata files - rank_rels_filename_1: file with list of rels for first rels in questions with ranking - rank_rels_filename_2: file with list of rels for second rels in questions with ranking sparql_queries_filename: file with sparql query templates + entity_linker: component deeppavlov.models.entity_extraction.entity_linking for linking of entities + rels_in_ranking_queries_fname: file with list of rels in queries for questions with ranking wiki_parser: component deeppavlov.models.kbqa.wiki_parser entities_to_leave: how many entities to leave after entity linking rels_to_leave: how many relations to leave after relation ranking @@ -70,37 +71,34 @@ def __init__(self, template_matcher: TemplateMatcher, use_el_api_requester: whether deeppavlov.models.api_requester.api_requester component will be used for Entity Linking use_alt_templates: whether to use alternative templates if no answer was found for default query template + delete_rel_prefix: whether to delete prefix in relations + kb_prefixes: prefixes for entities, relations and types in the knowledge base """ super().__init__(save_path=None, load_path=load_path) self.template_matcher = template_matcher self.entity_linker = entity_linker self.wiki_parser = wiki_parser self.rel_ranker = rel_ranker - self.rank_rels_filename_1 = rank_rels_filename_1 - self.rank_rels_filename_2 = rank_rels_filename_2 - self.rank_list_0 = [] - self.rank_list_1 = [] + self.rels_in_ranking_queries_fname = rels_in_ranking_queries_fname + self.rels_in_ranking_queries = {} self.entities_to_leave = entities_to_leave self.rels_to_leave = rels_to_leave self.syntax_structure_known = syntax_structure_known self.use_wp_api_requester = use_wp_api_requester self.use_el_api_requester = use_el_api_requester self.use_alt_templates = use_alt_templates - self.use_add_templates = use_add_templates self.sparql_queries_filename = sparql_queries_filename + self.delete_rel_prefix = delete_rel_prefix + self.kb_prefixes = kb_prefixes self.load() def load(self) -> None: - with open(self.load_path / self.rank_rels_filename_1, 'r') as fl1: - lines = fl1.readlines() - self.rank_list_0 = [line.split('\t')[0] for line in lines] + if self.rels_in_ranking_queries_fname is not None: + self.rels_in_ranking_queries = read_json(self.load_path / self.rels_in_ranking_queries_fname) - with open(self.load_path / self.rank_rels_filename_2, 'r') as fl2: - lines = fl2.readlines() - self.rank_list_1 = [line.split('\t')[0] for line in lines] - - self.template_queries = read_json(str(expand_path(self.sparql_queries_filename))) + template_queries = read_json(str(expand_path(self.sparql_queries_filename))) + self.template_queries = preprocess_template_queries(template_queries, self.kb_prefixes) def save(self) -> None: pass @@ -109,12 +107,13 @@ def find_candidate_answers(self, question: str, question_sanitized: str, template_types: Union[List[str], str], entities_from_ner: List[str], + types_from_ner: List[str], entity_tags: List[str], - answer_types: Set[str]) -> Tuple[Union[Union[List[List[Union[str, float]]], - List[Any]], Any], - Union[str, Any], Union[List[Any], Any]]: + probas: List[float], + entities_to_link: List[int], + answer_types: Set[str]) -> Tuple[Union[List[Dict[str, Any]], list], str]: candidate_outputs = [] - self.template_nums = template_types + self.template_nums = [template_types] replace_tokens = [(' - ', '-'), (' .', ''), ('{', ''), ('}', ''), (' ', ' '), ('"', "'"), ('(', ''), (')', ''), ('–', '-')] @@ -122,47 +121,50 @@ def find_candidate_answers(self, question: str, question = question.replace(old, new) entities_from_template, types_from_template, rels_from_template, rel_dirs_from_template, query_type_template, \ - entity_types, template_answer, answer_types, template_found = self.template_matcher(question_sanitized, - entities_from_ner) - self.template_nums = [query_type_template] - templates_nums = [] + entity_types, template_answer, template_answer_types, template_found = self.template_matcher( + question_sanitized, entities_from_ner) + if query_type_template: + self.template_nums = [query_type_template] log.debug( f"question: {question} entities_from_template {entities_from_template} template_type {self.template_nums} " - f"types from template {types_from_template} rels_from_template {rels_from_template}") + f"types from template {types_from_template} rels_from_template {rels_from_template} entities_from_ner " + f"{entities_from_ner} types_from_ner {types_from_ner} answer_types {list(answer_types)[:3]}") if entities_from_template or types_from_template: if rels_from_template[0][0] == "PHOW": how_to_content = self.find_answer_wikihow(entities_from_template[0]) candidate_outputs = [["PHOW", how_to_content, 1.0]] else: - entity_ids = self.get_entity_ids(entities_from_template, entity_tags, question) - log.debug(f"entities_from_template {entities_from_template}") - log.debug(f"entity_types {entity_types}") - log.debug(f"types_from_template {types_from_template}") - log.debug(f"rels_from_template {rels_from_template}") - log.debug(f"entity_ids {entity_ids}") - candidate_outputs, templates_nums = \ - self.sparql_template_parser(question_sanitized, entity_ids, [], answer_types, - rels_from_template, rel_dirs_from_template) - - if not candidate_outputs and entities_from_ner: + entity_ids = self.get_entity_ids(entities_from_template, entity_tags, probas, question, + entities_to_link) + type_ids = self.get_entity_ids(types_from_template, ["t" for _ in types_from_template], + [1.0 for _ in types_from_template], question) + log.debug(f"entities_from_template: {entities_from_template} --- entity_types: {entity_types} --- " + f"types_from_template: {types_from_template} --- rels_from_template: {rels_from_template} " + f"--- answer_types: {template_answer_types} --- entity_ids: {entity_ids}") + candidate_outputs = self.sparql_template_parser(question_sanitized, entity_ids, type_ids, + template_answer_types, rels_from_template, + rel_dirs_from_template) + if not candidate_outputs and (entities_from_ner or types_from_ner): log.debug(f"(__call__)entities_from_ner: {entities_from_ner}") - entity_ids = self.get_entity_ids(entities_from_ner, entity_tags, question) - log.debug(f"(__call__)entity_ids: {entity_ids}") + entity_ids = self.get_entity_ids(entities_from_ner, entity_tags, probas, question) + type_ids = self.get_entity_ids(types_from_ner, ["t" for _ in types_from_ner], + [1.0 for _ in types_from_ner], question) + log.debug(f"(__call__)entity_ids: {entity_ids} type_ids {type_ids}") self.template_nums = template_types log.debug(f"(__call__)self.template_nums: {self.template_nums}") if not self.syntax_structure_known: entity_ids = entity_ids[:3] - candidate_outputs, templates_nums = self.sparql_template_parser(question_sanitized, entity_ids, [], - answer_types) - return candidate_outputs, template_answer, templates_nums + candidate_outputs = self.sparql_template_parser(question_sanitized, entity_ids, type_ids, answer_types) + return candidate_outputs, template_answer - def get_entity_ids(self, entities: List[str], tags: List[str], question: str) -> List[List[str]]: - entity_ids = [] - el_output = [] + def get_entity_ids(self, entities: List[str], tags: List[str], probas: List[float], question: str, + entities_to_link: List[int] = None) -> List[List[str]]: + entity_ids, el_output = [], [] try: - el_output = self.entity_linker([entities], [tags], [[question]], [None], [None]) + el_output = self.entity_linker([entities], [tags], [probas], [[question]], [None], [None], + [entities_to_link]) except json.decoder.JSONDecodeError: log.warning("not received output from entity linking") if el_output: @@ -181,101 +183,111 @@ def get_entity_ids(self, entities: List[str], tags: List[str], question: str) -> def sparql_template_parser(self, question: str, entity_ids: List[List[str]], type_ids: List[List[str]], - answer_types: List[str], + answer_types: Set[str], rels_from_template: Optional[List[Tuple[str]]] = None, - rel_dirs_from_template: Optional[List[str]] = None) -> Tuple[Union[None, List[Any]], - List[Any]]: + rel_dirs_from_template: Optional[List[str]] = None) -> Union[List[Dict[str, Any]], list]: candidate_outputs = [] - log.debug(f"use alternative templates {self.use_alt_templates}") - log.debug(f"(find_candidate_answers)self.template_nums: {self.template_nums}") - templates = [] - templates_nums = [] + if isinstance(self.template_nums, str): + self.template_nums = [self.template_nums] + template_log_list = [str([elem["query_template"], elem["template_num"]]) + for elem in self.template_queries.values() if elem["template_num"] in self.template_nums] + log.debug(f"(find_candidate_answers)self.template_nums: {' --- '.join(template_log_list)}") + init_templates = [] for template_num in self.template_nums: for num, template in self.template_queries.items(): if (num == template_num and self.syntax_structure_known) or \ (template["template_num"] == template_num and not self.syntax_structure_known): - templates.append(template) - templates_nums.append(num) - new_templates = [] - new_templates_nums = [] - for template, template_num in zip(templates, templates_nums): - if (not self.syntax_structure_known and [len(entity_ids), len(type_ids)] == template[ - "entities_and_types_num"]) or self.syntax_structure_known: - new_templates.append(template) - new_templates_nums.append(template_num) - - templates = new_templates - templates_nums = new_templates_nums - - templates_string = '\n'.join([template["query_template"] for template in templates]) - log.debug(f"{templates_string}") + init_templates.append(template) + templates = [template for template in init_templates if + (not self.syntax_structure_known and [len(entity_ids), len(type_ids)] == template[ + "entities_and_types_num"]) + or self.syntax_structure_known] if not templates: - return candidate_outputs, [] + templates = [template for template in init_templates if + (not self.syntax_structure_known and [len(entity_ids), 0] == template[ + "entities_and_types_num"]) + or self.syntax_structure_known] + if not templates: + return candidate_outputs if rels_from_template is not None: query_template = {} for template in templates: if template["rel_dirs"] == rel_dirs_from_template: query_template = template if query_template: - entities_and_types_select = query_template["entities_and_types_select"] - candidate_outputs = self.query_parser(question, query_template, entities_and_types_select, - entity_ids, type_ids, answer_types, rels_from_template) + candidate_outputs = self.query_parser(question, [query_template], entity_ids, type_ids, answer_types, + rels_from_template) else: - for template in templates: - entities_and_types_select = template["entities_and_types_select"] - candidate_outputs = self.query_parser(question, template, entities_and_types_select, - entity_ids, type_ids, answer_types, rels_from_template) - if self.use_add_templates: - additional_templates = template.get("additional_templates", []) - templates_nums += additional_templates - for add_template_num in additional_templates: - candidate_outputs += self.query_parser(question, self.template_queries[add_template_num], - entities_and_types_select, entity_ids, type_ids, - answer_types, rels_from_template) + candidate_outputs = [] + for priority in range(1, 3): + pr_templates = [template for template in templates if template["priority"] == priority] + candidate_outputs = self.query_parser(question, pr_templates, entity_ids, type_ids, answer_types, + rels_from_template) if candidate_outputs: - templates_nums = list(set(templates_nums)) - return candidate_outputs, templates_nums + return candidate_outputs - if not candidate_outputs and self.use_alt_templates: - alternative_templates = templates[0]["alternative_templates"] - for template_num, entities_and_types_select in alternative_templates: - candidate_outputs = self.query_parser(question, self.template_queries[template_num], - entities_and_types_select, entity_ids, type_ids, - answer_types, rels_from_template) - templates_nums.append(template_num) - if candidate_outputs: - templates_nums = list(set(templates_nums)) - return candidate_outputs, templates_nums + if not candidate_outputs: + alt_template_nums = templates[0].get("alternative_templates", []) + log.debug(f"Using alternative templates {alt_template_nums}") + alt_templates = [self.template_queries[num] for num in alt_template_nums] + candidate_outputs = self.query_parser(question, alt_templates, entity_ids, type_ids, answer_types, + rels_from_template) + if candidate_outputs: + return candidate_outputs log.debug("candidate_rels_and_answers:\n" + '\n'.join([str(output) for output in candidate_outputs[:5]])) + return candidate_outputs - templates_nums = list(set(templates_nums)) - return candidate_outputs, templates_nums - - def find_top_rels(self, question: str, entity_ids: List[List[str]], triplet_info: Tuple) -> List[Tuple[str, Any]]: - ex_rels = [] - direction, source, rel_type = triplet_info + def find_top_rels(self, question: str, entity_ids: List[List[str]], triplet_info: Tuple) -> \ + Tuple[List[Tuple[str, float]], Dict[str, float], Set[Tuple[str, str]]]: + ex_rels, entity_rel_conn = [], set() + direction, source, rel_type, n_hop = triplet_info if source == "wiki": queries_list = list({(entity, direction, rel_type) for entity_id in entity_ids for entity in entity_id[:self.entities_to_leave]}) + entity_ids_list = [elem[0] for elem in queries_list] parser_info_list = ["find_rels" for i in range(len(queries_list))] - try: - ex_rels = self.wiki_parser(parser_info_list, queries_list) - except json.decoder.JSONDecodeError: - log.warning("find_top_rels, not received output from wiki parser") + ex_rels = self.wiki_parser(parser_info_list, queries_list) + for ex_rels_elem, entity_id in zip(ex_rels, entity_ids_list): + for rel in ex_rels_elem: + entity_rel_conn.add((entity_id, rel.split("/")[-1])) if self.use_wp_api_requester and ex_rels: ex_rels = [rel[0] for rel in ex_rels] - ex_rels = list(set(ex_rels)) - ex_rels = [rel.split('/')[-1] for rel in ex_rels] - elif source == "rank_list_1": - ex_rels = self.rank_list_0 - elif source == "rank_list_2": - ex_rels = self.rank_list_1 - rels_with_scores = [] - ex_rels = [rel for rel in ex_rels if rel.startswith("P")] - if ex_rels: - rels_with_scores = self.rel_ranker.rank_rels(question, ex_rels) - return rels_with_scores[:self.rels_to_leave] + ex_rels = list(set(itertools.chain.from_iterable(ex_rels))) + if n_hop in {"1-of-2-hop", "2-hop"}: + queries_list = list({(entity, "backw", rel_type) for entity_id in entity_ids + for entity in entity_id[:self.entities_to_leave]}) + entity_ids_list = [elem[0] for elem in queries_list] + parser_info_list = ["find_rels" for i in range(len(queries_list))] + ex_rels_backw = self.wiki_parser(parser_info_list, queries_list) + for ex_rels_elem, entity_id in zip(ex_rels_backw, entity_ids_list): + for rel in ex_rels_elem: + entity_rel_conn.add((entity_id, rel.split("/")[-1])) + ex_rels_backw = list(set(itertools.chain.from_iterable(ex_rels_backw))) + ex_rels += ex_rels_backw + if self.delete_rel_prefix: + ex_rels = [rel.split('/')[-1] for rel in ex_rels] + elif source in {"rank_list_1", "rel_list_1"}: + ex_rels = self.rels_in_ranking_queries.get("one_rel_in_query", []) + elif source in {"rank_list_2", "rel_list_2"}: + ex_rels = self.rels_in_ranking_queries.get("two_rels_in_query", []) + + ex_rels = [rel for rel in ex_rels if not any([rel.endswith(t_rel) for t_rel in self.kb_prefixes["type_rels"]])] + rels_with_scores = self.rel_ranker.rank_rels(question, ex_rels) + if n_hop == "2-hop" and rels_with_scores and entity_ids and entity_ids[0]: + rels_1hop = [rel for rel, score in rels_with_scores] + queries_list = [(entity_ids[0], rels_1hop[:5])] + parser_info_list = ["find_rels_2hop"] + ex_rels_2hop = self.wiki_parser(parser_info_list, queries_list) + if self.delete_rel_prefix: + ex_rels_2hop = [rel.split('/')[-1] for rel in ex_rels_2hop] + rels_with_scores = self.rel_ranker.rank_rels(question, ex_rels_2hop) + + rels_with_scores = list(set(rels_with_scores)) + rels_with_scores = sorted(rels_with_scores, key=lambda x: x[1], reverse=True) + rels_scores_dict = {rel: score for rel, score in rels_with_scores} + + return rels_with_scores[:self.rels_to_leave], rels_scores_dict, entity_rel_conn def find_answer_wikihow(self, howto_sentence: str) -> str: tags = [] @@ -291,6 +303,5 @@ def find_answer_wikihow(self, howto_sentence: str) -> str: howto_content = "Not Found" return howto_content - def query_parser(self, question, query_template, entities_and_types_select, entity_ids, type_ids, answer_types, - rels_from_template): + def query_parser(self, question, query_templates, entity_ids, type_ids, answer_types, rels_from_template): raise NotImplementedError diff --git a/deeppavlov/models/kbqa/rel_ranking_infer.py b/deeppavlov/models/kbqa/rel_ranking_infer.py index 669d8280a1..57437c7b49 100644 --- a/deeppavlov/models/kbqa/rel_ranking_infer.py +++ b/deeppavlov/models/kbqa/rel_ranking_infer.py @@ -18,7 +18,7 @@ from scipy.special import softmax from deeppavlov.core.common.chainer import Chainer -from deeppavlov.core.common.file import load_pickle +from deeppavlov.core.common.file import load_pickle, read_json from deeppavlov.core.common.registry import register from deeppavlov.core.models.component import Component from deeppavlov.core.models.serializable import Serializable @@ -34,33 +34,39 @@ class RelRankerInfer(Component, Serializable): def __init__(self, load_path: str, rel_q2name_filename: str, + return_elements: List[str] = None, ranker: Chainer = None, wiki_parser: Optional[WikiParser] = None, batch_size: int = 32, - rels_to_leave: int = 40, softmax: bool = False, - return_all_possible_answers: bool = False, - return_answer_ids: bool = False, use_api_requester: bool = False, - return_sentence_answer: bool = False, rank: bool = True, - return_confidences: bool = False, **kwargs): + nll_rel_ranking: bool = False, + nll_path_ranking: bool = False, + top_possible_answers: int = -1, + top_n: int = 1, + pos_class_num: int = 1, + rel_thres: float = 0.0, + type_rels: List[str] = None, **kwargs): """ Args: load_path: path to folder with wikidata files rel_q2name_filename: name of file which maps relation id to name + return_elements: what elements return in output ranker: component deeppavlov.models.ranking.rel_ranker wiki_parser: component deeppavlov.models.wiki_parser batch_size: infering batch size - rels_to_leave: how many relations to leave after relation ranking softmax: whether to process relation scores with softmax function - return_all_possible_answers: whether to return all found answers - return_answer_ids: whether to return answer ids from Wikidata use_api_requester: whether wiki parser will be used as external api - return_sentence_answer: whether to return answer as a sentence rank: whether to rank relations or simple copy input - return_confidences: whether to return confidences of candidate answers + nll_rel_ranking: whether use components trained with nll loss for relation ranking + nll_path_ranking: whether use components trained with nll loss for relation path ranking + top_possible_answers: number of answers returned for a question in each list of candidate answers + top_n: number of lists of candidate answers returned for a question + pos_class_num: index of positive class in the output of relation ranking model + rel_thres: threshold of relation confidence + type_rels: list of relations in the knowledge base which connect an entity and its type **kwargs: """ super().__init__(save_path=None, load_path=load_path) @@ -68,150 +74,221 @@ def __init__(self, load_path: str, self.ranker = ranker self.wiki_parser = wiki_parser self.batch_size = batch_size - self.rels_to_leave = rels_to_leave self.softmax = softmax - self.return_all_possible_answers = return_all_possible_answers - self.return_answer_ids = return_answer_ids + self.return_elements = return_elements or list() self.use_api_requester = use_api_requester - self.return_sentence_answer = return_sentence_answer self.rank = rank - self.return_confidences = return_confidences + self.nll_rel_ranking = nll_rel_ranking + self.nll_path_ranking = nll_path_ranking + self.top_possible_answers = top_possible_answers + self.top_n = top_n + self.pos_class_num = pos_class_num + self.rel_thres = rel_thres + self.type_rels = type_rels or set() self.load() def load(self) -> None: - self.rel_q2name = load_pickle(self.load_path / self.rel_q2name_filename) + if self.rel_q2name_filename.endswith("pickle"): + self.rel_q2name = load_pickle(self.load_path / self.rel_q2name_filename) + elif self.rel_q2name_filename.endswith("json"): + self.rel_q2name = read_json(self.load_path / self.rel_q2name_filename) def save(self) -> None: pass - def __call__(self, questions_list: List[str], - candidate_answers_list: List[List[Tuple[str]]], - entities_list: List[List[str]] = None, - template_answers_list: List[str] = None) -> List[str]: - answers = [] - confidence = 0.0 - if entities_list is None: - entities_list = [[] for _ in questions_list] - if template_answers_list is None: - template_answers_list = ["" for _ in questions_list] - for question, candidate_answers, entities, template_answer in \ - zip(questions_list, candidate_answers_list, entities_list, template_answers_list): + def __call__(self, questions_batch: List[str], + template_type_batch: List[str], + raw_answers_batch: List[List[Tuple[str]]], + entity_substr_batch: List[List[str]], + template_answers_batch: List[str]) -> List[str]: + answers_batch, outp_confidences_batch, answer_ids_batch = [], [], [] + entities_and_rels_batch, queries_batch, triplets_batch = [], [], [] + for question, template_type, raw_answers, entities, template_answer in \ + zip(questions_batch, template_type_batch, raw_answers_batch, entity_substr_batch, + template_answers_batch): answers_with_scores = [] - answer = "Not Found" - if self.rank: - n_batches = len(candidate_answers) // self.batch_size + int( - len(candidate_answers) % self.batch_size > 0) - for i in range(n_batches): - questions_batch = [] - rels_batch = [] - rels_labels_batch = [] - answers_batch = [] - entities_batch = [] - confidences_batch = [] - for candidate_ans_and_rels in candidate_answers[i * self.batch_size: (i + 1) * self.batch_size]: - candidate_rels = [] - candidate_rels_str, candidate_answer = "", "" - candidate_entities, candidate_confidence = [], [] - if candidate_ans_and_rels: - candidate_rels = candidate_ans_and_rels["relations"] - candidate_rels = [candidate_rel.split('/')[-1] for candidate_rel in candidate_rels] - candidate_answer = candidate_ans_and_rels["answers"] - candidate_entities = candidate_ans_and_rels["entities"] - candidate_confidence = candidate_ans_and_rels["rel_conf"] - candidate_rels_str = " # ".join([self.rel_q2name[candidate_rel] \ - for candidate_rel in candidate_rels if - candidate_rel in self.rel_q2name]) - if candidate_rels_str: - questions_batch.append(question) - rels_batch.append(candidate_rels) - rels_labels_batch.append(candidate_rels_str) - answers_batch.append(candidate_answer) - entities_batch.append(candidate_entities) - confidences_batch.append(candidate_confidence) - - if questions_batch: - probas = self.ranker(questions_batch, rels_labels_batch) - probas = [proba[1] for proba in probas] - for j, (answer, entities, confidence, rels_ids, rels_labels) in \ - enumerate(zip(answers_batch, entities_batch, confidences_batch, rels_batch, - rels_labels_batch)): - answers_with_scores.append( - (answer, entities, rels_labels, rels_ids, max(probas[j], confidence))) - - answers_with_scores = sorted(answers_with_scores, key=lambda x: x[-1], reverse=True) - else: - answers_with_scores = [(answer, rels, conf) for *rels, answer, conf in candidate_answers] - - answer_ids = tuple() - if answers_with_scores: - log.debug(f"answers: {answers_with_scores[0]}") - answer_ids = answers_with_scores[0][0] - if self.return_all_possible_answers and isinstance(answer_ids, tuple): - answer_ids_input = [(answer_id, question) for answer_id in answer_ids] - answer_ids = list(map(lambda x: x.split("/")[-1] if str(x).startswith("http") else x, answer_ids)) + l_questions, l_rels, l_rels_labels, l_cur_answers, l_entities, l_types, l_sparql_queries, l_triplets, \ + l_confs = self.preprocess_ranking_input(question, raw_answers) + + n_batches = len(l_questions) // self.batch_size + int(len(l_questions) % self.batch_size > 0) + for i in range(n_batches): + if self.rank: + if self.nll_path_ranking: + probas = self.ranker([l_questions[0]], + [l_rels_labels[self.batch_size * i:self.batch_size * (i + 1)]]) + probas = probas[0] + else: + probas = self.ranker(l_questions[self.batch_size * i:self.batch_size * (i + 1)], + l_rels_labels[self.batch_size * i:self.batch_size * (i + 1)]) + probas = [proba[0] for proba in probas] else: - answer_ids_input = [(answer_ids, question)] - if str(answer_ids).startswith("http:"): - answer_ids = answer_ids.split("/")[-1] + probas = [rel_conf for rel_conf, entity_conf in + l_confs[self.batch_size * i:self.batch_size * (i + 1)]] + for j in range(self.batch_size * i, self.batch_size * (i + 1)): + if j < len(l_cur_answers) and (probas[j - self.batch_size * i] > self.rel_thres or + (len(l_rels[j]) > 1 and not set(l_rels[j]).intersection( + self.type_rels))): + answers_with_scores.append((l_cur_answers[j], l_sparql_queries[j], l_triplets[j], + l_entities[j], l_types[j], l_rels_labels[j], l_rels[j], + round(probas[j - self.batch_size * i], 3), + round(l_confs[j][0], 3), l_confs[j][1])) + answers_with_scores = sorted(answers_with_scores, key=lambda x: x[-1] * x[-3], reverse=True) + if template_type == "simple_boolean" and not answers_with_scores: + answers_with_scores = [(["No"], "", [], [], [], [], [], 1.0, 1.0, 1.0)] + res_answers_list, res_answer_ids_list, res_confidences_list, res_entities_and_rels_list = [], [], [], [] + res_queries_list, res_triplets_list = [], [] + for n, ans_sc_elem in enumerate(answers_with_scores): + init_answer_ids, query, triplets, q_entities, q_types, _, q_rels, p_conf, r_conf, e_conf = ans_sc_elem + answer_ids = [] + for answer_id in init_answer_ids: + answer_id = str(answer_id).replace("@en", "").strip('"') + if answer_id not in answer_ids: + answer_ids.append(answer_id) + + if self.top_possible_answers > 0: + answer_ids = answer_ids[:self.top_possible_answers] + answer_ids_input = [(answer_id, question) for answer_id in answer_ids] + answer_ids = [str(answer_id).split("/")[-1] for answer_id in answer_ids] parser_info_list = ["find_label" for _ in answer_ids_input] - answer_labels = self.wiki_parser(parser_info_list, answer_ids_input) - log.debug(f"answer_labels {answer_labels}") - if self.return_all_possible_answers: - answer_labels = list(set(answer_labels)) - answer_labels = [label for label in answer_labels if (label and label != "Not Found")][:5] - answer_labels = [str(label) for label in answer_labels] - if len(answer_labels) > 2: - answer = f"{', '.join(answer_labels[:-1])} and {answer_labels[-1]}" - else: - answer = ', '.join(answer_labels) + init_answer_labels = self.wiki_parser(parser_info_list, answer_ids_input) + if n < 7: + log.debug(f"answers: {init_answer_ids[:3]} --- query {query} --- entities {q_entities} --- " + f"types {q_types[:3]} --- q_rels {q_rels} --- {ans_sc_elem[5:]} --- " + f"answer_labels {init_answer_labels[:3]}") + answer_labels = [] + for label in init_answer_labels: + if label not in answer_labels: + answer_labels.append(label) + answer_labels = [label for label in answer_labels if (label and label != "Not Found")][:5] + answer_labels = [str(label) for label in answer_labels] + if len(answer_labels) > 2: + answer = f"{', '.join(answer_labels[:-1])} and {answer_labels[-1]}" else: - answer = answer_labels[0] - if self.return_sentence_answer: + answer = ', '.join(answer_labels) + + if "sentence_answer" in self.return_elements: try: answer = sentence_answer(question, answer, entities, template_answer) - except: - log.warning("Error in sentence answer") - confidence = answers_with_scores[0][2] - if self.return_confidences: - answers.append((answer, confidence)) - else: - if self.return_answer_ids: - if not answer_ids: - answer_ids = "Not found" - answers.append((answer, answer_ids)) + except ValueError as e: + log.warning(f"Error in sentence answer, {e}") + + res_answers_list.append(answer) + res_answer_ids_list.append(answer_ids) + if "several_confidences" in self.return_elements: + res_confidences_list.append((p_conf, r_conf, e_conf)) + else: + res_confidences_list.append(p_conf) + res_entities_and_rels_list.append([q_entities[:-1], q_rels]) + res_queries_list.append(query) + res_triplets_list.append(triplets) + + if self.top_n == 1: + if answers_with_scores: + answers_batch.append(res_answers_list[0]) + outp_confidences_batch.append(res_confidences_list[0]) + answer_ids_batch.append(res_answer_ids_list[0]) + entities_and_rels_batch.append(res_entities_and_rels_list[0]) + queries_batch.append(res_queries_list[0]) + triplets_batch.append(res_triplets_list[0]) else: - answers.append(answer) - if not answers: - if self.return_confidences: - answers.append(("Not found", 0.0)) + answers_batch.append("Not Found") + outp_confidences_batch.append(0.0) + answer_ids_batch.append([]) + entities_and_rels_batch.append([]) + queries_batch.append([]) + triplets_batch.append([]) else: - answers.append("Not found") + answers_batch.append(res_answers_list[:self.top_n]) + outp_confidences_batch.append(res_confidences_list[:self.top_n]) + answer_ids_batch.append(res_answer_ids_list[:self.top_n]) + entities_and_rels_batch.append(res_entities_and_rels_list[:self.top_n]) + queries_batch.append(res_queries_list[:self.top_n]) + triplets_batch.append(res_triplets_list[:self.top_n]) + + answer_tuple = (answers_batch,) + if "confidences" in self.return_elements: + answer_tuple += (outp_confidences_batch,) + if "answer_ids" in self.return_elements: + answer_tuple += (answer_ids_batch,) + if "entities_and_rels" in self.return_elements: + answer_tuple += (entities_and_rels_batch,) + if "queries" in self.return_elements: + answer_tuple += (queries_batch,) + if "triplets" in self.return_elements: + answer_tuple += (triplets_batch,) + + return answer_tuple - return answers + def preprocess_ranking_input(self, question, answers): + l_questions, l_rels, l_rels_labels, l_cur_answers = [], [], [], [] + l_entities, l_types, l_sparql_queries, l_triplets, l_confs = [], [], [], [], [] + for ans_and_rels in answers: + answer, sparql_query, confidence = "", "", [] + entities, types, rels, rels_labels, triplets = [], [], [], [], [] + if ans_and_rels: + rels = [rel.split('/')[-1] for rel in ans_and_rels["relations"]] + answer = ans_and_rels["answers"] + entities = ans_and_rels["entities"] + types = ans_and_rels["types"] + sparql_query = ans_and_rels["sparql_query"] + triplets = ans_and_rels["triplets"] + confidence = ans_and_rels["output_conf"] + rels_labels = [] + for rel in rels: + if rel in self.rel_q2name: + label = self.rel_q2name[rel] + if isinstance(label, list): + label = label[0] + rels_labels.append(label.lower()) + if rels_labels: + l_questions.append(question) + l_rels.append(rels) + l_rels_labels.append(rels_labels) + l_cur_answers.append(answer) + l_entities.append(entities) + l_types.append(types) + l_sparql_queries.append(sparql_query) + l_triplets.append(triplets) + l_confs.append(confidence) + return l_questions, l_rels, l_rels_labels, l_cur_answers, l_entities, l_types, l_sparql_queries, l_triplets, \ + l_confs def rank_rels(self, question: str, candidate_rels: List[str]) -> List[Tuple[str, Any]]: rels_with_scores = [] if question is not None: - n_batches = len(candidate_rels) // self.batch_size + int(len(candidate_rels) % self.batch_size > 0) - for i in range(n_batches): - questions_batch = [] - rels_labels_batch = [] - rels_batch = [] - for candidate_rel in candidate_rels[i * self.batch_size: (i + 1) * self.batch_size]: - if candidate_rel in self.rel_q2name: - questions_batch.append(question) - rels_batch.append(candidate_rel) - rels_labels_batch.append(self.rel_q2name[candidate_rel]) - if questions_batch: - probas = self.ranker(questions_batch, rels_labels_batch) - probas = [proba[1] for proba in probas] - for j, rel in enumerate(rels_batch): + questions, rels_labels, rels = [], [], [] + for candidate_rel in candidate_rels: + if candidate_rel in self.rel_q2name: + cur_rels_labels = self.rel_q2name[candidate_rel] + if isinstance(cur_rels_labels, str): + cur_rels_labels = [cur_rels_labels] + for cur_rel in cur_rels_labels: + questions.append(question) + rels.append(candidate_rel) + rels_labels.append(cur_rel) + if questions: + n_batches = len(rels) // self.batch_size + int(len(rels) % self.batch_size > 0) + for i in range(n_batches): + if self.nll_rel_ranking: + probas = self.ranker([questions[0]], + [rels_labels[i * self.batch_size:(i + 1) * self.batch_size]]) + probas = probas[0] + else: + probas = self.ranker(questions[i * self.batch_size:(i + 1) * self.batch_size], + rels_labels[i * self.batch_size:(i + 1) * self.batch_size]) + probas = [proba[self.pos_class_num] for proba in probas] + for j, rel in enumerate(rels[i * self.batch_size:(i + 1) * self.batch_size]): rels_with_scores.append((rel, probas[j])) if self.softmax: scores = [score for rel, score in rels_with_scores] softmax_scores = softmax(scores) rels_with_scores = [(rel, softmax_score) for (rel, score), softmax_score in zip(rels_with_scores, softmax_scores)] + rels_with_scores_dict = {} + for rel, score in rels_with_scores: + if rel not in rels_with_scores_dict: + rels_with_scores_dict[rel] = [] + rels_with_scores_dict[rel].append(score) + rels_with_scores = [(rel, max(scores)) for rel, scores in rels_with_scores_dict.items()] rels_with_scores = sorted(rels_with_scores, key=lambda x: x[1], reverse=True) - - return rels_with_scores[:self.rels_to_leave] + return rels_with_scores diff --git a/deeppavlov/models/kbqa/ru_adj_to_noun.py b/deeppavlov/models/kbqa/ru_adj_to_noun.py new file mode 100644 index 0000000000..9db8c5d3e0 --- /dev/null +++ b/deeppavlov/models/kbqa/ru_adj_to_noun.py @@ -0,0 +1,108 @@ +# Copyright 2017 Neural Networks and Deep Learning lab, MIPT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from collections import defaultdict +from logging import getLogger +from typing import List + +import numpy as np +import spacy +from scipy.sparse import csr_matrix + +from deeppavlov.core.commands.utils import expand_path +from deeppavlov.core.common.registry import register + +log = getLogger(__name__) + + +@register('ru_adj_to_noun') +class RuAdjToNoun: + """ + Class for converting an adjective in Russian to the corresponding noun, for example: + "московский" -> "Москва", "африканский" -> "Африка" + """ + + def __init__(self, freq_dict_filename: str, candidate_nouns: int = 10, freq_thres: float = 4.5, + score_thres: float = 2.8, **kwargs): + """ + + Args: + freq_dict_filename: file with the dictionary of Russian words with the corresponding frequencies + candidate_nouns: how many candidate nouns to leave after search + **kwargs: + """ + self.candidate_nouns = candidate_nouns + self.freq_thres = freq_thres + self.score_thres = score_thres + alphabet = "абвгдеёжзийклмнопрстуфхцчшщъыьэюя-" + self.alphabet_length = len(alphabet) + self.max_word_length = 24 + self.letter_nums = {letter: num for num, letter in enumerate(alphabet)} + with open(str(expand_path(freq_dict_filename)), 'r') as fl: + lines = fl.readlines() + pos_freq_dict = defaultdict(list) + for line in lines: + line_split = line.strip('\n').split('\t') + if re.match("[\d]+\.[\d]+", line_split[2]): + pos_freq_dict[line_split[1]].append((line_split[0], float(line_split[2]))) + self.nouns_with_freq = pos_freq_dict["s.PROP"] + self.adj_set = set([word for word, freq in pos_freq_dict["a"]]) + self.nouns = [noun[0] for noun in self.nouns_with_freq] + self.matrix = self.make_sparse_matrix(self.nouns).transpose() + self.nlp = spacy.load("ru_core_news_sm") + + def search(self, word: str): + word = self.nlp(word)[0].lemma_ + if word in self.adj_set: + q_matrix = self.make_sparse_matrix([word]) + scores = q_matrix * self.matrix + scores = np.squeeze(scores.toarray()) + indices = np.argsort(-scores)[:self.candidate_nouns] + scores = list(scores[indices]) + candidates = [self.nouns_with_freq[indices[i]] + (scores[i],) for i in range(len(indices))] + candidates = [cand for cand in candidates if cand[0][:3].lower() == word[:3].lower()] + candidates = sorted(candidates, key=lambda x: (x[2], x[1]), reverse=True) + log.debug(f"AdjToNoun, found nouns: {candidates}") + if candidates and candidates[0][1] > self.freq_thres and candidates[0][2] > self.score_thres: + return candidates[0][0] + return "" + + def make_sparse_matrix(self, words: List[str]): + indptr = [] + indices = [] + data = [] + + total_length = 0 + + for n, word in enumerate(words): + indptr.append(total_length) + for cnt, letter in enumerate(word.lower()): + col = self.alphabet_length * cnt + self.letter_nums[letter] + indices.append(col) + init_value = 1.0 - cnt * 0.05 + if init_value < 0: + init_value = 0 + data.append(init_value) + total_length += len(word) + + indptr.append(total_length) + + data = np.array(data) + indptr = np.array(indptr) + indices = np.array(indices) + + matrix = csr_matrix((data, indices, indptr), shape=(len(words), self.max_word_length * self.alphabet_length)) + + return matrix diff --git a/deeppavlov/models/kbqa/tree_to_sparql.py b/deeppavlov/models/kbqa/tree_to_sparql.py index d406ce7368..d5f087761a 100644 --- a/deeppavlov/models/kbqa/tree_to_sparql.py +++ b/deeppavlov/models/kbqa/tree_to_sparql.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import re from collections import defaultdict from io import StringIO from logging import getLogger from typing import Any, List, Tuple, Dict, Union -import numpy as np import spacy from navec import Navec -from scipy.sparse import csr_matrix +from razdel import tokenize from slovnet import Syntax from udapi.block.read.conllu import Conllu from udapi.core.node import Node @@ -31,96 +31,28 @@ from deeppavlov.core.common.registry import register from deeppavlov.core.models.component import Component from deeppavlov.core.models.serializable import Serializable +from deeppavlov.models.kbqa.ru_adj_to_noun import RuAdjToNoun +from deeppavlov.models.kbqa.utils import preprocess_template_queries log = getLogger(__name__) -@register('ru_adj_to_noun') -class RuAdjToNoun: - """ - Class for converting an adjective in Russian to the corresponding noun, for example: - "московский" -> "Москва", "африканский" -> "Африка" - """ - - def __init__(self, freq_dict_filename: str, candidate_nouns: int = 10, **kwargs): - """ - - Args: - freq_dict_filename: file with the dictionary of Russian words with the corresponding frequencies - candidate_nouns: how many candidate nouns to leave after search - **kwargs: - """ - self.candidate_nouns = candidate_nouns - alphabet = "абвгдеёжзийклмнопрстуфхцчшщъыьэюя-" - self.alphabet_length = len(alphabet) - self.max_word_length = 24 - self.letter_nums = {letter: num for num, letter in enumerate(alphabet)} - with open(str(expand_path(freq_dict_filename)), 'r') as fl: - lines = fl.readlines() - pos_freq_dict = defaultdict(list) - for line in lines: - line_split = line.strip('\n').split('\t') - if re.match("[\d]+\.[\d]+", line_split[2]): - pos_freq_dict[line_split[1]].append((line_split[0], float(line_split[2]))) - self.nouns_with_freq = pos_freq_dict["s.PROP"] - self.adj_set = set([word for word, freq in pos_freq_dict["a"]]) - self.nouns = [noun[0] for noun in self.nouns_with_freq] - self.matrix = self.make_sparse_matrix(self.nouns).transpose() - self.nlp = spacy.load("ru_core_news_sm") - - def search(self, word: str): - word = self.nlp(word)[0].lemma_ - if word in self.adj_set: - q_matrix = self.make_sparse_matrix([word]) - scores = q_matrix * self.matrix - scores = np.squeeze(scores.toarray() + 0.0001) - indices = np.argsort(-scores)[:self.candidate_nouns] - scores = list(scores[indices]) - candidates = [self.nouns_with_freq[indices[i]] + (scores[i],) for i in range(len(indices))] - candidates = sorted(candidates, key=lambda x: x[1] * x[2], reverse=True) - log.debug(f"AdjToNoun, found nouns: {candidates}") - if candidates[0][2] > 2.5: - return candidates[0][0] - return "" - - def make_sparse_matrix(self, words: List[str]): - indptr = [] - indices = [] - data = [] - - total_length = 0 - - for n, word in enumerate(words): - indptr.append(total_length) - for cnt, letter in enumerate(word.lower()): - col = self.alphabet_length * cnt + self.letter_nums[letter] - indices.append(col) - init_value = 1.0 - cnt * 0.05 - if init_value < 0: - init_value = 0 - data.append(init_value) - total_length += len(word) - - indptr.append(total_length) - - data = np.array(data) - indptr = np.array(indptr) - indices = np.array(indices) - - matrix = csr_matrix((data, indices, indptr), shape=(len(words), self.max_word_length * self.alphabet_length)) - - return matrix - - @register('slovnet_syntax_parser') class SlovnetSyntaxParser(Component, Serializable): """Class for syntax parsing using Slovnet library""" - def __init__(self, load_path: str, navec_filename: str, syntax_parser_filename: str, **kwargs): + def __init__(self, load_path: str, navec_filename: str, syntax_parser_filename: str, tree_patterns_filename: str, + **kwargs): super().__init__(save_path=None, load_path=load_path) self.navec_filename = expand_path(navec_filename) self.syntax_parser_filename = expand_path(syntax_parser_filename) + self.tree_patterns = read_json(expand_path(tree_patterns_filename)) self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") + self.pronouns = {"q_pronouns": {"какой", "какая", "какое", "каком", "каким", "какую", "кто", "что", "как", + "когда", "где", "чем", "сколько"}, + "how_many": {"сколько"}} + self.first_tokens = {"первый", "первая", "первое"} + self.nlp = spacy.load("ru_core_news_sm") self.load() def load(self) -> None: @@ -131,35 +63,232 @@ def load(self) -> None: def save(self) -> None: pass - def __call__(self, sentences, entity_offsets_batch): - sentences_tok = [] + def preprocess_sentences(self, sentences, entity_offsets_batch): + sentences_tokens_batch, replace_dict_batch = [], [] for sentence, entity_offsets in zip(sentences, entity_offsets_batch): - for start, end in entity_offsets: - entity_old = sentence[start:end] - entity_new = entity_old.capitalize() - sentence = sentence.replace(entity_old, entity_new) - sentence = sentence.capitalize() - sentences_tok.append(re.findall(self.re_tokenizer, sentence)) - markup = list(self.syntax.map(sentences_tok)) - + if sentence.islower(): + for start, end in entity_offsets: + entity_old = sentence[start:end] + if entity_old: + entity_new = f"{entity_old[0].upper()}{entity_old[1:]}" + sentence = sentence.replace(entity_old, entity_new) + sentence = f"{sentence[0].upper()}{sentence[1:]}" + names3 = re.findall(r"([\w]{1}\.)([ ]?)([\w]{1}\.)([ ])([\w]{3,})", sentence) + replace_dict = {} + for name in names3: + names_str = "".join(name) + replace_dict[name[-1]] = (names_str, "name") + sentence = sentence.replace(names_str, name[-1]) + names2 = re.findall(r"([\w]{1}\.)([ ])([\w]{3,})", sentence) + for name in names2: + names_str = "".join(name) + replace_dict[name[-1]] = (names_str, "name") + sentence = sentence.replace(names_str, name[-1]) + works_of_art = re.findall(r'(["«])(.*?)(["»])', sentence) + for symb_start, work_of_art, symb_end in works_of_art: + work_of_art_tokens = re.findall(self.re_tokenizer, work_of_art) + if len(work_of_art.split()) > 1: + short_substr = "" + for tok in work_of_art_tokens: + if self.nlp(tok)[0].pos_ == "NOUN": + short_substr = tok + break + if not short_substr: + short_substr = work_of_art_tokens[0] + replace_dict[short_substr] = (work_of_art, "name") + sentence = sentence.replace(work_of_art, short_substr) + while True: + tokens = sentence.split() + found_substr = False + for i in range(len(tokens) - 2): + found = True + for j in range(i, i + 3): + if len(tokens[j]) < 2 or tokens[j][0] in '("' or tokens[j][-1] in '"),.?': + found = False + if found and i > 0: + token_tags = [self.nlp(tokens[j])[0].pos_ for j in range(i, i + 3)] + lemm_tokens = {self.nlp(tok)[0].lemma_ for tok in tokens[i:i + 3]} + if token_tags == ["DET", "DET", "NOUN"] and not lemm_tokens & self.first_tokens: + long_substr = " ".join(tokens[i:i + 3]) + replace_dict[tokens[i + 2]] = (long_substr, "adj") + sentence = sentence.replace(long_substr, tokens[i + 2]) + found_substr = True + if found_substr: + break + if not found_substr: + break + sentence_tokens = [tok.text for tok in tokenize(sentence)] + sentences_tokens_batch.append(sentence_tokens) + log.debug(f"replace_dict: {replace_dict} --- sentence: {sentence_tokens}") + replace_dict_batch.append(replace_dict) + return sentences_tokens_batch, replace_dict_batch + + def get_markup(self, proc_syntax_batch, replace_dict_batch): + markup_batch = [] + for proc_syntax, replace_dict in zip(proc_syntax_batch, replace_dict_batch): + markup_list = [] + for elem in proc_syntax.tokens: + markup_list.append({"id": elem.id, "text": elem.text, "head_id": int(elem.head_id), "rel": elem.rel}) + ids, words, head_ids, rels = self.get_elements(markup_list) + head_ids, markup_list = self.correct_cycle(ids, head_ids, rels, markup_list) + for substr in replace_dict: + substr_full, substr_type = replace_dict[substr] + found_n = -1 + for n, markup_elem in enumerate(markup_list): + if markup_elem["text"] == substr: + found_n = n + if found_n > -1: + before_markup_list = copy.deepcopy(markup_list[:found_n]) + after_markup_list = copy.deepcopy(markup_list[found_n + 1:]) + substr_tokens = [tok.text for tok in tokenize(substr_full)] + new_markup_list = [] + if substr_type == "name": + for j in range(len(substr_tokens)): + new_markup_elem = {"id": str(found_n + j + 1), "text": substr_tokens[j]} + if j == 0: + new_markup_elem["rel"] = markup_list[found_n]["rel"] + if int(markup_list[found_n]["head_id"]) < found_n + 1: + new_markup_elem["head_id"] = markup_list[found_n]["head_id"] + else: + new_markup_elem["head_id"] = str(int(markup_list[found_n]["head_id"]) + len( + substr_tokens) - 1) + else: + new_markup_elem["rel"] = "flat:name" + new_markup_elem["head_id"] = str(found_n + 1) + new_markup_list.append(new_markup_elem) + elif substr_type == "adj": + for j in range(len(substr_tokens)): + new_elem = {"id": str(found_n + j + 1), "text": substr_tokens[j]} + if j == len(substr_tokens) - 1: + new_elem["rel"] = markup_list[found_n]["rel"] + if markup_list[found_n]["head_id"] < found_n + 1: + new_elem["head_id"] = markup_list[found_n]["head_id"] + else: + new_elem["head_id"] = markup_list[found_n]["head_id"] + len(substr_tokens) - 1 + else: + new_elem["rel"] = "amod" + new_elem["head_id"] = str(found_n + len(substr_tokens)) + new_markup_list.append(new_elem) + + for j in range(len(before_markup_list)): + if int(before_markup_list[j]["head_id"]) > found_n + 1: + before_markup_list[j]["head_id"] = int(before_markup_list[j]["head_id"]) + \ + len(substr_tokens) - 1 + if before_markup_list[j]["head_id"] == found_n + 1 and substr_type == "adj": + before_markup_list[j]["head_id"] = found_n + len(substr_tokens) + for j in range(len(after_markup_list)): + after_markup_list[j]["id"] = str(int(after_markup_list[j]["id"]) + len(substr_tokens) - 1) + if int(after_markup_list[j]["head_id"]) > found_n + 1: + after_markup_list[j]["head_id"] = int(after_markup_list[j]["head_id"]) + \ + len(substr_tokens) - 1 + if after_markup_list[j]["head_id"] == found_n + 1 and substr_type == "adj": + after_markup_list[j]["head_id"] = found_n + len(substr_tokens) + + markup_list = before_markup_list + new_markup_list + after_markup_list + for j in range(len(markup_list)): + markup_list[j]["head_id"] = str(markup_list[j]["head_id"]) + markup_batch.append(markup_list) + return markup_batch + + def find_cycle(self, ids, head_ids): + for i in range(len(ids)): + for j in range(len(ids)): + if i < j and head_ids[j] == str(i + 1) and head_ids[i] == str(j + 1): + return i + 1 + return -1 + + def correct_markup(self, words, head_ids, rels, root_n): + if len(words) > 3: + pos = [self.nlp(words[i])[0].pos_ for i in range(len(words))] + for tree_pattern in self.tree_patterns: + first_word = tree_pattern.get("first_word", "") + (r_start, r_end), rel_info = tree_pattern.get("rels", [[0, 0], ""]) + (p_start, p_end), pos_info = tree_pattern.get("pos", [[0, 0], ""]) + if (not first_word or words[0].lower() in self.pronouns[first_word]) \ + and (not rel_info or rels[r_start:r_end] == rel_info) \ + and (not pos_info or pos[p_start:p_end] == pos_info): + for ind, deprel in tree_pattern.get("rel_ids", {}).items(): + rels[int(ind)] = deprel + for ind, head_id in tree_pattern.get("head_ids", {}).items(): + head_ids[int(ind)] = head_id + root_n = tree_pattern["root_n"] + break + if words[0].lower() in {"какой", "какая", "какое"} and rels[:3] == ["det", "obj", "root"] \ + and pos[1:3] == ["NOUN", "VERB"] and "nsubj" not in rels: + rels[1] = "nsubj" + return head_ids, rels, root_n + + def find_root(self, rels): + root_n = -1 + for n in range(len(rels)): + if rels[n] == "root": + root_n = n + 1 + break + return root_n + + def get_elements(self, markup_elem): + ids, words, head_ids, rels = [], [], [], [] + for elem in markup_elem: + ids.append(elem["id"]) + words.append(elem["text"]) + head_ids.append(elem["head_id"]) + rels.append(elem["rel"]) + return ids, words, head_ids, rels + + def correct_cycle(self, ids, head_ids, rels, markup_elem): + cycle_num = -1 + for n, (elem_id, head_id) in enumerate(zip(ids, head_ids)): + if str(head_id) == str(elem_id): + cycle_num = n + root_n = self.find_root(rels) + if cycle_num > 0 and root_n > -1: + head_ids[cycle_num] = root_n + markup_elem[cycle_num]["head_id"] = root_n + return head_ids, markup_elem + + def process_markup(self, markup_batch): processed_markup_batch = [] - for markup_elem in markup: + for markup_elem in markup_batch: processed_markup = [] - ids, words, head_ids, rels = [], [], [], [] - for elem in markup_elem.tokens: - ids.append(elem.id) - words.append(elem.text) - head_ids.append(elem.head_id) - rels.append(elem.rel) + ids, words, head_ids, rels = self.get_elements(markup_elem) if "root" not in {rel.lower() for rel in rels}: + found_root = False for n, (elem_id, head_id) in enumerate(zip(ids, head_ids)): if elem_id == head_id: rels[n] = "root" head_ids[n] = 0 + found_root = True + if not found_root: + for n in range(len(ids)): + if rels[n] == "nsubj": + rels[n] = "root" + head_ids[n] = 0 + found_root = True + if not found_root: + for n in range(len(ids)): + if self.nlp(words[n])[0].pos_ == "VERB": + rels[n] = "root" + head_ids[n] = 0 + + root_n = self.find_root(rels) + head_ids, rels, root_n = self.correct_markup(words, head_ids, rels, root_n) + if words[-1] == "?" and -1 < root_n != head_ids[-1]: + head_ids[-1] = root_n + + head_ids, markup_elem = self.correct_cycle(ids, head_ids, rels, markup_elem) + i = self.find_cycle(ids, head_ids) + if i == 1 and root_n > -1: + head_ids[i - 1] = root_n for elem_id, word, head_id, rel in zip(ids, words, head_ids, rels): processed_markup.append(f"{elem_id}\t{word}\t_\t_\t_\t_\t{head_id}\t{rel}\t_\t_") processed_markup_batch.append("\n".join(processed_markup)) + return processed_markup_batch + def __call__(self, sentences, entity_offsets_batch): + sentences_tokens_batch, substr_dict_batch = self.preprocess_sentences(sentences, entity_offsets_batch) + proc_syntax_batch = list(self.syntax.map(sentences_tokens_batch)) + markup_batch = self.get_markup(proc_syntax_batch, substr_dict_batch) + processed_markup_batch = self.process_markup(markup_batch) return processed_markup_batch @@ -169,141 +298,214 @@ class TreeToSparql(Component): Class for building of sparql query template using syntax parser """ - def __init__(self, sparql_queries_filename: str, lang: str = "rus", adj_to_noun: RuAdjToNoun = None, **kwargs): + def __init__(self, sparql_queries_filename: str, syntax_parser: Component, kb_prefixes: Dict[str, str], + adj_to_noun: RuAdjToNoun = None, **kwargs): """ Args: sparql_queries_filename: file with sparql query templates - lang: english or russian + syntax_parser: component for syntactic parsing of the input question + kb_prefixes: prefixes for entities, relations and types in the knowledge base adj_to_noun: component deeppavlov.models.kbqa.tree_to_sparql:RuAdjToNoun **kwargs: """ - self.lang = lang - if self.lang == "rus": - self.q_pronouns = {"какой", "какая", "какое", "каком", "каким", "какую", "кто", "что", "как", "когда", - "где", "чем", "сколько"} - self.how_many = "сколько" - self.change_root_tokens = {"каким был", "какой была"} - self.first_tokens = {"первый"} - self.last_tokens = {"последний"} - self.begin_tokens = {"начинать", "начать"} - self.end_tokens = {"завершить", "завершать", "закончить"} - self.ranking_tokens = {"самый"} - self.nlp = spacy.load("ru_core_news_sm") - elif self.lang == "eng": - self.q_pronouns = {"what", "who", "how", "when", "where", "which"} - self.how_many = "how many" - self.change_root_tokens = "" - self.first_tokens = {"first"} - self.last_tokens = {"last"} - self.begin_tokens = set() - self.end_tokens = set() - self.ranking_tokens = set() - self.nlp = spacy.load("en_core_web_sm") - else: - raise ValueError(f"unsupported language {lang}") + self.q_pronouns = {"какой", "какая", "какое", "каком", "каким", "какую", "кто", "что", "как", "когда", + "где", "чем", "сколько"} + self.how_many = "сколько" + self.change_root_tokens = {"каким был", "какой была"} + self.first_tokens = {"первый", "первая", "первое"} + self.last_tokens = {"последний"} + self.begin_tokens = {"начинать", "начать"} + self.end_tokens = {"завершить", "завершать", "закончить"} + self.ranking_tokens = {"самый"} + self.date_tokens = {"год", "месяц"} + self.nlp = spacy.load("ru_core_news_sm") + self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") self.sparql_queries_filename = expand_path(sparql_queries_filename) - self.template_queries = read_json(self.sparql_queries_filename) + template_queries = read_json(self.sparql_queries_filename) + self.template_queries = preprocess_template_queries(template_queries, kb_prefixes) + self.syntax_parser = syntax_parser self.adj_to_noun = adj_to_noun - def __call__(self, syntax_tree_batch: List[str], - positions_batch: List[List[List[int]]]) -> Tuple[ - List[str], List[List[str]], List[List[str]], List[List[str]]]: - log.debug(f"positions of entity tokens {positions_batch}") - query_nums_batch = [] - entities_dict_batch = [] - types_dict_batch = [] - questions_batch = [] + def __call__(self, questions_batch: List[str], substr_batch: List[List[str]], tags_batch: List[List[str]], + offsets_batch: List[List[List[int]]], positions_batch: List[List[List[int]]], + probas_batch: List[List[float]]) -> Tuple[ + List[Union[str, Any]], List[Union[List[str], List[Union[str, Any]]]], List[Union[List[str], Any]], List[ + Union[List[Union[str, Any]], Any]], List[Union[List[Union[float, Any]], Any]], List[List[int]], List[ + Union[List[str], List[Any]]]]: + substr_batch, tags_batch, offsets_batch, positions_batch, probas_batch = \ + self.sort_substr(substr_batch, tags_batch, offsets_batch, positions_batch, probas_batch) + log.debug(f"substr: {substr_batch} tags: {tags_batch} positions: {positions_batch}") + query_nums_batch, s_substr_batch, s_tags_batch, s_probas_batch, types_batch = [], [], [], [], [] + entities_to_link_batch = [] + clean_questions_batch = [] count = False - for syntax_tree, positions in zip(syntax_tree_batch, positions_batch): - log.debug(f"\n{syntax_tree}") - try: - tree = Conllu(filehandle=StringIO(syntax_tree)).read_tree() - root = self.find_root(tree) - tree_desc = tree.descendants - except ValueError: - root = "" - unknown_node = "" - if root: - log.debug(f"syntax tree info, root: {root.form}") - unknown_node, unknown_branch = self.find_branch_with_unknown(root) - positions = [num for position in positions for num in position] + for question, substr_list, tags_list, offsets_list, probas_list, positions in \ + zip(questions_batch, substr_batch, tags_batch, offsets_batch, probas_batch, positions_batch): + entities_dict, probas_dict = {}, {} + for substr, tag, proba in zip(substr_list, tags_list, probas_list): + entities_dict[substr.lower()] = tag + probas_dict[substr.lower()] = proba + for i in range(len(substr_list)): + substr = substr_list[i] + if len(substr) > 2 and ("-" in substr or f"{substr}-" in question) and " - " not in substr: + if "-" in substr: + length = len(re.findall(self.re_tokenizer, substr)) + else: + length = 3 + substr_tokens = list(tokenize(substr)) + positions[i] = [positions[i][j] for j in range(len(substr_tokens))] + if i < len(substr_list) - 1: + for j in range(i + 1, len(substr_list)): + pos_inds = positions[j] + pos_inds = [ind - length + 1 for ind in pos_inds] + positions[j] = pos_inds + + root, tree, tree_desc, unknown_node, unknown_branch = self.syntax_parse(question, offsets_list) + query_nums = ["7"] + s_substr_list = substr_list + s_tags_list = tags_list + s_probas_list = probas_list + types_list = [] if unknown_node: - log.debug(f"syntax tree info, unknown node: {unknown_node.form}, unknown branch: {unknown_branch.form}") + log.debug(f"syntax tree info 1, unknown node: {unknown_node.form}, unkn branch: {unknown_branch.form}") log.debug(f"wh_leaf: {self.wh_leaf}") clause_node, clause_branch = self.find_clause_node(root, unknown_branch) - if clause_node: - log.debug(f"clause node {clause_node.form}") - else: - log.debug(f"clause node not found") - modifiers, clause_modifiers = self.find_modifiers_of_unknown(unknown_node) - modifiers_debug_list = [] - for modifier in modifiers: - if isinstance(modifier, str): - modifiers_debug_list.append(modifier) - else: - modifiers_debug_list.append(modifier.form) - log.debug(f"modifiers: {' '.join(modifiers_debug_list)}") - if f"{tree_desc[0].form.lower()} {tree_desc[1].form.lower()}" in self.change_root_tokens: - new_root = root.children[0] - else: - new_root = root - root_desc = defaultdict(list) - for node in new_root.children: - if node.deprel not in ["punct", "advmod", "cop", "mark"]: - if node == unknown_branch: - root_desc[node.deprel].append(node) - else: - if self.find_entities(node, positions, cut_clause=False) or \ - (self.find_year_or_number(node) and node.deprel in ["obl", "nummod"]): - root_desc[node.deprel].append(node) - - if root.form.lower() == self.how_many or ("nsubj" in root_desc.keys() and - self.how_many in [nd.form.lower() for nd in - root_desc["nsubj"]]): - count = True - log.debug(f"root_desc {root_desc.keys()}") - appos_token_nums = sorted(self.find_appos_tokens(root, [])) + log.debug(f"clause node: {clause_node}") + tok_and_ord = {node.ord: node for node in tree.descendants} + appos_token_nums = sorted(self.find_appos_tokens(root, tok_and_ord, [])) appos_tokens = [elem.form for elem in tree_desc if elem.ord in appos_token_nums] - clause_token_nums = sorted(self.find_clause_tokens(root, clause_node, [])) + clause_token_nums = sorted(self.find_clause_tokens(root, tok_and_ord, clause_node)) clause_tokens = [elem.form for elem in tree_desc if elem.ord in clause_token_nums] log.debug(f"appos tokens: {appos_tokens}") log.debug(f"clause_tokens: {clause_tokens}") - self.root_entity = False - if root.ord - 1 in positions: - self.root_entity = True - - temporal_order = self.find_first_last(new_root) - new_root_nf = self.nlp(new_root.form)[0].lemma_ - if new_root_nf in self.begin_tokens or new_root_nf in self.end_tokens: - temporal_order = new_root_nf - ranking_tokens = self.find_ranking_tokens(new_root) - query_nums, entities_dict, types_dict = self.build_query(new_root, unknown_branch, root_desc, - unknown_node, modifiers, clause_modifiers, - clause_node, positions, count, - temporal_order, ranking_tokens) - - if self.lang == "rus": - if ranking_tokens: - question = [] - for node in tree.descendants: - if node.ord in ranking_tokens or node.form.lower() in self.q_pronouns: - question.append(self.nlp(node.form)[0].lemma_) - else: - question.append(node.form) - question = ' '.join(question) + question, ranking_tokens = self.sanitize_question(tree, root, appos_token_nums, clause_token_nums) + if appos_token_nums or clause_token_nums: + root, tree, tree_desc, unknown_node, unknown_branch = self.syntax_parse(question, offsets_list) + log.debug(f"syntax tree info 2, unknown node: {unknown_node}, unkn branch: {unknown_branch}") + + if unknown_node: + modifiers, clause_modifiers = self.find_modifiers_of_unknown(unknown_node) + log.debug(f"modifiers: {modifiers} --- clause modifiers: {[nd.form for nd in clause_modifiers]}") + if f"{tree_desc[0].form.lower()} {tree_desc[1].form.lower()}" in self.change_root_tokens: + new_root = root.children[0] else: - question = ' '.join([node.form for node in tree.descendants - if - (node.ord not in appos_token_nums or node.ord not in clause_token_nums)]) + new_root = root + root_desc = defaultdict(list) + for node in new_root.children: + if node.deprel not in ["punct", "advmod", "cop", "mark"]: + if node == unknown_branch: + root_desc[node.deprel].append(node) + else: + if self.find_entities(node, positions) or \ + (self.find_year_or_number(node) and node.deprel in ["obl", "nummod"]): + root_desc[node.deprel].append(node) + + if root.form.lower() == self.how_many or ("nsubj" in root_desc.keys() and + self.how_many in [nd.form.lower() for nd in + root_desc["nsubj"]]): + count = True + log.debug(f"root_desc {root_desc.keys()}") + self.root_entity = False + if root.ord - 1 in positions: + self.root_entity = True + + temporal_order = self.find_first_last(new_root) + new_root_nf = self.nlp(new_root.form)[0].lemma_ + if new_root_nf in self.begin_tokens or new_root_nf in self.end_tokens: + temporal_order = new_root_nf + query_nums, s_substr_list, types_list = self.build_query(new_root, unknown_branch, root_desc, + unknown_node, modifiers, clause_modifiers, + clause_node, positions, entities_dict, + count, temporal_order, ranking_tokens) + s_tags_list, s_probas_list = [], [] + for substr in s_substr_list: + substr = substr.replace(" - ", "-") + s_tags_list.append(entities_dict.get(substr.lower(), "E")) + s_probas_list.append(probas_dict.get(substr.lower(), 1.0)) + clean_questions_batch.append(question) + if query_nums and s_substr_list: + entities_to_link = [1 for _ in s_substr_list] + s_substr_list_lower = [s.lower() for s in s_substr_list] + for substr, tag, proba in zip(substr_list, tags_list, probas_list): + if substr.lower() not in s_substr_list_lower: + s_substr_list.append(substr) + s_tags_list.append(tag) + s_probas_list.append(proba) + entities_to_link.append(0) + s_substr_batch.append(s_substr_list) + s_tags_batch.append(s_tags_list) + s_probas_batch.append(s_probas_list) + entities_to_link_batch.append(entities_to_link) + else: + mod_len = 0 + gr_len = 1 + if all([tags_list[i] == tags_list[0] for i in range(len(tags_list))]): + gr_len = len(substr_list) + elif len(substr_list) > 1: + mod_len = 1 + for num, template in self.template_queries.items(): + syntax_info = [gr_len, 0, mod_len, 0, False, False, False] + if syntax_info == list(template["syntax_structure"].values()): + query_nums.append(num) + entities_to_link = [1 for _ in s_substr_list] + s_substr_batch.append(substr_list) + s_tags_batch.append(tags_list) + s_probas_batch.append(probas_list) + entities_to_link_batch.append(entities_to_link) + query_nums_batch.append(query_nums) + types_batch.append(types_list) + log.debug(f"clean_questions: {clean_questions_batch} --- substr: {s_substr_batch} --- tags: {s_tags_batch} " + f"--- entities_to_link {entities_to_link_batch} --- types: {types_batch}") + return clean_questions_batch, query_nums_batch, s_substr_batch, s_tags_batch, s_probas_batch, \ + entities_to_link_batch, types_batch + + def sort_substr(self, substr_batch: List[List[str]], tags_batch: List[List[str]], + offsets_batch: List[List[List[int]]], positions_batch: List[List[List[int]]], + probas_batch: List[List[float]]) -> Tuple[ + List[List[str]], List[List[str]], List[List[List[int]]], List[List[List[int]]], List[List[float]]]: + s_substr_batch, s_tags_batch, s_offsets_batch, s_positions_batch, s_probas_batch = [], [], [], [], [] + for substr_list, tags_list, offsets_list, positions_list, probas_list \ + in zip(substr_batch, tags_batch, offsets_batch, positions_batch, probas_batch): + substr_info = [(substr, tag, offsets, positions, proba) for substr, tag, offsets, positions, proba + in zip(substr_list, tags_list, offsets_list, positions_list, probas_list)] + substr_info = sorted(substr_info, key=lambda x: x[3][0]) + s_substr_batch.append([elem[0] for elem in substr_info]) + s_tags_batch.append([elem[1] for elem in substr_info]) + s_offsets_batch.append([elem[2] for elem in substr_info]) + s_positions_batch.append([elem[3] for elem in substr_info]) + s_probas_batch.append([elem[4] for elem in substr_info]) + return s_substr_batch, s_tags_batch, s_offsets_batch, s_positions_batch, s_probas_batch + + def syntax_parse(self, question: str, entity_offsets_list: List[List[int]]) -> Tuple[ + Union[str, Any], Union[str, Any], Union[str, Any], str, str]: + syntax_tree = self.syntax_parser([question], [entity_offsets_list])[0] + log.debug(f"syntax tree: \n{syntax_tree}") + root, tree, tree_desc, unknown_node, unknown_branch = "", "", "", "", "" + try: + tree = Conllu(filehandle=StringIO(syntax_tree)).read_tree() + root = self.find_root(tree) + tree_desc = tree.descendants + except ValueError as e: + log.warning(f"error in parsing syntax tree, {e}") + if root: + unknown_node, unknown_branch = self.find_branch_with_unknown(root) + log.debug(f"syntax tree info, root: {root.form} unk_node: {unknown_node} unk_branch: {unknown_branch}") + return root, tree, tree_desc, unknown_node, unknown_branch + + def sanitize_question(self, tree: Node, root: Node, appos_token_nums: List[int], clause_token_nums: List[int]) -> \ + Tuple[str, list]: + ranking_tokens = self.find_ranking_tokens(root, appos_token_nums, clause_token_nums) + question_tokens = [] + for node in tree.descendants: + if node.ord not in appos_token_nums + clause_token_nums: + if ranking_tokens and (node.ord in ranking_tokens or node.form.lower() in self.q_pronouns): + question_tokens.append(self.nlp(node.form)[0].lemma_) else: - question = ' '.join([node.form for node in tree.descendants]) - log.debug(f"sanitized question: {question}") - query_nums_batch.append(query_nums) - entities_dict_batch.append(entities_dict) - types_dict_batch.append(types_dict) - questions_batch.append(question) - return questions_batch, query_nums_batch, entities_dict_batch, types_dict_batch + question_tokens.append(node.form) + question = " ".join(question_tokens) + log.debug(f"sanitized question: {question}") + return question, ranking_tokens def find_root(self, tree: Node) -> Node: for node in tree.descendants: @@ -313,7 +515,6 @@ def find_root(self, tree: Node) -> Node: def find_branch_with_unknown(self, root: Node) -> Tuple[str, str]: self.wh_leaf = False self.one_chain = False - if root.form.lower() in self.q_pronouns: if "nsubj" in [node.deprel for node in root.children] or root.form.lower() in self.how_many: self.one_chain = True @@ -321,7 +522,6 @@ def find_branch_with_unknown(self, root: Node) -> Tuple[str, str]: for node in root.children: if node.deprel == "nsubj": return node, node - if not self.one_chain: for node in root.children: if node.form.lower() in self.q_pronouns: @@ -335,7 +535,6 @@ def find_branch_with_unknown(self, root: Node) -> Tuple[str, str]: for child in node.descendants: if child.form.lower() in self.q_pronouns: return child.parent, node - if self.wh_leaf or self.one_chain: for node in root.children: if node.deprel in ["nsubj", "obl", "obj", "nmod", "xcomp"] and node.form.lower() not in self.q_pronouns: @@ -367,126 +566,69 @@ def find_clause_node(self, root: Node, unknown_branch: Node) -> Tuple[str, str]: return elem, node return "", "" - def find_named_entity(self, node: Node, conj_list: List[Node], desc_list: List[Tuple[str, int]], - positions: List[int], cut_clause: bool) -> List[Tuple[str, int]]: - if node.children: - if self.find_nmod_appos(node, positions): - used_desc = [elem for elem in node.children if elem.deprel == "appos"] - else: - used_desc = node.children - - for elem in used_desc: - if self.check_node(elem, conj_list, cut_clause): - desc_list = self.find_named_entity(elem, conj_list, desc_list, positions, cut_clause) - log.debug(f"find_named_entity: node.ord, {node.ord - 1}, {node.form}, positions, {positions}") - log.debug(f"find nmod appos {self.find_nmod_appos(node, positions)}") - if node.ord - 1 in positions: - initials_3 = re.findall("([А-Яа-я]{1}\.)([А-Яа-я]{1}\.)([А-Яа-я]{3,15})", node.form) - initials_2 = re.findall("([А-Яа-я]{1}\.)([А-Яа-я]{3,15})", node.form) - if initials_3: - entity_substring = ' '.join(initials_3[0]) - elif initials_2: - entity_substring = ' '.join(initials_2[0]) - else: - entity_substring = node.form - desc_list.append((entity_substring, node.ord)) - - return desc_list - - def check_node(self, elem: Node, conj_list: List[Node], cut_clause: bool) -> bool: - """ - This function defines whether to go deeper in the syntactic tree to look for named entity tokens - If all the conditions are true, then we recursively look for named entity tokens in elem's descendants. - Args: - elem: node of the syntactic tree for which we decide whether to look for named entities in its descendants - conj_list: list of nodes, connected with the "elem" node with "conj" deprel - cut_clause: if cut_clause is True, we do not want to look for named entities in adjective clauses ("acl") - """ - move_deeper = False - if not cut_clause or (cut_clause and elem.deprel != "acl"): - if elem not in conj_list: - if elem.deprel != "appos" or \ - (elem.deprel == "appos" - and (not elem.children or - (len(elem.children) == 1 and elem.children[0].deprel in ["flat:name", "parataxis"]) or - (len(elem.children) > 1 and {"«", '"', '``', '('} & {nd.form for nd in - elem.descendants}))): - move_deeper = True - return move_deeper - - def find_conj(self, node: Node, conj_list: List[Node], positions: List[int], cut_clause: bool) -> List[Node]: - if node.children: - for elem in node.children: - if not cut_clause or (cut_clause and elem.deprel != "acl"): - conj_list = self.find_conj(elem, conj_list, positions, cut_clause) - - if node.deprel == "conj": - conj_in_ner = False - for elem in node.children: - if elem.deprel == "cc" and (elem.ord - 1) in positions: - conj_in_ner = True - if not conj_in_ner: - conj_list.append(node) - - return conj_list - - def find_entities(self, node: Node, positions: List[int], cut_clause: bool = True) -> List[str]: - entities_list = [] - conj_list = self.find_conj(node, [], positions, cut_clause) - entity = self.find_entity(node, conj_list, positions, cut_clause) - if entity: - entities_list.append(entity) - if conj_list: - for conj_node in conj_list: - curr_conj_list = [elem for elem in conj_list if elem != conj_node] - entity = self.find_entity(conj_node, curr_conj_list, positions, cut_clause) + def find_entities(self, node: Node, positions: List[List[int]]) -> List[str]: + node_desc = [(node.form, node.ord, node.parent)] + \ + [(elem.form, elem.ord, elem.parent) for elem in node.descendants] + node_desc = sorted(node_desc, key=lambda x: x[1]) + entities_list, heads_list = [], [] + for pos_elem in positions: + entity, parents = [], [] + for ind in pos_elem: + for node_elem in node_desc: + if ind + 1 == node_elem[1]: + entity.append(node_elem[0]) + parents.append(node_elem[2]) + break + if len(entity) == len(pos_elem): + entity = " ".join(entity).replace(" .", ".") entities_list.append(entity) - log.debug(f"found_entities, {entities_list}") + heads_list.append(parents[0]) + log.debug(f"node: {node.form} --- found_entities: {entities_list} --- node_desc: {node_desc} --- " + f"positions: {positions}") return entities_list - def find_entity(self, node: Node, conj_list: List[Node], positions: List[int], cut_clause: bool) -> str: - grounded_entity_tokens = self.find_named_entity(node, conj_list, [], positions, cut_clause) - grounded_entity = sorted(grounded_entity_tokens, key=lambda x: x[1]) - grounded_entity = " ".join([entity[0] for entity in grounded_entity]) - return grounded_entity - - def find_nmod_appos(self, node: Node, positions: List[int]) -> bool: - node_desc = {elem.deprel: elem for elem in node.children} - node_deprels = sorted([elem.deprel for elem in node.children if elem.deprel != "case"]) - if node.ord - 1 in positions: - return False - elif node_deprels == ["appos", "nmod"] and node_desc["appos"].ord - 1 in positions \ - and node_desc["nmod"].ord in positions: - return True - return False - def find_year_or_number(self, node: Node) -> bool: found = False for elem in node.descendants: - if elem.deprel == "nummod": + if elem.deprel == "nummod" or re.findall(r"[\d]{4}", elem.form): return True return found - def find_appos_tokens(self, node: Node, appos_token_nums: List[int]) -> List[int]: + def find_year_constraint(self, node: Node) -> list: + node_desc = [(node.form, node.ord)] + [(elem.form, elem.ord) for elem in node.descendants] + node_desc = sorted(node_desc, key=lambda x: x[1]) + desc_text = " ".join([elem[0] for elem in node_desc]) + for symb in ".,:;)": + desc_text = desc_text.replace(f" {symb}", symb) + for pattern in [r"в ([\d]{3,4}) году", r"с ([\d]{3,4}) по ([\d]{3,4})"]: + fnd = re.findall(pattern, desc_text) + if fnd: + return fnd + return [] + + def find_appos_tokens(self, node: Node, tok_and_ord: List[Tuple[Node, int]], + appos_token_nums: List[int]) -> List[int]: for elem in node.children: - if elem.deprel == "appos" and (len(elem.descendants) > 1 and - not {"«", '"', '``', '('} & {nd.form for nd in elem.descendants} or - (len(elem.descendants) == 1 and elem.descendants[0].deprel != "flat:name")): + e_desc = elem.descendants + if elem.deprel == "appos" and elem.ord > 1 and tok_and_ord[elem.ord - 1].deprel == "punct" \ + and not all([nd.deprel in {"appos", "flat:name"} for nd in e_desc]) \ + and not ({"«", '"', '``', '('} & {nd.form for nd in e_desc}): appos_token_nums.append(elem.ord) for desc in elem.descendants: appos_token_nums.append(desc.ord) else: - appos_token_nums = self.find_appos_tokens(elem, appos_token_nums) + appos_token_nums = self.find_appos_tokens(elem, tok_and_ord, appos_token_nums) return appos_token_nums - def find_clause_tokens(self, node: Node, clause_node: Node, clause_token_nums: List[int]) -> List[int]: + def find_clause_tokens(self, node: Node, tok_and_ord: Dict[int, Node], clause_node: Node) -> List[int]: + clause_token_nums = [] for elem in node.children: if elem != clause_node and elem.deprel == "acl": clause_token_nums.append(elem.ord) for desc in elem.descendants: clause_token_nums.append(desc.ord) else: - clause_token_nums = self.find_appos_tokens(elem, clause_token_nums) + clause_token_nums = self.find_appos_tokens(elem, tok_and_ord, clause_token_nums) return clause_token_nums def find_first_last(self, node: Node) -> str: @@ -496,11 +638,9 @@ def find_first_last(self, node: Node) -> str: for node in nodes: node_desc = defaultdict(set) for elem in node.children: - parsed_elem = self.nlp(elem.form.lower())[0].lemma_ - if parsed_elem is not None: - node_desc[elem.deprel].add(parsed_elem) - else: - node_desc[elem.deprel].add(elem.form) + normal_form = self.nlp(elem.form.lower())[0].lemma_ + node_desc[elem.deprel].add(normal_form) + log.debug(f"find_first_last {node_desc}") if "amod" in node_desc.keys() and "nmod" in node_desc.keys() and \ node_desc["amod"].intersection(self.first_tokens | self.last_tokens): first_or_last = ' '.join(node_desc["amod"].intersection(self.first_tokens | self.last_tokens)) @@ -508,34 +648,48 @@ def find_first_last(self, node: Node) -> str: nodes = [elem for node in nodes for elem in node.children] return first_or_last - def find_ranking_tokens(self, node: Node) -> list: + def find_ranking_tokens(self, node: Node, appos_token_nums: List[int], clause_token_nums: List[int]) -> list: ranking_tokens = [] for elem in node.descendants: - if self.nlp(elem.form)[0].lemma_ in self.ranking_tokens: + if self.nlp(elem.form)[0].lemma_ in self.ranking_tokens \ + and elem.ord not in appos_token_nums + clause_token_nums: ranking_tokens.append(elem.ord) ranking_tokens.append(elem.parent.ord) return ranking_tokens return ranking_tokens - def build_query(self, root: Node, unknown_branch: Node, root_desc: Dict[str, List[Node]], - unknown_node: Node, unknown_modifiers: List[Node], clause_modifiers: List[Node], - clause_node: Node, positions: List[int], - count: bool = False, temporal_order: str = "", - ranking_tokens: List[str] = None) -> Tuple[List[str], List[str], List[str]]: + @staticmethod + def choose_grounded_entity(grounded_entities: List[str], entities_dict: Dict[str, str]): + tags = [entities_dict.get(entity.lower(), "") for entity in grounded_entities] + if len(grounded_entities) > 1: + if not all([tags[i] == tags[0] for i in range(1, len(tags))]): + for f_tag in ["WORK_OF_ART", "FAC", "PERSON", "GPE"]: + for entity, tag in zip(grounded_entities, tags): + if tag == f_tag: + return [entity] + elif not all([entity[0].islower() for entity in grounded_entities]): + for entity in grounded_entities: + if entity[0].isupper(): + return [entity] + return grounded_entities + + def build_query(self, root: Node, unknown_branch: Node, root_desc: Dict[str, List[Node]], unknown_node: Node, + unknown_modifiers: List[Node], clause_modifiers: List[Node], clause_node: Node, + positions: List[List[int]], entities_dict: Dict[str, str], count: bool = False, + temporal_order: str = "", ranking_tokens: List[str] = None) -> Tuple[ + List[str], List[str], List[str]]: query_nums = [] - grounded_entities_list = [] - types_list = [] - modifiers_list = [] - qualifier_entities_list = [] + grounded_entities_list, types_list, modifiers_list, qualifier_entities_list = [], [], [], [] found_year_or_number = False order = False root_desc_deprels = [] for key in root_desc.keys(): for i in range(len(root_desc[key])): - root_desc_deprels.append(key) + if key in {"nsubj", "obj", "obl", "iobj", "acl", "nmod", "xcomp", "cop"}: + root_desc_deprels.append(key) root_desc_deprels = sorted(root_desc_deprels) - log.debug(f"build_query: root_desc.keys, {root_desc_deprels}, positions {positions}") - log.debug(f"temporal order {temporal_order}, ranking tokens {ranking_tokens}") + log.debug(f"build_query: root_desc.keys, {root_desc_deprels}, positions {positions}, wh_leaf {self.wh_leaf}, " + f"one_chain {self.one_chain}, temporal order {temporal_order}, ranking tokens {ranking_tokens}") if root_desc_deprels in [["nsubj", "obl"], ["nsubj", "obj"], ["nsubj", "xcomp"], @@ -552,13 +706,13 @@ def build_query(self, root: Node, unknown_branch: Node, root_desc: Dict[str, Lis ["nsubj"]]: if self.wh_leaf or self.one_chain: if root_desc_deprels == ["nsubj", "obl"]: - grounded_entities_list = self.find_entities(root_desc["nsubj"][0], positions, cut_clause=True) + grounded_entities_list = self.find_entities(root_desc["nsubj"][0], positions) if not grounded_entities_list: - grounded_entities_list = self.find_entities(root_desc["obl"][0], positions, cut_clause=True) + grounded_entities_list = self.find_entities(root_desc["obl"][0], positions) else: for nodes in root_desc.values(): if nodes[0].form not in self.q_pronouns: - grounded_entities_list = self.find_entities(nodes[0], positions, cut_clause=True) + grounded_entities_list = self.find_entities(nodes[0], positions) if grounded_entities_list: break else: @@ -566,7 +720,7 @@ def build_query(self, root: Node, unknown_branch: Node, root_desc: Dict[str, Lis grounded_entities_list = [root.form] for nodes in root_desc.values(): if nodes[0] != unknown_branch: - grounded_entities_list = self.find_entities(nodes[0], positions, cut_clause=True) + grounded_entities_list = self.find_entities(nodes[0], positions) if grounded_entities_list: type_entity = unknown_node.form types_list.append(type_entity) @@ -577,40 +731,67 @@ def build_query(self, root: Node, unknown_branch: Node, root_desc: Dict[str, Lis if isinstance(modifier, str): modifiers_list.append(modifier) else: - modifier_entities = self.find_entities(modifier, positions, cut_clause=True) + modifier_entities = self.find_entities(modifier, positions) if modifier_entities: modifiers_list += modifier_entities if clause_modifiers: found_year_or_number = self.find_year_or_number(clause_modifiers[0]) if found_year_or_number: query_nums.append("0") - qualifier_entities_list = self.find_entities(clause_modifiers[0], positions, cut_clause=True) + qualifier_entities_list = self.find_entities(clause_modifiers[0], positions) if root_desc_deprels == ["nsubj", "obl", "obl"]: - grounded_entities_list = self.find_entities(root_desc["nsubj"][0], positions, cut_clause=True) + grounded_entities_list = self.find_entities(root_desc["nsubj"][0], positions) for node in root_desc["obl"]: if node == unknown_branch: types_list.append(node.form) else: - grounded_entities_list += self.find_entities(node, positions, cut_clause=True) + grounded_entities_list += self.find_entities(node, positions) + + if root_desc_deprels == ["nsubj", "obj", "obj"]: + obj_desc = root_desc["obj"] + qualifier_entities_list = self.find_entities(obj_desc[0], positions) + grounded_entities_list = self.find_entities(obj_desc[1], positions) + + year_constraint = self.find_year_constraint(root) + if root_desc_deprels == ["nmod", "nsubj"] and year_constraint: + if len(year_constraint[0]) == 2: + query_nums.append("24") + elif len(year_constraint[0]) == 1: + query_nums.append("0") if root_desc_deprels == ["obj", "xcomp"]: - grounded_entities_list = self.find_entities(root_desc["xcomp"][0], positions, cut_clause=True) + grounded_entities_list = self.find_entities(root_desc["xcomp"][0], positions) - if root_desc_deprels == ["nsubj", "obj", "obl"] or root_desc_deprels == ["obj", "obl"] and self.wh_leaf: + if (self.wh_leaf and root_desc_deprels in [["nsubj", "obj", "obl"], ["obj", "obl"]]) \ + or (root_desc_deprels in [["nsubj", "obj", "obl"], ["obl", "xcomp"]] + and self.find_year_or_number(root_desc["obl"][0])): found_year_or_number = self.find_year_or_number(root_desc["obl"][0]) + nsubj_ent_list, obj_ent_list = [], [] + if "nsubj" in root_desc_deprels: + nsubj_ent_list = self.find_entities(root_desc["nsubj"][0], positions) + if "obj" in root_desc: + obj_ent_list = self.find_entities(root_desc["obj"][0], positions) + obl_ent_list = self.find_entities(root_desc["obl"][0], positions) + log.debug(f"nsubj_ent: {nsubj_ent_list} --- obj_ent: {obj_ent_list} obl_ent: {obl_ent_list}") if self.wh_leaf: - grounded_entities_list = self.find_entities(root_desc["obl"][0], positions, cut_clause=True) - qualifier_entities_list = self.find_entities(root_desc["obj"][0], positions, cut_clause=True) + grounded_entities_list = obl_ent_list + qualifier_entities_list = obj_ent_list + elif not found_year_or_number and nsubj_ent_list and obl_ent_list: + grounded_entities_list = nsubj_ent_list + modifiers_list = obl_ent_list else: - grounded_entities_list = self.find_entities(root_desc["obj"][0], positions, cut_clause=True) - if found_year_or_number: - query_nums.append("0") + grounded_entities_list = obj_ent_list + if found_year_or_number: + query_nums.append("0") + if not grounded_entities_list: + grounded_entities_list = self.find_entities(root, positions) + grounded_entities_list = self.choose_grounded_entity(grounded_entities_list, entities_dict) if clause_node: for node in clause_node.children: if node.deprel == "obj": - grounded_entities_list = self.find_entities(node, positions, cut_clause=False) + grounded_entities_list = self.find_entities(node, positions) if self.find_year_or_number(node): query_nums.append("0") @@ -619,58 +800,57 @@ def build_query(self, root: Node, unknown_branch: Node, root_desc: Dict[str, Lis types_list.append(type_entity) if root_desc_deprels == ["nmod", "nmod"]: - grounded_entities_list = self.find_entities(root_desc["nmod"][0], positions, cut_clause=True) - modifiers_list = self.find_entities(root_desc["nmod"][1], positions, cut_clause=True) + grounded_entities_list = self.find_entities(root_desc["nmod"][0], positions) + modifiers_list = self.find_entities(root_desc["nmod"][1], positions) if root_desc_deprels == ["nmod", "nsubj", "nummod"]: if not self.wh_leaf: - grounded_entities_list = self.find_entities(root_desc["nmod"][0], positions, cut_clause=True) + grounded_entities_list = self.find_entities(root_desc["nmod"][0], positions) found_year_or_number = self.find_year_or_number(root_desc["nummod"][0]) - if temporal_order: + if temporal_order and not query_nums: for deprel in root_desc: for node in root_desc[deprel]: - entities = self.find_entities(node, positions, cut_clause=True) + entities = self.find_entities(node, positions) if entities: grounded_entities_list = entities break if grounded_entities_list: break - if temporal_order in self.first_tokens: - query_nums.append("22") - query_nums.append("23") - if temporal_order in self.last_tokens: - query_nums.append("24") - if temporal_order in self.begin_tokens: - query_nums.append("22") - query_nums.append("25") - if temporal_order in self.end_tokens: - query_nums.append("24") - query_nums.append("26") + if temporal_order in self.first_tokens | self.begin_tokens: + query_nums += ["22"] + if temporal_order in self.last_tokens | self.end_tokens: + query_nums += ["23"] + log.debug(f"query_nums: {query_nums} --- year_constraint: {year_constraint}") if count: - grounded_entities_list = self.find_entities(root, positions, cut_clause=True) + grounded_entities_list = self.find_entities(root, positions) + grounded_entities_list = self.choose_grounded_entity(grounded_entities_list, entities_dict) entities_list = grounded_entities_list + qualifier_entities_list + modifiers_list + types_list = [tp for tp in types_list + if not (len(tp.split()) == 1 and self.nlp(tp)[0].lemma_ in self.date_tokens)] - grounded_entities_length = len(grounded_entities_list) - types_length = len(types_list) - modifiers_length = len(modifiers_list) - qualifiers_length = len(qualifier_entities_list) - if qualifiers_length > 0 or modifiers_length or count: - types_length = 0 + gr_len = len(grounded_entities_list) + types_len = len(types_list) + mod_len = len(modifiers_list) + qua_len = len(qualifier_entities_list) + if qua_len or count: + types_len = 0 - if not temporal_order: + if not temporal_order and not query_nums: for num, template in self.template_queries.items(): - if [grounded_entities_length, types_length, modifiers_length, - qualifiers_length, found_year_or_number, count, order] == list( - template["syntax_structure"].values()): + syntax_info = [gr_len, types_len, mod_len, qua_len, found_year_or_number, count, order] + if syntax_info == list(template["syntax_structure"].values()): query_nums.append(num) - - log.debug(f"tree_to_sparql, grounded entities {grounded_entities_list}") - log.debug(f"tree_to_sparql, types {types_list}") - log.debug(f"tree_to_sparql, modifier entities {modifiers_list}") - log.debug(f"tree_to_sparql, qualifier entities {qualifier_entities_list}") - log.debug(f"tree to sparql, query nums {query_nums}") + if mod_len: + syntax_info[1] = 0 + if syntax_info == list(template["syntax_structure"].values()): + query_nums.append(num) + + log.debug(f"tree_to_sparql, grounded entities: {grounded_entities_list} --- types: {types_list} --- " + f"modifier entities: {modifiers_list} --- qualifier entities: {qualifier_entities_list} --- " + f"year_or_number {found_year_or_number} --- count: {count} --- order: {order} --- " + f"query nums: {query_nums}") return query_nums, entities_list, types_list diff --git a/deeppavlov/models/kbqa/type_define.py b/deeppavlov/models/kbqa/type_define.py index 1ccdd9b388..312c325551 100644 --- a/deeppavlov/models/kbqa/type_define.py +++ b/deeppavlov/models/kbqa/type_define.py @@ -57,7 +57,6 @@ def __init__(self, lang: str, types_filename: str, types_sets_filename: str, def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[str]], tags_batch: List[List[str]], types_substr_batch: List[List[str]] = None): - types_sets_batch = [] if types_substr_batch is None: types_substr_batch = [] for question, entity_substr_list in zip(questions_batch, entity_substr_batch): @@ -83,43 +82,19 @@ def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[st break elif token.head.text == type_noun and token.dep_ == "prep": if len(list(token.children)) == 1 \ - and not any([[tok.text for tok in token.children][0] in entity_substr.lower() + and not any([list(token.children)[0].text in entity_substr.lower() for entity_substr in entity_substr_list]): - types_substr += [token.text, [tok.text for tok in token.children][0]] + types_substr += [token.text, list(token.children)[0].text] elif any([word in question for word in self.pronouns]): for token in doc: if token.dep_ == "nsubj" and not any([token.text in entity_substr.lower() for entity_substr in entity_substr_list]): types_substr.append(token.text) - types_substr = [(token, token_pos_dict[token]) for token in types_substr] types_substr = sorted(types_substr, key=lambda x: x[1]) types_substr = " ".join([elem[0] for elem in types_substr]) types_substr_batch.append(types_substr) - for types_substr in types_substr_batch: - types_substr_tokens = types_substr.split() - types_substr_tokens = [tok for tok in types_substr_tokens if tok not in self.stopwords] - if self.lang == "@ru": - types_substr_tokens = [self.nlp(tok)[0].lemma_ for tok in types_substr_tokens] - types_substr_tokens = set(types_substr_tokens) - types_scores = [] - for entity in self.types_dict: - labels, cnt = self.types_dict[entity] - cur_cnts = [] - for label in labels: - label_tokens = label.lower().split() - if len(types_substr_tokens) == 1 and len(label_tokens) == 2 and \ - list(types_substr_tokens)[0] == label_tokens[0]: - cur_cnts.append(0.3) - else: - inters = types_substr_tokens.intersection(set(label_tokens)) - cur_cnts.append(len(inters) / max(len(types_substr_tokens), len(label_tokens))) - - types_scores.append([entity, max(cur_cnts), cnt]) - types_scores = sorted(types_scores, key=lambda x: (x[1], x[2]), reverse=True) - cur_types = [elem[0] for elem in types_scores if elem[1] > 0][:self.num_types_to_return] - types_sets_batch.append(cur_types) - + types_sets_batch = [set() for _ in questions_batch] for n, (question, types_sets) in enumerate(zip(questions_batch, types_sets_batch)): question = question.lower() if not types_sets: @@ -128,11 +103,18 @@ def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[st types_sets_batch[n] = self.types_sets["PER"] elif question.startswith("где"): types_sets_batch[n] = self.types_sets["LOC"] + elif any([question.startswith(elem) for elem in ["когда", "в каком году", "в каком месяце"]]): + types_sets_batch[n] = {"date"} + elif len(question.split()) > 1 and (any([question.startswith(elem) for elem in ["кем ", "как"]]) \ + or question.split()[1].startswith("как")): + types_sets_batch[n] = {"not_date"} elif self.lang == "@en": if question.startswith("who"): types_sets_batch[n] = self.types_sets["PER"] elif question.startswith("where"): types_sets_batch[n] = self.types_sets["LOC"] + elif any([question.startswith(elem) for elem in ["when", "what year", "what month"]]): + types_sets_batch[n] = {"date"} new_entity_substr_batch, new_entity_offsets_batch, new_tags_batch = [], [], [] for question, entity_substr_list, tags_list in zip(questions_batch, entity_substr_batch, tags_batch): diff --git a/deeppavlov/models/kbqa/utils.py b/deeppavlov/models/kbqa/utils.py index 1af3cb6912..8dbf2443c2 100644 --- a/deeppavlov/models/kbqa/utils.py +++ b/deeppavlov/models/kbqa/utils.py @@ -12,13 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re import itertools -from typing import List +import re +from collections import namedtuple +from typing import List, Tuple, Dict, Any + + +def find_query_features(query, qualifier_rels=None, question=None, order_from_query=None): + query = query.lower().replace("select distinct", "select") + answer_ent = re.findall(r"select [\(]?([\S]+) ", query) + order_info_nt = namedtuple("order_info", ["variable", "sorting_order"]) + order_variable = re.findall("order by (asc|desc)\((.*)\)", query) + if order_variable: + if (qualifier_rels and len(qualifier_rels[0][4:]) > 1) or order_from_query: + answers_sorting_order = order_variable[0][0] + else: + answers_sorting_order = order_of_answers_sorting(question) + order_info = order_info_nt(order_variable[0][1], answers_sorting_order) + else: + order_info = order_info_nt(None, None) + filter_from_query = re.findall("contains\((\?\w), (.+?)\)", query) + return answer_ent, order_info, filter_from_query def extract_year(question_tokens: List[str], question: str) -> str: question_patterns = [r'.*\d{1,2}/\d{1,2}/(\d{4}).*', r'.*\d{1,2}-\d{1,2}-(\d{4}).*', r'.*(\d{4})-\d{1,2}-\d{1,2}.*'] + from_to_patterns = [r"from ([\d]{3,4}) to [\d]{3,4}", r"с ([\d]{3,4}) по [\d]{3,4}"] token_patterns = [r'(\d{4})', r'^(\d{4})-.*', r'.*-(\d{4})$'] year = "" for pattern in question_patterns: @@ -27,6 +46,10 @@ def extract_year(question_tokens: List[str], question: str) -> str: year = fnd.group(1) break else: + for pattern in from_to_patterns: + fnd = re.findall(pattern, question) + if fnd: + return fnd[0] for token in question_tokens: for pattern in token_patterns: fnd = re.search(pattern, token) @@ -54,18 +77,20 @@ def extract_number(question_tokens: List[str], question: str) -> str: def order_of_answers_sorting(question: str) -> str: question_lower = question.lower() - max_words = ["maximum", "highest", "max ", "greatest", "most", "longest", "biggest", "deepest"] - + max_words = ["maximum", "highest", "max ", "greatest", "most", "longest", "biggest", "deepest", "завершил", + "закончил", "завершает"] for word in max_words: if word in question_lower: return "desc" - return "asc" def make_combs(entity_ids: List[List[str]], permut: bool) -> List[List[str]]: entity_ids = [[(entity, n) for n, entity in enumerate(entities_list)] for entities_list in entity_ids] entity_ids = list(itertools.product(*entity_ids)) + entity_ids = [comb for comb in entity_ids if not + (all([comb[i][0][0].split("/")[-1] == comb[0][0][0].split("/")[-1] for i in range(len(comb))]) + and not all([comb[i][0][0] == comb[0][0][0] for i in range(len(comb))]))] entity_ids_permut = [] if permut: for comb in entity_ids: @@ -77,30 +102,153 @@ def make_combs(entity_ids: List[List[str]], permut: bool) -> List[List[str]]: return ent_combs -def fill_query(query: List[str], entity_comb: List[str], type_comb: List[str], rel_comb: List[str]) -> List[str]: +def fill_slots(query: str, entity_comb: List[str], type_comb: List[str], rel_comb: List[Tuple[str, float]], + delete_rel_prefix: bool = False) -> str: + for n, entity in enumerate(entity_comb[:-1]): + query = query.replace(f"e{n + 1}", entity) + for n, entity_type in enumerate(type_comb[:-1]): # type_entity + query = query.replace(f"t{n + 1}", entity_type) + for n, (rel, score) in enumerate(rel_comb[:-1]): + if not rel.startswith("?"): + if delete_rel_prefix: + rel = rel.split("/")[-1] + query = query.replace(f"r{n + 1}", rel) + return query + + +def correct_variables(query_triplets: List[str], answer_ent: List[str], query_info: Dict[str, str]): + for i in range(len(query_triplets)): + for ent_var in answer_ent: + triplet_elements = query_triplets[i].split() + for j in range(len(triplet_elements)): + if triplet_elements[j] not in ent_var and triplet_elements[j].startswith("?"): + triplet_elements[j] = query_info["mid_var"] + if triplet_elements[j].startswith("?") \ + and triplet_elements[j] not in [query_info["mid_var"], query_info["unk_var"]]: + triplet_elements[j] = query_info["unk_var"] + query_triplets[i] = " ".join(triplet_elements) + query_triplets[i] = query_triplets[i].replace(ent_var, query_info["unk_var"]) + return query_triplets + + +def query_from_triplets(query_triplets: List[str], answer_ent: List[str], query_info: Dict[str, str]) -> str: + filled_query = " . ".join(query_triplets) + if answer_ent and answer_ent[0].lower().startswith("count"): + filled_query = f"SELECT COUNT({query_info['unk_var']}) " + \ + f"WHERE {{ {filled_query}. }}" + else: + filled_query = f"SELECT {query_info['unk_var']} WHERE {{ {filled_query}. }}" + filled_query = filled_query.replace(" ..", ".") + return filled_query + + +def fill_query(query: List[str], entity_comb: List[str], type_comb: List[str], rel_comb: List[Tuple[str, float]], + map_query_str_to_kb) -> List[str]: ''' example of query: ["wd:E1", "p:R1", "?s"] entity_comb: ["Q159"] type_comb: [] rel_comb: ["P17"] + map_query_str_to_kb = [("P0", "http://wd"), + ("P00", "http://wl"), + ("wd:", "http://we/"), + ("wdt:", "http://wpd/"), + (" p:", " http://wp/"), + ("ps:", "http://wps/"), + ("pq:", "http://wpq/")] ''' query = " ".join(query) - map_query_str_to_wikidata = [("P0", "http://wd"), - ("P00", "http://wl"), - ("wd:", "http://we/"), - ("wdt:", "http://wpd/"), - (" p:", " http://wp/"), - ("ps:", "http://wps/"), - ("pq:", "http://wpq/")] - - for query_str, wikidata_str in map_query_str_to_wikidata: + + for query_str, wikidata_str in map_query_str_to_kb: query = query.replace(query_str, wikidata_str) - for n, entity in enumerate(entity_comb[:-1]): - query = query.replace(f"e{n + 1}", entity) - for n, entity_type in enumerate(type_comb[:-1]): # type_entity - query = query.replace(f"t{n + 1}", entity_type) - for n, (rel, score) in enumerate(rel_comb[:-1]): - query = query.replace(f"r{n + 1}", rel) + query = fill_slots(query, entity_comb, type_comb, rel_comb) query = query.replace("http://wpd/P0", "http://wd") query = query.replace("http://wpd/P00", "http://wl") query = query.split(' ') return query + + +def make_sparql_query(query_info: Tuple[List[str], List[str], List[str], Dict[str, Any], Dict[str, Any]], + entities: List[str], rels: List[Tuple[str, float]], types: List[str], + query_info_dict: Dict[str, str]) -> List[str]: + query_triplets, filled_triplets, answer_ent, filter_info, order_info = query_info + query_triplets = [fill_slots(elem, entities, types, rels, delete_rel_prefix=True) for elem in query_triplets] + query_triplets = correct_variables(query_triplets, answer_ent, query_info_dict) + filled_queries = [] + for triplets_p in list(itertools.permutations(query_triplets)): + filled_queries.append(query_from_triplets(triplets_p, answer_ent, query_info_dict)) + return filled_queries + + +def merge_sparql_query(query_info: Tuple[List[str], List[str], Dict[str, Any], Dict[str, Any]], + query_info_dict: Dict[str, str]) -> str: + query_triplets, answer_ent, filter_info, order_info = query_info + query = query_from_triplets(query_triplets, answer_ent, query_info_dict) + return query + + +def preprocess_template_queries(template_queries: Dict[str, Any], kb_prefixes: Dict[str, str]) -> Dict[str, Any]: + for template_num in template_queries: + template = template_queries[template_num] + query = template["query_template"] + q_triplets = re.findall("{[ ]?(.*?)[ ]?}", query)[0].split(' . ') + q_triplets = [triplet.split(' ')[:3] for triplet in q_triplets] + if not "rel_types" in template: + template["rel_types"] = ["direct" for _ in q_triplets] + rel_types = template["rel_types"] + rel_dirs, n_hops, entities, types, gr_ent, mod_ent, q_ent = [], [], set(), set(), set(), set(), set() + + for n, (triplet, rel_type) in enumerate(zip(q_triplets, rel_types)): + if not triplet[1].startswith(kb_prefixes["type_rel"]): + if triplet[2].startswith("?"): + rel_dirs.append("forw") + else: + rel_dirs.append("backw") + for ind in [0, 2]: + if triplet[ind].startswith(kb_prefixes["entity"]): + entities.add(triplet[ind]) + elif triplet[ind].startswith(kb_prefixes["type"]): + types.add(triplet[ind]) + if rel_type in {"qualifier", "statement"}: + if triplet[2].startswith(kb_prefixes["entity"]): + q_ent.add(triplet[2]) + else: + if triplet[0].startswith(kb_prefixes["entity"]): + gr_ent.add(triplet[0]) + elif triplet[2].startswith(kb_prefixes["entity"]): + mod_ent.add(triplet[2]) + if triplet[1].startswith(kb_prefixes["rel"]) and triplet[0].startswith("?") and triplet[2].startswith("?"): + n_hops.append("2-hop") + elif n == 0 and len(q_triplets) == 2 and q_triplets[1][1].startswith(kb_prefixes["rel"]) \ + and q_triplets[1][0].startswith("?") and q_triplets[1][2].startswith("?"): + n_hops.append("1-of-2-hop") + else: + n_hops.append("1-hop") + syntax_structure = {"gr_ent": len(gr_ent), "types": len(types), "mod_ent": len(mod_ent), + "q_ent": len(q_ent), "year_or_number": False, "count": False, "order": False} + if "filter" in query.lower(): + syntax_structure["year_or_number"] = True + if "order" in query.lower(): + syntax_structure["order"] = True + if "count" in query.lower(): + syntax_structure["count"] = True + if not "query_sequence" in template: + template["query_sequence"] = list(range(1, len(q_triplets) + 1)) + template["rel_dirs"] = rel_dirs + template["n_hops"] = n_hops + template["entities_and_types_num"] = [len(entities), len(types)] + if entities: + entities_str = '_'.join([str(num) for num in list(range(1, len(entities) + 1))]) + else: + entities_str = "0" + if types: + types_str = '_'.join([str(num) for num in list(range(1, len(types) + 1))]) + else: + types_str = "0" + template["entities_and_types_select"] = f"{entities_str} {types_str}" + template["syntax_structure"] = syntax_structure + if "return_if_found" not in template: + template["return_if_found"] = False + if "priority" not in template: + template["priority"] = 1 + template_queries[template_num] = template + return template_queries diff --git a/deeppavlov/models/kbqa/wiki_parser.py b/deeppavlov/models/kbqa/wiki_parser.py index dfc4b3b7d1..fa31afb3e3 100644 --- a/deeppavlov/models/kbqa/wiki_parser.py +++ b/deeppavlov/models/kbqa/wiki_parser.py @@ -14,14 +14,14 @@ import datetime import re +from collections import namedtuple from logging import getLogger from typing import List, Tuple, Dict, Any, Union -from collections import namedtuple from hdt import HDTDocument -from deeppavlov.core.common.file import load_pickle from deeppavlov.core.commands.utils import expand_path +from deeppavlov.core.common.file import load_pickle, read_json from deeppavlov.core.common.registry import register log = getLogger(__name__) @@ -34,6 +34,7 @@ class WikiParser: def __init__(self, wiki_filename: str, file_format: str = "hdt", prefixes: Dict[str, Union[str, Dict[str, str]]] = None, + rel_q2name_filename: str = None, max_comb_num: int = 1e6, lang: str = "@en", **kwargs) -> None: """ @@ -55,7 +56,8 @@ def __init__(self, wiki_filename: str, "direct": "http://wpd", "no_type": "http://wp", "statement": "http://wps", - "qualifier": "http://wpq" + "qualifier": "http://wpq", + "type": "http://wpd/P31" }, "statement": "http://ws" } @@ -69,9 +71,19 @@ def __init__(self, wiki_filename: str, self.parsed_document = {} else: raise ValueError("Unsupported file format") + self.used_rels = set() + self.rel_q2name = dict() + if rel_q2name_filename: + if rel_q2name_filename.endswith("json"): + self.rel_q2name = read_json(str(expand_path(rel_q2name_filename))) + elif rel_q2name_filename.endswith("pickle"): + self.rel_q2name = load_pickle(str(expand_path(rel_q2name_filename))) + else: + raise ValueError(f"Unsupported file format: {rel_q2name_filename}") self.max_comb_num = max_comb_num self.lang = lang + self.replace_tokens = [('"', ''), (self.lang, " "), ('$', ' '), (' ', ' ')] def __call__(self, parser_info_list: List[str], queries_list: List[Any]) -> List[Any]: wiki_parser_output = self.execute_queries_list(parser_info_list, queries_list) @@ -82,22 +94,31 @@ def execute_queries_list(self, parser_info_list: List[str], queries_list: List[A query_answer_types = [] for parser_info, query in zip(parser_info_list, queries_list): if parser_info == "query_execute": - candidate_output = [] + answers, found_rels, found_combs = [], [], [] try: - what_return, query_seq, filter_info, order_info, answer_types, rel_types, return_if_found = query + what_return, rels_from_query, query_seq, filter_info, order_info, answer_types, rel_types, \ + return_if_found = query if answer_types: query_answer_types = answer_types - candidate_output = self.execute(what_return, query_seq, filter_info, order_info, - query_answer_types, rel_types) - except: + answers, found_rels, found_combs = \ + self.execute(what_return, rels_from_query, query_seq, filter_info, order_info, + query_answer_types, rel_types) + except ValueError: log.warning("Wrong arguments are passed to wiki_parser") - wiki_parser_output.append(candidate_output) + wiki_parser_output.append([answers, found_rels, found_combs]) elif parser_info == "find_rels": rels = [] try: rels = self.find_rels(*query) except: log.warning("Wrong arguments are passed to wiki_parser") + wiki_parser_output.append(rels) + elif parser_info == "find_rels_2hop": + rels = [] + try: + rels = self.find_rels_2hop(*query) + except ValueError: + log.warning("Wrong arguments are passed to wiki_parser") wiki_parser_output += rels elif parser_info == "find_object": objects = [] @@ -127,6 +148,13 @@ def execute_queries_list(self, parser_info_list: List[str], queries_list: List[A except: log.warning("Wrong arguments are passed to wiki_parser") wiki_parser_output.append(types) + elif parser_info == "fill_triplets": + filled_triplets = [] + try: + filled_triplets = self.fill_triplets(*query) + except ValueError: + log.warning("Wrong arguments are passed to wiki_parser") + wiki_parser_output.append(filled_triplets) elif parser_info == "find_triplets": if self.file_format == "hdt": triplets = [] @@ -171,11 +199,12 @@ def execute_queries_list(self, parser_info_list: List[str], queries_list: List[A return wiki_parser_output def execute(self, what_return: List[str], + rels_from_query: List[str], query_seq: List[List[str]], filter_info: List[Tuple[str]] = None, order_info: namedtuple = None, answer_types: List[str] = None, - rel_types: List[str] = None) -> List[List[str]]: + rel_types: List[str] = None): """ Let us consider an example of the question "What is the deepest lake in Russia?" with the corresponding SPARQL query @@ -190,22 +219,22 @@ def execute(self, what_return: List[str], order_info: order_info(variable='?obj', sorting_order='asc') """ extended_combs = [] - combs = [] + answers, found_rels, found_combs = [], [], [] for n, (query, rel_type) in enumerate(zip(query_seq, rel_types)): unknown_elem_positions = [(pos, elem) for pos, elem in enumerate(query) if elem.startswith('?')] """ n = 0, query = ["?ent", "http://www.wikidata.org/prop/direct/P17", - "http://www.wikidata.org/entity/Q159"] + "http://www.wikidata.org/entity/Q159"] unknown_elem_positions = ["?ent"] n = 1, query = ["?ent", "http://www.wikidata.org/prop/direct/P31", - "http://www.wikidata.org/entity/Q23397"] + "http://www.wikidata.org/entity/Q23397"] unknown_elem_positions = [(0, "?ent")] n = 2, query = ["?ent", "http://www.wikidata.org/prop/direct/P4511", "?obj"] unknown_elem_positions = [(0, "?ent"), (2, "?obj")] """ if n == 0: - combs = self.search(query, unknown_elem_positions, rel_type) + combs, triplets = self.search(query, unknown_elem_positions, rel_type) # combs = [{"?ent": "http://www.wikidata.org/entity/Q5513"}, ...] else: if combs: @@ -230,20 +259,22 @@ def execute(self, what_return: List[str], "http://www.wikidata.org/entity/Q23397"], ...] extended_combs = [{"?ent": "http://www.wikidata.org/entity/Q5513"}, ...] """ - known_values = [comb[known_elem] for known_elem in known_elements] - for known_elem, known_value in zip(known_elements, known_values): - filled_query = [elem.replace(known_elem, known_value) for elem in query] - new_combs = self.search(filled_query, unknown_elem_positions, rel_type) - for new_comb in new_combs: - extended_combs.append({**comb, **new_comb}) + if comb: + known_values = [comb[known_elem] for known_elem in known_elements] + for known_elem, known_value in zip(known_elements, known_values): + filled_query = [elem.replace(known_elem, known_value) for elem in query] + new_combs, triplets = self.search(filled_query, unknown_elem_positions, rel_type) + for new_comb in new_combs: + extended_combs.append(self.merge_combs(comb, new_comb)) else: - new_combs = self.search(query, unknown_elem_positions, rel_type) + new_combs, triplets = self.search(query, unknown_elem_positions, rel_type) for comb in combs: for new_comb in new_combs: - extended_combs.append({**comb, **new_comb}) + extended_combs.append(self.merge_combs(comb, new_comb)) combs = extended_combs - if combs: + is_boolean = self.define_is_boolean(query_seq) + if combs or is_boolean: if filter_info: for filter_elem, filter_value in filter_info: if filter_value == "qualifier": @@ -253,48 +284,97 @@ def execute(self, what_return: List[str], if order_info and not isinstance(order_info, list) and order_info.variable is not None: reverse = True if order_info.sorting_order == "desc" else False sort_elem = order_info.variable - for i in range(len(combs)): - value_str = combs[i][sort_elem].split('^^')[0].strip('"') - fnd = re.findall(r"[\d]{3,4}-[\d]{1,2}-[\d]{1,2}", value_str) - if fnd: - combs[i][sort_elem] = fnd[0] - else: - combs[i][sort_elem] = float(value_str) - combs = sorted(combs, key=lambda x: x[sort_elem], reverse=reverse) - combs = [combs[0]] - - if what_return[-1].startswith("count"): - combs = [[combs[0][key] for key in what_return[:-1]] + [len(combs)]] + if combs and "?p" in combs[0]: + rel_combs = {} + for comb in combs: + if comb["?p"] not in rel_combs: + rel_combs[comb["?p"]] = [] + rel_combs[comb["?p"]].append(comb) + rel_combs_list = rel_combs.values() + else: + rel_combs_list = [combs] + new_rel_combs_list = [] + for rel_combs in rel_combs_list: + new_rel_combs = [] + for rel_comb in rel_combs: + value_str = rel_comb[sort_elem].split('^^')[0].strip('"+') + fnd_date = re.findall(r"[\d]{3,4}-[\d]{1,2}-[\d]{1,2}", value_str) + fnd_num = re.findall(r"([\d]+)\.([\d]+)", value_str) + if fnd_date: + rel_comb[sort_elem] = fnd_date[0] + elif fnd_num or value_str.isdigit(): + rel_comb[sort_elem] = float(value_str) + new_rel_combs.append(rel_comb) + new_rel_combs = [(elem, n) for n, elem in enumerate(new_rel_combs)] + new_rel_combs = sorted(new_rel_combs, key=lambda x: (x[0][sort_elem], x[1]), reverse=reverse) + new_rel_combs = [elem[0] for elem in new_rel_combs] + new_rel_combs_list.append(new_rel_combs) + combs = [new_rel_combs[0] for new_rel_combs in new_rel_combs_list] + + if what_return and what_return[-1].startswith("count"): + answers = [[len(combs)]] else: - combs = [[elem[key] for key in what_return] for elem in combs] + answers = [[elem[key] for key in what_return if key in elem] for elem in combs] if answer_types: - if answer_types == ["date"]: - combs = [[entity for entity in comb - if re.findall(r"[\d]{3,4}-[\d]{1,2}-[\d]{1,2}", entity)] for comb in combs] + if list(answer_types) == ["date"]: + answers = [[entity for entity in answer + if re.findall(r"[\d]{3,4}-[\d]{1,2}-[\d]{1,2}", entity)] for answer in answers] + elif list(answer_types) == ["not_date"]: + answers = [[entity for entity in answer + if not re.findall(r"[\d]{3,4}-[\d]{1,2}-[\d]{1,2}", entity)] for answer in answers] else: answer_types = set(answer_types) - combs = [[entity for entity in comb - if answer_types.intersection(self.find_types(entity))] for comb in combs] - combs = [comb for comb in combs if any([entity for entity in comb])] - - return combs - - def search(self, query: List[str], unknown_elem_positions: List[Tuple[int, str]], rel_type) -> List[Dict[str, str]]: + answers = [[entity for entity in answer + if answer_types.intersection(self.find_types(entity))] for answer in answers] + if is_boolean: + answers = [["Yes" if len(triplets) > 0 else "No"]] + found_rels = [[elem[key] for key in rels_from_query if key in elem] for elem in combs] + ans_rels_combs = [(answer, rel, comb) for answer, rel, comb in zip(answers, found_rels, combs) + if any([entity for entity in answer])] + answers = [elem[0] for elem in ans_rels_combs] + found_rels = [elem[1] for elem in ans_rels_combs] + found_combs = [elem[2] for elem in ans_rels_combs] + + return answers, found_rels, found_combs + + @staticmethod + def define_is_boolean(query_hdt_seq): + return len(query_hdt_seq) == 1 and all([not query_hdt_seq[0][i].startswith("?") for i in [0, 2]]) + + @staticmethod + def merge_combs(comb1, comb2): + new_comb = {} + for key in comb1: + if (key in comb2 and comb1[key] == comb2[key]) or key not in comb2: + new_comb[key] = comb1[key] + for key in comb2: + if (key in comb1 and comb2[key] == comb1[key]) or key not in comb1: + new_comb[key] = comb2[key] + return new_comb + + def search(self, query: List[str], unknown_elem_positions: List[Tuple[int, str]], rel_type): query = list(map(lambda elem: "" if elem.startswith('?') else elem, query)) subj, rel, obj = query if self.file_format == "hdt": combs = [] triplets, cnt = self.document.search_triples(subj, rel, obj) if cnt < self.max_comb_num: + triplets = list(triplets) if rel == self.prefixes["description"] or rel == self.prefixes["label"]: triplets = [triplet for triplet in triplets if triplet[2].endswith(self.lang)] combs = [{elem: triplet[pos] for pos, elem in unknown_elem_positions} for triplet in triplets] else: - combs = [{elem: triplet[pos] for pos, elem in unknown_elem_positions} for triplet in triplets - if triplet[1].startswith(self.prefixes["rels"][rel_type])] + if isinstance(self.prefixes["rels"][rel_type], str): + combs = [{elem: triplet[pos] for pos, elem in unknown_elem_positions} for triplet in triplets + if (triplet[1].startswith(self.prefixes["rels"][rel_type]) + or triplet[1].startswith(self.prefixes["rels"]["type"]))] + else: + combs = [{elem: triplet[pos] for pos, elem in unknown_elem_positions} for triplet in triplets + if (any(triplet[1].startswith(tp) for tp in self.prefixes["rels"][rel_type]) + or triplet[1].startswith(self.prefixes["rels"]["type"]))] else: - log.debug("max comb num exceede") + log.debug("max comb num exceeds") else: triplets = [] if subj: @@ -311,9 +391,9 @@ def search(self, query: List[str], unknown_elem_positions: List[Tuple[int, str]] triplets = [triplet for triplet in triplets if triplet[1] == rel] combs = [{elem: triplet[pos] for pos, elem in unknown_elem_positions} for triplet in triplets] - return combs + return combs, triplets - def find_label(self, entity: str, question: str) -> str: + def find_label(self, entity: str, question: str = "") -> str: entity = str(entity).replace('"', '') if self.file_format == "hdt": if entity.startswith("Q") or entity.startswith("P"): @@ -327,11 +407,10 @@ def find_label(self, entity: str, question: str) -> str: # '"Lake Baikal"@en'], ...] for label in labels: if label[2].endswith(self.lang): - found_label = label[2].strip(self.lang).replace('"', '').replace('$', ' ').replace(' ', ' ') - return found_label - for label in labels: - if label[2].endswith("@en"): - found_label = label[2].strip("@en").replace('"', '').replace('$', ' ').replace(' ', ' ') + found_label = label[2].strip(self.lang) + for old_tok, new_tok in self.replace_tokens: + found_label = found_label.replace(old_tok, new_tok) + found_label = found_label.strip() return found_label elif entity.endswith(self.lang): @@ -349,11 +428,17 @@ def find_label(self, entity: str, question: str) -> str: for token in ["T00:00:00Z", "+"]: entity = entity.replace(token, '') entity = self.format_date(entity, question).replace('$', '') + return entity + elif re.findall(r"[\d]{3,4}-[\d]{2}-[\d]{2}", entity): + entity = self.format_date(entity, question).replace('$', '') + return entity + + elif entity in ["Yes", "No"]: return entity elif entity.isdigit(): - entity = str(entity).replace('.', ',') + entity = entity.replace('.', ',') return entity if self.file_format == "pickle": @@ -384,7 +469,7 @@ def format_date(self, entity, question): entity = year elif "в каком месяце" in question.lower(): entity = month - elif day != "00": + elif day not in {"00", "0"}: date = datetime.datetime.strptime(f"{year}-{month}-{day}", "%Y-%m-%d") entity = date.strftime("%d %B %Y") else: @@ -403,7 +488,7 @@ def find_alias(self, entity: str) -> List[str]: aliases = [label[2].strip(self.lang).strip('"') for label in labels if label[2].endswith(self.lang)] return aliases - def find_rels(self, entity: str, direction: str, rel_type: str = "no_type", save: bool = False) -> List[str]: + def find_rels(self, entity: str, direction: str, rel_type: str = "no_type") -> List[str]: rels = [] if self.file_format == "hdt": if not rel_type: @@ -413,14 +498,40 @@ def find_rels(self, entity: str, direction: str, rel_type: str = "no_type", save else: query = ["", "", f"{self.prefixes['entity']}/{entity}"] triplets, c = self.document.search_triples(*query) - - start_str = f"{self.prefixes['rels'][rel_type]}/P" - rels = {triplet[1] for triplet in triplets if triplet[1].startswith(start_str)} + triplets = list(triplets) + if isinstance(self.prefixes['rels'][rel_type], str): + start_str = f"{self.prefixes['rels'][rel_type]}/P" + rels = {triplet[1] for triplet in triplets if triplet[1].startswith(start_str)} + else: + rels = {triplet[1] for triplet in triplets + if any([triplet[1].startswith(tp) for tp in self.prefixes['rels'][rel_type]])} rels = list(rels) - if self.file_format == "pickle": - triplets = self.document.get(entity, {}).get(direction, []) - triplets = self.uncompress(triplets) - rels = [triplet[0] for triplet in triplets if triplet[0].startswith("P")] + if self.used_rels: + rels = [rel for rel in rels if rel.split("/")[-1] in self.used_rels] + return rels + + def find_rels_2hop(self, entity_ids, rels_1hop): + rels = [] + for entity_id in entity_ids: + for rel_1hop in rels_1hop: + triplets, cnt = self.document.search_triples(f"{self.prefixes['entity']}/{entity_id}", rel_1hop, "") + triplets = [triplet for triplet in triplets if triplet[2].startswith(self.prefixes['entity'])] + objects_1hop = [triplet[2].split("/")[-1] for triplet in triplets] + triplets, cnt = self.document.search_triples("", rel_1hop, f"{self.prefixes['entity']}/{entity_id}") + triplets = [triplet for triplet in triplets if triplet[0].startswith(self.prefixes['entity'])] + objects_1hop += [triplet[0].split("/")[-1] for triplet in triplets] + for object_1hop in objects_1hop[:5]: + tr_2hop, cnt = self.document.search_triples(f"{self.prefixes['entity']}/{object_1hop}", "", "") + rels_2hop = [elem[1] for elem in tr_2hop if elem[1] != rel_1hop] + if self.used_rels: + rels_2hop = [elem for elem in rels_2hop if elem.split("/")[-1] in self.used_rels] + rels += rels_2hop + tr_2hop, cnt = self.document.search_triples("", "", f"{self.prefixes['entity']}/{object_1hop}") + rels_2hop = [elem[1] for elem in tr_2hop if elem[1] != rel_1hop] + if self.used_rels: + rels_2hop = [elem for elem in rels_2hop if elem.split("/")[-1] in self.used_rels] + rels += rels_2hop + rels = list(set(rels)) return rels def find_object(self, entity: str, rel: str, direction: str) -> List[str]: @@ -477,9 +588,10 @@ def find_types(self, entity: str): entity = f"{self.prefixes['entity']}/{entity}" tr, c = self.document.search_triples(entity, f"{self.prefixes['rels']['direct']}/P31", "") types = [triplet[2].split('/')[-1] for triplet in tr] - if "Q5" in types: - tr, c = self.document.search_triples(entity, f"{self.prefixes['rels']['direct']}/P106", "") + for rel in ["P106", "P21"]: + tr, c = self.document.search_triples(entity, f"{self.prefixes['rels']['direct']}/{rel}", "") types += [triplet[2].split('/')[-1] for triplet in tr] + if self.file_format == "pickle": entity = entity.split('/')[-1] triplets = self.document.get(entity, {}).get("forw", []) @@ -532,3 +644,45 @@ def find_triplets(self, subj: str, direction: str) -> Tuple[str, List[List[str]] triplets = self.document.get(subj, {}).get(direction, []) triplets = self.uncompress(triplets) return subj, triplets + + def fill_triplets(self, init_triplets, what_to_return, comb): + filled_triplets = [] + for n, (subj, rel, obj) in enumerate(init_triplets): + if "statement" in self.prefixes and subj.startswith("?") \ + and comb.get(subj, "").startswith(self.prefixes["statement"]) and not rel.startswith("?") \ + and (obj == what_to_return[0] or re.findall(r"[\d]{3,4}", comb.get(what_to_return[0], ""))): + continue + else: + if "statement" in self.prefixes and subj.startswith("?") \ + and str(comb.get(subj, "")).startswith(self.prefixes["statement"]): + if not comb.get(what_to_return[0], "").startswith("http") \ + and re.findall(r"[\d]{3,4}", comb.get(what_to_return[0], "")): + subj = init_triplets[1][2] + else: + subj = what_to_return[0] + if "statement" in self.prefixes and obj.startswith("?") \ + and str(comb.get(obj, "")).startswith(self.prefixes["statement"]): + if not str(comb.get(what_to_return[0], "")).startswith("http") \ + and re.findall(r"[\d]{3,4}", str(comb.get(what_to_return[0], ""))): + obj = init_triplets[1][2] + else: + obj = what_to_return[0] + subj, obj = str(subj), str(obj) + if subj.startswith("?"): + subj = comb.get(subj, "") + if obj.startswith("?"): + obj = comb.get(obj, "") + if rel.startswith("?"): + rel = comb.get(rel, "") + subj_label = self.find_label(subj) + obj_label = self.find_label(obj) + if rel in self.rel_q2name: + rel_label = self.rel_q2name[rel] + elif rel.split("/")[-1] in self.rel_q2name: + rel_label = self.rel_q2name[rel.split("/")[-1]] + else: + rel_label = self.find_label(rel) + if isinstance(rel_label, list) and rel_label: + rel_label = rel_label[0] + filled_triplets.append([subj_label, rel_label, obj_label]) + return filled_triplets diff --git a/deeppavlov/models/preprocessors/str_utf8_encoder.py b/deeppavlov/models/preprocessors/str_utf8_encoder.py index 8826380ebb..6647f68ed0 100644 --- a/deeppavlov/models/preprocessors/str_utf8_encoder.py +++ b/deeppavlov/models/preprocessors/str_utf8_encoder.py @@ -20,7 +20,6 @@ from typing import Union, List, Tuple import numpy as np -from overrides import overrides from deeppavlov.core.common.errors import ConfigError from deeppavlov.core.common.registry import register @@ -130,7 +129,6 @@ def __call__(self, batch: Union[List[str], Tuple[str]]) -> StrUTF8EncoderInfo: raise RuntimeError(f'The objects passed to the reverser are not list or tuple of str! ' f' But they are {type(batch)}.') - @overrides def load(self) -> None: if self.load_path: if self.load_path.is_file(): @@ -144,14 +142,12 @@ def load(self) -> None: else: raise ConfigError(f"`load_path` for {self} is not provided!") - @overrides def save(self) -> None: log.info(f"[saving vocabulary to {self.save_path}]") with self.save_path.open('wt', encoding='utf8') as f: for token in self._word_char_ids.keys(): f.write('{}\n'.format(token)) - @overrides def fit(self, *args) -> None: words = chain(*args) # filter(None, <>) -- to filter empty words diff --git a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py index dc6e4dad28..effd4388d1 100644 --- a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py +++ b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py @@ -368,15 +368,11 @@ class RelRankingPreprocessor(Component): def __init__(self, vocab_file: str, - add_special_tokens: List[str], do_lower_case: bool = True, max_seq_length: int = 512, **kwargs) -> None: self.max_seq_length = max_seq_length self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case) - self.add_special_tokens = add_special_tokens - special_tokens_dict = {'additional_special_tokens': add_special_tokens} - self.tokenizer.add_special_tokens(special_tokens_dict) def __call__(self, questions_batch: List[List[str]], rels_batch: List[List[str]] = None) -> Dict[str, torch.tensor]: """Tokenize questions and relations @@ -384,46 +380,97 @@ def __call__(self, questions_batch: List[List[str]], rels_batch: List[List[str]] Args: questions_batch: list of texts, rels_batch: list of relations list - Returns: batch of :class:`transformers.data.processors.utils.InputFeatures` with subtokens, subtoken ids, \ subtoken mask, segment mask, or tuple of batch of InputFeatures and Batch of subtokens """ - lengths = [] + lengths, proc_rels_batch = [], [] for question, rels_list in zip(questions_batch, rels_batch): if isinstance(rels_list, list): - rels_str = self.add_special_tokens[2].join(rels_list) + rels_str = " ".join(rels_list) else: rels_str = rels_list - text_input = f"{self.add_special_tokens[0]} {question} {self.add_special_tokens[1]} {rels_str}" - encoding = self.tokenizer.encode_plus(text=text_input, + encoding = self.tokenizer.encode_plus(text=question, text_pair=rels_str, return_attention_mask=True, add_special_tokens=True, truncation=True) lengths.append(len(encoding["input_ids"])) + proc_rels_batch.append(rels_str) max_len = max(lengths) - input_ids_batch = [] - attention_mask_batch = [] - token_type_ids_batch = [] - for question, rels_list in zip(questions_batch, rels_batch): - if isinstance(rels_list, list): - rels_str = self.add_special_tokens[2].join(rels_list) - else: - rels_str = rels_list - text_input = f"{self.add_special_tokens[0]} {question} {self.add_special_tokens[1]} {rels_str}" - encoding = self.tokenizer.encode_plus(text=text_input, - truncation = True, max_length=max_len, - pad_to_max_length=True, return_attention_mask = True) + input_ids_batch, attention_mask_batch, token_type_ids_batch = [], [], [] + for question, rels_list in zip(questions_batch, proc_rels_batch): + encoding = self.tokenizer.encode_plus(text=question, text_pair=rels_list, + truncation=True, max_length=max_len, + pad_to_max_length=True, return_attention_mask=True) input_ids_batch.append(encoding["input_ids"]) attention_mask_batch.append(encoding["attention_mask"]) if "token_type_ids" in encoding: token_type_ids_batch.append(encoding["token_type_ids"]) else: token_type_ids_batch.append([0]) - input_features = {"input_ids": torch.LongTensor(input_ids_batch), "attention_mask": torch.LongTensor(attention_mask_batch), "token_type_ids": torch.LongTensor(token_type_ids_batch)} + return input_features + +@register('path_ranking_preprocessor') +class PathRankingPreprocessor(Component): + def __init__(self, + vocab_file: str, + additional_special_tokens: List[str] = None, + do_lower_case: bool = True, + max_seq_length: int = 67, + **kwargs) -> None: + self.max_seq_length = max_seq_length + self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case) + self.additional_special_tokens = additional_special_tokens + if self.additional_special_tokens: + self.tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens}) + + def __call__(self, questions_batch: List[str], rels_batch: List[List[List[str]]]): + lengths, proc_rels_batch = [], [] + for question, rels_list in zip(questions_batch, rels_batch): + proc_rels_list = [] + for rels in rels_list: + if isinstance(rels, str): + rels = [rels] + rels_str = "" + if len(rels) == 1: + if self.additional_special_tokens: + rels_str = f" {rels[0]} " + else: + rels_str = rels[0] + elif len(rels) == 2: + if rels[0] == rels[1]: + rels_str = f" {rels[0]} " + else: + rels_str = f" {rels[0]} {rels[1]} " + encoding = self.tokenizer.encode_plus(text=question, text_pair=rels_str, + return_attention_mask=True, add_special_tokens=True, + truncation=True) + lengths.append(len(encoding["input_ids"])) + proc_rels_list.append(rels_str) + proc_rels_batch.append(proc_rels_list) + + max_len = min(max(lengths), self.max_seq_length) + input_ids_batch, attention_mask_batch, token_type_ids_batch = [], [], [] + for question, rels_list in zip(questions_batch, proc_rels_batch): + input_ids_list, attention_mask_list, token_type_ids_list = [], [], [] + for rels_str in rels_list: + encoding = self.tokenizer.encode_plus(text=question, text_pair=rels_str, + truncation=True, max_length=max_len, add_special_tokens=True, + pad_to_max_length=True, return_attention_mask=True) + input_ids_list.append(encoding["input_ids"]) + attention_mask_list.append(encoding["attention_mask"]) + if "token_type_ids" in encoding: + token_type_ids_list.append(encoding["token_type_ids"]) + else: + token_type_ids_list.append([0]) + input_ids_batch.append(input_ids_list) + attention_mask_batch.append(attention_mask_list) + token_type_ids_batch.append(token_type_ids_list) + input_features = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch, + "token_type_ids": token_type_ids_batch} return input_features diff --git a/deeppavlov/models/torch_bert/multitask_transformer.py b/deeppavlov/models/torch_bert/multitask_transformer.py index cead7a99c5..e7a8461237 100644 --- a/deeppavlov/models/torch_bert/multitask_transformer.py +++ b/deeppavlov/models/torch_bert/multitask_transformer.py @@ -20,7 +20,6 @@ import numpy as np import torch import torch.nn as nn -from overrides import overrides from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss from transformers import AutoConfig, AutoModel @@ -342,7 +341,6 @@ def __init__( def _reset_cache(self): self.preds_cache = {index_: None for index_ in self.types_to_cache if index_ != -1} - @overrides def init_from_opt(self) -> None: """ Initialize from scratch `self.model` with the architecture built @@ -401,7 +399,6 @@ def get_decay_params(model): return [ torch.optim.lr_scheduler, self.lr_scheduler_name )(self.optimizer, **self.lr_scheduler_parameters) - @overrides def load(self, fname: Optional[str] = None) -> None: """ Loads weights. diff --git a/deeppavlov/models/torch_bert/torch_bert_ranker.py b/deeppavlov/models/torch_bert/torch_bert_ranker.py index 261e4bd03e..e574fe4948 100644 --- a/deeppavlov/models/torch_bert/torch_bert_ranker.py +++ b/deeppavlov/models/torch_bert/torch_bert_ranker.py @@ -18,7 +18,6 @@ import numpy as np import torch -from overrides import overrides from transformers import AutoModelForSequenceClassification, AutoConfig from transformers.data.processors.utils import InputFeatures @@ -155,7 +154,6 @@ def __call__(self, features_li: List[List[InputFeatures]]) -> Union[List[int], L return predictions - @overrides def load(self, fname=None): if fname is not None: self.load_path = fname diff --git a/deeppavlov/models/torch_bert/torch_transformers_classifier.py b/deeppavlov/models/torch_bert/torch_transformers_classifier.py index d2449dafc2..80036eb682 100644 --- a/deeppavlov/models/torch_bert/torch_transformers_classifier.py +++ b/deeppavlov/models/torch_bert/torch_transformers_classifier.py @@ -12,20 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from logging import getLogger from pathlib import Path from typing import List, Dict, Union, Optional, Tuple import numpy as np import torch -from overrides import overrides from torch.nn import BCEWithLogitsLoss from transformers import AutoModelForSequenceClassification, AutoConfig, AutoModel, AutoTokenizer from transformers.modeling_outputs import SequenceClassifierOutput -from deeppavlov.core.common.errors import ConfigError from deeppavlov.core.commands.utils import expand_path +from deeppavlov.core.common.errors import ConfigError from deeppavlov.core.common.registry import register from deeppavlov.core.models.torch_model import TorchModel @@ -196,7 +194,6 @@ def is_data_parallel(self) -> bool: return isinstance(self.model, torch.nn.DataParallel) # TODO this method requires massive refactoring - @overrides def load(self, fname=None): if fname is not None: self.load_path = fname diff --git a/deeppavlov/models/torch_bert/torch_transformers_multiplechoice.py b/deeppavlov/models/torch_bert/torch_transformers_multiplechoice.py index c989715d10..078542e7df 100644 --- a/deeppavlov/models/torch_bert/torch_transformers_multiplechoice.py +++ b/deeppavlov/models/torch_bert/torch_transformers_multiplechoice.py @@ -18,7 +18,6 @@ import numpy as np import torch -from overrides import overrides from transformers import AutoModelForMultipleChoice, AutoConfig from deeppavlov.core.common.errors import ConfigError @@ -158,7 +157,6 @@ def __call__(self, features: Dict[str, torch.tensor]) -> Union[List[int], List[L return pred - @overrides def load(self, fname = None): if fname is not None: self.load_path = fname diff --git a/deeppavlov/models/torch_bert/torch_transformers_nll_ranking.py b/deeppavlov/models/torch_bert/torch_transformers_nll_ranking.py new file mode 100644 index 0000000000..4e4ae1940e --- /dev/null +++ b/deeppavlov/models/torch_bert/torch_transformers_nll_ranking.py @@ -0,0 +1,221 @@ +# Copyright 2017 Neural Networks and Deep Learning lab, MIPT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from logging import getLogger +from pathlib import Path +from typing import List, Optional, Dict, Tuple, Union, Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from transformers import AutoConfig, AutoModel, AutoTokenizer + +from deeppavlov.core.commands.utils import expand_path +from deeppavlov.core.common.errors import ConfigError +from deeppavlov.core.common.registry import register +from deeppavlov.core.models.torch_model import TorchModel + +log = getLogger(__name__) + + +@register('torch_transformers_nll_ranker') +class TorchTransformersNLLRanker(TorchModel): + """Class for ranking of relations using the model trained with NLL loss + Args: + model_name: name of the function which initialises and returns the model class + pretrained_bert: pretrained transformer checkpoint path or key title (e.g. "bert-base-uncased") + encoder_save_path: path to save the encoder checkpoint + linear_save_path: path to save linear layer checkpoint + optimizer: optimizer name from `torch.optim` + optimizer_parameters: dictionary with optimizer's parameters, + e.g. {'lr': 0.1, 'weight_decay': 0.001, 'momentum': 0.9} + return_probas: set this to `True` if you need the probabilities instead of raw answers + clip_norm: clip gradients by norm + """ + + def __init__( + self, + model_name: str, + pretrained_bert: str = None, + encoder_save_path: str = None, + linear_save_path: str = None, + optimizer: str = "AdamW", + optimizer_parameters: Dict = None, + return_probas: bool = False, + clip_norm: Optional[float] = None, + **kwargs + ): + self.pretrained_bert = pretrained_bert + self.encoder_save_path = encoder_save_path + self.linear_save_path = linear_save_path + self.return_probas = return_probas + self.clip_norm = clip_norm + if optimizer_parameters is None: + optimizer_parameters = {"lr": 1e-5, "weight_decay": 0.01, "eps": 1e-6} + + super().__init__( + model_name=model_name, + optimizer=optimizer, + optimizer_parameters=optimizer_parameters, + **kwargs) + + def train_on_batch(self, input_features: Dict[str, Any], positive_idx: List[int]) -> float: + _input = {'positive_idx': positive_idx, + "input_ids": torch.LongTensor(input_features["input_ids"]).to(self.device), + "attention_mask": torch.LongTensor(input_features["attention_mask"]).to(self.device), + "token_type_ids": torch.LongTensor(input_features["token_type_ids"]).to(self.device)} + + self.model.train() + self.model.zero_grad() + self.optimizer.zero_grad() # zero the parameter gradients + + loss, softmax_scores = self.model(**_input) + loss.backward() + self.optimizer.step() + + # Clip the norm of the gradients to prevent the "exploding gradients" problem + if self.clip_norm: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + return loss.item() + + def __call__(self, input_features: Dict[str, Any]) -> Union[List[int], List[np.ndarray]]: + self.model.eval() + _input = {"input_ids": torch.LongTensor(input_features["input_ids"]).to(self.device), + "attention_mask": torch.LongTensor(input_features["attention_mask"]).to(self.device), + "token_type_ids": torch.LongTensor(input_features["token_type_ids"]).to(self.device)} + + with torch.no_grad(): + output = self.model(**_input) + if isinstance(output, tuple) and len(output) == 2: + loss, softmax_scores = output + else: + softmax_scores = output + if self.return_probas: + softmax_scores = softmax_scores.cpu().numpy().tolist() + return softmax_scores + else: + pred = torch.argmax(softmax_scores, dim=1) + pred = pred.cpu() + pred = pred.numpy() + return pred + + def in_batch_ranking_model(self, **kwargs) -> nn.Module: + return NLLRanking( + pretrained_bert=self.pretrained_bert, + encoder_save_path=self.encoder_save_path, + linear_save_path=self.linear_save_path, + bert_tokenizer_config_file=self.pretrained_bert, + device=self.device + ) + + def save(self, fname: Optional[str] = None, *args, **kwargs) -> None: + if fname is None: + fname = self.save_path + if not fname.parent.is_dir(): + raise ConfigError("Provided save path is incorrect!") + weights_path = Path(fname).with_suffix(f".pth.tar") + log.info(f"Saving model to {weights_path}.") + torch.save({ + "model_state_dict": self.model.cpu().state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "epochs_done": self.epochs_done + }, weights_path) + self.model.to(self.device) + + +class NLLRanking(nn.Module): + """Class which implements the relation ranking model + Args: + pretrained_bert: pretrained transformer checkpoint path or key title (e.g. "bert-base-uncased") + encoder_save_path: path to save the encoder checkpoint + linear_save_path: path to save linear layer checkpoint + bert_tokenizer_config_file: path to configuration file of transformer tokenizer + device: cpu or gpu + """ + + def __init__( + self, + pretrained_bert: str = None, + encoder_save_path: str = None, + linear_save_path: str = None, + bert_tokenizer_config_file: str = None, + device: str = "gpu" + ): + super().__init__() + self.pretrained_bert = pretrained_bert + self.encoder_save_path = encoder_save_path + self.linear_save_path = linear_save_path + self.device = torch.device("cuda" if torch.cuda.is_available() and device == "gpu" else "cpu") + + # initialize parameters that would be filled later + self.encoder, self.config, self.bert_config = None, None, None + self.load() + + if Path(bert_tokenizer_config_file).is_file(): + vocab_file = str(expand_path(bert_tokenizer_config_file)) + tokenizer = AutoTokenizer(vocab_file=vocab_file) + else: + tokenizer = AutoTokenizer.from_pretrained(pretrained_bert) + self.encoder.resize_token_embeddings(len(tokenizer) + 7) + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + token_type_ids: Tensor, + positive_idx: List[List[int]] = None + ) -> Union[Tuple[Any, Tensor], Tuple[Tensor]]: + + bs, samples_num, seq_len = input_ids.size() + input_ids = input_ids.reshape(bs * samples_num, -1) + attention_mask = attention_mask.reshape(bs * samples_num, -1) + token_type_ids = token_type_ids.reshape(bs * samples_num, -1) + if hasattr(self.config, "type_vocab_size"): + encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask, + token_type_ids=token_type_ids) + else: + encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask) + cls_emb = encoder_output.last_hidden_state[:, :1, :].squeeze(1) + scores = self.fc(cls_emb) + scores = scores.reshape(bs, samples_num) + + if positive_idx is not None: + scores = F.log_softmax(scores, dim=1) + positive_idx = [] + for i in range(bs): + positive_idx.append(0) + loss = F.nll_loss(scores, torch.tensor(positive_idx).to(scores.device), reduction="mean") + return loss, scores + else: + return scores + + def load(self) -> None: + if self.pretrained_bert: + log.info(f"From pretrained {self.pretrained_bert}.") + self.config = AutoConfig.from_pretrained( + self.pretrained_bert, output_hidden_states=True + ) + self.encoder = AutoModel.from_pretrained(self.pretrained_bert, config=self.config) + self.fc = nn.Linear(self.config.hidden_size, 1) + else: + raise ConfigError("No pre-trained BERT model is given.") + + self.encoder.to(self.device) + self.fc.to(self.device) diff --git a/deeppavlov/models/torch_bert/torch_transformers_sequence_tagger.py b/deeppavlov/models/torch_bert/torch_transformers_sequence_tagger.py index fc78e4f32e..e18528f02c 100644 --- a/deeppavlov/models/torch_bert/torch_transformers_sequence_tagger.py +++ b/deeppavlov/models/torch_bert/torch_transformers_sequence_tagger.py @@ -18,7 +18,6 @@ import numpy as np import torch -from overrides import overrides from transformers import AutoModelForTokenClassification, AutoConfig from deeppavlov.core.commands.utils import expand_path @@ -270,7 +269,6 @@ def __call__(self, return pred, probas - @overrides def load(self, fname=None): if fname is not None: self.load_path = fname @@ -311,10 +309,11 @@ def load(self, fname=None): else: log.warning(f"Init from scratch. Load path {weights_path_crf} does not exist.") - @overrides def save(self, fname: Optional[str] = None, *args, **kwargs) -> None: - super().save() + super().save(fname) if self.use_crf: + if fname is None: + fname = self.save_path weights_path_crf = Path(f"{fname}_crf").resolve() weights_path_crf = weights_path_crf.with_suffix(".pth.tar") torch.save({"model_state_dict": self.crf.cpu().state_dict()}, weights_path_crf) diff --git a/deeppavlov/models/torch_bert/torch_transformers_squad.py b/deeppavlov/models/torch_bert/torch_transformers_squad.py index 83aee7bc7e..4122edbf2f 100644 --- a/deeppavlov/models/torch_bert/torch_transformers_squad.py +++ b/deeppavlov/models/torch_bert/torch_transformers_squad.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from logging import getLogger from pathlib import Path from typing import List, Tuple, Optional, Dict import numpy as np import torch -from overrides import overrides from transformers import AutoModelForQuestionAnswering, AutoConfig from transformers.data.processors.utils import InputFeatures @@ -264,7 +262,6 @@ def __call__(self, features_batch: List[List[InputFeatures]]) -> Tuple[ return start_pred_batch, end_pred_batch, logits_batch, scores_batch, ind_batch - @overrides def load(self, fname=None): if fname is not None: self.load_path = fname diff --git a/deeppavlov/requirements/razdel.txt b/deeppavlov/requirements/razdel.txt new file mode 100644 index 0000000000..6334fc66ba --- /dev/null +++ b/deeppavlov/requirements/razdel.txt @@ -0,0 +1 @@ +razdel==0.5.0 diff --git a/deeppavlov/utils/agent/__init__.py b/deeppavlov/utils/agent/__init__.py deleted file mode 100644 index b737818baf..0000000000 --- a/deeppavlov/utils/agent/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .server import start_rabbit_service diff --git a/deeppavlov/utils/agent/messages.py b/deeppavlov/utils/agent/messages.py deleted file mode 100644 index 568b417722..0000000000 --- a/deeppavlov/utils/agent/messages.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2019 Neural Networks and Deep Learning lab, MIPT -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Module contains classes defining messages received and sent by service via RabbitMQ message broker. - -The classes created to document the DeepPavlov Agent API and should match the corresponding classes -from https://github.com/deepmipt/dp-agent/blob/master/core/transport/messages.py - -""" - -from typing import Any - - -class MessageBase: - agent_name: str - msg_type: str - - def __init__(self, msg_type: str, agent_name: str) -> None: - self.msg_type = msg_type - self.agent_name = agent_name - - @classmethod - def from_json(cls, message_json: dict): - return cls(**message_json) - - def to_json(self) -> dict: - return self.__dict__ - - -class ServiceTaskMessage(MessageBase): - payload: dict - - def __init__(self, agent_name: str, payload: dict) -> None: - super().__init__('service_task', agent_name) - self.payload = payload - - -class ServiceResponseMessage(MessageBase): - response: Any - task_id: str - - def __init__(self, task_id: str, agent_name: str, response: Any) -> None: - super().__init__('service_response', agent_name) - self.task_id = task_id - self.response = response - - -class ServiceErrorMessage(MessageBase): - formatted_exc: str - - def __init__(self, task_id: str, agent_name: str, formatted_exc: str) -> None: - super().__init__('error', agent_name) - self.task_id = task_id - self.formatted_exc = formatted_exc - - @property - def exception(self) -> Exception: - return Exception(self.formatted_exc) - - -def get_service_task_message(message_json: dict) -> ServiceTaskMessage: - """Creates an instance of ServiceTaskMessage class using its json representation. - - Args: - message_json: Dictionary with class fields. - - Returns: - New ServiceTaskMessage instance. - - Raises: - ValueError if dict with instance fields isn't from an instance of ServiceTaskMessage class. - - """ - message_type = message_json.pop('msg_type') - - if message_type != 'service_task': - raise TypeError(f'Unknown transport message type: {message_type}') - - return ServiceTaskMessage.from_json(message_json) diff --git a/deeppavlov/utils/agent/rabbitmq.py b/deeppavlov/utils/agent/rabbitmq.py deleted file mode 100644 index ea82955374..0000000000 --- a/deeppavlov/utils/agent/rabbitmq.py +++ /dev/null @@ -1,261 +0,0 @@ -# Copyright 2019 Neural Networks and Deep Learning lab, MIPT -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import json -import logging -import time -from collections import defaultdict -from pathlib import Path -from traceback import format_exc -from typing import Any, Dict, List, Optional, Union - -import aio_pika -from aio_pika import Connection, Channel, Exchange, Queue, IncomingMessage, Message - -from deeppavlov.core.commands.infer import build_model -from deeppavlov.core.common.chainer import Chainer -from deeppavlov.core.data.utils import jsonify_data -from deeppavlov.utils.agent.messages import ServiceTaskMessage, ServiceResponseMessage, ServiceErrorMessage -from deeppavlov.utils.agent.messages import get_service_task_message -from deeppavlov.utils.connector import DialogLogger -from deeppavlov.utils.server import get_server_params - -dialog_logger = DialogLogger(logger_name='agent_rabbit') -log = logging.getLogger(__name__) - -AGENT_IN_EXCHANGE_NAME_TEMPLATE = '{agent_namespace}_e_in' -AGENT_OUT_EXCHANGE_NAME_TEMPLATE = '{agent_namespace}_e_out' -AGENT_ROUTING_KEY_TEMPLATE = 'agent.{agent_name}' - -SERVICE_QUEUE_NAME_TEMPLATE = '{agent_namespace}_q_service_{service_name}' -SERVICE_ROUTING_KEY_TEMPLATE = 'service.{service_name}' - - -class RabbitMQServiceGateway: - """Class object connects to the RabbitMQ broker to process requests from the DeepPavlov Agent.""" - _add_to_buffer_lock: asyncio.Lock - _infer_lock: asyncio.Lock - _model: Chainer - _model_args_names: List[str] - _incoming_messages_buffer: List[IncomingMessage] - _batch_size: int - _utterance_lifetime_sec: int - _in_queue: Optional[Queue] - _connection: Connection - _agent_in_exchange: Exchange - _agent_out_exchange: Exchange - _agent_in_channel: Channel - _agent_out_channel: Channel - - def __init__(self, - model_config: Union[str, Path], - service_name: str, - agent_namespace: str, - batch_size: int, - utterance_lifetime_sec: int, - rabbit_host: str, - rabbit_port: int, - rabbit_login: str, - rabbit_password: str, - rabbit_virtualhost: str, - loop: asyncio.AbstractEventLoop) -> None: - self._add_to_buffer_lock = asyncio.Lock() - self._infer_lock = asyncio.Lock() - server_params = get_server_params(model_config) - self._model_args_names = server_params['model_args_names'] - self._model = build_model(model_config) - self._in_queue = None - self._utterance_lifetime_sec = utterance_lifetime_sec - self._batch_size = batch_size - self._incoming_messages_buffer = [] - - loop.run_until_complete(self._connect(loop=loop, host=rabbit_host, port=rabbit_port, login=rabbit_login, - password=rabbit_password, virtualhost=rabbit_virtualhost, - agent_namespace=agent_namespace)) - loop.run_until_complete(self._setup_queues(service_name, agent_namespace)) - loop.run_until_complete(self._in_queue.consume(callback=self._on_message_callback)) - - log.info(f'Service in queue started consuming') - - async def _connect(self, - loop: asyncio.AbstractEventLoop, - host: str, - port: int, - login: str, - password: str, - virtualhost: str, - agent_namespace: str) -> None: - """Connects to RabbitMQ message broker and initiates agent in and out channels and exchanges.""" - log.info('Starting RabbitMQ connection...') - - while True: - try: - self._connection = await aio_pika.connect_robust(loop=loop, - host=host, - port=port, - login=login, - password=password, - virtualhost=virtualhost) - log.info('RabbitMQ connected') - break - except ConnectionError: - reconnect_timeout = 5 - log.error(f'RabbitMQ connection error, making another attempt in {reconnect_timeout} secs') - time.sleep(reconnect_timeout) - - self._agent_in_channel = await self._connection.channel() - agent_in_exchange_name = AGENT_IN_EXCHANGE_NAME_TEMPLATE.format(agent_namespace=agent_namespace) - self._agent_in_exchange = await self._agent_in_channel.declare_exchange(name=agent_in_exchange_name, - type=aio_pika.ExchangeType.TOPIC) - log.info(f'Declared agent in exchange: {agent_in_exchange_name}') - - self._agent_out_channel = await self._connection.channel() - agent_out_exchange_name = AGENT_OUT_EXCHANGE_NAME_TEMPLATE.format(agent_namespace=agent_namespace) - self._agent_out_exchange = await self._agent_in_channel.declare_exchange(name=agent_out_exchange_name, - type=aio_pika.ExchangeType.TOPIC) - log.info(f'Declared agent out exchange: {agent_out_exchange_name}') - - def disconnect(self): - self._connection.close() - - async def _setup_queues(self, service_name: str, agent_namespace: str) -> None: - """Setups input queue to get messages from DeepPavlov Agent.""" - in_queue_name = SERVICE_QUEUE_NAME_TEMPLATE.format(agent_namespace=agent_namespace, - service_name=service_name) - - self._in_queue = await self._agent_out_channel.declare_queue(name=in_queue_name, durable=True) - log.info(f'Declared service in queue: {in_queue_name}') - - service_routing_key = SERVICE_ROUTING_KEY_TEMPLATE.format(service_name=service_name) - await self._in_queue.bind(exchange=self._agent_out_exchange, routing_key=service_routing_key) - log.info(f'Queue: {in_queue_name} bound to routing key: {service_routing_key}') - - await self._agent_out_channel.set_qos(prefetch_count=self._batch_size * 2) - - async def _on_message_callback(self, message: IncomingMessage) -> None: - """Processes messages from the input queue. - - Collects incoming messages to buffer, sends tasks batches for further processing. Depending on the success of - the processing result sends negative or positive acknowledgements to the input messages. - - """ - await self._add_to_buffer_lock.acquire() - self._incoming_messages_buffer.append(message) - log.debug('Incoming message received') - - if len(self._incoming_messages_buffer) < self._batch_size: - self._add_to_buffer_lock.release() - - await self._infer_lock.acquire() - try: - messages_batch = self._incoming_messages_buffer - valid_messages_batch: List[IncomingMessage] = [] - tasks_batch: List[ServiceTaskMessage] = [] - - if messages_batch: - self._incoming_messages_buffer = [] - - if self._add_to_buffer_lock.locked(): - self._add_to_buffer_lock.release() - - for message in messages_batch: - try: - task = get_service_task_message(json.loads(message.body, encoding='utf-8')) - tasks_batch.append(task) - valid_messages_batch.append(message) - except Exception as e: - log.error(f'Failed to get ServiceTaskMessage from the incoming message: {repr(e)}') - await message.reject() - - elif self._add_to_buffer_lock.locked(): - self._add_to_buffer_lock.release() - - if tasks_batch: - try: - await self._process_tasks(tasks_batch) - except Exception as e: - task_ids = [task.payload["task_id"] for task in tasks_batch] - log.error(f'got exception {repr(e)} while processing tasks {", ".join(task_ids)}') - formatted_exception = format_exc() - error_replies = [self._send_results(task, formatted_exception) for task in tasks_batch] - await asyncio.gather(*error_replies) - for message in valid_messages_batch: - await message.reject() - else: - for message in valid_messages_batch: - await message.ack() - finally: - self._infer_lock.release() - - async def _process_tasks(self, tasks_batch: List[ServiceTaskMessage]) -> None: - """Gets from tasks batch payloads to infer model, processes them and creates tasks to send results.""" - task_uuids_batch, payloads = \ - zip(*[(task.payload['task_id'], task.payload['payload']) for task in tasks_batch]) - - log.debug(f'Prepared to infer tasks {", ".join(task_uuids_batch)}') - - responses_batch = await asyncio.wait_for(self._interact(payloads), - self._utterance_lifetime_sec) - - results_replies = [self._send_results(task, response) for task, response in zip(tasks_batch, responses_batch)] - await asyncio.gather(*results_replies) - - log.debug(f'Processed tasks {", ".join(task_uuids_batch)}') - - async def _interact(self, payloads: List[Dict]) -> List[Any]: - """Infers model with the batch.""" - batch = defaultdict(list) - - for payload in payloads: - for arg_name in self._model_args_names: - batch[arg_name].extend(payload.get(arg_name, [None])) - - dialog_logger.log_in(batch) - - prediction = self._model(*batch.values()) - if len(self._model.out_params) == 1: - prediction = [prediction] - prediction = list(zip(*prediction)) - result = jsonify_data(prediction) - - dialog_logger.log_out(result) - - return result - - async def _send_results(self, task: ServiceTaskMessage, response: Union[Dict, str]) -> None: - """Sends responses batch to the DeepPavlov Agent using agent input exchange. - - Args: - task: Task message from DeepPavlov Agent. - response: DeepPavlov model response (dict type) if infer was successful otherwise string representation of - raised error - - """ - if isinstance(response, dict): - result = ServiceResponseMessage(agent_name=task.agent_name, - task_id=task.payload["task_id"], - response=response) - else: - result = ServiceErrorMessage(agent_name=task.agent_name, - task_id=task.payload["task_id"], - formatted_exc=response) - - message = Message(body=json.dumps(result.to_json()).encode('utf-8'), - delivery_mode=aio_pika.DeliveryMode.PERSISTENT, - expiration=self._utterance_lifetime_sec) - - routing_key = AGENT_ROUTING_KEY_TEMPLATE.format(agent_name=task.agent_name) - await self._agent_in_exchange.publish(message=message, routing_key=routing_key) - log.debug(f'Sent response for task {str(task.payload["task_id"])} with routing key {routing_key}') diff --git a/deeppavlov/utils/agent/server.py b/deeppavlov/utils/agent/server.py deleted file mode 100644 index b9f3359af8..0000000000 --- a/deeppavlov/utils/agent/server.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2019 Neural Networks and Deep Learning lab, MIPT -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import logging -from pathlib import Path -from typing import Optional, Union - -from deeppavlov.core.common.file import read_json -from deeppavlov.core.common.paths import get_settings_path -from deeppavlov.utils.agent.rabbitmq import RabbitMQServiceGateway - -CONNECTOR_CONFIG_FILENAME = 'server_config.json' - - -def start_rabbit_service(model_config: Union[str, Path], - service_name: Optional[str] = None, - agent_namespace: Optional[str] = None, - batch_size: Optional[int] = None, - utterance_lifetime_sec: Optional[int] = None, - rabbit_host: Optional[str] = None, - rabbit_port: Optional[int] = None, - rabbit_login: Optional[str] = None, - rabbit_password: Optional[str] = None, - rabbit_virtualhost: Optional[str] = None) -> None: - """Launches DeepPavlov model receiving utterances and sending responses via RabbitMQ message broker. - - Args: - model_config: Path to DeepPavlov model to be launched. - service_name: Service name set in DeepPavlov Agent config. Used to format RabbitMQ exchanges, queues and routing - keys names. - agent_namespace: Service processes messages only from agents with the same namespace value. - batch_size: Limits the maximum number of utterances to be processed by service at one inference. - utterance_lifetime_sec: RabbitMQ message expiration time in seconds. - rabbit_host: RabbitMQ server host name. - rabbit_port: RabbitMQ server port number. - rabbit_login: RabbitMQ server administrator username. - rabbit_password: RabbitMQ server administrator password. - rabbit_virtualhost: RabbitMQ server virtualhost name. - - """ - service_config_path = get_settings_path() / CONNECTOR_CONFIG_FILENAME - service_config: dict = read_json(service_config_path)['agent-rabbit'] - - service_name = service_name or service_config['service_name'] - agent_namespace = agent_namespace or service_config['agent_namespace'] - batch_size = batch_size or service_config['batch_size'] - utterance_lifetime_sec = utterance_lifetime_sec or service_config['utterance_lifetime_sec'] - rabbit_host = rabbit_host or service_config['rabbit_host'] - rabbit_port = rabbit_port or service_config['rabbit_port'] - rabbit_login = rabbit_login or service_config['rabbit_login'] - rabbit_password = rabbit_password or service_config['rabbit_password'] - rabbit_virtualhost = rabbit_virtualhost or service_config['rabbit_virtualhost'] - - loop = asyncio.get_event_loop() - - gateway = RabbitMQServiceGateway( - model_config=model_config, - service_name=service_name, - agent_namespace=agent_namespace, - batch_size=batch_size, - utterance_lifetime_sec=utterance_lifetime_sec, - rabbit_host=rabbit_host, - rabbit_port=rabbit_port, - rabbit_login=rabbit_login, - rabbit_password=rabbit_password, - rabbit_virtualhost=rabbit_virtualhost, - loop=loop - ) - - try: - loop.run_forever() - except KeyboardInterrupt: - pass - finally: - gateway.disconnect() - loop.stop() - loop.close() - logging.shutdown() diff --git a/deeppavlov/utils/settings/server_config.json b/deeppavlov/utils/settings/server_config.json index 1bae81cf4b..3cbdfd5e4f 100644 --- a/deeppavlov/utils/settings/server_config.json +++ b/deeppavlov/utils/settings/server_config.json @@ -9,16 +9,5 @@ "socket_type": "TCP", "unix_socket_file": "/tmp/deeppavlov_socket.s", "socket_launch_message": "launching socket server at" - }, - "agent-rabbit": { - "service_name": "", - "agent_namespace": "deeppavlov_agent", - "batch_size": 1, - "utterance_lifetime_sec": 120, - "rabbit_host": "0.0.0.0", - "rabbit_port": 5672, - "rabbit_login": "guest", - "rabbit_password": "guest", - "rabbit_virtualhost": "/" } } diff --git a/docs/_static/social/telegram.png b/docs/_static/social/telegram.png new file mode 100644 index 0000000000..6a61600417 Binary files /dev/null and b/docs/_static/social/telegram.png differ diff --git a/docs/_templates/footer.html b/docs/_templates/footer.html index cd7ef617dc..e062a3edfb 100644 --- a/docs/_templates/footer.html +++ b/docs/_templates/footer.html @@ -32,9 +32,10 @@ {%- block extrafooter %}

Problem? Ask a Question or try our Demo

+ medium twitter youtube - medium + medium

{% endblock %}

diff --git a/docs/apiref/models/entity_extraction.rst b/docs/apiref/models/entity_extraction.rst index 865b51e686..7f47a4ed59 100644 --- a/docs/apiref/models/entity_extraction.rst +++ b/docs/apiref/models/entity_extraction.rst @@ -16,4 +16,4 @@ deeppavlov.models.entity_extraction .. automethod:: __init__ .. automethod:: __call__ -.. autofunction:: deeppavlov.models.entity_extraction.entity_detection_parser.question_sign_checker +.. autoclass:: deeppavlov.models.entity_extraction.entity_detection_parser.QuestionSignChecker diff --git a/docs/apiref/models/kbqa.rst b/docs/apiref/models/kbqa.rst index 8a327251cb..39ff367b25 100644 --- a/docs/apiref/models/kbqa.rst +++ b/docs/apiref/models/kbqa.rst @@ -28,7 +28,7 @@ deeppavlov.models.kbqa .. automethod:: __init__ .. automethod:: __call__ -.. autoclass:: deeppavlov.models.kbqa.tree_to_sparql.RuAdjToNoun +.. autoclass:: deeppavlov.models.kbqa.ru_adj_to_noun.RuAdjToNoun .. automethod:: __init__ .. automethod:: __call__ diff --git a/docs/conf.py b/docs/conf.py index d90853887a..b8ff2326c4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -201,7 +201,7 @@ # -- Extension configuration ------------------------------------------------- -autodoc_mock_imports = ['bs4', 'fasttext', 'hdt', 'kenlm', 'lxml', 'navec', 'nltk', 'opt_einsum', 'rapidfuzz', +autodoc_mock_imports = ['bs4', 'fasttext', 'hdt', 'kenlm', 'lxml', 'navec', 'nltk', 'opt_einsum', 'rapidfuzz', 'razdel', 'sacremoses', 'slovnet', 'sortedcontainers', 'spacy', 'torch', 'torchcrf', 'transformers', 'udapi', 'whapi'] diff --git a/docs/features/models/entity_extraction.rst b/docs/features/models/entity_extraction.rst index 8b5fa7ba9c..3c8f2791bf 100644 --- a/docs/features/models/entity_extraction.rst +++ b/docs/features/models/entity_extraction.rst @@ -49,9 +49,9 @@ Entity Detection model can be used from Python using the following code: .. code:: python - from deeppavlov import configs, build_model + from deeppavlov import build_model - ed = build_model(configs.entity_extraction.entity_detection_en, download=True) + ed = build_model('entity_detection_en', download=True) ed(['Forrest Gump is a comedy-drama film directed by Robert Zemeckis and written by Eric Roth.']) Entity Linking is the task of finding knowledge base entity ids for entity mentions in text. Entity Linking in DeepPavlov supports Wikidata and Wikipedia (for :config:`English ` and :config:`Russian `). Entity Linking component performs the following steps: @@ -101,7 +101,7 @@ Entity Linking model can be used from Python using the following code: .. code:: python - from deeppavlov import configs, build_model + from deeppavlov import build_model - entity_extraction = build_model(configs.kbqa.entity_extraction_en, download=True) + entity_extraction = build_model('entity_extraction_en', download=True) entity_extraction(['Forrest Gump is a comedy-drama film directed by Robert Zemeckis and written by Eric Roth.']) diff --git a/docs/features/models/kbqa.rst b/docs/features/models/kbqa.rst index 17541cd115..b420fdb86a 100755 --- a/docs/features/models/kbqa.rst +++ b/docs/features/models/kbqa.rst @@ -15,7 +15,7 @@ Currently, we support Wikidata as a Knowledge Base (Knowledge Graph). In the fut The question answerer: * validates questions against a preconfigured list of question templates, disambiguates entities using Entity Linking, and answers questions asked in natural language, -* can be used with Wikidata (English, Russian) and (in the future versions) with custom knowledge graphs. +* can be used with Wikidata (English, Russian) and DBPedia (Russian). Built-In Models ------------------ @@ -24,7 +24,7 @@ Currently, we provide two built-in models for KBQA in DeepPavlov library: * :config:`kbqa_cq_en ` - for answering complex questions over Wikidata in English, -* :config:`kbqa_cq_ru ` - for answering complex questions over Wikidata in Russian, +* :config:`kbqa_cq_ru ` - for answering complex questions over Wikidata in Russian. These configs use local Wikidata dump in hdt format (3.7 Gb on disk). @@ -124,7 +124,7 @@ Here are the models we've trained for complex question answering: * :config:`entity_detection ` - sequence tagging model for detection of entity and entity types substrings in the question, -* :config:`rel_ranking ` - model for ranking of candidate relations and candidate_relation_paths for the question, +* :config:`rel_ranking ` - model for ranking of candidate relations and candidate_relation_paths for the question, How Do I: Train Query Prediction Model -------------------------------------- @@ -159,34 +159,21 @@ An example of a template:: { "query_template": "SELECT ?obj WHERE { wd:E1 p:R1 ?s . ?s ps:R1 ?obj . ?s ?p ?x filter(contains(?x, N)) }", - "property_types": {"?p": "qualifier"}, "rank_rels": ["wiki", "do_not_rank", "do_not_rank"], "rel_types": ["no_type", "statement", "qualifier"], - "filter_rels": [false], - "rel_dirs": ["forw"], "query_sequence": [1, 2, 3], - "entities_and_types_num": [1, 0], - "entities_and_types_select": "1 0", - "syntax_structure": {"gr_ent": 1, "types": 0, "mod_ent": 0, "q_ent": 0, "count": false, "order": false}, "return_if_found": true, "template_num": "0", "alternative_templates": [] } * ``query_template`` is the template of the SPARQL query, -* ``property_types`` defines the types of unknown relations in the template, * ``rank_rels`` is a list which defines whether to rank relations, in this example **p:R1** relations we extract from Wikidata for **wd:E1** entities and rank with rel_ranker, **ps:R1** and **?p** relations we do not extract and rank, * ``rel_types`` - direct, statement or qualifier relations, -* ``filter_rels`` (only for online version of KBQA) - whether candidate rels will be enumerated in the **filter** expression in the query, for example, - **SELECT ?ent WHERE { ?ent wdt:P31 wd:Q4022 . ?ent ?p1 wd:Q90 } filter(?p1 = wdt:P131 || ?p1 = wdt:P17)**, -* ``rel_dirs`` - **forw** if the relation connects the subject and unknown object, for example, **wd:Q649 wdt:P17 ?p**, **backw** if the relation connects the unknown object and the subject, for example **?p wdt:P17 wd:Q159**, * ``query_sequence`` (only for offline version of KBQA) - the sequence in which the triplets will be extracted from Wikidata hdt file, -* ``entities_and_types_num`` - numbers of entities and types extracted from the question, which this template can contain, -* ``entities_and_types_select`` - the dictionary where keys are number of entities and types extracted from the question and values are indices of entities and types which should be filled in the template (because we can extract more entities and types than the template contains), -* ``syntax_structure`` - information about syntactic structure of questions corresponding to this query, * ``return_if_found`` - parameter for the cycle which iterates over all possible combinations of entities, relations and types, if **true** - return if the first valid combination is found, if **false** - consider all combinations, * ``template_num`` - the number of template, -* alternative_templates - numbers of alternative templates to use if the answer was not found with the current template. +* ``alternative_templates`` - numbers of alternative templates to use if the answer was not found with the current template. Advanced: Using Wiki Parser As Standalone Service For KBQA ------------------------------------------------------------------------------ @@ -250,8 +237,8 @@ To use Entity Linking service in KBQA, in the :config:`kbqa_cq_en ` you should replace :config:`wiki parser component ` with API Requester component in the following way:: diff --git a/docs/index.rst b/docs/index.rst index dffd62807b..0db3c9891b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -50,7 +50,6 @@ Welcome to DeepPavlov's documentation! REST API Socket API - DeepPavlov Agent RabbitMQ integration Amazon AWS deployment DeepPavlov settings diff --git a/docs/integrations/dp_agent.rst b/docs/integrations/dp_agent.rst deleted file mode 100644 index 4f64923f2c..0000000000 --- a/docs/integrations/dp_agent.rst +++ /dev/null @@ -1,64 +0,0 @@ -DeepPavlov Agent RabbitMQ integration -===================================== - -Any model specified by a DeepPavlov config can be launched as a service for -`DeepPavlov Agent `_ -communicating with agent through RabbitMQ message broker. You can launch it -using command line interface or using python. - -Command line interface -~~~~~~~~~~~~~~~~~~~~~~ - -To run a model specified by the ```` config file as a DeepPavlov Agent service, run: - -.. code:: bash - - python -m deeppavlov agent-rabbit [-d] \ - [-sn ] \ - [-an ] \ - [-ans ] \ - [-b ] \ - [-ul ] \ - [-rp ] \ - [-rl ] \ - [-rpwd ] \ - [-rvh ] - -* ``-d``: download model specific data before starting the service. -* ``-sn ``: service name set in the connector section of the DeepPavlov Agent config file. -* ``-an ``: namespace the service works in. Messages only from agents from this namespace is processed. -* ``-b ``: inference batch size. -* ``-ul ``: RabbitMQ server host. -* ``-rp ``: RabbitMQ server port. -* ``-rl ``: RabbitMQ server login. -* ``-rpwd ``: RabbitMQ server password. -* ``-rvh ``: RabbitMQ server virtualhost. - -Default values of optional arguments can be modified via changing ``agent-rabbit`` section of the file -``deeppavlov/utils/settings/server_config.json``. - -Python interface -~~~~~~~~~~~~~~~~ - -To run a model specified by the ```` config file as a DeepPavlov Agent service using python, -run the following code: - -.. code:: python - - from deeppavlov.utils.agent import start_rabbit_service - - start_rabbit_service(model_config=, - service_name=, - agent_namespace=, - batch_size=, - utterance_lifetime_sec=, - rabbit_host=, - rabbit_port=, - rabbit_login=, - rabbit_password=, - rabbit_virtualhost=) - -All arguments except ```` are optional. Default values of optional arguments can be modified via changing -``agent-rabbit`` section of the file ``deeppavlov/utils/settings/server_config.json``. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6707575de1..2137b9e581 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,11 @@ -aio-pika>=3.2.2,<6.9.0 -fastapi>=0.47.0,<0.78.0 +fastapi>=0.47.0,<=0.89.1 filelock>=3.0.0,<3.10.0 -nltk>=3.2.5,<3.10.0 +nltk>=3.2.4,<3.10.0 numpy<1.24 -overrides==4.1.2 pandas>=1.0.0,<1.6.0 -prometheus-client>=0.13.0,<0.15.0 +prometheus-client>=0.13.0,<=1.16.0 pydantic -pybind11==2.2.4 +pybind11==2.10.3 requests>=2.19.0,<3.0.0 scikit-learn>=0.24,<1.1.0 scipy<1.10.0 diff --git a/setup.py b/setup.py index 3a6f642e79..db912de26c 100644 --- a/setup.py +++ b/setup.py @@ -71,11 +71,12 @@ def readme(): 'sphinx==3.5.4;python_version<"3.10"', 'sphinx==4.5.0;python_version>="3.10"', 'sphinx_rtd_theme==0.5.2', + 'docutils<0.17,>=0.12', 'nbsphinx==0.8.4', 'ipykernel==5.5.4', 'jinja2<=3.0.3', 'sphinx-copybutton==0.5.0', - 'pandoc==2.2', + 'pandoc==2.3', 'ipython_genutils==0.2.0' ], 's3': [ diff --git a/tests/test_quick_start.py b/tests/test_quick_start.py index 654984c6b8..c6acf670d9 100644 --- a/tests/test_quick_start.py +++ b/tests/test_quick_start.py @@ -107,8 +107,8 @@ ("classifiers/rusentiment_bert.json", "classifiers", ('IP',)): [ONE_ARGUMENT_INFER_CHECK], ("classifiers/sentiment_twitter.json", "classifiers", ALL_MODES): [ONE_ARGUMENT_INFER_CHECK], ("classifiers/sentiment_sst_conv_bert.json", "classifiers", ('IP',)): [ONE_ARGUMENT_INFER_CHECK], - ("classifiers/glue/glue_mrpc_cased_bert_torch.json", "classifiers", ('TI',)): [TWO_ARGUMENTS_INFER_CHECK], - ("classifiers/glue/glue_stsb_cased_bert_torch.json", "classifiers", ('TI',)): [TWO_ARGUMENTS_INFER_CHECK], + ("classifiers/glue/glue_mrpc_roberta.json", "classifiers", ('TI',)): [TWO_ARGUMENTS_INFER_CHECK], + ("classifiers/glue/glue_stsb_roberta.json", "classifiers", ('TI',)): [TWO_ARGUMENTS_INFER_CHECK], ("classifiers/glue/glue_mnli_roberta.json", "classifiers", ('TI',)): [TWO_ARGUMENTS_INFER_CHECK], ("classifiers/glue/glue_rte_roberta_mnli.json", "classifiers", ('TI',)): [TWO_ARGUMENTS_INFER_CHECK], ("classifiers/superglue/superglue_copa_roberta.json", "classifiers", ('TI',)): [LIST_ARGUMENTS_INFER_CHECK], @@ -138,6 +138,12 @@ ("russian_super_glue/russian_superglue_parus_rubert.json", "russian_super_glue", ('IP',)): [LIST_ARGUMENTS_INFER_CHECK], ("russian_super_glue/russian_superglue_rucos_rubert.json", "russian_super_glue", ('IP',)): [RECORD_ARGUMENTS_INFER_CHECK] }, + "multitask":{ + ("multitask/multitask_example.json", "multitask", ALL_MODES): [ + ('Dummy text',) + (('Dummy text', 'Dummy text'),) * 3 + ('Dummy text',) + (None,)], + ("multitask/mt_glue.json", "multitask", ALL_MODES): [ + ('Dummy text',) * 2 + (('Dummy text', 'Dummy text'),) * 6 + (None,)] + }, "entity_extraction": { ("entity_extraction/entity_detection_en.json", "entity_extraction", ('IP',)): [ @@ -161,41 +167,6 @@ ['Москва — столица России, центр Центрального федерального округа и центр Московской области.'], [0.8359, 0.938, 0.9917, 0.9803])) ], - ("entity_extraction/entity_linking_en.json", "entity_extraction", ('IP',)): - [ - (['forrest gump', 'robert zemeckis', 'eric roth'], - ['WORK_OF_ART', 'PERSON', 'PERSON'], - ['Forrest Gump is a comedy-drama film directed by Robert Zemeckis and written by Eric Roth.'], - [(0, 12), (48, 63), (79, 88)], - [(0, 89)], - ([['Q134773', 'Q552213', 'Q12016774'], ['Q187364', 'Q36951156'], ['Q942932', 'Q89320386', 'Q89909683']], - [[(1.0, 110, 1.0), (1.0, 13, 0.73), (1.0, 8, 0.04)], - [(1.0, 73, 1.0), (0.5, 52, 0.29)], - [(1.0, 37, 0.95), (1.0, 2, 0.35), (0.67, 2, 0.35)]], - [['Forrest Gump', 'Forrest Gump (novel)', ''], - ['Robert Zemeckis', 'Welcome to Marwen'], ['Eric Roth', '', '']])) - ], - ("entity_extraction/entity_linking_ru.json", "entity_extraction", ('IP',)): - [ - (['москва', 'россии', 'центрального федерального округа', 'московской области'], - ['CITY', 'COUNTRY', 'LOC', 'LOC'], - ['Москва — столица России, центр Центрального федерального округа и центр Московской области.'], - [(0, 6), (17, 23), (31, 63), (72, 90)], - [(0, 91)], - ([['Q649', 'Q1023006', 'Q2380475'], ['Q159', 'Q2184', 'Q139319'], - ['Q190778', 'Q484215', 'Q21104009'], ['Q1697', 'Q4303932', 'Q24565285']], - [[(1.0, 134, 1.0), (1.0, 20, 0.0), (1.0, 18, 0.0)], - [(1.0, 203, 1.0), (1.0, 58, 1.0), (1.0, 29, 0.93)], - [(1.0, 24, 0.28), (0.67, 11, 0.5), (0.67, 8, 0.4)], - [(0.9, 30, 1.0), (0.9, 6, 0.83), (0.61, 8, 0.03)]], - [['Москва', 'Москоу (Канзас)', 'Москоу (Теннесси)'], - ['Россия', 'Российская Советская Федеративная Социалистическая Республика', - 'Российская республика'], - ['Центральный федеральный округ', 'Федеральные округа Российской Федерации', - 'Центральный административный округ (Назрань)'], - ['Московская область', 'Московская область (1917—1918)', - 'Мостовский (Волгоградская область)']])) - ], ("entity_extraction/entity_extraction_en.json", "entity_extraction", ('IP',)): [ ("Forrest Gump is a comedy-drama film directed by Robert Zemeckis and written by Eric Roth.", @@ -204,10 +175,12 @@ [(0, 12), (48, 63), (79, 88)], [['Q134773', 'Q552213', 'Q12016774'], ['Q187364', 'Q36951156'], ['Q942932', 'Q89320386', 'Q89909683']], - [[(1.0, 110, 1.0), (1.0, 13, 0.73), (1.0, 8, 0.04)], [(1.0, 73, 1.0), (0.5, 52, 0.29)], - [(1.0, 37, 0.95), (1.0, 2, 0.35), (0.67, 2, 0.35)]], + [[[1.1, 110, 1.0], [1.1, 13, 0.73], [1.1, 8, 0.04]], [[1.1, 73, 1.0], [0.5, 52, 0.29]], + [[1.1, 37, 0.95], [1.1, 2, 0.35], [0.67, 2, 0.35]]], [['Forrest Gump', 'Forrest Gump (novel)', ''], ['Robert Zemeckis', 'Welcome to Marwen'], - ['Eric Roth', '', '']])) + ['Eric Roth', '', '']], + [['Forrest Gump', 'Forrest Gump', 'Forrest Gump'], ['Robert Zemeckis', 'Welcome to Marwen'], + ['Eric Roth', 'Eric Roth', 'Eric W Roth']])) ], ("entity_extraction/entity_extraction_ru.json", "entity_extraction", ('IP',)): [ @@ -215,18 +188,23 @@ (['москва', 'россии', 'центрального федерального округа', 'московской области'], ['CITY', 'COUNTRY', 'LOC', 'LOC'], [(0, 6), (17, 23), (31, 63), (72, 90)], - [['Q649', 'Q1023006', 'Q2380475'], ['Q159', 'Q2184', 'Q139319'], - ['Q190778', 'Q484215', 'Q21104009'], ['Q1697', 'Q4303932', 'Q24565285']], - [[(1.0, 134, 1.0), (1.0, 20, 0.0), (1.0, 18, 0.0)], - [(1.0, 203, 1.0), (1.0, 58, 1.0), (1.0, 29, 0.93)], - [(1.0, 24, 0.28), (0.67, 11, 0.5), (0.67, 8, 0.4)], - [(0.9, 30, 1.0), (0.9, 6, 0.83), (0.61, 8, 0.03)]], + [['Q649', 'Q1023006', 'Q2380475'], ['Q159', 'Q2184', 'Q139319'], ['Q190778', 'Q4504288', 'Q27557290'], + ['Q1697', 'Q4303932', 'Q24565285']], + [[[1.1, 200, 1.0], [1.0, 20, 0.0], [1.0, 18, 0.0]], + [[1.1, 200, 1.0], [1.0, 58, 1.0], [1.0, 29, 0.85]], + [[1.1, 200, 1.0], [0.67, 3, 0.92], [0.67, 3, 0.89]], + [[0.9, 200, 1.0], [0.9, 6, 0.83], [0.61, 8, 0.03]]], [['Москва', 'Москоу (Канзас)', 'Москоу (Теннесси)'], ['Россия', 'Российская Советская Федеративная Социалистическая Республика', 'Российская республика'], - ['Центральный федеральный округ', 'Федеральные округа Российской Федерации', - 'Центральный административный округ (Назрань)'], - ['Московская область', 'Московская область (1917—1918)', 'Мостовский (Волгоградская область)']])) + ['Центральный федеральный округ', 'Центральный округ (Краснодар)', ''], + ['Московская область', 'Московская область (1917—1918)', + 'Мостовский (Волгоградская область)']], + [['Москва', 'Москоу', 'Москоу'], + ['Россия', 'Российская Советская Федеративная Социалистическая Республика', + 'Российская республика'], + ['Центральный федеральный округ', 'Центральный округ (Краснодар)', 'Центральный округ (Братск)'], + ['Московская область', 'Московская область', 'Мостовский']])) ] }, "ner": { @@ -243,16 +221,22 @@ "kbqa": { ("kbqa/kbqa_cq_en.json", "kbqa", ('IP',)): [ - ("What is the currency of Sweden?", ("Swedish krona",)), - ("Where was Napoleon Bonaparte born?", ("Ajaccio",)), - ("When did the Korean War end?", ("27 July 1953",)), - (" ", ("Not Found",)) - ], + ("What is the currency of Sweden?", + ("Swedish krona", ["Q122922"], ["SELECT ?answer WHERE { wd:Q34 wdt:P38 ?answer. }"])), + ("Where was Napoleon Bonaparte born?", + ("Ajaccio", ["Q40104"], ["SELECT ?answer WHERE { wd:Q517 wdt:P19 ?answer. }"])), + ("When did the Korean War end?", + ("27 July 1953", ["+1953-07-27^^T"], ["SELECT ?answer WHERE { wd:Q8663 wdt:P582 ?answer. }"])), + (" ", ("Not Found", [], [])) + ], ("kbqa/kbqa_cq_ru.json", "kbqa", ('IP',)): [ - ("Кто такой Оксимирон?", ("российский рэп-исполнитель",)), - ("Кто написал «Евгений Онегин»?", ("Александр Сергеевич Пушкин",)), - ("абв", ("Not Found",)) + ("Кто такой Оксимирон?", + ("российский рэп-исполнитель", ['российский рэп-исполнитель"@ru'], + ["SELECT ?answer WHERE { wd:Q4046107 wdt:P0 ?answer. }"])), + ("Кто написал «Евгений Онегин»?", + ("Александр Сергеевич Пушкин", ["Q7200"], ["SELECT ?answer WHERE { wd:Q50948 wdt:P50 ?answer. }"])), + ("абв", ("Not Found", [], [])) ] }, "ranking": { @@ -399,7 +383,8 @@ def infer(config_path, qr_list=None, check_outputs=True): raise RuntimeError(f'Unexpected results for {config_path}: {errors}') @staticmethod - def infer_api(config_path): + def infer_api(config_path, qr_list): + *inputs, expected_outputs = zip(*qr_list) server_params = get_server_params(config_path) url_base = 'http://{}:{}'.format(server_params['host'], api_port or server_params['port']) @@ -422,14 +407,10 @@ def infer_api(config_path): assert response_code == 200, f"GET /api request returned error code {response_code} with {config_path}" model_args_names = get_response.json()['in'] - post_payload = dict() - for arg_name in model_args_names: - arg_value = ' '.join(['qwerty'] * 10) - post_payload[arg_name] = [arg_value] + post_payload = dict(zip(model_args_names, inputs)) # TODO: remove this if from here and socket - if 'parus' in str(config_path): - post_payload = {k: [v] for k, v in post_payload.items()} - + if 'docred' in str(config_path) or 'rured' in str(config_path): + post_payload = {k: v[0] for k, v in post_payload.items()} post_response = requests.post(url, json=post_payload, headers=post_headers) response_code = post_response.status_code assert response_code == 200, f"POST request returned error code {response_code} with {config_path}" @@ -519,7 +500,7 @@ def test_inferring_pretrained_model(self, model, conf_file, model_dir, mode): def test_inferring_pretrained_model_api(self, model, conf_file, model_dir, mode): if 'IP' in mode: - self.infer_api(test_configs_path / conf_file) + self.infer_api(test_configs_path / conf_file, PARAMS[model][(conf_file, model_dir, mode)]) else: pytest.skip("Unsupported mode: {}".format(mode))