Skip to content

Commit

Permalink
Merge pull request #22 from alexa/jgmf-eval-fix-3
Browse files Browse the repository at this point in the history
fixes for validation engine and for using torchrun
  • Loading branch information
jgmf-amazon authored Jun 22, 2022
2 parents 7175898 + 085111d commit 3932705
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 0 deletions.
3 changes: 3 additions & 0 deletions scripts/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse
import datetime
import logging
import os
import sys

import datasets
Expand Down Expand Up @@ -50,6 +51,8 @@ def main():
trainer_args = MASSIVETrainingArguments(**conf.get('train_val.trainer_args'))
if args.local_rank:
trainer_args.local_rank = int(args.local_rank)
elif os.getenv('LOCAL_RANK'):
trainer_args.local_rank = int(os.environ['LOCAL_RANK'])

# Setup logging
logging.basicConfig(
Expand Down
3 changes: 3 additions & 0 deletions scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse
import datetime
import logging
import os
import pprint
import sys
import time
Expand Down Expand Up @@ -53,6 +54,8 @@ def main():
trainer_args = MASSIVETrainingArguments(**conf.get('test.trainer_args'))
if args.local_rank:
trainer_args.local_rank = int(args.local_rank)
elif os.getenv('LOCAL_RANK'):
trainer_args.local_rank = int(os.environ['LOCAL_RANK'])

# Setup logging
logging.basicConfig(
Expand Down
3 changes: 3 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse
import datetime
import logging
import os
import sys

import datasets
Expand Down Expand Up @@ -49,6 +50,8 @@ def main():
trainer_args = MASSIVETrainingArguments(**conf.get('train_val.trainer_args'))
if args.local_rank:
trainer_args.local_rank = int(args.local_rank)
elif os.getenv('LOCAL_RANK'):
trainer_args.local_rank = int(os.environ['LOCAL_RANK'])

# Setup logging
logging.basicConfig(
Expand Down
5 changes: 5 additions & 0 deletions src/massive/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,11 @@ def eval_preds(pred_intents=None, lab_intents=None, pred_slots=None, lab_slots=N
if type(pred) == list:
pred = pred[:len(lab)] + [pad]*(len(lab) - len(pred))

# Fix for Issue 21 -- subwords after the first one from a word should be ignored
for i, x in enumerate(lab):
if x == -100:
pred[i] = -100

# convert to BIO
bio_slot_labels.append(
convert_to_bio(lab, outside=labels_ignore, labels_merge=labels_merge)
Expand Down

0 comments on commit 3932705

Please sign in to comment.