annotate DL/two-step DL/SMALL_two_step_DL.m @ 160:e3035d45d014 danieleb

Added support classes
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Wed, 31 Aug 2011 10:53:10 +0100
parents a4d0977d4595
children 88578ec2f94a
rev   line source
ivan@152 1 function DL=SMALL_two_step_DL(Problem, DL)
ivan@152 2
ivan@152 3 % determine which solver is used for sparse representation %
ivan@152 4
ivan@152 5 solver = DL.param.solver;
ivan@152 6
ivan@152 7 % determine which type of udate to use ('KSVD', 'MOD', 'ols' or 'mailhe') %
ivan@152 8
ivan@152 9 typeUpdate = DL.name;
ivan@152 10
ivan@152 11 sig = Problem.b;
ivan@152 12
ivan@152 13 % determine dictionary size %
ivan@152 14
ivan@152 15 if (isfield(DL.param,'initdict'))
ivan@152 16 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@152 17 dictsize = length(DL.param.initdict);
ivan@152 18 else
ivan@152 19 dictsize = size(DL.param.initdict,2);
ivan@152 20 end
ivan@152 21 end
ivan@152 22 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
ivan@152 23 dictsize = DL.param.dictsize;
ivan@152 24 end
ivan@152 25
ivan@152 26 if (size(sig,2) < dictsize)
ivan@152 27 error('Number of training signals is smaller than number of atoms to train');
ivan@152 28 end
ivan@152 29
ivan@152 30
ivan@152 31 % initialize the dictionary %
ivan@152 32
ivan@152 33 if (isfield(DL.param,'initdict'))
ivan@152 34 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@152 35 dico = sig(:,DL.param.initdict(1:dictsize));
ivan@152 36 else
ivan@152 37 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
ivan@152 38 error('Invalid initial dictionary');
ivan@152 39 end
ivan@152 40 dico = DL.param.initdict(:,1:dictsize);
ivan@152 41 end
ivan@152 42 else
ivan@152 43 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
ivan@152 44 perm = randperm(length(data_ids));
ivan@152 45 dico = sig(:,data_ids(perm(1:dictsize)));
ivan@152 46 end
ivan@152 47
ivan@152 48 % flow: 'sequential' or 'parallel'. If sequential, the residual is updated
ivan@152 49 % after each atom update. If parallel, the residual is only updated once
ivan@152 50 % the whole dictionary has been computed. Sequential works better, there
ivan@152 51 % may be no need to implement parallel. Not used with MOD.
ivan@152 52
ivan@152 53 if isfield(DL.param,'flow')
ivan@152 54 flow = DL.param.flow;
ivan@152 55 else
ivan@152 56 flow = 'sequential';
ivan@152 57 end
ivan@152 58
ivan@152 59 % learningRate. If the type is 'ols', it is the descent step of
ivan@152 60 % the gradient (typical choice: 0.1). If the type is 'mailhe', the
ivan@152 61 % descent step is the optimal step*rho (typical choice: 1, although 2
ivan@152 62 % or 3 seems to work better). Not used for MOD and KSVD.
ivan@152 63
ivan@152 64 if isfield(DL.param,'learningRate')
ivan@152 65 learningRate = DL.param.learningRate;
ivan@152 66 else
ivan@152 67 learningRate = 0.1;
ivan@152 68 end
ivan@152 69
ivan@152 70 % number of iterations (default is 40) %
ivan@152 71
ivan@152 72 if isfield(DL.param,'iternum')
ivan@152 73 iternum = DL.param.iternum;
ivan@152 74 else
ivan@152 75 iternum = 40;
ivan@152 76 end
ivan@152 77 % determine if we should do decorrelation in every iteration %
ivan@152 78
danieleb@156 79 if isfield(DL.param,'coherence') && isscalar(DL.param.coherence)
ivan@152 80 decorrelate = 1;
ivan@152 81 mu = DL.param.coherence;
ivan@152 82 else
ivan@152 83 decorrelate = 0;
ivan@152 84 end
ivan@152 85
ivan@152 86 % show dictonary every specified number of iterations
ivan@152 87
ivan@153 88 if isfield(DL.param,'show_dict')
ivan@152 89 show_dictionary=1;
ivan@152 90 show_iter=DL.param.show_dict;
ivan@152 91 else
ivan@152 92 show_dictionary=0;
ivan@152 93 show_iter=0;
ivan@152 94 end
ivan@152 95
ivan@152 96 % This is a small patch that needs to be resolved in dictionary learning we
ivan@152 97 % want sparse representation of training set, and in Problem.b1 in this
ivan@152 98 % version of software we store the signal that needs to be represented
ivan@152 99 % (for example the whole image)
ivan@152 100
ivan@152 101 tmpTraining = Problem.b1;
ivan@152 102 Problem.b1 = sig;
ivan@153 103 if isfield(Problem,'reconstruct')
ivan@153 104 Problem = rmfield(Problem, 'reconstruct');
ivan@153 105 end
ivan@152 106 solver.profile = 0;
ivan@152 107
ivan@152 108 % main loop %
ivan@152 109
ivan@152 110 for i = 1:iternum
danieleb@156 111 disp([num2str(i) '/' num2str(iternum)]);
ivan@153 112 Problem.A = dico;
ivan@152 113 solver = SMALL_solve(Problem, solver);
ivan@152 114 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
ivan@152 115 typeUpdate, flow, learningRate);
danieleb@156 116 dico = normcols(dico);
danieleb@156 117 switch DL.param.decFcn
danieleb@156 118 case 'mailhe'
danieleb@156 119 dico = dico_decorr(dico, mu, solver.solution);
danieleb@156 120 case 'tropp'
danieleb@156 121 [n m] = size(dico);
danieleb@156 122 dico = grassmanian(n,m,[],[],[],dico,true);
danieleb@156 123 otherwise
danieleb@156 124 end
ivan@153 125
ivan@152 126 if ((show_dictionary)&&(mod(i,show_iter)==0))
ivan@152 127 dictimg = SMALL_showdict(dico,[8 8],...
ivan@152 128 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
ivan@152 129 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@152 130 pause(0.02);
ivan@152 131 end
ivan@152 132 end
ivan@152 133
ivan@152 134 Problem.b1 = tmpTraining;
ivan@152 135 DL.D = dico;
ivan@152 136
ivan@153 137 end
ivan@153 138
ivan@153 139 function Y = colnorms_squared(X)
ivan@153 140
ivan@153 141 % compute in blocks to conserve memory
ivan@153 142 Y = zeros(1,size(X,2));
ivan@153 143 blocksize = 2000;
ivan@153 144 for i = 1:blocksize:size(X,2)
ivan@153 145 blockids = i : min(i+blocksize-1,size(X,2));
ivan@153 146 Y(blockids) = sum(X(:,blockids).^2);
ivan@153 147 end
ivan@153 148
danieleb@156 149 end