boblsturm@0: function [Y, cost] = nmf_divergence(V, W, varargin) boblsturm@0: boblsturm@0: if nargin > 2 boblsturm@0: nmf_params = varargin{1}; boblsturm@0: L = nmf_params.Iterations; boblsturm@0: convergence = nmf_params.Convergence_criteria; boblsturm@0: r = nmf_params.Repition_restriction; boblsturm@0: p = nmf_params.Polyphony_restriction; boblsturm@0: c = nmf_params.Continuity_enhancement; boblsturm@0: rot = nmf_params.Continuity_enhancement_rot; boblsturm@0: pattern = nmf_params.Diagonal_pattern; boblsturm@0: endtime = nmf_params.Modification_application; boblsturm@0: rng(nmf_params.Random_seed); boblsturm@0: elseif nargin == 2 boblsturm@0: L = 10; boblsturm@0: convergence = 0; boblsturm@0: r = -1; boblsturm@0: p = -1; boblsturm@0: c = -1; boblsturm@0: pattern = 'Diagonal'; boblsturm@0: endtime = false; boblsturm@0: rng('shuffle'); boblsturm@0: end boblsturm@0: boblsturm@0: waitbarHandle = waitbar(0, 'Starting NMF synthesis...'); boblsturm@0: boblsturm@0: cost=0; boblsturm@0: K=size(W, 2); boblsturm@0: M=size(V, 2); boblsturm@0: boblsturm@0: H=random('unif',0, 1, K, M); boblsturm@0: boblsturm@0: P=zeros(K, M); boblsturm@0: R=zeros(K, M); boblsturm@0: C=zeros(K, M); boblsturm@0: boblsturm@0: V = V+1E-6; boblsturm@0: W = W+1E-6; boblsturm@0: den = sum(W); boblsturm@0: boblsturm@0: for l=1:L-1 boblsturm@0: waitbar(l/(L-1), waitbarHandle, ['Computing approximation...Iteration: ', num2str(l), '/', num2str(L-1)]) boblsturm@0: boblsturm@0: recon = W*H; boblsturm@0: for mm = 1:size(H,2) boblsturm@0: num = V(:,mm).*(1./recon(:,mm)); boblsturm@0: num2 = num'*W./den; boblsturm@0: H(:,mm) = H(:, mm).*num2'; boblsturm@0: end boblsturm@0: boblsturm@0: if((r > 0 && ~endtime) || (r > 0 && endtime && l==L-1)) boblsturm@0: waitbar(l/(L-1), waitbarHandle, ['Repition Restriction...Iteration: ', num2str(l), '/', num2str(L-1)]) boblsturm@0: for k = 1:size(H, 1) boblsturm@0: for m = 1:size(H, 2) boblsturm@0: if(m>r && (m+r)<=M && H(k,m)==max(H(k,m-r:m+r))) boblsturm@0: R(k,m)=H(k,m); boblsturm@0: else boblsturm@0: R(k,m)=H(k,m)*(1-(l+1)/L); boblsturm@0: end boblsturm@0: end boblsturm@0: end boblsturm@0: boblsturm@0: H = R; boblsturm@0: end boblsturm@0: boblsturm@0: if((p > 0 && ~endtime) || (p > 0 && endtime && l==L-1)) boblsturm@0: waitbar(l/(L-1), waitbarHandle, ['Polyphony Restriction...Iteration: ', num2str(l), '/', num2str(L-1)]) boblsturm@0: P = zeros(size(H)); boblsturm@0: mask = zeros(size(H,1),1); boblsturm@0: for m = 1:size(H, 2) boblsturm@0: [~, sortedIndices] = sort(H(:, m),'descend'); boblsturm@0: mask(sortedIndices(1:p)) = 1; boblsturm@0: mask(sortedIndices(p+1:end)) = (1-(l+1)/L); boblsturm@0: P(:,m)=H(:,m).*mask; boblsturm@0: end boblsturm@0: H = P; boblsturm@0: end boblsturm@0: boblsturm@0: if((c > 0 && ~endtime) || (c > 0 && endtime && l==L-1)) boblsturm@0: waitbar(l/(L-1), waitbarHandle, ['Continuity Enhancement...Iteration: ', num2str(l), '/', num2str(L-1)]) boblsturm@0: switch pattern boblsturm@0: case 'Diagonal' boblsturm@0: C = conv2(H, rot_kernel( eye(c), rot ), 'same'); %Default boblsturm@0: case 'Reverse' boblsturm@0: C = conv2(H, flip(eye(c)), 'same'); %Reverse boblsturm@0: case 'Blur' boblsturm@0: C = conv2(H, flip(ones(c)), 'same'); %Blurring boblsturm@0: case 'Vertical' boblsturm@0: M = zeros(c, c); %Vertical boblsturm@0: M(:, floor(c/2)) = 1; boblsturm@0: C = conv2(P, M, 'same'); boblsturm@0: end boblsturm@0: boblsturm@0: H = C; boblsturm@0: end boblsturm@0: boblsturm@0: if ~endtime boblsturm@0: recon = W*H; boblsturm@0: for mm = 1:size(H,2) boblsturm@0: num = V(:,mm).*(1./recon(:,mm)); boblsturm@0: num2 = num'*W./den; boblsturm@0: H(:,mm) = H(:, mm).*num2'; boblsturm@0: end boblsturm@0: end boblsturm@0: boblsturm@0: cost(l)=KLDivCost(V, W*H); boblsturm@0: if(l>3 && (abs(((cost(l)-cost(l-1)))/max(cost))<=convergence)) boblsturm@0: break; boblsturm@0: end boblsturm@0: end boblsturm@0: boblsturm@0: fprintf('Iterations: %i/%i\n', l, L); boblsturm@0: fprintf('Convergence Criteria: %i\n', convergence*100); boblsturm@0: fprintf('Repitition: %i\n', r); boblsturm@0: fprintf('Polyphony: %i\n', p); boblsturm@0: fprintf('Continuity: %i\n', c); boblsturm@0: boblsturm@0: Y=H; boblsturm@0: Y = Y./max(max(Y)); %Normalize activations boblsturm@0: boblsturm@0: close(waitbarHandle); boblsturm@0: end