boblsturm@0: % Author: Dr. Elio Quinton boblsturm@0: boblsturm@0: function [ W, H, deleted, finalCost ] = SA_B_NMF(V, W, lambda, varargin ) boblsturm@0: %SENMF Summary of this function goes here boblsturm@0: % Detailed explanation goes here boblsturm@0: if nargin > 2 boblsturm@0: nmf_params = varargin{1}; boblsturm@0: iterations = nmf_params.Iterations; boblsturm@0: lambda = nmf_params.Lambda; boblsturm@0: elseif nargin == 2 boblsturm@0: iterations = 500; boblsturm@0: lambda = 5; boblsturm@0: end boblsturm@0: boblsturm@0: waitbarHandle = waitbar(0, 'Starting sparse NMF synthesis...'); boblsturm@0: boblsturm@0: targetDim=size(V); boblsturm@0: sourceDim=size(W); boblsturm@0: K=sourceDim(2); boblsturm@0: M=targetDim(2); boblsturm@0: H=random('unif',0, 1, K, M); boblsturm@0: boblsturm@0: deleted = []; boblsturm@0: H = 0.01 * ones(size(H));% + (0.01 * rand(size(H))); boblsturm@0: cost = get_cost(V, W, H, lambda); % Function get_cost defined at the end of this file. boblsturm@0: boblsturm@0: converged = 0; boblsturm@0: convergence_threshold = 50; % Note that this is a threshold on the derivative of the cost, i.e. how much it decays at each iteration. This value might not be ideal for our use case. boblsturm@0: boblsturm@0: myeps = 10e-3; % A bigger eps here helps pruning out unused components boblsturm@0: V = V + myeps; % Note that V here is the `target' audio magnitude spectrogram boblsturm@0: boblsturm@0: % num_h_iter = 1; boblsturm@0: err = 0.0001; boblsturm@0: % max_iter = 1000; boblsturm@0: max_iter = iterations; boblsturm@0: boblsturm@0: boblsturm@0: boblsturm@0: iter = 0; boblsturm@0: omit = 0; boblsturm@0: boblsturm@0: while ~converged && iter < max_iter boblsturm@0: boblsturm@0: waitbar(iter/(iterations-1), waitbarHandle, ['Computing approximation...Iteration: ', num2str(iter), '/', num2str(iterations-1)]) boblsturm@0: iter = iter + 1; boblsturm@0: %% sparse NMF decomposition boblsturm@0: if iter > 0 boblsturm@0: boblsturm@0: % Update H. boblsturm@0: R = W*H+eps; boblsturm@0: RR = 1./R; boblsturm@0: RRR = RR.^2; boblsturm@0: RRRR = sqrt(R); boblsturm@0: boblsturm@0: pen = lambda./(sqrt(H + eps)); boblsturm@0: boblsturm@0: H = H .* (((W' * (V .* RRR.* RRRR)) ./ ((W' * (RR .* RRRR) + pen))).^(1/3)); boblsturm@0: boblsturm@0: boblsturm@0: %Update W: REMOVE THIS FOR OUR USE CASE boblsturm@0: % nn = sum(sqrt(H')); boblsturm@0: % NN = lambda * repmat(nn,size(V,1),1); boblsturm@0: % NNW = NN.*W; boblsturm@0: % boblsturm@0: % R = W*H+eps; boblsturm@0: % RR = 1./R; boblsturm@0: % RRR = RR.^2; boblsturm@0: % RRRR = sqrt(R); boblsturm@0: % W = W .* ( ((V .* RRR.* RRRR)*H') ./ ( ((RR .* RRRR)*H') + NNW + eps)).^(1/3); boblsturm@0: % Update W: stop deleting here boblsturm@0: boblsturm@0: boblsturm@0: else boblsturm@0: % Non-sparse first iteration. Not sure it is necessary boblsturm@0: % in our particular use case. We might want to get rid of boblsturm@0: % it later, but it should not harm anyway. boblsturm@0: R = W*H; boblsturm@0: RR = 1./R; boblsturm@0: RRR = RR.^2; boblsturm@0: RRRR = sqrt(R); boblsturm@0: H = H .* (((W' * (V .* RRR.* RRRR)) ./ (W' * (RR .* RRRR) + eps)).^(1/3)); boblsturm@0: boblsturm@0: %Update W: REMOVE THIS FOR OUR USE CASE boblsturm@0: % R = W*H + myeps; boblsturm@0: % RR = 1./R; boblsturm@0: % RRR = RR.^2; boblsturm@0: % RRRR = sqrt(R); boblsturm@0: % W = W .* ((((V .* RRR.* RRRR)*H') ./ (((RR .* RRRR)*H') + eps)).^(1/3)); boblsturm@0: % Update W: stop deleting here boblsturm@0: end boblsturm@0: boblsturm@0: %% normalise and prune templates if their total activation reaches 0. boblsturm@0: todel = []; boblsturm@0: shifts = []; boblsturm@0: for i = 1:size(W,2) boblsturm@0: % nn = norm(W(:,i)); % W is not being updated so this is of no use boblsturm@0: nn = sum(H(i,:)); % Check if norm of rows of H get to zero instead. boblsturm@0: if nn == 0 boblsturm@0: todel = [todel i]; boblsturm@0: % disp(['Deleting ' int2str(length(todel))]); % This is printing a message everytime templates are deleted boblsturm@0: else boblsturm@0: nn = norm(W(:,i)); % Still normalise against norm of Templates to avoid division by zero. boblsturm@0: W(:,i) = W(:,i) / nn; boblsturm@0: H(i,:) = H(i,:) * nn; boblsturm@0: end boblsturm@0: end boblsturm@0: boblsturm@0: if( length(deleted) == 0 ) boblsturm@0: deleted = [deleted todel]; boblsturm@0: else boblsturm@0: shifts = zeros(1, length(todel)); boblsturm@0: for i = 1:length(shifts) boblsturm@0: shifts(i) = length( deleted( deleted >= todel(i) ) ); boblsturm@0: end boblsturm@0: deleted = [deleted todel+shifts]; boblsturm@0: end boblsturm@0: W(:,todel) = []; boblsturm@0: H(todel,:) = []; boblsturm@0: boblsturm@0: %% get the cost and monitor convergence boblsturm@0: if (mod(iter, 5) == 0) || (iter == 1) boblsturm@0: boblsturm@0: new_cost = get_cost(V, W, H, lambda); boblsturm@0: boblsturm@0: if omit == 0 && cost - new_cost < cost * err & iter > convergence_threshold boblsturm@0: converged = 0; boblsturm@0: % boblsturm@0: end boblsturm@0: boblsturm@0: cost = new_cost; boblsturm@0: finalCost(iter) = cost; boblsturm@0: omit = 0; boblsturm@0: boblsturm@0: % disp([int2str(iter) ' ' num2str(cost)]); % this prints the cost function at each iteration. Could be commented out (printing is slow in matlab) boblsturm@0: elseif iter > 1 boblsturm@0: finalCost(iter) = finalCost(iter - 1); boblsturm@0: end boblsturm@0: boblsturm@0: end boblsturm@0: boblsturm@0: close(waitbarHandle); boblsturm@0: end boblsturm@0: boblsturm@0: boblsturm@0: boblsturm@0: function cost = get_cost(V, W, H, lambda) boblsturm@0: R = W*H+eps; boblsturm@0: hcost = sum(sum( (sqrt(V) - sqrt(R).^2 )./sqrt(R) )); boblsturm@0: nn = 2 * lambda * sum(sum(sqrt(H))); boblsturm@0: cost = hcost + nn; boblsturm@0: boblsturm@0: end boblsturm@0: