view DL/RLS-DLA/SMALL_rlsdla1.m @ 51:217a33ac374e

(none)
author idamnjanovic
date Mon, 14 Mar 2011 16:52:27 +0000
parents 6416fc12f2b8
children
line wrap: on
line source
function Dictionary = SMALL_rlsdla1(X, params)





global CODE_SPARSITY CODE_ERROR codemode
global MEM_LOW MEM_NORMAL MEM_HIGH memusage
global ompfunc ompparams exactsvd

CODE_SPARSITY = 1;
CODE_ERROR = 2;

MEM_LOW = 1;
MEM_NORMAL = 2;
MEM_HIGH = 3;


% p = randperm(size(X,2));

        % coding mode %
X_norm=sqrt(sum(X.^2, 1));
% X_norm_1=sum(abs(X));
% X_norm_inf=max(abs(X));
[X_norm_sort, p]=sort(X_norm);%, 'descend');
% [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend');

%         if (isfield(params,'codemode'))
%           switch lower(params.codemode)
%             case 'sparsity'
%               codemode = CODE_SPARSITY;
%               thresh = params.Tdata;
%             case 'error'
%               codemode = CODE_ERROR;
%               thresh = params.Edata;
%             otherwise
%               error('Invalid coding mode specified');
%           end
%         elseif (isfield(params,'Tdata'))
%           codemode = CODE_SPARSITY;
%           thresh = params.Tdata;
%         elseif (isfield(params,'Edata'))
%           codemode = CODE_ERROR;
%           thresh = params.Edata;
% 
%         else
%           error('Data sparse-coding target not specified');
%         end

thresh = params.Edata;
% max number of atoms %

% if (codemode==CODE_ERROR && isfield(params,'maxatoms'))
%   ompparams{end+1} = 'maxatoms';
%   ompparams{end+1} = params.maxatoms;
% end


% memory usage %

if (isfield(params,'memusage'))
  switch lower(params.memusage)
    case 'low'
      memusage = MEM_LOW;
    case 'normal'
      memusage = MEM_NORMAL;
    case 'high'
      memusage = MEM_HIGH;
    otherwise
      error('Invalid memory usage mode');
  end
else
  memusage = MEM_NORMAL;
end


% iteration count %

if (isfield(params,'iternum'))
  iternum = params.iternum;
else
  iternum = 10;
end


% omp function %

if (codemode == CODE_SPARSITY)
  ompfunc = @omp;
else
  ompfunc = @omp2;
end


% % status messages %
% 
% printiter = 0;
% printreplaced = 0;
% printerr = 0;
% printgerr = 0;
% 
% verbose = 't';
% msgdelta = -1;
% 

% 
% for i = 1:length(verbose)
%   switch lower(verbose(i))
%     case 'i'
%       printiter = 1;
%     case 'r'
%       printiter = 1;
%       printreplaced = 1;
%     case 't'
%       printiter = 1;
%       printerr = 1;
%       if (isfield(params,'testdata'))
%         printgerr = 1;
%       end
%   end
% end
% 
% if (msgdelta<=0 || isempty(verbose))
%   msgdelta = -1; 
% end
% 
% ompparams{end+1} = 'messages';
% ompparams{end+1} = msgdelta;
% 
% 
% 
% % compute error flag %
% 
% comperr = (nargout>=3 || printerr);
% 
% 
% % validation flag %
% 
% testgen = 0;
% if (isfield(params,'testdata'))
%   testdata = params.testdata;
%   if (nargout>=4 || printgerr)
%     testgen = 1;
%   end
% end

% 
% % data norms %
% 
% XtX = []; XtXg = [];
% if (codemode==CODE_ERROR && memusage==MEM_HIGH)
%   XtX = colnorms_squared(data);
%   if (testgen)
%     XtXg = colnorms_squared(testdata);
%   end
% end


% mutual incoherence limit %

if (isfield(params,'muthresh'))
  muthresh = params.muthresh;
else
  muthresh = 0.99;
end
if (muthresh < 0)
  error('invalid muthresh value, must be non-negative');
end





% determine dictionary size %

if (isfield(params,'initdict'))
  if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
    dictsize = length(params.initdict);
  else
    dictsize = size(params.initdict,2);
  end
end
if (isfield(params,'dictsize'))    % this superceedes the size determined by initdict
  dictsize = params.dictsize;
end

if (size(X,2) < dictsize)
  error('Number of training signals is smaller than number of atoms to train');
end


% initialize the dictionary %

if (isfield(params,'initdict'))
  if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
    D1 = X(:,params.initdict(1:dictsize));
     D2 = X(:,params.initdict(1:dictsize));
  else
    if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
      error('Invalid initial dictionary');
    end
    D1 = params.initdict(:,1:dictsize);
    D2 = params.initdict(:,1:dictsize);
  end
else
  data_ids = find(colnorms_squared(X) > 1e-6);   % ensure no zero data elements are chosen
  perm = randperm(length(data_ids));
  D = X(:,data_ids(perm(1:dictsize)));
end


% normalize the dictionary %

% D = normcols(D);
% DtD=D'*D;

err = zeros(1,iternum);
gerr = zeros(1,iternum);

if (codemode == CODE_SPARSITY)
  errstr = 'RMSE';
else
  errstr = 'mean atomnum';
end
X(:,p(X_norm_sort<thresh))=0;
% if (iternum==4)
%     X_im=col2imstep(X, [256 256], [8 8]);
% else
%     X_im=col2imstep(X, [512 512], [8 8]);
% end
% figure(10); imshow(X_im);
p1=p(X_norm_sort>thresh);

%p1=p1(p2(1:40000));
%end-min(40000, end)+1:end));%1:min(40000, end)));
%p1 = randperm(size(data,2));%size(data,2)
%data=data(:,p1);

