daniele@175: function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams) ivan@152: daniele@175: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% daniele@175: % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) daniele@175: % daniele@175: % perform one iteration of dictionary update for dictionary learning daniele@175: % daniele@175: % parameters: daniele@175: % - dico: the initial dictionary with atoms as columns daniele@175: % - sig: the training data daniele@175: % - amp: the amplitude coefficients as a sparse matrix daniele@175: % - type: the algorithm can be one of the following daniele@175: % - ols: fixed step gradient descent daniele@175: % - mailhe: optimal step gradient descent (can be implemented as a daniele@175: % default for ols?) daniele@175: % - MOD: pseudo-inverse of the coefficients daniele@175: % - KSVD: already implemented by Elad daniele@175: % - flow: 'sequential' or 'parallel'. If sequential, the residual is daniele@175: % updated after each atom update. If parallel, the residual is only daniele@175: % updated once the whole dictionary has been computed. Sequential works daniele@175: % better, there may be no need to implement parallel. Not used with daniele@175: % MOD. daniele@175: % - rho: learning rate. If the type is 'ols', it is the descent step of daniele@175: % the gradient (typical choice: 0.1). If the type is 'mailhe', the daniele@175: % descent step is the optimal step*rho (typical choice: 1, although 2 daniele@175: % or 3 seems to work better). Not used for MOD and KSVD. daniele@175: % - mocodParams: struct containing the parameters for the MOCOD dictionary daniele@175: % update (see Ramirez et Al., Sparse modeling with universal priors and daniele@175: % learned incoherent dictionaries). The required fields are daniele@175: % .Dprev: dictionary at previous optimisation step daniele@175: % .zeta: coherence regularization factor daniele@175: % .eta: atoms norm regularisation factor daniele@175: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% daniele@175: if ~ exist( 'rho', 'var' ) || isempty(rho) daniele@175: rho = 0.1; ivan@152: end ivan@152: daniele@175: if ~ exist( 'flow', 'var' ) || isempty(flow) daniele@175: flow = sequential; daniele@175: end daniele@175: daniele@175: res = sig - dico*amp; daniele@175: nb_pattern = size(dico, 2); daniele@175: daniele@175: switch type daniele@175: case 'rand' daniele@175: x = rand(); daniele@175: if x < 1/3 daniele@175: type = 'MOD'; daniele@175: elseif type < 2/3 daniele@175: type = 'mailhe'; daniele@175: else daniele@175: type = 'KSVD'; daniele@175: end daniele@175: end daniele@175: daniele@175: switch upper(type) daniele@175: case 'MOD' daniele@175: G = amp*amp'; daniele@175: dico2 = sig*amp'*G^-1; daniele@175: for p = 1:nb_pattern daniele@175: n = norm(dico2(:,p)); daniele@175: % renormalize daniele@175: if n > 0 daniele@175: dico(:,p) = dico2(:,p)/n; daniele@175: amp(p,:) = amp(p,:)*n; daniele@175: end daniele@175: end daniele@175: case 'OLS' daniele@175: for p = 1:nb_pattern daniele@175: grad = res*amp(p,:)'; daniele@175: if norm(grad) > 0 daniele@175: pat = dico(:,p) + rho*grad; daniele@175: pat = pat/norm(pat); daniele@175: if nargin >5 && strcmp(flow, 'sequential') daniele@175: res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> daniele@175: end daniele@175: dico(:,p) = pat; daniele@175: end daniele@175: end daniele@175: case 'MAILHE' daniele@175: for p = 1:nb_pattern daniele@175: grad = res*amp(p,:)'; daniele@175: if norm(grad) > 0 daniele@175: pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; daniele@175: pat = pat/norm(pat); daniele@175: if nargin >5 && strcmp(flow, 'sequential') daniele@175: res = res + (dico(:,p)-pat)*amp(p,:); daniele@175: end daniele@175: dico(:,p) = pat; daniele@175: end daniele@175: end daniele@175: case 'KSVD' daniele@175: for p = 1:nb_pattern daniele@175: index = find(amp(p,:)~=0); daniele@175: if ~isempty(index) daniele@175: patch = res(:,index)+dico(:,p)*amp(p,index); daniele@175: [U,~,V] = svd(patch); daniele@175: if U(:,1)'*dico(:,p) > 0 daniele@175: dico(:,p) = U(:,1); daniele@175: else daniele@175: dico(:,p) = -U(:,1); daniele@175: end daniele@175: dico(:,p) = dico(:,p)/norm(dico(:,p)); daniele@175: amp(p,index) = dico(:,p)'*patch; daniele@175: if nargin >5 && strcmp(flow, 'sequential') daniele@175: res(:,index) = patch-dico(:,p)*amp(p,index); daniele@175: end daniele@175: end daniele@175: end daniele@175: case 'MOCOD' daniele@175: zeta = mocodParams.zeta; daniele@175: eta = mocodParams.eta; daniele@175: Dprev = mocodParams.Dprev; daniele@175: daniele@175: dico = (sig*amp' + 2*(zeta+eta)*Dprev)/... daniele@175: (amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev))); daniele@175: end daniele@175: end daniele@175: