view CAdaptInstrSpec.m @ 0:b4e26b53072f tip

Initial commit.
author Holger Kirchhoff <holger.kirchhoff@eecs.qmul.ac.uk>
date Tue, 04 Dec 2012 13:57:15 +0000
parents
children
line wrap: on
line source
classdef CAdaptInstrSpec    
% CAdaptInstrSpec - Estimation of a filter curve that enables the
%   adaptation of instrument templates from one recording to another. The
%   beta-divergence is used as a cost function between the original and the
%   adapted spectra. All spectra need to be provided on a logarithmic
%   frequency axis (for now, an extension to linear frequency axes should
%   be straightforward). For further details on the estimation method, see
%   the references below.
%
% PROPERTIES
%    no public properties
%
% METHODS
%    CAdaptInstrSpec    - constructor for CAdaptInstrSpec object
%    setH               - sets filter curve 'h'
%    getH               - returns estimate of filter curve ''h''
%    getSmoothedH       - returns smoothed and interpolated version of the
%                         filter curve ''h''
%    updateH            - performs single update of filter curve ''h''
%    estimateSpectra    - estimates spectra based on current estimate of
%                         the filter curve.
%    compBetaDivergence - compute beta-divergence between original and
%                         estimated spectra
%
% For further help on the methods, type 'help CAdaptInstrSpec.[methodName]'
%
%
% References:
%
% [1] H. Kirchhoff, S. Dixon, A. Klapuri. Missing spectral templates
%     estimation for user-assisted music transcription. IEEE International
%     Conference on Acoustics, Speech and Signal Processing, Vancouver,
%     Canada, 2013, submitted.
% [2] H. Kirchhoff, S. Dixon, and A. Klapuri. Cross-recording adaptation of
%     musical instrument spectra. Technical Report C4DM-TR-11-2012,
%     Queen Mary University of London, 2012.
%     http://www.eecs.qmul.ac.uk/~holger/C4DM-TR-11-2012

