view core/magnatagatune/tests_evals/test_generic_display_param_influence.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function stats = test_generic_display_param_influence(results, show)
% returns the mean accuracy influence of each feature and training parameter 
%
% the influence is measured by comparing the mean 
% achieved accuracy for all tries with each specific 
% parameter being constant
% 
% TODO: evaluate how the comparisons of all configuration
% tuples twith just the specific analysed parameter 
% changing differ from the above approach

% get statistics for feature parameters
stats.fparams = gen_param_influence(results, 'fparams');

% get statistics for feature parameters
if isfield(results, 'trainparams')
    
    stats.trainparams = gen_param_influence(results, 'trainparams');
    
    % the following case is for backwards compability
elseif isfield(results, 'mlrparams')
    
    stats.trainparams = gen_param_influence(results, 'mlrparams');
end

if show
    % display results

    if ~isempty(stats.fparams)
        figure;
    %    subplot(2,1,1);
        display_param_influence(stats.fparams);
    end

    if ~isempty(stats.trainparams)
        figure;
    %    subplot(2,1,2);
        display_param_influence(stats.trainparams);
    end
end

end

% --- 
% gen_param_influence
% ---
function stats = gen_param_influence(results, paramname)
% generates statistics given results and parameter type as string. 

% get individual fields of this parameter set
ptypes = fieldnames(results(1).(paramname));

for i = 1:numel(ptypes)
    % ---
    % get all individual configurations of this parameter.
    % ---
    allvals = [results.(paramname)];
    
    % take care of string args
    if ~ischar(allvals(1).(ptypes{i}))
        if ~iscell(allvals(1).(ptypes{i}))
            
            % parameter array of chars 
            allvals = [allvals.(ptypes{i})];
        else
            % complex parameter array of cells 
            for j=1:numel(allvals)
                tmpvals{j} = cell2str(allvals(j).(ptypes{i}));
            end
            allvals = tmpvals;
        end
    else
        % parameter array of numbers 
        allvals = {allvals.(ptypes{i})};
    end
    
    % save using original parameter name
    tmp = param_influence(results, allvals);
    
    if ~isempty(tmp)
        stats.(ptypes{i}) = tmp;
    end   
end

if ~exist('stats','var')
    stats = [];
end

end


% --- 
% param_influence
% ---
function out = param_influence(results, allvals)
% give the influence (given results) for the parameter settings
% given in allvals.
%
% numel(results) = numel(allvals)

    % ---
    % get all different settings of this parameter.
    % NOTE: this might also work results-of the box for strings.
    % not tested, below has to be changed ot cell / matrix notations
    % ---
    entries = unique(allvals);
    
    % just calculate for params with more than one option
    if numel(entries) < 2 || ischar(entries)
        
        out = [];
        return;
    end
    
    % calculate statstics for this fixed parameter
    for j = 1:numel(entries)
        
        % care for string parameters
        if ~(iscell(allvals) && ischar(allvals{1}))
            valid_idx = (allvals == entries(j));
            
            % mean_ok_test
            valid_ids = find(valid_idx);
        else
            valid_ids = strcellfind(allvals, entries{j}, 1);
        end
        
        % ---
        % get the relevant statistics over the variations
        % of the further parameters
        % ---
        mean_ok_testval = [];
        for i = 1:numel(valid_ids)
             mean_ok_testval = [mean_ok_testval results(valid_ids(i)).mean_ok_test(1,:)];
        end

        [ma,maidx] = max(mean_ok_testval);
        [mi,miidx] = min(mean_ok_testval);
        [me] = mean(mean_ok_testval);
        mean_ok_test(j) = struct('max',ma , ...
                            'max_idx',valid_ids(maidx) , ...
                            'min',mi , ...
                            'min_idx',valid_ids(miidx) , ...
                            'mean',me);
                        
        % ---
        % get the training statistics over the variations
        % of the further parameters
        % ---
        mean_ok_trainval = [];
        for i = 1:numel(valid_ids)
             mean_ok_trainval = [mean_ok_trainval results(valid_ids(i)).mean_ok_train(1,:)];
        end

        [ma,maidx] = max(mean_ok_trainval);
        % ---
        % NOTE :this allowed for accesment of improvement by RBM selection
%         warning testing random idx instead of best one 
%         maidx = max(1, round(rand(1)* numel(valid_ids)));
%       % ---

        [mi,miidx] = min(mean_ok_trainval);
        [me] = mean(mean_ok_trainval);
        mean_ok_train(j) = struct('max',ma , ...
                            'max_idx',valid_ids(maidx) , ...
                            'min',mi , ...
                            'min_idx',valid_ids(miidx) , ...
                            'mean',me);
    end
    
    % ---
    % get the statistics over the different values 
    % this parameter can hold
    %
    % CAVE/TODO: the idx references are relative to valid_idx
    % ---
    [best, absolute.best_idx] = max([mean_ok_test.max]);
    [worst, absolute.worst_idx] = min([mean_ok_test.max]);
    
    % ---
    % get differences:
    difference.max = max([mean_ok_test.max]) - min([mean_ok_test.max]);
    
    % format output
    out.entries = entries;
    out.mean_ok_test = mean_ok_test;
    out.mean_ok_train = mean_ok_train;
    out.difference = difference;
    out.absolute = absolute;
end
    

% --- 
% display
% ---
function a = display_param_influence(stats)

if isempty(stats)
    return;
end

ptypes = fieldnames(stats);

dmean = [];
dmax = [];
best_val = {};
for i = 1:numel(ptypes)
    
    % serialise the statistics
%    dmean = [dmean stats.(ptypes{i}).difference.mean];
     dmax = [dmax stats.(ptypes{i}).difference.max];
    best_val = {best_val{:} stats.(ptypes{i}).entries( ...
        stats.(ptypes{i}).absolute.best_idx) };
    
    % take care of string args
    if isnumeric(best_val{i})
        lbl{i} = sprintf('%5.2f' ,best_val{i});
    else
        lbl{i} = best_val{i};
    end
end


bar([dmax]'* 100);
colormap(1-spring);
% legend({'maximal effect on mean correctness'})
xlabel('effect on max. correctness for best + worst case of other parameters');
ylabel('correctness (0-100%)');
a = gca;
set(a,'XTick', 1:numel(ptypes), ...
    'XTickLabel', ptypes);

% display best param results
for i = 1:numel(ptypes)
    text(i,0,lbl{i}, 'color','k');
end

end