-
Notifications
You must be signed in to change notification settings - Fork 0
/
CalculateParamsNV.m
71 lines (40 loc) · 2.33 KB
/
CalculateParamsNV.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
function GMM_Params = CalculateParamsNV(SIFT_data,RGB_data,Log_Likelihoods_SIFT,Log_Likelihoods_RGB)
%% Create the parameters from the SIFT data
for j = 1 : length(Log_Likelihoods_SIFT)
Corr_LogLikelihood = gather(Log_Likelihoods_SIFT{j});
get_numCluster = size(Corr_LogLikelihood,2);
MaxCorrLogLikelihood = max(Corr_LogLikelihood,[],2);
Responsibilities = gpuArray(exp(Corr_LogLikelihood - MaxCorrLogLikelihood));
for k = 1 : get_numCluster
Responsibilities_k = Responsibilities(:,k)';
Nonzero_idx = Responsibilities_k > 0;
SIFT_mus(k,:) = Responsibilities_k * SIFT_data / sum(Responsibilities_k);
Centered_Data = SIFT_data(Nonzero_idx,:) - SIFT_mus(k,:);
SIFT_Sigmas(k,:) = Responsibilities_k(Nonzero_idx) *...
(Centered_Data.^2) / sum(Responsibilities_k(Nonzero_idx)) + 1e-6;
SIFT_weights(k) = sum(Responsibilities_k) / size(SIFT_data,1);
end
GMM_Params(j).Training_SIFT_mus = gather(SIFT_mus);
GMM_Params(j).Training_SIFT_Sigmas = gather(SIFT_Sigmas);
GMM_Params(j).Training_SIFT_weights = gather(SIFT_weights);
end
%% Create the parameters from the RGB data
for j = 1 : length(Log_Likelihoods_RGB)
Corr_LogLikelihood = gather(Log_Likelihoods_RGB{j});
get_numCluster = size(Corr_LogLikelihood,2);
MaxCorrLogLikelihood = max(Corr_LogLikelihood,[],2);
Responsibilities = gpuArray(exp(Corr_LogLikelihood - MaxCorrLogLikelihood));
for k = 1 : get_numCluster
Responsibilities_k = Responsibilities(:,k)';
Nonzero_idx = Responsibilities_k > 0;
RGB_mus(k,:) = Responsibilities_k * RGB_data / sum(Responsibilities_k);
Centered_Data = RGB_data(Nonzero_idx,:) - RGB_mus(k,:);
RGB_Sigmas(k,:) = Responsibilities_k(Nonzero_idx) *...
(Centered_Data.^2) / sum(Responsibilities_k(Nonzero_idx)) + 1e-6;
RGB_weights(k) = sum(Responsibilities_k) / size(RGB_data,1);
end
GMM_Params(j).Training_RGB_mus = gather(RGB_mus);
GMM_Params(j).Training_RGB_Sigmas = gather(RGB_Sigmas);
GMM_Params(j).Training_RGB_weights = gather(RGB_weights);
end
end