view examples/SMALL_DL_test.m @ 193:cc540df790f4 danieleb

Simple example that demonstrated dictionary learning... to be completed
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Fri, 09 Mar 2012 15:12:01 +0000
parents
children 9b0595a8478d
line wrap: on
line source
function SMALL_DL_test
clear, clc, close all
% Create a 2-dimensional dataset of points that are oriented in 3
% directions on a x-y plane
nData = 10000;						%number of data
theta = [pi/6 pi/3 4*pi/6];			%angles
m     = length(theta);
Q	  = [cos(theta); sin(theta)];	%rotation matrix
X	  = Q*randmog(m,nData);

% find principal directions using PCA and plot them
XXt = X*X';
[U ~] = svd(XXt);
scale = 3;
subplot(1,2,1), hold on
title('Principal Component Analysis')
scatter(X(1,:), X(2,:),'.');
O = zeros(size(U));
quiver(O(1,1:2),O(2,1:2),scale*U(1,:),scale*U(2,:),'LineWidth',2,'Color','k')
axis equal

subplot(1,2,2), hold on
title('K-SVD Dictionary')
scatter(X(1,:), X(2,:),'.');
axis equal
nAtoms   = 3;
initDict = randn(2,nAtoms);
nIter = 10;
O = zeros(size(initDict));
% apply dictionary learning algorithm
ksvd_params = struct('data',X,...		%training data
					 'Tdata',1,...		%sparsity level
					 'dictsize',nAtoms,...	%number of atoms
					 'initdict',initDict,...
					 'iternum',10);		%number of iterations
DL = SMALL_init_DL('ksvd','ksvd',ksvd_params);
DL.D = initDict;
xdata = DL.D(1,:);
ydata = DL.D(2,:);
qPlot = quiver(O(1,:),O(2,:),scale*initDict(1,:),scale*initDict(2,:),...
	'LineWidth',2,'Color','k','XDataSource','xdata','YDataSource','ydata');
problem = struct('b',X);				%training data

%plot dictionary and learn
for iIter=1:nIter
	DL.ksvd_params.initdict = DL.D;
	DL = SMALL_learn(problem,DL);			%learn dictionary
	xdata = DL.D(1,:);	
	ydata = DL.D(2,:);
	pause
	refreshdata(gcf,'caller');
	%quiver(O(1,:),O(2,:),scale*DL.D(1,:),scale*DL.D(2,:),'LineWidth',2,'Color','k');
end


function X = randmog(m, n)
% RANDMOG - Generate mixture of Gaussians
s = [0.2 2];
% Choose which Gaussian
G1 = (rand(m, n) < 0.9);
% Make them
X = (G1.*s(1) + (1-G1).*s(2)) .* randn(m,n);