ivan@152: function DL=SMALL_two_step_DL(Problem, DL) ivan@152: ivan@152: % determine which solver is used for sparse representation % ivan@152: ivan@152: solver = DL.param.solver; ivan@152: ivan@152: % determine which type of udate to use ('KSVD', 'MOD', 'ols' or 'mailhe') % ivan@152: ivan@152: typeUpdate = DL.name; ivan@152: ivan@152: sig = Problem.b; ivan@152: ivan@152: % determine dictionary size % ivan@152: ivan@152: if (isfield(DL.param,'initdict')) ivan@152: if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) ivan@152: dictsize = length(DL.param.initdict); ivan@152: else ivan@152: dictsize = size(DL.param.initdict,2); ivan@152: end ivan@152: end ivan@152: if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict ivan@152: dictsize = DL.param.dictsize; ivan@152: end ivan@152: ivan@152: if (size(sig,2) < dictsize) ivan@152: error('Number of training signals is smaller than number of atoms to train'); ivan@152: end ivan@152: ivan@152: ivan@152: % initialize the dictionary % ivan@152: luis@190: % todo: check second if statement luis@190: if (isfield(DL.param,'initdict')) && ~isempty(DL.param.initdict); ivan@152: if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) ivan@152: dico = sig(:,DL.param.initdict(1:dictsize)); ivan@152: else ivan@152: if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2) 1e-6); % ensure no zero data elements are chosen ivan@152: perm = randperm(length(data_ids)); ivan@152: dico = sig(:,data_ids(perm(1:dictsize))); ivan@152: end ivan@152: ivan@152: % flow: 'sequential' or 'parallel'. If sequential, the residual is updated ivan@152: % after each atom update. If parallel, the residual is only updated once ivan@152: % the whole dictionary has been computed. Sequential works better, there ivan@152: % may be no need to implement parallel. Not used with MOD. ivan@152: ivan@152: if isfield(DL.param,'flow') ivan@152: flow = DL.param.flow; ivan@152: else ivan@152: flow = 'sequential'; ivan@152: end ivan@152: ivan@152: % learningRate. If the type is 'ols', it is the descent step of ivan@152: % the gradient (typical choice: 0.1). If the type is 'mailhe', the ivan@152: % descent step is the optimal step*rho (typical choice: 1, although 2 ivan@152: % or 3 seems to work better). Not used for MOD and KSVD. ivan@152: ivan@152: if isfield(DL.param,'learningRate') ivan@152: learningRate = DL.param.learningRate; ivan@152: else ivan@152: learningRate = 0.1; ivan@152: end ivan@152: ivan@152: % number of iterations (default is 40) % ivan@152: ivan@152: if isfield(DL.param,'iternum') ivan@152: iternum = DL.param.iternum; ivan@152: else ivan@152: iternum = 40; ivan@152: end ivan@152: % determine if we should do decorrelation in every iteration % ivan@152: ivan@152: if isfield(DL.param,'coherence') ivan@152: decorrelate = 1; ivan@152: mu = DL.param.coherence; ivan@152: else ivan@152: decorrelate = 0; ivan@152: end ivan@152: ivan@152: % show dictonary every specified number of iterations ivan@152: ivan@153: if isfield(DL.param,'show_dict') ivan@152: show_dictionary=1; ivan@152: show_iter=DL.param.show_dict; ivan@152: else ivan@152: show_dictionary=0; ivan@152: show_iter=0; ivan@152: end ivan@152: ivan@152: % This is a small patch that needs to be resolved in dictionary learning we ivan@152: % want sparse representation of training set, and in Problem.b1 in this ivan@152: % version of software we store the signal that needs to be represented ivan@152: % (for example the whole image) ivan@152: ivan@152: tmpTraining = Problem.b1; ivan@152: Problem.b1 = sig; ivan@153: if isfield(Problem,'reconstruct') ivan@153: Problem = rmfield(Problem, 'reconstruct'); ivan@153: end ivan@152: solver.profile = 0; ivan@152: ivan@152: % main loop % ivan@152: ivan@152: for i = 1:iternum ivan@153: Problem.A = dico; ivan@152: solver = SMALL_solve(Problem, solver); luis@190: luis@190: % configuration file luis@190: run([SMALL_path '/config/SMALL_two_step_DL_config.m']) luis@190: luis@190: % [dico, solver.solution] = dico_update(dico, sig, solver.solution, ... luis@190: % typeUpdate, flow, learningRate); luis@190: % if (decorrelate) luis@190: % dico = dico_decorr(dico, mu, solver.solution); luis@190: % end ivan@153: ivan@152: if ((show_dictionary)&&(mod(i,show_iter)==0)) ivan@152: dictimg = SMALL_showdict(dico,[8 8],... ivan@152: round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); ivan@152: figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; ivan@152: pause(0.02); ivan@152: end ivan@152: end ivan@152: ivan@152: Problem.b1 = tmpTraining; ivan@152: DL.D = dico; ivan@152: ivan@153: end ivan@153: ivan@153: function Y = colnorms_squared(X) ivan@153: ivan@153: % compute in blocks to conserve memory ivan@153: Y = zeros(1,size(X,2)); ivan@153: blocksize = 2000; ivan@153: for i = 1:blocksize:size(X,2) ivan@153: blockids = i : min(i+blocksize-1,size(X,2)); ivan@153: Y(blockids) = sum(X(:,blockids).^2); ivan@153: end ivan@153: ivan@152: end