-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathtutorial_binarynet_mnist_cnn.py
106 lines (83 loc) · 3.98 KB
/
tutorial_binarynet_mnist_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#! /usr/bin/python
# -*- coding: utf-8 -*-
import time
import numpy as np
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import (BatchNorm, BinaryConv2d, BinaryDense, Flatten, Input, MaxPool2d, Sign)
from tensorlayer.models import Model
tl.logging.set_verbosity(tl.logging.DEBUG)
X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
batch_size = 128
def model(inputs_shape, n_class=10):
# In BNN, all the layers inputs are binary, with the exception of the first layer.
# ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py
net_in = Input(inputs_shape, name='input')
net = BinaryConv2d(32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1')(net_in)
net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1')(net)
net = BatchNorm(act=tl.act.htanh, name='bn1')(net)
net = Sign("sign1")(net)
net = BinaryConv2d(64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2')(net)
net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2')(net)
net = BatchNorm(act=tl.act.htanh, name='bn2')(net)
net = Flatten('ft')(net)
net = Sign("sign2")(net)
net = BinaryDense(256, b_init=None, name='dense')(net)
net = BatchNorm(act=tl.act.htanh, name='bn3')(net)
net = Sign("sign3")(net)
net = BinaryDense(10, b_init=None, name='bout')(net)
net = BatchNorm(name='bno')(net)
net = Model(inputs=net_in, outputs=net, name='binarynet')
return net
def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None):
with tf.GradientTape() as tape:
y_pred = network(X_batch)
_loss = cost(y_pred, y_batch)
grad = tape.gradient(_loss, network.trainable_weights)
train_op.apply_gradients(zip(grad, network.trainable_weights))
if acc is not None:
_acc = acc(y_pred, y_batch)
return _loss, _acc
else:
return _loss, None
def accuracy(_logits, y_batch):
return np.mean(np.equal(np.argmax(_logits, 1), y_batch))
n_epoch = 200
print_freq = 5
net = model([None, 28, 28, 1])
train_op = tf.optimizers.Adam(learning_rate=0.0001)
cost = tl.cost.cross_entropy
for epoch in range(n_epoch):
start_time = time.time()
train_loss, train_acc, n_batch = 0, 0, 0
net.train()
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
_loss, acc = _train_step(net, X_train_a, y_train_a, cost=cost, train_op=train_op, acc=accuracy)
train_loss += _loss
train_acc += acc
n_batch += 1
# print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
# print(" train loss: %f" % (train_loss / n_batch))
# print(" train acc: %f" % (train_acc / n_batch))
if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
print(" train loss: %f" % (train_loss / n_batch))
print(" train acc: %f" % (train_acc / n_batch))
val_loss, val_acc, val_batch = 0, 0, 0
net.eval()
for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True):
_logits = net(X_val_a)
val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss')
val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a))
val_batch += 1
print(" val loss: {}".format(val_loss / val_batch))
print(" val acc: {}".format(val_acc / val_batch))
net.test()
test_loss, test_acc, n_test_batch = 0, 0, 0
for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True):
_logits = net(X_test_a)
test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss')
test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a))
n_test_batch += 1
print(" test loss: %f" % (test_loss / n_test_batch))
print(" test acc: %f" % (test_acc / n_test_batch))