% ---
% The DistMeasureMahal class states a wrapper for 
% special Mahalanobis similarity distance, and is compatible with the 
% DistMeasure class
% ---
classdef DistMeasureNNet <  handle
    
    properties (SetAccess = private)
        
        net;
        
        featX;
        
        ids;
        ret_ids;
    end

    methods
        
    % ---
    % constructor
    % ---
    function m = DistMeasureNNet(clips, net, featX)

        if size(featX, 2) ~= numel(clips) 
            error 'wrong input format'
        end

        % fill index and generate matrix;
        m.ids = [clips.id];

        % reverse index
        m.ret_ids = sparse(numel(m.ids),1);
        m.ret_ids(m.ids) = 1:numel(m.ids);

        % ---
        % save neural net and lazy-copy features
        % ---
        m.net = net;
        
        m.featX = featX;
        
    end
    

    % ---
    % this function returns the
    % similarity of two clip indices
    % --- 
    function out = mat(m, idxa, idxb)
    
        if nargin == 1
            idxa = 1:numel(m.ids);
            idxb = 1:numel(m.ids);
        end
        
        % cycle through all index combinations
        out = zeros(numel(idxa), numel(idxb));
        for i = 1:numel(idxa)
           for j = 1:numel(idxb)
               
                % calculate distance vector
                deltas = m.featX(:,idxa(i)) - m.featX(:,idxb(j));
                
                % return distance from net
                out(i,j) = m.net.calcValue(deltas);
           end
        end
        
    end
  
    % ---
    % returns the distance for the two input clips
    % ---
    function out = distance(m, clipa, clipb)
        posa = m.get_clip_pos(clipa);
        posb = m.get_clip_pos(clipb);
        
        out = m.mat(posa, posb);
    end
    
    % ---
    % returns a list of n (default = 10) clips most 
    % similar to the  input
    % ---
    function [clips, dist] = get_nearest(m, clip, n)
    % list = get_nearest(m, clip, n)
    % 
    % returns a list of n (default = 10) clips most 
    % similar to the  input

        % default number of results
        if nargin == 2 

            n = 10;
        end
        
        % return all clips in case n = 0
        if n == 0; n = numel(m.ids); end

        % get clip positions
        pos = m.get_clip_pos(clip);

        % sort according to distance
        [sc, idx] = sort( m.mat(pos, 1:numel(m.ids)), 'ascend');

        % we only output relevant data
        idx = idx(sc < inf);

        if numel(idx) > 0
            % create clips form best ids
            clips = MTTClip( m.ids( idx(1:min(n, end))));
            dist = m.mat(pos, idx(1:min(n, end)));

        else
            clips = [];
            dist = [];
        end
    end
    
    

    function [clips, dist] = present_nearest(m, clip, n)
    % plays and shows the n best hits for a given clip    

        % default number of results
        if nargin == 2 

            n = 3;
        end

        % get best list
        [clips, dist] = get_nearest(m, clip, n);

        clip.audio_features_basicsm.visualise();
        for i = 1:numel(clips)
            fprintf('\n\n\n- Rank %d, distance: %1.4f \n\n',i, dist(i));

            clips(i).audio_features_basicsm.visualise();
            h = gcf();
            t = clips(i).play(20);
            pause(t);
            close(h);
        end
    end

    function a = visualise(m)

        figure;

        % plot data

        imagesc(m.mat);

        a = gca;
        set(a,'YTick',[1:numel(m.ids)], 'YTickLabel',m.ids);
        set(a,'XTick',[1:numel(m.ids)], 'XTickLabel', m.ids);

        axis xy;
        colormap(hot);
    end

    % end methods
    end
    
    % ---
    % private methods
    % ---
    methods(Access = private)
        
    function out = get_clip_pos(m, clip)
        % returns position in mat for given clip

        out = m.ret_ids(clip.id);
    end
        
    end
    
end