ivan@152: function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) ivan@152: ivan@152: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ivan@152: % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) ivan@152: % ivan@152: % perform one iteration of dictionary update for dictionary learning ivan@152: % ivan@152: % parameters: ivan@152: % - dico: the initial dictionary with atoms as columns ivan@152: % - sig: the training data ivan@152: % - amp: the amplitude coefficients as a sparse matrix ivan@152: % - type: the algorithm can be one of the following ivan@152: % - ols: fixed step gradient descent ivan@152: % - mailhe: optimal step gradient descent (can be implemented as a ivan@152: % default for ols?) ivan@152: % - MOD: pseudo-inverse of the coefficients ivan@152: % - KSVD: already implemented by Elad ivan@152: % - flow: 'sequential' or 'parallel'. If sequential, the residual is ivan@152: % updated after each atom update. If parallel, the residual is only ivan@152: % updated once the whole dictionary has been computed. Sequential works ivan@152: % better, there may be no need to implement parallel. Not used with ivan@152: % MOD. ivan@152: % - rho: learning rate. If the type is 'ols', it is the descent step of ivan@152: % the gradient (typical choice: 0.1). If the type is 'mailhe', the ivan@152: % descent step is the optimal step*rho (typical choice: 1, although 2 ivan@152: % or 3 seems to work better). Not used for MOD and KSVD. ivan@152: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ivan@152: if ~ exist( 'rho', 'var' ) || isempty(rho) ivan@152: rho = 0.1; ivan@152: end ivan@152: ivan@152: if ~ exist( 'flow', 'var' ) || isempty(flow) ivan@152: flow = sequential; ivan@152: end ivan@152: ivan@152: res = sig - dico*amp; ivan@152: nb_pattern = size(dico, 2); ivan@152: ivan@152: switch type ivan@152: case 'rand' ivan@152: x = rand(); ivan@152: if x < 1/3 ivan@152: type = 'MOD'; ivan@152: elseif type < 2/3 ivan@152: type = 'mailhe'; ivan@152: else ivan@152: type = 'KSVD'; ivan@152: end ivan@152: end ivan@152: ivan@152: switch type ivan@152: case 'MOD' ivan@152: G = amp*amp'; ivan@152: dico2 = sig*amp'*G^-1; ivan@152: for p = 1:nb_pattern ivan@152: n = norm(dico2(:,p)); ivan@152: % renormalize ivan@152: if n > 0 ivan@152: dico(:,p) = dico2(:,p)/n; ivan@152: amp(p,:) = amp(p,:)*n; ivan@152: end ivan@152: end ivan@152: case 'ols' ivan@152: for p = 1:nb_pattern ivan@152: grad = res*amp(p,:)'; ivan@152: if norm(grad) > 0 ivan@152: pat = dico(:,p) + rho*grad; ivan@152: pat = pat/norm(pat); ivan@152: if nargin >5 && strcmp(flow, 'sequential') ivan@152: res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> ivan@152: end ivan@152: dico(:,p) = pat; ivan@152: end ivan@152: end ivan@152: case 'mailhe' ivan@152: for p = 1:nb_pattern ivan@152: grad = res*amp(p,:)'; ivan@152: if norm(grad) > 0 ivan@152: pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; ivan@152: pat = pat/norm(pat); ivan@152: if nargin >5 && strcmp(flow, 'sequential') ivan@152: res = res + (dico(:,p)-pat)*amp(p,:); ivan@152: end ivan@152: dico(:,p) = pat; ivan@152: end ivan@152: end ivan@152: case 'KSVD' ivan@152: for p = 1:nb_pattern ivan@152: index = find(amp(p,:)~=0); ivan@152: if ~isempty(index) ivan@152: patch = res(:,index)+dico(:,p)*amp(p,index); ivan@152: [U,S,V] = svd(patch); ivan@152: if U(:,1)'*dico(:,p) > 0 ivan@152: dico(:,p) = U(:,1); ivan@152: else ivan@152: dico(:,p) = -U(:,1); ivan@152: end ivan@152: dico(:,p) = dico(:,p)/norm(dico(:,p)); ivan@152: amp(p,index) = dico(:,p)'*patch; ivan@152: if nargin >5 && strcmp(flow, 'sequential') ivan@152: res(:,index) = patch-dico(:,p)*amp(p,index); ivan@152: end ivan@152: end ivan@152: end ivan@152: end ivan@152: end ivan@152: