Mercurial > hg > camir-aes2014
comparison core/magnatagatune/tests_evals/rbm_subspace/Exp_grad.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
2 % Experiment with gradient ascent % | |
3 % Project: sub-euclidean distance for music similarity % | |
4 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
5 %% Load features | |
6 %feature_file = 'rel_music_raw_features.mat'; | |
7 feature_file = 'rel_music_raw_features+simdata_ISMIR12.mat'; | |
8 | |
9 vars = whos('-file', feature_file); | |
10 A = load(feature_file,vars(1).name,vars(2).name,vars(3).name,vars(4).name); | |
11 raw_features = A.(vars(1).name); | |
12 indices = A.(vars(2).name); | |
13 tst_inx = A.(vars(3).name); | |
14 trn_inx = A.(vars(4).name); | |
15 %% Params setting | |
16 dmr = [0 5 10 20 30 50]; % dimension reduction by PCA | |
17 ws = [0 5 10 20 30 50 70]; % window size | |
18 % parameters of rbm (if it is used for extraction) | |
19 hidNum = 0; | |
20 lr_1 = 0; | |
21 lr_2 = 0; | |
22 mmt = 0; | |
23 cost = 0; | |
24 %% Select parameters (if grid-search is not applied) | |
25 di = 1; | |
26 | |
27 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
28 % If grid search is define | |
29 % Define directory to save parameters & results | |
30 if ~isempty(findstr('WIN',computer())) | |
31 dir = 'C:\Pros\Experiments\ISMIR_2013\grad\'; % In windows platform | |
32 dlm = '\'; | |
33 elseif ~isempty(findstr('linux',computer())) || ~isempty(findstr('LNX',computer())) | |
34 dir = '/home/funzi/Documents/Experiments/ISMIR_2013/grad/'; % In lunix platform | |
35 dlm = '/'; | |
36 end | |
37 | |
38 EXT_TYPE = 2; | |
39 switch (EXT_TYPE) | |
40 case 1 | |
41 dir = strcat(dir,'pca',dlm); | |
42 case 2 | |
43 dir = strcat(dir,'rbm',dlm); | |
44 | |
45 hidNum = [100 500 1000 1200]; | |
46 lr_1 = [0.5 0.7]; | |
47 lr_2 = [0.7]; | |
48 mmt = [0.1]; | |
49 cost = [0.00002]; | |
50 otherwise | |
51 dir = strcat(dir,'none',dlm); | |
52 end | |
53 | |
54 w_num = size(ws,2); | |
55 | |
56 for iiii = 1:200 % set the higher range to search for better features in case of ext using rbm | |
57 log_file = strcat(dir,'exp',num2str(iiii),'.mat') | |
58 inx = resume_from_grid(log_file,8 + w_num); | |
59 if inx(end-w_num+1:end)==ones(1,w_num) | |
60 max_= zeros(1,w_num); | |
61 else | |
62 max_ = inx(end-w_num+1:end); | |
63 end | |
64 | |
65 results = zeros(1,w_num); | |
66 W_max = cell(1,w_num); | |
67 vB_max = cell(1,w_num); | |
68 hB_max = cell(1,w_num); | |
69 Ws_max = cell(1,w_num); | |
70 | |
71 for hi = inx(1):size(hidNum,2) | |
72 for l1i = inx(2):size(lr_1,2) | |
73 % for l1i = inx(3):size(lr_2,2) | |
74 for mi = inx(4):size(mmt,2) | |
75 for ci = inx(5):size(cost,2) | |
76 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
77 %% Feature extraction | |
78 features = raw_features; | |
79 switch (EXT_TYPE) | |
80 case 1 % Using PCA | |
81 assert(~exist('OCTAVE_VERSION'),'This script cannot run in octave'); | |
82 coeff = princomp(raw_features); | |
83 coeff = coeff(:,1:6); % best = 6 | |
84 features = raw_features*coeff; | |
85 % normalizing | |
86 mm = minmax(features')'; | |
87 inn= (find(mm(1,:)~=mm(2,:))); | |
88 mm = mm(:,inn); | |
89 features = features(:,inn); | |
90 features = (features-repmat(mm(1,:),size(features,1),1))./(repmat(mm(2,:),size(features,1),1)-repmat(mm(1,:),size(features,1),1)); | |
91 case 2 % Using rbm | |
92 conf.hidNum = hidNum(hi); | |
93 conf.eNum = 100; | |
94 conf.sNum = size(raw_features,1); | |
95 conf.bNum = 1; | |
96 conf.gNum = 1; | |
97 conf.params = [lr_1(l1i) lr_1(l1i) mmt(mi) cost(ci)]; | |
98 conf.N = 50; | |
99 conf.MAX_INC = 10; | |
100 W1 = zeros(0,0); | |
101 [W1 vB1 hB1] = training_rbm_(conf,W1,raw_features); | |
102 features = logistic(raw_features*W1 + repmat(hB1,conf.sNum,1)); | |
103 otherwise | |
104 % normalizing | |
105 % mm = minmax(features')'; | |
106 % inn= (find(mm(1,:)~=mm(2,:))); | |
107 % mm = mm(:,inn); | |
108 % features = features(:,inn); | |
109 % features = (features-repmat(mm(1,:),size(features,1),1))./(repmat(mm(2,:),size(features,1),1)-repmat(mm(1,:),size(features,1),1)); | |
110 end | |
111 | |
112 for wi = inx(6):w_num | |
113 %% Sub-euclidean computation | |
114 w = ws(wi); % w = subspace window size | |
115 num_case = size(trn_inx,1); | |
116 [trnd_12 trnd_13] = subspace_distances(trn_inx,features,indices,w,0); | |
117 [tstd_12 tstd_13] = subspace_distances(tst_inx,features,indices,w,0); | |
118 cr_ = 0; % correct rate for training | |
119 cr = 0; % correct rate for testing | |
120 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
121 %% CODE HERE %% | |
122 [Ws cr_] = gradient_ascent(trnd_12,trnd_13,0.1,0.1,0.00002); | |
123 | |
124 for i = 1:num_case | |
125 cr = cr + sum((tstd_13{i}-tstd_12{i})*Ws{i}' > 0, 1)/size(tstd_12{i},1); | |
126 end | |
127 cr = cr/num_case; | |
128 if cr_>max_(wi) | |
129 max_(wi) = cr_; | |
130 results(wi) = cr; | |
131 if EXT_TYPE==2 | |
132 W_max{wi} = W1; | |
133 vB_max{wi} = vB1; | |
134 hB_max{wi} = hB1; | |
135 Ws_max{wi} = Ws; | |
136 end | |
137 end | |
138 | |
139 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
140 fprintf('[window|train|test]= %2d |%f |%f\n',w,cr_,cr); | |
141 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
142 % Using the logging function to save paramters | |
143 % and the result for plotting or in grid search | |
144 switch EXT_TYPE | |
145 case 1 | |
146 % logging(log_file,[100 100 100 100 100 wi cr_ cr max_]); | |
147 case 2 | |
148 logging(log_file,[hi l1i l1i mi ci wi cr_ cr max_ conf.hidNum conf.eNum conf.params]); | |
149 otherwise | |
150 logging(log_file,[100 100 100 100 100 wi cr_ cr max_]); | |
151 end | |
152 end | |
153 inx(6)=1; | |
154 end | |
155 inx(5) = 1; | |
156 end | |
157 inx(4) = 1; | |
158 end | |
159 inx(2) = 1; | |
160 end | |
161 inx(1) = 1; | |
162 %% Test on best features | |
163 | |
164 save(strcat(dir,'res_',num2str(iiii),'.mat'),'max_','results','W_max','vB_max','hB_max','Ws_max','ws'); | |
165 [dummy pos] = max(max_); | |
166 fprintf('Accuracy (RBM best fts): w = %d train = %f test = %f\n',ws(pos),max_(pos),results(pos)); | |
167 clc; | |
168 end | |
169 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
170 clear; |