-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathIMDBAddMNIST.m
141 lines (106 loc) · 4.42 KB
/
IMDBAddMNIST.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
function [ imdb, select ] = IMDBAddMNIST( imdb, path, task_num, varargin )
% IMDBADDMNIST Add to imdb the MNIST dataset.
% Input:
% PATH struct generated by GETPATH() with path.path_OrigNetresponse field adjusted
% TASK_NUM the path number (see CNN_CUSTOMTRAIN) that this dataset corresponds to
% Options:
% See code comments
%
% Authors: Zhizhong Li
%
% See the COPYING file.
%
% The MNIST helper files are generously provided by Stanford UFLDL
% Source: http://ufldl.stanford.edu/wiki/index.php/Using_the_MNIST_Dataset
opts.partial = 0; % for >0 partial, e.g. 0.3, only include that much portion of # samples.
opts.label = 'class'; % 'class' for class label, 'probdist' for a recorded response distribution (uses path.path_OrigNetresponse.MNIST)
opts.trainval = [1 2]; % 1 for train, 2 for val. By default include train+val.
opts.preloaded_offset = [ 0 60000 70000 ]; % Determines how many samples are there in MNIST's train / test.
opts.randstream = []; % use randstream if provided
opts = vl_argparse(opts, varargin);
if ~isfield(imdb, 'images')
imdb.images.name = [];
imdb.images.label = [];
imdb.images.set = [];
imdb.images.task = [];
end
fid = fopen(fullfile(path.path_MIT67root, 'classes.txt'), 'r');
classstrs = textscan(fid, '%s\n'); classstrs = classstrs{1};
fclose(fid);
classmap = containers.Map(classstrs, num2cell(1:numel(classstrs)));
sets = {'train', 'test'};
% train or test set
for set = opts.trainval(:)'
f = sets{set};
% image names
loadedimg = loadMNISTImages(fullfile(path.path_MNISTroot, [f '-images.idx3-ubyte']));
ids = ( opts.preloaded_offset(set)+1 : opts.preloaded_offset(set+1) )';
imdb.img_MNIST(:,:, ids) = single(loadedimg);
loadedlbl = loadMNISTLabels(fullfile(path.path_MNISTroot, [f '-labels.idx1-ubyte']));
loadedlbl(loadedlbl==0) = 10;
imdb.lbl_MNIST(ids,1) = loadedlbl;
readname.(f){1} = num2cell(ids);
% selecting partial
if opts.partial
if numel(opts.partial)==1
partial = opts.partial;
else
partial = opts.partial(set);
end
if isempty(opts.randstream)
select.(f) = randperm(numel(readname.(f){1}), ceil(numel(readname.(f){1}) * partial));
else
select.(f) = randperm(opts.randstream, numel(readname.(f){1}), ceil(numel(readname.(f){1}) * partial));
end
else
select.(f) = 1:numel(readname.(f){1});
end
% names
names = readname.(f){1}(select.(f));
n_set = size(names,1);
% labels...
switch opts.label
case 'probdist'
% distribution: just load
probdist = load(path.path_OrigNetresponse.MNIST.(f));
classes = num2cell( probdist.lastfc_out, 1 )';
case 'class'
% induct from file folder name
classes = num2cell( loadedlbl, 2 );
otherwise
throw(MException('opts.label:notRecognized', 'probdist/class/multilabel'));
end
classes = classes(select.(f),:);
imdb.images.name = [ imdb.images.name; names ];
imdb.images.label = [ imdb.images.label; classes ];
imdb.images.set = [ imdb.images.set;
ones(n_set,1) * set ];
imdb.images.task = [ imdb.images.task;
task_num * ones(n_set, 1) ];
end
if isfield(select, 'test'), select.val = select.test; select = rmfield(select, 'test'); end
function images = loadMNISTImages(filename)
%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing
%the raw MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2051, ['Bad magic number in ', filename, '']);
numImages = fread(fp, 1, 'int32', 0, 'ieee-be');
numRows = fread(fp, 1, 'int32', 0, 'ieee-be');
numCols = fread(fp, 1, 'int32', 0, 'ieee-be');
images = fread(fp, inf, 'unsigned char');
images = reshape(images, numCols, numRows, numImages);
images = permute(images,[2 1 3]);
fclose(fp);
function labels = loadMNISTLabels(filename)
%loadMNISTLabels returns a [number of MNIST images]x1 matrix containing
%the labels for the MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2049, ['Bad magic number in ', filename, '']);
numLabels = fread(fp, 1, 'int32', 0, 'ieee-be');
labels = fread(fp, inf, 'unsigned char');
assert(size(labels,1) == numLabels, 'Mismatch in label count');
fclose(fp);