Mercurial > hg > camir-aes2014
comparison core/magnatagatune/tests_evals/do_test_rounds.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [out]= do_test_rounds(trainfun, X, simdata, trainparams, fparams,... | |
2 paramhash, paramhash_train, clips) | |
3 | |
4 % --- | |
5 % DEBUG: we mix up the training set | |
6 % --- | |
7 % simdata = mixup(simdata); | |
8 | |
9 if isfield(simdata, 'clip_type'); | |
10 clip_type = simdata.clip_type; | |
11 else | |
12 clip_type = 'MTTClip'; | |
13 end | |
14 | |
15 nTestSets = size(simdata.partBinTst, 2); % num cv bins | |
16 ntrainsizes = size(simdata.partBinTrn, 2); % num increases of training | |
17 | |
18 for m = 1:ntrainsizes | |
19 | |
20 ok_train = zeros(2, nTestSets); | |
21 ok_test = zeros(2, nTestSets); | |
22 equal_test = zeros(1, nTestSets); | |
23 ok_notin_train = zeros(2, nTestSets); | |
24 | |
25 % A = cell(nTestSets,1); | |
26 % dout = cell(nTestSets,1); | |
27 % clips_train = cell(nTestSets,1); | |
28 % clips_test = cell(nTestSets,1); | |
29 % clips_notin_train = cell(nTestSets,1); | |
30 % Y_notin_train = cell(nTestSets,1); | |
31 % Ytrain = cell(nTestSets,1); | |
32 % Ytest = cell(nTestSets,1); | |
33 % parfor | |
34 for k = 1:nTestSets | |
35 | |
36 | |
37 % runlog mlr | |
38 try | |
39 | |
40 % --- | |
41 % Get the training constraints and features for this round | |
42 % --- | |
43 % DEBUG: the similarity data in Ytrain and Ytest seems correct. | |
44 [clips_train{k}, Xtrain, Ytrain{k}] ... | |
45 = get_data_compact(clips, X, simdata.partBinTrn{k,m}); | |
46 Ytest{k}={}; | |
47 | |
48 % training step | |
49 [A{k}, dout{k}] = feval(trainfun, Xtrain, Ytrain{k}, trainparams); | |
50 | |
51 % --- | |
52 % test step | |
53 % TODO: the distmeasure object could be created by the wrapper! | |
54 % --- | |
55 if isfield(dout{k},'interpreter'); | |
56 interpreter = str2func(dout{k}.interpreter); | |
57 else | |
58 % only for backward compability | |
59 % warning ('legacy implementation of dist measure'); | |
60 if isnumeric(A{k}) | |
61 % mahalanobis case | |
62 | |
63 % special delta mahalanobis | |
64 interpreter = str2func('DistMeasureMahal'); | |
65 else | |
66 % neural network case: A{k} is a neural net object | |
67 interpreter = str2func('DistMeasureGeneric'); | |
68 end | |
69 end | |
70 | |
71 if isfield(trainparams,'deltafun') | |
72 % special delta | |
73 diss = feval(interpreter,clips, A{k}, X, str2func(trainparams.deltafun), trainparams.deltafun_params); | |
74 else | |
75 % standard | |
76 % --- | |
77 % TODO: the default delta is different between | |
78 % similarity measures. except for the mahalmeasure | |
79 % this should be specified | |
80 % --- | |
81 diss = feval(interpreter, clips, A{k}, X); | |
82 end | |
83 | |
84 % test training data | |
85 [ok_train(:,k)] = metric_fulfills_ranking... | |
86 (diss, Ytrain{k}, feval(clip_type,clips_train{k})); | |
87 | |
88 % get test data | |
89 [clips_test{k}, Xtest, Ytest{k}] ... | |
90 = get_data_compact(clips, X, simdata.partBinTst{k}); | |
91 | |
92 % diss = DistMeasureMahal(feval(clip_type,clips_test{k}), A{k}, Xtest); | |
93 % test test data | |
94 [ok_test(:,k), equal_test(k)] = metric_fulfills_ranking... | |
95 (diss, Ytest{k}, feval(clip_type,clips_test{k})); | |
96 cprint(3,'%2.2f %2.2f fold performance', ok_test(:,k)); | |
97 | |
98 % --- | |
99 % extra diag for MLR | |
100 % TODO: make this wrappeable | |
101 % --- | |
102 if isequal(trainfun, @mlr_wrapper) | |
103 dout{k}.mprperf = mlr_test(A{k}, 0, Xtrain, Ytrain{k}(:,1:2), Xtest, Ytest{k}(:,1:2)) ; | |
104 end | |
105 | |
106 % --- | |
107 % this gives data for the unused training set remainders | |
108 % --- | |
109 if isfield(simdata,'partBinNoTrn') | |
110 if ~isempty(simdata.partBinNoTrn{k,m}) | |
111 [clips_notin_train{k}, X_notin_train, Y_notin_train{k}] ... | |
112 = get_data_compact(clips, X, simdata.partBinNoTrn{k,m}); | |
113 | |
114 % test unused training data | |
115 [ok_notin_train(:,k), equal_test(k)] = metric_fulfills_ranking... | |
116 (diss, Y_notin_train{k}, feval(clip_type,clips_notin_train{k})); | |
117 | |
118 % what to do if there is no data ? | |
119 else | |
120 ok_notin_train(:,k) = -1; | |
121 end | |
122 else | |
123 ok_notin_train(:,k) = -1; | |
124 end | |
125 | |
126 catch err | |
127 | |
128 % --- | |
129 % in case training or test fails | |
130 % --- | |
131 print_error(err); | |
132 | |
133 A{k} = []; | |
134 dout{k} = -1; | |
135 | |
136 ok_test(:,k) = -1; | |
137 ok_train(:,k) = -1; | |
138 ok_notin_train(:,k) = -1; | |
139 equal_test(k) = -1; | |
140 | |
141 % --- | |
142 % save feature, system and data configuration | |
143 % and indicate failure | |
144 % --- | |
145 xml_save(sprintf('runlog_%s.%s_trainparam.xml',... | |
146 paramhash, paramhash_train), trainparams); | |
147 xml_save(sprintf('runlog_%s.%s_err.xml',... | |
148 paramhash, paramhash_train), print_error(err)); | |
149 end | |
150 end | |
151 | |
152 if ~(ntrainsizes == 1) | |
153 | |
154 % save elaborate testing data | |
155 size_sum = 0; | |
156 for i = 1:nTestSets | |
157 size_sum = size_sum + size(simdata.partBinTrn{i,m}) / size(simdata.partBinTrn{i,end}); | |
158 end | |
159 size_sum = size_sum / nTestSets; | |
160 | |
161 out.inctrain.trainfrac(:, m) = size_sum; | |
162 out.inctrain.dataPartition(:, m) = 0; | |
163 | |
164 % --- | |
165 % NOTE: the max value is important for debugging, | |
166 % especially when the maximal training success is reached | |
167 % in the middle of the data set | |
168 % --- | |
169 % out.inctrain.max_ok_test(:, m) = max(ok_test, 2); | |
170 out.inctrain.mean_ok_test(:, m) = mean(ok_test(:, ok_test(1,:) >=0), 2); | |
171 out.inctrain.var_ok_test(:, m) = var(ok_test(:, ok_test(1,:) >=0), 0, 2); | |
172 out.inctrain.equal_test(m) = median(equal_test); | |
173 | |
174 out.inctrain.mean_ok_train(:, m) = mean(ok_train(:, ok_train(1,:) >=0), 2); | |
175 out.inctrain.var_ok_train(:, m) = var(ok_train(:, ok_train(1,:) >=0), 0, 2); | |
176 | |
177 % --- | |
178 % TODO: DEBUG: this does not work correctly | |
179 % maybe thats also true for the above? | |
180 % --- | |
181 out.inctrain.mean_ok_notin_train(:, m) = mean(ok_notin_train(:, ok_notin_train(1,:) >=0), 2); | |
182 out.inctrain.var_ok_notin_train(:, m) = var(ok_notin_train(:, ok_notin_train(1,:) >=0), 0, 2); | |
183 | |
184 diag.inctrain(m).ok_train = ok_train; | |
185 diag.inctrain(m).ok_test = ok_test; | |
186 diag.inctrain(m).ok_notin_train = ok_notin_train; | |
187 diag.inctrain(m).equal_test = equal_test; | |
188 end | |
189 | |
190 % --- | |
191 % save traditional information for full training set | |
192 % --- | |
193 if size(simdata.partBinTrn{1,m}) == size(simdata.partBinTrn{1,end}); | |
194 | |
195 % out.max_ok_test = max(ok_test, 2); | |
196 out.mean_ok_test = mean(ok_test(:, ok_test(1,:) >=0), 2); | |
197 out.var_ok_test = var(ok_test(:, ok_test(1,:) >=0), 0, 2); | |
198 out.equal_test = median(equal_test); | |
199 | |
200 out.mean_ok_train = mean(ok_train(:, ok_train(1,:) >=0), 2); | |
201 out.var_ok_train = var(ok_train(:, ok_train(1,:) >=0), 0, 2); | |
202 | |
203 % --- | |
204 % TODO: DEBUG: this does not work correctly | |
205 % --- | |
206 out.mean_ok_notin_train = mean(ok_notin_train(:, ok_notin_train(1,:) >=0), 2); | |
207 out.var_ok_notin_train = var(ok_notin_train(:, ok_notin_train(1,:) >=0), 0, 2); | |
208 | |
209 % --- | |
210 % get winning measure | |
211 % we use the weighted winning measure if possible | |
212 % --- | |
213 if max(ok_test(2,:)) > 0 | |
214 [~, best] = max(ok_test(2,:)); | |
215 else | |
216 [~, best] = max(ok_test(1,:)); | |
217 end | |
218 | |
219 diag.A = A; | |
220 diag.diag = dout; | |
221 | |
222 diag.ok_test = ok_test; | |
223 diag.equal_test = equal_test; | |
224 diag.ok_train = ok_train; | |
225 diag.ok_notin_train = ok_notin_train; | |
226 | |
227 % save some metric matrices | |
228 out.best_A = A{best}; | |
229 out.best_diag = dout{best}; | |
230 out.best_idx = best; | |
231 | |
232 end | |
233 end | |
234 | |
235 % save parameters | |
236 out.camirrev = camirversion(); | |
237 out.fparams = fparams; | |
238 out.trainfun = trainfun; | |
239 out.trainparams = trainparams; | |
240 out.clip_ids = clips.id(); | |
241 out.dataPartition = []; | |
242 out.Y = size(simdata); | |
243 % --- | |
244 % NOTE: this takes A LOT OF DISC SPACE | |
245 % --- | |
246 % out.Ytrain = Ytrain{end}; | |
247 % out.Ytest = Ytest{end}; | |
248 | |
249 % --- | |
250 % save the diagostics data to disk | |
251 % --- | |
252 save(sprintf('runlog_%s.%s_results.mat',... | |
253 paramhash, paramhash_train),... | |
254 'out', 'diag'); | |
255 end |