Mercurial > hg > smallbox
diff DL/two-step DL/dico_update.m @ 224:fd0b5d36f6ad danieleb
Updated the contents of this branch with the contents of the default branch.
author | luisf <luis.figueira@eecs.qmul.ac.uk> |
---|---|
date | Thu, 12 Apr 2012 13:52:28 +0100 |
parents | 9eb5f0d4c1a4 233e75809e4a |
children |
line wrap: on
line diff
--- a/DL/two-step DL/dico_update.m Wed Mar 14 16:31:38 2012 +0000 +++ b/DL/two-step DL/dico_update.m Thu Apr 12 13:52:28 2012 +0100 @@ -1,120 +1,135 @@ -function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams) - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) -% -% perform one iteration of dictionary update for dictionary learning -% -% parameters: -% - dico: the initial dictionary with atoms as columns -% - sig: the training data -% - amp: the amplitude coefficients as a sparse matrix -% - type: the algorithm can be one of the following -% - ols: fixed step gradient descent -% - mailhe: optimal step gradient descent (can be implemented as a -% default for ols?) -% - MOD: pseudo-inverse of the coefficients -% - KSVD: already implemented by Elad -% - flow: 'sequential' or 'parallel'. If sequential, the residual is -% updated after each atom update. If parallel, the residual is only -% updated once the whole dictionary has been computed. Sequential works -% better, there may be no need to implement parallel. Not used with -% MOD. -% - rho: learning rate. If the type is 'ols', it is the descent step of -% the gradient (typical choice: 0.1). If the type is 'mailhe', the -% descent step is the optimal step*rho (typical choice: 1, although 2 -% or 3 seems to work better). Not used for MOD and KSVD. -% - mocodParams: struct containing the parameters for the MOCOD dictionary -% update (see Ramirez et Al., Sparse modeling with universal priors and -% learned incoherent dictionaries). The required fields are -% .Dprev: dictionary at previous optimisation step -% .zeta: coherence regularization factor -% .eta: atoms norm regularisation factor -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -if ~ exist( 'rho', 'var' ) || isempty(rho) - rho = 0.1; +function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) + + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) + % + % perform one iteration of dictionary update for dictionary learning + % + % parameters: + % - dico: the initial dictionary with atoms as columns + % - sig: the training data + % - amp: the amplitude coefficients as a sparse matrix + % - type: the algorithm can be one of the following + % - ols: fixed step gradient descent, as described in Olshausen & + % Field95 + % - opt: optimal step gradient descent, as described in Mailhe et + % al.08 + % - MOD: pseudo-inverse of the coefficients, as described in Engan99 + % - KSVD: PCA update as described in Aharon06. For fast applications, + % use KSVDbox rather than this code. + % - LGD: large step gradient descent. Equivalent to 'opt' with + % rho=2. + % - flow: 'sequential' or 'parallel'. If sequential, the residual is + % updated after each atom update. If parallel, the residual is only + % updated once the whole dictionary has been computed. + % Default: Sequential (sequential usually works better). Not used with + % MOD. + % - rho: learning rate. If the type is 'ols', it is the descent step of + % the gradient (default: 0.1). If the type is 'opt', the + % descent step is the optimal step*rho (default: 1, although 2 works + % better. See LGD for more details). Not used for MOD and KSVD. + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + if ~ exist( 'flow', 'var' ) || isempty(flow) + flow = 'sequential'; + end + + res = sig - dico*amp; + nb_pattern = size(dico, 2); + + % if the type is random, then randomly pick another type + switch type + case 'rand' + x = rand(); + if x < 1/3 + type = 'MOD'; + elseif type < 2/3 + type = 'opt'; + else + type = 'KSVD'; + end + end + + % set the learning rate to default if not provided + if ~ exist( 'rho', 'var' ) || isempty(rho) + switch type + case 'ols' + rho = 0.1; + case 'opt' + rho = 1; + end + end + + switch type + case 'MOD' + G = amp*amp'; + dico2 = sig*amp'*G^-1; + for p = 1:nb_pattern + n = norm(dico2(:,p)); + % renormalize + if n > 0 + dico(:,p) = dico2(:,p)/n; + amp(p,:) = amp(p,:)*n; + end + end + case 'ols' + for p = 1:nb_pattern + grad = res*amp(p,:)'; + if norm(grad) > 0 + pat = dico(:,p) + rho*grad; + pat = pat/norm(pat); + if nargin >5 && strcmp(flow, 'sequential') + res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> + end + dico(:,p) = pat; + end + end + case 'opt' + for p = 1:nb_pattern + index = find(amp(p,:)~=0); + vec = amp(p,index); + grad = res(:,index)*vec'; + if norm(grad) > 0 + pat = (vec*vec')*dico(:,p) + rho*grad; + pat = pat/norm(pat); + if nargin >5 && strcmp(flow, 'sequential') + res(:,index) = res(:,index) + (dico(:,p)-pat)*vec; + end + dico(:,p) = pat; + end + end + case 'LGD' + for p = 1:nb_pattern + index = find(amp(p,:)~=0); + vec = amp(p,index); + grad = res(:,index)*vec'; + if norm(grad) > 0 + pat = (vec*vec')*dico(:,p) + 2*grad; + pat = pat/norm(pat); + if nargin >5 && strcmp(flow, 'sequential') + res(:,index) = res(:,index) + (dico(:,p)-pat)*vec; + end + dico(:,p) = pat; + end + end + case 'KSVD' + for p = 1:nb_pattern + index = find(amp(p,:)~=0); + if ~isempty(index) + patch = res(:,index)+dico(:,p)*amp(p,index); + [U,~,V] = svd(patch); + if U(:,1)'*dico(:,p) > 0 + dico(:,p) = U(:,1); + else + dico(:,p) = -U(:,1); + end + dico(:,p) = dico(:,p)/norm(dico(:,p)); + amp(p,index) = dico(:,p)'*patch; + if nargin >5 && strcmp(flow, 'sequential') + res(:,index) = patch-dico(:,p)*amp(p,index); + end + end + end + end end -if ~ exist( 'flow', 'var' ) || isempty(flow) - flow = sequential; -end - -res = sig - dico*amp; -nb_pattern = size(dico, 2); - -switch type - case 'rand' - x = rand(); - if x < 1/3 - type = 'MOD'; - elseif type < 2/3 - type = 'mailhe'; - else - type = 'KSVD'; - end -end - -switch upper(type) - case 'MOD' - G = amp*amp'; - dico2 = sig*amp'*G^-1; - for p = 1:nb_pattern - n = norm(dico2(:,p)); - % renormalize - if n > 0 - dico(:,p) = dico2(:,p)/n; - amp(p,:) = amp(p,:)*n; - end - end - case 'OLS' - for p = 1:nb_pattern - grad = res*amp(p,:)'; - if norm(grad) > 0 - pat = dico(:,p) + rho*grad; - pat = pat/norm(pat); - if nargin >5 && strcmp(flow, 'sequential') - res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> - end - dico(:,p) = pat; - end - end - case 'MAILHE' - for p = 1:nb_pattern - grad = res*amp(p,:)'; - if norm(grad) > 0 - pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; - pat = pat/norm(pat); - if nargin >5 && strcmp(flow, 'sequential') - res = res + (dico(:,p)-pat)*amp(p,:); - end - dico(:,p) = pat; - end - end - case 'KSVD' - for p = 1:nb_pattern - index = find(amp(p,:)~=0); - if ~isempty(index) - patch = res(:,index)+dico(:,p)*amp(p,index); - [U,~,V] = svd(patch); - if U(:,1)'*dico(:,p) > 0 - dico(:,p) = U(:,1); - else - dico(:,p) = -U(:,1); - end - dico(:,p) = dico(:,p)/norm(dico(:,p)); - amp(p,index) = dico(:,p)'*patch; - if nargin >5 && strcmp(flow, 'sequential') - res(:,index) = patch-dico(:,p)*amp(p,index); - end - end - end - case 'MOCOD' - zeta = mocodParams.zeta; - eta = mocodParams.eta; - Dprev = mocodParams.Dprev; - - dico = (sig*amp' + 2*(zeta+eta)*Dprev)/... - (amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev))); -end -end -