-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain.py
92 lines (69 loc) · 3.37 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from models import Generator, Discriminator
from keras.optimizers import Adam
from keras.layers import Input
from keras.models import Model
from dataloader import DataLoader
import os
import numpy as np
def train(args):
# optimizers
dis_optim = Adam(lr=args.discriminator_lr, beta_1=args.beta)
gen_optim = Adam(lr=args.generator_lr, beta_1=args.beta)
# discrminator
discriminator = Discriminator(args).build_discriminator()
print('Discriminator...')
discriminator.summary()
discriminator.compile(loss='binary_crossentropy', optimizer=dis_optim)
# generator
generator = Generator(args).build_generator()
print('Generator')
generator.summary()
z = Input(shape=(1, 1, 1, args.latent_dim))
img = generator(z)
# make discriminator not trainable
discriminator.trainable = False
validity = discriminator(img)
combined = Model(input=z, output=validity)
combined.compile(loss='binary_crossentropy', optimizer=gen_optim)
# load data
data_loader = DataLoader(args)
X_train = np.array(data_loader.load_data()).astype(np.float32)
dl, gl = [],[]
for epoch in range(args.num_epochs):
#sample a random batch
idx = np.random.randint(len(X_train), size=args.batch_size)
# print('Sampling indices...' + str(idx))
real = X_train[idx]
z = np.random.normal(0, 0.33, size=[args.batch_size, 1, 1, 1, args.latent_dim]).astype(np.float32)
fake = generator.predict(z)
real = np.expand_dims(real, axis=4)
# eval_ = np.concatenate((real, fake))
lab_real = np.reshape([1] * args.batch_size, (-1, 1, 1, 1, 1))
lab_fake = np.reshape([0] * args.batch_size, (-1, 1, 1, 1, 1))
# print(lab_real.shape)
# calculate discrminator loss
d_loss_real = discriminator.train_on_batch(real, lab_real)
d_loss_fake = discriminator.train_on_batch(fake, lab_fake)
d_loss = 0.5*np.add(d_loss_real, d_loss_fake)
z = np.random.normal(0, 0.33, size=[args.batch_size, 1, 1, 1, args.latent_dim]).astype(np.float32)
# calculate generator loss
g_loss = combined.train_on_batch(z, np.reshape([1] * args.batch_size, (-1, 1, 1, 1, 1))).astype(np.float64)
dl.append(d_loss)
gl.append(g_loss)
avg_d_loss = round(sum(dl)/len(dl), 4)
avg_g_loss = round(sum(gl)/len(gl), 4)
print('Training epoch {}/{}, d_loss_real/avg: {}/{}, g_loss/avg: {}/{}'.format(epoch+1, args.num_epochs, round(d_loss, 4), avg_d_loss, round(g_loss, 4), avg_g_loss))
# sampling
if epoch % args.sample_epoch == 0:
if not os.path.exists(args.sample_path):
os.makedirs(args.sample_path)
print('Sampling...')
sample_noise = np.random.normal(0, 0.33, size=[args.batch_size, 1, 1, 1, args.latent_dim]).astype(np.float32)
generated_volumes = generator.predict(sample_noise, verbose=1)
generated_volumes.dump(args.sample_path + '/sample_epoch_' + str(epoch+1) + '.npy')
# save weights
if epoch % args.save_epoch == 0:
if not os.path.exists(args.checkpoints_path):
os.makedirs(args.checkpoints_path)
generator.save_weights(args.checkpoints_path + '/generator_epoch_' + str(epoch+1), True)
discriminator.save_weights(args.checkpoints_path + '/discriminator_epoch_' + str(epoch+1), True)