-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_stab.py
169 lines (134 loc) · 6.32 KB
/
train_stab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# this file is based on code publicly available at
# https://github.com/locuslab/smoothing
# written by Jeremy Cohen.
import argparse
import time
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from architectures import ARCHITECTURES
from datasets import DATASETS
from train_utils import AverageMeter, accuracy, log, test
from train_utils import prologue
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('arch', type=str, choices=ARCHITECTURES)
parser.add_argument('--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--batch', default=256, type=int, metavar='N',
help='batchsize (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
help='initial learning rate', dest='lr')
parser.add_argument('--lr_step_size', type=int, default=30,
help='How often to decrease learning by gamma.')
parser.add_argument('--gamma', type=float, default=0.1,
help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--noise_sd', default=0.0, type=float,
help="standard deviation of Gaussian noise for data augmentation")
parser.add_argument('--print-freq', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--id', default=None, type=int,
help='experiment id, `randint(10000)` if None')
#####################
# Options added by Salman et al. (2019)
parser.add_argument('--resume', action='store_true',
help='if true, tries to resume training from existing checkpoint')
parser.add_argument('--pretrained-model', type=str, default='',
help='Path to a pretrained model')
#####################
# Stability training hyperparameter
parser.add_argument('--lbd', default=2.0, type=float)
args = parser.parse_args()
args.outdir = f"logs/{args.dataset}/stab/lbd_{args.lbd}/noise_{args.noise_sd}"
def _cross_entropy(input, targets, reduction='mean'):
targets_prob = F.softmax(targets, dim=1)
xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1)
if reduction == 'sum':
return xent.sum()
elif reduction == 'mean':
return xent.mean()
elif reduction == 'none':
return xent
else:
raise NotImplementedError()
def main():
train_loader, test_loader, criterion, model, optimizer, scheduler, \
starting_epoch, logfilename, model_path, device, writer = prologue(args)
for epoch in range(starting_epoch, args.epochs):
before = time.time()
train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, args.noise_sd, device, writer)
test_loss, test_acc = test(test_loader, model, criterion, epoch, args.noise_sd, device, writer, args.print_freq)
after = time.time()
log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
epoch, after - before,
scheduler.get_lr()[0], train_loss, train_acc, test_loss, test_acc))
# In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`.
# See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
scheduler.step(epoch)
torch.save({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, model_path)
def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer,
epoch: int, noise_sd: float, device: torch.device, writer=None):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
losses_reg = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
# switch to train mode
model.train()
for i, (inputs, targets) in enumerate(loader):
# measure data loading time
data_time.update(time.time() - end)
inputs, targets = inputs.to(device), targets.to(device)
batch_size = inputs.size(0)
# augment inputs with noise
noise = torch.randn_like(inputs, device=device) * noise_sd
logits = model(inputs)
logits_n = model(inputs + noise)
loss_xent = criterion(logits, targets)
stab = _cross_entropy(logits_n, logits)
loss = loss_xent + args.lbd * stab
acc1, acc5 = accuracy(logits_n, targets, topk=(1, 5))
losses.update(loss_xent.item(), batch_size)
losses_reg.update(stab.item(), batch_size)
top1.update(acc1.item(), batch_size)
top5.update(acc5.item(), batch_size)
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.avg:.3f}\t'
'Data {data_time.avg:.3f}\t'
'Loss {loss.avg:.4f}\t'
'Acc@1 {top1.avg:.3f}\t'
'Acc@5 {top5.avg:.3f}'.format(
epoch, i, len(loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
if writer:
writer.add_scalar('loss/train', losses.avg, epoch)
writer.add_scalar('loss/stability', losses_reg.avg, epoch)
writer.add_scalar('batch_time', batch_time.avg, epoch)
writer.add_scalar('accuracy/train@1', top1.avg, epoch)
writer.add_scalar('accuracy/train@5', top5.avg, epoch)
return (losses.avg, top1.avg)
if __name__ == "__main__":
main()