Mercurial > hg > smallbox
view DL/two-step DL/dico_color_separate.m @ 216:a986ee86651e luisf_dev
Calls SMALLboxInit in the beginning of both solve and learn, in order not to lose the SMALL_path variable.
author | luisf <luis.figueira@eecs.qmul.ac.uk> |
---|---|
date | Thu, 22 Mar 2012 11:41:04 +0000 |
parents | 69ce11724b1f |
children |
line wrap: on
line source
function [colors nbColors] = dico_color_separate(dico, mu) % DICO_COLOR cluster several dictionaries in pairs of high correlation % atoms. Called by dico_decorr. % % Parameters: % -dico: the dictionaries % -mu: the correlation threshold % % Result: % -colors: a cell array of indices. Two atoms with the same color have % a correlation greater than mu numDico = length(dico); colors = cell(numDico,1); for i = 1:numDico colors{i} = zeros(length(dico{i}),1); end G = cell(numDico); % compute the correlations for i = 1:numDico for j = i+1:numDico G{i,j} = abs(dico{i}'*dico{j}); end end % iterate on the correlations higher than mu c = 1; [maxCorr, i, j, m, n] = findMaxCorr(G); while maxCorr > mu % find the highest correlated pair % color them colors{i}(m) = c; colors{j}(n) = c; c = c+1; % make sure these atoms never get selected again % Set to zero relevant lines in the Gram Matrix for j2 = i+1:numDico G{i,j2}(m,:) = 0; end for i2 = 1:i-1 G{i2,i}(:,m) = 0; end for j2 = j+1:numDico G{j,j2}(n,:) = 0; end for i2 = 1:j-1 G{i2,j}(:,n) = 0; end % find the next correlation [maxCorr, i, j, m, n] = findMaxCorr(G); end % complete the coloring with singletons % index = find(colors==0); % colors(index) = c:c+length(index)-1; nbColors = c-1; end function [val, i, j, m, n] = findMaxCorr(G) %FINDMAXCORR find the maximal correlation in the cellular Gram matrix % % Input: % -G: the Gram matrix % % Output: % -val: value of the correlation % -i,j,m,n: indices of the argmax. The maximal correlation is reached % for the m^th atom of the i^th dictionary and the n^h atom of the % j^h dictionary val = -1; for tmpI = 1:length(G) for tmpJ = tmpI+1:length(G) [tmpVal tmpM] = max(G{tmpI,tmpJ},[],1); [tmpVal tmpN] = max(tmpVal); if tmpVal > val val = tmpVal; i = tmpI; j = tmpJ; n = tmpN; m = tmpM(n); end end end end