Updates: Aug 9, 2018 night
- Change naming to formatted
- Add time counting
- Fix logger, now will not incorrectly write log info to the first logger
Updates: Aug 9, 2018
- Now random pixels and shuffle pixels apply to validation set
- Add start from specific epoch, not need to start from beginning
- Fix the bug of random pixels, change type to uint8
Updates: Modification:
- Add shuffle pixels
- Add random pixels
- Add save models
- Change default learning rate to 0.01
- Now by default, there is no learning rate decay
- Now run different corruption levels in loop, from 0 to 1 with spacing 0.1
- Now loggings have their name instead of root
Parameters change:
- Add, label-corrupt, bool type, set to True for corrput labels
- Add, pixel-corrupt, bool type, set to True for corrput pixels
- Add, pixel-shuffle, bool type, set to True for shuffle pixels
- Add, adjust-lr, bool type, set to True for learning rate decay by factor 0.1 at epoch 150,225
- Delete, label-corrupt-prob
- Add save-every, int type, set to int value for save model every such epoches
- Change, learning-rate, change default value from 0.1 to 0.01
Original
This repo contains simple demo code for the following paper, to train over-parameterized models on random label CIFAR-10 datasets.
Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, Oriol Vinyals. Understanding deep learning requires rethinking generalization. International Conference on Learning Representations (ICLR), 2017. [arXiv:1611.03530].
The original code for the paper belongs to Google, and there seems to be no need to go through all the legal processes to open source them as the experiments are rather straightforward to implement and reproduce.
However, since people have been asking for the code. In order for people to easily get started, I created this repo to demonstrate how one can take an existing implementation of commonly used successful models (here we use an implementation of Wide Resnets in pytorch) and fit them to random labels. We are not trying to reproduce exactly the same experiments (e.g. with Inception and Alexnet and corresponding hyper parameters) in the paper. This repo is completely written from scratch, and has no association with Google.
If you are trying to reproduce the fitting random label results yourself but have some troubles, here are some tips that might be helpful:
- Try to run the examples in this repo directly (pytorch is needed) and modify based on this.
- Check if the randomization of labels is fixed and consistent throughout epochs. Check if your data augmentation is turned off. It is also good to start playing with all the regularizers (e.g. dropout, weight decay) turned off.
- Try to run more epochs. Random labels could take more epochs to fit.
- If you are trying to reproduce the results on MNIST, you might need to try harder. MNIST contains digits that are very similar even in the pixel space. So assigning different (random) labels to those visually similar digits makes it quite hard to fit. So you might need to run more epochs or tune the learning rate a bit depending on the specific architecture you use.
train.py
: main command line interface and training loops.model_wideresnet.py
andmodel_mlp.py
: model definition for Wide Resnets and MLPs.cifar10_data.py
: a thin wrapper of torchvision's CIFAR-10 dataset to support (full or partial) random label corruption.cmd_args.py
: command line argument parsing.
Show all the program arguments
python train.py --help
python train.py
and here are some sample outputs from my local run:
Number of parameters: 369498
000: Acc-tr: 58.49, Acc-val: 56.92, L-tr: 1.1574, L-val: 1.1857
001: Acc-tr: 69.64, Acc-val: 68.24, L-tr: 0.8488, L-val: 0.8980
002: Acc-tr: 76.29, Acc-val: 73.24, L-tr: 0.6726, L-val: 0.7650
003: Acc-tr: 73.34, Acc-val: 70.56, L-tr: 0.7899, L-val: 0.9145
004: Acc-tr: 82.07, Acc-val: 77.42, L-tr: 0.5031, L-val: 0.6565
005: Acc-tr: 84.26, Acc-val: 79.33, L-tr: 0.4427, L-val: 0.6233
...
149: Acc-tr: 94.11, Acc-val: 80.31, L-tr: 0.1755, L-val: 0.9136
150: Acc-tr: 99.97, Acc-val: 85.87, L-tr: 0.0061, L-val: 0.5876
151: Acc-tr: 100.00, Acc-val: 86.31, L-tr: 0.0034, L-val: 0.5824
152: Acc-tr: 100.00, Acc-val: 86.25, L-tr: 0.0025, L-val: 0.5874
...
166: Acc-tr: 100.00, Acc-val: 86.44, L-tr: 0.0007, L-val: 0.6017
167: Acc-tr: 100.00, Acc-val: 86.52, L-tr: 0.0006, L-val: 0.6050
...
298: Acc-tr: 100.00, Acc-val: 86.43, L-tr: 0.0005, L-val: 0.5649
299: Acc-tr: 100.00, Acc-val: 86.41, L-tr: 0.0004, L-val: 0.5636
python train.py --label-corrupt-prob=1.0
and here are some sample outputs from my local run:
Number of parameters: 369498
000: Acc-tr: 10.30, Acc-val: 9.91, L-tr: 2.3105, L-val: 2.3129
001: Acc-tr: 10.55, Acc-val: 9.86, L-tr: 2.3038, L-val: 2.3051
002: Acc-tr: 11.07, Acc-val: 10.28, L-tr: 2.3023, L-val: 2.3052
003: Acc-tr: 10.96, Acc-val: 10.16, L-tr: 2.3002, L-val: 2.3043
004: Acc-tr: 10.87, Acc-val: 10.04, L-tr: 2.2993, L-val: 2.3054
005: Acc-tr: 11.40, Acc-val: 9.61, L-tr: 2.3000, L-val: 2.3071
...
038: Acc-tr: 48.70, Acc-val: 10.00, L-tr: 1.4724, L-val: 3.7426
039: Acc-tr: 54.72, Acc-val: 9.83, L-tr: 1.3039, L-val: 3.6905
040: Acc-tr: 63.80, Acc-val: 10.34, L-tr: 1.1148, L-val: 3.6388
041: Acc-tr: 63.55, Acc-val: 10.08, L-tr: 1.0545, L-val: 3.6621
...
147: Acc-tr: 67.70, Acc-val: 10.24, L-tr: 1.0133, L-val: 6.0229
148: Acc-tr: 76.09, Acc-val: 9.82, L-tr: 0.7517, L-val: 5.8342
149: Acc-tr: 77.35, Acc-val: 9.75, L-tr: 0.6599, L-val: 5.4791
150: Acc-tr: 99.00, Acc-val: 9.90, L-tr: 0.0809, L-val: 5.5166
151: Acc-tr: 99.87, Acc-val: 9.92, L-tr: 0.0398, L-val: 5.9411
152: Acc-tr: 99.97, Acc-val: 9.89, L-tr: 0.0240, L-val: 6.2846
153: Acc-tr: 99.98, Acc-val: 9.94, L-tr: 0.0172, L-val: 6.5499
154: Acc-tr: 100.00, Acc-val: 9.94, L-tr: 0.0121, L-val: 6.7529
155: Acc-tr: 100.00, Acc-val: 10.16, L-tr: 0.0093, L-val: 6.9280
...
172: Acc-tr: 100.00, Acc-val: 9.95, L-tr: 0.0019, L-val: 7.8329
173: Acc-tr: 100.00, Acc-val: 9.84, L-tr: 0.0018, L-val: 7.8759
...
297: Acc-tr: 100.00, Acc-val: 9.97, L-tr: 0.0009, L-val: 7.6774
298: Acc-tr: 100.00, Acc-val: 10.04, L-tr: 0.0009, L-val: 7.6958
299: Acc-tr: 100.00, Acc-val: 9.99, L-tr: 0.0008, L-val: 7.7216
Train a MLP with 1 hidden layer and 512 hidden units with weight decay on random label CIFAR-10 dataset:
python train.py --arch=mlp --mlp-spec=512 --label-corrupt-prob=1.0 --learning-rate=0.01
and here are some sample outputs from my local run:
Number of parameters: 1577984
000: Acc-tr: 12.59, Acc-val: 10.15, L-tr: 2.3094, L-val: 2.3680
001: Acc-tr: 15.92, Acc-val: 9.78, L-tr: 2.2653, L-val: 2.3575
002: Acc-tr: 17.73, Acc-val: 9.36, L-tr: 2.2402, L-val: 2.3735
003: Acc-tr: 18.64, Acc-val: 9.70, L-tr: 2.2231, L-val: 2.4024
004: Acc-tr: 20.90, Acc-val: 9.52, L-tr: 2.2103, L-val: 2.4279
...
220: Acc-tr: 99.96, Acc-val: 9.54, L-tr: 0.0147, L-val: 10.2900
221: Acc-tr: 99.96, Acc-val: 9.58, L-tr: 0.0146, L-val: 10.2958
222: Acc-tr: 99.96, Acc-val: 9.54, L-tr: 0.0145, L-val: 10.2980
223: Acc-tr: 99.96, Acc-val: 9.55, L-tr: 0.0143, L-val: 10.3036
224: Acc-tr: 99.96, Acc-val: 9.57, L-tr: 0.0142, L-val: 10.3074
...
295: Acc-tr: 99.97, Acc-val: 9.55, L-tr: 0.0134, L-val: 10.3331
296: Acc-tr: 99.97, Acc-val: 9.56, L-tr: 0.0134, L-val: 10.3336
297: Acc-tr: 99.97, Acc-val: 9.57, L-tr: 0.0134, L-val: 10.3340
298: Acc-tr: 99.97, Acc-val: 9.59, L-tr: 0.0134, L-val: 10.3344
299: Acc-tr: 99.97, Acc-val: 9.57, L-tr: 0.0134, L-val: 10.3347
In Section 5 of the paper, we talked about SGD implicit regularizing by finding the minimum norm solution in an over-parameterized linear problem with the simple square loss. A frequently asked question is how the experiments on MNIST is conducted, since it has 60,000 training examples, but only 28x28 = 784 features. As mentioned in the paper, the experiments are actually carried out by applying the "kernel trick" to Equation (3) in the paper (for both MNIST and CIFAR-10, with or without pre-processing). We are attaching a sample code for solving the MNIST raw pixel problem for reference:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
import numpy as np
import scipy.spatial
import scipy.linalg
def onehot_encode(y, n_cls=10):
y = np.array(y, dtype=int)
return np.array(np.eye(n_cls)[y], dtype=float)
data = np.array(mnist.data, dtype=float) / 255
labels = onehot_encode(mnist.target)
n_tr = 60000
n_tot = 70000
x_tr = data[:n_tr]; y_tr = labels[:n_tr]
x_tt = data[n_tr:n_tot]; y_tt = labels[n_tr:n_tot]
bw=2.0e-2
pdist_tr = scipy.spatial.distance.pdist(x_tr, 'sqeuclidean')
pdist_tr = scipy.spatial.distance.squareform(pdist_tr)
cdist_tt = scipy.spatial.distance.cdist(x_tt, x_tr, 'sqeuclidean')
coeff = scipy.linalg.solve(np.exp(-bw*pdist_tr), y_tr)
preds = np.argmax(np.dot(np.exp(-bw*cdist_tt), coeff), axis=1)
acc = float(np.sum(np.equal(preds, mnist.target[n_tr:n_tot]))) / (n_tot-n_tr)
print('err = %.2f%%' % (100*(1-acc)))
# err = 1.22%