-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.lua
executable file
·69 lines (61 loc) · 1.79 KB
/
model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
local init_utils = require 'models.init_model_weight'
local parallel_utils = require 'utils.parallel_utils'
model = {}
criterion = {}
if opt.retrain then
assert(paths.filep(opt.retrain),
'File not found: ' .. opt.retrain)
print('===> Loading model from file: '..opt.retrain);
-- for single-gpu
model = torch.load(opt.retrain)
model = model:get(1)
--[[
--print(model)
-- resception
model:remove(22)
model:remove(21)
model:add(nn.Linear(2048,opt.nClasses))
model:add(cudnn.LogSoftMax())
--]]
--[[
-- inception-v3-2015-12-06
model.modules[#model] = nil
model:get(1):add(nn.View(2048))
model:get(1):add(nn.Linear(2048,opt.nClasses))
model:get(1):add(cudnn.LogSoftMax())
--]]
--[[
-- for inception-v3-2015-12-05
feature_encoder = torch.load(opt.retrain)
feature_encoder.modules[#feature_encoder] = nil
feature_encoder.modules[#feature_encoder] = nil
feature_encoder.modules[#feature_encoder] = nil
classifier = nn.Sequential()
classifier:add(nn.View(2048))
classifier:add(nn.Linear(2048, 1000))
classifier:add(cudnn.LogSoftMax())
cudnn.convert(feature_encoder, cudnn)
cudnn.convert(classifier, cudnn)
feature_encoder:cuda()
classifier:cuda()
--]]
else
local model_filename = opt.netType..'.lua'
local model_filepath = paths.concat('models', model_filename)
assert(paths.filep(model_filepath),
'File not found: '..model_filepath)
paths.dofile(model_filepath)
print('===> Creating model from file: '..model_filepath)
model = createModel()
--init_utils.MSRinit(model)
end
cudnn.fastest, cudnn.benchmark = true, true
if #opt.nGPU > 1 then
model = parallel_utils.makeDataParallel(model, opt.nGPU)
end
model:cuda()
criterion = nn.ClassNLLCriterion():cuda()
print(model)
print(criterion)
print('===> Loading model complete')
collectgarbage()