-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
307 lines (277 loc) · 13.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import argparse
import time
import math
import os
import torch
import sys
import data
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import sparse_rnn_core
from sparse_rnn_core import Masking, CosineDecay
from models import RHN, Stacked_LSTM
from Sparse_ASGD import Sparse_ASGD
parser = argparse.ArgumentParser(description='PyTorch implementation of Selfish-RNN')
parser.add_argument('--data', type=str, default='data/penn/',
help='location of the data corpus')
parser.add_argument('--model', type=str, default='LSTM',
help='type of recurrent net (RHN, LSTM)')
parser.add_argument('--evaluate', default='', type=str, metavar='PATH',
help='path to pre-trained model (default: none)')
parser.add_argument('--emsize', type=int, default=1500,
help='size of word embeddings')
parser.add_argument('--nhid', type=int, default=1500,
help='number of hidden units per layer')
parser.add_argument('--nonmono', type=int, default=5,
help='random seed')
parser.add_argument('--nlayers', type=int, default=2,
help='number of layers')
parser.add_argument('--nrecurrence_depth', type=int, default=10,
help='number of recurrence layer')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.0)')
parser.add_argument('--beta', type=float, default=1,
help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)')
parser.add_argument('--finetuning', type=int, default=100,
help='When (which epochs) to switch to finetuning')
parser.add_argument('--lr', type=float, default=15,
help='initial learning rate')
parser.add_argument('--clip', type=float, default=0.25,
help='gradient clipping')
parser.add_argument('--epochs', type=int, default=40,
help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=20, metavar='N',
help='batch size')
parser.add_argument('--eval_batch_size', type=int, default=20,
help='batch size')
parser.add_argument('--bptt', type=int, default=35,
help='sequence length')
parser.add_argument('--dropout', type=float, default=0.65,
help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--dropouth', type=float, default=0.25,
help='dropout for rnn hidden units (0 = no dropout)')
parser.add_argument('--dropouti', type=float, default=0.65,
help='dropout for input embedding layers (0 = no dropout)')
parser.add_argument('--dropoute', type=float, default=0.2,
help='dropout to remove words from embedding layer (0 = no dropout)')
parser.add_argument('--tied', action='store_true',
help='tie the word embedding and softmax weights')
parser.add_argument('--couple', action='store_true',
help='couple the transform and carry weights')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='report interval')
parser.add_argument('--wdecay', type=float, default=1.2e-6,
help='weight decay applied to all weights')
parser.add_argument('--optimizer', type=str, default='sgd',
help='optimizer to use (sgd, adam)')
parser.add_argument('--testmodel', type=str, default='15877701169307067.pt',
help='the name of saved model')
randomhash = ''.join(str(time.time()).split('.'))
parser.add_argument('--save', type=str, default=randomhash + '.pt',
help='path to save the final model')
sparse_rnn_core.add_sparse_args(parser)
args = parser.parse_args()
print(args)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
device = torch.device("cuda" if args.cuda else "cpu")
def model_save(fn):
torch.save(model.state_dict(), fn)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def repackage_hidden(h):
"""Wraps hidden states in new Tensors, to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
###############################################################################
# Load data
###############################################################################
corpus = data.Corpus(args.data)
# Starting from sequential data, batchify arranges the dataset into columns.
# For instance, with the alphabet as the sequence and batch size 4, we'd get
# ┌ a g m s ┐
# │ b h n t │
# │ c i o u │
# │ d j p v │
# │ e k q w │
# └ f l r x ┘.
# These columns are treated as independent by the model, which means that the
# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
# batch processing.
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
def get_batch(source, i):
seq_len = min(args.bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].view(-1)
return data, target
train_data = batchify(corpus.train, args.batch_size)
val_data = batchify(corpus.valid, args.eval_batch_size)
test_data = batchify(corpus.test, args.eval_batch_size)
###############################################################################
# Build the model
###############################################################################
ntokens = len(corpus.dictionary)
if args.model == 'RHN':
model = RHN(vocab_sz=ntokens, embedding_dim=args.emsize, hidden_dim=args.nhid,
recurrence_depth=args.nrecurrence_depth, num_layers=args.nlayers, input_dp=args.dropouti,
output_dp=args.dropout, hidden_dp=args.dropouth, embed_dp=args.dropoute,
tie_weights=args.tied, couple=args.couple).to(device)
elif args.model == 'LSTM':
model = Stacked_LSTM(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)
criterion = nn.CrossEntropyLoss()
###############################################################################
# Train and evaluate code
###############################################################################
def evaluate(data_source):
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss = 0.
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(args.eval_batch_size)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i)
output, hidden = model(data, hidden)
hidden = repackage_hidden(hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1)
def train(mask=None):
# Turn on training mode which enables dropout.
model.train()
total_loss = 0.
start_time = time.time()
ntokens = len(corpus.dictionary)
if args.model != 'Transformer':
hidden = model.init_hidden(args.batch_size)
for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
data, targets = get_batch(train_data, i)
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
optimizer.zero_grad()
hidden = repackage_hidden(hidden)
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
if mask is not None:
mask.step()
else:
optimizer.step()
total_loss += loss.item()
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss / args.log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, len(train_data) // args.bptt, lr,
elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
total_loss = 0
start_time = time.time()
sys.stdout.flush()
###############################################################################
# Training
###############################################################################
if args.evaluate:
print("=> loading checkpoint '{}'".format(args.evaluate))
model.load_state_dict(torch.load(args.evaluate))
print('=> testing...')
test_loss = evaluate(test_data)
print('=' * 89)
print('| Final test | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)
sys.stdout.flush()
else:
mask = None
lr = args.lr
best_val_loss = []
stored_loss = 100000000
# At any point you can hit Ctrl + C to break out of training early.
try:
optimizer = None
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
if args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), betas=(0, 0.999), eps=1e-9, weight_decay=args.wdecay)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', 0.5, patience=2, threshold=0)
mask = None
if args.sparse:
decay = CosineDecay(args.death_rate, args.epochs * len(train_data) // args.bptt)
mask = Masking(optimizer, death_rate=args.death_rate, death_mode=args.death, death_rate_decay=decay, growth_mode=args.growth,
redistribution_mode=args.redistribution, model=args.model)
mask.add_module(model, sparse_init=args.sparse_init, density=args.density)
# Loop over epochs.
for epoch in range(1, args.epochs + 1):
epoch_start_time = time.time()
train(mask)
if 't0' in optimizer.param_groups[0]:
tmp = {}
for prm in model.parameters():
tmp[prm] = prm.data.clone()
prm.data = optimizer.state[prm]['ax'].clone()
val_loss2 = evaluate(val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
epoch, (time.time() - epoch_start_time), val_loss2, math.exp(val_loss2), val_loss2 / math.log(2)))
print('-' * 89)
if val_loss2 < stored_loss:
model_save(args.save)
print('Saving Averaged!')
stored_loss = val_loss2
for prm in model.parameters():
prm.data = tmp[prm].clone()
if args.sparse and epoch < args.epochs + 1:
mask.at_end_of_epoch(epoch)
else:
val_loss = evaluate(val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), val_loss / math.log(2)))
print('-' * 89)
if val_loss < stored_loss:
model_save(args.save)
print('Saving model (new best validation)')
stored_loss = val_loss
if args.optimizer == 'adam':
scheduler.step(val_loss)
if args.optimizer == 'sgd' and 't0' not in optimizer.param_groups[0] and (
len(best_val_loss) > args.nonmono and val_loss > min(best_val_loss[:-args.nonmono])):
print('Switching to ASGD')
optimizer = Sparse_ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay)
mask.optimizer = optimizer
mask.init_optimizer_mask()
if args.sparse and 't0' not in optimizer.param_groups[0]:
mask.at_end_of_epoch(epoch)
best_val_loss.append(val_loss)
print("PROGRESS: {}%".format((epoch / args.epochs) * 100))
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')
# Load the best saved model.
with open(args.save, 'rb') as f:
model.load_state_dict(torch.load(args.save))
# Run on test data.
test_loss = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)
sys.stdout.flush()