% ---
% The DistMeasureMahal class states a wrapper for 
% special Mahalanobis similarity distance, and is compatible with the 
% DistMeasure class
% ---
classdef DistMeasureMahal <  handle
    
    properties (SetAccess = private)
        
        mahalW;
        
        featX;
        
        ids;
        ret_ids;
        
        % special delta parameters
        deltafun;
        deltafun_params;
    end

    methods
        
    % ---
    % constructor
    % ---
    function m = DistMeasureMahal(clips, mahalW, featX, deltafun, deltafun_params)

        if nargin < 4 && (size(featX, 2) ~= numel(clips) || size(featX, 1) ~= length(mahalW))
            error 'wrong input format'
        end

        % fill index and generate matrix;
        m.ids = [clips.id];

        % reverse index
        m.ret_ids = sparse(numel(m.ids),1);
        m.ret_ids(m.ids) = 1:numel(m.ids);

        % ---
        % save mahal Matrix and lazy-copy features
        % ---
        if size(mahalW, 1) ~= size(mahalW, 2)
            
            m.mahalW = diag(mahalW);
        else
            
            m.mahalW = mahalW;
        end
        
        m.featX = featX;
        
        % ---
        % special deltas
        % ---
        if nargin > 3
            m.deltafun = deltafun;
            m.deltafun_params = deltafun_params;
        else 
            m.deltafun = [];
        end
    end
    

    % ---
    % this compability function returns the
    % mahalanobis similarity of two clip indices
    % --- 
    function out = mat(m, idxa, idxb)
    
        if nargin == 1
            idxa = 1:numel(m.ids);
            idxb = 1:numel(m.ids);
        end
        
        % ---
        % account for different delta functions
        % ---
        if ~isempty(m.deltafun)
            out = zeros(numel(idxa),numel(idxb));
            for i=1:numel(idxa)
                for j=1:numel(idxb)
                    
                    % calculate new distance
                    tmp = m.deltafun(m.featX(:,idxa), m.featX(:,idxb),m.deltafun_params{:});
                    out(i,j) = tmp' * m.mahalW * tmp;
                end
            end
        else
        % Standard Mahaldist is much faster to calculate
            out = sqdist( m.featX(:,idxa), m.featX(:,idxb), m.mahalW);
        end
    end
  
    % ---
    % returns the distance for the two input clips
    % ---
    function out = distance(m, clipa, clipb)
        posa = m.get_clip_pos(clipa);
        posb = m.get_clip_pos(clipb);
        
        out = m.mat(posa, posb);
    end
    
    % ---
    % returns a list of n (default = 10) clips most 
    % similar to the  input
    % ---
    function [clips, dist] = get_nearest(m, clip, n)
    % list = get_nearest(m, clip, n)
    % 
    % returns a list of n (default = 10) clips most 
    % similar to the  input

        % default number of results
        if nargin == 2 

            n = 10;
        end
        
        % return all clips in case n = 0
        if n == 0; n = numel(m.ids); end

        % get clip positions
        pos = m.get_clip_pos(clip);

        % sort according to distance
        [sc, idx] = sort( m.mat(pos, 1:numel(m.ids)), 'ascend');

        % we only output relevant data
        idx = idx(sc < inf);

        if numel(idx) > 0
            % create clips form best ids
            clips = MTTClip( m.ids( idx(1:min(n, end))));
            dist = m.mat(pos, idx(1:min(n, end)));

        else
            clips = [];
            dist = [];
        end
    end
    
    

    function [clips, dist] = present_nearest(m, clip, n)
    % plays and shows the n best hits for a given clip    

        % default number of results
        if nargin == 2 

            n = 3;
        end

        % get best list
        [clips, dist] = get_nearest(m, clip, n);

        clip.audio_features_basicsm.visualise();
        for i = 1:numel(clips)
            fprintf('\n\n\n- Rank %d, distance: %1.4f \n\n',i, dist(i));

            clips(i).audio_features_basicsm.visualise();
            h = gcf();
            t = clips(i).play(20);
            pause(t);
            close(h);
        end
    end

    function a = visualise(m)

        figure;

        % plot data

        imagesc(m.mat);

        a = gca;
        set(a,'YTick',[1:numel(m.ids)], 'YTickLabel',m.ids);
        set(a,'XTick',[1:numel(m.ids)], 'XTickLabel', m.ids);

        axis xy;
        colormap(hot);
    end

    % end methods
    end
    
    % ---
    % private methods
    % ---
    methods(Access = private)
        
    function out = get_clip_pos(m, clip)
        % returns position in mat for given clip

        out = m.ret_ids(clip.id);
    end
        
    end
    
end