Mercurial > hg > smallbox
comparison examples/SMALL_test_mocod.m @ 181:0dc98f1c60bb danieleb
minor edits
| author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
|---|---|
| date | Thu, 05 Jan 2012 15:46:13 +0000 |
| parents | 714fa7b8c1ad |
| children |
comparison
equal
deleted
inserted
replaced
| 180:28b20fd46ba7 | 181:0dc98f1c60bb |
|---|---|
| 1 %% SMALL_test_mocod | |
| 2 % Script that tests the twostep dictionary learning algorithm with MOCOD | |
| 3 % dictionary update. | |
| 4 % | |
| 5 % REFERENCES | |
| 6 % D. Barchiesi and M. D. Plumbely, Learning incoherenct dictionaries for | |
| 7 % sparse approximation using iterative projections and rotations. | |
| 8 %% Clear and close | |
| 1 clc, clear, close all | 9 clc, clear, close all |
| 2 | 10 |
| 3 %% Parameteres | 11 %% Parameteres |
| 4 nTrials = 10; %number of trials of the experiment | 12 nTrials = 10; %number of trials of the experiment |
| 5 | 13 |
| 6 % Dictionary learning parameters | 14 % Dictionary learning parameters |
| 7 toolbox = 'TwoStepDL'; %dictionary learning toolbox | 15 toolbox = 'TwoStepDL'; %dictionary learning toolbox |
| 8 dicUpdate = 'mocod'; %dictionary learning updates | 16 dicUpdate = 'mocod'; %dictionary learning updates |
| 9 zeta = logspace(-2,2,10); | 17 zeta = logspace(-2,2,10); %range of values for the incoherence term |
| 10 eta = logspace(-2,2,10); | 18 eta = logspace(-2,2,10); %range of values for the unit norm term |
| 11 | 19 iterNum = 20; %number of iterations |
| 12 iterNum = 20; %number of iterations | 20 epsilon = 1e-6; %tolerance level |
| 13 epsilon = 1e-6; %tolerance level | 21 dictSize = 512; %number of atoms in the dictionary |
| 14 dictSize = 512; %number of atoms in the dictionary | 22 percActiveAtoms = 5; %percentage of active atoms |
| 15 percActiveAtoms = 5; %percentage of active atoms | |
| 16 | 23 |
| 17 % Test signal parameters | 24 % Test signal parameters |
| 18 signal = audio('music03_16kHz.wav'); %audio signal | 25 signal = audio('music03_16kHz.wav'); %audio signal |
| 19 blockSize = 256; %size of audio frames | 26 blockSize = 256; %size of audio frames |
| 20 overlap = 0.5; %overlap between consecutive frames | 27 overlap = 0.5; %overlap between consecutive frames |
| 23 nActiveAtoms = fix(blockSize/100*percActiveAtoms); %number of active atoms | 30 nActiveAtoms = fix(blockSize/100*percActiveAtoms); %number of active atoms |
| 24 | 31 |
| 25 % Initial dictionaries | 32 % Initial dictionaries |
| 26 gaborParam = struct('N',blockSize,'redundancyFactor',2,'wd',@rectwin); | 33 gaborParam = struct('N',blockSize,'redundancyFactor',2,'wd',@rectwin); |
| 27 gaborDict = Gabor_Dictionary(gaborParam); | 34 gaborDict = Gabor_Dictionary(gaborParam); |
| 28 initDicts = {[],gaborDict}; | 35 initDicts = {[],gaborDict}; %cell containing initial dictionaries |
| 29 | 36 |
| 30 %% Generate audio approximation problem | 37 %% Generate audio approximation problem |
| 31 signal = buffer(signal,blockSize,blockSize*overlap,@rectwin); %buffer frames of audio into columns of the matrix S | 38 %buffer frames of audio into columns of the matrix S |
| 39 signal = buffer(signal,blockSize,blockSize*overlap,@rectwin); | |
| 32 SMALL.Problem.b = signal.S; | 40 SMALL.Problem.b = signal.S; |
| 33 SMALL.Problem.b1 = SMALL.Problem.b; % copy signals from training set b to test set b1 (needed for later functions) | |
| 34 | 41 |
| 35 % omp2 sparse representation solver | 42 %copy signals from training set b to test set b1 (needed for later functions) |
| 36 ompParam = struct('X',SMALL.Problem.b,'epsilon',epsilon,'maxatoms',nActiveAtoms); %parameters | 43 SMALL.Problem.b1 = SMALL.Problem.b; |
| 44 | |
| 45 %omp2 sparse representation solver | |
| 46 ompParam = struct('X',SMALL.Problem.b,... | |
| 47 'epsilon',epsilon,... | |
| 48 'maxatoms',nActiveAtoms); %parameters | |
| 37 solver = SMALL_init_solver('ompbox','omp2',ompParam,false); %solver structure | 49 solver = SMALL_init_solver('ompbox','omp2',ompParam,false); %solver structure |
| 38 | 50 |
| 39 %% Test | 51 %% Test |
| 40 nInitDicts = length(initDicts); %number of initial dictionaries | 52 nInitDicts = length(initDicts);%number of initial dictionaries |
| 41 nZetas = length(zeta); | 53 nZetas = length(zeta); %number of incoherence penalty parameters |
| 42 nEtas = length(eta); | 54 nEtas = length(eta); %number of unit norm penalty parameters |
| 43 | 55 |
| 44 SMALL.DL(nTrials,nInitDicts,nZetas,nEtas) = SMALL_init_DL(toolbox); %create dictionary learning structures | 56 %create dictionary learning structures |
| 57 SMALL.DL(nTrials,nInitDicts,nZetas,nEtas) = SMALL_init_DL(toolbox); | |
| 45 for iTrial=1:nTrials | 58 for iTrial=1:nTrials |
| 46 for iInitDicts=1:nInitDicts | 59 for iInitDicts=1:nInitDicts |
| 47 for iZetas=1:nZetas | 60 for iZetas=1:nZetas |
| 48 for iEtas=1:nEtas | 61 for iEtas=1:nEtas |
| 49 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).toolbox = toolbox; | 62 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).toolbox = toolbox; |
| 50 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).name = dicUpdate; | 63 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).name = dicUpdate; |
| 51 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).profile = true; | 64 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).profile = true; |
| 52 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).param = ... | 65 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).param = ... |
| 53 struct('data',SMALL.Problem.b,... | 66 struct('data',SMALL.Problem.b,... %observed data |
| 54 'Tdata',nActiveAtoms,... | 67 'Tdata',nActiveAtoms,... %active atoms |
| 55 'dictsize',dictSize,... | 68 'dictsize',dictSize,... %number of atoms |
| 56 'iternum',iterNum,... | 69 'iternum',iterNum,... %number of iterations |
| 57 'memusage','high',... | 70 'memusage','high',... %memory usage |
| 58 'solver',solver,... | 71 'solver',solver,... %sparse approx solver |
| 59 'initdict',initDicts(iInitDicts),... | 72 'initdict',initDicts(iInitDicts),...%initial dictionary |
| 60 'zeta',zeta(iZetas),... | 73 'zeta',zeta(iZetas),... %incoherence penalty factor |
| 61 'eta',eta(iEtas)); | 74 'eta',eta(iEtas)); %unit norm penalty factor |
| 75 %learn dictionary | |
| 62 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas) = ... | 76 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas) = ... |
| 63 SMALL_learn(SMALL.Problem,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas)); | 77 SMALL_learn(SMALL.Problem,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas)); |
| 64 end | 78 end |
| 65 end | 79 end |
| 66 end | 80 end |
| 67 end | 81 end |
| 68 | 82 |
| 69 %% Evaluate coherence and snr of representation for the various methods | 83 %% Evaluate coherence and snr of representation |
| 70 sr = zeros(size(SMALL.DL)); %signal to noise ratio | 84 sr = zeros(size(SMALL.DL)); %signal to noise ratio |
| 71 mu = zeros(iTrial,iInitDicts,iZetas,iEtas); %coherence | 85 mu = zeros(iTrial,iInitDicts,iZetas,iEtas); %coherence |
| 72 dic(size(SMALL.DL)) = dictionary; %initialise dictionary objects | 86 dic(size(SMALL.DL)) = dictionary; %initialise dictionary objects |
| 73 for iTrial=1:nTrials | 87 for iTrial=1:nTrials |
| 74 for iInitDicts=1:nInitDicts | 88 for iInitDicts=1:nInitDicts |
| 75 for iZetas=1:nZetas | 89 for iZetas=1:nZetas |
| 76 for iEtas=1:nEtas | 90 for iEtas=1:nEtas |
| 77 %Sparse representation | 91 %Sparse representation |
| 78 SMALL.Problem.A = SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D; | 92 SMALL.Problem.A = SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D; |
| 79 tempSolver = SMALL_solve(SMALL.Problem,solver); | 93 tempSolver = SMALL_solve(SMALL.Problem,solver); |
| 80 %calculate snr | 94 %calculate snr |
| 81 sr(iTrial,iInitDicts,iZetas,iEtas) = ... | 95 sr(iTrial,iInitDicts,iZetas,iEtas) = ... |
| 82 snr(SMALL.Problem.b,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D*tempSolver.solution); | 96 snr(SMALL.Problem.b,... |
| 97 SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D*tempSolver.solution); | |
| 83 %calculate mu | 98 %calculate mu |
| 84 dic(iTrial,iInitDicts,iZetas,iEtas) = ... | 99 dic(iTrial,iInitDicts,iZetas,iEtas) = ... |
| 85 dictionary(SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D); | 100 dictionary(SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D); |
| 86 mu(iTrial,iInitDicts,iZetas,iEtas) = ... | 101 mu(iTrial,iInitDicts,iZetas,iEtas) = ... |
| 87 dic(iTrial,iInitDicts,iZetas,iEtas).coherence; | 102 dic(iTrial,iInitDicts,iZetas,iEtas).coherence; |
| 88 end | 103 end |
| 89 end | 104 end |
| 90 end | 105 end |
| 91 end | 106 end |
| 92 | 107 |
| 93 save('MOCOD.mat') | |
| 94 | |
| 95 %% Plot results | 108 %% Plot results |
| 96 minMu = sqrt((dictSize-blockSize)/(blockSize*(dictSize-1))); %lowe bound on coherence | 109 minMu = sqrt((dictSize-blockSize)/(blockSize*(dictSize-1)));%lower bound on coherence |
| 97 initDictsNames = {'Data','Gabor'}; | 110 initDictsNames = {'Data','Gabor'}; %names of initial dictionaries |
| 98 lineStyles = {'k.-','r*-','b+-'}; | 111 lineStyles = {'k.-','r*-','b+-'}; |
| 99 for iInitDict=1:nInitDicts | 112 for iInitDict=1:nInitDicts |
| 100 figure, hold on, grid on | 113 figure, hold on, grid on |
| 101 title([initDictsNames{iInitDict} ' Initialisation']); | 114 %print initial dictionary as figure title |
| 115 DisplayFigureTitle([initDictsNames{iInitDict} ' Initialisation']); | |
| 116 % set(gcf,'Units','Normalized'); | |
| 117 % txh = annotation(gcf,'textbox',[0.4,0.95,0.2,0.05]); | |
| 118 % set(txh,'String',[initDictsNames{iInitDict} ' Initialisation'],... | |
| 119 % 'LineStyle','none','HorizontalAlignment','center'); | |
| 120 %calculate mean coherence levels and SNRs over trials | |
| 102 coherenceLevels = squeeze(mean(mu(:,iInitDict,:,:),1)); | 121 coherenceLevels = squeeze(mean(mu(:,iInitDict,:,:),1)); |
| 103 meanSNRs = squeeze(mean(sr(:,iInitDict,:,:),1)); | 122 meanSNRs = squeeze(mean(sr(:,iInitDict,:,:),1)); |
| 104 %stdSNRs = squeeze(std(sr(:,iInitDict,iZetas,iEtas),0,1)); | 123 % stdSNRs = squeeze(std(sr(:,iInitDict,iZetas,iEtas),0,1)); |
| 124 %plot coherence levels | |
| 105 subplot(2,2,1) | 125 subplot(2,2,1) |
| 106 surf(eta,zeta,coherenceLevels); | 126 surf(eta,zeta,coherenceLevels); |
| 107 set(gca,'Xscale','log','Yscale','log','ZLim',[0 1.4]); | 127 set(gca,'Xscale','log','Yscale','log','ZLim',[0 1.4]); |
| 108 view(gca,130,20) | 128 view(gca,130,20) |
| 109 xlabel('\eta'); | 129 xlabel('\eta'); |
| 110 ylabel('\zeta'); | 130 ylabel('\zeta'); |
| 111 zlabel('\mu'); | 131 zlabel('\mu'); |
| 112 title('Coherence') | 132 title('Coherence') |
| 113 | 133 |
| 134 %plot SNRs | |
| 114 subplot(2,2,2) | 135 subplot(2,2,2) |
| 115 surf(eta,zeta,meanSNRs); | 136 surf(eta,zeta,meanSNRs); |
| 116 set(gca,'Xscale','log','Yscale','log','ZLim',[0 25]); | 137 set(gca,'Xscale','log','Yscale','log','ZLim',[0 25]); |
| 117 view(gca,130,20) | 138 view(gca,130,20) |
| 118 xlabel('\eta'); | 139 xlabel('\eta'); |
| 119 ylabel('\zeta'); | 140 ylabel('\zeta'); |
| 120 zlabel('SNR (dB)'); | 141 zlabel('SNR (dB)'); |
| 121 title('Reconstruction Error') | 142 title('Reconstruction Error') |
| 122 | 143 |
| 144 %plot mu/SNR scatter | |
| 123 subplot(2,2,[3 4]) | 145 subplot(2,2,[3 4]) |
| 124 mus = mu(:,iInitDict,:,:); | 146 mus = mu(:,iInitDict,:,:); |
| 125 mus = mus(:); | 147 mus = mus(:); |
| 126 SNRs = sr(:,iInitDict,:,:); | 148 SNRs = sr(:,iInitDict,:,:); |
| 127 SNRs = SNRs(:); | 149 SNRs = SNRs(:); |
