Mercurial > hg > smallbox
diff DL/two-step DL/SMALL_two_step_DL.m @ 224:fd0b5d36f6ad danieleb
Updated the contents of this branch with the contents of the default branch.
author | luisf <luis.figueira@eecs.qmul.ac.uk> |
---|---|
date | Thu, 12 Apr 2012 13:52:28 +0100 |
parents | d0645d5fca7d f12a476a4977 |
children |
line wrap: on
line diff
--- a/DL/two-step DL/SMALL_two_step_DL.m Wed Mar 14 16:31:38 2012 +0000 +++ b/DL/two-step DL/SMALL_two_step_DL.m Thu Apr 12 13:52:28 2012 +0100 @@ -1,10 +1,22 @@ function DL=SMALL_two_step_DL(Problem, DL) + + %% DL=SMALL_two_step_DL(Problem, DL) learn a dictionary using two_step_DL + % The specific parameters of the DL structure are: + % -name: can be either 'ols', 'opt', 'MOD', KSVD' or 'LGD'. + % -param.learningRate: a step size used by 'ols' and 'opt'. Default: 0.1 + % for 'ols', 1 for 'opt'. + % -param.flow: can be either 'sequential' or 'parallel'. De fault: + % 'sequential'. Not used by MOD. + % -param.coherence: a real number between 0 and 1. If present, then + % a low-coherence constraint is added to the learning. + % + % See dico_update.m for more details. % determine which solver is used for sparse representation % solver = DL.param.solver; -% determine which type of udate to use ('KSVD', 'MOD','MOCOD','ols' or 'mailhe') % +% determine which type of udate to use ('KSVD', 'MOD', 'ols', 'opt' or 'LGD') % typeUpdate = DL.name; @@ -30,7 +42,7 @@ % initialize the dictionary % -if (isfield(DL.param,'initdict')) && ~isempty(DL.param.initdict); +if (isfield(DL.param,'initdict')) if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) dico = sig(:,DL.param.initdict(1:dictsize)); else @@ -57,14 +69,19 @@ end % learningRate. If the type is 'ols', it is the descent step of -% the gradient (typical choice: 0.1). If the type is 'mailhe', the -% descent step is the optimal step*rho (typical choice: 1, although 2 -% or 3 seems to work better). Not used for MOD and KSVD. +% the gradient (default: 0.1). If the type is 'mailhe', the +% descent step is the optimal step*rho (default: 1, although 2 works +% better). Not used for MOD and KSVD. if isfield(DL.param,'learningRate') learningRate = DL.param.learningRate; else - learningRate = 0.1; + switch typeUpdate + case 'ols' + learningRate = 0.1; + otherwise + learningRate = 1; + end end % number of iterations (default is 40) % @@ -76,15 +93,13 @@ end % determine if we should do decorrelation in every iteration % -if isfield(DL.param,'coherence') && isscalar(DL.param.coherence) +if isfield(DL.param,'coherence') decorrelate = 1; mu = DL.param.coherence; else decorrelate = 0; end -if ~isfield(DL.param,'decFcn'), DL.param.decFcn = 'none'; end - % show dictonary every specified number of iterations if isfield(DL.param,'show_dict') @@ -110,34 +125,13 @@ % main loop % for i = 1:iternum - %disp([num2str(i) '/' num2str(iternum)]); - %SPARSE CODING STEP - Problem.A = dico; + Problem.A = dico; solver = SMALL_solve(Problem, solver); - %DICTIONARY UPDATE STEP - if strcmpi(typeUpdate,'mocod') %if update is MOCOD create parameters structure - mocodParams = struct('zeta',DL.param.zeta,... %coherence regularization factor - 'eta',DL.param.eta,... %atoms norm regularization factor - 'Dprev',dico); %previous dictionary - dico = dico_update(dico,sig,solver.solution,typeUpdate,flow,learningRate,mocodParams); - else - [dico, solver.solution] = dico_update(dico, sig, solver.solution, ... - typeUpdate, flow, learningRate); - dico = normcols(dico); - end - - switch lower(DL.param.decFcn) - case 'ink-svd' - dico = dico_decorr_symetric(dico,mu,solver.solution); - case 'grassmannian' - [n m] = size(dico); - dico = grassmannian(n,m,[],0.9,0.99,dico); - case 'shrinkgram' - dico = shrinkgram(dico,mu); - case 'iterproj' - dico = iterativeprojections(dico,mu,Problem.b1,solver.solution); - otherwise - end + [dico, solver.solution] = dico_update(dico, sig, solver.solution, ... + typeUpdate, flow, learningRate); + if (decorrelate) + dico = dico_decorr_symetric(dico, mu, solver.solution); + end if ((show_dictionary)&&(mod(i,show_iter)==0)) dictimg = SMALL_showdict(dico,[8 8],... @@ -162,4 +156,4 @@ Y(blockids) = sum(X(:,blockids).^2); end -end +end \ No newline at end of file