Mercurial > hg > smallbox
changeset 85:fd1c32cda22c
Comments to small_rlsdla.m, removed unfinished work.
author | Ivan <ivan.damnjanovic@eecs.qmul.ac.uk> |
---|---|
date | Tue, 05 Apr 2011 17:03:26 +0100 |
parents | 67aae1283973 |
children | f6cc633fd94b dab78a3598b6 |
files | DL/RLS-DLA/SMALL_rlsdla 05032011.m DL/RLS-DLA/SMALL_rlsdla.m DL/RLS-DLA/SMALL_rlsdla1.m DL/RLS-DLA/SMALL_rlsdlaFirstClustTry.m DL/RLS-DLA/SolveFISTA.m Problems/Cardiac_MRI_problem.m examples/Image Denoising/SMALL_ImgDenoise_DL_test_KSVDvsRLSDLA.m examples/Image Denoising/SMALL_ImgDenoise_dic_ODCT_solvers_OMP_BPDN_etc_test.m util/SMALL_solve.m |
diffstat | 9 files changed, 110 insertions(+), 2041 deletions(-) [+] |
line wrap: on
line diff
--- a/DL/RLS-DLA/SMALL_rlsdla 05032011.m Fri Apr 01 14:27:44 2011 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,488 +0,0 @@ -function Dictionary = SMALL_rlsdla(X, params) - - - - - -global CODE_SPARSITY CODE_ERROR codemode -global MEM_LOW MEM_NORMAL MEM_HIGH memusage -global ompfunc ompparams exactsvd - -CODE_SPARSITY = 1; -CODE_ERROR = 2; - -MEM_LOW = 1; -MEM_NORMAL = 2; -MEM_HIGH = 3; - - -% p = randperm(size(X,2)); - - % coding mode % -%X_norm=sqrt(sum(X.^2, 1)); -% X_norm_1=sum(abs(X)); -% X_norm_inf=max(abs(X)); -%[X_norm_sort, p]=sort(X_norm);%, 'descend'); -% [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); - -% if (isfield(params,'codemode')) -% switch lower(params.codemode) -% case 'sparsity' -% codemode = CODE_SPARSITY; -% thresh = params.Tdata; -% case 'error' -% codemode = CODE_ERROR; -% thresh = params.Edata; -% otherwise -% error('Invalid coding mode specified'); -% end -% elseif (isfield(params,'Tdata')) -% codemode = CODE_SPARSITY; -% thresh = params.Tdata; -% elseif (isfield(params,'Edata')) -% codemode = CODE_ERROR; -% thresh = params.Edata; -% -% else -% error('Data sparse-coding target not specified'); -% end - -thresh = params.Edata; -% max number of atoms % - -% if (codemode==CODE_ERROR && isfield(params,'maxatoms')) -% ompparams{end+1} = 'maxatoms'; -% ompparams{end+1} = params.maxatoms; -% end - - -% memory usage % - -if (isfield(params,'memusage')) - switch lower(params.memusage) - case 'low' - memusage = MEM_LOW; - case 'normal' - memusage = MEM_NORMAL; - case 'high' - memusage = MEM_HIGH; - otherwise - error('Invalid memory usage mode'); - end -else - memusage = MEM_NORMAL; -end - - -% iteration count % - -if (isfield(params,'iternum')) - iternum = params.iternum; -else - iternum = 10; -end - - -% omp function % - -if (codemode == CODE_SPARSITY) - ompfunc = @omp; -else - ompfunc = @omp2; -end - - -% % status messages % -% -% printiter = 0; -% printreplaced = 0; -% printerr = 0; -% printgerr = 0; -% -% verbose = 't'; -% msgdelta = -1; -% - -% -% for i = 1:length(verbose) -% switch lower(verbose(i)) -% case 'i' -% printiter = 1; -% case 'r' -% printiter = 1; -% printreplaced = 1; -% case 't' -% printiter = 1; -% printerr = 1; -% if (isfield(params,'testdata')) -% printgerr = 1; -% end -% end -% end -% -% if (msgdelta<=0 || isempty(verbose)) -% msgdelta = -1; -% end -% -% ompparams{end+1} = 'messages'; -% ompparams{end+1} = msgdelta; -% -% -% -% % compute error flag % -% -% comperr = (nargout>=3 || printerr); -% -% -% % validation flag % -% -% testgen = 0; -% if (isfield(params,'testdata')) -% testdata = params.testdata; -% if (nargout>=4 || printgerr) -% testgen = 1; -% end -% end - -% -% % data norms % -% -% XtX = []; XtXg = []; -% if (codemode==CODE_ERROR && memusage==MEM_HIGH) -% XtX = colnorms_squared(data); -% if (testgen) -% XtXg = colnorms_squared(testdata); -% end -% end - - -% mutual incoherence limit % - -if (isfield(params,'muthresh')) - muthresh = params.muthresh; -else - muthresh = 0.99; -end -if (muthresh < 0) - error('invalid muthresh value, must be non-negative'); -end - - - - - -% determine dictionary size % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - dictsize = length(params.initdict); - else - dictsize = size(params.initdict,2); - end -end -if (isfield(params,'dictsize')) % this superceedes the size determined by initdict - dictsize = params.dictsize; -end - -if (size(X,2) < dictsize) - error('Number of training signals is smaller than number of atoms to train'); -end - - -% initialize the dictionary % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - D = X(:,params.initdict(1:dictsize)); - else - if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) - error('Invalid initial dictionary'); - end - D = params.initdict(:,1:dictsize); - end -else - data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen - perm = randperm(length(data_ids)); - D = X(:,data_ids(perm(1:dictsize))); -end - - -% normalize the dictionary % - -% D = normcols(D); -% DtD=D'*D; - -err = zeros(1,iternum); -gerr = zeros(1,iternum); - -if (codemode == CODE_SPARSITY) - errstr = 'RMSE'; -else - errstr = 'mean atomnum'; -end -%X(:,p(X_norm_sort<thresh))=0; -% if (iternum==4) -% X_im=col2imstep(X, [256 256], [8 8]); -% else -% X_im=col2imstep(X, [512 512], [8 8]); -% end -% figure(10); imshow(X_im); - -%p1=p(cumsum(X_norm_sort)./[1:size(X_norm_sort,2)]>thresh); -%p1=p(X_norm_sort>thresh); -% X(:,setxor(p1,1:end))=0; -% X_im=col2imstep(X, [256 256], [8 8]); -% figure(10); imshow(X_im); -% if iternum==2 -% D(:,1)=D(:,2); -% end -%p1=p1(p2(1:40000)); -%end-min(40000, end)+1:end));%1:min(40000, end))); -%p1 = randperm(size(data,2));%size(data,2) -%data=data(:,p1); - -C=(100000*thresh)*eye(dictsize); -% figure(11); -w=zeros(dictsize,1); -replaced=zeros(dictsize,1); -u=zeros(dictsize,1); -% dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); -% figure(11);imshow(imresize(dictimg,2,'nearest')); -% pause(1); -lambda=0.9997%0.99986;%3+0.0001*params.linc; -for j=1:1 - %data=X; -if size(X,2)>40000 - p2 = randperm(size(X,2)); - - p2=sort(p2(1:40000));%min(floor(size(p1,2)/2),40000))); - size(p2,2) - data=X(:,p2); -elseif size(X,2)>0 - %p2 = randperm(size(p1,2)); - size(X,2) - data=X; -else - break; -end -% figure(1); -% plot(sqrt(sum(data.^2, 1))); -% a=size(data,2)/4; -% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 -%C(1,1)=0; -modi=1000; -for i = 1:size(data,2) -% if norm(data(:,i))>thresh - % par.multA= @(x,par) multMatr(D,x); % user function y=Ax - % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x - % par.y=data(:,i); - % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); - % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); - %w = SMALL_chol(D,data(:,i), 256,32, thresh);% - %w = sparsecode(data(:,i), D, [], [], thresh); - w = omp2mex(D,data(:,i),[],[],[],thresh,0,-1,-1,0); - - %w(find(w<1))=0; - %^2; -% lambda(i)=1-0.001/(1+i/a); -% if i<a -% lambda(i)=1-0.001*(1-(i/a)); -% else -% lambda(i)=1; -% end -% param.lambda=thresh; -% param.mode=2; -% param.L=32; -% w=mexLasso(data(:,i), D, param); - spind=find(w); - %replaced(spind)=replaced(spind)+1; - %-0.001*(1/2)^(i/a); -% w_sp(i)=nnz(w); - residual = data(:,i) - D * w; - %if ~isempty(spind) - %i - if (j==1) - C = C *(1/ lambda); - end - u = C(:,spind) * w(spind); - - %spindu=find(u); - % v = D' * residual; - - alfa = 1/(1 + w' * u); - - D = D + (alfa * residual) * u'; - - %uut=; - C = C - (alfa * u)* u'; - % lambda=(19*lambda+1)/20; - % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; - -% if (mod(i,modi)==0) -% Ximd=zeros(size(X)); -% Ximd(:,p1((i-modi+1:i)))=data(:,i-modi+1:i); -% -% if (iternum==4) -% X_ima(:,:,1)=col2imstep(Ximd, [256 256], [8 8]); -% X_ima(:,:,2)=col2imstep(X, [256 256], [8 8]); -% X_ima(:,:,3)=zeros(256,256); -% else -% X_ima(:,:,1)=col2imstep(Ximd, [512 512], [8 8]); -% X_ima(:,:,2)=col2imstep(X, [512 512], [8 8]); -% X_ima(:,:,3)=zeros(512,512); -% end -% -% dictimg1=dictimg; -% dictimg = showdict(D,[8 8],... -% round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); -% dictimg1=(dictimg-dictimg1); -% -% figure(2); -% subplot(2,2,1); imshow(X_ima); title(sprintf('%d',i)); -% subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); -% subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); -% subplot(2,2,2);imshow(C*(255/max(max(C)))); -% pause(0.02); -% if (i>=35000) -% modi=100; -% pause -% end; -% end -% end -end -%p1=p1(setxor(p2,1:end)); -%[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); -%replaced=zeros(dictsize,1); -% W=sparsecode(data, D, [], [], thresh); -% data=D*W; -%lambda=lambda+0.0002 -end -%Gamma=mexLasso(data, D, param); -%err=compute_err(D,Gamma, data); -%[y,i]=max(err); -%D(:,1)=data(:,i)/norm(data(:,i)); - -Dictionary = D;%D(:,p); -% figure(3); -% plot(lambda); -% mean(lambda); -% figure(4+j);plot(w_sp); -end - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% sparsecode % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - -function Gamma = sparsecode(data,D,XtX,G,thresh) - -global CODE_SPARSITY codemode -global MEM_HIGH memusage -global ompfunc ompparams - -if (memusage < MEM_HIGH) - Gamma = ompfunc(D,data,G,thresh,ompparams{:}); - -else % memusage is high - - if (codemode == CODE_SPARSITY) - Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); - - else - Gamma = ompfunc(D, data, G, thresh,ompparams{:}); - end - -end - -end - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% compute_err % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function err = compute_err(D,Gamma,data) - -global CODE_SPARSITY codemode - -if (codemode == CODE_SPARSITY) - err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); -else - err = nnz(Gamma)/size(data,2); -end - -end - - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% cleardict % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) - -use_thresh = 4; % at least this number of samples must use the atom to be kept - -dictsize = size(D,2); - -% compute error in blocks to conserve memory -% err = zeros(1,size(X,2)); -% blocks = [1:3000:size(X,2) size(X,2)+1]; -% for i = 1:length(blocks)-1 -% err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); -% end - -cleared_atoms = 0; -usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); - -for j = 1:dictsize - - % compute G(:,j) - Gj = D'*D(:,j); - Gj(j) = 0; - - % replace atom - if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) -% [y,i] = max(err(unused_sigs)); - D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); - unused_sigs = unused_sigs([1:end-1]); - cleared_atoms = cleared_atoms+1; - end -end - -end - - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% misc functions % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function err2 = reperror2(X,D,Gamma) - -% compute in blocks to conserve memory -err2 = zeros(1,size(X,2)); -blocksize = 2000; -for i = 1:blocksize:size(X,2) - blockids = i : min(i+blocksize-1,size(X,2)); - err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); -end - -end - - -function Y = colnorms_squared(X) - -% compute in blocks to conserve memory -Y = zeros(1,size(X,2)); -blocksize = 2000; -for i = 1:blocksize:size(X,2) - blockids = i : min(i+blocksize-1,size(X,2)); - Y(blockids) = sum(X(:,blockids).^2); -end - -end - -
--- a/DL/RLS-DLA/SMALL_rlsdla.m Fri Apr 01 14:27:44 2011 +0100 +++ b/DL/RLS-DLA/SMALL_rlsdla.m Tue Apr 05 17:03:26 2011 +0100 @@ -1,9 +1,44 @@ function Dictionary = SMALL_rlsdla(X, params) +%% Recursive Least Squares Dictionary Learning Algorithm +% +% D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for +% training signals specified as columns of matrix X with parameters +% specified in params structure returning the learned dictionary D. +% +% Fields in params structure: +% Required: +% 'Tdata' / 'Edata' sparse-coding target +% 'initdict' / 'dictsize' initial dictionary / dictionary size +% +% Optional (default values in parentheses): +% 'codemode' 'sparsity' or 'error' ('sparsity') +% 'maxatoms' max # of atoms in error sparse-coding (none) +% 'forgettingMode' 'fix' - fix forgetting factor, +% other modes are not implemented in +% this version(exponential etc.) +% 'forgettingFactor' for 'fix' mode (default is 1) +% 'show_dict' shows dictionary after # of +% iterations specified (less then 100 +% can make it running slow). In this +% version it assumes that it is image +% dictionary and atoms size is 8x8 +% +% - RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares +% Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on, +% vol.58, no.4, pp.2121-2130, April 2010 +% - - - +% Centre for Digital Music, Queen Mary, University of London. +% This file copyright 2011 Ivan Damnjanovic. +% +% This program is free software; you can redistribute it and/or +% modify it under the terms of the GNU General Public License as +% published by the Free Software Foundation; either version 2 of the +% License, or (at your option) any later version. See the file +% COPYING included with this distribution for more information. +% +%% CODE_SPARSITY = 1; CODE_ERROR = 2; @@ -45,6 +80,62 @@ end + + +% determine dictionary size % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + dictsize = length(params.initdict); + else + dictsize = size(params.initdict,2); + end +end +if (isfield(params,'dictsize')) % this superceedes the size determined by initdict + dictsize = params.dictsize; +end + +if (size(X,2) < dictsize) + error('Number of training signals is smaller than number of atoms to train'); +end + + +% initialize the dictionary % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + D = X(:,params.initdict(1:dictsize)); + else + if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) + error('Invalid initial dictionary'); + end + D = params.initdict(:,1:dictsize); + end +else + data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen + perm = randperm(length(data_ids)); + D = X(:,data_ids(perm(1:dictsize))); +end + + +% normalize the dictionary % + +D = normcols(D); + +% show dictonary every specified number of iterations + +if (isfield(params,'show_dict')) + show_dictionary=1; + show_iter=params.show_dict; +else + show_dictionary=0; + show_iter=0; +end + +if (show_dictionary) + dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); + figure(2); imshow(imresize(dictimg,2,'nearest')); +end % Forgetting factor if (isfield(params,'forgettingMode')) @@ -64,50 +155,11 @@ lambda=1; end -% determine dictionary size % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - dictsize = length(params.initdict); - else - dictsize = size(params.initdict,2); - end -end -if (isfield(params,'dictsize')) % this superceedes the size determined by initdict - dictsize = params.dictsize; -end - -if (size(X,2) < dictsize) - error('Number of training signals is smaller than number of atoms to train'); -end - - -% initialize the dictionary % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - D = X(:,params.initdict(1:dictsize)); - else - if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) - error('Invalid initial dictionary'); - end - D = params.initdict(:,1:dictsize); - end -else - data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen - perm = randperm(length(data_ids)); - D = X(:,data_ids(perm(1:dictsize))); -end - - -% normalize the dictionary % - -D = normcols(D); - % Training data data=X; cnt=size(data,2); + % C=(100000*thresh)*eye(dictsize); @@ -118,7 +170,7 @@ for i = 1:cnt if (codemode == CODE_SPARSITY) - w = ompmex(D,data(:,i),[],thresh,'checkdict','off'); + w = omp2(D,data(:,i),[],thresh,'checkdict','off'); else w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off'); end @@ -140,7 +192,12 @@ C = C - (alfa * u)* u'; - + if (show_dictionary &&(mod(i,show_iter)==0)) + dictimg = showdict(D,[8 8],... + round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); + figure(2); imshow(imresize(dictimg,2,'nearest')); + pause(0.02); + end end Dictionary = D;
--- a/DL/RLS-DLA/SMALL_rlsdla1.m Fri Apr 01 14:27:44 2011 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,479 +0,0 @@ -function Dictionary = SMALL_rlsdla1(X, params) - - - - - -global CODE_SPARSITY CODE_ERROR codemode -global MEM_LOW MEM_NORMAL MEM_HIGH memusage -global ompfunc ompparams exactsvd - -CODE_SPARSITY = 1; -CODE_ERROR = 2; - -MEM_LOW = 1; -MEM_NORMAL = 2; -MEM_HIGH = 3; - - -% p = randperm(size(X,2)); - - % coding mode % -X_norm=sqrt(sum(X.^2, 1)); -% X_norm_1=sum(abs(X)); -% X_norm_inf=max(abs(X)); -[X_norm_sort, p]=sort(X_norm);%, 'descend'); -% [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); - -% if (isfield(params,'codemode')) -% switch lower(params.codemode) -% case 'sparsity' -% codemode = CODE_SPARSITY; -% thresh = params.Tdata; -% case 'error' -% codemode = CODE_ERROR; -% thresh = params.Edata; -% otherwise -% error('Invalid coding mode specified'); -% end -% elseif (isfield(params,'Tdata')) -% codemode = CODE_SPARSITY; -% thresh = params.Tdata; -% elseif (isfield(params,'Edata')) -% codemode = CODE_ERROR; -% thresh = params.Edata; -% -% else -% error('Data sparse-coding target not specified'); -% end - -thresh = params.Edata; -% max number of atoms % - -% if (codemode==CODE_ERROR && isfield(params,'maxatoms')) -% ompparams{end+1} = 'maxatoms'; -% ompparams{end+1} = params.maxatoms; -% end - - -% memory usage % - -if (isfield(params,'memusage')) - switch lower(params.memusage) - case 'low' - memusage = MEM_LOW; - case 'normal' - memusage = MEM_NORMAL; - case 'high' - memusage = MEM_HIGH; - otherwise - error('Invalid memory usage mode'); - end -else - memusage = MEM_NORMAL; -end - - -% iteration count % - -if (isfield(params,'iternum')) - iternum = params.iternum; -else - iternum = 10; -end - - -% omp function % - -if (codemode == CODE_SPARSITY) - ompfunc = @omp; -else - ompfunc = @omp2; -end - - -% % status messages % -% -% printiter = 0; -% printreplaced = 0; -% printerr = 0; -% printgerr = 0; -% -% verbose = 't'; -% msgdelta = -1; -% - -% -% for i = 1:length(verbose) -% switch lower(verbose(i)) -% case 'i' -% printiter = 1; -% case 'r' -% printiter = 1; -% printreplaced = 1; -% case 't' -% printiter = 1; -% printerr = 1; -% if (isfield(params,'testdata')) -% printgerr = 1; -% end -% end -% end -% -% if (msgdelta<=0 || isempty(verbose)) -% msgdelta = -1; -% end -% -% ompparams{end+1} = 'messages'; -% ompparams{end+1} = msgdelta; -% -% -% -% % compute error flag % -% -% comperr = (nargout>=3 || printerr); -% -% -% % validation flag % -% -% testgen = 0; -% if (isfield(params,'testdata')) -% testdata = params.testdata; -% if (nargout>=4 || printgerr) -% testgen = 1; -% end -% end - -% -% % data norms % -% -% XtX = []; XtXg = []; -% if (codemode==CODE_ERROR && memusage==MEM_HIGH) -% XtX = colnorms_squared(data); -% if (testgen) -% XtXg = colnorms_squared(testdata); -% end -% end - - -% mutual incoherence limit % - -if (isfield(params,'muthresh')) - muthresh = params.muthresh; -else - muthresh = 0.99; -end -if (muthresh < 0) - error('invalid muthresh value, must be non-negative'); -end - - - - - -% determine dictionary size % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - dictsize = length(params.initdict); - else - dictsize = size(params.initdict,2); - end -end -if (isfield(params,'dictsize')) % this superceedes the size determined by initdict - dictsize = params.dictsize; -end - -if (size(X,2) < dictsize) - error('Number of training signals is smaller than number of atoms to train'); -end - - -% initialize the dictionary % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - D1 = X(:,params.initdict(1:dictsize)); - D2 = X(:,params.initdict(1:dictsize)); - else - if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) - error('Invalid initial dictionary'); - end - D1 = params.initdict(:,1:dictsize); - D2 = params.initdict(:,1:dictsize); - end -else - data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen - perm = randperm(length(data_ids)); - D = X(:,data_ids(perm(1:dictsize))); -end - - -% normalize the dictionary % - -% D = normcols(D); -% DtD=D'*D; - -err = zeros(1,iternum); -gerr = zeros(1,iternum); - -if (codemode == CODE_SPARSITY) - errstr = 'RMSE'; -else - errstr = 'mean atomnum'; -end -X(:,p(X_norm_sort<thresh))=0; -% if (iternum==4) -% X_im=col2imstep(X, [256 256], [8 8]); -% else -% X_im=col2imstep(X, [512 512], [8 8]); -% end -% figure(10); imshow(X_im); -p1=p(X_norm_sort>thresh); - -%p1=p1(p2(1:40000)); -%end-min(40000, end)+1:end));%1:min(40000, end))); -%p1 = randperm(size(data,2));%size(data,2) -%data=data(:,p1); - -C1=(100000*thresh)*eye(dictsize); -C2=(100000*thresh)*eye(dictsize); -% figure(11); -w=zeros(dictsize,1); -replaced=zeros(dictsize,1); -u=zeros(dictsize,1); - dictimg = showdict(D1,[8 8],round(sqrt(size(D1,2))),round(sqrt(size(D1,2))),'lines','highcontrast'); -% h=imshow(imresize(dictimg,2,'nearest')); -lambda=0.9998 -for j=1:3 -if size(p1,2)>20000 - p2 = randperm(floor(size(p1,2)/2)); - p2=sort(p2(1:20000)); - data1=X(:,p1(p2)); - data2=X(:,p1(floor(size(p1,2)/2)+p2)); -elseif size(p1,2)>0 - data=X(:,p1); -else - break; -end -% figure(1); -% plot(sqrt(sum(data.^2, 1))); -% a=size(data,2)/4; -% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 -C1(1,1)=0; -C2(1,1)=0; -for i = 1:size(data1,2) -% if norm(data(:,i))>thresh - % par.multA= @(x,par) multMatr(D,x); % user function y=Ax - % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x - % par.y=data(:,i); - % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); - % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); - %w = SMALL_chol(D,data(:,i), 256,32, thresh);% - %w = sparsecode(data(:,i), D, [], [], thresh); - w1 = omp2mex(D1,data1(:,i),[],[],[],thresh,0,-1,-1,0); - w2 = omp2mex(D2,data2(:,i),[],[],[],thresh,0,-1,-1,0); - %w(find(w<1))=0; - %^2; -% lambda(i)=1-0.001/(1+i/a); -% if i<a -% lambda(i)=1-0.001*(1-(i/a)); -% else -% lambda(i)=1; -% end -% param.lambda=thresh; -% param.mode=2; -% param.L=32; -% w=mexLasso(data(:,i), D, param); - spind1=find(w1); - spind2=find(w2); - - %replaced(spind)=replaced(spind)+1; - %-0.001*(1/2)^(i/a); -% w_sp(i)=nnz(w); - residual1 = data1(:,i) - D1 * w1; - residual2 = data2(:,i) - D2 * w2; - %if ~isempty(spind) - %i - - C1 = C1 *(1/ lambda); - C2 = C2 *(1/ lambda); - u1 = C1(:,spind1) * w1(spind1); - u2 = C2(:,spind2) * w2(spind2); - %spindu=find(u); - % v = D' * residual; - - alfa1 = 1/(1 + w1' * u1); - alfa2 = 1/(1 + w2' * u2); - D1 = D1 + (alfa1 * residual1) * u1'; - D2 = D2 + (alfa2 * residual2) * u2'; - %uut=; - C1 = C1 - (alfa1 * u1)* u1'; - C2 = C2 - (alfa2 * u2)* u2'; - % lambda=(19*lambda+1)/20; - % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; -% modi=5000; -% if (mod(i,modi)==0) -% Ximd=zeros(size(X)); -% Ximd(:,p((i-modi+1:i)))=data(:,i-modi+1:i); -% -% if (iternum==4) -% X_ima=col2imstep(Ximd, [256 256], [8 8]); -% else -% X_ima=col2imstep(Ximd, [512 512], [8 8]); -% end -% dictimg1=dictimg; -% dictimg = showdict(D,[8 8],... -% round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); -% dictimg1=(dictimg-dictimg1)*255; -% -% figure(2); -% subplot(2,2,1); imshow(X_ima); -% subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); -% subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); -% subplot(2,2,2);imshow(C*(255/max(max(C)))); -% pause(0.02); -% end -% end -end -%p1=p1(setxor(p2,1:end)); -%[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); -%replaced=zeros(dictsize,1); -% W=sparsecode(data, D, [], [], thresh); -% data=D*W; -lambda=lambda+0.0001 -end -%Gamma=mexLasso(data, D, param); -%err=compute_err(D,Gamma, data); -%[y,i]=max(err); -%D(:,1)=data(:,i)/norm(data(:,i)); -% D=normcols(D); -% D_norm=sqrt(sum(D.^2, 1)); -% D_norm_1=sum(abs(D)); -% X_norm_1=sum(abs(X)); -% X_norm_inf=max(abs(X)); -% [D_norm_sort, p]=sort(D_norm_1, 'descend'); -Dictionary =[D1 D2]; -% figure(3); -% plot(lambda); -% mean(lambda); -% figure(4+j);plot(w_sp); -end - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% sparsecode % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - -function Gamma = sparsecode(data,D,XtX,G,thresh) - -global CODE_SPARSITY codemode -global MEM_HIGH memusage -global ompfunc ompparams - -if (memusage < MEM_HIGH) - Gamma = ompfunc(D,data,G,thresh,ompparams{:}); - -else % memusage is high - - if (codemode == CODE_SPARSITY) - Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); - - else - Gamma = ompfunc(D, data, G, thresh,ompparams{:}); - end - -end - -end - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% compute_err % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function err = compute_err(D,Gamma,data) - -global CODE_SPARSITY codemode - -if (codemode == CODE_SPARSITY) - err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); -else - err = nnz(Gamma)/size(data,2); -end - -end - - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% cleardict % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) - -use_thresh = 4; % at least this number of samples must use the atom to be kept - -dictsize = size(D,2); - -% compute error in blocks to conserve memory -% err = zeros(1,size(X,2)); -% blocks = [1:3000:size(X,2) size(X,2)+1]; -% for i = 1:length(blocks)-1 -% err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); -% end - -cleared_atoms = 0; -usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); - -for j = 1:dictsize - - % compute G(:,j) - Gj = D'*D(:,j); - Gj(j) = 0; - - % replace atom - if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) -% [y,i] = max(err(unused_sigs)); - D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); - unused_sigs = unused_sigs([1:end-1]); - cleared_atoms = cleared_atoms+1; - end -end - -end - - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% misc functions % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function err2 = reperror2(X,D,Gamma) - -% compute in blocks to conserve memory -err2 = zeros(1,size(X,2)); -blocksize = 2000; -for i = 1:blocksize:size(X,2) - blockids = i : min(i+blocksize-1,size(X,2)); - err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); -end - -end - - -function Y = colnorms_squared(X) - -% compute in blocks to conserve memory -Y = zeros(1,size(X,2)); -blocksize = 2000; -for i = 1:blocksize:size(X,2) - blockids = i : min(i+blocksize-1,size(X,2)); - Y(blockids) = sum(X(:,blockids).^2); -end - -end - -
--- a/DL/RLS-DLA/SMALL_rlsdlaFirstClustTry.m Fri Apr 01 14:27:44 2011 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,494 +0,0 @@ -function Dictionary = SMALL_rlsdla(X, params) - - - - - -global CODE_SPARSITY CODE_ERROR codemode -global MEM_LOW MEM_NORMAL MEM_HIGH memusage -global ompfunc ompparams exactsvd - -CODE_SPARSITY = 1; -CODE_ERROR = 2; - -MEM_LOW = 1; -MEM_NORMAL = 2; -MEM_HIGH = 3; - - -% p = randperm(size(X,2)); - - % coding mode % -X_norm=sqrt(sum(X.^2, 1)); -% X_norm_1=sum(abs(X)); -% X_norm_inf=max(abs(X)); -[X_norm_sort, p]=sort(X_norm);%, 'descend'); -% [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); - -% if (isfield(params,'codemode')) -% switch lower(params.codemode) -% case 'sparsity' -% codemode = CODE_SPARSITY; -% thresh = params.Tdata; -% case 'error' -% codemode = CODE_ERROR; -% thresh = params.Edata; -% otherwise -% error('Invalid coding mode specified'); -% end -% elseif (isfield(params,'Tdata')) -% codemode = CODE_SPARSITY; -% thresh = params.Tdata; -% elseif (isfield(params,'Edata')) -% codemode = CODE_ERROR; -% thresh = params.Edata; -% -% else -% error('Data sparse-coding target not specified'); -% end - -thresh = params.Edata; -% max number of atoms % - -% if (codemode==CODE_ERROR && isfield(params,'maxatoms')) -% ompparams{end+1} = 'maxatoms'; -% ompparams{end+1} = params.maxatoms; -% end - - -% memory usage % - -if (isfield(params,'memusage')) - switch lower(params.memusage) - case 'low' - memusage = MEM_LOW; - case 'normal' - memusage = MEM_NORMAL; - case 'high' - memusage = MEM_HIGH; - otherwise - error('Invalid memory usage mode'); - end -else - memusage = MEM_NORMAL; -end - - -% iteration count % - -if (isfield(params,'iternum')) - iternum = params.iternum; -else - iternum = 10; -end - - -% omp function % - -if (codemode == CODE_SPARSITY) - ompfunc = @omp; -else - ompfunc = @omp2; -end - - -% % status messages % -% -% printiter = 0; -% printreplaced = 0; -% printerr = 0; -% printgerr = 0; -% -% verbose = 't'; -% msgdelta = -1; -% - -% -% for i = 1:length(verbose) -% switch lower(verbose(i)) -% case 'i' -% printiter = 1; -% case 'r' -% printiter = 1; -% printreplaced = 1; -% case 't' -% printiter = 1; -% printerr = 1; -% if (isfield(params,'testdata')) -% printgerr = 1; -% end -% end -% end -% -% if (msgdelta<=0 || isempty(verbose)) -% msgdelta = -1; -% end -% -% ompparams{end+1} = 'messages'; -% ompparams{end+1} = msgdelta; -% -% -% -% % compute error flag % -% -% comperr = (nargout>=3 || printerr); -% -% -% % validation flag % -% -% testgen = 0; -% if (isfield(params,'testdata')) -% testdata = params.testdata; -% if (nargout>=4 || printgerr) -% testgen = 1; -% end -% end - -% -% % data norms % -% -% XtX = []; XtXg = []; -% if (codemode==CODE_ERROR && memusage==MEM_HIGH) -% XtX = colnorms_squared(data); -% if (testgen) -% XtXg = colnorms_squared(testdata); -% end -% end - - -% mutual incoherence limit % - -if (isfield(params,'muthresh')) - muthresh = params.muthresh; -else - muthresh = 0.99; -end -if (muthresh < 0) - error('invalid muthresh value, must be non-negative'); -end - - - - - -% determine dictionary size % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - dictsize = length(params.initdict); - else - dictsize = size(params.initdict,2); - end -end -if (isfield(params,'dictsize')) % this superceedes the size determined by initdict - dictsize = params.dictsize; -end - -if (size(X,2) < dictsize) - error('Number of training signals is smaller than number of atoms to train'); -end - - -% initialize the dictionary % - -if (isfield(params,'initdict')) - if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) - D = X(:,params.initdict(1:dictsize)); - else - if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) - error('Invalid initial dictionary'); - end - D = params.initdict(:,1:dictsize); - end -else - data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen - perm = randperm(length(data_ids)); - D = X(:,data_ids(perm(1:dictsize))); -end - -% normalize the dictionary % - -% D = normcols(D); -% DtD=D'*D; - -err = zeros(1,iternum); -gerr = zeros(1,iternum); - -if (codemode == CODE_SPARSITY) - errstr = 'RMSE'; -else - errstr = 'mean atomnum'; -end -%X(:,p(X_norm_sort<thresh))=0; -% if (iternum==4) -% X_im=col2imstep(X, [256 256], [8 8]); -% else -% X_im=col2imstep(X, [512 512], [8 8]); -% end -% figure(10); imshow(X_im); - -%p1=p(cumsum(X_norm_sort)./[1:size(X_norm_sort,2)]>thresh); -p1=p(X_norm_sort>thresh); -tic; idx=kmeans(X(:,p1)',4, 'Start', 'cluster','MaxIter',200); toc -D=[D D D D]; -dictsize1=4*dictsize; -% X(:,setxor(p1,1:end))=0; -% X_im=col2imstep(X, [256 256], [8 8]); -% figure(10); imshow(X_im); -% if iternum==2 -% D(:,1)=D(:,2); -% end -%p1=p1(p2(1:40000)); -%end-min(40000, end)+1:end));%1:min(40000, end))); -%p1 = randperm(size(data,2));%size(data,2) -%data=data(:,p1); - -C=(100000*thresh)*eye(dictsize1); -% figure(11); -w=zeros(dictsize,1); -replaced=zeros(dictsize,1); -u=zeros(dictsize,1); -% dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); -% figure(11);imshow(imresize(dictimg,2,'nearest')); -% pause(1); -lambda=0.99986;%3+0.0001*params.linc; -for j=1:1 -if size(p1,2)>60000 - p2 = randperm(size(p1,2)); - - p2=sort(p2(1:60000));%min(floor(size(p1,2)/2),40000))); - size(p2,2) - data=X(:,p1(p2)); -elseif size(p1,2)>0 - p2 = randperm(size(p1,2)); - size(p2,2) - data=X(:,p1); -else - break; -end -% figure(1); -% plot(sqrt(sum(data.^2, 1))); -% a=size(data,2)/4; -% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 -%C(1,1)=0; -modi=1000; -for i = 1:size(data,2) -% if norm(data(:,i))>thresh - % par.multA= @(x,par) multMatr(D,x); % user function y=Ax - % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x - % par.y=data(:,i); - % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); - % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); - %w = SMALL_chol(D,data(:,i), 256,32, thresh);% - %w = sparsecode(data(:,i), D, [], [], thresh); - w = omp2mex(D(:,((idx(i)-1)*dictsize+1):idx(i)*dictsize),data(:,i),[],[],[],thresh,0,-1,-1,0); - - %w(find(w<1))=0; - %^2; -% lambda(i)=1-0.001/(1+i/a); -% if i<a -% lambda(i)=1-0.001*(1-(i/a)); -% else -% lambda(i)=1; -% end -% param.lambda=thresh; -% param.mode=2; -% param.L=32; -% w=mexLasso(data(:,i), D, param); - spind=find(w); - %replaced(spind)=replaced(spind)+1; - %-0.001*(1/2)^(i/a); -% w_sp(i)=nnz(w); - residual = data(:,i) - D (:,((idx(i)-1)*dictsize+1):idx(i)*dictsize)* w; - %if ~isempty(spind) - %i - if (j==1) - C = C *(1/ lambda); - end - u = C(((idx(i)-1)*dictsize+1):idx(i)*dictsize,((idx(i)-1)*dictsize)+spind) * w(spind); - - %spindu=find(u); - % v = D' * residual; - - alfa = 1/(1 + w' * u); - - D(:,((idx(i)-1)*dictsize+1):idx(i)*dictsize) = D (:,((idx(i)-1)*dictsize+1):idx(i)*dictsize)+ (alfa * residual) * u'; - - %uut=; - C (((idx(i)-1)*dictsize+1):idx(i)*dictsize,((idx(i)-1)*dictsize+1):idx(i)*dictsize)= C(((idx(i)-1)*dictsize+1):idx(i)*dictsize,((idx(i)-1)*dictsize+1):idx(i)*dictsize) - (alfa * u)* u'; - % lambda=(19*lambda+1)/20; - % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; - -% if (mod(i,modi)==0) -% Ximd=zeros(size(X)); -% Ximd(:,p1((i-modi+1:i)))=data(:,i-modi+1:i); -% -% if (iternum==4) -% X_ima(:,:,1)=col2imstep(Ximd, [256 256], [8 8]); -% X_ima(:,:,2)=col2imstep(X, [256 256], [8 8]); -% X_ima(:,:,3)=zeros(256,256); -% else -% X_ima(:,:,1)=col2imstep(Ximd, [512 512], [8 8]); -% X_ima(:,:,2)=col2imstep(X, [512 512], [8 8]); -% X_ima(:,:,3)=zeros(512,512); -% end -% -% dictimg1=dictimg; -% dictimg = showdict(D,[8 8],... -% round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); -% dictimg1=(dictimg-dictimg1); -% -% figure(2); -% subplot(2,2,1); imshow(X_ima); title(sprintf('%d',i)); -% subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); -% subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); -% subplot(2,2,2);imshow(C*(255/max(max(C)))); -% pause(0.02); -% if (i>=35000) -% modi=100; -% pause -% end; -% end -% end -end -%p1=p1(setxor(p2,1:end)); -%[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); -%replaced=zeros(dictsize,1); -% W=sparsecode(data, D, [], [], thresh); -% data=D*W; -lambda=lambda+0.0002 -end -%Gamma=mexLasso(data, D, param); -%err=compute_err(D,Gamma, data); -%[y,i]=max(err); -%D(:,1)=data(:,i)/norm(data(:,i)); -D=normcols(D); -D_norm=sqrt(sum(D.^2, 1)); -D_norm_1=sum(abs(D)); -% X_norm_1=sum(abs(X)); -% X_norm_inf=max(abs(X)); -[D_norm_sort, p]=sort(D_norm_1, 'descend'); -Dictionary = D;%D(:,p); -% figure(3); -% plot(lambda); -% mean(lambda); -% figure(4+j);plot(w_sp); -end - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% sparsecode % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - -function Gamma = sparsecode(data,D,XtX,G,thresh) - -global CODE_SPARSITY codemode -global MEM_HIGH memusage -global ompfunc ompparams - -if (memusage < MEM_HIGH) - Gamma = ompfunc(D,data,G,thresh,ompparams{:}); - -else % memusage is high - - if (codemode == CODE_SPARSITY) - Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); - - else - Gamma = ompfunc(D, data, G, thresh,ompparams{:}); - end - -end - -end - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% compute_err % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function err = compute_err(D,Gamma,data) - -global CODE_SPARSITY codemode - -if (codemode == CODE_SPARSITY) - err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); -else - err = nnz(Gamma)/size(data,2); -end - -end - - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% cleardict % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) - -use_thresh = 4; % at least this number of samples must use the atom to be kept - -dictsize = size(D,2); - -% compute error in blocks to conserve memory -% err = zeros(1,size(X,2)); -% blocks = [1:3000:size(X,2) size(X,2)+1]; -% for i = 1:length(blocks)-1 -% err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); -% end - -cleared_atoms = 0; -usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); - -for j = 1:dictsize - - % compute G(:,j) - Gj = D'*D(:,j); - Gj(j) = 0; - - % replace atom - if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) -% [y,i] = max(err(unused_sigs)); - D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); - unused_sigs = unused_sigs([1:end-1]); - cleared_atoms = cleared_atoms+1; - end -end - -end - - - -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -% misc functions % -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - - -function err2 = reperror2(X,D,Gamma) - -% compute in blocks to conserve memory -err2 = zeros(1,size(X,2)); -blocksize = 2000; -for i = 1:blocksize:size(X,2) - blockids = i : min(i+blocksize-1,size(X,2)); - err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); -end - -end - - -function Y = colnorms_squared(X) - -% compute in blocks to conserve memory -Y = zeros(1,size(X,2)); -blocksize = 2000; -for i = 1:blocksize:size(X,2) - blockids = i : min(i+blocksize-1,size(X,2)); - Y(blockids) = sum(X(:,blockids).^2); -end - -end - -
--- a/DL/RLS-DLA/SolveFISTA.m Fri Apr 01 14:27:44 2011 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,211 +0,0 @@ -% Copyright ©2010. The Regents of the University of California (Regents). -% All Rights Reserved. Contact The Office of Technology Licensing, -% UC Berkeley, 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, -% (510) 643-7201, for commercial licensing opportunities. - -% Authors: Arvind Ganesh, Allen Y. Yang, Zihan Zhou. -% Contact: Allen Y. Yang, Department of EECS, University of California, -% Berkeley. <yang@eecs.berkeley.edu> - -% IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, -% SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, -% ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF -% REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -% REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED -% TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -% PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, -% PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO -% PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. - -%% This function is modified from Matlab code proximal_gradient_bp - -function [x_hat,nIter] = SolveFISTA(A,b, varargin) - -% b - m x 1 vector of observations/data (required input) -% A - m x n measurement matrix (required input) -% -% tol - tolerance for stopping criterion. -% - DEFAULT 1e-7 if omitted or -1. -% maxIter - maxilambdam number of iterations -% - DEFAULT 10000, if omitted or -1. -% lineSearchFlag - 1 if line search is to be done every iteration -% - DEFAULT 0, if omitted or -1. -% continuationFlag - 1 if a continuation is to be done on the parameter lambda -% - DEFAULT 1, if omitted or -1. -% eta - line search parameter, should be in (0,1) -% - ignored if lineSearchFlag is 0. -% - DEFAULT 0.9, if omitted or -1. -% lambda - relaxation parameter -% - ignored if continuationFlag is 1. -% - DEFAULT 1e-3, if omitted or -1. -% outputFileName - Details of each iteration are dumped here, if provided. -% -% x_hat - estimate of coeeficient vector -% numIter - number of iterations until convergence -% -% -% References -% "Robust PCA: Exact Recovery of Corrupted Low-Rank Matrices via Convex Optimization", J. Wright et al., preprint 2009. -% "An Accelerated Proximal Gradient Algorithm for Nuclear Norm Regularized Least Squares problems", K.-C. Toh and S. Yun, preprint 2009. -% -% Arvind Ganesh, Summer 2009. Questions? abalasu2@illinois.edu - -DEBUG = 0 ; - -STOPPING_GROUND_TRUTH = -1; -STOPPING_DUALITY_GAP = 1; -STOPPING_SPARSE_SUPPORT = 2; -STOPPING_OBJECTIVE_VALUE = 3; -STOPPING_SUBGRADIENT = 4; -STOPPING_DEFAULT = STOPPING_SUBGRADIENT; - -stoppingCriterion = STOPPING_DEFAULT; -maxIter = 1000 ; -tolerance = 1e-3; -[m,n] = size(A) ; -x0 = zeros(n,1) ; -xG = []; - -%% Initializing optimization variables -t_k = 1 ; -t_km1 = 1 ; -L0 = 1 ; -G = A'*A ; -nIter = 0 ; -c = A'*b ; -lambda0 = 0.99*L0*norm(c,inf) ; -eta = 0.6 ; -lambda_bar = 1e-4*lambda0 ; -xk = zeros(n,1) ; -lambda = lambda0 ; -L = L0 ; -beta = 1.5 ; - -% Parse the optional inputs. -if (mod(length(varargin), 2) ~= 0 ), - error(['Extra Parameters passed to the function ''' mfilename ''' lambdast be passed in pairs.']); -end -parameterCount = length(varargin)/2; - -for parameterIndex = 1:parameterCount, - parameterName = varargin{parameterIndex*2 - 1}; - parameterValue = varargin{parameterIndex*2}; - switch lower(parameterName) - case 'stoppingcriterion' - stoppingCriterion = parameterValue; - case 'groundtruth' - xG = parameterValue; - case 'tolerance' - tolerance = parameterValue; - case 'linesearchflag' - lineSearchFlag = parameterValue; - case 'lambda' - lambda_bar = parameterValue; - case 'maxiteration' - maxIter = parameterValue; - case 'isnonnegative' - isNonnegative = parameterValue; - case 'continuationflag' - continuationFlag = parameterValue; - case 'initialization' - xk = parameterValue; - if ~all(size(xk)==[n,1]) - error('The dimension of the initial xk does not match.'); - end - case 'eta' - eta = parameterValue; - if ( eta <= 0 || eta >= 1 ) - disp('Line search parameter out of bounds, switching to default 0.9') ; - eta = 0.9 ; - end - otherwise - error(['The parameter ''' parameterName ''' is not recognized by the function ''' mfilename '''.']); - end -end -clear varargin - -if stoppingCriterion==STOPPING_GROUND_TRUTH && isempty(xG) - error('The stopping criterion must provide the ground truth value of x.'); -end - -keep_going = 1 ; -nz_x = (abs(xk)> eps*10); -f = 0.5*norm(b-A*xk)^2 + lambda_bar * norm(xk,1); -xkm1 = xk; -while keep_going && (nIter < maxIter) - nIter = nIter + 1 ; - - yk = xk + ((t_km1-1)/t_k)*(xk-xkm1) ; - - stop_backtrack = 0 ; - - temp = G*yk - c ; % gradient of f at yk - - while ~stop_backtrack - - gk = yk - (1/L)*temp ; - - xkp1 = soft(gk,lambda/L) ; - - temp1 = 0.5*norm(b-A*xkp1)^2 ; - temp2 = 0.5*norm(b-A*yk)^2 + (xkp1-yk)'*temp + (L/2)*norm(xkp1-yk)^2 ; - - if temp1 <= temp2 - stop_backtrack = 1 ; - else - L = L*beta ; - end - - end - - switch stoppingCriterion - case STOPPING_GROUND_TRUTH - keep_going = norm(xG-xkp1)>tolerance; - case STOPPING_SUBGRADIENT - sk = L*(yk-xkp1) + G*(xkp1-yk) ; - keep_going = norm(sk) > tolerance*L*max(1,norm(xkp1)); - case STOPPING_SPARSE_SUPPORT - % compute the stopping criterion based on the change - % of the number of non-zero components of the estimate - nz_x_prev = nz_x; - nz_x = (abs(xkp1)>eps*10); - num_nz_x = sum(nz_x(:)); - num_changes_active = (sum(nz_x(:)~=nz_x_prev(:))); - if num_nz_x >= 1 - criterionActiveSet = num_changes_active / num_nz_x; - keep_going = (criterionActiveSet > tolerance); - end - case STOPPING_OBJECTIVE_VALUE - % compute the stopping criterion based on the relative - % variation of the objective function. - prev_f = f; - f = 0.5*norm(b-A*xkp1)^2 + lambda_bar * norm(xk,1); - criterionObjective = abs(f-prev_f)/(prev_f); - keep_going = (criterionObjective > tolerance); - case STOPPING_DUALITY_GAP - error('Duality gap is not a valid stopping criterion for PGBP.'); - otherwise - error('Undefined stopping criterion.'); - end - - lambda = max(eta*lambda,lambda_bar) ; - - - t_kp1 = 0.5*(1+sqrt(1+4*t_k*t_k)) ; - - t_km1 = t_k ; - t_k = t_kp1 ; - xkm1 = xk ; - xk = xkp1 ; -end - -x_hat = xk ; - -function y = soft(x,T) -if sum(abs(T(:)))==0 - y = x; -else - y = max(abs(x) - T, 0); - y = sign(x).*y; -end
--- a/Problems/Cardiac_MRI_problem.m Fri Apr 01 14:27:44 2011 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,155 +0,0 @@ -function data = Cardiac_MRI_Problem(varargin) -% CHANGE!!!!PROB503 Shepp-Logan phantom, partial Fourier with sample mask, -% complex domain, total variation. -% -% PROB503 creates a problem structure. The generated signal will -% consist of a N = 256 by N Shepp-Logan phantom. The signal is -% sampled at random locations in frequency domain generated -% according to a probability density function. -% -% The following optional arguments are supported: -% -% PROB503('n',N,flags) is the same as above, but with a -% phantom of size N by N. The 'noseed' flag can be specified to -% suppress initialization of the random number generators. Both -% the parameter pair and flags can be omitted. -% -% Examples: -% P = prob503; % Creates the default 503 problem. -% -% References: -% -% [LustDonoPaul:2007] M. Lustig, D.L. Donoho and J.M. Pauly, -% Sparse MRI: The application of compressed sensing for rapid MR -% imaging, Submitted to Magnetic Resonance in Medicine, 2007. -% -% [sparsemri] M. Lustig, SparseMRI, -% http://www.stanford.edu/~mlustig/SparseMRI.html -% -% See also GENERATEPROBLEM. -% -%MATLAB SPARCO Toolbox. - -% Copyright 2008, Ewout van den Berg and Michael P. Friedlander -% http://www.cs.ubc.ca/labs/scl/sparco -% $Id: prob503.m 1040 2008-06-26 20:29:02Z ewout78 $ - -% Parse parameters and set problem name - -[opts,varg] = parseDefaultOpts(varargin{:}); -[parm,varg] = parseOptions(varg,{'noseed'},{'n','fold','sigma','slice'}); -n = getOption(parm,'n',256); -info.name = 'Cardiac_MRI'; -opts.show = 1; - - -fold = getOption(parm,'fold', 6); % undersampling level -sigma = getOption(parm,'sigma', 0.05);; % noise level -z = getOption(parm,'slice', 5);; % slice number (1-10) -szt = 20; % number of time samples - -% Return problem name if requested -if opts.getname, data = info.name; return; end; - -% Initialize random number generators -if (~parm.noseed), randn('state',0); rand('twister',2000); end; - -% Set up the data -% if allowed use variable density -%pdf = genPDF([n,n],5,0.1,2,0.1,0); - - - -%load heart images -FS=filesep; -TMPpath=pwd; - [pathstr1, name, ext, versn] = fileparts(which('SMALLboxSetup.m')); - cd([pathstr1,FS,'data',FS,'images',FS,'Cardiac_MRI_dataset',FS,'Images']); - [filename,pathname] = uigetfile({'*.mat;'},'Select a patient MRI image set'); - [pathstr, name, ext, versn] = fileparts(filename); -load(filename); -data.name=name; -cd(TMPpath); - -% Set up the problem - -% Get 3D matrix of heart images (size 256x256, 20 frames) and stack them to -% 2D matrix (256 x 256*20) -data.signal = reshape(sol_yxzt(:,:,z,:), [n n*szt]); - -% make a noise matrix - -noise_var=sqrt(sigma*var(reshape(data.signal, [n*n*szt 1]))); -data.noise = randn(n,n*szt)*noise_var + sqrt(-1)*randn(n,n*szt)*noise_var; - -% make a mask of random lines in phase encode and time domain random - vector -% of 0 and 1 of size n*szt multiplied with vector of 1 of size n - -mask = rand(n*szt,1); -mask(mask>(1-1/fold))=1; -mask(mask<=(1-1/fold))=0; -mask=(mask*ones(1,n))'; -data.op.mask = opMask(mask); -data.op.padding = opPadding([n,n*szt],[n,n*szt]); - -% make an fft 2D dictionary. It will do 2D fft on evry image in the stack -data.op.fft2d = opKron(opDiag(szt,1), opFFT2C(n,n)); - -% make measurement operator mask*padding*fft2d -data.M = opFoG(data.op.mask, data.op.padding, ... - data.op.fft2d); - -% make a mesurement vector b = M* (signal + noise) where s+n is stack to 1d vector -data.b = data.M(reshape(data.signal + data.noise,[n*n*szt,1]),1); - - -data = completeOps(data); - -% Additional information -info.title = 'Cardiac-MRI'; -info.thumb = 'figcardiacProblem'; -info.citations = {'LustDonoPaul:2007','sparsemri'}; -info.fig{1}.title = 'Cardiac MRI'; -% info.fig{1}.filename = 'figProblemCardiac'; -% info.fig{2}.title = 'Probability density function'; -% info.fig{2}.filename = 'figProblem503PDF'; -% info.fig{3}.title = 'Sampling mask'; -% info.fig{3}.filename = 'figProblem503Mask'; - -% Set the info field in data -data.info = info; -opts.figinc=1; -% Plot figures -if opts.update || opts.show - - %figure(opts.figno); opts.figno = opts.figno + opts.figinc; - - mov=reshape(data.signal/500, [n n szt]); - - implay(mov); - clear mov; - - %updateFigure(opts, info.fig{1}.title, info.fig{1}.filename); - - movMeas=reshape(abs(data.A(data.b,2))/500, [n n szt]); - implay(movMeas); - clear movMeas; -% figure(opts.figno); opts.figno = opts.figno + opts.figinc; -% imagesc(pdf), colormap gray; -% updateFigure(opts, info.fig{2}.title, info.fig{2}.filename) - - implay(reshape(mask, [n n szt])); - -% figure(opts.figno); opts.figno = opts.figno + opts.figinc; -% imagesc(mask), colormap gray -% updateFigure(opts, info.fig{3}.title, info.fig{3}.filename) -% -% if opts.update -% mn = min(min(data.signal + real(data.noise))); -% mx = max(max(data.signal + real(data.noise))); -% P = (data.signal + real(data.noise) - mn) / (mx - mn); -% P = scaleImage(P,128,128); -% P = P(1:2:end,1:2:end,:); -% thumbwrite(P, info.thumb, opts); -% end -end
--- a/examples/Image Denoising/SMALL_ImgDenoise_DL_test_KSVDvsRLSDLA.m Fri Apr 01 14:27:44 2011 +0100 +++ b/examples/Image Denoising/SMALL_ImgDenoise_DL_test_KSVDvsRLSDLA.m Tue Apr 05 17:03:26 2011 +0100 @@ -205,7 +205,8 @@ 'initdict', SMALL.Problem.initdict,... 'dictsize', SMALL.Problem.p,... 'forgettingMode', 'FIX',... - 'forgettingFactor', lambda); + 'forgettingFactor', lambda,... + 'show_dict', 500); SMALL.DL(3) = SMALL_learn(SMALL.Problem, SMALL.DL(3));
--- a/examples/Image Denoising/SMALL_ImgDenoise_dic_ODCT_solvers_OMP_BPDN_etc_test.m Fri Apr 01 14:27:44 2011 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,163 +0,0 @@ -%% DICTIONARY LEARNING FOR IMAGE DENOISING -% This file contains an example of how SMALLbox can be used to test different -% dictionary learning techniques in Image Denoising problem. -% It calls generateImageDenoiseProblem that will let you to choose image, -% add noise and use noisy image to generate training set for dictionary -% learning. -% Three dictionary learning techniques were compared: -% - KSVD - M. Elad, R. Rubinstein, and M. Zibulevsky, "Efficient -% Implementation of the K-SVD Algorithm using Batch Orthogonal -% Matching Pursuit", Technical Report - CS, Technion, April 2008. -% - KSVDS - R. Rubinstein, M. Zibulevsky, and M. Elad, "Learning Sparse -% Dictionaries for Sparse Signal Approximation", Technical -% Report - CS, Technion, June 2009. -% - SPAMS - J. Mairal, F. Bach, J. Ponce and G. Sapiro. Online -% Dictionary Learning for Sparse Coding. International -% Conference on Machine Learning,Montreal, Canada, 2009 -% -% -% Ivan Damnjanovic 2010 -%% - -clear; - -% If you want to load the image outside of generateImageDenoiseProblem -% function uncomment following lines. This can be useful if you want to -% denoise more then one image for example. - -% TMPpath=pwd; -% FS=filesep; -% [pathstr1, name, ext, versn] = fileparts(which('SMALLboxSetup.m')); -% cd([pathstr1,FS,'data',FS,'images']); -% [filename,pathname] = uigetfile({'*.png;'},'Select a file containin pre-calculated notes'); -% [pathstr, name, ext, versn] = fileparts(filename); -% test_image = imread(filename); -% test_image = double(test_image); -% cd(TMPpath); -% SMALL.Problem.name=name; - - -% Defining Image Denoising Problem as Dictionary Learning -% Problem. As an input we set the number of training patches. - -SMALL.Problem = generateImageDenoiseProblem('', 40000, '','', 20); - -Edata=sqrt(prod(SMALL.Problem.blocksize)) * SMALL.Problem.sigma * SMALL.Problem.gain; -maxatoms = floor(prod(SMALL.Problem.blocksize)/2); -%% Use KSVD Dictionary Learning Algorithm to Learn overcomplete dictionary -% -% % Initialising Dictionary structure -% % Setting Dictionary structure fields (toolbox, name, param, D and time) -% % to zero values -% -% SMALL.DL(1)=SMALL_init_DL(); -% -% % Defining the parameters needed for dictionary learning -% -% SMALL.DL(1).toolbox = 'KSVD'; -% SMALL.DL(1).name = 'ksvd'; -% -% % Defining the parameters for KSVD -% % In this example we are learning 256 atoms in 20 iterations, so that -% % every patch in the training set can be represented with target error in -% % L2-norm (EData) -% % Type help ksvd in MATLAB prompt for more options. -% -% -% SMALL.DL(1).param=struct(... -% 'Edata', Edata,... -% 'initdict', SMALL.Problem.initdict,... -% 'dictsize', SMALL.Problem.p,... -% 'iternum', 20,... -% 'memusage', 'high'); -% -% % Learn the dictionary -% -% SMALL.DL(1) = SMALL_learn(SMALL.Problem, SMALL.DL(1)); -%% Initialising Dictionary structure -% Setting Dictionary structure fields (toolbox, name, param, D and time) -% to zero values -% - -SMALL.DL(1)=SMALL_init_DL(); -% Take initial dictonary (overcomplete DCT) to be a final dictionary for -% reconstruction - -SMALL.DL(1).D=SMALL.Problem.initdict; -%% - -% Set SMALL.Problem.A dictionary -% (backward compatiblity with SPARCO: solver structure communicate -% only with Problem structure, ie no direct communication between DL and -% solver structures) -SMALL.Problem.A = SMALL.DL(1).D; - -SparseDict=0; -SMALL.Problem.reconstruct = @(x) ImgDenoise_reconstruct(x, SMALL.Problem, SparseDict); - -%% -% Initialising solver structure -% Setting solver structure fields (toolbox, name, param, solution, -% reconstructed and time) to zero values - - -SMALL.solver(1)=SMALL_init_solver; - -% Defining the parameters needed for image denoising - -SMALL.solver(1).toolbox='ompbox'; -SMALL.solver(1).name='omp2'; -SMALL.solver(1).param=struct(... - 'epsilon',Edata,... - 'maxatoms', maxatoms); - -% Denoising the image - SMALL_denoise function is similar to SMALL_solve, -% but backward compatible with KSVD definition of denoising -% Pay attention that since implicit base dictionary is used, denoising -% can be much faster then using explicit dictionary in KSVD example. - -SMALL.solver(1)=SMALL_solve(SMALL.Problem, SMALL.solver(1)); - -%% -% Initialising solver structure -% Setting solver structure fields (toolbox, name, param, solution, -% reconstructed and time) to zero values -lam=2*SMALL.Problem.sigma;%*sqrt(2*log2(size(SMALL.Problem.A,1))) -for i=1:11 - lambda(i)=lam+5-(i-1); -SMALL.DL(2)=SMALL_init_DL(); -i -%SMALL.Problem.A = SMALL.Problem.initdict; -SMALL.DL(2).D=SMALL.Problem.initdict; -SMALL.solver(2)=SMALL_init_solver; - -% Defining the parameters needed for image denoising - -SMALL.solver(2).toolbox='SPAMS'; -SMALL.solver(2).name='mexLasso'; -SMALL.solver(2).param=struct(... - 'mode', 2, ... - 'lambda',lambda(i),... - 'L', maxatoms); - -% Denoising the image - SMALL_denoise function is similar to SMALL_solve, -% but backward compatible with KSVD definition of denoising -% Pay attention that since implicit base dictionary is used, denoising -% can be much faster then using explicit dictionary in KSVD example. - -SMALL.solver(2)=SMALL_solve(SMALL.Problem, SMALL.solver(2)); - - -% show results % - -%SMALL_ImgDeNoiseResult(SMALL); - - time(1,i) = SMALL.solver(2).time; - psnr(1,i) = SMALL.solver(2).reconstructed.psnr; -end%% show time and psnr %% -figure('Name', 'SPAMS LAMBDA TEST'); - -subplot(1,2,1); plot(lambda, time(1,:), 'ro-'); -title('time vs lambda'); -subplot(1,2,2); plot(lambda, psnr(1,:), 'b*-'); -title('PSNR vs lambda'); \ No newline at end of file
--- a/util/SMALL_solve.m Fri Apr 01 14:27:44 2011 +0100 +++ b/util/SMALL_solve.m Tue Apr 05 17:03:26 2011 +0100 @@ -1,6 +1,12 @@ function solver = SMALL_solve(Problem, solver) %%% SMALL sparse solver % +% Function gets as input SMALL structure that contains SPARCO problem to +% be solved, name of the toolbox and solver, and parameters file for +% particular solver. +% +% Outputs are solution, reconstructed signal and time spent + % Centre for Digital Music, Queen Mary, University of London. % This file copyright 2009 Ivan Damnjanovic. % @@ -10,11 +16,6 @@ % License, or (at your option) any later version. See the file % COPYING included with this distribution for more information. % -% Function gets as input SMALL structure that contains SPARCO problem to -% be solved, name of the toolbox and solver, and parameters file for -% particular solver. -% -% Outputs are solution, reconstructed signal and time spent %% if isa(Problem.A,'float')