comparison notes/hnmf.m @ 16:d42c500b8ad0

Add non-convolutional version
author Chris Cannam
date Wed, 26 Mar 2014 08:33:45 +0000
parents
children f1f8c84339d0
comparison
equal deleted inserted replaced
15:2b7257e4fc8a 16:d42c500b8ad0
1 function [w,h,z,xa] = hnmf( x, K, R, iter, sh, sz, w, h, pl, lh, lz)
2 % function [w,h,z,xa] = hnmf( x, K, R, iter, sh, sz, w, h, pl, lh, lz)
3 %
4 % Perform Multi-source NMF
5 %
6 % Inputs:
7 % x input distribution
8 % K number of components
9 % R number of sources
10 % iter number of EM iterations [default = 100]
11 % sh sparsity of h
12 % sz sparsity of z
13 % w initial value of w
14 % h initial value of h
15 % pl plot flag
16 % lh update h flag
17 % lz update z flag
18 %
19 % Outputs:
20 % w spectral bases
21 % h component activation
22 % z source activation per component
23 % xa approximation of input
24 %
25 % Emmanouil Benetos 2011
26
27
28 % Get sizes
29 [M,N] = size( x);
30 sumx = sum(x);
31
32 % Default training iterations
33 if ~exist( 'iter')
34 iter = 100;
35 end
36
37 % Default plot flag
38 if ~exist( 'pl')
39 pl = 1;
40 end
41
42 % Initialize
43 if ~exist( 'w') || isempty( w)
44 w = rand( M, R, K);
45 end
46 for r=1:R
47 for k=1:K
48 w(:,k,r) = w(:,k,r) ./ sum(w(:,k,r));
49 end;
50 end;
51 if ~exist( 'h') || isempty( h)
52 h = rand( K, N);
53 end
54 n=1:N;
55 h(:,n) = repmat(sumx(n),K,1) .* (h(:,n) ./ repmat( sum( h(:,n), 1), K, 1));
56 if ~exist( 'z') || isempty( z)
57 z = rand( R, K, N);
58 end
59 for k=1:K
60 for n=1:N
61 z(:,k,n) = z(:,k,n) ./ sum(z(:,k,n));
62 end;
63 end;
64
65
66
67 % Iterate
68 for it = 1:iter
69
70 % E-step
71 zh = z .* permute(repmat(h,[1 1 R]),[3 1 2]);
72 xa=eps;
73 for r=1:R
74 for k=1:K
75 xa = xa + w(:,k,r) * squeeze(zh(r,k,:))';
76 end;
77 end;
78 Q = x ./ xa;
79
80 % M-step (update h,z)
81 if (lh && lz)
82 nh=zeros(K,N);
83 for k=1:K
84 for r=1:R
85 nh(k,:) = nh(k,:) + squeeze(z(r,k,:))' .* (squeeze(w(:,k,r))' * Q);
86 nz = h(k,:) .* (squeeze(w(:,k,r))' * Q);
87 nz = nz .* squeeze(z(r,k,:))';
88 z(r,k,:) = nz;
89 end;
90 end;
91 nh = h .* nh;
92 end
93
94
95 % Assign and normalize
96 k=1:K;
97 n=1:N;
98 if lh
99 nh = nh.^sh;
100 h(:,n) = repmat(sumx(n),K,1) .* (nh(:,n) ./ repmat( sum( nh(:,n), 1), K, 1));
101 end
102 if lz
103 z = z.^sz;
104 z(:,k,n) = z(:,k,n) ./ repmat( sum( z(:,k,n), 1), R, 1);
105 end
106
107 end
108
109 % Show me
110 if pl
111 subplot(3, 1, 1), imagesc(x), axis xy
112 subplot(3, 1, 2), imagesc(xa), axis xy
113 subplot(3, 1, 3), imagesc(h), axis xy
114 end