-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtrain_Meso.py
122 lines (113 loc) · 4.72 KB
/
train_Meso.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler
import argparse
import os
import cv2
from torchvision import datasets, models, transforms
from network.classifier import *
from network.transform import mesonet_data_transforms
def main():
args = parse.parse_args()
name = args.name
train_path = args.train_path
val_path = args.val_path
continue_train = args.continue_train
epoches = args.epoches
batch_size = args.batch_size
model_name = args.model_name
model_path = args.model_path
output_path = os.path.join('./output', name)
if not os.path.exists(output_path):
os.mkdir(output_path)
torch.backends.cudnn.benchmark=True
#creat train and val dataloader
train_dataset = torchvision.datasets.ImageFolder(train_path, transform=mesonet_data_transforms['train'])
val_dataset = torchvision.datasets.ImageFolder(val_path, transform=mesonet_data_transforms['val'])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8)
train_dataset_size = len(train_dataset)
val_dataset_size = len(val_dataset)
#Creat the model
model = Meso4()
if continue_train:
model.load_state_dict(torch.load(model_path))
model = model.cuda()
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
#Train the model using multiple GPUs
#model = nn.DataParallel(model)
best_model_wts = model.state_dict()
best_acc = 0.0
iteration = 0
for epoch in range(epoches):
print('Epoch {}/{}'.format(epoch+1, epoches))
print('-'*10)
model=model.train()
train_loss = 0.0
train_corrects = 0.0
val_loss = 0.0
val_corrects = 0.0
for (image, labels) in train_loader:
iter_loss = 0.0
iter_corrects = 0.0
image = image.cuda()
labels = labels.cuda()
optimizer.zero_grad()
outputs = model(image)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
iter_loss = loss.data.item()
train_loss += iter_loss
iter_corrects = torch.sum(preds == labels.data).to(torch.float32)
train_corrects += iter_corrects
iteration += 1
if not (iteration % 20):
print('iteration {} train loss: {:.4f} Acc: {:.4f}'.format(iteration, iter_loss / batch_size, iter_corrects / batch_size))
epoch_loss = train_loss / train_dataset_size
epoch_acc = train_corrects / train_dataset_size
print('epoch train loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
model.eval()
with torch.no_grad():
for (image, labels) in val_loader:
image = image.cuda()
labels = labels.cuda()
outputs = model(image)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
val_loss += loss.data.item()
val_corrects += torch.sum(preds == labels.data).to(torch.float32)
epoch_loss = val_loss / val_dataset_size
epoch_acc = val_corrects / val_dataset_size
print('epoch val loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
if epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = model.state_dict()
scheduler.step()
if not (epoch % 10):
#Save the model trained with multiple gpu
#torch.save(model.module.state_dict(), os.path.join(output_path, str(epoch) + '_' + model_name))
torch.save(model.state_dict(), os.path.join(output_path, str(epoch) + '_' + model_name))
print('Best val Acc: {:.4f}'.format(best_acc))
model.load_state_dict(best_model_wts)
#torch.save(model.module.state_dict(), os.path.join(output_path, "best.pkl"))
torch.save(model.state_dict(), os.path.join(output_path, "best.pkl"))
if __name__ == '__main__':
parse = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parse.add_argument('--name', '-n', type=str, default='Mesonet')
parse.add_argument('--train_path', '-tp' , type=str, default = './deepfake_database/train')
parse.add_argument('--val_path', '-vp' , type=str, default = './deepfake_database/val')
parse.add_argument('--batch_size', '-bz', type=int, default=64)
parse.add_argument('--epoches', '-e', type=int, default='50')
parse.add_argument('--model_name', '-mn', type=str, default='meso4.pkl')
parse.add_argument('--continue_train', type=bool, default=False)
parse.add_argument('--model_path', '-mp', type=str, default='./output/Mesonet/best.pkl')
main()