Skip to content

Commit

Permalink
add a en_output_size param in seq2seq
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Nov 1, 2024
1 parent 6c2f228 commit d6fcf47
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 42 deletions.
33 changes: 12 additions & 21 deletions tests/test_seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-17 12:55:24
LastEditTime: 2024-10-31 10:17:40
LastEditTime: 2024-11-01 12:02:49
LastEditors: Wenyu Ouyang
Description: Test funcs for seq2seq model
FilePath: \torchhydro\tests\test_seq2seq.py
Expand Down Expand Up @@ -147,35 +147,26 @@ def test_seq2seq(config):
@pytest.fixture
def model():
return GeneralSeq2Seq(
en_input_size=10,
de_input_size=10,
output_size=5,
en_input_size=2,
de_input_size=3,
output_size=2,
hidden_size=20,
forecast_length=5,
prec_window=2,
prec_window=10,
teacher_forcing_ratio=0.5,
)


def test_forward_no_teacher_forcing(model):
src1 = torch.randn(3, 10, 10)
src2 = torch.randn(3, 5, 10)
src1 = torch.randn(3, 10, 2)
src2 = torch.randn(3, 5, 1)
outputs = model(src1, src2)
assert outputs.shape == (3, 6, 5)
assert outputs.shape == (3, 6, 2)


def test_forward_with_teacher_forcing(model):
src1 = torch.randn(3, 10, 10)
src2 = torch.randn(3, 5, 10)
trgs = torch.randn(3, 7, 5)
src1 = torch.randn(3, 10, 2)
src2 = torch.randn(3, 5, 1)
trgs = torch.randn(3, 15, 2)
outputs = model(src1, src2, trgs)
assert outputs.shape == (3, 6, 5)


def test_forward_with_nan_in_trgs(model):
src1 = torch.randn(3, 10, 10)
src2 = torch.randn(3, 5, 10)
trgs = torch.randn(3, 7, 5)
trgs[0, 3, 1] = float("nan")
outputs = model(src1, src2, trgs)
assert outputs.shape == (3, 6, 5)
assert outputs.shape == (3, 6, 2)
76 changes: 55 additions & 21 deletions torchhydro/models/seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-17 12:32:26
LastEditTime: 2024-04-17 12:33:34
LastEditors: Xinzhuo Wu
LastEditTime: 2024-11-01 12:01:16
LastEditors: Wenyu Ouyang
Description:
FilePath: /torchhydro/torchhydro/models/seq2seq.py
FilePath: \torchhydro\torchhydro\models\seq2seq.py
Copyright (c) 2021-2024 Wenyu Ouyang. All rights reserved.
"""

Expand Down Expand Up @@ -70,11 +70,15 @@ def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, dropout=0.3)
self.fc = nn.Linear(hidden_dim, output_dim)

def forward(self, x):
x = self.pre_fc(x)
x = self.pre_relu(x)
outputs, (hidden, cell) = self.lstm(x)
outputs = self.dropout(outputs)
outputs = self.fc(outputs)
# a nonlinear layer to transform the input
x0 = self.pre_fc(x)
x1 = self.pre_relu(x0)
# the LSTM layer
outputs_, (hidden, cell) = self.lstm(x1)
# a dropout layer
dr_outputs = self.dropout(outputs_)
# final linear layer
outputs = self.fc(dr_outputs)
return outputs, hidden, cell


Expand All @@ -89,12 +93,12 @@ def __init__(self, input_dim, output_dim, hidden_dim, num_layers=1, dropout=0.3)
self.fc_out = nn.Linear(hidden_dim, output_dim)

def forward(self, input, hidden, cell):
x = self.pre_fc(input)
x = self.pre_relu(x)
output, (hidden, cell) = self.lstm(x, (hidden, cell))
output = self.dropout(output)
output = self.fc_out(output)
return output, hidden, cell
x0 = self.pre_fc(input)
x1 = self.pre_relu(x0)
output_, (hidden_, cell_) = self.lstm(x1, (hidden, cell))
output_dr = self.dropout(output_)
output = self.fc_out(output_dr)
return output, hidden_, cell_


class StateTransferNetwork(nn.Module):
Expand All @@ -119,11 +123,35 @@ def __init__(
forecast_length,
prec_window=0,
teacher_forcing_ratio=0.5,
en_output_size=1,
):
"""General Seq2Seq model
Parameters
----------
en_input_size : _type_
the size of the input of the encoder
de_input_size : _type_
the size of the input of the decoder
output_size : _type_
the size of the output, same for encoder and decoder
hidden_size : _type_
the size of the hidden state of LSTMs
forecast_length : _type_
the length of the forecast, i.e., the periods of decoder outputs
prec_window : int, optional
starting index of decoder output for teacher forcing; default is 0
teacher_forcing_ratio : float, optional
the probability of using teacher forcing
en_output_size : int, optional
the encoder's final several outputs in the final output;
default is 1 which means the final encoder output is included in the final output
"""
super(GeneralSeq2Seq, self).__init__()
self.trg_len = forecast_length
self.prec_window = prec_window
self.teacher_forcing_ratio = teacher_forcing_ratio
self.en_output_size = en_output_size
self.encoder = Encoder(
input_dim=en_input_size, hidden_dim=hidden_size, output_dim=output_size
)
Expand All @@ -134,25 +162,30 @@ def __init__(

def forward(self, *src):
if len(src) == 3:
src1, src2, trgs = src
encoder_input, decoder_input, trgs = src
else:
src1, src2 = src
encoder_input, decoder_input = src
trgs = None
encoder_outputs, hidden, cell = self.encoder(src1)
hidden, cell = self.transfer(hidden, cell)
encoder_outputs, hidden_, cell_ = self.encoder(encoder_input)
hidden, cell = self.transfer(hidden_, cell_)
outputs = []
current_input = encoder_outputs[:, -1, :].unsqueeze(1)

for t in range(self.trg_len):
p = src2[:, t, :].unsqueeze(1)
p = decoder_input[:, t, :].unsqueeze(1)
current_input = torch.cat((current_input, p), dim=2)
output, hidden, cell = self.decoder(current_input, hidden, cell)
outputs.append(output.squeeze(1))
if trgs is None or self.teacher_forcing_ratio <= 0:
current_input = output
else:
sm_trg = trgs[:, (self.prec_window + t), 1].unsqueeze(1).unsqueeze(1)
# most of soil moisture from remote sensing are not nan,
# so if we meet nan values, we just ignore the teacher forcing
# for streamflow, there are always some ungauged stations,
# so we just use ssm to choose if we use teacher forcing
if not torch.any(torch.isnan(sm_trg)).item():
# random choice of using teacher forcing with probability of teacher_forcing_ratio
use_teacher_forcing = random.random() < self.teacher_forcing_ratio
str_trg = output[:, :, 0].unsqueeze(2)
current_input = (
Expand All @@ -164,8 +197,9 @@ def forward(self, *src):
current_input = output

outputs = torch.stack(outputs, dim=1)
prec_outputs = encoder_outputs[:, -self.prec_window, :].unsqueeze(1)
outputs = torch.cat((prec_outputs, outputs), dim=1)
if self.en_output_size > 0:
prec_outputs = encoder_outputs[:, -self.en_output_size :, :]
outputs = torch.cat((prec_outputs, outputs), dim=1)
return outputs


Expand Down

0 comments on commit d6fcf47

Please sign in to comment.