# HG changeset patch # User Daniele Barchiesi # Date 1321528705 0 # Node ID 714fa7b8c1ad9e018397581e2568f4eb75401a30 # Parent d0645d5fca7d451b52951ca15928a89a716aa051 added ramirez dl (to be completed) and MOCOD dictionary update diff -r d0645d5fca7d -r 714fa7b8c1ad DL/dl_ramirez.m --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/dl_ramirez.m Thu Nov 17 11:18:25 2011 +0000 @@ -0,0 +1,189 @@ +function DL = dl_ramirez(Problem,DL) +%% Dictionary learning with incoherent dictionary +% +% REFERENCE +% I. Ramirez, F. Lecumberry and G. Sapiro, Sparse modeling with universal +% priors and learned incoherent dictionaries. + +%% +% Centre for Digital Music, Queen Mary, University of London. +% This file copyright 2011 Daniele Barchiesi. +% +% This program is free software; you can redistribute it and/or +% modify it under the terms of the GNU General Public License as +% published by the Free Software Foundation; either version 2 of the +% License, or (at your option) any later version. See the file +% COPYING included with this distribution for more information. + +%% Test function +if ~nargin, testdl_ramirez; return; end + +%% Parameters & Defaults +X = Problem.b; %matrix of observed signals + +% determine dictionary size % +if (isfield(DL.param,'initdict')) %if the dictionary has been initialised + if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) + dictSize = length(DL.param.initdict); + else + dictSize = size(DL.param.initdict,2); + end +end +if (isfield(DL.param,'dictsize')) + dictSize = DL.param.dictsize; +end + +if (size(X,2) < dictSize) + error('Number of training signals is smaller than number of atoms to train'); +end + + +% initialize the dictionary % +if (isfield(DL.param,'initdict')) && ~isempty(DL.param.initdict); + if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) + D = X(:,DL.param.initdict(1:dictSize)); + else + if (size(DL.param.initdict,1)~=size(X,1) || size(DL.param.initdict,2) 1e-6); % ensure no zero data elements are chosen + perm = randperm(length(data_ids)); + D = X(:,data_ids(perm(1:dictSize))); +end + + +% coherence penalty factor +if isfield(DL.param,'zeta') + zeta = DL.param.zeta; +else + zeta = 0.1; +end + +% atoms norm penalty factor +if isfield(DL.param,'eta') + eta = DL.param.eta; +else + eta = 0.1; +end + +% number of iterations (default is 40) % +if isfield(DL.param,'iternum') + iternum = DL.param.iternum; +else + iternum = 40; +end + +% show dictonary every specified number of iterations +if isfield(DL.param,'show_dict') + show_dictionary=1; + show_iter=DL.param.show_dict; +else + show_dictionary=0; + show_iter=0; +end + +tmpTraining = Problem.b1; +Problem.b1 = X; +if isfield(Problem,'reconstruct') + Problem = rmfield(Problem, 'reconstruct'); +end + + +%% Main Algorithm +Dprev = D; %initial dictionary +Aprev = D\X; %set initial solution as pseudoinverse +for i = 1:iternum + %Sparse Coding by + A = sparsecoding(X,D,Aprev); + %Dictionary Update + D = dictionaryupdate(X,A,Dprev,zeta,eta); + + Dprev = D; + Aprev = A; + if ((show_dictionary)&&(mod(i,show_iter)==0)) + dictimg = SMALL_showdict(dico,[8 8],... + round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); + figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; + pause(0.02); + end +end + +Problem.b1 = tmpTraining; +DL.D = D; + +end + +function A = sparsecoding(X,D,Aprev) +%Sparse coding using a mixture of laplacians (MOL) as universal prior. + +%parameters +K = size(D,2); %number of atoms +M = size(X,2); %number of signals + +mu1 = mean(abs(Aprev(:))); %first moment of distribution of Aprev +mu2 = (norm(Aprev(:))^2)/numel(Aprev);%second moment of distribution of Aprev +kappa = 2*(mu2-mu1^2)/(mu2-2*mu2^2); %parameter kappa of the MOL distribution +beta = (kappa-1)*mu1; %parameter beta of the MOL distribution + +E = X-D*Aprev; %error term +sigmasq = mean(var(E)); %error variance +tau = 2*sigmasq*(kappa+1); %sparsity factor + +%solve a succession of subproblems to approximate the non-convex cost +%function +nIter = 10; %number of iterations of surrogate subproblem +Psi = zeros(K,M); %initialise solution of subproblem +for iIter=1:nIter + Reg = 1./(abs(Psi) + beta); + Psi = solvel1(X,D,tau,Reg); +end +A = Psi; +end + +function Psi = solvel1(X,D,tau,A) + [K M] = size(A); + Psi = zeros(K,M); + for m=1:M + cvx_begin quiet + variable v(K) + minimise (norm(X(:,m)-D*v) + tau*norm(A(:,m).*v,1)); + cvx_end + Psi(:,m) = v; + end +end + +function D = dictionaryupdate(X,A,Dprev,zeta,eta) + D = (X*A' + 2*(zeta + eta)*Dprev)/(A*A' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev))); +end + + + +function Y = colnorms_squared(X) +% compute in blocks to conserve memory +Y = zeros(1,size(X,2)); +blocksize = 2000; +for i = 1:blocksize:size(X,2) + blockids = i : min(i+blocksize-1,size(X,2)); + Y(blockids) = sum(X(:,blockids).^2); +end +end + +function testdl_ramirez + clc + N = 10; %ambient dimension + K = 20; %number of atoms + M = 30; %number of observed signals + X = randn(N,M); %observed signals + D = normcol(randn(N,K)); %initial dictionary + Problem.b = X; %sparse representation problem + Problem.b1 = X; + DL = SMALL_init_DL('dl_ramirez'); + DL.param.initdict = D; + DL.param = struct('initdict',D,... + 'zeta',0.5,... + 'eta',0.5); + DL = SMALL_learn(Problem,DL); +end diff -r d0645d5fca7d -r 714fa7b8c1ad examples/SMALL_test_mocod.m --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/examples/SMALL_test_mocod.m Thu Nov 17 11:18:25 2011 +0000 @@ -0,0 +1,145 @@ +clc, clear, close all + +%% Parameteres +nTrials = 10; %number of trials of the experiment + +% Dictionary learning parameters +toolbox = 'TwoStepDL'; %dictionary learning toolbox +dicUpdate = 'mocod'; %dictionary learning updates +zeta = logspace(-2,2,10); +eta = logspace(-2,2,10); + +iterNum = 20; %number of iterations +epsilon = 1e-6; %tolerance level +dictSize = 512; %number of atoms in the dictionary +percActiveAtoms = 5; %percentage of active atoms + +% Test signal parameters +signal = audio('music03_16kHz.wav'); %audio signal +blockSize = 256; %size of audio frames +overlap = 0.5; %overlap between consecutive frames + +% Dependent parameters +nActiveAtoms = fix(blockSize/100*percActiveAtoms); %number of active atoms + +% Initial dictionaries +gaborParam = struct('N',blockSize,'redundancyFactor',2,'wd',@rectwin); +gaborDict = Gabor_Dictionary(gaborParam); +initDicts = {[],gaborDict}; + +%% Generate audio approximation problem +signal = buffer(signal,blockSize,blockSize*overlap,@rectwin); %buffer frames of audio into columns of the matrix S +SMALL.Problem.b = signal.S; +SMALL.Problem.b1 = SMALL.Problem.b; % copy signals from training set b to test set b1 (needed for later functions) + +% omp2 sparse representation solver +ompParam = struct('X',SMALL.Problem.b,'epsilon',epsilon,'maxatoms',nActiveAtoms); %parameters +solver = SMALL_init_solver('ompbox','omp2',ompParam,false); %solver structure + +%% Test +nInitDicts = length(initDicts); %number of initial dictionaries +nZetas = length(zeta); +nEtas = length(eta); + +SMALL.DL(nTrials,nInitDicts,nZetas,nEtas) = SMALL_init_DL(toolbox); %create dictionary learning structures +for iTrial=1:nTrials + for iInitDicts=1:nInitDicts + for iZetas=1:nZetas + for iEtas=1:nEtas + SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).toolbox = toolbox; + SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).name = dicUpdate; + SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).profile = true; + SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).param = ... + struct('data',SMALL.Problem.b,... + 'Tdata',nActiveAtoms,... + 'dictsize',dictSize,... + 'iternum',iterNum,... + 'memusage','high',... + 'solver',solver,... + 'initdict',initDicts(iInitDicts),... + 'zeta',zeta(iZetas),... + 'eta',eta(iEtas)); + SMALL.DL(iTrial,iInitDicts,iZetas,iEtas) = ... + SMALL_learn(SMALL.Problem,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas)); + end + end + end +end + +%% Evaluate coherence and snr of representation for the various methods +sr = zeros(size(SMALL.DL)); %signal to noise ratio +mu = zeros(iTrial,iInitDicts,iZetas,iEtas); %coherence +dic(size(SMALL.DL)) = dictionary; %initialise dictionary objects +for iTrial=1:nTrials + for iInitDicts=1:nInitDicts + for iZetas=1:nZetas + for iEtas=1:nEtas + %Sparse representation + SMALL.Problem.A = SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D; + tempSolver = SMALL_solve(SMALL.Problem,solver); + %calculate snr + sr(iTrial,iInitDicts,iZetas,iEtas) = ... + snr(SMALL.Problem.b,SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D*tempSolver.solution); + %calculate mu + dic(iTrial,iInitDicts,iZetas,iEtas) = ... + dictionary(SMALL.DL(iTrial,iInitDicts,iZetas,iEtas).D); + mu(iTrial,iInitDicts,iZetas,iEtas) = ... + dic(iTrial,iInitDicts,iZetas,iEtas).coherence; + end + end + end +end + +save('MOCOD.mat') + +%% Plot results +minMu = sqrt((dictSize-blockSize)/(blockSize*(dictSize-1))); %lowe bound on coherence +initDictsNames = {'Data','Gabor'}; +lineStyles = {'k.-','r*-','b+-'}; +for iInitDict=1:nInitDicts + figure, hold on, grid on + title([initDictsNames{iInitDict} ' Initialisation']); + coherenceLevels = squeeze(mean(mu(:,iInitDict,:,:),1)); + meanSNRs = squeeze(mean(sr(:,iInitDict,:,:),1)); + %stdSNRs = squeeze(std(sr(:,iInitDict,iZetas,iEtas),0,1)); + subplot(2,2,1) + surf(eta,zeta,coherenceLevels); + set(gca,'Xscale','log','Yscale','log','ZLim',[0 1.4]); + view(gca,130,20) + xlabel('\eta'); + ylabel('\zeta'); + zlabel('\mu'); + title('Coherence') + + subplot(2,2,2) + surf(eta,zeta,meanSNRs); + set(gca,'Xscale','log','Yscale','log','ZLim',[0 25]); + view(gca,130,20) + xlabel('\eta'); + ylabel('\zeta'); + zlabel('SNR (dB)'); + title('Reconstruction Error') + + subplot(2,2,[3 4]) + mus = mu(:,iInitDict,:,:); + mus = mus(:); + SNRs = sr(:,iInitDict,:,:); + SNRs = SNRs(:); + [un idx] = sort(mus); + plot([1 1],[0 25],'k') + hold on, grid on + scatter(mus(idx),SNRs(idx),'k+'); + plot([minMu minMu],[0 25],'k--') + set(gca,'YLim',[0 25],'XLim',[0 1.4]); + xlabel('\mu'); + ylabel('SNR (dB)'); + legend([{'\mu_{max}'},'MOCOD',{'\mu_{min}'}]); + title('Coherence-Reconstruction Error Tradeoff') + +% plot([minMu minMu],[0 25],'k--') +% +% set(gca,'YLim',[0 25],'XLim',[0 1.4]); +% legend([{'\mu_{max}'},dicDecorrNames,{'\mu_{min}'}]); +% xlabel('\mu'); +% ylabel('SNR (dB)'); +end