diff core/magnatagatune/tests_evals/test_generic_display_param_influence.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/core/magnatagatune/tests_evals/test_generic_display_param_influence.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,238 @@
+function stats = test_generic_display_param_influence(results, show)
+% returns the mean accuracy influence of each feature and training parameter 
+%
+% the influence is measured by comparing the mean 
+% achieved accuracy for all tries with each specific 
+% parameter being constant
+% 
+% TODO: evaluate how the comparisons of all configuration
+% tuples twith just the specific analysed parameter 
+% changing differ from the above approach
+
+% get statistics for feature parameters
+stats.fparams = gen_param_influence(results, 'fparams');
+
+% get statistics for feature parameters
+if isfield(results, 'trainparams')
+    
+    stats.trainparams = gen_param_influence(results, 'trainparams');
+    
+    % the following case is for backwards compability
+elseif isfield(results, 'mlrparams')
+    
+    stats.trainparams = gen_param_influence(results, 'mlrparams');
+end
+
+if show
+    % display results
+
+    if ~isempty(stats.fparams)
+        figure;
+    %    subplot(2,1,1);
+        display_param_influence(stats.fparams);
+    end
+
+    if ~isempty(stats.trainparams)
+        figure;
+    %    subplot(2,1,2);
+        display_param_influence(stats.trainparams);
+    end
+end
+
+end
+
+% --- 
+% gen_param_influence
+% ---
+function stats = gen_param_influence(results, paramname)
+% generates statistics given results and parameter type as string. 
+
+% get individual fields of this parameter set
+ptypes = fieldnames(results(1).(paramname));
+
+for i = 1:numel(ptypes)
+    % ---
+    % get all individual configurations of this parameter.
+    % ---
+    allvals = [results.(paramname)];
+    
+    % take care of string args
+    if ~ischar(allvals(1).(ptypes{i}))
+        if ~iscell(allvals(1).(ptypes{i}))
+            
+            % parameter array of chars 
+            allvals = [allvals.(ptypes{i})];
+        else
+            % complex parameter array of cells 
+            for j=1:numel(allvals)
+                tmpvals{j} = cell2str(allvals(j).(ptypes{i}));
+            end
+            allvals = tmpvals;
+        end
+    else
+        % parameter array of numbers 
+        allvals = {allvals.(ptypes{i})};
+    end
+    
+    % save using original parameter name
+    tmp = param_influence(results, allvals);
+    
+    if ~isempty(tmp)
+        stats.(ptypes{i}) = tmp;
+    end   
+end
+
+if ~exist('stats','var')
+    stats = [];
+end
+
+end
+
+
+% --- 
+% param_influence
+% ---
+function out = param_influence(results, allvals)
+% give the influence (given results) for the parameter settings
+% given in allvals.
+%
+% numel(results) = numel(allvals)
+
+    % ---
+    % get all different settings of this parameter.
+    % NOTE: this might also work results-of the box for strings.
+    % not tested, below has to be changed ot cell / matrix notations
+    % ---
+    entries = unique(allvals);
+    
+    % just calculate for params with more than one option
+    if numel(entries) < 2 || ischar(entries)
+        
+        out = [];
+        return;
+    end
+    
+    % calculate statstics for this fixed parameter
+    for j = 1:numel(entries)
+        
+        % care for string parameters
+        if ~(iscell(allvals) && ischar(allvals{1}))
+            valid_idx = (allvals == entries(j));
+            
+            % mean_ok_test
+            valid_ids = find(valid_idx);
+        else
+            valid_ids = strcellfind(allvals, entries{j}, 1);
+        end
+        
+        % ---
+        % get the relevant statistics over the variations
+        % of the further parameters
+        % ---
+        mean_ok_testval = [];
+        for i = 1:numel(valid_ids)
+             mean_ok_testval = [mean_ok_testval results(valid_ids(i)).mean_ok_test(1,:)];
+        end
+
+        [ma,maidx] = max(mean_ok_testval);
+        [mi,miidx] = min(mean_ok_testval);
+        [me] = mean(mean_ok_testval);
+        mean_ok_test(j) = struct('max',ma , ...
+                            'max_idx',valid_ids(maidx) , ...
+                            'min',mi , ...
+                            'min_idx',valid_ids(miidx) , ...
+                            'mean',me);
+                        
+        % ---
+        % get the training statistics over the variations
+        % of the further parameters
+        % ---
+        mean_ok_trainval = [];
+        for i = 1:numel(valid_ids)
+             mean_ok_trainval = [mean_ok_trainval results(valid_ids(i)).mean_ok_train(1,:)];
+        end
+
+        [ma,maidx] = max(mean_ok_trainval);
+        % ---
+        % NOTE :this allowed for accesment of improvement by RBM selection
+%         warning testing random idx instead of best one 
+%         maidx = max(1, round(rand(1)* numel(valid_ids)));
+%       % ---
+
+        [mi,miidx] = min(mean_ok_trainval);
+        [me] = mean(mean_ok_trainval);
+        mean_ok_train(j) = struct('max',ma , ...
+                            'max_idx',valid_ids(maidx) , ...
+                            'min',mi , ...
+                            'min_idx',valid_ids(miidx) , ...
+                            'mean',me);
+    end
+    
+    % ---
+    % get the statistics over the different values 
+    % this parameter can hold
+    %
+    % CAVE/TODO: the idx references are relative to valid_idx
+    % ---
+    [best, absolute.best_idx] = max([mean_ok_test.max]);
+    [worst, absolute.worst_idx] = min([mean_ok_test.max]);
+    
+    % ---
+    % get differences:
+    difference.max = max([mean_ok_test.max]) - min([mean_ok_test.max]);
+    
+    % format output
+    out.entries = entries;
+    out.mean_ok_test = mean_ok_test;
+    out.mean_ok_train = mean_ok_train;
+    out.difference = difference;
+    out.absolute = absolute;
+end
+    
+
+% --- 
+% display
+% ---
+function a = display_param_influence(stats)
+
+if isempty(stats)
+    return;
+end
+
+ptypes = fieldnames(stats);
+
+dmean = [];
+dmax = [];
+best_val = {};
+for i = 1:numel(ptypes)
+    
+    % serialise the statistics
+%    dmean = [dmean stats.(ptypes{i}).difference.mean];
+     dmax = [dmax stats.(ptypes{i}).difference.max];
+    best_val = {best_val{:} stats.(ptypes{i}).entries( ...
+        stats.(ptypes{i}).absolute.best_idx) };
+    
+    % take care of string args
+    if isnumeric(best_val{i})
+        lbl{i} = sprintf('%5.2f' ,best_val{i});
+    else
+        lbl{i} = best_val{i};
+    end
+end
+
+
+bar([dmax]'* 100);
+colormap(1-spring);
+% legend({'maximal effect on mean correctness'})
+xlabel('effect on max. correctness for best + worst case of other parameters');
+ylabel('correctness (0-100%)');
+a = gca;
+set(a,'XTick', 1:numel(ptypes), ...
+    'XTickLabel', ptypes);
+
+% display best param results
+for i = 1:numel(ptypes)
+    text(i,0,lbl{i}, 'color','k');
+end
+
+end