-
Notifications
You must be signed in to change notification settings - Fork 7
/
rnn.py
53 lines (46 loc) · 1.96 KB
/
rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from torch import nn
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, class_size, dropout=0.5, rnn_type='lstm'):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.class_size = class_size
self.num_layers = num_layers
self.rnn_type = rnn_type
if self.rnn_type == 'lstm':
self.rnn = nn.LSTM(
input_size=self.input_size,
hidden_size=self.hidden_size, # rnn hidden unit
num_layers=self.num_layers, # number of rnn layer
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
)
elif self.rnn_type == 'rnn':
self.rnn = nn.RNN(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True,
)
elif self.rnn_type == 'gru':
self.rnn = nn.GRU(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True,
)
else:
raise NotImplementedError
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(self.hidden_size, self.class_size) # FC layer in our paper
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
if self.rnn_type == 'lstm':
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
r_out, _ = self.rnn(x, (h0, c0))
else:
r_out, _ = self.rnn(x, h0)
outs = [] # save all predictions
for time_step in range(r_out.size(1)): # calculate output for each time step
outs.append(self.out(self.dropout((r_out[:, time_step, :]))))
return torch.stack(outs, dim=1)