-
Notifications
You must be signed in to change notification settings - Fork 0
/
scriptT3.m
66 lines (54 loc) · 2.15 KB
/
scriptT3.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
%% Script for T3
%Use transfer learning based on a pre-trained AlexNet
%freeze the weights of all the layers but the last fully connected layer and
%fine-tune the weights of the last layer based on the same train and validation sets employed before;
%% import dataset
LazebnikTrainDatasetPath = fullfile('train');
imds = imageDatastore(LazebnikTrainDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
% Splitting train dataset in train and validation datasets (85%-15%)
quotaForEachLabel=0.85;
[imdsTrain,imdsValidation] = splitEachLabel(imds,quotaForEachLabel,'randomize');
%convert BN to 3 channel BN repeating the channel 3 times
imdsTrain.ReadFcn = @(x)repmat(imread(x), 1, 1, 3);
imdsValidation.ReadFcn = @(x)repmat(imread(x), 1, 1, 3);
%%
net = alexnet;
layersTransfer = net.Layers(1:end-3);
numClasses = numel(categories(imdsTrain.Labels));
%%
layers = [
layersTransfer
fullyConnectedLayer(numClasses,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20)
softmaxLayer
classificationLayer];
%% Augment
inputSize = net.Layers(1).InputSize;
augmenter = imageDataAugmenter( ...
'RandXReflection',1);
augimdsTrain = augmentedImageDatastore(inputSize(1:3),imdsTrain, ...
'DataAugmentation',augmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:3),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',32, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(augimdsTrain,layers,options);
%%
LazebnikTestDatasetPath = fullfile('test');
imdsTest = imageDatastore(LazebnikTestDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
imdsTest.ReadFcn = @(x)repmat(imread(x), 1, 1, 3);
augimdsTest = augmentedImageDatastore(inputSize(1:3),imdsTest);
[YPred,scores] = classify(netTransfer,augimdsTest);
YValidation = imdsTest.Labels;
mean(YPred == YValidation)
%%
% confusion matrix
figure
plotconfusion(YValidation,YPred)