-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCG_CLASSIFY.m
73 lines (63 loc) · 2.33 KB
/
CG_CLASSIFY.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
% Version 1.000
%
% Code provided by Ruslan Salakhutdinov and Geoff Hinton
%
% Permission is granted for anyone to copy, use, modify, or distribute this
% program and accompanying programs and documents for any purpose, provided
% this copyright notice is retained and prominently displayed, along with
% a note saying that the original programs are available from our
% web page.
% The programs and documents are distributed without any warranty, express or
% implied. As the programs were written for research purposes only, they have
% not been tested to the degree that would be advisable in any important
% application. All use of these programs is entirely at the user's own risk.
function [f, df] = CG_CLASSIFY(VV,Dim,XX,target);
l1 = Dim(1);
l2 = Dim(2);
l3= Dim(3);
l4= Dim(4);
% l5= Dim(5);
% l6= Dim(6);
N = size(XX,1);
% Do decomversion.
w1 = reshape(VV(1:(l1+1)*l2),l1+1,l2);
xxx = (l1+1)*l2;
w2 = reshape(VV(xxx+1:xxx+(l2+1)*l3),l2+1,l3);
xxx = xxx+(l2+1)*l3;
% w3 = reshape(VV(xxx+1:xxx+(l3+1)*l4),l3+1,l4);
% xxx = xxx+(l3+1)*l4;
% w4 = reshape(VV(xxx+1:xxx+(l4+1)*l5),l4+1,l5);
% xxx = xxx+(l4+1)*l5;
w_class = reshape(VV(xxx+1:xxx+(l3+1)*l4),l3+1,l4);
XX = [XX ones(N,1)];
% w1probs = 1./(1 + exp(-XX*w1)); w1probs = [w1probs ones(N,1)];
% w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)];
w1probs = max(0, XX*w1); w1probs = [w1probs ones(N,1)];
w2probs = max(0, w1probs*w2); w2probs = [w2probs ones(N,1)];
% w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs ones(N,1)];
% w4probs = 1./(1 + exp(-w3probs*w4)); w4probs = [w4probs ones(N,1)];
targetout = exp(w2probs*w_class);
targetout = targetout./repmat(sum(targetout,2),1,10);
f = -sum(sum( target(:,1:end).*log(targetout))) ;
IO = (targetout-target(:,1:end));
Ix_class=IO;
dw_class = w2probs'*Ix_class;
%
% Ix4 = (Ix_class*w_class').*w4probs.*(1-w4probs);
% Ix4 = Ix4(:,1:end-1);
% dw4 = w3probs'*Ix4;
%
% Ix3 = (Ix4*w4').*w3probs.*(1-w3probs);
% Ix3 = Ix3(:,1:end-1);
% dw3 = w2probs'*Ix3;
w22 = logical(0 < [w1probs*w2 ones(N,1)]);
% Ix2 = (Ix_class*w_class').*w2probs.*(1-w2probs);
Ix2 = (Ix_class*w_class').*w22;
Ix2 = Ix2(:,1:end-1);
dw2 = w1probs'*Ix2;
w11 = logical(0 < [XX*w1 ones(N,1)]);
% Ix1 = (Ix2*w2').*w1probs.*(1-w1probs);
Ix1 = (Ix2*w2').*w11;
Ix1 = Ix1(:,1:end-1);
dw1 = XX'*Ix1;
df = [dw1(:)' dw2(:)' dw_class(:)']';