Mercurial > hg > camir-aes2014
comparison core/magnatagatune/tests_evals/test_generic_display_param_influence.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function stats = test_generic_display_param_influence(results, show) | |
2 % returns the mean accuracy influence of each feature and training parameter | |
3 % | |
4 % the influence is measured by comparing the mean | |
5 % achieved accuracy for all tries with each specific | |
6 % parameter being constant | |
7 % | |
8 % TODO: evaluate how the comparisons of all configuration | |
9 % tuples twith just the specific analysed parameter | |
10 % changing differ from the above approach | |
11 | |
12 % get statistics for feature parameters | |
13 stats.fparams = gen_param_influence(results, 'fparams'); | |
14 | |
15 % get statistics for feature parameters | |
16 if isfield(results, 'trainparams') | |
17 | |
18 stats.trainparams = gen_param_influence(results, 'trainparams'); | |
19 | |
20 % the following case is for backwards compability | |
21 elseif isfield(results, 'mlrparams') | |
22 | |
23 stats.trainparams = gen_param_influence(results, 'mlrparams'); | |
24 end | |
25 | |
26 if show | |
27 % display results | |
28 | |
29 if ~isempty(stats.fparams) | |
30 figure; | |
31 % subplot(2,1,1); | |
32 display_param_influence(stats.fparams); | |
33 end | |
34 | |
35 if ~isempty(stats.trainparams) | |
36 figure; | |
37 % subplot(2,1,2); | |
38 display_param_influence(stats.trainparams); | |
39 end | |
40 end | |
41 | |
42 end | |
43 | |
44 % --- | |
45 % gen_param_influence | |
46 % --- | |
47 function stats = gen_param_influence(results, paramname) | |
48 % generates statistics given results and parameter type as string. | |
49 | |
50 % get individual fields of this parameter set | |
51 ptypes = fieldnames(results(1).(paramname)); | |
52 | |
53 for i = 1:numel(ptypes) | |
54 % --- | |
55 % get all individual configurations of this parameter. | |
56 % --- | |
57 allvals = [results.(paramname)]; | |
58 | |
59 % take care of string args | |
60 if ~ischar(allvals(1).(ptypes{i})) | |
61 if ~iscell(allvals(1).(ptypes{i})) | |
62 | |
63 % parameter array of chars | |
64 allvals = [allvals.(ptypes{i})]; | |
65 else | |
66 % complex parameter array of cells | |
67 for j=1:numel(allvals) | |
68 tmpvals{j} = cell2str(allvals(j).(ptypes{i})); | |
69 end | |
70 allvals = tmpvals; | |
71 end | |
72 else | |
73 % parameter array of numbers | |
74 allvals = {allvals.(ptypes{i})}; | |
75 end | |
76 | |
77 % save using original parameter name | |
78 tmp = param_influence(results, allvals); | |
79 | |
80 if ~isempty(tmp) | |
81 stats.(ptypes{i}) = tmp; | |
82 end | |
83 end | |
84 | |
85 if ~exist('stats','var') | |
86 stats = []; | |
87 end | |
88 | |
89 end | |
90 | |
91 | |
92 % --- | |
93 % param_influence | |
94 % --- | |
95 function out = param_influence(results, allvals) | |
96 % give the influence (given results) for the parameter settings | |
97 % given in allvals. | |
98 % | |
99 % numel(results) = numel(allvals) | |
100 | |
101 % --- | |
102 % get all different settings of this parameter. | |
103 % NOTE: this might also work results-of the box for strings. | |
104 % not tested, below has to be changed ot cell / matrix notations | |
105 % --- | |
106 entries = unique(allvals); | |
107 | |
108 % just calculate for params with more than one option | |
109 if numel(entries) < 2 || ischar(entries) | |
110 | |
111 out = []; | |
112 return; | |
113 end | |
114 | |
115 % calculate statstics for this fixed parameter | |
116 for j = 1:numel(entries) | |
117 | |
118 % care for string parameters | |
119 if ~(iscell(allvals) && ischar(allvals{1})) | |
120 valid_idx = (allvals == entries(j)); | |
121 | |
122 % mean_ok_test | |
123 valid_ids = find(valid_idx); | |
124 else | |
125 valid_ids = strcellfind(allvals, entries{j}, 1); | |
126 end | |
127 | |
128 % --- | |
129 % get the relevant statistics over the variations | |
130 % of the further parameters | |
131 % --- | |
132 mean_ok_testval = []; | |
133 for i = 1:numel(valid_ids) | |
134 mean_ok_testval = [mean_ok_testval results(valid_ids(i)).mean_ok_test(1,:)]; | |
135 end | |
136 | |
137 [ma,maidx] = max(mean_ok_testval); | |
138 [mi,miidx] = min(mean_ok_testval); | |
139 [me] = mean(mean_ok_testval); | |
140 mean_ok_test(j) = struct('max',ma , ... | |
141 'max_idx',valid_ids(maidx) , ... | |
142 'min',mi , ... | |
143 'min_idx',valid_ids(miidx) , ... | |
144 'mean',me); | |
145 | |
146 % --- | |
147 % get the training statistics over the variations | |
148 % of the further parameters | |
149 % --- | |
150 mean_ok_trainval = []; | |
151 for i = 1:numel(valid_ids) | |
152 mean_ok_trainval = [mean_ok_trainval results(valid_ids(i)).mean_ok_train(1,:)]; | |
153 end | |
154 | |
155 [ma,maidx] = max(mean_ok_trainval); | |
156 % --- | |
157 % NOTE :this allowed for accesment of improvement by RBM selection | |
158 % warning testing random idx instead of best one | |
159 % maidx = max(1, round(rand(1)* numel(valid_ids))); | |
160 % % --- | |
161 | |
162 [mi,miidx] = min(mean_ok_trainval); | |
163 [me] = mean(mean_ok_trainval); | |
164 mean_ok_train(j) = struct('max',ma , ... | |
165 'max_idx',valid_ids(maidx) , ... | |
166 'min',mi , ... | |
167 'min_idx',valid_ids(miidx) , ... | |
168 'mean',me); | |
169 end | |
170 | |
171 % --- | |
172 % get the statistics over the different values | |
173 % this parameter can hold | |
174 % | |
175 % CAVE/TODO: the idx references are relative to valid_idx | |
176 % --- | |
177 [best, absolute.best_idx] = max([mean_ok_test.max]); | |
178 [worst, absolute.worst_idx] = min([mean_ok_test.max]); | |
179 | |
180 % --- | |
181 % get differences: | |
182 difference.max = max([mean_ok_test.max]) - min([mean_ok_test.max]); | |
183 | |
184 % format output | |
185 out.entries = entries; | |
186 out.mean_ok_test = mean_ok_test; | |
187 out.mean_ok_train = mean_ok_train; | |
188 out.difference = difference; | |
189 out.absolute = absolute; | |
190 end | |
191 | |
192 | |
193 % --- | |
194 % display | |
195 % --- | |
196 function a = display_param_influence(stats) | |
197 | |
198 if isempty(stats) | |
199 return; | |
200 end | |
201 | |
202 ptypes = fieldnames(stats); | |
203 | |
204 dmean = []; | |
205 dmax = []; | |
206 best_val = {}; | |
207 for i = 1:numel(ptypes) | |
208 | |
209 % serialise the statistics | |
210 % dmean = [dmean stats.(ptypes{i}).difference.mean]; | |
211 dmax = [dmax stats.(ptypes{i}).difference.max]; | |
212 best_val = {best_val{:} stats.(ptypes{i}).entries( ... | |
213 stats.(ptypes{i}).absolute.best_idx) }; | |
214 | |
215 % take care of string args | |
216 if isnumeric(best_val{i}) | |
217 lbl{i} = sprintf('%5.2f' ,best_val{i}); | |
218 else | |
219 lbl{i} = best_val{i}; | |
220 end | |
221 end | |
222 | |
223 | |
224 bar([dmax]'* 100); | |
225 colormap(1-spring); | |
226 % legend({'maximal effect on mean correctness'}) | |
227 xlabel('effect on max. correctness for best + worst case of other parameters'); | |
228 ylabel('correctness (0-100%)'); | |
229 a = gca; | |
230 set(a,'XTick', 1:numel(ptypes), ... | |
231 'XTickLabel', ptypes); | |
232 | |
233 % display best param results | |
234 for i = 1:numel(ptypes) | |
235 text(i,0,lbl{i}, 'color','k'); | |
236 end | |
237 | |
238 end |