Mercurial > hg > smallbox
view DL/two-step DL/SMALL_two_step_DL.m @ 167:8324c7ea6602 danieleb
Added symmetric de-correlation function, modified target de-correlation in test function.
author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
---|---|
date | Tue, 20 Sep 2011 14:27:14 +0100 |
parents | 88578ec2f94a |
children | 290cca7d3469 |
line wrap: on
line source
function DL=SMALL_two_step_DL(Problem, DL) % determine which solver is used for sparse representation % solver = DL.param.solver; % determine which type of udate to use ('KSVD', 'MOD', 'ols' or 'mailhe') % typeUpdate = DL.name; sig = Problem.b; % determine dictionary size % if (isfield(DL.param,'initdict')) if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) dictsize = length(DL.param.initdict); else dictsize = size(DL.param.initdict,2); end end if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict dictsize = DL.param.dictsize; end if (size(sig,2) < dictsize) error('Number of training signals is smaller than number of atoms to train'); end % initialize the dictionary % if (isfield(DL.param,'initdict')) if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) dico = sig(:,DL.param.initdict(1:dictsize)); else if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) error('Invalid initial dictionary'); end dico = DL.param.initdict(:,1:dictsize); end else data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen perm = randperm(length(data_ids)); dico = sig(:,data_ids(perm(1:dictsize))); end % flow: 'sequential' or 'parallel'. If sequential, the residual is updated % after each atom update. If parallel, the residual is only updated once % the whole dictionary has been computed. Sequential works better, there % may be no need to implement parallel. Not used with MOD. if isfield(DL.param,'flow') flow = DL.param.flow; else flow = 'sequential'; end % learningRate. If the type is 'ols', it is the descent step of % the gradient (typical choice: 0.1). If the type is 'mailhe', the % descent step is the optimal step*rho (typical choice: 1, although 2 % or 3 seems to work better). Not used for MOD and KSVD. if isfield(DL.param,'learningRate') learningRate = DL.param.learningRate; else learningRate = 0.1; end % number of iterations (default is 40) % if isfield(DL.param,'iternum') iternum = DL.param.iternum; else iternum = 40; end % determine if we should do decorrelation in every iteration % if isfield(DL.param,'coherence') && isscalar(DL.param.coherence) decorrelate = 1; mu = DL.param.coherence; else decorrelate = 0; end % show dictonary every specified number of iterations if isfield(DL.param,'show_dict') show_dictionary=1; show_iter=DL.param.show_dict; else show_dictionary=0; show_iter=0; end % This is a small patch that needs to be resolved in dictionary learning we % want sparse representation of training set, and in Problem.b1 in this % version of software we store the signal that needs to be represented % (for example the whole image) tmpTraining = Problem.b1; Problem.b1 = sig; if isfield(Problem,'reconstruct') Problem = rmfield(Problem, 'reconstruct'); end solver.profile = 0; % main loop % for i = 1:iternum disp([num2str(i) '/' num2str(iternum)]); Problem.A = dico; solver = SMALL_solve(Problem, solver); [dico, solver.solution] = dico_update(dico, sig, solver.solution, ... typeUpdate, flow, learningRate); dico = normcols(dico); switch DL.param.decFcn case 'mailhe' dico = dico_decorr_symetric(dico, mu, solver.solution); case 'tropp' [n m] = size(dico); dico = grassmannian(n,m,[],[],[],dico,true); otherwise end if ((show_dictionary)&&(mod(i,show_iter)==0)) dictimg = SMALL_showdict(dico,[8 8],... round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; pause(0.02); end end Problem.b1 = tmpTraining; DL.D = dico; end function Y = colnorms_squared(X) % compute in blocks to conserve memory Y = zeros(1,size(X,2)); blocksize = 2000; for i = 1:blocksize:size(X,2) blockids = i : min(i+blocksize-1,size(X,2)); Y(blockids) = sum(X(:,blockids).^2); end end