diff CAdaptInstrSpec.m @ 0:b4e26b53072f tip

Initial commit.
author Holger Kirchhoff <holger.kirchhoff@eecs.qmul.ac.uk>
date Tue, 04 Dec 2012 13:57:15 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/CAdaptInstrSpec.m	Tue Dec 04 13:57:15 2012 +0000
@@ -0,0 +1,484 @@
+classdef CAdaptInstrSpec    
+% CAdaptInstrSpec - Estimation of a filter curve that enables the
+%   adaptation of instrument templates from one recording to another. The
+%   beta-divergence is used as a cost function between the original and the
+%   adapted spectra. All spectra need to be provided on a logarithmic
+%   frequency axis (for now, an extension to linear frequency axes should
+%   be straightforward). For further details on the estimation method, see
+%   the references below.
+%
+% PROPERTIES
+%    no public properties
+%
+% METHODS
+%    CAdaptInstrSpec    - constructor for CAdaptInstrSpec object
+%    setH               - sets filter curve 'h'
+%    getH               - returns estimate of filter curve ''h''
+%    getSmoothedH       - returns smoothed and interpolated version of the
+%                         filter curve ''h''
+%    updateH            - performs single update of filter curve ''h''
+%    estimateSpectra    - estimates spectra based on current estimate of
+%                         the filter curve.
+%    compBetaDivergence - compute beta-divergence between original and
+%                         estimated spectra
+%
+% For further help on the methods, type 'help CAdaptInstrSpec.[methodName]'
+%
+%
+% References:
+%
+% [1] H. Kirchhoff, S. Dixon, A. Klapuri. Missing spectral templates
+%     estimation for user-assisted music transcription. IEEE International
+%     Conference on Acoustics, Speech and Signal Processing, Vancouver,
+%     Canada, 2013, submitted.
+% [2] H. Kirchhoff, S. Dixon, and A. Klapuri. Cross-recording adaptation of
+%     musical instrument spectra. Technical Report C4DM-TR-11-2012,
+%     Queen Mary University of London, 2012.
+%     http://www.eecs.qmul.ac.uk/~holger/C4DM-TR-11-2012
+
+% Copyright (C) 2012 Holger Kirchhoff
+% 
+% This program is free software; you can redistribute it and/or
+% modify it under the terms of the GNU General Public License
+% as published by the Free Software Foundation; either version 2
+% of the License, or (at your option) any later version.
+% 
+% This program is distributed in the hope that it will be useful,
+% but WITHOUT ANY WARRANTY; without even the implied warranty of
+% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+% GNU General Public License for more details.
+% 
+% You should have received a copy of the GNU General Public License
+% along with this program; if not, write to the Free Software
+% Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+
+    
+   properties (Access='private')
+
+        h = []; % filter transfer function
+        maxDevFromMedianInDB = 10; % maximum deviation from mean values
+        numCepstralCoeffs = 20;
+        regularisationParam = 0.001;
+        
+        spectra_DB      = []; % basis functions estimated from database (e.g. RWC)
+        spectra_data    = []; % basis functions derived from analysis spectrogram
+        W_DB            = []; % spectra_DB reduced to peak amplitudes only
+        W_data          = []; % spectra_data reduced to peak amplitudes only
+        W_data_est      = []; % estimated basis functions (W_data * h)
+        f0Idcs_DB       = []; % pitch values of columns in W_DB
+        f0Idcs_data     = []; % pitch values of columns in W_data
+        numF0Idcs_DB    = 0;
+        numF0Idcs_data  = 0;
+        
+        commonF0Idcs          = []; % midi pitch values that occur both in f0Idcs_DB and in f0Idcs_data
+        commonF0IdcsIdcs_DB   = []; % pitch indices in f0Idcs_DB   that also occur in f0Idcs_data
+        commonF0IdcsIdcs_data = []; % pitch indices in f0Idcs_data that also occur in f0Idcs_DB
+        numCommonShifts       = 0;
+
+        zeroIdcsW_data        = []; % indices in W_data that are zero (required for numerical reasons)
+        zeroIdcsW_data_est    = []; % indices in W_DB that are zero
+        zeroIdcsH             = []; % indices in h that are zero
+        
+        maxNumFreqs = 0;
+        
+        costFctName = '';
+        costFctNames = {'LS', 'KL', 'IS', 'BD'};
+        
+        beta = 0; % parameter beta for beta divergence
+        
+   end % properties
+
+   methods
+        
+        function obj = CAdaptInstrSpec(spectra_DB, spectra_data, f0Idcs_DB, f0Idcs_data, numBinsPerSemitone, costFctName, varargin)
+        % CAdaptInstrSpec - constructor of CAdaptInstrSpec class
+        %    
+        %    myObj = CAdaptInstrSpec(spectra_DB, spectra_data, f0Idcs_DB,
+        %            f0Idcs_data, numBinsPerSemitone, costFctName)
+        %    constructs a CAdaptInstrSpec object.
+        %
+        %    Parameters:    
+        %    spectra_DB         - matrix containing in its columns the the
+        %                         database templates that are to be adapted
+        %    spectra_data       - matrix containing the spectra estimated
+        %                         from the recording to which the database
+        %                         spectra should be adapted
+        %    f0Idcs_DB          - f0-indices corresponding to the columns
+        %                         in spectra_DB
+        %    f0Idcs_data        - f0-indices corresponding to the columns
+        %                         in spectra_data
+        %    numBinsPerSemitone - pitch resolution of the constant-Q
+        %                         spectra
+        %    costFctName        - name of cost function. available cost
+        %                         functions are:
+        %                         'LS' - least squares error
+        %                         'KL' - generalised Kullback-Leibler div.
+        %                         'IS' - Itakura-Saito divergence
+        %                         'BD' - beta divergence
+        %
+        %   If 'BD' is selected as the cost function, the parameter beta
+        %   has to be provided by myObj = CAdaptInstrSpec(..., 'beta',
+        %   betaValue) where betaValue is a real, finite scalar.
+            
+        
+            assert(length(f0Idcs_DB)   == size(spectra_DB,2),   'number of values in ''f0Idcs_DB''   must be the same as number of columns in W_DB');
+            assert(length(f0Idcs_data) == size(spectra_data,2), 'number of values in ''f0Idcs_data'' must be the same as number of columns in W_data');
+            assert(size(spectra_DB,1) == size(spectra_data,1), 'number of frequency bins (rows) in W_DB and W_data must be the same');
+            
+            % FIXME: check that inputs are valid!
+            
+            %% set member variables
+            obj.f0Idcs_DB      = f0Idcs_DB;
+            obj.f0Idcs_data    = f0Idcs_data;
+            obj.spectra_DB     = spectra_DB;
+            obj.spectra_data   = spectra_data;
+            
+            obj.maxNumFreqs    = size(spectra_DB,  1);
+            obj.numF0Idcs_DB   = size(spectra_DB,  2);
+            obj.numF0Idcs_data = size(spectra_data,2);
+            
+            [obj.commonF0Idcs, obj.commonF0IdcsIdcs_DB, obj.commonF0IdcsIdcs_data] = intersect(f0Idcs_DB, f0Idcs_data);
+            obj.numCommonShifts = length(obj.commonF0IdcsIdcs_DB);
+            
+            
+            %% reduce spectra to partial amplitudes only
+            obj.W_DB   = obj.noteSpec2partialSpec(spectra_DB,   f0Idcs_DB,   numBinsPerSemitone);
+            obj.W_data = obj.noteSpec2partialSpec(spectra_data, f0Idcs_data, numBinsPerSemitone);
+            
+            %% adjust amplitudes in database spectra at common pitches
+            obj.W_DB(:, obj.commonF0IdcsIdcs_DB) = obj.adjustPartialPositions(obj.W_DB(:, obj.commonF0IdcsIdcs_DB), ...
+                                                                              obj.W_data(:, obj.commonF0IdcsIdcs_data), ...
+                                                                              obj.commonF0Idcs, numBinsPerSemitone);
+
+            
+            %% find zero-entries in W_data and W_data_est
+            obj.zeroIdcsW_data        = (obj.W_data(:,obj.commonF0IdcsIdcs_data) == 0);
+            obj.zeroIdcsW_data_est    = (obj.W_DB(:,obj.commonF0IdcsIdcs_DB) == 0); % zero where W_DB is zero (see computeW_data_est)
+            obj.zeroIdcsH             = (sum(obj.W_data,2) == 0) | (sum(obj.W_DB(:,obj.commonF0IdcsIdcs_DB),2) == 0);
+
+            assert(ischar(costFctName), ~any(strcmpi(costFctName, obj.costFctNames)), ...
+                   'Argument ''costFctName'' must be a string, and must match one of the implemented cost function names.');
+            obj.costFctName = costFctName;
+
+            
+            %% set beta & cost function name
+            switch obj.costFctName
+                case 'LS'
+                    obj.beta = 2;
+                    obj.costFctName = 'BD';
+                case 'KL'
+                    obj.beta = 1;
+                    obj.costFctName = 'BD';
+                case 'IS'
+                    obj.beta = 0;
+                    obj.costFctName = 'BD';
+                case 'BD'
+                    %check if beta was set
+                    if isempty(varargin) % FIXME: if more optional arguments are added later, use MATLAB's inputparser
+                        error('When ''BD'' is used as the cost function, beta needs to be set.');
+                    elseif ~strcmp(varargin{1}, 'beta')
+                        error('Name/value pair for ''beta'' not found.')
+                    end
+                    beta = varargin{2};
+                    validateattributes(beta, {'numeric'}, {'scalar', 'real', 'finite', 'nonnan'}, 'CSourceFilter', 'beta', 5)
+                    obj.beta = beta;
+            end
+            
+            %% initialise h
+            obj.h = zeros(obj.maxNumFreqs,1);
+            obj.h(~obj.zeroIdcsH) = 1;
+            
+            %% compute initial WEst
+            obj = computeW_data_est(obj);
+            
+        end
+        
+        
+        function obj = updateH(obj)
+        % updateH - perform single update of filter curve ''h''
+        %
+        %    myObj = myObj.updateH applies the update functions to the
+        %    filter curve ''h''.
+        
+            %% get spectra at common f0 indices
+            W_data     = obj.W_data(:, obj.commonF0IdcsIdcs_data);
+            W_DB       = obj.W_DB(:, obj.commonF0IdcsIdcs_DB);
+            W_data_est = obj.W_data_est;
+
+
+            %% compute W_data * W_data_est^(beta-2)
+            nomMatrix = W_data .* W_data_est .^ (obj.beta-2);
+
+            % fix divide by 0            
+            if obj.beta < 2  % if beta < 2, exponent of WEst^(beta-2) is negative -> division
+                maxRatio = max(max(nomMatrix( ~obj.zeroIdcsW_data_est )));
+                nomMatrix(obj.zeroIdcsW_data    & obj.zeroIdcsW_data_est) = 1;
+                nomMatrix(~obj.zeroIdcsW_data & obj.zeroIdcsW_data_est) = maxRatio;
+            end
+            
+            %% compute W_data^(beta-1)
+            denomMatrix = W_data_est .^ (obj.beta-1);
+            
+            % fix divide by 0
+            if obj.beta < 1  % if beta < 1, exponent of WEst^(beta-1) is negative -> division
+                maxRatio = max(max(denomMatrix(~obj.zeroIdcsW_data_est)));
+                denomMatrix(obj.zeroIdcsW_data_est) = maxRatio;
+            end
+            
+            %% multiply by W_DB
+            nomMatrix   = nomMatrix   .* W_DB;
+            denomMatrix = denomMatrix .* W_DB;
+            
+                        
+            %% compute nominator and denominator
+            nom   = sum(nomMatrix, 2);
+            denom = sum(denomMatrix, 2);
+            
+            %% compute ratio
+            ratio = nom ./ denom;
+            ratio((nom==0) & (denom==0)) = 1;
+            ratio((nom~=0) & (denom==0)) = max(ratio);
+            
+            %% apply update
+            obj.h(~obj.zeroIdcsH) = obj.h(~obj.zeroIdcsH) .* ratio(~obj.zeroIdcsH);
+            
+            %% recompute WEst
+            obj = computeW_data_est(obj);
+        end
+        
+                
+        function [spectra shiftVals] = estimateSpectra(obj, shiftVals)
+        % estimateSpectra - computes basis functions from the current
+        %    estimates for e and h
+        %
+        %    [spectra f0Idcs] = myObj.estimateSpectra(f0Idcs) estimates the
+        %    spectra at the f0 indices provided by ''f0Idcs'' by applying
+        %    the current estimate of the filter curve to the spectra in
+        %    ''spectra_DB'' (see constructor). Spectra are only returned
+        %    for those f0Idcs that exist in ''f0Idcs_DB'' specified in the
+        %    constructor. The output is a matrix ''spectra'' containing the
+        %    estimated spectra and a vector ''f0Idcs'' containing the
+        %    corresponding f0 indices.
+            
+            %FIXME: check that input variable 'shiftVals' is correct
+            
+            % find values in shiftVals that also exist in obj.f0Idcs_DB
+            [commonShiftVals, shiftIdcs_DB, dummy] = intersect(obj.f0Idcs_DB, shiftVals);
+            numCommonShiftVals = length(commonShiftVals);
+            
+            h = obj.getSmoothedH();
+            %h = obj.getH();
+            spectra = obj.spectra_DB(:,shiftIdcs_DB) .* repmat(h, 1, numCommonShiftVals);
+        end
+        
+        function h = getH(obj)
+        % getH - get filter curve ''h''
+        %
+        %    myH = myObj.getH() returns the member variable ''h''.
+
+            h = obj.h;
+        end
+        
+        function h = getSmoothedH(obj)
+        % getH - get smoothed version of filter curve ''h''
+        %
+        %    myH = myObj.getSmoothedH() returns a smoothed version of the
+        %    filter curve ''h''. Smoothing is done by applying the discrete
+        %    cepstrum spectral envelope algorithm from Diemo Schwartz to 
+        %    the filter curve ''h''.
+            
+            nonZeroFreqIdcsH = find(~obj.zeroIdcsH);
+            
+            % select nonzero entries from h
+            h = obj.h;
+            h_nonzero = h(nonZeroFreqIdcsH);
+            h_nonzero_DB = 20*log10(h_nonzero);
+                        
+            
+            % correct outliers that are more than 10 dB above or below median
+            medianInDB = median(h_nonzero_DB);
+            
+            idcs = h_nonzero_DB > medianInDB + obj.maxDevFromMedianInDB;
+            h_nonzero(idcs) = 10^( (medianInDB + obj.maxDevFromMedianInDB)/20 );
+            
+            idcs = h_nonzero_DB < medianInDB - obj.maxDevFromMedianInDB;
+            h_nonzero(idcs) = 10^( (medianInDB - obj.maxDevFromMedianInDB)/20 );
+            
+            
+            % setup vector containing frequencies for cosine approx.
+            w = (1:obj.maxNumFreqs)' / obj.maxNumFreqs * pi;
+            
+            % select w at nonzero entries of h
+            w_nonzero = w(nonZeroFreqIdcsH);
+            
+            % copy first and last nonzero entry to boundaries
+            if nonZeroFreqIdcsH(1) ~= 1
+                h_nonzero = [h_nonzero(1); h_nonzero]; 
+                w_nonzero = [w(1); w_nonzero];
+            end
+            if nonZeroFreqIdcsH(end) ~= obj.maxNumFreqs
+                h_nonzero = [h_nonzero; h_nonzero(end)];
+                w_nonzero = [w_nonzero; w(end)];
+            end
+
+            
+            % apply cosine approximation to h (discrete cepstrum)
+            coeffs = dceps(h_nonzero, w_nonzero, obj.numCepstralCoeffs, obj.regularisationParam);
+            h = idceps(coeffs, w);
+            
+        end
+        
+        function obj = setH(obj, h)
+        % setH - set member variable h 
+        %
+        %    myObj = myObj.setH(myH) sets the member variable h to myH.
+        %    myH must be a non-negative column vector of length [number of 
+        %    frequencies].
+
+            obj.h = h;
+            obj = computeW_data_est(obj);
+        end
+        
+        function betaDiv = compBetaDivergence(obj)
+        % compBetaDivergence - computes beta divergence based on the
+        %    current estiates
+        %
+        %    betaDiv = myObj.compBetaDivergence() returns the
+        %    beta-divergence between the instrument spectra from the
+        %    recording and the adapted database spectra based on the value
+        %    for beta specified by the cost function in the constructor.
+            
+            W_data = obj.W_data(~obj.zeroIdcsW_data_est);
+            W_data_est = obj.W_data_est(~obj.zeroIdcsW_data_est);
+            
+            switch obj.beta
+                case 0
+                    betaDivMat = W_data ./ W_data_est - log(W_data ./ W_data_est) - 1;
+                    
+                case 1
+                    betaDivMat = W_data .* log(W_data ./ W_data_est) + W_data - W_data_est;
+                    
+                otherwise
+                    betaDivMat = (W_data .^ obj.beta) / (obj.beta * (obj.beta-1)) ...
+                                 + (W_data_est .^ obj.beta) / obj.beta ...
+                                 - (W_data .* (W_data_est .^ (obj.beta-1))) / (obj.beta-1);
+            end
+            
+            betaDiv = sum(betaDivMat(:));
+        end
+
+   end % methods
+   
+
+   methods (Access = private)
+       
+       function obj = computeW_data_est(obj)
+       % computes basis functions from the current estimates for s, e and h
+           
+           if ~isempty(obj.h)
+               obj.W_data_est = obj.W_DB(:,obj.commonF0IdcsIdcs_DB) .* repmat(obj.h, 1, obj.numCommonShifts);
+           end
+       end
+       
+   end % methods (Access = private)
+
+
+   methods (Access = private, Static)
+       
+        function partialSpectra = noteSpec2partialSpec(noteSpectra, f0Idcs, numBinsPerSemitone)
+        % goes through all note spectra, extracts the partial amplitudes
+        % and writes them to their absolute frequency positions
+
+            % initialize matrix for result
+            [numFreqs numPitches] = size(noteSpectra);
+            partialSpectra = zeros(numFreqs, numPitches);
+            
+            % get (ideal) relative partial positions
+            maxNumPartials = floor(freqIdx2PartialIdx(numFreqs, numBinsPerSemitone));
+            relF0IdcsOfPartials = partialIdx2FreqIdx((1:maxNumPartials)', numBinsPerSemitone);
+            meansF0Idcs = geomean( [relF0IdcsOfPartials(1:end-1)'; relF0IdcsOfPartials(2:end)'] )';
+            relLowerBoundOfPartials = [1; floor(meansF0Idcs)+1];
+            relUpperBoundOfPartials = [floor(meansF0Idcs); numFreqs];
+                    
+            % go through all spectra
+            for pitchIdx = 1:numPitches
+                
+                currF0Idx = f0Idcs(pitchIdx);
+                currNumPartials = floor(freqIdx2PartialIdx(numFreqs-currF0Idx+1, numBinsPerSemitone));
+                
+                % go through partials
+                for partialIdx = 1:currNumPartials
+                    
+                    % find maximum with partial range
+                    lowerBound = currF0Idx-1 + relLowerBoundOfPartials(partialIdx);
+                    upperBound = min(currF0Idx-1 + relUpperBoundOfPartials(partialIdx), numFreqs);
+                    [maxAmpl maxIdx] = max(noteSpectra(lowerBound:upperBound, pitchIdx));
+                    
+                    % write to result matrix
+                    partialSpectra(lowerBound-1+maxIdx, pitchIdx) = maxAmpl;
+                    
+                end                
+            end
+        end % noteSpec2partialSpec
+        
+        
+        function W_DB = adjustPartialPositions(W_DB, W_data, f0Idcs, numBinsPerSemitone)
+        % adjust the positions of the partials in W_DB to those in W_data
+        
+            assert(size(W_DB,2)   == size(W_data,2), '''W_DB'' and ''W_data'' must contain the same number of columns');
+            assert(length(f0Idcs) == size(W_DB,2), 'Number of elements in ''f0Idcs'' must be equal to number of columns in ''W_DB''');
+
+            [numFreqs numPitches] = size(W_DB);
+            
+            % get (ideal) relative partial positions
+            maxNumPartials = floor(freqIdx2PartialIdx(numFreqs, numBinsPerSemitone));
+            relF0IdcsOfPartials = partialIdx2FreqIdx((1:maxNumPartials)', numBinsPerSemitone);
+            meansF0Idcs = geomean( [relF0IdcsOfPartials(1:end-1)'; relF0IdcsOfPartials(2:end)'] )';
+            relLowerBoundOfPartials = [-ceil(numBinsPerSemitone/2); floor(meansF0Idcs)+1]; % make 1st bound -numBinsPerSemitone/2 to allow 1st partial to deviate below ideal position
+            relUpperBoundOfPartials = [floor(meansF0Idcs); numFreqs];
+                    
+            % go through all spectra
+            for pitchIdx = 1:numPitches
+                
+                currF0Idx = f0Idcs(pitchIdx);
+                currNumPartials = compNumPartials(numFreqs-currF0Idx+1, numBinsPerSemitone);
+                
+                % go through partials
+                for partialIdx = 1:currNumPartials
+                    
+                    % compute bounds for partial range
+                    lowerBound = max(currF0Idx-1 + relLowerBoundOfPartials(partialIdx), 1);
+                    upperBound = min(currF0Idx-1 + relUpperBoundOfPartials(partialIdx), numFreqs);
+                    
+                    % get indices of partial in both W_DB and W_data
+                    idx_DB   = find(W_DB(lowerBound:upperBound, pitchIdx));
+                    idx_data = find(W_data(lowerBound:upperBound, pitchIdx));
+                    
+                    % set partial amplitude in W_DB to frequency index of W_data
+                    partAmpl = W_DB(lowerBound-1+idx_DB, pitchIdx);
+                    W_DB(lowerBound-1+idx_DB,   pitchIdx) = 0;
+                    W_DB(lowerBound-1+idx_data, pitchIdx) = partAmpl;
+                end                
+            end        
+        end % adjustPartialPositions
+   end
+   
+%    methods (Static)
+%        
+%        function Xshift = shiftSpectra(X, shiftVals)
+%        % shifts each spectrum (column in X) down by the amount specified in
+%        % shiftVals
+%            
+%            assert(size(X,2) == length(shiftVals), 'number values in ''shiftVals'' must be the same as number of columns in ''X''');
+%        
+%            [maxNumFreqs numShifts] = size(X);
+%            Xshift = zeros(maxNumFreqs, numShifts);
+%            
+%            for pitchIdx = 1:numShifts
+%                phi = shiftVals(pitchIdx);
+%                Xshift(1:maxNumFreqs-phi, pitchIdx) = X(phi+1:maxNumFreqs, pitchIdx);
+%            end
+%        end
+%    end % methods (Static)
+
+end
\ No newline at end of file