-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathmodel.py
81 lines (64 loc) · 4.22 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
#import tensorflow.contrib.slim as slim
import tf_slim as slim
#from tensorflow.contrib.layers.python.layers import initializers
from msia_BN_3_M import *
def lrelu(x, trainbable=None):
return tf.maximum(x*0.2,x)
def upsample_and_concat(x1, x2, output_channels, in_channels, scope_name, trainable=True):
with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
pool_size = 2
deconv_filter = tf.get_variable('weights', [pool_size, pool_size, output_channels, in_channels], trainable= True)
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2) , strides=[1, pool_size, pool_size, 1], name=scope_name)
deconv_output = tf.concat([deconv, x2],3)
deconv_output.set_shape([None, None, None, output_channels*2])
return deconv_output
def DecomNet(input):
with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):
conv1=slim.conv2d(input,32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv1_1')
pool1=slim.max_pool2d(conv1, [2, 2], stride = 2, padding='SAME' )
conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv2_1')
pool2=slim.max_pool2d(conv2, [2, 2], stride = 2, padding='SAME' )
conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='g_conv3_1')
up8 = upsample_and_concat( conv3, conv2, 64, 128 , 'g_up_1')
conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv8_1')
up9 = upsample_and_concat( conv8, conv1, 32, 64 , 'g_up_2')
conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv9_1')
# Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
conv10=slim.conv2d(conv9,3,[1,1], rate=1, activation_fn=None, scope='g_conv10')
R_out = tf.sigmoid(conv10)
l_conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='l_conv1_2')
l_conv3=tf.concat([l_conv2, conv9],3)
# Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
l_conv4=slim.conv2d(l_conv3,1,[1,1], rate=1, activation_fn=None,scope='l_conv1_4')
L_out = tf.sigmoid(l_conv4)
return R_out, L_out
def Restoration_net(input_r, input_i, training = True):
with tf.variable_scope('Denoise_Net', reuse=tf.AUTO_REUSE):
conv1=slim.conv2d(input_r, 32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_1')
conv1=slim.conv2d(conv1,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_2')
msia_1 = msia_3_M(conv1, input_i, name='de_conv1', training=training)#, name='de_conv1_22')
conv2=slim.conv2d(msia_1,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_1')
conv2=slim.conv2d(conv2,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_2')
msia_2 = msia_3_M(conv2, input_i, name='de_conv2', training=training)
conv3=slim.conv2d(msia_2,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_1')
conv3=slim.conv2d(conv3,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_2')
msia_3 = msia_3_M(conv3, input_i, name='de_conv3', training=training)
conv4=slim.conv2d(msia_3,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_1')
conv4=slim.conv2d(conv4,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_2')
msia_4 = msia_3_M(conv4, input_i, name='de_conv4', training=training)
conv5=slim.conv2d(msia_4,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_1')
conv10=slim.conv2d(conv5,3,[3,3], rate=1, activation_fn=None, scope='de_conv10')
out = tf.sigmoid(conv10)
return out
def Illumination_adjust_net(input_i, input_ratio):
with tf.variable_scope('I_enhance_Net', reuse=tf.AUTO_REUSE):
input_all = tf.concat([input_i, input_ratio], 3)
conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='conv_1')
conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='conv_2')
conv3=slim.conv2d(conv2,32,[3,3], rate=1, activation_fn=lrelu,scope='conv_3')
conv4=slim.conv2d(conv3,1,[3,3], rate=1, activation_fn=lrelu,scope='conv_4')
L_enhance = tf.sigmoid(conv4)
return L_enhance