-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmodel_deep.lua
137 lines (116 loc) · 5.04 KB
/
model_deep.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
require 'nn'
require 'cunn'
require 'cudnn'
require 'nngraph'
-- cnn: VGG-16 with batch norm
local function create_conv_unit(c1, c2, p, k, d)
local conv = nn.Sequential()
if d == 1 then
conv:add(cudnn.SpatialConvolution(c1, c2, k, k, 1, 1, p, p, 1))
else
conv:add(nn.SpatialDilatedConvolution(c1, c2, k, k, p, p, d, d))
end
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
return conv
end
local function create_conv_2(c1, c2)
dilate = d or 1
local conv = nn.Sequential()
conv:add(cudnn.SpatialConvolution(c1, c2, 3, 3, 1, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
conv:add(cudnn.SpatialConvolution(c2, c2, 3, 3, 1, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
return conv
end
local function create_conv_3(c1, c2)
dilate = d or 1
local conv = nn.Sequential()
conv:add(cudnn.SpatialConvolution(c1, c2, 3, 3, 1, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
conv:add(cudnn.SpatialConvolution(c2, c2, 3, 3, 1, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
conv:add(cudnn.SpatialConvolution(c2, c2, 3, 3, 1, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
return conv
end
local function create_deconv_3(c1, c2)
local conv = nn.Sequential()
conv:add(nn.SpatialFullConvolution(c1, c2, 3, 3, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
conv:add(nn.SpatialFullConvolution(c2, c2, 3, 3, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
conv:add(nn.SpatialFullConvolution(c2, c2, 3, 3, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
return conv
end
local function create_deconv_2(c1, c2)
local conv = nn.Sequential()
conv:add(nn.SpatialFullConvolution(c1, c2, 3, 3, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
conv:add(nn.SpatialFullConvolution(c2, c2, 3, 3, 1, 1, 1, 1))
conv:add(cudnn.SpatialBatchNormalization(c2, nil, nil, nil))
conv:add(cudnn.ReLU(true))
return conv
end
-- local function create_simple_conv(c1,c2,)
local function create_model(config)
-- local conv1_1 = create_conv_unit(3,64,3,1,1)
-- local conv1_2 = create_conv_unit(64,64,3,1,1)
-- local conv2_1 = create_conv_unit(64,128,3,1,1)
-- local conv2_2 = create_conv_unit(128,128,3,1,1)
-- local conv3_1 = create_conv_unit(128,256,3,1,1)
-- local conv3_2 = create_conv_unit(256,256,3,1,1)
-- local conv3_3 = create_conv_unit(256,256,3,1,1)
-- local conv4_1 = create_conv_unit(256,512,3,1,1)
-- local conv4_2 = create_conv_unit(512,512,3,1,1)
-- local conv4_3 = create_conv_unit(512,512,3,1,1)
-- local conv5_1 = create_conv_unit(512,512,3,1,1)
-- local conv5_2 = create_conv_unit(512,512,3,1,1)
-- local conv5_3 = create_conv_unit(512,512,3,1,1)
local input_channel = config.input_channel
local output_channel = config.output_channel
-- encoder
local conv1 = create_conv_2(input_channel, 64)
local conv2 = create_conv_2(64, 128)
local conv3 = create_conv_3(128, 256)
local conv4 = create_conv_3(256, 512)
local conv5 = create_conv_3(512, 512)
local pool1 = nn.SpatialMaxPooling(2, 2, 2, 2, 0, 0)
local pool2 = nn.SpatialMaxPooling(2, 2, 2, 2, 0, 0)
local pool3 = nn.SpatialMaxPooling(2, 2, 2, 2, 0, 0)
local pool4 = nn.SpatialMaxPooling(2, 2, 2, 2, 0, 0)
local input = nn.Identity()()
local features1 = conv1(input) -- 3 -> 64
local features2 = conv2(pool1(features1)) -- 64 -> 128
local features3 = conv3(pool2(features2)) -- 128 ->256
local features4 = conv4(pool3(features3)) -- 256 -> 512
local features5 = conv5(pool4(features4)) -- 512 -> 512
local deconv5 = create_deconv_3(512, 512)
local deconv4 = create_deconv_3(512+512, 256)
local deconv3 = create_deconv_3(256+256, 128)
local deconv2 = create_deconv_2(128+128, 64)
local deconv1 = create_deconv_2(64+64, output_channel)
deconv1:remove(6)
local defeature5 = deconv5(features5)
local defeature4 = nn.JoinTable(1,3)({nn.SpatialMaxUnpooling(pool4)(defeature5),features4})
local defeature3t = deconv4(defeature4)
local defeature3 = nn.JoinTable(1,3)({nn.SpatialMaxUnpooling(pool3)(defeature3t),features3})
local defeature2t = deconv3(defeature3)
local defeature2 = nn.JoinTable(1,3)({nn.SpatialMaxUnpooling(pool2)(defeature2t),features2})
local defeature1t = deconv2(defeature2)
local defeature1 = nn.JoinTable(1,3)({nn.SpatialMaxUnpooling(pool1)(defeature1t),features1})
local output = deconv1(defeature1)
local model = nn.gModule({input}, {output})
return model
end
return create_model