ivan@152: function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) bmailhe@201: ivan@152: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ivan@152: % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) ivan@152: % ivan@152: % perform one iteration of dictionary update for dictionary learning ivan@152: % ivan@152: % parameters: ivan@152: % - dico: the initial dictionary with atoms as columns ivan@152: % - sig: the training data ivan@152: % - amp: the amplitude coefficients as a sparse matrix ivan@152: % - type: the algorithm can be one of the following bmailhe@201: % - ols: fixed step gradient descent, as described in Olshausen & bmailhe@201: % Field95 bmailhe@201: % - opt: optimal step gradient descent, as described in Mailhe et bmailhe@201: % al.08 bmailhe@201: % - MOD: pseudo-inverse of the coefficients, as described in Engan99 bmailhe@201: % - KSVD: PCA update as described in Aharon06. For fast applications, bmailhe@201: % use KSVDbox rather than this code. bmailhe@201: % - LGD: large step gradient descent. Equivalent to 'opt' with bmailhe@201: % rho=2. ivan@152: % - flow: 'sequential' or 'parallel'. If sequential, the residual is ivan@152: % updated after each atom update. If parallel, the residual is only bmailhe@201: % updated once the whole dictionary has been computed. bmailhe@201: % Default: Sequential (sequential usually works better). Not used with ivan@152: % MOD. ivan@152: % - rho: learning rate. If the type is 'ols', it is the descent step of bmailhe@201: % the gradient (default: 0.1). If the type is 'opt', the bmailhe@201: % descent step is the optimal step*rho (default: 1, although 2 works bmailhe@201: % better. See LGD for more details). Not used for MOD and KSVD. ivan@152: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ivan@152: ivan@152: if ~ exist( 'flow', 'var' ) || isempty(flow) bmailhe@201: flow = 'sequential'; ivan@152: end ivan@152: ivan@152: res = sig - dico*amp; ivan@152: nb_pattern = size(dico, 2); ivan@152: bmailhe@201: % if the type is random, then randomly pick another type ivan@152: switch type ivan@152: case 'rand' ivan@152: x = rand(); ivan@152: if x < 1/3 ivan@152: type = 'MOD'; ivan@152: elseif type < 2/3 bmailhe@201: type = 'opt'; ivan@152: else ivan@152: type = 'KSVD'; ivan@152: end ivan@152: end ivan@152: bmailhe@201: % set the learning rate to default if not provided bmailhe@201: if ~ exist( 'rho', 'var' ) || isempty(rho) bmailhe@201: switch type bmailhe@201: case 'ols' bmailhe@201: rho = 0.1; bmailhe@201: case 'opt' bmailhe@201: rho = 1; bmailhe@201: end bmailhe@201: end bmailhe@201: ivan@152: switch type ivan@152: case 'MOD' ivan@152: G = amp*amp'; ivan@152: dico2 = sig*amp'*G^-1; ivan@152: for p = 1:nb_pattern ivan@152: n = norm(dico2(:,p)); ivan@152: % renormalize ivan@152: if n > 0 ivan@152: dico(:,p) = dico2(:,p)/n; ivan@152: amp(p,:) = amp(p,:)*n; ivan@152: end ivan@152: end ivan@152: case 'ols' ivan@152: for p = 1:nb_pattern ivan@152: grad = res*amp(p,:)'; ivan@152: if norm(grad) > 0 ivan@152: pat = dico(:,p) + rho*grad; ivan@152: pat = pat/norm(pat); ivan@152: if nargin >5 && strcmp(flow, 'sequential') ivan@152: res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> ivan@152: end ivan@152: dico(:,p) = pat; ivan@152: end ivan@152: end bmailhe@201: case 'opt' ivan@152: for p = 1:nb_pattern bmailhe@207: index = find(amp(p,:)~=0); bmailhe@207: vec = amp(p,index); bmailhe@207: grad = res(:,index)*vec'; ivan@152: if norm(grad) > 0 bmailhe@201: pat = (vec*vec')*dico(:,p) + rho*grad; ivan@152: pat = pat/norm(pat); ivan@152: if nargin >5 && strcmp(flow, 'sequential') bmailhe@207: res(:,index) = res(:,index) + (dico(:,p)-pat)*vec; bmailhe@207: end bmailhe@207: dico(:,p) = pat; bmailhe@207: end bmailhe@207: end bmailhe@207: case 'LGD' bmailhe@207: for p = 1:nb_pattern bmailhe@207: index = find(amp(p,:)~=0); bmailhe@207: vec = amp(p,index); bmailhe@207: grad = res(:,index)*vec'; bmailhe@207: if norm(grad) > 0 bmailhe@207: pat = (vec*vec')*dico(:,p) + 2*grad; bmailhe@207: pat = pat/norm(pat); bmailhe@207: if nargin >5 && strcmp(flow, 'sequential') bmailhe@207: res(:,index) = res(:,index) + (dico(:,p)-pat)*vec; ivan@152: end ivan@152: dico(:,p) = pat; ivan@152: end ivan@152: end ivan@152: case 'KSVD' ivan@152: for p = 1:nb_pattern ivan@152: index = find(amp(p,:)~=0); ivan@152: if ~isempty(index) ivan@152: patch = res(:,index)+dico(:,p)*amp(p,index); bmailhe@201: [U,~,V] = svd(patch); ivan@152: if U(:,1)'*dico(:,p) > 0 ivan@152: dico(:,p) = U(:,1); ivan@152: else ivan@152: dico(:,p) = -U(:,1); ivan@152: end ivan@152: dico(:,p) = dico(:,p)/norm(dico(:,p)); ivan@152: amp(p,index) = dico(:,p)'*patch; ivan@152: if nargin >5 && strcmp(flow, 'sequential') ivan@152: res(:,index) = patch-dico(:,p)*amp(p,index); ivan@152: end ivan@152: end ivan@152: end ivan@152: end ivan@152: end ivan@152: