boblsturm@0
|
1 % Author: Dr. Elio Quinton
|
boblsturm@0
|
2
|
boblsturm@0
|
3 function [ W, H, deleted, finalCost ] = SA_B_NMF(V, W, lambda, varargin )
|
boblsturm@0
|
4 %SENMF Summary of this function goes here
|
boblsturm@0
|
5 % Detailed explanation goes here
|
boblsturm@0
|
6 if nargin > 2
|
boblsturm@0
|
7 nmf_params = varargin{1};
|
boblsturm@0
|
8 iterations = nmf_params.Iterations;
|
boblsturm@0
|
9 lambda = nmf_params.Lambda;
|
boblsturm@0
|
10 elseif nargin == 2
|
boblsturm@0
|
11 iterations = 500;
|
boblsturm@0
|
12 lambda = 5;
|
boblsturm@0
|
13 end
|
boblsturm@0
|
14
|
boblsturm@0
|
15 waitbarHandle = waitbar(0, 'Starting sparse NMF synthesis...');
|
boblsturm@0
|
16
|
boblsturm@0
|
17 targetDim=size(V);
|
boblsturm@0
|
18 sourceDim=size(W);
|
boblsturm@0
|
19 K=sourceDim(2);
|
boblsturm@0
|
20 M=targetDim(2);
|
boblsturm@0
|
21 H=random('unif',0, 1, K, M);
|
boblsturm@0
|
22
|
boblsturm@0
|
23 deleted = [];
|
boblsturm@0
|
24 H = 0.01 * ones(size(H));% + (0.01 * rand(size(H)));
|
boblsturm@0
|
25 cost = get_cost(V, W, H, lambda); % Function get_cost defined at the end of this file.
|
boblsturm@0
|
26
|
boblsturm@0
|
27 converged = 0;
|
boblsturm@0
|
28 convergence_threshold = 50; % Note that this is a threshold on the derivative of the cost, i.e. how much it decays at each iteration. This value might not be ideal for our use case.
|
boblsturm@0
|
29
|
boblsturm@0
|
30 myeps = 10e-3; % A bigger eps here helps pruning out unused components
|
boblsturm@0
|
31 V = V + myeps; % Note that V here is the `target' audio magnitude spectrogram
|
boblsturm@0
|
32
|
boblsturm@0
|
33 % num_h_iter = 1;
|
boblsturm@0
|
34 err = 0.0001;
|
boblsturm@0
|
35 % max_iter = 1000;
|
boblsturm@0
|
36 max_iter = iterations;
|
boblsturm@0
|
37
|
boblsturm@0
|
38
|
boblsturm@0
|
39
|
boblsturm@0
|
40 iter = 0;
|
boblsturm@0
|
41 omit = 0;
|
boblsturm@0
|
42
|
boblsturm@0
|
43 while ~converged && iter < max_iter
|
boblsturm@0
|
44
|
boblsturm@0
|
45 waitbar(iter/(iterations-1), waitbarHandle, ['Computing approximation...Iteration: ', num2str(iter), '/', num2str(iterations-1)])
|
boblsturm@0
|
46 iter = iter + 1;
|
boblsturm@0
|
47 %% sparse NMF decomposition
|
boblsturm@0
|
48 if iter > 0
|
boblsturm@0
|
49
|
boblsturm@0
|
50 % Update H.
|
boblsturm@0
|
51 R = W*H+eps;
|
boblsturm@0
|
52 RR = 1./R;
|
boblsturm@0
|
53 RRR = RR.^2;
|
boblsturm@0
|
54 RRRR = sqrt(R);
|
boblsturm@0
|
55
|
boblsturm@0
|
56 pen = lambda./(sqrt(H + eps));
|
boblsturm@0
|
57
|
boblsturm@0
|
58 H = H .* (((W' * (V .* RRR.* RRRR)) ./ ((W' * (RR .* RRRR) + pen))).^(1/3));
|
boblsturm@0
|
59
|
boblsturm@0
|
60
|
boblsturm@0
|
61 %Update W: REMOVE THIS FOR OUR USE CASE
|
boblsturm@0
|
62 % nn = sum(sqrt(H'));
|
boblsturm@0
|
63 % NN = lambda * repmat(nn,size(V,1),1);
|
boblsturm@0
|
64 % NNW = NN.*W;
|
boblsturm@0
|
65 %
|
boblsturm@0
|
66 % R = W*H+eps;
|
boblsturm@0
|
67 % RR = 1./R;
|
boblsturm@0
|
68 % RRR = RR.^2;
|
boblsturm@0
|
69 % RRRR = sqrt(R);
|
boblsturm@0
|
70 % W = W .* ( ((V .* RRR.* RRRR)*H') ./ ( ((RR .* RRRR)*H') + NNW + eps)).^(1/3);
|
boblsturm@0
|
71 % Update W: stop deleting here
|
boblsturm@0
|
72
|
boblsturm@0
|
73
|
boblsturm@0
|
74 else
|
boblsturm@0
|
75 % Non-sparse first iteration. Not sure it is necessary
|
boblsturm@0
|
76 % in our particular use case. We might want to get rid of
|
boblsturm@0
|
77 % it later, but it should not harm anyway.
|
boblsturm@0
|
78 R = W*H;
|
boblsturm@0
|
79 RR = 1./R;
|
boblsturm@0
|
80 RRR = RR.^2;
|
boblsturm@0
|
81 RRRR = sqrt(R);
|
boblsturm@0
|
82 H = H .* (((W' * (V .* RRR.* RRRR)) ./ (W' * (RR .* RRRR) + eps)).^(1/3));
|
boblsturm@0
|
83
|
boblsturm@0
|
84 %Update W: REMOVE THIS FOR OUR USE CASE
|
boblsturm@0
|
85 % R = W*H + myeps;
|
boblsturm@0
|
86 % RR = 1./R;
|
boblsturm@0
|
87 % RRR = RR.^2;
|
boblsturm@0
|
88 % RRRR = sqrt(R);
|
boblsturm@0
|
89 % W = W .* ((((V .* RRR.* RRRR)*H') ./ (((RR .* RRRR)*H') + eps)).^(1/3));
|
boblsturm@0
|
90 % Update W: stop deleting here
|
boblsturm@0
|
91 end
|
boblsturm@0
|
92
|
boblsturm@0
|
93 %% normalise and prune templates if their total activation reaches 0.
|
boblsturm@0
|
94 todel = [];
|
boblsturm@0
|
95 shifts = [];
|
boblsturm@0
|
96 for i = 1:size(W,2)
|
boblsturm@0
|
97 % nn = norm(W(:,i)); % W is not being updated so this is of no use
|
boblsturm@0
|
98 nn = sum(H(i,:)); % Check if norm of rows of H get to zero instead.
|
boblsturm@0
|
99 if nn == 0
|
boblsturm@0
|
100 todel = [todel i];
|
boblsturm@0
|
101 % disp(['Deleting ' int2str(length(todel))]); % This is printing a message everytime templates are deleted
|
boblsturm@0
|
102 else
|
boblsturm@0
|
103 nn = norm(W(:,i)); % Still normalise against norm of Templates to avoid division by zero.
|
boblsturm@0
|
104 W(:,i) = W(:,i) / nn;
|
boblsturm@0
|
105 H(i,:) = H(i,:) * nn;
|
boblsturm@0
|
106 end
|
boblsturm@0
|
107 end
|
boblsturm@0
|
108
|
boblsturm@0
|
109 if( length(deleted) == 0 )
|
boblsturm@0
|
110 deleted = [deleted todel];
|
boblsturm@0
|
111 else
|
boblsturm@0
|
112 shifts = zeros(1, length(todel));
|
boblsturm@0
|
113 for i = 1:length(shifts)
|
boblsturm@0
|
114 shifts(i) = length( deleted( deleted >= todel(i) ) );
|
boblsturm@0
|
115 end
|
boblsturm@0
|
116 deleted = [deleted todel+shifts];
|
boblsturm@0
|
117 end
|
boblsturm@0
|
118 W(:,todel) = [];
|
boblsturm@0
|
119 H(todel,:) = [];
|
boblsturm@0
|
120
|
boblsturm@0
|
121 %% get the cost and monitor convergence
|
boblsturm@0
|
122 if (mod(iter, 5) == 0) || (iter == 1)
|
boblsturm@0
|
123
|
boblsturm@0
|
124 new_cost = get_cost(V, W, H, lambda);
|
boblsturm@0
|
125
|
boblsturm@0
|
126 if omit == 0 && cost - new_cost < cost * err & iter > convergence_threshold
|
boblsturm@0
|
127 converged = 0;
|
boblsturm@0
|
128 %
|
boblsturm@0
|
129 end
|
boblsturm@0
|
130
|
boblsturm@0
|
131 cost = new_cost;
|
boblsturm@0
|
132 finalCost(iter) = cost;
|
boblsturm@0
|
133 omit = 0;
|
boblsturm@0
|
134
|
boblsturm@0
|
135 % disp([int2str(iter) ' ' num2str(cost)]); % this prints the cost function at each iteration. Could be commented out (printing is slow in matlab)
|
boblsturm@0
|
136 elseif iter > 1
|
boblsturm@0
|
137 finalCost(iter) = finalCost(iter - 1);
|
boblsturm@0
|
138 end
|
boblsturm@0
|
139
|
boblsturm@0
|
140 end
|
boblsturm@0
|
141
|
boblsturm@0
|
142 close(waitbarHandle);
|
boblsturm@0
|
143 end
|
boblsturm@0
|
144
|
boblsturm@0
|
145
|
boblsturm@0
|
146
|
boblsturm@0
|
147 function cost = get_cost(V, W, H, lambda)
|
boblsturm@0
|
148 R = W*H+eps;
|
boblsturm@0
|
149 hcost = sum(sum( (sqrt(V) - sqrt(R).^2 )./sqrt(R) ));
|
boblsturm@0
|
150 nn = 2 * lambda * sum(sum(sqrt(H)));
|
boblsturm@0
|
151 cost = hcost + nn;
|
boblsturm@0
|
152
|
boblsturm@0
|
153 end
|
boblsturm@0
|
154
|