Mercurial > hg > camir-aes2014
comparison core/magnatagatune/tests_evals/rbm_subspace/write_mat_results_ISMIR13RBM_singletraining.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [out, stats] = write_mat_results_ISMIR13RBM_singletraining(dirin,fileout) | |
2 % combine the test results from the directories supplied, | |
3 % group them according to dataset parameter values | |
4 % combine the test results from the directories supplied, | |
5 % group them according to dataset parameter values | |
6 | |
7 features = []; | |
8 show = 1; | |
9 | |
10 if nargin == 0 | |
11 dirin{1} = './'; | |
12 end | |
13 | |
14 global comparison; | |
15 global comparison_ids; | |
16 | |
17 newout = []; | |
18 thisdir = pwd; | |
19 % loop through al lthe result directories and | |
20 for diri = 1:numel(dirin) | |
21 | |
22 % --- | |
23 % go to directory and locate file | |
24 % --- | |
25 cd(dirin{diri}); | |
26 | |
27 u = dir(); | |
28 u = {u.name}; | |
29 [idx, strpos] = substrcellfind(u, '_finalresults.mat', 1); | |
30 | |
31 if numel(idx) < 1 | |
32 error 'This directory contains no valid test data'; | |
33 end | |
34 | |
35 % just one or more tests in this folder? | |
36 if exist('file','var') && isnumeric(file) | |
37 cprint(1, 'loading one result file'); | |
38 file = u{idx(file)}; | |
39 data = load(file); | |
40 sappend(out,data.out); | |
41 else | |
42 for filei = 1:numel(idx) | |
43 cprint(1, 'loading result file %i of %i',filei, numel(idx)); | |
44 file = u{idx(filei)}; | |
45 data = load(file); | |
46 newout = sappend(newout,data.out); | |
47 end | |
48 end | |
49 % reset act directory | |
50 cd(thisdir); | |
51 end | |
52 | |
53 % --- | |
54 % filter according to training parameter C | |
55 % | |
56 % NOTE :if we don't filter by C, we get strong overfitting with training | |
57 % success > 96 % and test set performance aorund 65 % | |
58 % --- | |
59 cs = zeros(numel(newout),1); | |
60 for i=1:numel(newout) | |
61 cs(i) = newout(i).trainparams.C; | |
62 end | |
63 cvals = unique(cs); | |
64 | |
65 for ci=1:numel(cvals) | |
66 valididx = find(cs == cvals(ci)); | |
67 filteredout = newout(valididx); | |
68 | |
69 % --- | |
70 % get parameter statistics | |
71 % --- | |
72 stats = test_generic_display_param_influence(filteredout, show); | |
73 | |
74 % get maximal values for each test set bin | |
75 % --- | |
76 % trainparams.dataset contains sets which have each only one bin of the | |
77 % ismir testsets | |
78 % --- | |
79 max_idx = [stats.trainparams.dataset.mean_ok_train.max_idx]; | |
80 ok_test = zeros(2, numel(max_idx)); | |
81 ok_train = zeros(2, numel(max_idx)); | |
82 ok_config = []; | |
83 % cycle over all test sets and save best result | |
84 for i=1:numel(max_idx) | |
85 ok_test(:,i) = filteredout(max_idx(i)).mean_ok_test; | |
86 ok_train(:,i) = filteredout(max_idx(i)).mean_ok_train; | |
87 ok_config = sappend(ok_config,struct('trainparams',filteredout(max_idx(i)).trainparams, ... | |
88 'fparams',filteredout(max_idx(i)).fparams)); | |
89 end | |
90 % save the stuff | |
91 out(ci).mean_ok_test = mean(ok_test,2); | |
92 out(ci).var_ok_test = var(ok_test,0,2); | |
93 out(ci).mean_ok_train = mean(ok_train,2); | |
94 out(ci).var_ok_train = var(ok_train,0,2); | |
95 out(ci).trainparams.C = cvals(ci); | |
96 out(ci).ok_config = ok_config; | |
97 out(ci).ok_test = ok_test; | |
98 out(ci).ok_train = ok_train; | |
99 end | |
100 | |
101 % --- | |
102 % show results for different C | |
103 % --- | |
104 if numel([out.mean_ok_test]) > 1 && show | |
105 | |
106 % plot means % plot std = sqrt(var) % plot training results | |
107 figure; | |
108 boxplot([out.mean_ok_test], sqrt([out.var_ok_test]), [out.mean_ok_train]); | |
109 title (sprintf('Performance for all configs')); | |
110 end | |
111 | |
112 % --- | |
113 % write max. test success | |
114 % --- | |
115 mean_ok_test = [out.mean_ok_test]; | |
116 [val, idx] = max(mean_ok_test(1,:)); | |
117 if show | |
118 fprintf(' --- Maximal test set success: nr. %d, %3.2f percent. --- \n', idx, val * 100) | |
119 end | |
120 | |
121 % save | |
122 save([hash(strcat(dirin{:}),'md5') '_summary'], 'out'); | |
123 | |
124 end | |
125 | |
126 | |
127 | |
128 function boxplot(mean, std, train); | |
129 | |
130 bar([train; mean]', 1.5); | |
131 hold on; | |
132 errorbar(1:size(mean,2), mean(1,:), std(1,:),'.'); | |
133 % plot(train,'rO'); | |
134 colormap(spring); | |
135 axis([0 size(mean,2)+1 max(0, min(min([train mean] - 0.1))) max(max([train mean] + 0.1))]); | |
136 end |