Mercurial > hg > smallbox
comparison examples/SMALL_DL_test.m @ 194:9b0595a8478d danieleb
Debugged SMALL_DL_test and added copyright info
author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
---|---|
date | Tue, 13 Mar 2012 17:53:46 +0000 |
parents | cc540df790f4 |
children | fd0b5d36f6ad |
comparison
equal
deleted
inserted
replaced
193:cc540df790f4 | 194:9b0595a8478d |
---|---|
1 function SMALL_DL_test | 1 function SMALL_DL_test |
2 clear, clc, close all | 2 clear, clc, close all |
3 % Create a 2-dimensional dataset of points that are oriented in 3 | 3 % Create a 2-dimensional dataset of points that are oriented in 3 |
4 % directions on a x-y plane | 4 % directions on a x/y plane |
5 | |
6 % | |
7 % Centre for Digital Music, Queen Mary, University of London. | |
8 % This file copyright 2012 Daniele Barchiesi. | |
9 % | |
10 % This program is free software; you can redistribute it and/or | |
11 % modify it under the terms of the GNU General Public License as | |
12 % published by the Free Software Foundation; either version 2 of the | |
13 % License, or (at your option) any later version. See the file | |
14 % COPYING included with this distribution for more information. | |
15 | |
5 nData = 10000; %number of data | 16 nData = 10000; %number of data |
6 theta = [pi/6 pi/3 4*pi/6]; %angles | 17 theta = [pi/6 pi/3 4*pi/6]; %angles |
7 m = length(theta); | 18 nAngles = length(theta); %number of angles |
8 Q = [cos(theta); sin(theta)]; %rotation matrix | 19 Q = [cos(theta); sin(theta)]; %rotation matrix |
9 X = Q*randmog(m,nData); | 20 X = Q*randmog(nAngles,nData); %training data |
10 | 21 |
11 % find principal directions using PCA and plot them | 22 % find principal directions using PCA |
12 XXt = X*X'; | 23 XXt = X*X'; %cross correlation matrix |
13 [U ~] = svd(XXt); | 24 [U ~] = svd(XXt); %svd of XXt |
14 scale = 3; | 25 |
26 scale = 3; %scale factor for plots | |
15 subplot(1,2,1), hold on | 27 subplot(1,2,1), hold on |
16 title('Principal Component Analysis') | 28 title('Principal Component Analysis') |
17 scatter(X(1,:), X(2,:),'.'); | 29 scatter(X(1,:), X(2,:),'.'); %scatter training data |
18 O = zeros(size(U)); | 30 O = zeros(size(U)); %origin |
19 quiver(O(1,1:2),O(2,1:2),scale*U(1,:),scale*U(2,:),'LineWidth',2,'Color','k') | 31 quiver(O(1,1:2),O(2,1:2),scale*U(1,:),scale*U(2,:),... |
20 axis equal | 32 'LineWidth',2,'Color','k') %plot atoms |
33 axis equal %scale axis | |
21 | 34 |
22 subplot(1,2,2), hold on | 35 subplot(1,2,2), hold on |
23 title('K-SVD Dictionary') | 36 title('K-SVD Dictionary') |
24 scatter(X(1,:), X(2,:),'.'); | 37 scatter(X(1,:), X(2,:),'.'); |
25 axis equal | 38 axis equal |
26 nAtoms = 3; | 39 |
27 initDict = randn(2,nAtoms); | 40 nAtoms = 3; %number of atoms in the dictionary |
28 nIter = 10; | 41 nIter = 1; %number of dictionary learning iterations |
29 O = zeros(size(initDict)); | 42 initDict = normc(randn(2,nAtoms)); %random initial dictionary |
43 O = zeros(size(initDict)); %origin | |
44 | |
30 % apply dictionary learning algorithm | 45 % apply dictionary learning algorithm |
31 ksvd_params = struct('data',X,... %training data | 46 ksvd_params = struct('data',X,... %training data |
32 'Tdata',1,... %sparsity level | 47 'Tdata',1,... %sparsity level |
33 'dictsize',nAtoms,... %number of atoms | 48 'dictsize',nAtoms,... %number of atoms |
34 'initdict',initDict,... | 49 'initdict',initDict,...%initial dictionary |
35 'iternum',10); %number of iterations | 50 'iternum',10); %number of iterations |
36 DL = SMALL_init_DL('ksvd','ksvd',ksvd_params); | 51 DL = SMALL_init_DL('ksvd','ksvd',ksvd_params); %dictionary learning structure |
37 DL.D = initDict; | 52 DL.D = initDict; %copy initial dictionary in solution variable |
53 problem = struct('b',X); %copy training data in problem structure | |
54 | |
38 xdata = DL.D(1,:); | 55 xdata = DL.D(1,:); |
39 ydata = DL.D(2,:); | 56 ydata = DL.D(2,:); |
40 qPlot = quiver(O(1,:),O(2,:),scale*initDict(1,:),scale*initDict(2,:),... | 57 qPlot = quiver(O(1,:),O(2,:),scale*initDict(1,:),scale*initDict(2,:),... |
41 'LineWidth',2,'Color','k','XDataSource','xdata','YDataSource','ydata'); | 58 'LineWidth',2,'Color','k','UDataSource','xdata','VDataSource','ydata'); |
42 problem = struct('b',X); %training data | |
43 | 59 |
44 %plot dictionary and learn | |
45 for iIter=1:nIter | 60 for iIter=1:nIter |
46 DL.ksvd_params.initdict = DL.D; | 61 DL.ksvd_params.initdict = DL.D; |
62 pause | |
47 DL = SMALL_learn(problem,DL); %learn dictionary | 63 DL = SMALL_learn(problem,DL); %learn dictionary |
48 xdata = DL.D(1,:); | 64 xdata = scale*DL.D(1,:); |
49 ydata = DL.D(2,:); | 65 ydata = scale*DL.D(2,:); |
50 pause | |
51 refreshdata(gcf,'caller'); | 66 refreshdata(gcf,'caller'); |
52 %quiver(O(1,:),O(2,:),scale*DL.D(1,:),scale*DL.D(2,:),'LineWidth',2,'Color','k'); | |
53 end | 67 end |
54 | 68 |
55 | 69 |
56 function X = randmog(m, n) | 70 function X = randmog(m, n) |
57 % RANDMOG - Generate mixture of Gaussians | 71 % RANDMOG - Generate mixture of Gaussians |