Mercurial > hg > smallbox
changeset 175:9eb5f0d4c1a4 danieleb
added MOCOD dictionary update
author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
---|---|
date | Thu, 17 Nov 2011 11:17:00 +0000 |
parents | dc2f0fa21310 |
children | d0645d5fca7d |
files | DL/two-step DL/dico_update.m |
diffstat | 1 files changed, 117 insertions(+), 104 deletions(-) [+] |
line wrap: on
line diff
--- a/DL/two-step DL/dico_update.m Thu Nov 17 11:16:15 2011 +0000 +++ b/DL/two-step DL/dico_update.m Thu Nov 17 11:17:00 2011 +0000 @@ -1,107 +1,120 @@ -function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) +function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams) - %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) - % - % perform one iteration of dictionary update for dictionary learning - % - % parameters: - % - dico: the initial dictionary with atoms as columns - % - sig: the training data - % - amp: the amplitude coefficients as a sparse matrix - % - type: the algorithm can be one of the following - % - ols: fixed step gradient descent - % - mailhe: optimal step gradient descent (can be implemented as a - % default for ols?) - % - MOD: pseudo-inverse of the coefficients - % - KSVD: already implemented by Elad - % - flow: 'sequential' or 'parallel'. If sequential, the residual is - % updated after each atom update. If parallel, the residual is only - % updated once the whole dictionary has been computed. Sequential works - % better, there may be no need to implement parallel. Not used with - % MOD. - % - rho: learning rate. If the type is 'ols', it is the descent step of - % the gradient (typical choice: 0.1). If the type is 'mailhe', the - % descent step is the optimal step*rho (typical choice: 1, although 2 - % or 3 seems to work better). Not used for MOD and KSVD. - %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - if ~ exist( 'rho', 'var' ) || isempty(rho) - rho = 0.1; - end - - if ~ exist( 'flow', 'var' ) || isempty(flow) - flow = sequential; - end - - res = sig - dico*amp; - nb_pattern = size(dico, 2); - - switch type - case 'rand' - x = rand(); - if x < 1/3 - type = 'MOD'; - elseif type < 2/3 - type = 'mailhe'; - else - type = 'KSVD'; - end - end - - switch type - case 'MOD' - G = amp*amp'; - dico2 = sig*amp'*G^-1; - for p = 1:nb_pattern - n = norm(dico2(:,p)); - % renormalize - if n > 0 - dico(:,p) = dico2(:,p)/n; - amp(p,:) = amp(p,:)*n; - end - end - case 'ols' - for p = 1:nb_pattern - grad = res*amp(p,:)'; - if norm(grad) > 0 - pat = dico(:,p) + rho*grad; - pat = pat/norm(pat); - if nargin >5 && strcmp(flow, 'sequential') - res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> - end - dico(:,p) = pat; - end - end - case 'mailhe' - for p = 1:nb_pattern - grad = res*amp(p,:)'; - if norm(grad) > 0 - pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; - pat = pat/norm(pat); - if nargin >5 && strcmp(flow, 'sequential') - res = res + (dico(:,p)-pat)*amp(p,:); - end - dico(:,p) = pat; - end - end - case 'KSVD' - for p = 1:nb_pattern - index = find(amp(p,:)~=0); - if ~isempty(index) - patch = res(:,index)+dico(:,p)*amp(p,index); - [U,S,V] = svd(patch); - if U(:,1)'*dico(:,p) > 0 - dico(:,p) = U(:,1); - else - dico(:,p) = -U(:,1); - end - dico(:,p) = dico(:,p)/norm(dico(:,p)); - amp(p,index) = dico(:,p)'*patch; - if nargin >5 && strcmp(flow, 'sequential') - res(:,index) = patch-dico(:,p)*amp(p,index); - end - end - end - end +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) +% +% perform one iteration of dictionary update for dictionary learning +% +% parameters: +% - dico: the initial dictionary with atoms as columns +% - sig: the training data +% - amp: the amplitude coefficients as a sparse matrix +% - type: the algorithm can be one of the following +% - ols: fixed step gradient descent +% - mailhe: optimal step gradient descent (can be implemented as a +% default for ols?) +% - MOD: pseudo-inverse of the coefficients +% - KSVD: already implemented by Elad +% - flow: 'sequential' or 'parallel'. If sequential, the residual is +% updated after each atom update. If parallel, the residual is only +% updated once the whole dictionary has been computed. Sequential works +% better, there may be no need to implement parallel. Not used with +% MOD. +% - rho: learning rate. If the type is 'ols', it is the descent step of +% the gradient (typical choice: 0.1). If the type is 'mailhe', the +% descent step is the optimal step*rho (typical choice: 1, although 2 +% or 3 seems to work better). Not used for MOD and KSVD. +% - mocodParams: struct containing the parameters for the MOCOD dictionary +% update (see Ramirez et Al., Sparse modeling with universal priors and +% learned incoherent dictionaries). The required fields are +% .Dprev: dictionary at previous optimisation step +% .zeta: coherence regularization factor +% .eta: atoms norm regularisation factor +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +if ~ exist( 'rho', 'var' ) || isempty(rho) + rho = 0.1; end +if ~ exist( 'flow', 'var' ) || isempty(flow) + flow = sequential; +end + +res = sig - dico*amp; +nb_pattern = size(dico, 2); + +switch type + case 'rand' + x = rand(); + if x < 1/3 + type = 'MOD'; + elseif type < 2/3 + type = 'mailhe'; + else + type = 'KSVD'; + end +end + +switch upper(type) + case 'MOD' + G = amp*amp'; + dico2 = sig*amp'*G^-1; + for p = 1:nb_pattern + n = norm(dico2(:,p)); + % renormalize + if n > 0 + dico(:,p) = dico2(:,p)/n; + amp(p,:) = amp(p,:)*n; + end + end + case 'OLS' + for p = 1:nb_pattern + grad = res*amp(p,:)'; + if norm(grad) > 0 + pat = dico(:,p) + rho*grad; + pat = pat/norm(pat); + if nargin >5 && strcmp(flow, 'sequential') + res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> + end + dico(:,p) = pat; + end + end + case 'MAILHE' + for p = 1:nb_pattern + grad = res*amp(p,:)'; + if norm(grad) > 0 + pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; + pat = pat/norm(pat); + if nargin >5 && strcmp(flow, 'sequential') + res = res + (dico(:,p)-pat)*amp(p,:); + end + dico(:,p) = pat; + end + end + case 'KSVD' + for p = 1:nb_pattern + index = find(amp(p,:)~=0); + if ~isempty(index) + patch = res(:,index)+dico(:,p)*amp(p,index); + [U,~,V] = svd(patch); + if U(:,1)'*dico(:,p) > 0 + dico(:,p) = U(:,1); + else + dico(:,p) = -U(:,1); + end + dico(:,p) = dico(:,p)/norm(dico(:,p)); + amp(p,index) = dico(:,p)'*patch; + if nargin >5 && strcmp(flow, 'sequential') + res(:,index) = patch-dico(:,p)*amp(p,index); + end + end + end + case 'MOCOD' + zeta = mocodParams.zeta; + eta = mocodParams.eta; + Dprev = mocodParams.Dprev; + + dico = (sig*amp' + 2*(zeta+eta)*Dprev)/... + (amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev))); +end +end +