Mercurial > hg > silvet
comparison notes/hnmf.m @ 16:d42c500b8ad0
Add non-convolutional version
author | Chris Cannam |
---|---|
date | Wed, 26 Mar 2014 08:33:45 +0000 |
parents | |
children | f1f8c84339d0 |
comparison
equal
deleted
inserted
replaced
15:2b7257e4fc8a | 16:d42c500b8ad0 |
---|---|
1 function [w,h,z,xa] = hnmf( x, K, R, iter, sh, sz, w, h, pl, lh, lz) | |
2 % function [w,h,z,xa] = hnmf( x, K, R, iter, sh, sz, w, h, pl, lh, lz) | |
3 % | |
4 % Perform Multi-source NMF | |
5 % | |
6 % Inputs: | |
7 % x input distribution | |
8 % K number of components | |
9 % R number of sources | |
10 % iter number of EM iterations [default = 100] | |
11 % sh sparsity of h | |
12 % sz sparsity of z | |
13 % w initial value of w | |
14 % h initial value of h | |
15 % pl plot flag | |
16 % lh update h flag | |
17 % lz update z flag | |
18 % | |
19 % Outputs: | |
20 % w spectral bases | |
21 % h component activation | |
22 % z source activation per component | |
23 % xa approximation of input | |
24 % | |
25 % Emmanouil Benetos 2011 | |
26 | |
27 | |
28 % Get sizes | |
29 [M,N] = size( x); | |
30 sumx = sum(x); | |
31 | |
32 % Default training iterations | |
33 if ~exist( 'iter') | |
34 iter = 100; | |
35 end | |
36 | |
37 % Default plot flag | |
38 if ~exist( 'pl') | |
39 pl = 1; | |
40 end | |
41 | |
42 % Initialize | |
43 if ~exist( 'w') || isempty( w) | |
44 w = rand( M, R, K); | |
45 end | |
46 for r=1:R | |
47 for k=1:K | |
48 w(:,k,r) = w(:,k,r) ./ sum(w(:,k,r)); | |
49 end; | |
50 end; | |
51 if ~exist( 'h') || isempty( h) | |
52 h = rand( K, N); | |
53 end | |
54 n=1:N; | |
55 h(:,n) = repmat(sumx(n),K,1) .* (h(:,n) ./ repmat( sum( h(:,n), 1), K, 1)); | |
56 if ~exist( 'z') || isempty( z) | |
57 z = rand( R, K, N); | |
58 end | |
59 for k=1:K | |
60 for n=1:N | |
61 z(:,k,n) = z(:,k,n) ./ sum(z(:,k,n)); | |
62 end; | |
63 end; | |
64 | |
65 | |
66 | |
67 % Iterate | |
68 for it = 1:iter | |
69 | |
70 % E-step | |
71 zh = z .* permute(repmat(h,[1 1 R]),[3 1 2]); | |
72 xa=eps; | |
73 for r=1:R | |
74 for k=1:K | |
75 xa = xa + w(:,k,r) * squeeze(zh(r,k,:))'; | |
76 end; | |
77 end; | |
78 Q = x ./ xa; | |
79 | |
80 % M-step (update h,z) | |
81 if (lh && lz) | |
82 nh=zeros(K,N); | |
83 for k=1:K | |
84 for r=1:R | |
85 nh(k,:) = nh(k,:) + squeeze(z(r,k,:))' .* (squeeze(w(:,k,r))' * Q); | |
86 nz = h(k,:) .* (squeeze(w(:,k,r))' * Q); | |
87 nz = nz .* squeeze(z(r,k,:))'; | |
88 z(r,k,:) = nz; | |
89 end; | |
90 end; | |
91 nh = h .* nh; | |
92 end | |
93 | |
94 | |
95 % Assign and normalize | |
96 k=1:K; | |
97 n=1:N; | |
98 if lh | |
99 nh = nh.^sh; | |
100 h(:,n) = repmat(sumx(n),K,1) .* (nh(:,n) ./ repmat( sum( nh(:,n), 1), K, 1)); | |
101 end | |
102 if lz | |
103 z = z.^sz; | |
104 z(:,k,n) = z(:,k,n) ./ repmat( sum( z(:,k,n), 1), R, 1); | |
105 end | |
106 | |
107 end | |
108 | |
109 % Show me | |
110 if pl | |
111 subplot(3, 1, 1), imagesc(x), axis xy | |
112 subplot(3, 1, 2), imagesc(xa), axis xy | |
113 subplot(3, 1, 3), imagesc(h), axis xy | |
114 end |