-
Notifications
You must be signed in to change notification settings - Fork 8.2k
/
Copy pathmain.py
189 lines (141 loc) · 6.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import torch
import torchvision
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
# ================================================================== #
# Table of Contents #
# ================================================================== #
# 1. Basic autograd example 1 (Line 25 to 39)
# 2. Basic autograd example 2 (Line 46 to 83)
# 3. Loading data from numpy (Line 90 to 97)
# 4. Input pipline (Line 104 to 129)
# 5. Input pipline for custom dataset (Line 136 to 156)
# 6. Pretrained model (Line 163 to 176)
# 7. Save and load model (Line 183 to 189)
# ================================================================== #
# 1. Basic autograd example 1 #
# ================================================================== #
# Create tensors.
x = torch.tensor(1., requires_grad=True)
w = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)
# Build a computational graph.
y = w * x + b # y = 2 * x + 3
# Compute gradients.
y.backward()
# Print out the gradients.
print(x.grad) # x.grad = 2
print(w.grad) # w.grad = 1
print(b.grad) # b.grad = 1
# ================================================================== #
# 2. Basic autograd example 2 #
# ================================================================== #
# Create tensors of shape (10, 3) and (10, 2).
x = torch.randn(10, 3)
y = torch.randn(10, 2)
# Build a fully connected layer.
linear = nn.Linear(3, 2)
print ('w: ', linear.weight)
print ('b: ', linear.bias)
# Build loss function and optimizer.
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)
# Forward pass.
pred = linear(x)
# Compute loss.
loss = criterion(pred, y)
print('loss: ', loss.item())
# Backward pass.
loss.backward()
# Print out the gradients.
print ('dL/dw: ', linear.weight.grad)
print ('dL/db: ', linear.bias.grad)
# 1-step gradient descent.
optimizer.step()
# You can also perform gradient descent at the low level.
# linear.weight.data.sub_(0.01 * linear.weight.grad.data)
# linear.bias.data.sub_(0.01 * linear.bias.grad.data)
# Print out the loss after 1-step gradient descent.
pred = linear(x)
loss = criterion(pred, y)
print('loss after 1 step optimization: ', loss.item())
# ================================================================== #
# 3. Loading data from numpy #
# ================================================================== #
# Create a numpy array.
x = np.array([[1, 2], [3, 4]])
# Convert the numpy array to a torch tensor.
y = torch.from_numpy(x)
# Convert the torch tensor to a numpy array.
z = y.numpy()
# ================================================================== #
# 4. Input pipeline #
# ================================================================== #
# Download and construct CIFAR-10 dataset.
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
# Fetch one data pair (read data from disk).
image, label = train_dataset[0]
print (image.size())
print (label)
# Data loader (this provides queues and threads in a very simple way).
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=64,
shuffle=True)
# When iteration starts, queue and thread start to load data from files.
data_iter = iter(train_loader)
# Mini-batch images and labels.
images, labels = data_iter.next()
# Actual usage of the data loader is as below.
for images, labels in train_loader:
# Training code should be written here.
pass
# ================================================================== #
# 5. Input pipeline for custom dataset #
# ================================================================== #
# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
# TODO
# 1. Initialize file paths or a list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
# You can then use the prebuilt data loader.
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
batch_size=64,
shuffle=True)
# ================================================================== #
# 6. Pretrained model #
# ================================================================== #
# Download and load the pretrained ResNet-18.
resnet = torchvision.models.resnet18(pretrained=True)
# If you want to finetune only the top layer of the model, set as below.
for param in resnet.parameters():
param.requires_grad = False
# Replace the top layer for finetuning.
resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is an example.
# Forward pass.
images = torch.randn(64, 3, 224, 224)
outputs = resnet(images)
print (outputs.size()) # (64, 100)
# ================================================================== #
# 7. Save and load the model #
# ================================================================== #
# Save and load the entire model.
torch.save(resnet, 'model.ckpt')
model = torch.load('model.ckpt')
# Save and load only the model parameters (recommended).
torch.save(resnet.state_dict(), 'params.ckpt')
resnet.load_state_dict(torch.load('params.ckpt'))