diff src/matlab/SA_B_NMF.m @ 0:c52bc3e8d3ad tip

user: boblsturm branch 'default' added README.md added assets/.DS_Store added assets/playButton.jpg added assets/stopButton.png added assets/swapButton.jpg added data/.DS_Store added data/fiveoctaves.mp3 added data/glock2.wav added data/sinScale.mp3 added data/speech_female.mp3 added data/sweep.wav added nimfks.m.lnk added src/.DS_Store added src/matlab/.DS_Store added src/matlab/AnalysisCache.m added src/matlab/CSS.m added src/matlab/DataHash.m added src/matlab/ExistsInCache.m added src/matlab/KLDivCost.m added src/matlab/LoadFromCache.m added src/matlab/SA_B_NMF.m added src/matlab/SaveInCache.m added src/matlab/Sound.m added src/matlab/SynthesisCache.m added src/matlab/chromagram_E.m added src/matlab/chromagram_IF.m added src/matlab/chromagram_P.m added src/matlab/chromsynth.m added src/matlab/computeSTFTFeat.m added src/matlab/controller.m added src/matlab/decibelSliderReleaseCallback.m added src/matlab/drawClickCallBack.m added src/matlab/fft2chromamx.m added src/matlab/hz2octs.m added src/matlab/ifgram.m added src/matlab/ifptrack.m added src/matlab/istft.m added src/matlab/nimfks.fig added src/matlab/nimfks.m added src/matlab/nmfFn.m added src/matlab/nmf_beta.m added src/matlab/nmf_divergence.m added src/matlab/nmf_euclidean.m added src/matlab/prune_corpus.m added src/matlab/rot_kernel.m added src/matlab/templateAdditionResynth.m added src/matlab/templateDelCb.m added src/matlab/templateScrollCb.m
author boblsturm
date Sun, 18 Jun 2017 06:26:13 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/matlab/SA_B_NMF.m	Sun Jun 18 06:26:13 2017 -0400
@@ -0,0 +1,154 @@
+% Author: Dr. Elio Quinton
+
+function [ W, H, deleted, finalCost ] = SA_B_NMF(V, W, lambda, varargin )
+%SENMF Summary of this function goes here
+%   Detailed explanation goes here
+if nargin > 2
+    nmf_params = varargin{1};
+    iterations = nmf_params.Iterations;
+    lambda = nmf_params.Lambda;
+elseif nargin == 2
+    iterations = 500;
+    lambda = 5;
+end
+
+waitbarHandle = waitbar(0, 'Starting sparse NMF synthesis...');
+
+targetDim=size(V);
+sourceDim=size(W);
+K=sourceDim(2);
+M=targetDim(2);
+H=random('unif',0, 1, K, M);
+
+deleted = [];
+H = 0.01 * ones(size(H));% + (0.01 * rand(size(H)));
+cost = get_cost(V, W, H, lambda); % Function get_cost defined at the end of this file.
+
+converged = 0;
+convergence_threshold = 50; % Note that this is a threshold on the derivative of the cost, i.e. how much it decays at each iteration. This value might not be ideal for our use case.
+
+myeps = 10e-3; % A bigger eps here helps pruning out unused components
+V = V + myeps; % Note that V here is the `target' audio magnitude spectrogram
+
+% num_h_iter = 1;
+err = 0.0001;
+% max_iter = 1000;
+max_iter = iterations;
+
+
+
+iter = 0;
+omit = 0;
+
+while ~converged && iter < max_iter
+    
+    waitbar(iter/(iterations-1), waitbarHandle, ['Computing approximation...Iteration: ', num2str(iter), '/', num2str(iterations-1)])
+    iter = iter + 1;
+    %% sparse NMF decomposition
+    if iter > 0
+        
+        % Update H.
+        R = W*H+eps;
+        RR = 1./R;
+        RRR = RR.^2;
+        RRRR = sqrt(R);
+        
+        pen = lambda./(sqrt(H + eps)); 
+        
+        H = H .* (((W' * (V .* RRR.* RRRR)) ./ ((W' * (RR .* RRRR) + pen))).^(1/3));
+        
+        
+        %Update W: REMOVE THIS FOR OUR USE CASE
+%         nn = sum(sqrt(H'));
+%         NN = lambda * repmat(nn,size(V,1),1);
+%         NNW = NN.*W;
+% 
+%         R = W*H+eps;
+%         RR = 1./R;
+%         RRR = RR.^2;
+%         RRRR = sqrt(R);
+%         W = W .* ( ((V .* RRR.* RRRR)*H') ./ ( ((RR .* RRRR)*H') + NNW + eps)).^(1/3);
+        % Update W: stop deleting here
+
+        
+    else
+        % Non-sparse first iteration. Not sure it is necessary
+        % in our particular use case. We might want to get rid of
+        % it later, but it should not harm anyway.
+        R = W*H;
+        RR = 1./R;
+        RRR = RR.^2;
+        RRRR = sqrt(R);
+        H = H .* (((W' * (V .* RRR.* RRRR)) ./ (W' * (RR .* RRRR) + eps)).^(1/3));
+        
+        %Update W: REMOVE THIS FOR OUR USE CASE
+%         R = W*H + myeps;
+%         RR = 1./R;
+%         RRR = RR.^2;
+%         RRRR = sqrt(R);
+%         W = W .* ((((V .* RRR.* RRRR)*H') ./ (((RR .* RRRR)*H') + eps)).^(1/3));
+        % Update W: stop deleting here
+    end
+    
+    %% normalise and prune templates if their total activation reaches 0.
+    todel = [];
+    shifts = [];
+    for i = 1:size(W,2)
+%        nn =  norm(W(:,i)); % W is not being updated so this is of no use
+       nn =  sum(H(i,:)); % Check if norm of rows of H get to zero instead. 
+       if nn == 0
+           todel = [todel i];
+%            disp(['Deleting  ' int2str(length(todel))]); % This is printing a message everytime templates are deleted
+       else
+            nn =  norm(W(:,i)); % Still normalise against norm of Templates to avoid division by zero.
+            W(:,i) = W(:,i) / nn; 
+            H(i,:) = H(i,:) * nn;
+       end
+    end
+
+    if( length(deleted) == 0 )
+        deleted = [deleted todel];
+    else
+        shifts = zeros(1, length(todel));
+        for i = 1:length(shifts)
+            shifts(i) = length( deleted( deleted >= todel(i) ) );
+        end
+        deleted = [deleted todel+shifts];
+    end
+    W(:,todel) = [];
+    H(todel,:) = [];
+    
+    %% get the cost and monitor convergence
+    if (mod(iter, 5) == 0) || (iter == 1)
+
+        new_cost = get_cost(V, W, H, lambda);
+    
+        if omit == 0 && cost - new_cost < cost * err & iter > convergence_threshold
+           converged = 0; 
+      %     
+        end
+        
+        cost = new_cost;
+        finalCost(iter) = cost;
+        omit = 0;
+        
+%         disp([int2str(iter)  '    ' num2str(cost)]); % this prints the cost function at each iteration. Could be commented out (printing is slow in matlab)
+    elseif iter > 1
+        finalCost(iter) = finalCost(iter - 1);
+    end
+    
+end
+
+close(waitbarHandle);
+end
+
+
+
+function cost = get_cost(V, W, H, lambda)
+R = W*H+eps;
+hcost = sum(sum( (sqrt(V) - sqrt(R).^2 )./sqrt(R) ));
+nn = 2 * lambda * sum(sum(sqrt(H)));
+cost = hcost + nn;
+
+end
+