boblsturm@0
|
1 function [Y, cost] = nmf_divergence(V, W, varargin)
|
boblsturm@0
|
2
|
boblsturm@0
|
3 if nargin > 2
|
boblsturm@0
|
4 nmf_params = varargin{1};
|
boblsturm@0
|
5 L = nmf_params.Iterations;
|
boblsturm@0
|
6 convergence = nmf_params.Convergence_criteria;
|
boblsturm@0
|
7 r = nmf_params.Repition_restriction;
|
boblsturm@0
|
8 p = nmf_params.Polyphony_restriction;
|
boblsturm@0
|
9 c = nmf_params.Continuity_enhancement;
|
boblsturm@0
|
10 rot = nmf_params.Continuity_enhancement_rot;
|
boblsturm@0
|
11 pattern = nmf_params.Diagonal_pattern;
|
boblsturm@0
|
12 endtime = nmf_params.Modification_application;
|
boblsturm@0
|
13 rng(nmf_params.Random_seed);
|
boblsturm@0
|
14 elseif nargin == 2
|
boblsturm@0
|
15 L = 10;
|
boblsturm@0
|
16 convergence = 0;
|
boblsturm@0
|
17 r = -1;
|
boblsturm@0
|
18 p = -1;
|
boblsturm@0
|
19 c = -1;
|
boblsturm@0
|
20 pattern = 'Diagonal';
|
boblsturm@0
|
21 endtime = false;
|
boblsturm@0
|
22 rng('shuffle');
|
boblsturm@0
|
23 end
|
boblsturm@0
|
24
|
boblsturm@0
|
25 waitbarHandle = waitbar(0, 'Starting NMF synthesis...');
|
boblsturm@0
|
26
|
boblsturm@0
|
27 cost=0;
|
boblsturm@0
|
28 K=size(W, 2);
|
boblsturm@0
|
29 M=size(V, 2);
|
boblsturm@0
|
30
|
boblsturm@0
|
31 H=random('unif',0, 1, K, M);
|
boblsturm@0
|
32
|
boblsturm@0
|
33 P=zeros(K, M);
|
boblsturm@0
|
34 R=zeros(K, M);
|
boblsturm@0
|
35 C=zeros(K, M);
|
boblsturm@0
|
36
|
boblsturm@0
|
37 V = V+1E-6;
|
boblsturm@0
|
38 W = W+1E-6;
|
boblsturm@0
|
39 den = sum(W);
|
boblsturm@0
|
40
|
boblsturm@0
|
41 for l=1:L-1
|
boblsturm@0
|
42 waitbar(l/(L-1), waitbarHandle, ['Computing approximation...Iteration: ', num2str(l), '/', num2str(L-1)])
|
boblsturm@0
|
43
|
boblsturm@0
|
44 recon = W*H;
|
boblsturm@0
|
45 for mm = 1:size(H,2)
|
boblsturm@0
|
46 num = V(:,mm).*(1./recon(:,mm));
|
boblsturm@0
|
47 num2 = num'*W./den;
|
boblsturm@0
|
48 H(:,mm) = H(:, mm).*num2';
|
boblsturm@0
|
49 end
|
boblsturm@0
|
50
|
boblsturm@0
|
51 if((r > 0 && ~endtime) || (r > 0 && endtime && l==L-1))
|
boblsturm@0
|
52 waitbar(l/(L-1), waitbarHandle, ['Repition Restriction...Iteration: ', num2str(l), '/', num2str(L-1)])
|
boblsturm@0
|
53 for k = 1:size(H, 1)
|
boblsturm@0
|
54 for m = 1:size(H, 2)
|
boblsturm@0
|
55 if(m>r && (m+r)<=M && H(k,m)==max(H(k,m-r:m+r)))
|
boblsturm@0
|
56 R(k,m)=H(k,m);
|
boblsturm@0
|
57 else
|
boblsturm@0
|
58 R(k,m)=H(k,m)*(1-(l+1)/L);
|
boblsturm@0
|
59 end
|
boblsturm@0
|
60 end
|
boblsturm@0
|
61 end
|
boblsturm@0
|
62
|
boblsturm@0
|
63 H = R;
|
boblsturm@0
|
64 end
|
boblsturm@0
|
65
|
boblsturm@0
|
66 if((p > 0 && ~endtime) || (p > 0 && endtime && l==L-1))
|
boblsturm@0
|
67 waitbar(l/(L-1), waitbarHandle, ['Polyphony Restriction...Iteration: ', num2str(l), '/', num2str(L-1)])
|
boblsturm@0
|
68 P = zeros(size(H));
|
boblsturm@0
|
69 mask = zeros(size(H,1),1);
|
boblsturm@0
|
70 for m = 1:size(H, 2)
|
boblsturm@0
|
71 [~, sortedIndices] = sort(H(:, m),'descend');
|
boblsturm@0
|
72 mask(sortedIndices(1:p)) = 1;
|
boblsturm@0
|
73 mask(sortedIndices(p+1:end)) = (1-(l+1)/L);
|
boblsturm@0
|
74 P(:,m)=H(:,m).*mask;
|
boblsturm@0
|
75 end
|
boblsturm@0
|
76 H = P;
|
boblsturm@0
|
77 end
|
boblsturm@0
|
78
|
boblsturm@0
|
79 if((c > 0 && ~endtime) || (c > 0 && endtime && l==L-1))
|
boblsturm@0
|
80 waitbar(l/(L-1), waitbarHandle, ['Continuity Enhancement...Iteration: ', num2str(l), '/', num2str(L-1)])
|
boblsturm@0
|
81 switch pattern
|
boblsturm@0
|
82 case 'Diagonal'
|
boblsturm@0
|
83 C = conv2(H, rot_kernel( eye(c), rot ), 'same'); %Default
|
boblsturm@0
|
84 case 'Reverse'
|
boblsturm@0
|
85 C = conv2(H, flip(eye(c)), 'same'); %Reverse
|
boblsturm@0
|
86 case 'Blur'
|
boblsturm@0
|
87 C = conv2(H, flip(ones(c)), 'same'); %Blurring
|
boblsturm@0
|
88 case 'Vertical'
|
boblsturm@0
|
89 M = zeros(c, c); %Vertical
|
boblsturm@0
|
90 M(:, floor(c/2)) = 1;
|
boblsturm@0
|
91 C = conv2(P, M, 'same');
|
boblsturm@0
|
92 end
|
boblsturm@0
|
93
|
boblsturm@0
|
94 H = C;
|
boblsturm@0
|
95 end
|
boblsturm@0
|
96
|
boblsturm@0
|
97 if ~endtime
|
boblsturm@0
|
98 recon = W*H;
|
boblsturm@0
|
99 for mm = 1:size(H,2)
|
boblsturm@0
|
100 num = V(:,mm).*(1./recon(:,mm));
|
boblsturm@0
|
101 num2 = num'*W./den;
|
boblsturm@0
|
102 H(:,mm) = H(:, mm).*num2';
|
boblsturm@0
|
103 end
|
boblsturm@0
|
104 end
|
boblsturm@0
|
105
|
boblsturm@0
|
106 cost(l)=KLDivCost(V, W*H);
|
boblsturm@0
|
107 if(l>3 && (abs(((cost(l)-cost(l-1)))/max(cost))<=convergence))
|
boblsturm@0
|
108 break;
|
boblsturm@0
|
109 end
|
boblsturm@0
|
110 end
|
boblsturm@0
|
111
|
boblsturm@0
|
112 fprintf('Iterations: %i/%i\n', l, L);
|
boblsturm@0
|
113 fprintf('Convergence Criteria: %i\n', convergence*100);
|
boblsturm@0
|
114 fprintf('Repitition: %i\n', r);
|
boblsturm@0
|
115 fprintf('Polyphony: %i\n', p);
|
boblsturm@0
|
116 fprintf('Continuity: %i\n', c);
|
boblsturm@0
|
117
|
boblsturm@0
|
118 Y=H;
|
boblsturm@0
|
119 Y = Y./max(max(Y)); %Normalize activations
|
boblsturm@0
|
120
|
boblsturm@0
|
121 close(waitbarHandle);
|
boblsturm@0
|
122 end |