-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathnets.py
150 lines (113 loc) · 5.39 KB
/
nets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
## -*- coding: utf-8 -*-
import tensorflow as tf
from utils import BatchNorm, Conv3D
stp = [[0,0], [1,1], [1,1], [1,1], [0,0]]
sp = [[0,0], [0,0], [1,1], [1,1], [0,0]]
def FR_16L(x, is_train, uf=4):
x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,3,64], [1,1,1,1,1], 'VALID', name='conv1')
F = 64
G = 32
for r in range(3):
t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
t = tf.nn.relu(t)
t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
t = tf.nn.relu(t)
t = Conv3D(tf.pad(t, stp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
x = tf.concat([x, t], 4)
F += G
for r in range(3,6):
t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
t = tf.nn.relu(t)
t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
t = tf.nn.relu(t)
t = Conv3D(tf.pad(t, sp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
x = tf.concat([x[:,1:-1], t], 4)
F += G
x = BatchNorm(x, is_train, name='fbn1')
x = tf.nn.relu(x)
x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,256,256], [1,1,1,1,1], 'VALID', name='conv2')
x = tf.nn.relu(x)
r = Conv3D(x, [1,1,1,256,256], [1,1,1,1,1], 'VALID', name='rconv1')
r = tf.nn.relu(r)
r = Conv3D(r, [1,1,1,256,3*uf*uf], [1,1,1,1,1], 'VALID', name='rconv2')
f = Conv3D(x, [1,1,1,256,512], [1,1,1,1,1], 'VALID', name='fconv1')
f = tf.nn.relu(f)
f = Conv3D(f, [1,1,1,512,1*5*5*uf*uf], [1,1,1,1,1], 'VALID', name='fconv2')
ds_f = tf.shape(f)
f = tf.reshape(f, [ds_f[0], ds_f[1], ds_f[2], ds_f[3], 25, uf*uf])
f = tf.nn.softmax(f, dim=4)
return f, r
def FR_28L(x, is_train, uf=4):
x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,3,64], [1,1,1,1,1], 'VALID', name='conv1')
F = 64
G = 16
for r in range(9):
t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
t = tf.nn.relu(t)
t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
t = tf.nn.relu(t)
t = Conv3D(tf.pad(t, stp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
x = tf.concat([x, t], 4)
F += G
for r in range(9,12):
t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
t = tf.nn.relu(t)
t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
t = tf.nn.relu(t)
t = Conv3D(tf.pad(t, sp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
x = tf.concat([x[:,1:-1], t], 4)
F += G
x = BatchNorm(x, is_train, name='fbn1')
x = tf.nn.relu(x)
x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,256,256], [1,1,1,1,1], 'VALID', name='conv2')
x = tf.nn.relu(x)
r = Conv3D(x, [1,1,1,256,256], [1,1,1,1,1], 'VALID', name='rconv1')
r = tf.nn.relu(r)
r = Conv3D(r, [1,1,1,256,3*uf*uf], [1,1,1,1,1], 'VALID', name='rconv2')
f = Conv3D(x, [1,1,1,256,512], [1,1,1,1,1], 'VALID', name='fconv1')
f = tf.nn.relu(f)
f = Conv3D(f, [1,1,1,512,1*5*5*uf*uf], [1,1,1,1,1], 'VALID', name='fconv2')
ds_f = tf.shape(f)
f = tf.reshape(f, [ds_f[0], ds_f[1], ds_f[2], ds_f[3], 25, uf*uf])
f = tf.nn.softmax(f, dim=4)
return f, r
def FR_52L(x, is_train, uf=4):
x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,3,64], [1,1,1,1,1], 'VALID', name='conv1')
F = 64
G = 16
for r in range(0,21):
t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
t = tf.nn.relu(t)
t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
t = tf.nn.relu(t)
t = Conv3D(tf.pad(t, stp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
x = tf.concat([x, t], 4)
F += G
for r in range(21,24):
t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
t = tf.nn.relu(t)
t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
t = tf.nn.relu(t)
t = Conv3D(tf.pad(t, sp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
x = tf.concat([x[:,1:-1], t], 4)
F += G
x = BatchNorm(x, is_train, name='fbn1')
x = tf.nn.relu(x)
x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,448,256], [1,1,1,1,1], 'VALID', name='conv2')
x = tf.nn.relu(x)
r = Conv3D(x, [1,1,1,256,256], [1,1,1,1,1], 'VALID', name='rconv1')
r = tf.nn.relu(r)
r = Conv3D(r, [1,1,1,256,3*uf*uf], [1,1,1,1,1], 'VALID', name='rconv2')
f = Conv3D(x, [1,1,1,256,512], [1,1,1,1,1], 'VALID', name='fconv1')
f = tf.nn.relu(f)
f = Conv3D(f, [1,1,1,512,1*5*5*uf*uf], [1,1,1,1,1], 'VALID', name='fconv2')
ds_f = tf.shape(f)
f = tf.reshape(f, [ds_f[0], ds_f[1], ds_f[2], ds_f[3], 25, uf*uf])
f = tf.nn.softmax(f, dim=4)
return f, r