comparison examples/SMALL_DL_test.m @ 194:9b0595a8478d danieleb

Debugged SMALL_DL_test and added copyright info
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Tue, 13 Mar 2012 17:53:46 +0000
parents cc540df790f4
children fd0b5d36f6ad
comparison
equal deleted inserted replaced
193:cc540df790f4 194:9b0595a8478d
1 function SMALL_DL_test 1 function SMALL_DL_test
2 clear, clc, close all 2 clear, clc, close all
3 % Create a 2-dimensional dataset of points that are oriented in 3 3 % Create a 2-dimensional dataset of points that are oriented in 3
4 % directions on a x-y plane 4 % directions on a x/y plane
5
6 %
7 % Centre for Digital Music, Queen Mary, University of London.
8 % This file copyright 2012 Daniele Barchiesi.
9 %
10 % This program is free software; you can redistribute it and/or
11 % modify it under the terms of the GNU General Public License as
12 % published by the Free Software Foundation; either version 2 of the
13 % License, or (at your option) any later version. See the file
14 % COPYING included with this distribution for more information.
15
5 nData = 10000; %number of data 16 nData = 10000; %number of data
6 theta = [pi/6 pi/3 4*pi/6]; %angles 17 theta = [pi/6 pi/3 4*pi/6]; %angles
7 m = length(theta); 18 nAngles = length(theta); %number of angles
8 Q = [cos(theta); sin(theta)]; %rotation matrix 19 Q = [cos(theta); sin(theta)]; %rotation matrix
9 X = Q*randmog(m,nData); 20 X = Q*randmog(nAngles,nData); %training data
10 21
11 % find principal directions using PCA and plot them 22 % find principal directions using PCA
12 XXt = X*X'; 23 XXt = X*X'; %cross correlation matrix
13 [U ~] = svd(XXt); 24 [U ~] = svd(XXt); %svd of XXt
14 scale = 3; 25
26 scale = 3; %scale factor for plots
15 subplot(1,2,1), hold on 27 subplot(1,2,1), hold on
16 title('Principal Component Analysis') 28 title('Principal Component Analysis')
17 scatter(X(1,:), X(2,:),'.'); 29 scatter(X(1,:), X(2,:),'.'); %scatter training data
18 O = zeros(size(U)); 30 O = zeros(size(U)); %origin
19 quiver(O(1,1:2),O(2,1:2),scale*U(1,:),scale*U(2,:),'LineWidth',2,'Color','k') 31 quiver(O(1,1:2),O(2,1:2),scale*U(1,:),scale*U(2,:),...
20 axis equal 32 'LineWidth',2,'Color','k') %plot atoms
33 axis equal %scale axis
21 34
22 subplot(1,2,2), hold on 35 subplot(1,2,2), hold on
23 title('K-SVD Dictionary') 36 title('K-SVD Dictionary')
24 scatter(X(1,:), X(2,:),'.'); 37 scatter(X(1,:), X(2,:),'.');
25 axis equal 38 axis equal
26 nAtoms = 3; 39
27 initDict = randn(2,nAtoms); 40 nAtoms = 3; %number of atoms in the dictionary
28 nIter = 10; 41 nIter = 1; %number of dictionary learning iterations
29 O = zeros(size(initDict)); 42 initDict = normc(randn(2,nAtoms)); %random initial dictionary
43 O = zeros(size(initDict)); %origin
44
30 % apply dictionary learning algorithm 45 % apply dictionary learning algorithm
31 ksvd_params = struct('data',X,... %training data 46 ksvd_params = struct('data',X,... %training data
32 'Tdata',1,... %sparsity level 47 'Tdata',1,... %sparsity level
33 'dictsize',nAtoms,... %number of atoms 48 'dictsize',nAtoms,... %number of atoms
34 'initdict',initDict,... 49 'initdict',initDict,...%initial dictionary
35 'iternum',10); %number of iterations 50 'iternum',10); %number of iterations
36 DL = SMALL_init_DL('ksvd','ksvd',ksvd_params); 51 DL = SMALL_init_DL('ksvd','ksvd',ksvd_params); %dictionary learning structure
37 DL.D = initDict; 52 DL.D = initDict; %copy initial dictionary in solution variable
53 problem = struct('b',X); %copy training data in problem structure
54
38 xdata = DL.D(1,:); 55 xdata = DL.D(1,:);
39 ydata = DL.D(2,:); 56 ydata = DL.D(2,:);
40 qPlot = quiver(O(1,:),O(2,:),scale*initDict(1,:),scale*initDict(2,:),... 57 qPlot = quiver(O(1,:),O(2,:),scale*initDict(1,:),scale*initDict(2,:),...
41 'LineWidth',2,'Color','k','XDataSource','xdata','YDataSource','ydata'); 58 'LineWidth',2,'Color','k','UDataSource','xdata','VDataSource','ydata');
42 problem = struct('b',X); %training data
43 59
44 %plot dictionary and learn
45 for iIter=1:nIter 60 for iIter=1:nIter
46 DL.ksvd_params.initdict = DL.D; 61 DL.ksvd_params.initdict = DL.D;
62 pause
47 DL = SMALL_learn(problem,DL); %learn dictionary 63 DL = SMALL_learn(problem,DL); %learn dictionary
48 xdata = DL.D(1,:); 64 xdata = scale*DL.D(1,:);
49 ydata = DL.D(2,:); 65 ydata = scale*DL.D(2,:);
50 pause
51 refreshdata(gcf,'caller'); 66 refreshdata(gcf,'caller');
52 %quiver(O(1,:),O(2,:),scale*DL.D(1,:),scale*DL.D(2,:),'LineWidth',2,'Color','k');
53 end 67 end
54 68
55 69
56 function X = randmog(m, n) 70 function X = randmog(m, n)
57 % RANDMOG - Generate mixture of Gaussians 71 % RANDMOG - Generate mixture of Gaussians