C1=(100000*thresh)*eye(dictsize);
C2=(100000*thresh)*eye(dictsize);
% figure(11);
w=zeros(dictsize,1);
replaced=zeros(dictsize,1);
u=zeros(dictsize,1);
 dictimg = showdict(D1,[8 8],round(sqrt(size(D1,2))),round(sqrt(size(D1,2))),'lines','highcontrast');
%  h=imshow(imresize(dictimg,2,'nearest'));
lambda=0.9998
for j=1:3
if size(p1,2)>20000
    p2 = randperm(floor(size(p1,2)/2));
    p2=sort(p2(1:20000));
    data1=X(:,p1(p2));
    data2=X(:,p1(floor(size(p1,2)/2)+p2));
elseif size(p1,2)>0
    data=X(:,p1);
else
    break;
end
% figure(1);
% plot(sqrt(sum(data.^2, 1)));
% a=size(data,2)/4;
% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002
C1(1,1)=0;
C2(1,1)=0;
for i = 1:size(data1,2)
%     if norm(data(:,i))>thresh
    %      par.multA= @(x,par)  multMatr(D,x);     % user function   y=Ax
    %      par.multAt=@(x,par)  multMatrAdj(D,x);  % user function  y=A'*x
    %      par.y=data(:,i);
     %  w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh);
     % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par);
   %w = SMALL_chol(D,data(:,i), 256,32, thresh);%
   %w = sparsecode(data(:,i), D, [], [], thresh);
   w1 = omp2mex(D1,data1(:,i),[],[],[],thresh,0,-1,-1,0);
 w2 = omp2mex(D2,data2(:,i),[],[],[],thresh,0,-1,-1,0);
   %w(find(w<1))=0; 
   %^2;
%   lambda(i)=1-0.001/(1+i/a);
% if i<a
%     lambda(i)=1-0.001*(1-(i/a));
% else
%     lambda(i)=1;
% end
%    param.lambda=thresh;
%    param.mode=2;
%    param.L=32;
%    w=mexLasso(data(:,i), D, param);
    spind1=find(w1);
    spind2=find(w2);
    
    %replaced(spind)=replaced(spind)+1;
    %-0.001*(1/2)^(i/a);
