# HG changeset patch # User Chris Cannam # Date 1395822825 0 # Node ID d42c500b8ad0282c93d91f41d5e45790da8687a2 # Parent 2b7257e4fc8accc9936ea5d576dfde960f97ab33 Add non-convolutional version diff -r 2b7257e4fc8a -r d42c500b8ad0 notes/hnmf.m --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/notes/hnmf.m Wed Mar 26 08:33:45 2014 +0000 @@ -0,0 +1,114 @@ +function [w,h,z,xa] = hnmf( x, K, R, iter, sh, sz, w, h, pl, lh, lz) +% function [w,h,z,xa] = hnmf( x, K, R, iter, sh, sz, w, h, pl, lh, lz) +% +% Perform Multi-source NMF +% +% Inputs: +% x input distribution +% K number of components +% R number of sources +% iter number of EM iterations [default = 100] +% sh sparsity of h +% sz sparsity of z +% w initial value of w +% h initial value of h +% pl plot flag +% lh update h flag +% lz update z flag +% +% Outputs: +% w spectral bases +% h component activation +% z source activation per component +% xa approximation of input +% +% Emmanouil Benetos 2011 + + +% Get sizes +[M,N] = size( x); +sumx = sum(x); + +% Default training iterations +if ~exist( 'iter') + iter = 100; +end + +% Default plot flag +if ~exist( 'pl') + pl = 1; +end + +% Initialize +if ~exist( 'w') || isempty( w) + w = rand( M, R, K); +end +for r=1:R + for k=1:K + w(:,k,r) = w(:,k,r) ./ sum(w(:,k,r)); + end; +end; +if ~exist( 'h') || isempty( h) + h = rand( K, N); +end +n=1:N; +h(:,n) = repmat(sumx(n),K,1) .* (h(:,n) ./ repmat( sum( h(:,n), 1), K, 1)); +if ~exist( 'z') || isempty( z) + z = rand( R, K, N); +end +for k=1:K + for n=1:N + z(:,k,n) = z(:,k,n) ./ sum(z(:,k,n)); + end; +end; + + + +% Iterate +for it = 1:iter + + % E-step + zh = z .* permute(repmat(h,[1 1 R]),[3 1 2]); + xa=eps; + for r=1:R + for k=1:K + xa = xa + w(:,k,r) * squeeze(zh(r,k,:))'; + end; + end; + Q = x ./ xa; + + % M-step (update h,z) + if (lh && lz) + nh=zeros(K,N); + for k=1:K + for r=1:R + nh(k,:) = nh(k,:) + squeeze(z(r,k,:))' .* (squeeze(w(:,k,r))' * Q); + nz = h(k,:) .* (squeeze(w(:,k,r))' * Q); + nz = nz .* squeeze(z(r,k,:))'; + z(r,k,:) = nz; + end; + end; + nh = h .* nh; + end + + + % Assign and normalize + k=1:K; + n=1:N; + if lh + nh = nh.^sh; + h(:,n) = repmat(sumx(n),K,1) .* (nh(:,n) ./ repmat( sum( nh(:,n), 1), K, 1)); + end + if lz + z = z.^sz; + z(:,k,n) = z(:,k,n) ./ repmat( sum( z(:,k,n), 1), R, 1); + end + +end + +% Show me +if pl + subplot(3, 1, 1), imagesc(x), axis xy + subplot(3, 1, 2), imagesc(xa), axis xy + subplot(3, 1, 3), imagesc(h), axis xy +end