diff core/magnatagatune/tests_evals/test_generic_display_results.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/core/magnatagatune/tests_evals/test_generic_display_results.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,209 @@
+function [out, stats, features, individual] = test_generic_display_results(file, get_features, show)
+% [out, stats] = test_generic_display_results([file], get_features)
+% 
+% [out, stats, features, individual] = test_generic_display_results([file], get_features)
+% 
+% displays the finalresults mat file and enables 
+% further analysis and duiagnostics of the individual runs
+
+features = [];
+
+if nargin < 3
+    show = 1;
+end
+if nargin < 2
+    get_features = 1;
+end
+
+global comparison;
+global comparison_ids;
+        
+if nargin < 1 || isempty(file) || isnumeric(file)
+    u = dir();
+    u = {u.name};
+    [idx, strpos] = substrcellfind(u, '_finalresults.mat', 1);
+    
+    if numel(idx) < 1
+        error 'This directory contains no valid test data';
+    end
+    
+    if exist('file','var') && isnumeric(file)
+        file = u{idx(file)};
+    else
+        if numel(idx) > 1
+            file = u{idx(ask_dataset())};
+        else
+            file = u{idx(1)};
+        end
+    end
+end
+
+
+% ---
+% LOAD THE RESULT DATA
+% We have:
+% Y
+% out.fparams
+%     trainparams
+%     dataPartition
+%     mean_ok_test
+%     var_ok_test
+%     mean_ok_train
+% ---
+load(file);
+
+% compability
+if isfield(out, 'mlrparams')
+    for i = 1:numel(out)
+        out(i).trainparams = out(i).mlrparams;
+    end
+end
+
+
+% ---
+% % get statistics for feature parameters
+% Visualise the accuracy and variance
+% ---
+if isfield(out, 'inctrain') && show
+    for i = 1:numel(out)
+        
+       figure;
+       boxplot([out(i).inctrain.mean_ok_test], sqrt([out(i).inctrain.var_ok_test]), [out(i).inctrain.mean_ok_train]);
+       set(gca,'XTick',1:numel(out(i).inctrain.trainfrac), ...
+        'XTickLabel', out(i).inctrain.trainfrac* 100);
+
+       xlabel ('fraction of training data');
+       title (sprintf('increasing training size test, config %d',i));
+       legend('train', 'train weighted', 'test', 'test weighted');
+       
+    end
+end
+
+
+if numel([out.mean_ok_test]) > 1 && show
+    
+    % plot means  % plot std = sqrt(var) % plot training results
+    figure;
+    boxplot([out.mean_ok_test], sqrt([out.var_ok_test]), [out.mean_ok_train]);
+    title (sprintf('Performance for all configs'));
+end
+    
+    
+% --- 
+% write max. test success
+% ---
+    mean_ok_test = [out.mean_ok_test];
+    [val, idx] = max(mean_ok_test(1,:));
+if show
+    fprintf(' --- Maximal test set success: nr. %d, %3.2f percent. --- \n', idx, val * 100)
+end
+
+% ---
+% display parameter statistics
+% ---
+stats = test_generic_display_param_influence(out, show);
+
+
+if nargout < 3 
+    return; 
+end
+% ---
+% display statistics and get features
+%  for run with best test success
+% ---
+[resfile, featfile] = get_res_filename(out, idx);
+
+% ---
+% import features:
+% 1. reset databse
+% 2. import features
+% 3. assign to clip ids as in ranking
+% ---
+if get_features
+    type = MTTAudioFeatureDBgen.import_type(featfile);
+    db_name = MTTAudioFeatureDBgen.db_name(type);
+    eval(sprintf('global %s', db_name));
+    eval(sprintf('%s.reset();', db_name));
+    eval(sprintf('features = %s.import(featfile);', db_name));
+
+    if isfield(out,'clip_ids')
+        clips = MTTClip(out(1).clip_ids);
+        features = clips.features(type);
+    end
+end
+
+% ---
+% Display Metric Stats
+% tmp = test_mlr_display_metric_stats(individual.out, individual.diag, features);
+% ---
+
+if nargout < 4 
+    return; 
+end
+individual = load(resfile);
+for i = 1:numel(out)
+ 
+        [resfile, featfile] = get_res_filename(out, i);
+        
+        if get_features
+            % reset db and load testing features
+            eval(sprintf('global %s', db_name));
+            eval(sprintf('%s.reset();', db_name));
+            eval(sprintf('%s.import(featfile);', db_name));
+        end
+    
+        % load individual results
+        if i == 1;
+            
+            individual = load(resfile);
+        else
+            
+            individual(i) = load(resfile);
+        end
+end
+end
+
+function out = ask_dataset()
+% ---
+% displays the parameters of the datasets, 
+% and asks for the right one to display
+% ---
+clc;
+u = dir();
+u = {u.name};
+[idx, strpos] = substrcellfind(u, '_params.mat', 1);
+
+for i = 1:numel(idx)
+    file = u{idx(i)};
+    fprintf('------------ Dataset nr. %d --------------\n',i);
+    fprintf('Filename: %s\n',file);
+    type(file);
+end
+
+out = (input('Please choose the dataset number: '));
+end
+
+
+function [resfile, featfile] = get_res_filename(out, i)
+% get filename given test results and index
+
+    paramhash = hash(xml_format(out(i).fparams),'MD5');
+    
+    paramhash_mlr = hash(xml_format(out(i).trainparams),'MD5');
+
+    featfile = sprintf('runlog_%s_feat.mat', paramhash);
+
+    resfile = sprintf('runlog_%s.%s_results.mat',...
+            paramhash, paramhash_mlr);
+end
+
+
+function boxplot(mean, std, train);
+
+    bar([train; mean]', 1.5);
+    hold on;
+    errorbar(1:size(mean,2), mean(1,:), std(1,:),'.');
+%     plot(train,'rO');
+    colormap(spring);
+    axis([0 size(mean,2)+1 max(0, min(min([train mean] - 0.1))) max(max([train mean] + 0.1))]);
+end