% Copyright (C) 2012 Holger Kirchhoff
% 
% This program is free software; you can redistribute it and/or
% modify it under the terms of the GNU General Public License
% as published by the Free Software Foundation; either version 2
% of the License, or (at your option) any later version.
% 
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
% 
% You should have received a copy of the GNU General Public License
% along with this program; if not, write to the Free Software
% Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

    
   properties (Access='private')

        h = []; % filter transfer function
        maxDevFromMedianInDB = 10; % maximum deviation from mean values
        numCepstralCoeffs = 20;
        regularisationParam = 0.001;
        
        spectra_DB      = []; % basis functions estimated from database (e.g. RWC)
        spectra_data    = []; % basis functions derived from analysis spectrogram
        W_DB            = []; % spectra_DB reduced to peak amplitudes only
        W_data          = []; % spectra_data reduced to peak amplitudes only
        W_data_est      = []; % estimated basis functions (W_data * h)
        f0Idcs_DB       = []; % pitch values of columns in W_DB
        f0Idcs_data     = []; % pitch values of columns in W_data
        numF0Idcs_DB    = 0;
        numF0Idcs_data  = 0;
        
        commonF0Idcs          = []; % midi pitch values that occur both in f0Idcs_DB and in f0Idcs_data
        commonF0IdcsIdcs_DB   = []; % pitch indices in f0Idcs_DB   that also occur in f0Idcs_data
        commonF0IdcsIdcs_data = []; % pitch indices in f0Idcs_data that also occur in f0Idcs_DB
        numCommonShifts       = 0;

        zeroIdcsW_data        = []; % indices in W_data that are zero (required for numerical reasons)
        zeroIdcsW_data_est    = []; % indices in W_DB that are zero
        zeroIdcsH             = []; % indices in h that are zero
        
        maxNumFreqs = 0;
        
        costFctName = '';
        costFctNames = {'LS', 'KL', 'IS', 'BD'};
        
        beta = 0; % parameter beta for beta divergence
        
   end % properties

   methods
        
        function obj = CAdaptInstrSpec(spectra_DB, spectra_data, f0Idcs_DB, f0Idcs_data, numBinsPerSemitone, costFctName, varargin)
        % CAdaptInstrSpec - constructor of CAdaptInstrSpec class
        %    
        %    myObj = CAdaptInstrSpec(spectra_DB, spectra_data, f0Idcs_DB,
        %            f0Idcs_data, numBinsPerSemitone, costFctName)
        %    constructs a CAdaptInstrSpec object.
        %
        %    Parameters:    
        %    spectra_DB         - matrix containing in its columns the the
        %                         database templates that are to be adapted
        %    spectra_data       - matrix containing the spectra estimated
        %                         from the recording to which the database
        %                         spectra should be adapted
        %    f0Idcs_DB          - f0-indices corresponding to the columns
        %                         in spectra_DB
        %    f0Idcs_data        - f0-indices corresponding to the columns
        %                         in spectra_data
        %    numBinsPerSemitone - pitch resolution of the constant-Q
        %                         spectra
        %    costFctName        - name of cost function. available cost
        %                         functions are:
        %                         'LS' - least squares error
        %                         'KL' - generalised Kullback-Leibler div.
        %                         'IS' - Itakura-Saito divergence
        %                         'BD' - beta divergence
        %
        %   If 'BD' is selected as the cost function, the parameter beta
        %   has to be provided by myObj = CAdaptInstrSpec(..., 'beta',
        %   betaValue) where betaValue is a real, finite scalar.
            
        
            assert(length(f0Idcs_DB)   == size(spectra_DB,2),   'number of values in ''f0Idcs_DB''   must be the same as number of columns in W_DB');
            assert(length(f0Idcs_data) == size(spectra_data,2), 'number of values in ''f0Idcs_data'' must be the same as number of columns in W_data');
            assert(size(spectra_DB,1) == size(spectra_data,1), 'number of frequency bins (rows) in W_DB and W_data must be the same');
            
            % FIXME: check that inputs are valid!
            
            %% set member variables
            obj.f0Idcs_DB      = f0Idcs_DB;
            obj.f0Idcs_data    = f0Idcs_data;
            obj.spectra_DB     = spectra_DB;
            obj.spectra_data   = spectra_data;
            
            obj.maxNumFreqs    = size(spectra_DB,  1);
            obj.numF0Idcs_DB   = size(spectra_DB,  2);
            obj.numF0Idcs_data = size(spectra_data,2);
            
            [obj.commonF0Idcs, obj.commonF0IdcsIdcs_DB, obj.commonF0IdcsIdcs_data] = intersect(f0Idcs_DB, f0Idcs_data);
            obj.numCommonShifts = length(obj.commonF0IdcsIdcs_DB);
            
            
            %% reduce spectra to partial amplitudes only
            obj.W_DB   = obj.noteSpec2partialSpec(spectra_DB,   f0Idcs_DB,   numBinsPerSemitone);
            obj.W_data = obj.noteSpec2partialSpec(spectra_data, f0Idcs_data, numBinsPerSemitone);
            
            %% adjust amplitudes in database spectra at common pitches
            obj.W_DB(:, obj.commonF0IdcsIdcs_DB) = obj.adjustPartialPositions(obj.W_DB(:, obj.commonF0IdcsIdcs_DB), ...
                                                                              obj.W_data(:, obj.commonF0IdcsIdcs_data), ...
                                                                              obj.commonF0Idcs, numBinsPerSemitone);

            
            %% find zero-entries in W_data and W_data_est
            obj.zeroIdcsW_data        = (obj.W_data(:,obj.commonF0IdcsIdcs_data) == 0);
            obj.zeroIdcsW_data_est    = (obj.W_DB(:,obj.commonF0IdcsIdcs_DB) == 0); % zero where W_DB is zero (see computeW_data_est)
            obj.zeroIdcsH             = (sum(obj.W_data,2) == 0) | (sum(obj.W_DB(:,obj.commonF0IdcsIdcs_DB),2) == 0);

            assert(ischar(costFctName), ~any(strcmpi(costFctName, obj.costFctNames)), ...
                   'Argument ''costFctName'' must be a string, and must match one of the implemented cost function names.');
            obj.costFctName = costFctName;

            
            %% set beta & cost function name
            switch obj.costFctName
                case 'LS'
                    obj.beta = 2;
                    obj.costFctName = 'BD';
                case 'KL'
                    obj.beta = 1;
                    obj.costFctName = 'BD';
                case 'IS'
                    obj.beta = 0;
                    obj.costFctName = 'BD';
                case 'BD'
                    %check if beta was set
                    if isempty(varargin) % FIXME: if more optional arguments are added later, use MATLAB's inputparser
                        error('When ''BD'' is used as the cost function, beta needs to be set.');
                    elseif ~strcmp(varargin{1}, 'beta')
                        error('Name/value pair for ''beta'' not found.')
                    end
                    beta = varargin{2};
                    validateattributes(beta, {'numeric'}, {'scalar', 'real', 'finite', 'nonnan'}, 'CSourceFilter', 'beta', 5)
                    obj.beta = beta;
            end
            
            %% initialise h
            obj.h = zeros(obj.maxNumFreqs,1);
            obj.h(~obj.zeroIdcsH) = 1;
            
            %% compute initial WEst
            obj = computeW_data_est(obj);
            
        end
        
        
        function obj = updateH(obj)
        % updateH - perform single update of filter curve ''h''
        %
        %    myObj = myObj.updateH applies the update functions to the
        %    filter curve ''h''.
        
            %% get spectra at common f0 indices
            W_data     = obj.W_data(:, obj.commonF0IdcsIdcs_data);
            W_DB       = obj.W_DB(:, obj.commonF0IdcsIdcs_DB);
            W_data_est = obj.W_data_est;


            %% compute W_data * W_data_est^(beta-2)
            nomMatrix = W_data .* W_data_est .^ (obj.beta-2);

            % fix divide by 0            
            if obj.beta < 2  % if beta < 2, exponent of WEst^(beta-2) is negative -> division
                maxRatio = max(max(nomMatrix( ~obj.zeroIdcsW_data_est )));
                nomMatrix(obj.zeroIdcsW_data    & obj.zeroIdcsW_data_est) = 1;
                nomMatrix(~obj.zeroIdcsW_data & obj.zeroIdcsW_data_est) = maxRatio;
            end
            
            %% compute W_data^(beta-1)
            denomMatrix = W_data_est .^ (obj.beta-1);
            
            % fix divide by 0
            if obj.beta < 1  % if beta < 1, exponent of WEst^(beta-1) is negative -> division
                maxRatio = max(max(denomMatrix(~obj.zeroIdcsW_data_est)));
                denomMatrix(obj.zeroIdcsW_data_est) = maxRatio;
            end
            
            %% multiply by W_DB
            nomMatrix   = nomMatrix   .* W_DB;
            denomMatrix = denomMatrix .* W_DB;
            
                        
            %% compute nominator and denominator
            nom   = sum(nomMatrix, 2);
            denom = sum(denomMatrix, 2);
            
            %% compute ratio
            ratio = nom ./ denom;
            ratio((nom==0) & (denom==0)) = 1;
            ratio((nom~=0) & (denom==0)) = max(ratio);
            
            %% apply update
            obj.h(~obj.zeroIdcsH) = obj.h(~obj.zeroIdcsH) .* ratio(~obj.zeroIdcsH);
            
            %% recompute WEst
            obj = computeW_data_est(obj);
        end
        
                
        function [spectra shiftVals] = estimateSpectra(obj, shiftVals)
        % estimateSpectra - computes basis functions from the current
        %    estimates for e and h
        %
        %    [spectra f0Idcs] = myObj.estimateSpectra(f0Idcs) estimates the
        %    spectra at the f0 indices provided by ''f0Idcs'' by applying
        %    the current estimate of the filter curve to the spectra in
        %    ''spectra_DB'' (see constructor). Spectra are only returned
        %    for those f0Idcs that exist in ''f0Idcs_DB'' specified in the
        %    constructor. The output is a matrix ''spectra'' containing the
        %    estimated spectra and a vector ''f0Idcs'' containing the
        %    corresponding f0 indices.
            
            %FIXME: check that input variable 'shiftVals' is correct
            
            % find values in shiftVals that also exist in obj.f0Idcs_DB
            [commonShiftVals, shiftIdcs_DB, dummy] = intersect(obj.f0Idcs_DB, shiftVals);
            numCommonShiftVals = length(commonShiftVals);
            
            h = obj.getSmoothedH();
            %h = obj.getH();
            spectra = obj.spectra_DB(:,shiftIdcs_DB) .* repmat(h, 1, numCommonShiftVals);
        end
        
        function h = getH(obj)
        % getH - get filter curve ''h''
        %
        %    myH = myObj.getH() returns the member variable ''h''.

            h = obj.h;
        end
        
        function h = getSmoothedH(obj)
        % getH - get smoothed version of filter curve ''h''
        %
        %    myH = myObj.getSmoothedH() returns a smoothed version of the
        %    filter curve ''h''. Smoothing is done by applying the discrete
        %    cepstrum spectral envelope algorithm from Diemo Schwartz to 
        %    the filter curve ''h''.
            
            nonZeroFreqIdcsH = find(~obj.zeroIdcsH);
            
            % select nonzero entries from h
            h = obj.h;
            h_nonzero = h(nonZeroFreqIdcsH);
            h_nonzero_DB = 20*log10(h_nonzero);
                        
            
            % correct outliers that are more than 10 dB above or below median
            medianInDB = median(h_nonzero_DB);
            
            idcs = h_nonzero_DB > medianInDB + obj.maxDevFromMedianInDB;
            h_nonzero(idcs) = 10^( (medianInDB + obj.maxDevFromMedianInDB)/20 );
            
            idcs = h_nonzero_DB < medianInDB - obj.maxDevFromMedianInDB;
            h_nonzero(idcs) = 10^( (medianInDB - obj.maxDevFromMedianInDB)/20 );
            
            
            % setup vector containing frequencies for cosine approx.
            w = (1:obj.maxNumFreqs)' / obj.maxNumFreqs * pi;
            
            % select w at nonzero entries of h
            w_nonzero = w(nonZeroFreqIdcsH);
            
            % copy first and last nonzero entry to boundaries
            if nonZeroFreqIdcsH(1) ~= 1
                h_nonzero = [h_nonzero(1); h_nonzero]; 
                w_nonzero = [w(1); w_nonzero];
            end
            if nonZeroFreqIdcsH(end) ~= obj.maxNumFreqs
                h_nonzero = [h_nonzero; h_nonzero(end)];
                w_nonzero = [w_nonzero; w(end)];
            end

            
            % apply cosine approximation to h (discrete cepstrum)
            coeffs = dceps(h_nonzero, w_nonzero, obj.numCepstralCoeffs, obj.regularisationParam);
            h = idceps(coeffs, w);
            
        end
        
        function obj = setH(obj, h)
        % setH - set member variable h 
        %
        %    myObj = myObj.setH(myH) sets the member variable h to myH.
        %    myH must be a non-negative column vector of length [number of 
        %    frequencies].

            obj.h = h;
            obj = computeW_data_est(obj);
        end
        
        function betaDiv = compBetaDivergence(obj)
        % compBetaDivergence - computes beta divergence based on the
        %    current estiates
        %
        %    betaDiv = myObj.compBetaDivergence() returns the
        %    beta-divergence between the instrument spectra from the
        %    recording and the adapted database spectra based on the value
        %    for beta specified by the cost function in the constructor.
            
            W_data = obj.W_data(~obj.zeroIdcsW_data_est);
            W_data_est = obj.W_data_est(~obj.zeroIdcsW_data_est);
            
            switch obj.beta
                case 0
                    betaDivMat = W_data ./ W_data_est - log(W_data ./ W_data_est) - 1;
                    
                case 1
                    betaDivMat = W_data .* log(W_data ./ W_data_est) + W_data - W_data_est;
                    
                otherwise
                    betaDivMat = (W_data .^ obj.beta) / (obj.beta * (obj.beta-1)) ...
                                 + (W_data_est .^ obj.beta) / obj.beta ...
                                 - (W_data .* (W_data_est .^ (obj.beta-1))) / (obj.beta-1);
            end
            
            betaDiv = sum(betaDivMat(:));
        end

   end % methods
   

   methods (Access = private)
       
       function obj = computeW_data_est(obj)
       % computes basis functions from the current estimates for s, e and h
           
           if ~isempty(obj.h)
               obj.W_data_est = obj.W_DB(:,obj.commonF0IdcsIdcs_DB) .* repmat(obj.h, 1, obj.numCommonShifts);
           end
       end
       
   end % methods (Access = private)


   methods (Access = private, Static)
       
        function partialSpectra = noteSpec2partialSpec(noteSpectra, f0Idcs, numBinsPerSemitone)
        % goes through all note spectra, extracts the partial amplitudes
        % and writes them to their absolute frequency positions

            % initialize matrix for result
            [numFreqs numPitches] = size(noteSpectra);
            partialSpectra = zeros(numFreqs, numPitches);
            
            % get (ideal) relative partial positions
            maxNumPartials = floor(freqIdx2PartialIdx(numFreqs, numBinsPerSemitone));
            relF0IdcsOfPartials = partialIdx2FreqIdx((1:maxNumPartials)', numBinsPerSemitone);
            meansF0Idcs = geomean( [relF0IdcsOfPartials(1:end-1)'; relF0IdcsOfPartials(2:end)'] )';
            relLowerBoundOfPartials = [1; floor(meansF0Idcs)+1];
            relUpperBoundOfPartials = [floor(meansF0Idcs); numFreqs];
                    
            % go through all spectra
            for pitchIdx = 1:numPitches
                
                currF0Idx = f0Idcs(pitchIdx);
                currNumPartials = floor(freqIdx2PartialIdx(numFreqs-currF0Idx+1, numBinsPerSemitone));
                
                % go through partials
                for partialIdx = 1:currNumPartials
                    
                    % find maximum with partial range
                    lowerBound = currF0Idx-1 + relLowerBoundOfPartials(partialIdx);
                    upperBound = min(currF0Idx-1 + relUpperBoundOfPartials(partialIdx), numFreqs);
                    [maxAmpl maxIdx] = max(noteSpectra(lowerBound:upperBound, pitchIdx));
                    
                    % write to result matrix
                    partialSpectra(lowerBound-1+maxIdx, pitchIdx) = maxAmpl;
                    
                end                
            end
        end % noteSpec2partialSpec
        
        
        function W_DB = adjustPartialPositions(W_DB, W_data, f0Idcs, numBinsPerSemitone)
        % adjust the positions of the partials in W_DB to those in W_data
        
            assert(size(W_DB,2)   == size(W_data,2), '''W_DB'' and ''W_data'' must contain the same number of columns');
            assert(length(f0Idcs) == size(W_DB,2), 'Number of elements in ''f0Idcs'' must be equal to number of columns in ''W_DB''');

            [numFreqs numPitches] = size(W_DB);
            
            % get (ideal) relative partial positions
            maxNumPartials = floor(freqIdx2PartialIdx(numFreqs, numBinsPerSemitone));
            relF0IdcsOfPartials = partialIdx2FreqIdx((1:maxNumPartials)', numBinsPerSemitone);
            meansF0Idcs = geomean( [relF0IdcsOfPartials(1:end-1)'; relF0IdcsOfPartials(2:end)'] )';
            relLowerBoundOfPartials = [-ceil(numBinsPerSemitone/2); floor(meansF0Idcs)+1]; % make 1st bound -numBinsPerSemitone/2 to allow 1st partial to deviate below ideal position
            relUpperBoundOfPartials = [floor(meansF0Idcs); numFreqs];
                    
            % go through all spectra
            for pitchIdx = 1:numPitches
                
                currF0Idx = f0Idcs(pitchIdx);
                currNumPartials = compNumPartials(numFreqs-currF0Idx+1, numBinsPerSemitone);
                
                % go through partials
                for partialIdx = 1:currNumPartials
                    
                    % compute bounds for partial range
                    lowerBound = max(currF0Idx-1 + relLowerBoundOfPartials(partialIdx), 1);
                    upperBound = min(currF0Idx-1 + relUpperBoundOfPartials(partialIdx), numFreqs);
                    
                    % get indices of partial in both W_DB and W_data
                    idx_DB   = find(W_DB(lowerBound:upperBound, pitchIdx));
                    idx_data = find(W_data(lowerBound:upperBound, pitchIdx));
                    
                    % set partial amplitude in W_DB to frequency index of W_data
                    partAmpl = W_DB(lowerBound-1+idx_DB, pitchIdx);
                    W_DB(lowerBound-1+idx_DB,   pitchIdx) = 0;
                    W_DB(lowerBound-1+idx_data, pitchIdx) = partAmpl;
                end                
            end        
        end % adjustPartialPositions
   end
   
%    methods (Static)
%        
%        function Xshift = shiftSpectra(X, shiftVals)
%        % shifts each spectrum (column in X) down by the amount specified in
%        % shiftVals
%            
%            assert(size(X,2) == length(shiftVals), 'number values in ''shiftVals'' must be the same as number of columns in ''X''');
%        
%            [maxNumFreqs numShifts] = size(X);
%            Xshift = zeros(maxNumFreqs, numShifts);
%            
%            for pitchIdx = 1:numShifts
%                phi = shiftVals(pitchIdx);
%                Xshift(1:maxNumFreqs-phi, pitchIdx) = X(phi+1:maxNumFreqs, pitchIdx);
%            end
%        end
%    end % methods (Static)

end