comparison core/magnatagatune/tests_evals/do_test_rounds.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [out]= do_test_rounds(trainfun, X, simdata, trainparams, fparams,...
2 paramhash, paramhash_train, clips)
3
4 % ---
5 % DEBUG: we mix up the training set
6 % ---
7 % simdata = mixup(simdata);
8
9 if isfield(simdata, 'clip_type');
10 clip_type = simdata.clip_type;
11 else
12 clip_type = 'MTTClip';
13 end
14
15 nTestSets = size(simdata.partBinTst, 2); % num cv bins
16 ntrainsizes = size(simdata.partBinTrn, 2); % num increases of training
17
18 for m = 1:ntrainsizes
19
20 ok_train = zeros(2, nTestSets);
21 ok_test = zeros(2, nTestSets);
22 equal_test = zeros(1, nTestSets);
23 ok_notin_train = zeros(2, nTestSets);
24
25 % A = cell(nTestSets,1);
26 % dout = cell(nTestSets,1);
27 % clips_train = cell(nTestSets,1);
28 % clips_test = cell(nTestSets,1);
29 % clips_notin_train = cell(nTestSets,1);
30 % Y_notin_train = cell(nTestSets,1);
31 % Ytrain = cell(nTestSets,1);
32 % Ytest = cell(nTestSets,1);
33 % parfor
34 for k = 1:nTestSets
35
36
37 % runlog mlr
38 try
39
40 % ---
41 % Get the training constraints and features for this round
42 % ---
43 % DEBUG: the similarity data in Ytrain and Ytest seems correct.
44 [clips_train{k}, Xtrain, Ytrain{k}] ...
45 = get_data_compact(clips, X, simdata.partBinTrn{k,m});
46 Ytest{k}={};
47
48 % training step
49 [A{k}, dout{k}] = feval(trainfun, Xtrain, Ytrain{k}, trainparams);
50
51 % ---
52 % test step
53 % TODO: the distmeasure object could be created by the wrapper!
54 % ---
55 if isfield(dout{k},'interpreter');
56 interpreter = str2func(dout{k}.interpreter);
57 else
58 % only for backward compability
59 % warning ('legacy implementation of dist measure');
60 if isnumeric(A{k})
61 % mahalanobis case
62
63 % special delta mahalanobis
64 interpreter = str2func('DistMeasureMahal');
65 else
66 % neural network case: A{k} is a neural net object
67 interpreter = str2func('DistMeasureGeneric');
68 end
69 end
70
71 if isfield(trainparams,'deltafun')
72 % special delta
73 diss = feval(interpreter,clips, A{k}, X, str2func(trainparams.deltafun), trainparams.deltafun_params);
74 else
75 % standard
76 % ---
77 % TODO: the default delta is different between
78 % similarity measures. except for the mahalmeasure
79 % this should be specified
80 % ---
81 diss = feval(interpreter, clips, A{k}, X);
82 end
83
84 % test training data
85 [ok_train(:,k)] = metric_fulfills_ranking...
86 (diss, Ytrain{k}, feval(clip_type,clips_train{k}));
87
88 % get test data
89 [clips_test{k}, Xtest, Ytest{k}] ...
90 = get_data_compact(clips, X, simdata.partBinTst{k});
91
92 % diss = DistMeasureMahal(feval(clip_type,clips_test{k}), A{k}, Xtest);
93 % test test data
94 [ok_test(:,k), equal_test(k)] = metric_fulfills_ranking...
95 (diss, Ytest{k}, feval(clip_type,clips_test{k}));
96 cprint(3,'%2.2f %2.2f fold performance', ok_test(:,k));
97
98 % ---
99 % extra diag for MLR
100 % TODO: make this wrappeable
101 % ---
102 if isequal(trainfun, @mlr_wrapper)
103 dout{k}.mprperf = mlr_test(A{k}, 0, Xtrain, Ytrain{k}(:,1:2), Xtest, Ytest{k}(:,1:2)) ;
104 end
105
106 % ---
107 % this gives data for the unused training set remainders
108 % ---
109 if isfield(simdata,'partBinNoTrn')
110 if ~isempty(simdata.partBinNoTrn{k,m})
111 [clips_notin_train{k}, X_notin_train, Y_notin_train{k}] ...
112 = get_data_compact(clips, X, simdata.partBinNoTrn{k,m});
113
114 % test unused training data
115 [ok_notin_train(:,k), equal_test(k)] = metric_fulfills_ranking...
116 (diss, Y_notin_train{k}, feval(clip_type,clips_notin_train{k}));
117
118 % what to do if there is no data ?
119 else
120 ok_notin_train(:,k) = -1;
121 end
122 else
123 ok_notin_train(:,k) = -1;
124 end
125
126 catch err
127
128 % ---
129 % in case training or test fails
130 % ---
131 print_error(err);
132
133 A{k} = [];
134 dout{k} = -1;
135
136 ok_test(:,k) = -1;
137 ok_train(:,k) = -1;
138 ok_notin_train(:,k) = -1;
139 equal_test(k) = -1;
140
141 % ---
142 % save feature, system and data configuration
143 % and indicate failure
144 % ---
145 xml_save(sprintf('runlog_%s.%s_trainparam.xml',...
146 paramhash, paramhash_train), trainparams);
147 xml_save(sprintf('runlog_%s.%s_err.xml',...
148 paramhash, paramhash_train), print_error(err));
149 end
150 end
151
152 if ~(ntrainsizes == 1)
153
154 % save elaborate testing data
155 size_sum = 0;
156 for i = 1:nTestSets
157 size_sum = size_sum + size(simdata.partBinTrn{i,m}) / size(simdata.partBinTrn{i,end});
158 end
159 size_sum = size_sum / nTestSets;
160
161 out.inctrain.trainfrac(:, m) = size_sum;
162 out.inctrain.dataPartition(:, m) = 0;
163
164 % ---
165 % NOTE: the max value is important for debugging,
166 % especially when the maximal training success is reached
167 % in the middle of the data set
168 % ---
169 % out.inctrain.max_ok_test(:, m) = max(ok_test, 2);
170 out.inctrain.mean_ok_test(:, m) = mean(ok_test(:, ok_test(1,:) >=0), 2);
171 out.inctrain.var_ok_test(:, m) = var(ok_test(:, ok_test(1,:) >=0), 0, 2);
172 out.inctrain.equal_test(m) = median(equal_test);
173
174 out.inctrain.mean_ok_train(:, m) = mean(ok_train(:, ok_train(1,:) >=0), 2);
175 out.inctrain.var_ok_train(:, m) = var(ok_train(:, ok_train(1,:) >=0), 0, 2);
176
177 % ---
178 % TODO: DEBUG: this does not work correctly
179 % maybe thats also true for the above?
180 % ---
181 out.inctrain.mean_ok_notin_train(:, m) = mean(ok_notin_train(:, ok_notin_train(1,:) >=0), 2);
182 out.inctrain.var_ok_notin_train(:, m) = var(ok_notin_train(:, ok_notin_train(1,:) >=0), 0, 2);
183
184 diag.inctrain(m).ok_train = ok_train;
185 diag.inctrain(m).ok_test = ok_test;
186 diag.inctrain(m).ok_notin_train = ok_notin_train;
187 diag.inctrain(m).equal_test = equal_test;
188 end
189
190 % ---
191 % save traditional information for full training set
192 % ---
193 if size(simdata.partBinTrn{1,m}) == size(simdata.partBinTrn{1,end});
194
195 % out.max_ok_test = max(ok_test, 2);
196 out.mean_ok_test = mean(ok_test(:, ok_test(1,:) >=0), 2);
197 out.var_ok_test = var(ok_test(:, ok_test(1,:) >=0), 0, 2);
198 out.equal_test = median(equal_test);
199
200 out.mean_ok_train = mean(ok_train(:, ok_train(1,:) >=0), 2);
201 out.var_ok_train = var(ok_train(:, ok_train(1,:) >=0), 0, 2);
202
203 % ---
204 % TODO: DEBUG: this does not work correctly
205 % ---
206 out.mean_ok_notin_train = mean(ok_notin_train(:, ok_notin_train(1,:) >=0), 2);
207 out.var_ok_notin_train = var(ok_notin_train(:, ok_notin_train(1,:) >=0), 0, 2);
208
209 % ---
210 % get winning measure
211 % we use the weighted winning measure if possible
212 % ---
213 if max(ok_test(2,:)) > 0
214 [~, best] = max(ok_test(2,:));
215 else
216 [~, best] = max(ok_test(1,:));
217 end
218
219 diag.A = A;
220 diag.diag = dout;
221
222 diag.ok_test = ok_test;
223 diag.equal_test = equal_test;
224 diag.ok_train = ok_train;
225 diag.ok_notin_train = ok_notin_train;
226
227 % save some metric matrices
228 out.best_A = A{best};
229 out.best_diag = dout{best};
230 out.best_idx = best;
231
232 end
233 end
234
235 % save parameters
236 out.camirrev = camirversion();
237 out.fparams = fparams;
238 out.trainfun = trainfun;
239 out.trainparams = trainparams;
240 out.clip_ids = clips.id();
241 out.dataPartition = [];
242 out.Y = size(simdata);
243 % ---
244 % NOTE: this takes A LOT OF DISC SPACE
245 % ---
246 % out.Ytrain = Ytrain{end};
247 % out.Ytest = Ytest{end};
248
249 % ---
250 % save the diagostics data to disk
251 % ---
252 save(sprintf('runlog_%s.%s_results.mat',...
253 paramhash, paramhash_train),...
254 'out', 'diag');
255 end