-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathmodel.py
100 lines (77 loc) · 5.25 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
#import tensorflow.contrib.slim as slim
import tf_slim as slim
#from tensorflow.contrib.layers.python.layers import initializers
def lrelu(x, trainbable=None):
return tf.maximum(x*0.2,x)
def upsample_and_concat(x1, x2, output_channels, in_channels, scope_name, trainable=True):
with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
pool_size = 2
deconv_filter = tf.get_variable('weights', [pool_size, pool_size, output_channels, in_channels], trainable= True)
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2) , strides=[1, pool_size, pool_size, 1], name=scope_name)
deconv_output = tf.concat([deconv, x2],3)
deconv_output.set_shape([None, None, None, output_channels*2])
return deconv_output
def DecomNet_simple(input):
with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):
conv1=slim.conv2d(input,32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv1_1')
pool1=slim.max_pool2d(conv1, [2, 2], stride = 2, padding='SAME' )
conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv2_1')
pool2=slim.max_pool2d(conv2, [2, 2], stride = 2, padding='SAME' )
conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='g_conv3_1')
up8 = upsample_and_concat( conv3, conv2, 64, 128 , 'g_up_1')
conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv8_1')
up9 = upsample_and_concat( conv8, conv1, 32, 64 , 'g_up_2')
conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv9_1')
# Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
conv10=slim.conv2d(conv9,3,[1,1], rate=1, activation_fn=None, scope='g_conv10')
R_out = tf.sigmoid(conv10)
l_conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='l_conv1_2')
l_conv3=tf.concat([l_conv2, conv9],3)
# Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
l_conv4=slim.conv2d(l_conv3,1,[1,1], rate=1, activation_fn=None,scope='l_conv1_4')
L_out = tf.sigmoid(l_conv4)
return R_out, L_out
def Restoration_net(input_r, input_i):
with tf.variable_scope('Restoration_net', reuse=tf.AUTO_REUSE):
input_all = tf.concat([input_r,input_i], 3)
conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_1')
conv1=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_2')
pool1=slim.max_pool2d(conv1, [2, 2], padding='SAME' )
conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_1')
conv2=slim.conv2d(conv2,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_2')
pool2=slim.max_pool2d(conv2, [2, 2], padding='SAME' )
conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_1')
conv3=slim.conv2d(conv3,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_2')
pool3=slim.max_pool2d(conv3, [2, 2], padding='SAME' )
conv4=slim.conv2d(pool3,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_1')
conv4=slim.conv2d(conv4,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_2')
pool4=slim.max_pool2d(conv4, [2, 2], padding='SAME' )
conv5=slim.conv2d(pool4,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_1')
conv5=slim.conv2d(conv5,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_2')
up6 = upsample_and_concat( conv5, conv4, 256, 512, 'up_6')
conv6=slim.conv2d(up6, 256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_1')
conv6=slim.conv2d(conv6,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_2')
up7 = upsample_and_concat( conv6, conv3, 128, 256, 'up_7' )
conv7=slim.conv2d(up7, 128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_1')
conv7=slim.conv2d(conv7,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_2')
up8 = upsample_and_concat( conv7, conv2, 64, 128, 'up_8' )
conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_1')
conv8=slim.conv2d(conv8,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_2')
up9 = upsample_and_concat( conv8, conv1, 32, 64, 'up_9' )
conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_1')
conv9=slim.conv2d(conv9,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_2')
conv10=slim.conv2d(conv9,3,[3,3], rate=1, activation_fn=None, scope='de_conv10')
out = tf.sigmoid(conv10)
return out
def Illumination_adjust_net(input_i, input_ratio):
with tf.variable_scope('Illumination_adjust_net', reuse=tf.AUTO_REUSE):
input_all = tf.concat([input_i, input_ratio], 3)
conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_1')
conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_2')
conv3=slim.conv2d(conv2,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_3')
conv4=slim.conv2d(conv3,1,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_4')
L_enhance = tf.sigmoid(conv4)
return L_enhance