Skip to content

Commit

Permalink
fix teacher forcing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
silencesoup committed Nov 7, 2024
1 parent d879f3f commit f5f97e1
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion torchhydro/models/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,13 @@ def forward(self, *src):
use_teacher_forcing = (
random_vals < self.teacher_forcing_ratio
) * valid_mask
current_input = trg * use_teacher_forcing + output * (~use_teacher_forcing)
current_input = torch.where(
torch.isnan(trg), # if trg is nan
output, # then use output
trg * use_teacher_forcing
+ output
* (~use_teacher_forcing), # else calculate with teacher forcing
)

outputs = torch.stack(outputs, dim=1)
if self.prec_window > 0:
Expand Down

0 comments on commit f5f97e1

Please sign in to comment.