-
Notifications
You must be signed in to change notification settings - Fork 3
/
demo.py
65 lines (52 loc) · 1.79 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Author: Acer Zhang
# Datetime: 2021/2/25
# Copyright belongs to the author.
# Please indicate the source for reprinting.
import paddle
from paddle.vision.transforms import Compose, Resize, ToTensor
from paddle.vision.models import resnet50
from paddle.vision.datasets import Cifar100
# 导入RIFLE模块
from paddle_rifle.rifle import RIFLECallback
# 定义数据预处理
transform = Compose([Resize(224),
ToTensor()])
# 加载Cifar100数据集
train_data = Cifar100(transform=transform)
test_data = Cifar100(mode="test", transform=transform)
# 加载Resnet50
net = resnet50(True, num_classes=100)
# 获取Resnet50的输出层
fc_layer = net.fc
"""
# 自定义网络场景下的输出层获取示例
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.layer1 = xxx
self.layer2 = xxx
self.输出层 = paddle.nn.Linear(...)
...
# 实例化Net类
net = Net()
# 获取输出层成员变量
输出层 = net.输出层
"""
model = paddle.Model(network=net,
inputs=paddle.static.InputSpec([3, 224, 224], name="ipt"),
labels=paddle.static.InputSpec([1], dtype="int64", name="lab"))
# 实例化可视化Callback和RIFLE Callback
vdl = paddle.callbacks.VisualDL("./log_RIFLE")
# 若需要用新的随机初始化方式则可在下方指定weight_initializer=paddle.nn.initializer.XavierNormal()
rifle_cb = RIFLECallback(fc_layer, 3, 3)
sgd = paddle.optimizer.SGD(parameters=model.parameters())
loss = paddle.nn.loss.CrossEntropyLoss()
acc = paddle.metric.Accuracy((1, 5))
model.prepare(sgd, loss, acc)
# 开始训练并传入RIFLE Callback
model.fit(train_data,
test_data,
batch_size=128,
epochs=20,
log_freq=200,
callbacks=[vdl, rifle_cb])