-
Notifications
You must be signed in to change notification settings - Fork 136
/
module.py
115 lines (82 loc) · 4.6 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import functools
import tensorflow as tf
import tensorflow.contrib.slim as slim
import utils
conv = functools.partial(slim.conv2d, activation_fn=None)
dconv = functools.partial(slim.conv2d_transpose, activation_fn=None)
fc = functools.partial(slim.fully_connected, activation_fn=None)
class UNetGenc:
def __call__(self, x, dim=64, n_downsamplings=5, weight_decay=0.0,
norm_name='batch_norm', training=True, scope='UNetGenc'):
MAX_DIM = 1024
conv_ = functools.partial(conv, weights_regularizer=slim.l2_regularizer(weight_decay))
norm = utils.get_norm_layer(norm_name, training, updates_collections=None)
conv_norm_lrelu = functools.partial(conv_, normalizer_fn=norm, activation_fn=tf.nn.leaky_relu)
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
z = x
zs = []
for i in range(n_downsamplings):
d = min(dim * 2**i, MAX_DIM)
z = conv_norm_lrelu(z, d, 4, 2)
zs.append(z)
# variables and update operations
self.variables = tf.global_variables(scope)
self.trainable_variables = tf.trainable_variables(scope)
self.reg_losses = tf.losses.get_regularization_losses(scope)
return zs
class UNetGdec:
def __call__(self, zs, a, dim=64, n_upsamplings=5, shortcut_layers=1, inject_layers=1, weight_decay=0.0,
norm_name='batch_norm', training=True, scope='UNetGdec'):
MAX_DIM = 1024
dconv_ = functools.partial(dconv, weights_regularizer=slim.l2_regularizer(weight_decay))
norm = utils.get_norm_layer(norm_name, training, updates_collections=None)
dconv_norm_relu = functools.partial(dconv_, normalizer_fn=norm, activation_fn=tf.nn.relu)
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
a = tf.to_float(a)
z = utils.tile_concat(zs[-1], a)
for i in range(n_upsamplings - 1):
d = min(dim * 2**(n_upsamplings - 1 - i), MAX_DIM)
z = dconv_norm_relu(z, d, 4, 2)
if shortcut_layers > i:
z = utils.tile_concat([z, zs[-2 - i]])
if inject_layers > i:
z = utils.tile_concat(z, a)
x = tf.nn.tanh(dconv_(z, 3, 4, 2))
# variables and update operations
self.variables = tf.global_variables(scope)
self.trainable_variables = tf.trainable_variables(scope)
self.reg_losses = tf.losses.get_regularization_losses(scope)
return x
class ConvD:
def __call__(self, x, n_atts, dim=64, fc_dim=1024, n_downsamplings=5, weight_decay=0.0,
norm_name='instance_norm', training=True, scope='ConvD'):
MAX_DIM = 1024
conv_ = functools.partial(conv, weights_regularizer=slim.l2_regularizer(weight_decay))
fc_ = functools.partial(fc, weights_regularizer=slim.l2_regularizer(weight_decay))
norm = utils.get_norm_layer(norm_name, training, updates_collections=None)
conv_norm_lrelu = functools.partial(conv_, normalizer_fn=norm, activation_fn=tf.nn.leaky_relu)
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
z = x
for i in range(n_downsamplings):
d = min(dim * 2**i, MAX_DIM)
z = conv_norm_lrelu(z, d, 4, 2)
z = slim.flatten(z)
logit_gan = tf.nn.leaky_relu(fc_(z, fc_dim))
logit_gan = fc_(logit_gan, 1)
logit_att = tf.nn.leaky_relu(fc_(z, fc_dim))
logit_att = fc_(logit_att, n_atts)
# variables and update operations
self.variables = tf.global_variables(scope)
self.trainable_variables = tf.trainable_variables(scope)
self.reg_losses = tf.losses.get_regularization_losses(scope)
return logit_gan, logit_att
def get_model(name, n_atts, weight_decay=0.0):
if name in ['model_128', 'model_256']:
Genc = functools.partial(UNetGenc(), dim=64, n_downsamplings=5, weight_decay=weight_decay)
Gdec = functools.partial(UNetGdec(), dim=64, n_upsamplings=5, shortcut_layers=1, inject_layers=1, weight_decay=weight_decay)
D = functools.partial(ConvD(), n_atts=n_atts, dim=64, fc_dim=1024, n_downsamplings=5, weight_decay=weight_decay)
elif name == 'model_384':
Genc = functools.partial(UNetGenc(), dim=48, n_downsamplings=5, weight_decay=weight_decay)
Gdec = functools.partial(UNetGdec(), dim=48, n_upsamplings=5, shortcut_layers=1, inject_layers=1, weight_decay=weight_decay)
D = functools.partial(ConvD(), n_atts=n_atts, dim=48, fc_dim=512, n_downsamplings=5, weight_decay=weight_decay)
return Genc, Gdec, D