%   w_sp(i)=nnz(w);
    residual1 = data1(:,i) - D1 * w1;
    residual2 = data2(:,i) - D2 * w2;
   %if ~isempty(spind) 
    %i
    
   C1 = C1 *(1/ lambda);
    C2 = C2 *(1/ lambda);
    u1 = C1(:,spind1) * w1(spind1);
  u2 = C2(:,spind2) * w2(spind2);
   %spindu=find(u);
   % v = D' * residual;
    
    alfa1 = 1/(1 + w1' * u1);
    alfa2 = 1/(1 + w2' * u2);
    D1 = D1 + (alfa1 * residual1) * u1';
    D2 = D2 + (alfa2 * residual2) * u2';
    %uut=;
    C1 = C1 - (alfa1 * u1)* u1';
    C2 = C2 - (alfa2 * u2)* u2';
   % lambda=(19*lambda+1)/20;
   % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut;
%    modi=5000;
%    if (mod(i,modi)==0)
%        Ximd=zeros(size(X));
%        Ximd(:,p((i-modi+1:i)))=data(:,i-modi+1:i);
%        
%        if (iternum==4)
%             X_ima=col2imstep(Ximd, [256 256], [8 8]);
%        else
%             X_ima=col2imstep(Ximd, [512 512], [8 8]);
%        end
%        dictimg1=dictimg;
%        dictimg = showdict(D,[8 8],...
%         round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
%         dictimg1=(dictimg-dictimg1)*255;
%         
%         figure(2);
%         subplot(2,2,1); imshow(X_ima);
%         subplot(2,2,3); imshow(imresize(dictimg,2,'nearest'));
%         subplot(2,2,4);    imshow(imresize(dictimg1,2,'nearest'));
%         subplot(2,2,2);imshow(C*(255/max(max(C))));
%         pause(0.02);
%     end
%    end
end
%p1=p1(setxor(p2,1:end));
%[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced);
%replaced=zeros(dictsize,1);
% W=sparsecode(data, D, [], [], thresh);
% data=D*W;
lambda=lambda+0.0001
end
%Gamma=mexLasso(data, D, param);
%err=compute_err(D,Gamma, data);
%[y,i]=max(err);
%D(:,1)=data(:,i)/norm(data(:,i));
% D=normcols(D);
% D_norm=sqrt(sum(D.^2, 1));
% D_norm_1=sum(abs(D));
% X_norm_1=sum(abs(X));
% X_norm_inf=max(abs(X));
% [D_norm_sort, p]=sort(D_norm_1, 'descend');
Dictionary =[D1 D2];
% figure(3);
% plot(lambda);
% mean(lambda);
% figure(4+j);plot(w_sp);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%             sparsecode               %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function Gamma = sparsecode(data,D,XtX,G,thresh)

global CODE_SPARSITY codemode
global MEM_HIGH memusage
global ompfunc ompparams

if (memusage < MEM_HIGH)
  Gamma = ompfunc(D,data,G,thresh,ompparams{:});
  
else  % memusage is high
  
  if (codemode == CODE_SPARSITY)
    Gamma = ompfunc(D'*data,G,thresh,ompparams{:});
    
  else
    Gamma = ompfunc(D, data, G, thresh,ompparams{:});
  end
  
end

end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%             compute_err              %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


function err = compute_err(D,Gamma,data)
  
global CODE_SPARSITY codemode

if (codemode == CODE_SPARSITY)
  err = sqrt(sum(reperror2(data,D,Gamma))/numel(data));
else
  err = nnz(Gamma)/size(data,2);
end

end



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%           cleardict                  %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms)

use_thresh = 4;  % at least this number of samples must use the atom to be kept

dictsize = size(D,2);

% compute error in blocks to conserve memory
% err = zeros(1,size(X,2));
% blocks = [1:3000:size(X,2) size(X,2)+1];
% for i = 1:length(blocks)-1
%   err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2);
% end

cleared_atoms = 0;
usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2);

for j = 1:dictsize
  
  % compute G(:,j)
  Gj = D'*D(:,j);
  Gj(j) = 0;
  
  % replace atom
  if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) )
%     [y,i] = max(err(unused_sigs));
    D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end)));
    unused_sigs = unused_sigs([1:end-1]);
    cleared_atoms = cleared_atoms+1;
  end
end

end



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%            misc functions            %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


function err2 = reperror2(X,D,Gamma)

% compute in blocks to conserve memory
err2 = zeros(1,size(X,2));
blocksize = 2000;
for i = 1:blocksize:size(X,2)
  blockids = i : min(i+blocksize-1,size(X,2));
  err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2);
end

end


function Y = colnorms_squared(X)

% compute in blocks to conserve memory
Y = zeros(1,size(X,2));
blocksize = 2000;
for i = 1:blocksize:size(X,2)
  blockids = i : min(i+blocksize-1,size(X,2));
  Y(blockids) = sum(X(:,blockids).^2);
end

end