daniele@193
|
1 function SMALL_DL_test
|
daniele@193
|
2 clear, clc, close all
|
daniele@193
|
3 % Create a 2-dimensional dataset of points that are oriented in 3
|
daniele@193
|
4 % directions on a x-y plane
|
daniele@193
|
5 nData = 10000; %number of data
|
daniele@193
|
6 theta = [pi/6 pi/3 4*pi/6]; %angles
|
daniele@193
|
7 m = length(theta);
|
daniele@193
|
8 Q = [cos(theta); sin(theta)]; %rotation matrix
|
daniele@193
|
9 X = Q*randmog(m,nData);
|
daniele@193
|
10
|
daniele@193
|
11 % find principal directions using PCA and plot them
|
daniele@193
|
12 XXt = X*X';
|
daniele@193
|
13 [U ~] = svd(XXt);
|
daniele@193
|
14 scale = 3;
|
daniele@193
|
15 subplot(1,2,1), hold on
|
daniele@193
|
16 title('Principal Component Analysis')
|
daniele@193
|
17 scatter(X(1,:), X(2,:),'.');
|
daniele@193
|
18 O = zeros(size(U));
|
daniele@193
|
19 quiver(O(1,1:2),O(2,1:2),scale*U(1,:),scale*U(2,:),'LineWidth',2,'Color','k')
|
daniele@193
|
20 axis equal
|
daniele@193
|
21
|
daniele@193
|
22 subplot(1,2,2), hold on
|
daniele@193
|
23 title('K-SVD Dictionary')
|
daniele@193
|
24 scatter(X(1,:), X(2,:),'.');
|
daniele@193
|
25 axis equal
|
daniele@193
|
26 nAtoms = 3;
|
daniele@193
|
27 initDict = randn(2,nAtoms);
|
daniele@193
|
28 nIter = 10;
|
daniele@193
|
29 O = zeros(size(initDict));
|
daniele@193
|
30 % apply dictionary learning algorithm
|
daniele@193
|
31 ksvd_params = struct('data',X,... %training data
|
daniele@193
|
32 'Tdata',1,... %sparsity level
|
daniele@193
|
33 'dictsize',nAtoms,... %number of atoms
|
daniele@193
|
34 'initdict',initDict,...
|
daniele@193
|
35 'iternum',10); %number of iterations
|
daniele@193
|
36 DL = SMALL_init_DL('ksvd','ksvd',ksvd_params);
|
daniele@193
|
37 DL.D = initDict;
|
daniele@193
|
38 xdata = DL.D(1,:);
|
daniele@193
|
39 ydata = DL.D(2,:);
|
daniele@193
|
40 qPlot = quiver(O(1,:),O(2,:),scale*initDict(1,:),scale*initDict(2,:),...
|
daniele@193
|
41 'LineWidth',2,'Color','k','XDataSource','xdata','YDataSource','ydata');
|
daniele@193
|
42 problem = struct('b',X); %training data
|
daniele@193
|
43
|
daniele@193
|
44 %plot dictionary and learn
|
daniele@193
|
45 for iIter=1:nIter
|
daniele@193
|
46 DL.ksvd_params.initdict = DL.D;
|
daniele@193
|
47 DL = SMALL_learn(problem,DL); %learn dictionary
|
daniele@193
|
48 xdata = DL.D(1,:);
|
daniele@193
|
49 ydata = DL.D(2,:);
|
daniele@193
|
50 pause
|
daniele@193
|
51 refreshdata(gcf,'caller');
|
daniele@193
|
52 %quiver(O(1,:),O(2,:),scale*DL.D(1,:),scale*DL.D(2,:),'LineWidth',2,'Color','k');
|
daniele@193
|
53 end
|
daniele@193
|
54
|
daniele@193
|
55
|
daniele@193
|
56 function X = randmog(m, n)
|
daniele@193
|
57 % RANDMOG - Generate mixture of Gaussians
|
daniele@193
|
58 s = [0.2 2];
|
daniele@193
|
59 % Choose which Gaussian
|
daniele@193
|
60 G1 = (rand(m, n) < 0.9);
|
daniele@193
|
61 % Make them
|
daniele@193
|
62 X = (G1.*s(1) + (1-G1).*s(2)) .* randn(m,n);
|