Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added mtl to quickstart #1640

Merged
merged 5 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deeppavlov/configs/multitask/mt_glue.json
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@
"log_every_n_epochs": 1,
"show_examples": false,
"evaluation_targets": ["valid"],
"class_name": "torch_trainer"
"class_name": "torch_trainer",
"pytest_max_batches": 2
},
"metadata": {
"variables": {
Expand Down
3 changes: 2 additions & 1 deletion deeppavlov/configs/multitask/multitask_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@
"log_every_n_epochs": 1,
"show_examples": false,
"evaluation_targets": ["valid"],
"class_name": "torch_trainer"
"class_name": "torch_trainer",
"pytest_max_batches": 2
},
"metadata": {
"variables": {
Expand Down
3 changes: 0 additions & 3 deletions deeppavlov/models/torch_bert/multitask_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
import torch
import torch.nn as nn
from overrides import overrides
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers import AutoConfig, AutoModel

Expand Down Expand Up @@ -342,7 +341,6 @@ def __init__(
def _reset_cache(self):
self.preds_cache = {index_: None for index_ in self.types_to_cache if index_ != -1}

@overrides
def init_from_opt(self) -> None:
"""
Initialize from scratch `self.model` with the architecture built
Expand Down Expand Up @@ -401,7 +399,6 @@ def get_decay_params(model): return [
torch.optim.lr_scheduler, self.lr_scheduler_name
)(self.optimizer, **self.lr_scheduler_parameters)

@overrides
def load(self, fname: Optional[str] = None) -> None:
"""
Loads weights.
Expand Down
21 changes: 12 additions & 9 deletions tests/test_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@
("russian_super_glue/russian_superglue_parus_rubert.json", "russian_super_glue", ('IP',)): [LIST_ARGUMENTS_INFER_CHECK],
("russian_super_glue/russian_superglue_rucos_rubert.json", "russian_super_glue", ('IP',)): [RECORD_ARGUMENTS_INFER_CHECK]
},
"multitask":{
("multitask/multitask_example.json", "multitask", ALL_MODES): [
('Dummy text',) + (('Dummy text', 'Dummy text'),) * 3 + ('Dummy text',) + (None,)],
("multitask/mt_glue.json", "multitask", ALL_MODES): [
('Dummy text',) * 2 + (('Dummy text', 'Dummy text'),) * 6 + (None,)]
},
"entity_extraction": {
("entity_extraction/entity_detection_en.json", "entity_extraction", ('IP',)):
[
Expand Down Expand Up @@ -399,7 +405,8 @@ def infer(config_path, qr_list=None, check_outputs=True):
raise RuntimeError(f'Unexpected results for {config_path}: {errors}')

@staticmethod
def infer_api(config_path):
def infer_api(config_path, qr_list):
*inputs, expected_outputs = zip(*qr_list)
server_params = get_server_params(config_path)

url_base = 'http://{}:{}'.format(server_params['host'], api_port or server_params['port'])
Expand All @@ -422,14 +429,10 @@ def infer_api(config_path):
assert response_code == 200, f"GET /api request returned error code {response_code} with {config_path}"

model_args_names = get_response.json()['in']
post_payload = dict()
for arg_name in model_args_names:
arg_value = ' '.join(['qwerty'] * 10)
post_payload[arg_name] = [arg_value]
post_payload = dict(zip(model_args_names, inputs))
# TODO: remove this if from here and socket
if 'parus' in str(config_path):
post_payload = {k: [v] for k, v in post_payload.items()}

if 'docred' in str(config_path) or 'rured' in str(config_path):
post_payload = {k: v[0] for k, v in post_payload.items()}
post_response = requests.post(url, json=post_payload, headers=post_headers)
response_code = post_response.status_code
assert response_code == 200, f"POST request returned error code {response_code} with {config_path}"
Expand Down Expand Up @@ -519,7 +522,7 @@ def test_inferring_pretrained_model(self, model, conf_file, model_dir, mode):

def test_inferring_pretrained_model_api(self, model, conf_file, model_dir, mode):
if 'IP' in mode:
self.infer_api(test_configs_path / conf_file)
self.infer_api(test_configs_path / conf_file, PARAMS[model][(conf_file, model_dir, mode)])
else:
pytest.skip("Unsupported mode: {}".format(mode))

Expand Down