annotate DL/two-step DL/SMALL_two_step_DL.m @ 210:f12a476a4977 luisf_dev

Added help comments to SMALL_two_step_DL.m
author bmailhe
date Wed, 21 Mar 2012 17:25:40 +0000
parents dfa795944aae
children fd0b5d36f6ad
rev   line source
ivan@152 1 function DL=SMALL_two_step_DL(Problem, DL)
bmailhe@210 2
bmailhe@210 3 %% DL=SMALL_two_step_DL(Problem, DL) learn a dictionary using two_step_DL
bmailhe@210 4 % The specific parameters of the DL structure are:
bmailhe@210 5 % -name: can be either 'ols', 'opt', 'MOD', KSVD' or 'LGD'.
bmailhe@210 6 % -param.learningRate: a step size used by 'ols' and 'opt'. Default: 0.1
bmailhe@210 7 % for 'ols', 1 for 'opt'.
bmailhe@210 8 % -param.flow: can be either 'sequential' or 'parallel'. De fault:
bmailhe@210 9 % 'sequential'. Not used by MOD.
bmailhe@210 10 % -param.coherence: a real number between 0 and 1. If present, then
bmailhe@210 11 % a low-coherence constraint is added to the learning.
bmailhe@210 12 %
bmailhe@210 13 % See dico_update.m for more details.
ivan@152 14
ivan@152 15 % determine which solver is used for sparse representation %
ivan@152 16
ivan@152 17 solver = DL.param.solver;
ivan@152 18
bmailhe@209 19 % determine which type of udate to use ('KSVD', 'MOD', 'ols', 'opt' or 'LGD') %
ivan@152 20
ivan@152 21 typeUpdate = DL.name;
ivan@152 22
ivan@152 23 sig = Problem.b;
ivan@152 24
ivan@152 25 % determine dictionary size %
ivan@152 26
ivan@152 27 if (isfield(DL.param,'initdict'))
ivan@152 28 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@152 29 dictsize = length(DL.param.initdict);
ivan@152 30 else
ivan@152 31 dictsize = size(DL.param.initdict,2);
ivan@152 32 end
ivan@152 33 end
ivan@152 34 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
ivan@152 35 dictsize = DL.param.dictsize;
ivan@152 36 end
ivan@152 37
ivan@152 38 if (size(sig,2) < dictsize)
ivan@152 39 error('Number of training signals is smaller than number of atoms to train');
ivan@152 40 end
ivan@152 41
ivan@152 42
ivan@152 43 % initialize the dictionary %
ivan@152 44
luis@198 45 if (isfield(DL.param,'initdict'))
ivan@152 46 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@152 47 dico = sig(:,DL.param.initdict(1:dictsize));
ivan@152 48 else
ivan@152 49 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
ivan@152 50 error('Invalid initial dictionary');
ivan@152 51 end
ivan@152 52 dico = DL.param.initdict(:,1:dictsize);
ivan@152 53 end
ivan@152 54 else
ivan@152 55 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
ivan@152 56 perm = randperm(length(data_ids));
ivan@152 57 dico = sig(:,data_ids(perm(1:dictsize)));
ivan@152 58 end
ivan@152 59
ivan@152 60 % flow: 'sequential' or 'parallel'. If sequential, the residual is updated
ivan@152 61 % after each atom update. If parallel, the residual is only updated once
ivan@152 62 % the whole dictionary has been computed. Sequential works better, there
ivan@152 63 % may be no need to implement parallel. Not used with MOD.
ivan@152 64
ivan@152 65 if isfield(DL.param,'flow')
ivan@152 66 flow = DL.param.flow;
ivan@152 67 else
ivan@152 68 flow = 'sequential';
ivan@152 69 end
ivan@152 70
ivan@152 71 % learningRate. If the type is 'ols', it is the descent step of
bmailhe@210 72 % the gradient (default: 0.1). If the type is 'mailhe', the
bmailhe@210 73 % descent step is the optimal step*rho (default: 1, although 2 works
bmailhe@209 74 % better). Not used for MOD and KSVD.
ivan@152 75
ivan@152 76 if isfield(DL.param,'learningRate')
ivan@152 77 learningRate = DL.param.learningRate;
ivan@152 78 else
bmailhe@209 79 switch typeUpdate
bmailhe@209 80 case 'ols'
bmailhe@209 81 learningRate = 0.1;
bmailhe@209 82 otherwise
bmailhe@209 83 learningRate = 1;
bmailhe@209 84 end
ivan@152 85 end
ivan@152 86
ivan@152 87 % number of iterations (default is 40) %
ivan@152 88
ivan@152 89 if isfield(DL.param,'iternum')
ivan@152 90 iternum = DL.param.iternum;
ivan@152 91 else
ivan@152 92 iternum = 40;
ivan@152 93 end
ivan@152 94 % determine if we should do decorrelation in every iteration %
ivan@152 95
ivan@152 96 if isfield(DL.param,'coherence')
ivan@152 97 decorrelate = 1;
ivan@152 98 mu = DL.param.coherence;
ivan@152 99 else
ivan@152 100 decorrelate = 0;
ivan@152 101 end
ivan@152 102
ivan@152 103 % show dictonary every specified number of iterations
ivan@152 104
ivan@153 105 if isfield(DL.param,'show_dict')
ivan@152 106 show_dictionary=1;
ivan@152 107 show_iter=DL.param.show_dict;
ivan@152 108 else
ivan@152 109 show_dictionary=0;
ivan@152 110 show_iter=0;
ivan@152 111 end
ivan@152 112
ivan@152 113 % This is a small patch that needs to be resolved in dictionary learning we
ivan@152 114 % want sparse representation of training set, and in Problem.b1 in this
ivan@152 115 % version of software we store the signal that needs to be represented
ivan@152 116 % (for example the whole image)
ivan@152 117
ivan@152 118 tmpTraining = Problem.b1;
ivan@152 119 Problem.b1 = sig;
ivan@153 120 if isfield(Problem,'reconstruct')
ivan@153 121 Problem = rmfield(Problem, 'reconstruct');
ivan@153 122 end
ivan@152 123 solver.profile = 0;
ivan@152 124
ivan@152 125 % main loop %
ivan@152 126
ivan@152 127 for i = 1:iternum
ivan@153 128 Problem.A = dico;
ivan@152 129 solver = SMALL_solve(Problem, solver);
luis@198 130 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
luis@198 131 typeUpdate, flow, learningRate);
luis@198 132 if (decorrelate)
bmailhe@209 133 dico = dico_decorr_symetric(dico, mu, solver.solution);
luis@198 134 end
ivan@153 135
ivan@152 136 if ((show_dictionary)&&(mod(i,show_iter)==0))
ivan@152 137 dictimg = SMALL_showdict(dico,[8 8],...
ivan@152 138 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
ivan@152 139 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@152 140 pause(0.02);
ivan@152 141 end
ivan@152 142 end
ivan@152 143
ivan@152 144 Problem.b1 = tmpTraining;
ivan@152 145 DL.D = dico;
ivan@152 146
ivan@153 147 end
ivan@153 148
ivan@153 149 function Y = colnorms_squared(X)
ivan@153 150
ivan@153 151 % compute in blocks to conserve memory
ivan@153 152 Y = zeros(1,size(X,2));
ivan@153 153 blocksize = 2000;
ivan@153 154 for i = 1:blocksize:size(X,2)
ivan@153 155 blockids = i : min(i+blocksize-1,size(X,2));
ivan@153 156 Y(blockids) = sum(X(:,blockids).^2);
ivan@153 157 end
ivan@153 158
luis@198 159 end