-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathAttnDecoderRNN.py
30 lines (26 loc) · 1.35 KB
/
AttnDecoderRNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch import nn
from Attn import Attn
class AttnDecoderRNN(nn.Module):
def __init__(self, attn_model, hidden_size, output_size, n_layers=1, dropout_p=0.1):
super(AttnDecoderRNN, self).__init__()
self.attn_model = attn_model
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.dropout_p = dropout_p
self.embedding = nn.Embedding(output_size, hidden_size)
self.gru = nn.GRU(hidden_size * 2, hidden_size, n_layers, dropout=dropout_p)
self.out = nn.Linear(hidden_size * 2, output_size)
if attn_model != 'none':
self.attn = Attn(attn_model, hidden_size)
def forward(self, word_input, last_context, last_hidden, encoder_outputs):
word_embedded = self.embedding(word_input).view(1, 1, -1)
rnn_input = torch.cat((word_embedded, last_context.unsqueeze(0)), 2)
rnn_output, hidden = self.gru(rnn_input, last_hidden)
attn_weights = self.attn(rnn_output.squeeze(0), encoder_outputs)
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
rnn_output = rnn_output.squeeze(0)
context = context.squeeze(1)
output = F.log_softmax(self.out(torch.cat((rnn_output, context), 1)))
return output, context, hidden, attn_weights