Skip to content

Commit

Permalink
restore fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent authored and mryab committed Jul 13, 2024
1 parent 14ff472 commit 4ce1181
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions benchmarks/benchmark_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose:
grad_scaler = hivemind.GradScaler()
else:
# check that hivemind.Optimizer supports regular PyTorch grad scaler as well
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
grad_scaler = torch.amp.GradScaler(enabled=args.use_amp)

prev_time = time.perf_counter()

Expand All @@ -107,7 +107,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose:

batch = torch.randint(0, len(X_train), (batch_size,))

with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
with torch.amp.autocast() if args.use_amp else nullcontext():
loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
grad_scaler.scale(loss).backward()

Expand Down

0 comments on commit 4ce1181

Please sign in to comment.