Mercurial > hg > smallbox
diff DL/two-step DL/SMALL_two_step_DL.m @ 152:485747bf39e0 ivand_dev
Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author | Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk> |
---|---|
date | Thu, 28 Jul 2011 15:49:32 +0100 |
parents | |
children | af307f247ac7 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/two-step DL/SMALL_two_step_DL.m Thu Jul 28 15:49:32 2011 +0100 @@ -0,0 +1,127 @@ +function DL=SMALL_two_step_DL(Problem, DL) + +% determine which solver is used for sparse representation % + +solver = DL.param.solver; + +% determine which type of udate to use ('KSVD', 'MOD', 'ols' or 'mailhe') % + +typeUpdate = DL.name; + +sig = Problem.b; + +% determine dictionary size % + +if (isfield(DL.param,'initdict')) + if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) + dictsize = length(DL.param.initdict); + else + dictsize = size(DL.param.initdict,2); + end +end +if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict + dictsize = DL.param.dictsize; +end + +if (size(sig,2) < dictsize) + error('Number of training signals is smaller than number of atoms to train'); +end + + +% initialize the dictionary % + +if (isfield(DL.param,'initdict')) + if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) + dico = sig(:,DL.param.initdict(1:dictsize)); + else + if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) + error('Invalid initial dictionary'); + end + dico = DL.param.initdict(:,1:dictsize); + end +else + data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen + perm = randperm(length(data_ids)); + dico = sig(:,data_ids(perm(1:dictsize))); +end + +% flow: 'sequential' or 'parallel'. If sequential, the residual is updated +% after each atom update. If parallel, the residual is only updated once +% the whole dictionary has been computed. Sequential works better, there +% may be no need to implement parallel. Not used with MOD. + +if isfield(DL.param,'flow') + flow = DL.param.flow; +else + flow = 'sequential'; +end + +% learningRate. If the type is 'ols', it is the descent step of +% the gradient (typical choice: 0.1). If the type is 'mailhe', the +% descent step is the optimal step*rho (typical choice: 1, although 2 +% or 3 seems to work better). Not used for MOD and KSVD. + +if isfield(DL.param,'learningRate') + learningRate = DL.param.learningRate; +else + learningRate = 0.1; +end + +% number of iterations (default is 40) % + +if isfield(DL.param,'iternum') + iternum = DL.param.iternum; +else + iternum = 40; +end +% determine if we should do decorrelation in every iteration % + +if isfield(DL.param,'coherence') + decorrelate = 1; + mu = DL.param.coherence; +else + decorrelate = 0; +end + +% show dictonary every specified number of iterations + +if (isfield(DL.param,'show_dict')) + show_dictionary=1; + show_iter=DL.param.show_dict; +else + show_dictionary=0; + show_iter=0; +end + +% This is a small patch that needs to be resolved in dictionary learning we +% want sparse representation of training set, and in Problem.b1 in this +% version of software we store the signal that needs to be represented +% (for example the whole image) + +tmpTraining = Problem.b1; +Problem.b1 = sig; +Problem = rmfield(Problem, 'reconstruct'); +solver.profile = 0; + +% main loop % + +for i = 1:iternum + solver = SMALL_solve(Problem, solver); + [dico, solver.solution] = dico_update(dico, sig, solver.solution, ... + typeUpdate, flow, learningRate); + if (decorrelate) + dico = dico_decorr(dico, mu, solver.solution); + end + Problem.A = dico; + if ((show_dictionary)&&(mod(i,show_iter)==0)) + dictimg = SMALL_showdict(dico,[8 8],... + round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); + figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; + pause(0.02); + end +end + +Problem.b1 = tmpTraining; +DL.D = dico; + +end \ No newline at end of file