view examples/SMALL_test_mocod.m @ 193:cc540df790f4 danieleb

Simple example that demonstrated dictionary learning... to be completed
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Fri, 09 Mar 2012 15:12:01 +0000
parents 0dc98f1c60bb
children
line wrap: on
line source
%% SMALL_test_mocod
% Script that tests the twostep dictionary learning algorithm with MOCOD
% dictionary update.
%
% REFERENCES
% D. Barchiesi and M. D. Plumbely, Learning incoherenct dictionaries for 
% sparse approximation using iterative projections and rotations.
%% Clear and close
clc, clear, close all

%% Parameteres
nTrials   = 10;					%number of trials of the experiment

% Dictionary learning parameters
toolbox   = 'TwoStepDL';		%dictionary learning toolbox
dicUpdate = 'mocod';			%dictionary learning updates
zeta	  = logspace(-2,2,10);	%range of values for the incoherence term				
eta  	  = logspace(-2,2,10);	%range of values for the unit norm term
iterNum   = 20;					%number of iterations
epsilon   = 1e-6;				%tolerance level
dictSize  = 512;				%number of atoms in the dictionary
percActiveAtoms = 5;			%percentage of active atoms

% Test signal parameters
signal    = audio('music03_16kHz.wav'); %audio signal
blockSize = 256;						%size of audio frames
overlap   = 0.5;						%overlap between consecutive frames

% Dependent parameters
nActiveAtoms = fix(blockSize/100*percActiveAtoms); %number of active atoms

% Initial dictionaries
gaborParam = struct('N',blockSize,'redundancyFactor',2,'wd',@rectwin);
gaborDict = Gabor_Dictionary(gaborParam);
initDicts = {[],gaborDict};				%cell containing initial dictionaries

%% Generate audio approximation problem
%buffer frames of audio into columns of the matrix S
signal			 = buffer(signal,blockSize,blockSize*overlap,@rectwin);	
SMALL.Problem.b  = signal.S;

%copy signals from training set b to test set b1 (needed for later functions)
SMALL.Problem.b1 = SMALL.Problem.b;

%omp2 sparse representation solver
ompParam = struct('X',SMALL.Problem.b,...
				  'epsilon',epsilon,...
				  'maxatoms',nActiveAtoms); %parameters
solver	 = SMALL_init_solver('ompbox','omp2',ompParam,false); %solver structure

%% Test
nInitDicts  = length(initDicts);%number of initial dictionaries
nZetas = length(zeta);			%number of incoherence penalty parameters
nEtas  = length(eta);			%number of unit norm penalty parameters

%create dictionary learning structures
SMALL.DL(nTrials,nInitDicts,nZetas,nEtas) = SMALL_init_DL(toolbox); 
for iTrial=1:nTrials
	for iInitDicts=1:nInitDicts
		for iZetas=1:nZetas
			for iEtas=1:nEtas
				SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).toolbox = toolbox;
				SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).name = dicUpdate;
				SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).profile = true;
				SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).param = ...
					struct('data',SMALL.Problem.b,...	%observed data
					'Tdata',nActiveAtoms,...			%active atoms
					'dictsize',dictSize,...				%number of atoms
					'iternum',iterNum,...				%number of iterations
					'memusage','high',...				%memory usage
					'solver',solver,...					%sparse approx solver
					'initdict',initDicts(iInitDicts),...%initial dictionary
					'zeta',zeta(iZetas),...				%incoherence penalty factor
					'eta',eta(iEtas));					%unit norm penalty factor
				%learn dictionary
				SMALL.DL(iTrial,iInitDicts,iZetas,iEtas) = ...
					SMALL_learn(SMALL.Problem,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas));
			end
		end
	end
end

%% Evaluate coherence and snr of representation
sr = zeros(size(SMALL.DL));					%signal to noise ratio
mu = zeros(iTrial,iInitDicts,iZetas,iEtas);	%coherence
dic(size(SMALL.DL)) = dictionary;			%initialise dictionary objects
for iTrial=1:nTrials
	for iInitDicts=1:nInitDicts
		for iZetas=1:nZetas
			for iEtas=1:nEtas
				%Sparse representation
				SMALL.Problem.A = SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D;
				tempSolver = SMALL_solve(SMALL.Problem,solver);
				%calculate snr
				sr(iTrial,iInitDicts,iZetas,iEtas) = ...
					snr(SMALL.Problem.b,...
					SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D*tempSolver.solution);
				%calculate mu
				dic(iTrial,iInitDicts,iZetas,iEtas) = ...
					dictionary(SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D);
				mu(iTrial,iInitDicts,iZetas,iEtas) = ...
					dic(iTrial,iInitDicts,iZetas,iEtas).coherence;
			end
		end
	end
end

%% Plot results
minMu = sqrt((dictSize-blockSize)/(blockSize*(dictSize-1)));%lower bound on coherence
initDictsNames = {'Data','Gabor'};							%names of initial dictionaries
lineStyles     = {'k.-','r*-','b+-'};						
for iInitDict=1:nInitDicts
	figure, hold on, grid on
	%print initial dictionary as figure title
	DisplayFigureTitle([initDictsNames{iInitDict} ' Initialisation']);
% 	set(gcf,'Units','Normalized');
% 	txh = annotation(gcf,'textbox',[0.4,0.95,0.2,0.05]);
% 	set(txh,'String',[initDictsNames{iInitDict} ' Initialisation'],...
% 		'LineStyle','none','HorizontalAlignment','center');
	%calculate mean coherence levels and SNRs over trials
	coherenceLevels = squeeze(mean(mu(:,iInitDict,:,:),1));
	meanSNRs		= squeeze(mean(sr(:,iInitDict,:,:),1));
%	stdSNRs		= squeeze(std(sr(:,iInitDict,iZetas,iEtas),0,1));
	%plot coherence levels
	subplot(2,2,1)
	surf(eta,zeta,coherenceLevels);
	set(gca,'Xscale','log','Yscale','log','ZLim',[0 1.4]);
	view(gca,130,20)
	xlabel('\eta');
	ylabel('\zeta');
	zlabel('\mu');
	title('Coherence')
	
	%plot SNRs
	subplot(2,2,2)
	surf(eta,zeta,meanSNRs);
	set(gca,'Xscale','log','Yscale','log','ZLim',[0 25]);
	view(gca,130,20)
	xlabel('\eta');
	ylabel('\zeta');
	zlabel('SNR (dB)');
	title('Reconstruction Error')
	
	%plot mu/SNR scatter
	subplot(2,2,[3 4])
	mus = mu(:,iInitDict,:,:);
	mus = mus(:);
	SNRs = sr(:,iInitDict,:,:);
	SNRs = SNRs(:);
	[un idx] = sort(mus);
	plot([1 1],[0 25],'k')
	hold on, grid on
	scatter(mus(idx),SNRs(idx),'k+');
	plot([minMu minMu],[0 25],'k--')
	set(gca,'YLim',[0 25],'XLim',[0 1.4]);
	xlabel('\mu');
	ylabel('SNR (dB)');
	legend([{'\mu_{max}'},'MOCOD',{'\mu_{min}'}]);
	title('Coherence-Reconstruction Error Tradeoff')
	
% 	plot([minMu minMu],[0 25],'k--')
% 	
% 	set(gca,'YLim',[0 25],'XLim',[0 1.4]);
% 	legend([{'\mu_{max}'},dicDecorrNames,{'\mu_{min}'}]);
% 	xlabel('\mu');
% 	ylabel('SNR (dB)');
end