Skip to content

Commit

Permalink
Merge pull request #17 from salman-moh/main
Browse files Browse the repository at this point in the history
Added `model.eval()` and `torch.no_grad()` to disable dropout/gradient calculation during inference
  • Loading branch information
SGenheden authored Nov 23, 2022
2 parents 190dcb5 + 72ef3d9 commit 0ca3c1b
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions example_scripts/finetune_regression/finetuneRegr.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,18 @@ def get_targs_preds(model, dl):
targs = []
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
for i, batch in enumerate(iter(dl)):
batch['encoder_input'] = batch['encoder_input'].to(device)
batch['encoder_pad_mask'] = batch['encoder_pad_mask'].to(device)
batch['target'] = batch['target'].to(device)

batch_preds = model(batch).squeeze(dim=1).tolist()
batch_targs = batch['target'].squeeze(dim=1).tolist()

preds.append(batch_preds)
targs.append(batch_targs)
model.eval()
with torch.no_grad():
for i, batch in enumerate(iter(dl)):
batch['encoder_input'] = batch['encoder_input'].to(device)
batch['encoder_pad_mask'] = batch['encoder_pad_mask'].to(device)
batch['target'] = batch['target'].to(device)

batch_preds = model(batch).squeeze(dim=1).tolist()
batch_targs = batch['target'].squeeze(dim=1).tolist()

preds.append(batch_preds)
targs.append(batch_targs)

targs = list(itertools.chain.from_iterable(targs))
preds = list(itertools.chain.from_iterable(preds))
Expand Down Expand Up @@ -262,4 +264,4 @@ def main(args):
parser.add_argument("--limit_val_batches", type=float, default=DEFAULT_LIMIT_VAL_BATCHES)

args = parser.parse_args()
main(args)
main(args)

0 comments on commit 0ca3c1b

Please sign in to comment.