comparison examples/SMALL_test_mocod.m @ 181:0dc98f1c60bb danieleb

minor edits
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Thu, 05 Jan 2012 15:46:13 +0000
parents 714fa7b8c1ad
children
comparison
equal deleted inserted replaced
180:28b20fd46ba7 181:0dc98f1c60bb
1 %% SMALL_test_mocod
2 % Script that tests the twostep dictionary learning algorithm with MOCOD
3 % dictionary update.
4 %
5 % REFERENCES
6 % D. Barchiesi and M. D. Plumbely, Learning incoherenct dictionaries for
7 % sparse approximation using iterative projections and rotations.
8 %% Clear and close
1 clc, clear, close all 9 clc, clear, close all
2 10
3 %% Parameteres 11 %% Parameteres
4 nTrials = 10; %number of trials of the experiment 12 nTrials = 10; %number of trials of the experiment
5 13
6 % Dictionary learning parameters 14 % Dictionary learning parameters
7 toolbox = 'TwoStepDL'; %dictionary learning toolbox 15 toolbox = 'TwoStepDL'; %dictionary learning toolbox
8 dicUpdate = 'mocod'; %dictionary learning updates 16 dicUpdate = 'mocod'; %dictionary learning updates
9 zeta = logspace(-2,2,10); 17 zeta = logspace(-2,2,10); %range of values for the incoherence term
10 eta = logspace(-2,2,10); 18 eta = logspace(-2,2,10); %range of values for the unit norm term
11 19 iterNum = 20; %number of iterations
12 iterNum = 20; %number of iterations 20 epsilon = 1e-6; %tolerance level
13 epsilon = 1e-6; %tolerance level 21 dictSize = 512; %number of atoms in the dictionary
14 dictSize = 512; %number of atoms in the dictionary 22 percActiveAtoms = 5; %percentage of active atoms
15 percActiveAtoms = 5; %percentage of active atoms
16 23
17 % Test signal parameters 24 % Test signal parameters
18 signal = audio('music03_16kHz.wav'); %audio signal 25 signal = audio('music03_16kHz.wav'); %audio signal
19 blockSize = 256; %size of audio frames 26 blockSize = 256; %size of audio frames
20 overlap = 0.5; %overlap between consecutive frames 27 overlap = 0.5; %overlap between consecutive frames
23 nActiveAtoms = fix(blockSize/100*percActiveAtoms); %number of active atoms 30 nActiveAtoms = fix(blockSize/100*percActiveAtoms); %number of active atoms
24 31
25 % Initial dictionaries 32 % Initial dictionaries
26 gaborParam = struct('N',blockSize,'redundancyFactor',2,'wd',@rectwin); 33 gaborParam = struct('N',blockSize,'redundancyFactor',2,'wd',@rectwin);
27 gaborDict = Gabor_Dictionary(gaborParam); 34 gaborDict = Gabor_Dictionary(gaborParam);
28 initDicts = {[],gaborDict}; 35 initDicts = {[],gaborDict}; %cell containing initial dictionaries
29 36
30 %% Generate audio approximation problem 37 %% Generate audio approximation problem
31 signal = buffer(signal,blockSize,blockSize*overlap,@rectwin); %buffer frames of audio into columns of the matrix S 38 %buffer frames of audio into columns of the matrix S
39 signal = buffer(signal,blockSize,blockSize*overlap,@rectwin);
32 SMALL.Problem.b = signal.S; 40 SMALL.Problem.b = signal.S;
33 SMALL.Problem.b1 = SMALL.Problem.b; % copy signals from training set b to test set b1 (needed for later functions)
34 41
35 % omp2 sparse representation solver 42 %copy signals from training set b to test set b1 (needed for later functions)
36 ompParam = struct('X',SMALL.Problem.b,'epsilon',epsilon,'maxatoms',nActiveAtoms); %parameters 43 SMALL.Problem.b1 = SMALL.Problem.b;
44
45 %omp2 sparse representation solver
46 ompParam = struct('X',SMALL.Problem.b,...
47 'epsilon',epsilon,...
48 'maxatoms',nActiveAtoms); %parameters
37 solver = SMALL_init_solver('ompbox','omp2',ompParam,false); %solver structure 49 solver = SMALL_init_solver('ompbox','omp2',ompParam,false); %solver structure
38 50
39 %% Test 51 %% Test
40 nInitDicts = length(initDicts); %number of initial dictionaries 52 nInitDicts = length(initDicts);%number of initial dictionaries
41 nZetas = length(zeta); 53 nZetas = length(zeta); %number of incoherence penalty parameters
42 nEtas = length(eta); 54 nEtas = length(eta); %number of unit norm penalty parameters
43 55
44 SMALL.DL(nTrials,nInitDicts,nZetas,nEtas) = SMALL_init_DL(toolbox); %create dictionary learning structures 56 %create dictionary learning structures
57 SMALL.DL(nTrials,nInitDicts,nZetas,nEtas) = SMALL_init_DL(toolbox);
45 for iTrial=1:nTrials 58 for iTrial=1:nTrials
46 for iInitDicts=1:nInitDicts 59 for iInitDicts=1:nInitDicts
47 for iZetas=1:nZetas 60 for iZetas=1:nZetas
48 for iEtas=1:nEtas 61 for iEtas=1:nEtas
49 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).toolbox = toolbox; 62 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).toolbox = toolbox;
50 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).name = dicUpdate; 63 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).name = dicUpdate;
51 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).profile = true; 64 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).profile = true;
52 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).param = ... 65 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).param = ...
53 struct('data',SMALL.Problem.b,... 66 struct('data',SMALL.Problem.b,... %observed data
54 'Tdata',nActiveAtoms,... 67 'Tdata',nActiveAtoms,... %active atoms
55 'dictsize',dictSize,... 68 'dictsize',dictSize,... %number of atoms
56 'iternum',iterNum,... 69 'iternum',iterNum,... %number of iterations
57 'memusage','high',... 70 'memusage','high',... %memory usage
58 'solver',solver,... 71 'solver',solver,... %sparse approx solver
59 'initdict',initDicts(iInitDicts),... 72 'initdict',initDicts(iInitDicts),...%initial dictionary
60 'zeta',zeta(iZetas),... 73 'zeta',zeta(iZetas),... %incoherence penalty factor
61 'eta',eta(iEtas)); 74 'eta',eta(iEtas)); %unit norm penalty factor
75 %learn dictionary
62 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas) = ... 76 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas) = ...
63 SMALL_learn(SMALL.Problem,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas)); 77 SMALL_learn(SMALL.Problem,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas));
64 end 78 end
65 end 79 end
66 end 80 end
67 end 81 end
68 82
69 %% Evaluate coherence and snr of representation for the various methods 83 %% Evaluate coherence and snr of representation
70 sr = zeros(size(SMALL.DL)); %signal to noise ratio 84 sr = zeros(size(SMALL.DL)); %signal to noise ratio
71 mu = zeros(iTrial,iInitDicts,iZetas,iEtas); %coherence 85 mu = zeros(iTrial,iInitDicts,iZetas,iEtas); %coherence
72 dic(size(SMALL.DL)) = dictionary; %initialise dictionary objects 86 dic(size(SMALL.DL)) = dictionary; %initialise dictionary objects
73 for iTrial=1:nTrials 87 for iTrial=1:nTrials
74 for iInitDicts=1:nInitDicts 88 for iInitDicts=1:nInitDicts
75 for iZetas=1:nZetas 89 for iZetas=1:nZetas
76 for iEtas=1:nEtas 90 for iEtas=1:nEtas
77 %Sparse representation 91 %Sparse representation
78 SMALL.Problem.A = SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D; 92 SMALL.Problem.A = SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D;
79 tempSolver = SMALL_solve(SMALL.Problem,solver); 93 tempSolver = SMALL_solve(SMALL.Problem,solver);
80 %calculate snr 94 %calculate snr
81 sr(iTrial,iInitDicts,iZetas,iEtas) = ... 95 sr(iTrial,iInitDicts,iZetas,iEtas) = ...
82 snr(SMALL.Problem.b,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D*tempSolver.solution); 96 snr(SMALL.Problem.b,...
97 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D*tempSolver.solution);
83 %calculate mu 98 %calculate mu
84 dic(iTrial,iInitDicts,iZetas,iEtas) = ... 99 dic(iTrial,iInitDicts,iZetas,iEtas) = ...
85 dictionary(SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D); 100 dictionary(SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D);
86 mu(iTrial,iInitDicts,iZetas,iEtas) = ... 101 mu(iTrial,iInitDicts,iZetas,iEtas) = ...
87 dic(iTrial,iInitDicts,iZetas,iEtas).coherence; 102 dic(iTrial,iInitDicts,iZetas,iEtas).coherence;
88 end 103 end
89 end 104 end
90 end 105 end
91 end 106 end
92 107
93 save('MOCOD.mat')
94
95 %% Plot results 108 %% Plot results
96 minMu = sqrt((dictSize-blockSize)/(blockSize*(dictSize-1))); %lowe bound on coherence 109 minMu = sqrt((dictSize-blockSize)/(blockSize*(dictSize-1)));%lower bound on coherence
97 initDictsNames = {'Data','Gabor'}; 110 initDictsNames = {'Data','Gabor'}; %names of initial dictionaries
98 lineStyles = {'k.-','r*-','b+-'}; 111 lineStyles = {'k.-','r*-','b+-'};
99 for iInitDict=1:nInitDicts 112 for iInitDict=1:nInitDicts
100 figure, hold on, grid on 113 figure, hold on, grid on
101 title([initDictsNames{iInitDict} ' Initialisation']); 114 %print initial dictionary as figure title
115 DisplayFigureTitle([initDictsNames{iInitDict} ' Initialisation']);
116 % set(gcf,'Units','Normalized');
117 % txh = annotation(gcf,'textbox',[0.4,0.95,0.2,0.05]);
118 % set(txh,'String',[initDictsNames{iInitDict} ' Initialisation'],...
119 % 'LineStyle','none','HorizontalAlignment','center');
120 %calculate mean coherence levels and SNRs over trials
102 coherenceLevels = squeeze(mean(mu(:,iInitDict,:,:),1)); 121 coherenceLevels = squeeze(mean(mu(:,iInitDict,:,:),1));
103 meanSNRs = squeeze(mean(sr(:,iInitDict,:,:),1)); 122 meanSNRs = squeeze(mean(sr(:,iInitDict,:,:),1));
104 %stdSNRs = squeeze(std(sr(:,iInitDict,iZetas,iEtas),0,1)); 123 % stdSNRs = squeeze(std(sr(:,iInitDict,iZetas,iEtas),0,1));
124 %plot coherence levels
105 subplot(2,2,1) 125 subplot(2,2,1)
106 surf(eta,zeta,coherenceLevels); 126 surf(eta,zeta,coherenceLevels);
107 set(gca,'Xscale','log','Yscale','log','ZLim',[0 1.4]); 127 set(gca,'Xscale','log','Yscale','log','ZLim',[0 1.4]);
108 view(gca,130,20) 128 view(gca,130,20)
109 xlabel('\eta'); 129 xlabel('\eta');
110 ylabel('\zeta'); 130 ylabel('\zeta');
111 zlabel('\mu'); 131 zlabel('\mu');
112 title('Coherence') 132 title('Coherence')
113 133
134 %plot SNRs
114 subplot(2,2,2) 135 subplot(2,2,2)
115 surf(eta,zeta,meanSNRs); 136 surf(eta,zeta,meanSNRs);
116 set(gca,'Xscale','log','Yscale','log','ZLim',[0 25]); 137 set(gca,'Xscale','log','Yscale','log','ZLim',[0 25]);
117 view(gca,130,20) 138 view(gca,130,20)
118 xlabel('\eta'); 139 xlabel('\eta');
119 ylabel('\zeta'); 140 ylabel('\zeta');
120 zlabel('SNR (dB)'); 141 zlabel('SNR (dB)');
121 title('Reconstruction Error') 142 title('Reconstruction Error')
122 143
144 %plot mu/SNR scatter
123 subplot(2,2,[3 4]) 145 subplot(2,2,[3 4])
124 mus = mu(:,iInitDict,:,:); 146 mus = mu(:,iInitDict,:,:);
125 mus = mus(:); 147 mus = mus(:);
126 SNRs = sr(:,iInitDict,:,:); 148 SNRs = sr(:,iInitDict,:,:);
127 SNRs = SNRs(:); 149 SNRs = SNRs(:);