Mercurial > hg > smallbox
view DL/two-step DL/SMALL_two_step_DL.m @ 218:c38d965b5a1d luisf_dev
Moved CVX_add_const_Audio_declipping.m from solvers to the examples/AudioInpainting folder.
author | Aris Gretsistas <aris.gretsistas@eecs.qmul.ac.uk> |
---|---|
date | Thu, 22 Mar 2012 15:37:45 +0000 |
parents | f12a476a4977 |
children | fd0b5d36f6ad |
line wrap: on
line source
function DL=SMALL_two_step_DL(Problem, DL) %% DL=SMALL_two_step_DL(Problem, DL) learn a dictionary using two_step_DL % The specific parameters of the DL structure are: % -name: can be either 'ols', 'opt', 'MOD', KSVD' or 'LGD'. % -param.learningRate: a step size used by 'ols' and 'opt'. Default: 0.1 % for 'ols', 1 for 'opt'. % -param.flow: can be either 'sequential' or 'parallel'. De fault: % 'sequential'. Not used by MOD. % -param.coherence: a real number between 0 and 1. If present, then % a low-coherence constraint is added to the learning. % % See dico_update.m for more details. % determine which solver is used for sparse representation % solver = DL.param.solver; % determine which type of udate to use ('KSVD', 'MOD', 'ols', 'opt' or 'LGD') % typeUpdate = DL.name; sig = Problem.b; % determine dictionary size % if (isfield(DL.param,'initdict')) if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) dictsize = length(DL.param.initdict); else dictsize = size(DL.param.initdict,2); end end if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict dictsize = DL.param.dictsize; end if (size(sig,2) < dictsize) error('Number of training signals is smaller than number of atoms to train'); end % initialize the dictionary % if (isfield(DL.param,'initdict')) if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) dico = sig(:,DL.param.initdict(1:dictsize)); else if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) error('Invalid initial dictionary'); end dico = DL.param.initdict(:,1:dictsize); end else data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen perm = randperm(length(data_ids)); dico = sig(:,data_ids(perm(1:dictsize))); end % flow: 'sequential' or 'parallel'. If sequential, the residual is updated % after each atom update. If parallel, the residual is only updated once % the whole dictionary has been computed. Sequential works better, there % may be no need to implement parallel. Not used with MOD. if isfield(DL.param,'flow') flow = DL.param.flow; else flow = 'sequential'; end % learningRate. If the type is 'ols', it is the descent step of % the gradient (default: 0.1). If the type is 'mailhe', the % descent step is the optimal step*rho (default: 1, although 2 works % better). Not used for MOD and KSVD. if isfield(DL.param,'learningRate') learningRate = DL.param.learningRate; else switch typeUpdate case 'ols' learningRate = 0.1; otherwise learningRate = 1; end end % number of iterations (default is 40) % if isfield(DL.param,'iternum') iternum = DL.param.iternum; else iternum = 40; end % determine if we should do decorrelation in every iteration % if isfield(DL.param,'coherence') decorrelate = 1; mu = DL.param.coherence; else decorrelate = 0; end % show dictonary every specified number of iterations if isfield(DL.param,'show_dict') show_dictionary=1; show_iter=DL.param.show_dict; else show_dictionary=0; show_iter=0; end % This is a small patch that needs to be resolved in dictionary learning we % want sparse representation of training set, and in Problem.b1 in this % version of software we store the signal that needs to be represented % (for example the whole image) tmpTraining = Problem.b1; Problem.b1 = sig; if isfield(Problem,'reconstruct') Problem = rmfield(Problem, 'reconstruct'); end solver.profile = 0; % main loop % for i = 1:iternum Problem.A = dico; solver = SMALL_solve(Problem, solver); [dico, solver.solution] = dico_update(dico, sig, solver.solution, ... typeUpdate, flow, learningRate); if (decorrelate) dico = dico_decorr_symetric(dico, mu, solver.solution); end if ((show_dictionary)&&(mod(i,show_iter)==0)) dictimg = SMALL_showdict(dico,[8 8],... round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; pause(0.02); end end Problem.b1 = tmpTraining; DL.D = dico; end function Y = colnorms_squared(X) % compute in blocks to conserve memory Y = zeros(1,size(X,2)); blocksize = 2000; for i = 1:blocksize:size(X,2) blockids = i : min(i+blocksize-1,size(X,2)); Y(blockids) = sum(X(:,blockids).^2); end end