comparison core/magnatagatune/tests_evals/test_generic_display_param_influence.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function stats = test_generic_display_param_influence(results, show)
2 % returns the mean accuracy influence of each feature and training parameter
3 %
4 % the influence is measured by comparing the mean
5 % achieved accuracy for all tries with each specific
6 % parameter being constant
7 %
8 % TODO: evaluate how the comparisons of all configuration
9 % tuples twith just the specific analysed parameter
10 % changing differ from the above approach
11
12 % get statistics for feature parameters
13 stats.fparams = gen_param_influence(results, 'fparams');
14
15 % get statistics for feature parameters
16 if isfield(results, 'trainparams')
17
18 stats.trainparams = gen_param_influence(results, 'trainparams');
19
20 % the following case is for backwards compability
21 elseif isfield(results, 'mlrparams')
22
23 stats.trainparams = gen_param_influence(results, 'mlrparams');
24 end
25
26 if show
27 % display results
28
29 if ~isempty(stats.fparams)
30 figure;
31 % subplot(2,1,1);
32 display_param_influence(stats.fparams);
33 end
34
35 if ~isempty(stats.trainparams)
36 figure;
37 % subplot(2,1,2);
38 display_param_influence(stats.trainparams);
39 end
40 end
41
42 end
43
44 % ---
45 % gen_param_influence
46 % ---
47 function stats = gen_param_influence(results, paramname)
48 % generates statistics given results and parameter type as string.
49
50 % get individual fields of this parameter set
51 ptypes = fieldnames(results(1).(paramname));
52
53 for i = 1:numel(ptypes)
54 % ---
55 % get all individual configurations of this parameter.
56 % ---
57 allvals = [results.(paramname)];
58
59 % take care of string args
60 if ~ischar(allvals(1).(ptypes{i}))
61 if ~iscell(allvals(1).(ptypes{i}))
62
63 % parameter array of chars
64 allvals = [allvals.(ptypes{i})];
65 else
66 % complex parameter array of cells
67 for j=1:numel(allvals)
68 tmpvals{j} = cell2str(allvals(j).(ptypes{i}));
69 end
70 allvals = tmpvals;
71 end
72 else
73 % parameter array of numbers
74 allvals = {allvals.(ptypes{i})};
75 end
76
77 % save using original parameter name
78 tmp = param_influence(results, allvals);
79
80 if ~isempty(tmp)
81 stats.(ptypes{i}) = tmp;
82 end
83 end
84
85 if ~exist('stats','var')
86 stats = [];
87 end
88
89 end
90
91
92 % ---
93 % param_influence
94 % ---
95 function out = param_influence(results, allvals)
96 % give the influence (given results) for the parameter settings
97 % given in allvals.
98 %
99 % numel(results) = numel(allvals)
100
101 % ---
102 % get all different settings of this parameter.
103 % NOTE: this might also work results-of the box for strings.
104 % not tested, below has to be changed ot cell / matrix notations
105 % ---
106 entries = unique(allvals);
107
108 % just calculate for params with more than one option
109 if numel(entries) < 2 || ischar(entries)
110
111 out = [];
112 return;
113 end
114
115 % calculate statstics for this fixed parameter
116 for j = 1:numel(entries)
117
118 % care for string parameters
119 if ~(iscell(allvals) && ischar(allvals{1}))
120 valid_idx = (allvals == entries(j));
121
122 % mean_ok_test
123 valid_ids = find(valid_idx);
124 else
125 valid_ids = strcellfind(allvals, entries{j}, 1);
126 end
127
128 % ---
129 % get the relevant statistics over the variations
130 % of the further parameters
131 % ---
132 mean_ok_testval = [];
133 for i = 1:numel(valid_ids)
134 mean_ok_testval = [mean_ok_testval results(valid_ids(i)).mean_ok_test(1,:)];
135 end
136
137 [ma,maidx] = max(mean_ok_testval);
138 [mi,miidx] = min(mean_ok_testval);
139 [me] = mean(mean_ok_testval);
140 mean_ok_test(j) = struct('max',ma , ...
141 'max_idx',valid_ids(maidx) , ...
142 'min',mi , ...
143 'min_idx',valid_ids(miidx) , ...
144 'mean',me);
145
146 % ---
147 % get the training statistics over the variations
148 % of the further parameters
149 % ---
150 mean_ok_trainval = [];
151 for i = 1:numel(valid_ids)
152 mean_ok_trainval = [mean_ok_trainval results(valid_ids(i)).mean_ok_train(1,:)];
153 end
154
155 [ma,maidx] = max(mean_ok_trainval);
156 % ---
157 % NOTE :this allowed for accesment of improvement by RBM selection
158 % warning testing random idx instead of best one
159 % maidx = max(1, round(rand(1)* numel(valid_ids)));
160 % % ---
161
162 [mi,miidx] = min(mean_ok_trainval);
163 [me] = mean(mean_ok_trainval);
164 mean_ok_train(j) = struct('max',ma , ...
165 'max_idx',valid_ids(maidx) , ...
166 'min',mi , ...
167 'min_idx',valid_ids(miidx) , ...
168 'mean',me);
169 end
170
171 % ---
172 % get the statistics over the different values
173 % this parameter can hold
174 %
175 % CAVE/TODO: the idx references are relative to valid_idx
176 % ---
177 [best, absolute.best_idx] = max([mean_ok_test.max]);
178 [worst, absolute.worst_idx] = min([mean_ok_test.max]);
179
180 % ---
181 % get differences:
182 difference.max = max([mean_ok_test.max]) - min([mean_ok_test.max]);
183
184 % format output
185 out.entries = entries;
186 out.mean_ok_test = mean_ok_test;
187 out.mean_ok_train = mean_ok_train;
188 out.difference = difference;
189 out.absolute = absolute;
190 end
191
192
193 % ---
194 % display
195 % ---
196 function a = display_param_influence(stats)
197
198 if isempty(stats)
199 return;
200 end
201
202 ptypes = fieldnames(stats);
203
204 dmean = [];
205 dmax = [];
206 best_val = {};
207 for i = 1:numel(ptypes)
208
209 % serialise the statistics
210 % dmean = [dmean stats.(ptypes{i}).difference.mean];
211 dmax = [dmax stats.(ptypes{i}).difference.max];
212 best_val = {best_val{:} stats.(ptypes{i}).entries( ...
213 stats.(ptypes{i}).absolute.best_idx) };
214
215 % take care of string args
216 if isnumeric(best_val{i})
217 lbl{i} = sprintf('%5.2f' ,best_val{i});
218 else
219 lbl{i} = best_val{i};
220 end
221 end
222
223
224 bar([dmax]'* 100);
225 colormap(1-spring);
226 % legend({'maximal effect on mean correctness'})
227 xlabel('effect on max. correctness for best + worst case of other parameters');
228 ylabel('correctness (0-100%)');
229 a = gca;
230 set(a,'XTick', 1:numel(ptypes), ...
231 'XTickLabel', ptypes);
232
233 % display best param results
234 for i = 1:numel(ptypes)
235 text(i,0,lbl{i}, 'color','k');
236 end
237
238 end