# HG changeset patch # User idamnjanovic # Date 1300282862 0 # Node ID 55faa9b5d1acc9e3a6475c2f6e375f71eae4bd9e # Parent 8288a23f041fbe001e328dade6acc1056d9be7b0 diff -r 8288a23f041f -r 55faa9b5d1ac DL/RLS-DLA/SMALL_rlsdla 05032011.m --- a/DL/RLS-DLA/SMALL_rlsdla 05032011.m Tue Mar 15 15:30:54 2011 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,488 +0,0 @@ -function Dictionary = SMALL_rlsdla(X, params) - - - - - -global CODE_SPARSITY CODE_ERROR codemode -global MEM_LOW MEM_NORMAL MEM_HIGH memusage -global ompfunc ompparams exactsvd - -CODE_SPARSITY = 1; -CODE_ERROR = 2; - -MEM_LOW = 1; -MEM_NORMAL = 2; -MEM_HIGH = 3; - - -% p = randperm(size(X,2)); - - % coding mode % -%X_norm=sqrt(sum(X.^2, 1)); -% X_norm_1=sum(abs(X)); -% X_norm_inf=max(abs(X)); -%[X_norm_sort, p]=sort(X_norm);%, 'descend'); -% [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); - -% if (isfield(params,'codemode')) -% switch lower(params.codemode) -% case 'sparsity' -% codemode = CODE_SPARSITY; -% thresh = params.Tdata; -% case 'error' -% codemode = CODE_ERROR; -% thresh = params.Edata; -% otherwise -% error('Invalid coding mode specified'); -% end -% elseif (isfield(params,'Tdata')) -% codemode = CODE_SPARSITY; -% thresh = params.Tdata; -% elseif (isfield(params,'Edata')) -% codemode = CODE_ERROR; -% thresh = params.Edata; -% -% else -% error('Data sparse-coding target not specified'); -% end - -thresh = params.Edata; -% max number of atoms % - -% if (codemode==CODE_ERROR && isfield(params,'maxatoms')) -% ompparams{end+1} = 'maxatoms'; -% ompparams{end+1} = params.maxatoms; -% end - - -% memory usage % - -if (isfield(params,'memusage')) - switch lower(params.memusage) - case 'low' - memusage = MEM_LOW; - case 'normal' - memusage = MEM_NORMAL; - case 'high' - memusage = MEM_HIGH; - otherwise - error('Invalid memory usage mode'); - end -else - memusage = MEM_NORMAL; -end - - -% iteration count % - -if (isfield(params,'iternum')) - iternum = params.iternum; -else - iternum = 10; -end - - -% omp function % - -if (codemode == CODE_SPARSITY) - ompfunc = @omp; -else - ompfunc = @omp2; -end - - -% % status messages % -% -% printiter = 0; -% printreplaced = 0; -% printerr = 0; -% printgerr = 0; -% -% verbose = 't'; -% msgdelta = -1; -% - -% -% for i = 1:length(verbose) -% switch lower(verbose(i)) -% case 'i' -% printiter = 1; -% case 'r' -% printiter = 1; -% printreplaced = 1; -% case 't' -% printiter = 1; -% printerr = 1; -% if (isfield(params,'testdata')) -% printgerr = 1; -% end -% end -% end -% -% if (msgdelta<=0 || isempty(verbose)) -% msgdelta = -1; -% end -% -% ompparams{end+1} = 'messages'; -% ompparams{end+1} = msgdelta; -% -% -% -% % compute error flag % -% -% comperr = (nargout>=3 || printerr); -% -% -% % validation flag % -% -% testgen = 0; -% if (isfield(params,'testdata')) -% testdata = params.testdata; -% if (nargout>=4 || printgerr) -% testgen = 1; -% end -% end - -% -% % data norms % -% -% XtX = []; XtXg = []; -% if (codemode==CODE_ERROR && memusage==MEM_HIGH) -% XtX = colnorms_squared(data); -% if (testgen) -% XtXg = colnorms_squared(testdata); -% end -% end - - -% mutual incoherence limit % - -if (isfield(params,'muthresh')) - muthresh = params.muthresh; -else - muthresh = 0.99; -end -if (muthresh < 0) - error('invalid muthresh value, must be non-negative'); -end - - - - - -% determine dictionary size % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - dictsize = length(params.initdict); - else - dictsize = size(params.initdict,2); - end -end -if (isfield(params,'dictsize')) % this superceedes the size determined by initdict - dictsize = params.dictsize; -end - -if (size(X,2) < dictsize) - error('Number of training signals is smaller than number of atoms to train'); -end - - -% initialize the dictionary % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - D = X(:,params.initdict(1:dictsize)); - else - if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2) 1e-6); % ensure no zero data elements are chosen - perm = randperm(length(data_ids)); - D = X(:,data_ids(perm(1:dictsize))); -end - - -% normalize the dictionary % - -% D = normcols(D); -% DtD=D'*D; - -err = zeros(1,iternum); -gerr = zeros(1,iternum); - -if (codemode == CODE_SPARSITY) - errstr = 'RMSE'; -else - errstr = 'mean atomnum'; -end -%X(:,p(X_norm_sortthresh); -%p1=p(X_norm_sort>thresh); -% X(:,setxor(p1,1:end))=0; -% X_im=col2imstep(X, [256 256], [8 8]); -% figure(10); imshow(X_im); -% if iternum==2 -% D(:,1)=D(:,2); -% end -%p1=p1(p2(1:40000)); -%end-min(40000, end)+1:end));%1:min(40000, end))); -%p1 = randperm(size(data,2));%size(data,2) -%data=data(:,p1); - -C=(100000*thresh)*eye(dictsize); -% figure(11); -w=zeros(dictsize,1); -replaced=zeros(dictsize,1); -u=zeros(dictsize,1); -% dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); -% figure(11);imshow(imresize(dictimg,2,'nearest')); -% pause(1); -lambda=0.9997%0.99986;%3+0.0001*params.linc; -for j=1:1 - %data=X; -if size(X,2)>40000 - p2 = randperm(size(X,2)); - - p2=sort(p2(1:40000));%min(floor(size(p1,2)/2),40000))); - size(p2,2) - data=X(:,p2); -elseif size(X,2)>0 - %p2 = randperm(size(p1,2)); - size(X,2) - data=X; -else - break; -end -% figure(1); -% plot(sqrt(sum(data.^2, 1))); -% a=size(data,2)/4; -% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 -%C(1,1)=0; -modi=1000; -for i = 1:size(data,2) -% if norm(data(:,i))>thresh - % par.multA= @(x,par) multMatr(D,x); % user function y=Ax - % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x - % par.y=data(:,i); - % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); - % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); - %w = SMALL_chol(D,data(:,i), 256,32, thresh);% - %w = sparsecode(data(:,i), D, [], [], thresh); - w = omp2mex(D,data(:,i),[],[],[],thresh,0,-1,-1,0); - - %w(find(w<1))=0; - %^2; -% lambda(i)=1-0.001/(1+i/a); -% if i=35000) -% modi=100; -% pause -% end; -% end -% end -end -%p1=p1(setxor(p2,1:end)); -%[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); -%replaced=zeros(dictsize,1); -% W=sparsecode(data, D, [], [], thresh); -% data=D*W; -%lambda=lambda+0.0002 -end -%Gamma=mexLasso(data, D, param); -%err=compute_err(D,Gamma, data); -%[y,i]=max(err); -%D(:,1)=data(:,i)/norm(data(:,i)); - -Dictionary = D;%D(:,p); -% figure(3); -% plot(lambda); -% mean(lambda); -% figure(4+j);plot(w_sp); -end - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% sparsecode % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - -function Gamma = sparsecode(data,D,XtX,G,thresh) - -global CODE_SPARSITY codemode -global MEM_HIGH memusage -global ompfunc ompparams - -if (memusage < MEM_HIGH) - Gamma = ompfunc(D,data,G,thresh,ompparams{:}); - -else % memusage is high - - if (codemode == CODE_SPARSITY) - Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); - - else - Gamma = ompfunc(D, data, G, thresh,ompparams{:}); - end - -end - -end - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% compute_err % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function err = compute_err(D,Gamma,data) - -global CODE_SPARSITY codemode - -if (codemode == CODE_SPARSITY) - err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); -else - err = nnz(Gamma)/size(data,2); -end - -end - - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% cleardict % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) - -use_thresh = 4; % at least this number of samples must use the atom to be kept - -dictsize = size(D,2); - -% compute error in blocks to conserve memory -% err = zeros(1,size(X,2)); -% blocks = [1:3000:size(X,2) size(X,2)+1]; -% for i = 1:length(blocks)-1 -% err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); -% end - -cleared_atoms = 0; -usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); - -for j = 1:dictsize - - % compute G(:,j) - Gj = D'*D(:,j); - Gj(j) = 0; - - % replace atom - if ( (max(Gj.^2)>muthresh^2 || usecount(j)=3 || printerr); -% -% -% % validation flag % -% -% testgen = 0; -% if (isfield(params,'testdata')) -% testdata = params.testdata; -% if (nargout>=4 || printgerr) -% testgen = 1; -% end -% end - -% -% % data norms % -% -% XtX = []; XtXg = []; -% if (codemode==CODE_ERROR && memusage==MEM_HIGH) -% XtX = colnorms_squared(data); -% if (testgen) -% XtXg = colnorms_squared(testdata); -% end -% end - - -% mutual incoherence limit % - -if (isfield(params,'muthresh')) - muthresh = params.muthresh; -else - muthresh = 0.99; -end -if (muthresh < 0) - error('invalid muthresh value, must be non-negative'); -end - - - - - -% determine dictionary size % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - dictsize = length(params.initdict); - else - dictsize = size(params.initdict,2); - end -end -if (isfield(params,'dictsize')) % this superceedes the size determined by initdict - dictsize = params.dictsize; -end - -if (size(X,2) < dictsize) - error('Number of training signals is smaller than number of atoms to train'); -end - - -% initialize the dictionary % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - D1 = X(:,params.initdict(1:dictsize)); - D2 = X(:,params.initdict(1:dictsize)); - else - if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2) 1e-6); % ensure no zero data elements are chosen - perm = randperm(length(data_ids)); - D = X(:,data_ids(perm(1:dictsize))); -end - - -% normalize the dictionary % - -% D = normcols(D); -% DtD=D'*D; - -err = zeros(1,iternum); -gerr = zeros(1,iternum); - -if (codemode == CODE_SPARSITY) - errstr = 'RMSE'; -else - errstr = 'mean atomnum'; -end -X(:,p(X_norm_sortthresh); - -%p1=p1(p2(1:40000)); -%end-min(40000, end)+1:end));%1:min(40000, end))); -%p1 = randperm(size(data,2));%size(data,2) -%data=data(:,p1); - -C1=(100000*thresh)*eye(dictsize); -C2=(100000*thresh)*eye(dictsize); -% figure(11); -w=zeros(dictsize,1); -replaced=zeros(dictsize,1); -u=zeros(dictsize,1); - dictimg = showdict(D1,[8 8],round(sqrt(size(D1,2))),round(sqrt(size(D1,2))),'lines','highcontrast'); -% h=imshow(imresize(dictimg,2,'nearest')); -lambda=0.9998 -for j=1:3 -if size(p1,2)>20000 - p2 = randperm(floor(size(p1,2)/2)); - p2=sort(p2(1:20000)); - data1=X(:,p1(p2)); - data2=X(:,p1(floor(size(p1,2)/2)+p2)); -elseif size(p1,2)>0 - data=X(:,p1); -else - break; -end -% figure(1); -% plot(sqrt(sum(data.^2, 1))); -% a=size(data,2)/4; -% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 -C1(1,1)=0; -C2(1,1)=0; -for i = 1:size(data1,2) -% if norm(data(:,i))>thresh - % par.multA= @(x,par) multMatr(D,x); % user function y=Ax - % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x - % par.y=data(:,i); - % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); - % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); - %w = SMALL_chol(D,data(:,i), 256,32, thresh);% - %w = sparsecode(data(:,i), D, [], [], thresh); - w1 = omp2mex(D1,data1(:,i),[],[],[],thresh,0,-1,-1,0); - w2 = omp2mex(D2,data2(:,i),[],[],[],thresh,0,-1,-1,0); - %w(find(w<1))=0; - %^2; -% lambda(i)=1-0.001/(1+i/a); -% if i1e-7, 2); - -for j = 1:dictsize - - % compute G(:,j) - Gj = D'*D(:,j); - Gj(j) = 0; - - % replace atom - if ( (max(Gj.^2)>muthresh^2 || usecount(j)=3 || printerr); -% -% -% % validation flag % -% -% testgen = 0; -% if (isfield(params,'testdata')) -% testdata = params.testdata; -% if (nargout>=4 || printgerr) -% testgen = 1; -% end -% end - -% -% % data norms % -% -% XtX = []; XtXg = []; -% if (codemode==CODE_ERROR && memusage==MEM_HIGH) -% XtX = colnorms_squared(data); -% if (testgen) -% XtXg = colnorms_squared(testdata); -% end -% end - - -% mutual incoherence limit % - -if (isfield(params,'muthresh')) - muthresh = params.muthresh; -else - muthresh = 0.99; -end -if (muthresh < 0) - error('invalid muthresh value, must be non-negative'); -end - - - - - -% determine dictionary size % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - dictsize = length(params.initdict); - else - dictsize = size(params.initdict,2); - end -end -if (isfield(params,'dictsize')) % this superceedes the size determined by initdict - dictsize = params.dictsize; -end - -if (size(X,2) < dictsize) - error('Number of training signals is smaller than number of atoms to train'); -end - - -% initialize the dictionary % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - D = X(:,params.initdict(1:dictsize)); - else - if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2) 1e-6); % ensure no zero data elements are chosen - perm = randperm(length(data_ids)); - D = X(:,data_ids(perm(1:dictsize))); -end - -% normalize the dictionary % - -% D = normcols(D); -% DtD=D'*D; - -err = zeros(1,iternum); -gerr = zeros(1,iternum); - -if (codemode == CODE_SPARSITY) - errstr = 'RMSE'; -else - errstr = 'mean atomnum'; -end -%X(:,p(X_norm_sortthresh); -p1=p(X_norm_sort>thresh); -tic; idx=kmeans(X(:,p1)',4, 'Start', 'cluster','MaxIter',200); toc -D=[D D D D]; -dictsize1=4*dictsize; -% X(:,setxor(p1,1:end))=0; -% X_im=col2imstep(X, [256 256], [8 8]); -% figure(10); imshow(X_im); -% if iternum==2 -% D(:,1)=D(:,2); -% end -%p1=p1(p2(1:40000)); -%end-min(40000, end)+1:end));%1:min(40000, end))); -%p1 = randperm(size(data,2));%size(data,2) -%data=data(:,p1); - -C=(100000*thresh)*eye(dictsize1); -% figure(11); -w=zeros(dictsize,1); -replaced=zeros(dictsize,1); -u=zeros(dictsize,1); -% dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); -% figure(11);imshow(imresize(dictimg,2,'nearest')); -% pause(1); -lambda=0.99986;%3+0.0001*params.linc; -for j=1:1 -if size(p1,2)>60000 - p2 = randperm(size(p1,2)); - - p2=sort(p2(1:60000));%min(floor(size(p1,2)/2),40000))); - size(p2,2) - data=X(:,p1(p2)); -elseif size(p1,2)>0 - p2 = randperm(size(p1,2)); - size(p2,2) - data=X(:,p1); -else - break; -end -% figure(1); -% plot(sqrt(sum(data.^2, 1))); -% a=size(data,2)/4; -% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 -%C(1,1)=0; -modi=1000; -for i = 1:size(data,2) -% if norm(data(:,i))>thresh - % par.multA= @(x,par) multMatr(D,x); % user function y=Ax - % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x - % par.y=data(:,i); - % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); - % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); - %w = SMALL_chol(D,data(:,i), 256,32, thresh);% - %w = sparsecode(data(:,i), D, [], [], thresh); - w = omp2mex(D(:,((idx(i)-1)*dictsize+1):idx(i)*dictsize),data(:,i),[],[],[],thresh,0,-1,-1,0); - - %w(find(w<1))=0; - %^2; -% lambda(i)=1-0.001/(1+i/a); -% if i=35000) -% modi=100; -% pause -% end; -% end -% end -end -%p1=p1(setxor(p2,1:end)); -%[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); -%replaced=zeros(dictsize,1); -% W=sparsecode(data, D, [], [], thresh); -% data=D*W; -lambda=lambda+0.0002 -end -%Gamma=mexLasso(data, D, param); -%err=compute_err(D,Gamma, data); -%[y,i]=max(err); -%D(:,1)=data(:,i)/norm(data(:,i)); -D=normcols(D); -D_norm=sqrt(sum(D.^2, 1)); -D_norm_1=sum(abs(D)); -% X_norm_1=sum(abs(X)); -% X_norm_inf=max(abs(X)); -[D_norm_sort, p]=sort(D_norm_1, 'descend'); -Dictionary = D;%D(:,p); -% figure(3); -% plot(lambda); -% mean(lambda); -% figure(4+j);plot(w_sp); -end - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% sparsecode % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - -function Gamma = sparsecode(data,D,XtX,G,thresh) - -global CODE_SPARSITY codemode -global MEM_HIGH memusage -global ompfunc ompparams - -if (memusage < MEM_HIGH) - Gamma = ompfunc(D,data,G,thresh,ompparams{:}); - -else % memusage is high - - if (codemode == CODE_SPARSITY) - Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); - - else - Gamma = ompfunc(D, data, G, thresh,ompparams{:}); - end - -end - -end - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% compute_err % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function err = compute_err(D,Gamma,data) - -global CODE_SPARSITY codemode - -if (codemode == CODE_SPARSITY) - err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); -else - err = nnz(Gamma)/size(data,2); -end - -end - - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% cleardict % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) - -use_thresh = 4; % at least this number of samples must use the atom to be kept - -dictsize = size(D,2); - -% compute error in blocks to conserve memory -% err = zeros(1,size(X,2)); -% blocks = [1:3000:size(X,2) size(X,2)+1]; -% for i = 1:length(blocks)-1 -% err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); -% end - -cleared_atoms = 0; -usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); - -for j = 1:dictsize - - % compute G(:,j) - Gj = D'*D(:,j); - Gj(j) = 0; - - % replace atom - if ( (max(Gj.^2)>muthresh^2 || usecount(j)