Mercurial > hg > smallbox
comparison examples/SMALL_test_mocod.m @ 181:0dc98f1c60bb danieleb
minor edits
author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
---|---|
date | Thu, 05 Jan 2012 15:46:13 +0000 |
parents | 714fa7b8c1ad |
children |
comparison
equal
deleted
inserted
replaced
180:28b20fd46ba7 | 181:0dc98f1c60bb |
---|---|
1 %% SMALL_test_mocod | |
2 % Script that tests the twostep dictionary learning algorithm with MOCOD | |
3 % dictionary update. | |
4 % | |
5 % REFERENCES | |
6 % D. Barchiesi and M. D. Plumbely, Learning incoherenct dictionaries for | |
7 % sparse approximation using iterative projections and rotations. | |
8 %% Clear and close | |
1 clc, clear, close all | 9 clc, clear, close all |
2 | 10 |
3 %% Parameteres | 11 %% Parameteres |
4 nTrials = 10; %number of trials of the experiment | 12 nTrials = 10; %number of trials of the experiment |
5 | 13 |
6 % Dictionary learning parameters | 14 % Dictionary learning parameters |
7 toolbox = 'TwoStepDL'; %dictionary learning toolbox | 15 toolbox = 'TwoStepDL'; %dictionary learning toolbox |
8 dicUpdate = 'mocod'; %dictionary learning updates | 16 dicUpdate = 'mocod'; %dictionary learning updates |
9 zeta = logspace(-2,2,10); | 17 zeta = logspace(-2,2,10); %range of values for the incoherence term |
10 eta = logspace(-2,2,10); | 18 eta = logspace(-2,2,10); %range of values for the unit norm term |
11 | 19 iterNum = 20; %number of iterations |
12 iterNum = 20; %number of iterations | 20 epsilon = 1e-6; %tolerance level |
13 epsilon = 1e-6; %tolerance level | 21 dictSize = 512; %number of atoms in the dictionary |
14 dictSize = 512; %number of atoms in the dictionary | 22 percActiveAtoms = 5; %percentage of active atoms |
15 percActiveAtoms = 5; %percentage of active atoms | |
16 | 23 |
17 % Test signal parameters | 24 % Test signal parameters |
18 signal = audio('music03_16kHz.wav'); %audio signal | 25 signal = audio('music03_16kHz.wav'); %audio signal |
19 blockSize = 256; %size of audio frames | 26 blockSize = 256; %size of audio frames |
20 overlap = 0.5; %overlap between consecutive frames | 27 overlap = 0.5; %overlap between consecutive frames |
23 nActiveAtoms = fix(blockSize/100*percActiveAtoms); %number of active atoms | 30 nActiveAtoms = fix(blockSize/100*percActiveAtoms); %number of active atoms |
24 | 31 |
25 % Initial dictionaries | 32 % Initial dictionaries |
26 gaborParam = struct('N',blockSize,'redundancyFactor',2,'wd',@rectwin); | 33 gaborParam = struct('N',blockSize,'redundancyFactor',2,'wd',@rectwin); |
27 gaborDict = Gabor_Dictionary(gaborParam); | 34 gaborDict = Gabor_Dictionary(gaborParam); |
28 initDicts = {[],gaborDict}; | 35 initDicts = {[],gaborDict}; %cell containing initial dictionaries |
29 | 36 |
30 %% Generate audio approximation problem | 37 %% Generate audio approximation problem |
31 signal = buffer(signal,blockSize,blockSize*overlap,@rectwin); %buffer frames of audio into columns of the matrix S | 38 %buffer frames of audio into columns of the matrix S |
39 signal = buffer(signal,blockSize,blockSize*overlap,@rectwin); | |
32 SMALL.Problem.b = signal.S; | 40 SMALL.Problem.b = signal.S; |
33 SMALL.Problem.b1 = SMALL.Problem.b; % copy signals from training set b to test set b1 (needed for later functions) | |
34 | 41 |
35 % omp2 sparse representation solver | 42 %copy signals from training set b to test set b1 (needed for later functions) |
36 ompParam = struct('X',SMALL.Problem.b,'epsilon',epsilon,'maxatoms',nActiveAtoms); %parameters | 43 SMALL.Problem.b1 = SMALL.Problem.b; |
44 | |
45 %omp2 sparse representation solver | |
46 ompParam = struct('X',SMALL.Problem.b,... | |
47 'epsilon',epsilon,... | |
48 'maxatoms',nActiveAtoms); %parameters | |
37 solver = SMALL_init_solver('ompbox','omp2',ompParam,false); %solver structure | 49 solver = SMALL_init_solver('ompbox','omp2',ompParam,false); %solver structure |
38 | 50 |
39 %% Test | 51 %% Test |
40 nInitDicts = length(initDicts); %number of initial dictionaries | 52 nInitDicts = length(initDicts);%number of initial dictionaries |
41 nZetas = length(zeta); | 53 nZetas = length(zeta); %number of incoherence penalty parameters |
42 nEtas = length(eta); | 54 nEtas = length(eta); %number of unit norm penalty parameters |
43 | 55 |
44 SMALL.DL(nTrials,nInitDicts,nZetas,nEtas) = SMALL_init_DL(toolbox); %create dictionary learning structures | 56 %create dictionary learning structures |
57 SMALL.DL(nTrials,nInitDicts,nZetas,nEtas) = SMALL_init_DL(toolbox); | |
45 for iTrial=1:nTrials | 58 for iTrial=1:nTrials |
46 for iInitDicts=1:nInitDicts | 59 for iInitDicts=1:nInitDicts |
47 for iZetas=1:nZetas | 60 for iZetas=1:nZetas |
48 for iEtas=1:nEtas | 61 for iEtas=1:nEtas |
49 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).toolbox = toolbox; | 62 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).toolbox = toolbox; |
50 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).name = dicUpdate; | 63 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).name = dicUpdate; |
51 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).profile = true; | 64 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).profile = true; |
52 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).param = ... | 65 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).param = ... |
53 struct('data',SMALL.Problem.b,... | 66 struct('data',SMALL.Problem.b,... %observed data |
54 'Tdata',nActiveAtoms,... | 67 'Tdata',nActiveAtoms,... %active atoms |
55 'dictsize',dictSize,... | 68 'dictsize',dictSize,... %number of atoms |
56 'iternum',iterNum,... | 69 'iternum',iterNum,... %number of iterations |
57 'memusage','high',... | 70 'memusage','high',... %memory usage |
58 'solver',solver,... | 71 'solver',solver,... %sparse approx solver |
59 'initdict',initDicts(iInitDicts),... | 72 'initdict',initDicts(iInitDicts),...%initial dictionary |
60 'zeta',zeta(iZetas),... | 73 'zeta',zeta(iZetas),... %incoherence penalty factor |
61 'eta',eta(iEtas)); | 74 'eta',eta(iEtas)); %unit norm penalty factor |
75 %learn dictionary | |
62 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas) = ... | 76 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas) = ... |
63 SMALL_learn(SMALL.Problem,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas)); | 77 SMALL_learn(SMALL.Problem,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas)); |
64 end | 78 end |
65 end | 79 end |
66 end | 80 end |
67 end | 81 end |
68 | 82 |
69 %% Evaluate coherence and snr of representation for the various methods | 83 %% Evaluate coherence and snr of representation |
70 sr = zeros(size(SMALL.DL)); %signal to noise ratio | 84 sr = zeros(size(SMALL.DL)); %signal to noise ratio |
71 mu = zeros(iTrial,iInitDicts,iZetas,iEtas); %coherence | 85 mu = zeros(iTrial,iInitDicts,iZetas,iEtas); %coherence |
72 dic(size(SMALL.DL)) = dictionary; %initialise dictionary objects | 86 dic(size(SMALL.DL)) = dictionary; %initialise dictionary objects |
73 for iTrial=1:nTrials | 87 for iTrial=1:nTrials |
74 for iInitDicts=1:nInitDicts | 88 for iInitDicts=1:nInitDicts |
75 for iZetas=1:nZetas | 89 for iZetas=1:nZetas |
76 for iEtas=1:nEtas | 90 for iEtas=1:nEtas |
77 %Sparse representation | 91 %Sparse representation |
78 SMALL.Problem.A = SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D; | 92 SMALL.Problem.A = SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D; |
79 tempSolver = SMALL_solve(SMALL.Problem,solver); | 93 tempSolver = SMALL_solve(SMALL.Problem,solver); |
80 %calculate snr | 94 %calculate snr |
81 sr(iTrial,iInitDicts,iZetas,iEtas) = ... | 95 sr(iTrial,iInitDicts,iZetas,iEtas) = ... |
82 snr(SMALL.Problem.b,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D*tempSolver.solution); | 96 snr(SMALL.Problem.b,... |
97 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D*tempSolver.solution); | |
83 %calculate mu | 98 %calculate mu |
84 dic(iTrial,iInitDicts,iZetas,iEtas) = ... | 99 dic(iTrial,iInitDicts,iZetas,iEtas) = ... |
85 dictionary(SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D); | 100 dictionary(SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D); |
86 mu(iTrial,iInitDicts,iZetas,iEtas) = ... | 101 mu(iTrial,iInitDicts,iZetas,iEtas) = ... |
87 dic(iTrial,iInitDicts,iZetas,iEtas).coherence; | 102 dic(iTrial,iInitDicts,iZetas,iEtas).coherence; |
88 end | 103 end |
89 end | 104 end |
90 end | 105 end |
91 end | 106 end |
92 | 107 |
93 save('MOCOD.mat') | |
94 | |
95 %% Plot results | 108 %% Plot results |
96 minMu = sqrt((dictSize-blockSize)/(blockSize*(dictSize-1))); %lowe bound on coherence | 109 minMu = sqrt((dictSize-blockSize)/(blockSize*(dictSize-1)));%lower bound on coherence |
97 initDictsNames = {'Data','Gabor'}; | 110 initDictsNames = {'Data','Gabor'}; %names of initial dictionaries |
98 lineStyles = {'k.-','r*-','b+-'}; | 111 lineStyles = {'k.-','r*-','b+-'}; |
99 for iInitDict=1:nInitDicts | 112 for iInitDict=1:nInitDicts |
100 figure, hold on, grid on | 113 figure, hold on, grid on |
101 title([initDictsNames{iInitDict} ' Initialisation']); | 114 %print initial dictionary as figure title |
115 DisplayFigureTitle([initDictsNames{iInitDict} ' Initialisation']); | |
116 % set(gcf,'Units','Normalized'); | |
117 % txh = annotation(gcf,'textbox',[0.4,0.95,0.2,0.05]); | |
118 % set(txh,'String',[initDictsNames{iInitDict} ' Initialisation'],... | |
119 % 'LineStyle','none','HorizontalAlignment','center'); | |
120 %calculate mean coherence levels and SNRs over trials | |
102 coherenceLevels = squeeze(mean(mu(:,iInitDict,:,:),1)); | 121 coherenceLevels = squeeze(mean(mu(:,iInitDict,:,:),1)); |
103 meanSNRs = squeeze(mean(sr(:,iInitDict,:,:),1)); | 122 meanSNRs = squeeze(mean(sr(:,iInitDict,:,:),1)); |
104 %stdSNRs = squeeze(std(sr(:,iInitDict,iZetas,iEtas),0,1)); | 123 % stdSNRs = squeeze(std(sr(:,iInitDict,iZetas,iEtas),0,1)); |
124 %plot coherence levels | |
105 subplot(2,2,1) | 125 subplot(2,2,1) |
106 surf(eta,zeta,coherenceLevels); | 126 surf(eta,zeta,coherenceLevels); |
107 set(gca,'Xscale','log','Yscale','log','ZLim',[0 1.4]); | 127 set(gca,'Xscale','log','Yscale','log','ZLim',[0 1.4]); |
108 view(gca,130,20) | 128 view(gca,130,20) |
109 xlabel('\eta'); | 129 xlabel('\eta'); |
110 ylabel('\zeta'); | 130 ylabel('\zeta'); |
111 zlabel('\mu'); | 131 zlabel('\mu'); |
112 title('Coherence') | 132 title('Coherence') |
113 | 133 |
134 %plot SNRs | |
114 subplot(2,2,2) | 135 subplot(2,2,2) |
115 surf(eta,zeta,meanSNRs); | 136 surf(eta,zeta,meanSNRs); |
116 set(gca,'Xscale','log','Yscale','log','ZLim',[0 25]); | 137 set(gca,'Xscale','log','Yscale','log','ZLim',[0 25]); |
117 view(gca,130,20) | 138 view(gca,130,20) |
118 xlabel('\eta'); | 139 xlabel('\eta'); |
119 ylabel('\zeta'); | 140 ylabel('\zeta'); |
120 zlabel('SNR (dB)'); | 141 zlabel('SNR (dB)'); |
121 title('Reconstruction Error') | 142 title('Reconstruction Error') |
122 | 143 |
144 %plot mu/SNR scatter | |
123 subplot(2,2,[3 4]) | 145 subplot(2,2,[3 4]) |
124 mus = mu(:,iInitDict,:,:); | 146 mus = mu(:,iInitDict,:,:); |
125 mus = mus(:); | 147 mus = mus(:); |
126 SNRs = sr(:,iInitDict,:,:); | 148 SNRs = sr(:,iInitDict,:,:); |
127 SNRs = SNRs(:); | 149 SNRs = SNRs(:); |