# HG changeset patch # User luisf # Date 1332258795 0 # Node ID 5fe60504a6a9e3ab85f6d8f9d4c77670b70f32ac # Parent 5bb579c9874ed42e6085f755ed4ad34061be61ad# Parent f3b6ddd2f04f0b4f85bdc39afa0a3caf5f3df86c Merge from 203:f3b6ddd2f04f diff -r f3b6ddd2f04f -r 5fe60504a6a9 DL/two-step DL/dico_color_separate.m --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/two-step DL/dico_color_separate.m Tue Mar 20 15:53:15 2012 +0000 @@ -0,0 +1,95 @@ +function [colors nbColors] = dico_color_separate(dico, mu) + % DICO_COLOR cluster several dictionaries in pairs of high correlation + % atoms. Called by dico_decorr. + % + % Parameters: + % -dico: the dictionaries + % -mu: the correlation threshold + % + % Result: + % -colors: a cell array of indices. Two atoms with the same color have + % a correlation greater than mu + + + numDico = length(dico); + colors = cell(numDico,1); + for i = 1:numDico + colors{i} = zeros(length(dico{i}),1); + end + + G = cell(numDico); + + % compute the correlations + for i = 1:numDico + for j = i+1:numDico + G{i,j} = abs(dico{i}'*dico{j}); + end + end + + % iterate on the correlations higher than mu + c = 1; + [maxCorr, i, j, m, n] = findMaxCorr(G); + while maxCorr > mu + % find the highest correlated pair + + % color them + colors{i}(m) = c; + colors{j}(n) = c; + c = c+1; + + % make sure these atoms never get selected again + % Set to zero relevant lines in the Gram Matrix + for j2 = i+1:numDico + G{i,j2}(m,:) = 0; + end + + for i2 = 1:i-1 + G{i2,i}(:,m) = 0; + end + + for j2 = j+1:numDico + G{j,j2}(n,:) = 0; + end + + for i2 = 1:j-1 + G{i2,j}(:,n) = 0; + end + + % find the next correlation + [maxCorr, i, j, m, n] = findMaxCorr(G); + end + + % complete the coloring with singletons + % index = find(colors==0); + % colors(index) = c:c+length(index)-1; + nbColors = c-1; +end + +function [val, i, j, m, n] = findMaxCorr(G) + %FINDMAXCORR find the maximal correlation in the cellular Gram matrix + % + % Input: + % -G: the Gram matrix + % + % Output: + % -val: value of the correlation + % -i,j,m,n: indices of the argmax. The maximal correlation is reached + % for the m^th atom of the i^th dictionary and the n^h atom of the + % j^h dictionary + + val = -1; + for tmpI = 1:length(G) + for tmpJ = tmpI+1:length(G) + [tmpVal tmpM] = max(G{tmpI,tmpJ},[],1); + [tmpVal tmpN] = max(tmpVal); + if tmpVal > val + val = tmpVal; + i = tmpI; + j = tmpJ; + n = tmpN; + m = tmpM(n); + end + end + end +end + \ No newline at end of file diff -r f3b6ddd2f04f -r 5fe60504a6a9 DL/two-step DL/dico_decorr_symetric.m --- a/DL/two-step DL/dico_decorr_symetric.m Tue Mar 20 15:52:22 2012 +0000 +++ b/DL/two-step DL/dico_decorr_symetric.m Tue Mar 20 15:53:15 2012 +0000 @@ -1,15 +1,17 @@ -function dico = dico_decorr_symetric(dico, mu, amp) +function dico = dico_decorr_symetric(dico, mu) %DICO_DECORR decorrelate a dictionary % Parameters: - % dico: the dictionary + % dico: the dictionary, either a matrix or a cell array of matrices. % mu: the coherence threshold - % amp: the amplitude coefficients, only used to decide which atom to - % project % % Result: - % dico: a dictionary close to the input one with coherence mu. + % dico: if the input dico was a matrix, then a matrix close to the + % input one with coherence mu. + % If the input was a cell array, a cell array of the same size + % containing matrices such that the coherence between different cells + % is lower than mu. - eps = 1e-6; % define tolerance for normalisation term alpha + eps = 1e-3; % define tolerance for normalisation term alpha % convert mu to the to the mean direction theta = acos(mu)/2; @@ -23,39 +25,115 @@ % rank = randperm(length(dico)); % end - % several decorrelation iterations might be needed to reach global - % coherence mu. niter can be adjusted to needs. - niter = 1; - while max(max(abs(dico'*dico -eye(length(dico))))) > mu + 0.01 - % find pairs of high correlation atoms - colors = dico_color(dico, mu); + % if only one dictionary is provided, then decorrelate it + if ~iscell(dico) + % several decorrelation iterations might be needed to reach global + % coherence mu. niter can be adjusted to needs. + niter = 1; + while max(max(abs(dico'*dico -eye(length(dico))))) > mu + eps + % find pairs of high correlation atoms + colors = dico_color(dico, mu); + + % iterate on all pairs + nbColors = max(colors); + for c = 1:nbColors + index = find(colors==c); + if numel(index) == 2 + if dico(:,index(1))'*dico(:,index(2)) > 0 + %build the basis vectors + v1 = dico(:,index(1))+dico(:,index(2)); + v1 = v1/norm(v1); + v2 = dico(:,index(1))-dico(:,index(2)); + v2 = v2/norm(v2); + + dico(:,index(1)) = ctheta*v1+stheta*v2; + dico(:,index(2)) = ctheta*v1-stheta*v2; + else + v1 = dico(:,index(1))-dico(:,index(2)); + v1 = v1/norm(v1); + v2 = dico(:,index(1))+dico(:,index(2)); + v2 = v2/norm(v2); + + dico(:,index(1)) = ctheta*v1+stheta*v2; + dico(:,index(2)) = -ctheta*v1+stheta*v2; + end + end + end + niter = niter+1; + end + %if a cell array of dictionaries is provided, decorrelate among + %different dictionaries only + else + niter = 1; + numDicos = length(dico); + G = cell(numDicos); + maxCorr = 0; + for i = 1:numDicos + for j = i+1:numDicos + G{i,j} = dico{i}'*dico{j}; + maxCorr = max(maxCorr,max(max(abs(G{i,j})))); + end + end - % iterate on all pairs - nbColors = max(colors); - for c = 1:nbColors - index = find(colors==c); - if numel(index) == 2 - if dico(:,index(1))'*dico(:,index(2)) > 0 + while maxCorr > mu + eps + % find pairs of high correlation atoms + [colors nbColors] = dico_color_separate(dico, mu); + + % iterate on all pairs + for c = 1:nbColors + for tmpI = 1:numDicos + index = find(colors{tmpI}==c); + if ~isempty(index) + i = tmpI; + m = index; + break; + end + end + for tmpJ = i+1:numDicos + index = find(colors{tmpJ}==c); + if ~isempty(index) + j = tmpJ; + n = index; + break; + end + end + + if dico{i}(:,m)'*dico{j}(:,n) > 0 %build the basis vectors - v1 = dico(:,index(1))+dico(:,index(2)); + v1 = dico{i}(:,m)+dico{j}(:,n); v1 = v1/norm(v1); - v2 = dico(:,index(1))-dico(:,index(2)); + v2 = dico{i}(:,m)-dico{j}(:,n); v2 = v2/norm(v2); - dico(:,index(1)) = ctheta*v1+stheta*v2; - dico(:,index(2)) = ctheta*v1-stheta*v2; + dico{i}(:,m) = ctheta*v1+stheta*v2; + dico{j}(:,n) = ctheta*v1-stheta*v2; else - v1 = dico(:,index(1))-dico(:,index(2)); + v1 = dico{i}(:,m)-dico{j}(:,n); v1 = v1/norm(v1); - v2 = dico(:,index(1))+dico(:,index(2)); + v2 = dico{i}(:,m)+dico{j}(:,n); v2 = v2/norm(v2); - dico(:,index(1)) = ctheta*v1+stheta*v2; - dico(:,index(2)) = -ctheta*v1+stheta*v2; + dico{i}(:,m) = ctheta*v1+stheta*v2; + dico{j}(:,n) = -ctheta*v1+stheta*v2; + end + end + niter = niter+1; + + % Remove noegative components and renormalize + for i = 1:length(dico) + dico{i} = max(dico{i},0); + for m = 1:size(dico{i},2) + dico{i}(:,m) = dico{i}(:,m)/norm(dico{i}(:,m)); + end + end + + maxCorr = 0; + for i = 1:numDicos + for j = i+1:numDicos + G{i,j} = dico{i}'*dico{j}; + maxCorr = max(maxCorr,max(max(abs(G{i,j})))); end end end - niter = niter+1; end end - diff -r f3b6ddd2f04f -r 5fe60504a6a9 DL/two-step DL/dico_update.m --- a/DL/two-step DL/dico_update.m Tue Mar 20 15:52:22 2012 +0000 +++ b/DL/two-step DL/dico_update.m Tue Mar 20 15:53:15 2012 +0000 @@ -1,5 +1,5 @@ function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) - + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) % @@ -10,44 +10,56 @@ % - sig: the training data % - amp: the amplitude coefficients as a sparse matrix % - type: the algorithm can be one of the following - % - ols: fixed step gradient descent - % - mailhe: optimal step gradient descent (can be implemented as a - % default for ols?) - % - MOD: pseudo-inverse of the coefficients - % - KSVD: already implemented by Elad + % - ols: fixed step gradient descent, as described in Olshausen & + % Field95 + % - opt: optimal step gradient descent, as described in Mailhe et + % al.08 + % - MOD: pseudo-inverse of the coefficients, as described in Engan99 + % - KSVD: PCA update as described in Aharon06. For fast applications, + % use KSVDbox rather than this code. + % - LGD: large step gradient descent. Equivalent to 'opt' with + % rho=2. % - flow: 'sequential' or 'parallel'. If sequential, the residual is % updated after each atom update. If parallel, the residual is only - % updated once the whole dictionary has been computed. Sequential works - % better, there may be no need to implement parallel. Not used with + % updated once the whole dictionary has been computed. + % Default: Sequential (sequential usually works better). Not used with % MOD. % - rho: learning rate. If the type is 'ols', it is the descent step of - % the gradient (typical choice: 0.1). If the type is 'mailhe', the - % descent step is the optimal step*rho (typical choice: 1, although 2 - % or 3 seems to work better). Not used for MOD and KSVD. + % the gradient (default: 0.1). If the type is 'opt', the + % descent step is the optimal step*rho (default: 1, although 2 works + % better. See LGD for more details). Not used for MOD and KSVD. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - if ~ exist( 'rho', 'var' ) || isempty(rho) - rho = 0.1; - end if ~ exist( 'flow', 'var' ) || isempty(flow) - flow = sequential; + flow = 'sequential'; end res = sig - dico*amp; nb_pattern = size(dico, 2); + % if the type is random, then randomly pick another type switch type case 'rand' x = rand(); if x < 1/3 type = 'MOD'; elseif type < 2/3 - type = 'mailhe'; + type = 'opt'; else type = 'KSVD'; end end + % set the learning rate to default if not provided + if ~ exist( 'rho', 'var' ) || isempty(rho) + switch type + case 'ols' + rho = 0.1; + case 'opt' + rho = 1; + end + end + switch type case 'MOD' G = amp*amp'; @@ -72,14 +84,15 @@ dico(:,p) = pat; end end - case 'mailhe' + case 'opt' for p = 1:nb_pattern - grad = res*amp(p,:)'; + vec : amp(p,:); + grad = res*vec'; if norm(grad) > 0 - pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; + pat = (vec*vec')*dico(:,p) + rho*grad; pat = pat/norm(pat); if nargin >5 && strcmp(flow, 'sequential') - res = res + (dico(:,p)-pat)*amp(p,:); + res = res + (dico(:,p)-pat)*vec; end dico(:,p) = pat; end @@ -89,7 +102,7 @@ index = find(amp(p,:)~=0); if ~isempty(index) patch = res(:,index)+dico(:,p)*amp(p,index); - [U,S,V] = svd(patch); + [U,~,V] = svd(patch); if U(:,1)'*dico(:,p) > 0 dico(:,p) = U(:,1); else