Mercurial > hg > smallbox
changeset 40:6416fc12f2b8
(none)
author | idamnjanovic |
---|---|
date | Mon, 14 Mar 2011 15:35:24 +0000 |
parents | 8f734534839a |
children | 83de4ea524df |
files | DL/RLS-DLA/SMALL_rlsdla 05032011.m DL/RLS-DLA/SMALL_rlsdla.m DL/RLS-DLA/SMALL_rlsdla1.m DL/RLS-DLA/SMALL_rlsdlaFirstClustTry.m DL/RLS-DLA/SolveFISTA.m |
diffstat | 5 files changed, 1820 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/RLS-DLA/SMALL_rlsdla 05032011.m Mon Mar 14 15:35:24 2011 +0000 @@ -0,0 +1,488 @@ +function Dictionary = SMALL_rlsdla(X, params) + + + + + +global CODE_SPARSITY CODE_ERROR codemode +global MEM_LOW MEM_NORMAL MEM_HIGH memusage +global ompfunc ompparams exactsvd + +CODE_SPARSITY = 1; +CODE_ERROR = 2; + +MEM_LOW = 1; +MEM_NORMAL = 2; +MEM_HIGH = 3; + + +% p = randperm(size(X,2)); + + % coding mode % +%X_norm=sqrt(sum(X.^2, 1)); +% X_norm_1=sum(abs(X)); +% X_norm_inf=max(abs(X)); +%[X_norm_sort, p]=sort(X_norm);%, 'descend'); +% [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); + +% if (isfield(params,'codemode')) +% switch lower(params.codemode) +% case 'sparsity' +% codemode = CODE_SPARSITY; +% thresh = params.Tdata; +% case 'error' +% codemode = CODE_ERROR; +% thresh = params.Edata; +% otherwise +% error('Invalid coding mode specified'); +% end +% elseif (isfield(params,'Tdata')) +% codemode = CODE_SPARSITY; +% thresh = params.Tdata; +% elseif (isfield(params,'Edata')) +% codemode = CODE_ERROR; +% thresh = params.Edata; +% +% else +% error('Data sparse-coding target not specified'); +% end + +thresh = params.Edata; +% max number of atoms % + +% if (codemode==CODE_ERROR && isfield(params,'maxatoms')) +% ompparams{end+1} = 'maxatoms'; +% ompparams{end+1} = params.maxatoms; +% end + + +% memory usage % + +if (isfield(params,'memusage')) + switch lower(params.memusage) + case 'low' + memusage = MEM_LOW; + case 'normal' + memusage = MEM_NORMAL; + case 'high' + memusage = MEM_HIGH; + otherwise + error('Invalid memory usage mode'); + end +else + memusage = MEM_NORMAL; +end + + +% iteration count % + +if (isfield(params,'iternum')) + iternum = params.iternum; +else + iternum = 10; +end + + +% omp function % + +if (codemode == CODE_SPARSITY) + ompfunc = @omp; +else + ompfunc = @omp2; +end + + +% % status messages % +% +% printiter = 0; +% printreplaced = 0; +% printerr = 0; +% printgerr = 0; +% +% verbose = 't'; +% msgdelta = -1; +% + +% +% for i = 1:length(verbose) +% switch lower(verbose(i)) +% case 'i' +% printiter = 1; +% case 'r' +% printiter = 1; +% printreplaced = 1; +% case 't' +% printiter = 1; +% printerr = 1; +% if (isfield(params,'testdata')) +% printgerr = 1; +% end +% end +% end +% +% if (msgdelta<=0 || isempty(verbose)) +% msgdelta = -1; +% end +% +% ompparams{end+1} = 'messages'; +% ompparams{end+1} = msgdelta; +% +% +% +% % compute error flag % +% +% comperr = (nargout>=3 || printerr); +% +% +% % validation flag % +% +% testgen = 0; +% if (isfield(params,'testdata')) +% testdata = params.testdata; +% if (nargout>=4 || printgerr) +% testgen = 1; +% end +% end + +% +% % data norms % +% +% XtX = []; XtXg = []; +% if (codemode==CODE_ERROR && memusage==MEM_HIGH) +% XtX = colnorms_squared(data); +% if (testgen) +% XtXg = colnorms_squared(testdata); +% end +% end + + +% mutual incoherence limit % + +if (isfield(params,'muthresh')) + muthresh = params.muthresh; +else + muthresh = 0.99; +end +if (muthresh < 0) + error('invalid muthresh value, must be non-negative'); +end + + + + + +% determine dictionary size % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + dictsize = length(params.initdict); + else + dictsize = size(params.initdict,2); + end +end +if (isfield(params,'dictsize')) % this superceedes the size determined by initdict + dictsize = params.dictsize; +end + +if (size(X,2) < dictsize) + error('Number of training signals is smaller than number of atoms to train'); +end + + +% initialize the dictionary % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + D = X(:,params.initdict(1:dictsize)); + else + if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) + error('Invalid initial dictionary'); + end + D = params.initdict(:,1:dictsize); + end +else + data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen + perm = randperm(length(data_ids)); + D = X(:,data_ids(perm(1:dictsize))); +end + + +% normalize the dictionary % + +% D = normcols(D); +% DtD=D'*D; + +err = zeros(1,iternum); +gerr = zeros(1,iternum); + +if (codemode == CODE_SPARSITY) + errstr = 'RMSE'; +else + errstr = 'mean atomnum'; +end +%X(:,p(X_norm_sort<thresh))=0; +% if (iternum==4) +% X_im=col2imstep(X, [256 256], [8 8]); +% else +% X_im=col2imstep(X, [512 512], [8 8]); +% end +% figure(10); imshow(X_im); + +%p1=p(cumsum(X_norm_sort)./[1:size(X_norm_sort,2)]>thresh); +%p1=p(X_norm_sort>thresh); +% X(:,setxor(p1,1:end))=0; +% X_im=col2imstep(X, [256 256], [8 8]); +% figure(10); imshow(X_im); +% if iternum==2 +% D(:,1)=D(:,2); +% end +%p1=p1(p2(1:40000)); +%end-min(40000, end)+1:end));%1:min(40000, end))); +%p1 = randperm(size(data,2));%size(data,2) +%data=data(:,p1); + +C=(100000*thresh)*eye(dictsize); +% figure(11); +w=zeros(dictsize,1); +replaced=zeros(dictsize,1); +u=zeros(dictsize,1); +% dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); +% figure(11);imshow(imresize(dictimg,2,'nearest')); +% pause(1); +lambda=0.9997%0.99986;%3+0.0001*params.linc; +for j=1:1 + %data=X; +if size(X,2)>40000 + p2 = randperm(size(X,2)); + + p2=sort(p2(1:40000));%min(floor(size(p1,2)/2),40000))); + size(p2,2) + data=X(:,p2); +elseif size(X,2)>0 + %p2 = randperm(size(p1,2)); + size(X,2) + data=X; +else + break; +end +% figure(1); +% plot(sqrt(sum(data.^2, 1))); +% a=size(data,2)/4; +% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 +%C(1,1)=0; +modi=1000; +for i = 1:size(data,2) +% if norm(data(:,i))>thresh + % par.multA= @(x,par) multMatr(D,x); % user function y=Ax + % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x + % par.y=data(:,i); + % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); + % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); + %w = SMALL_chol(D,data(:,i), 256,32, thresh);% + %w = sparsecode(data(:,i), D, [], [], thresh); + w = omp2mex(D,data(:,i),[],[],[],thresh,0,-1,-1,0); + + %w(find(w<1))=0; + %^2; +% lambda(i)=1-0.001/(1+i/a); +% if i<a +% lambda(i)=1-0.001*(1-(i/a)); +% else +% lambda(i)=1; +% end +% param.lambda=thresh; +% param.mode=2; +% param.L=32; +% w=mexLasso(data(:,i), D, param); + spind=find(w); + %replaced(spind)=replaced(spind)+1; + %-0.001*(1/2)^(i/a); +% w_sp(i)=nnz(w); + residual = data(:,i) - D * w; + %if ~isempty(spind) + %i + if (j==1) + C = C *(1/ lambda); + end + u = C(:,spind) * w(spind); + + %spindu=find(u); + % v = D' * residual; + + alfa = 1/(1 + w' * u); + + D = D + (alfa * residual) * u'; + + %uut=; + C = C - (alfa * u)* u'; + % lambda=(19*lambda+1)/20; + % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; + +% if (mod(i,modi)==0) +% Ximd=zeros(size(X)); +% Ximd(:,p1((i-modi+1:i)))=data(:,i-modi+1:i); +% +% if (iternum==4) +% X_ima(:,:,1)=col2imstep(Ximd, [256 256], [8 8]); +% X_ima(:,:,2)=col2imstep(X, [256 256], [8 8]); +% X_ima(:,:,3)=zeros(256,256); +% else +% X_ima(:,:,1)=col2imstep(Ximd, [512 512], [8 8]); +% X_ima(:,:,2)=col2imstep(X, [512 512], [8 8]); +% X_ima(:,:,3)=zeros(512,512); +% end +% +% dictimg1=dictimg; +% dictimg = showdict(D,[8 8],... +% round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); +% dictimg1=(dictimg-dictimg1); +% +% figure(2); +% subplot(2,2,1); imshow(X_ima); title(sprintf('%d',i)); +% subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); +% subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); +% subplot(2,2,2);imshow(C*(255/max(max(C)))); +% pause(0.02); +% if (i>=35000) +% modi=100; +% pause +% end; +% end +% end +end +%p1=p1(setxor(p2,1:end)); +%[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); +%replaced=zeros(dictsize,1); +% W=sparsecode(data, D, [], [], thresh); +% data=D*W; +%lambda=lambda+0.0002 +end +%Gamma=mexLasso(data, D, param); +%err=compute_err(D,Gamma, data); +%[y,i]=max(err); +%D(:,1)=data(:,i)/norm(data(:,i)); + +Dictionary = D;%D(:,p); +% figure(3); +% plot(lambda); +% mean(lambda); +% figure(4+j);plot(w_sp); +end + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% sparsecode % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +function Gamma = sparsecode(data,D,XtX,G,thresh) + +global CODE_SPARSITY codemode +global MEM_HIGH memusage +global ompfunc ompparams + +if (memusage < MEM_HIGH) + Gamma = ompfunc(D,data,G,thresh,ompparams{:}); + +else % memusage is high + + if (codemode == CODE_SPARSITY) + Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); + + else + Gamma = ompfunc(D, data, G, thresh,ompparams{:}); + end + +end + +end + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% compute_err % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +function err = compute_err(D,Gamma,data) + +global CODE_SPARSITY codemode + +if (codemode == CODE_SPARSITY) + err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); +else + err = nnz(Gamma)/size(data,2); +end + +end + + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% cleardict % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) + +use_thresh = 4; % at least this number of samples must use the atom to be kept + +dictsize = size(D,2); + +% compute error in blocks to conserve memory +% err = zeros(1,size(X,2)); +% blocks = [1:3000:size(X,2) size(X,2)+1]; +% for i = 1:length(blocks)-1 +% err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); +% end + +cleared_atoms = 0; +usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); + +for j = 1:dictsize + + % compute G(:,j) + Gj = D'*D(:,j); + Gj(j) = 0; + + % replace atom + if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) +% [y,i] = max(err(unused_sigs)); + D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); + unused_sigs = unused_sigs([1:end-1]); + cleared_atoms = cleared_atoms+1; + end +end + +end + + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% misc functions % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +function err2 = reperror2(X,D,Gamma) + +% compute in blocks to conserve memory +err2 = zeros(1,size(X,2)); +blocksize = 2000; +for i = 1:blocksize:size(X,2) + blockids = i : min(i+blocksize-1,size(X,2)); + err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); +end + +end + + +function Y = colnorms_squared(X) + +% compute in blocks to conserve memory +Y = zeros(1,size(X,2)); +blocksize = 2000; +for i = 1:blocksize:size(X,2) + blockids = i : min(i+blocksize-1,size(X,2)); + Y(blockids) = sum(X(:,blockids).^2); +end + +end + +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/RLS-DLA/SMALL_rlsdla.m Mon Mar 14 15:35:24 2011 +0000 @@ -0,0 +1,148 @@ +function Dictionary = SMALL_rlsdla(X, params) + + + + + + +CODE_SPARSITY = 1; +CODE_ERROR = 2; + + +% Determine which method will be used for sparse representation step - +% Sparsity or Error mode + +if (isfield(params,'codemode')) + switch lower(params.codemode) + case 'sparsity' + codemode = CODE_SPARSITY; + thresh = params.Tdata; + case 'error' + codemode = CODE_ERROR; + thresh = params.Edata; + + otherwise + error('Invalid coding mode specified'); + end +elseif (isfield(params,'Tdata')) + codemode = CODE_SPARSITY; + thresh = params.Tdata; +elseif (isfield(params,'Edata')) + codemode = CODE_ERROR; + thresh = params.Edata; + +else + error('Data sparse-coding target not specified'); +end + + +% max number of atoms % + +if (codemode==CODE_ERROR && isfield(params,'maxatoms')) + maxatoms = params.maxatoms; +else + maxatoms = -1; +end + + +% Forgetting factor + +if (isfield(params,'forgettingMode')) + switch lower(params.forgettingMode) + case 'fix' + if (isfield(params,'forgettingFactor')) + lambda=params.forgettingFactor; + else + lambda=1; + end + otherwise + error('This mode is still not implemented'); + end +elseif (isfield(params,'forgettingFactor')) + lambda=params.forgettingFactor; +else + lambda=1; +end + +% determine dictionary size % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + dictsize = length(params.initdict); + else + dictsize = size(params.initdict,2); + end +end +if (isfield(params,'dictsize')) % this superceedes the size determined by initdict + dictsize = params.dictsize; +end + +if (size(X,2) < dictsize) + error('Number of training signals is smaller than number of atoms to train'); +end + + +% initialize the dictionary % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + D = X(:,params.initdict(1:dictsize)); + else + if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) + error('Invalid initial dictionary'); + end + D = params.initdict(:,1:dictsize); + end +else + data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen + perm = randperm(length(data_ids)); + D = X(:,data_ids(perm(1:dictsize))); +end + + +% normalize the dictionary % + +D = normcols(D); + +% Training data + +data=X; + +% + +C=(100000*thresh)*eye(dictsize); +w=zeros(dictsize,1); +u=zeros(dictsize,1); + + +for i = 1:size(data,2) + + if (codemode == CODE_SPARSITY) + w = ompmex(D,data(:,i),[],[],thresh,1,-1,0); + else + w = omp2mex(D,data(:,i),[],[],[],thresh,0,-1,maxatoms,0); + end + + spind=find(w); + + residual = data(:,i) - D * w; + + if (lambda~=1) + C = C *(1/ lambda); + end + + u = C(:,spind) * w(spind); + + + alfa = 1/(1 + w' * u); + + D = D + (alfa * residual) * u'; + + + C = C - (alfa * u)* u'; + +end + +Dictionary = D; + +end
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/RLS-DLA/SMALL_rlsdla1.m Mon Mar 14 15:35:24 2011 +0000 @@ -0,0 +1,479 @@ +function Dictionary = SMALL_rlsdla1(X, params) + + + + + +global CODE_SPARSITY CODE_ERROR codemode +global MEM_LOW MEM_NORMAL MEM_HIGH memusage +global ompfunc ompparams exactsvd + +CODE_SPARSITY = 1; +CODE_ERROR = 2; + +MEM_LOW = 1; +MEM_NORMAL = 2; +MEM_HIGH = 3; + + +% p = randperm(size(X,2)); + + % coding mode % +X_norm=sqrt(sum(X.^2, 1)); +% X_norm_1=sum(abs(X)); +% X_norm_inf=max(abs(X)); +[X_norm_sort, p]=sort(X_norm);%, 'descend'); +% [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); + +% if (isfield(params,'codemode')) +% switch lower(params.codemode) +% case 'sparsity' +% codemode = CODE_SPARSITY; +% thresh = params.Tdata; +% case 'error' +% codemode = CODE_ERROR; +% thresh = params.Edata; +% otherwise +% error('Invalid coding mode specified'); +% end +% elseif (isfield(params,'Tdata')) +% codemode = CODE_SPARSITY; +% thresh = params.Tdata; +% elseif (isfield(params,'Edata')) +% codemode = CODE_ERROR; +% thresh = params.Edata; +% +% else +% error('Data sparse-coding target not specified'); +% end + +thresh = params.Edata; +% max number of atoms % + +% if (codemode==CODE_ERROR && isfield(params,'maxatoms')) +% ompparams{end+1} = 'maxatoms'; +% ompparams{end+1} = params.maxatoms; +% end + + +% memory usage % + +if (isfield(params,'memusage')) + switch lower(params.memusage) + case 'low' + memusage = MEM_LOW; + case 'normal' + memusage = MEM_NORMAL; + case 'high' + memusage = MEM_HIGH; + otherwise + error('Invalid memory usage mode'); + end +else + memusage = MEM_NORMAL; +end + + +% iteration count % + +if (isfield(params,'iternum')) + iternum = params.iternum; +else + iternum = 10; +end + + +% omp function % + +if (codemode == CODE_SPARSITY) + ompfunc = @omp; +else + ompfunc = @omp2; +end + + +% % status messages % +% +% printiter = 0; +% printreplaced = 0; +% printerr = 0; +% printgerr = 0; +% +% verbose = 't'; +% msgdelta = -1; +% + +% +% for i = 1:length(verbose) +% switch lower(verbose(i)) +% case 'i' +% printiter = 1; +% case 'r' +% printiter = 1; +% printreplaced = 1; +% case 't' +% printiter = 1; +% printerr = 1; +% if (isfield(params,'testdata')) +% printgerr = 1; +% end +% end +% end +% +% if (msgdelta<=0 || isempty(verbose)) +% msgdelta = -1; +% end +% +% ompparams{end+1} = 'messages'; +% ompparams{end+1} = msgdelta; +% +% +% +% % compute error flag % +% +% comperr = (nargout>=3 || printerr); +% +% +% % validation flag % +% +% testgen = 0; +% if (isfield(params,'testdata')) +% testdata = params.testdata; +% if (nargout>=4 || printgerr) +% testgen = 1; +% end +% end + +% +% % data norms % +% +% XtX = []; XtXg = []; +% if (codemode==CODE_ERROR && memusage==MEM_HIGH) +% XtX = colnorms_squared(data); +% if (testgen) +% XtXg = colnorms_squared(testdata); +% end +% end + + +% mutual incoherence limit % + +if (isfield(params,'muthresh')) + muthresh = params.muthresh; +else + muthresh = 0.99; +end +if (muthresh < 0) + error('invalid muthresh value, must be non-negative'); +end + + + + + +% determine dictionary size % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + dictsize = length(params.initdict); + else + dictsize = size(params.initdict,2); + end +end +if (isfield(params,'dictsize')) % this superceedes the size determined by initdict + dictsize = params.dictsize; +end + +if (size(X,2) < dictsize) + error('Number of training signals is smaller than number of atoms to train'); +end + + +% initialize the dictionary % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + D1 = X(:,params.initdict(1:dictsize)); + D2 = X(:,params.initdict(1:dictsize)); + else + if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) + error('Invalid initial dictionary'); + end + D1 = params.initdict(:,1:dictsize); + D2 = params.initdict(:,1:dictsize); + end +else + data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen + perm = randperm(length(data_ids)); + D = X(:,data_ids(perm(1:dictsize))); +end + + +% normalize the dictionary % + +% D = normcols(D); +% DtD=D'*D; + +err = zeros(1,iternum); +gerr = zeros(1,iternum); + +if (codemode == CODE_SPARSITY) + errstr = 'RMSE'; +else + errstr = 'mean atomnum'; +end +X(:,p(X_norm_sort<thresh))=0; +% if (iternum==4) +% X_im=col2imstep(X, [256 256], [8 8]); +% else +% X_im=col2imstep(X, [512 512], [8 8]); +% end +% figure(10); imshow(X_im); +p1=p(X_norm_sort>thresh); + +%p1=p1(p2(1:40000)); +%end-min(40000, end)+1:end));%1:min(40000, end))); +%p1 = randperm(size(data,2));%size(data,2) +%data=data(:,p1); + +C1=(100000*thresh)*eye(dictsize); +C2=(100000*thresh)*eye(dictsize); +% figure(11); +w=zeros(dictsize,1); +replaced=zeros(dictsize,1); +u=zeros(dictsize,1); + dictimg = showdict(D1,[8 8],round(sqrt(size(D1,2))),round(sqrt(size(D1,2))),'lines','highcontrast'); +% h=imshow(imresize(dictimg,2,'nearest')); +lambda=0.9998 +for j=1:3 +if size(p1,2)>20000 + p2 = randperm(floor(size(p1,2)/2)); + p2=sort(p2(1:20000)); + data1=X(:,p1(p2)); + data2=X(:,p1(floor(size(p1,2)/2)+p2)); +elseif size(p1,2)>0 + data=X(:,p1); +else + break; +end +% figure(1); +% plot(sqrt(sum(data.^2, 1))); +% a=size(data,2)/4; +% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 +C1(1,1)=0; +C2(1,1)=0; +for i = 1:size(data1,2) +% if norm(data(:,i))>thresh + % par.multA= @(x,par) multMatr(D,x); % user function y=Ax + % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x + % par.y=data(:,i); + % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); + % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); + %w = SMALL_chol(D,data(:,i), 256,32, thresh);% + %w = sparsecode(data(:,i), D, [], [], thresh); + w1 = omp2mex(D1,data1(:,i),[],[],[],thresh,0,-1,-1,0); + w2 = omp2mex(D2,data2(:,i),[],[],[],thresh,0,-1,-1,0); + %w(find(w<1))=0; + %^2; +% lambda(i)=1-0.001/(1+i/a); +% if i<a +% lambda(i)=1-0.001*(1-(i/a)); +% else +% lambda(i)=1; +% end +% param.lambda=thresh; +% param.mode=2; +% param.L=32; +% w=mexLasso(data(:,i), D, param); + spind1=find(w1); + spind2=find(w2); + + %replaced(spind)=replaced(spind)+1; + %-0.001*(1/2)^(i/a); +% w_sp(i)=nnz(w); + residual1 = data1(:,i) - D1 * w1; + residual2 = data2(:,i) - D2 * w2; + %if ~isempty(spind) + %i + + C1 = C1 *(1/ lambda); + C2 = C2 *(1/ lambda); + u1 = C1(:,spind1) * w1(spind1); + u2 = C2(:,spind2) * w2(spind2); + %spindu=find(u); + % v = D' * residual; + + alfa1 = 1/(1 + w1' * u1); + alfa2 = 1/(1 + w2' * u2); + D1 = D1 + (alfa1 * residual1) * u1'; + D2 = D2 + (alfa2 * residual2) * u2'; + %uut=; + C1 = C1 - (alfa1 * u1)* u1'; + C2 = C2 - (alfa2 * u2)* u2'; + % lambda=(19*lambda+1)/20; + % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; +% modi=5000; +% if (mod(i,modi)==0) +% Ximd=zeros(size(X)); +% Ximd(:,p((i-modi+1:i)))=data(:,i-modi+1:i); +% +% if (iternum==4) +% X_ima=col2imstep(Ximd, [256 256], [8 8]); +% else +% X_ima=col2imstep(Ximd, [512 512], [8 8]); +% end +% dictimg1=dictimg; +% dictimg = showdict(D,[8 8],... +% round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); +% dictimg1=(dictimg-dictimg1)*255; +% +% figure(2); +% subplot(2,2,1); imshow(X_ima); +% subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); +% subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); +% subplot(2,2,2);imshow(C*(255/max(max(C)))); +% pause(0.02); +% end +% end +end +%p1=p1(setxor(p2,1:end)); +%[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); +%replaced=zeros(dictsize,1); +% W=sparsecode(data, D, [], [], thresh); +% data=D*W; +lambda=lambda+0.0001 +end +%Gamma=mexLasso(data, D, param); +%err=compute_err(D,Gamma, data); +%[y,i]=max(err); +%D(:,1)=data(:,i)/norm(data(:,i)); +% D=normcols(D); +% D_norm=sqrt(sum(D.^2, 1)); +% D_norm_1=sum(abs(D)); +% X_norm_1=sum(abs(X)); +% X_norm_inf=max(abs(X)); +% [D_norm_sort, p]=sort(D_norm_1, 'descend'); +Dictionary =[D1 D2]; +% figure(3); +% plot(lambda); +% mean(lambda); +% figure(4+j);plot(w_sp); +end + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% sparsecode % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +function Gamma = sparsecode(data,D,XtX,G,thresh) + +global CODE_SPARSITY codemode +global MEM_HIGH memusage +global ompfunc ompparams + +if (memusage < MEM_HIGH) + Gamma = ompfunc(D,data,G,thresh,ompparams{:}); + +else % memusage is high + + if (codemode == CODE_SPARSITY) + Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); + + else + Gamma = ompfunc(D, data, G, thresh,ompparams{:}); + end + +end + +end + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% compute_err % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +function err = compute_err(D,Gamma,data) + +global CODE_SPARSITY codemode + +if (codemode == CODE_SPARSITY) + err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); +else + err = nnz(Gamma)/size(data,2); +end + +end + + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% cleardict % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) + +use_thresh = 4; % at least this number of samples must use the atom to be kept + +dictsize = size(D,2); + +% compute error in blocks to conserve memory +% err = zeros(1,size(X,2)); +% blocks = [1:3000:size(X,2) size(X,2)+1]; +% for i = 1:length(blocks)-1 +% err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); +% end + +cleared_atoms = 0; +usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); + +for j = 1:dictsize + + % compute G(:,j) + Gj = D'*D(:,j); + Gj(j) = 0; + + % replace atom + if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) +% [y,i] = max(err(unused_sigs)); + D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); + unused_sigs = unused_sigs([1:end-1]); + cleared_atoms = cleared_atoms+1; + end +end + +end + + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% misc functions % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +function err2 = reperror2(X,D,Gamma) + +% compute in blocks to conserve memory +err2 = zeros(1,size(X,2)); +blocksize = 2000; +for i = 1:blocksize:size(X,2) + blockids = i : min(i+blocksize-1,size(X,2)); + err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); +end + +end + + +function Y = colnorms_squared(X) + +% compute in blocks to conserve memory +Y = zeros(1,size(X,2)); +blocksize = 2000; +for i = 1:blocksize:size(X,2) + blockids = i : min(i+blocksize-1,size(X,2)); + Y(blockids) = sum(X(:,blockids).^2); +end + +end + +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/RLS-DLA/SMALL_rlsdlaFirstClustTry.m Mon Mar 14 15:35:24 2011 +0000 @@ -0,0 +1,494 @@ +function Dictionary = SMALL_rlsdla(X, params) + + + + + +global CODE_SPARSITY CODE_ERROR codemode +global MEM_LOW MEM_NORMAL MEM_HIGH memusage +global ompfunc ompparams exactsvd + +CODE_SPARSITY = 1; +CODE_ERROR = 2; + +MEM_LOW = 1; +MEM_NORMAL = 2; +MEM_HIGH = 3; + + +% p = randperm(size(X,2)); + + % coding mode % +X_norm=sqrt(sum(X.^2, 1)); +% X_norm_1=sum(abs(X)); +% X_norm_inf=max(abs(X)); +[X_norm_sort, p]=sort(X_norm);%, 'descend'); +% [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); + +% if (isfield(params,'codemode')) +% switch lower(params.codemode) +% case 'sparsity' +% codemode = CODE_SPARSITY; +% thresh = params.Tdata; +% case 'error' +% codemode = CODE_ERROR; +% thresh = params.Edata; +% otherwise +% error('Invalid coding mode specified'); +% end +% elseif (isfield(params,'Tdata')) +% codemode = CODE_SPARSITY; +% thresh = params.Tdata; +% elseif (isfield(params,'Edata')) +% codemode = CODE_ERROR; +% thresh = params.Edata; +% +% else +% error('Data sparse-coding target not specified'); +% end + +thresh = params.Edata; +% max number of atoms % + +% if (codemode==CODE_ERROR && isfield(params,'maxatoms')) +% ompparams{end+1} = 'maxatoms'; +% ompparams{end+1} = params.maxatoms; +% end + + +% memory usage % + +if (isfield(params,'memusage')) + switch lower(params.memusage) + case 'low' + memusage = MEM_LOW; + case 'normal' + memusage = MEM_NORMAL; + case 'high' + memusage = MEM_HIGH; + otherwise + error('Invalid memory usage mode'); + end +else + memusage = MEM_NORMAL; +end + + +% iteration count % + +if (isfield(params,'iternum')) + iternum = params.iternum; +else + iternum = 10; +end + + +% omp function % + +if (codemode == CODE_SPARSITY) + ompfunc = @omp; +else + ompfunc = @omp2; +end + + +% % status messages % +% +% printiter = 0; +% printreplaced = 0; +% printerr = 0; +% printgerr = 0; +% +% verbose = 't'; +% msgdelta = -1; +% + +% +% for i = 1:length(verbose) +% switch lower(verbose(i)) +% case 'i' +% printiter = 1; +% case 'r' +% printiter = 1; +% printreplaced = 1; +% case 't' +% printiter = 1; +% printerr = 1; +% if (isfield(params,'testdata')) +% printgerr = 1; +% end +% end +% end +% +% if (msgdelta<=0 || isempty(verbose)) +% msgdelta = -1; +% end +% +% ompparams{end+1} = 'messages'; +% ompparams{end+1} = msgdelta; +% +% +% +% % compute error flag % +% +% comperr = (nargout>=3 || printerr); +% +% +% % validation flag % +% +% testgen = 0; +% if (isfield(params,'testdata')) +% testdata = params.testdata; +% if (nargout>=4 || printgerr) +% testgen = 1; +% end +% end + +% +% % data norms % +% +% XtX = []; XtXg = []; +% if (codemode==CODE_ERROR && memusage==MEM_HIGH) +% XtX = colnorms_squared(data); +% if (testgen) +% XtXg = colnorms_squared(testdata); +% end +% end + + +% mutual incoherence limit % + +if (isfield(params,'muthresh')) + muthresh = params.muthresh; +else + muthresh = 0.99; +end +if (muthresh < 0) + error('invalid muthresh value, must be non-negative'); +end + + + + + +% determine dictionary size % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + dictsize = length(params.initdict); + else + dictsize = size(params.initdict,2); + end +end +if (isfield(params,'dictsize')) % this superceedes the size determined by initdict + dictsize = params.dictsize; +end + +if (size(X,2) < dictsize) + error('Number of training signals is smaller than number of atoms to train'); +end + + +% initialize the dictionary % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + D = X(:,params.initdict(1:dictsize)); + else + if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) + error('Invalid initial dictionary'); + end + D = params.initdict(:,1:dictsize); + end +else + data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen + perm = randperm(length(data_ids)); + D = X(:,data_ids(perm(1:dictsize))); +end + +% normalize the dictionary % + +% D = normcols(D); +% DtD=D'*D; + +err = zeros(1,iternum); +gerr = zeros(1,iternum); + +if (codemode == CODE_SPARSITY) + errstr = 'RMSE'; +else + errstr = 'mean atomnum'; +end +%X(:,p(X_norm_sort<thresh))=0; +% if (iternum==4) +% X_im=col2imstep(X, [256 256], [8 8]); +% else +% X_im=col2imstep(X, [512 512], [8 8]); +% end +% figure(10); imshow(X_im); + +%p1=p(cumsum(X_norm_sort)./[1:size(X_norm_sort,2)]>thresh); +p1=p(X_norm_sort>thresh); +tic; idx=kmeans(X(:,p1)',4, 'Start', 'cluster','MaxIter',200); toc +D=[D D D D]; +dictsize1=4*dictsize; +% X(:,setxor(p1,1:end))=0; +% X_im=col2imstep(X, [256 256], [8 8]); +% figure(10); imshow(X_im); +% if iternum==2 +% D(:,1)=D(:,2); +% end +%p1=p1(p2(1:40000)); +%end-min(40000, end)+1:end));%1:min(40000, end))); +%p1 = randperm(size(data,2));%size(data,2) +%data=data(:,p1); + +C=(100000*thresh)*eye(dictsize1); +% figure(11); +w=zeros(dictsize,1); +replaced=zeros(dictsize,1); +u=zeros(dictsize,1); +% dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); +% figure(11);imshow(imresize(dictimg,2,'nearest')); +% pause(1); +lambda=0.99986;%3+0.0001*params.linc; +for j=1:1 +if size(p1,2)>60000 + p2 = randperm(size(p1,2)); + + p2=sort(p2(1:60000));%min(floor(size(p1,2)/2),40000))); + size(p2,2) + data=X(:,p1(p2)); +elseif size(p1,2)>0 + p2 = randperm(size(p1,2)); + size(p2,2) + data=X(:,p1); +else + break; +end +% figure(1); +% plot(sqrt(sum(data.^2, 1))); +% a=size(data,2)/4; +% lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 +%C(1,1)=0; +modi=1000; +for i = 1:size(data,2) +% if norm(data(:,i))>thresh + % par.multA= @(x,par) multMatr(D,x); % user function y=Ax + % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x + % par.y=data(:,i); + % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); + % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); + %w = SMALL_chol(D,data(:,i), 256,32, thresh);% + %w = sparsecode(data(:,i), D, [], [], thresh); + w = omp2mex(D(:,((idx(i)-1)*dictsize+1):idx(i)*dictsize),data(:,i),[],[],[],thresh,0,-1,-1,0); + + %w(find(w<1))=0; + %^2; +% lambda(i)=1-0.001/(1+i/a); +% if i<a +% lambda(i)=1-0.001*(1-(i/a)); +% else +% lambda(i)=1; +% end +% param.lambda=thresh; +% param.mode=2; +% param.L=32; +% w=mexLasso(data(:,i), D, param); + spind=find(w); + %replaced(spind)=replaced(spind)+1; + %-0.001*(1/2)^(i/a); +% w_sp(i)=nnz(w); + residual = data(:,i) - D (:,((idx(i)-1)*dictsize+1):idx(i)*dictsize)* w; + %if ~isempty(spind) + %i + if (j==1) + C = C *(1/ lambda); + end + u = C(((idx(i)-1)*dictsize+1):idx(i)*dictsize,((idx(i)-1)*dictsize)+spind) * w(spind); + + %spindu=find(u); + % v = D' * residual; + + alfa = 1/(1 + w' * u); + + D(:,((idx(i)-1)*dictsize+1):idx(i)*dictsize) = D (:,((idx(i)-1)*dictsize+1):idx(i)*dictsize)+ (alfa * residual) * u'; + + %uut=; + C (((idx(i)-1)*dictsize+1):idx(i)*dictsize,((idx(i)-1)*dictsize+1):idx(i)*dictsize)= C(((idx(i)-1)*dictsize+1):idx(i)*dictsize,((idx(i)-1)*dictsize+1):idx(i)*dictsize) - (alfa * u)* u'; + % lambda=(19*lambda+1)/20; + % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; + +% if (mod(i,modi)==0) +% Ximd=zeros(size(X)); +% Ximd(:,p1((i-modi+1:i)))=data(:,i-modi+1:i); +% +% if (iternum==4) +% X_ima(:,:,1)=col2imstep(Ximd, [256 256], [8 8]); +% X_ima(:,:,2)=col2imstep(X, [256 256], [8 8]); +% X_ima(:,:,3)=zeros(256,256); +% else +% X_ima(:,:,1)=col2imstep(Ximd, [512 512], [8 8]); +% X_ima(:,:,2)=col2imstep(X, [512 512], [8 8]); +% X_ima(:,:,3)=zeros(512,512); +% end +% +% dictimg1=dictimg; +% dictimg = showdict(D,[8 8],... +% round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); +% dictimg1=(dictimg-dictimg1); +% +% figure(2); +% subplot(2,2,1); imshow(X_ima); title(sprintf('%d',i)); +% subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); +% subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); +% subplot(2,2,2);imshow(C*(255/max(max(C)))); +% pause(0.02); +% if (i>=35000) +% modi=100; +% pause +% end; +% end +% end +end +%p1=p1(setxor(p2,1:end)); +%[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); +%replaced=zeros(dictsize,1); +% W=sparsecode(data, D, [], [], thresh); +% data=D*W; +lambda=lambda+0.0002 +end +%Gamma=mexLasso(data, D, param); +%err=compute_err(D,Gamma, data); +%[y,i]=max(err); +%D(:,1)=data(:,i)/norm(data(:,i)); +D=normcols(D); +D_norm=sqrt(sum(D.^2, 1)); +D_norm_1=sum(abs(D)); +% X_norm_1=sum(abs(X)); +% X_norm_inf=max(abs(X)); +[D_norm_sort, p]=sort(D_norm_1, 'descend'); +Dictionary = D;%D(:,p); +% figure(3); +% plot(lambda); +% mean(lambda); +% figure(4+j);plot(w_sp); +end + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% sparsecode % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +function Gamma = sparsecode(data,D,XtX,G,thresh) + +global CODE_SPARSITY codemode +global MEM_HIGH memusage +global ompfunc ompparams + +if (memusage < MEM_HIGH) + Gamma = ompfunc(D,data,G,thresh,ompparams{:}); + +else % memusage is high + + if (codemode == CODE_SPARSITY) + Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); + + else + Gamma = ompfunc(D, data, G, thresh,ompparams{:}); + end + +end + +end + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% compute_err % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +function err = compute_err(D,Gamma,data) + +global CODE_SPARSITY codemode + +if (codemode == CODE_SPARSITY) + err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); +else + err = nnz(Gamma)/size(data,2); +end + +end + + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% cleardict % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) + +use_thresh = 4; % at least this number of samples must use the atom to be kept + +dictsize = size(D,2); + +% compute error in blocks to conserve memory +% err = zeros(1,size(X,2)); +% blocks = [1:3000:size(X,2) size(X,2)+1]; +% for i = 1:length(blocks)-1 +% err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); +% end + +cleared_atoms = 0; +usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); + +for j = 1:dictsize + + % compute G(:,j) + Gj = D'*D(:,j); + Gj(j) = 0; + + % replace atom + if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) +% [y,i] = max(err(unused_sigs)); + D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); + unused_sigs = unused_sigs([1:end-1]); + cleared_atoms = cleared_atoms+1; + end +end + +end + + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% misc functions % +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +function err2 = reperror2(X,D,Gamma) + +% compute in blocks to conserve memory +err2 = zeros(1,size(X,2)); +blocksize = 2000; +for i = 1:blocksize:size(X,2) + blockids = i : min(i+blocksize-1,size(X,2)); + err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); +end + +end + + +function Y = colnorms_squared(X) + +% compute in blocks to conserve memory +Y = zeros(1,size(X,2)); +blocksize = 2000; +for i = 1:blocksize:size(X,2) + blockids = i : min(i+blocksize-1,size(X,2)); + Y(blockids) = sum(X(:,blockids).^2); +end + +end + +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/RLS-DLA/SolveFISTA.m Mon Mar 14 15:35:24 2011 +0000 @@ -0,0 +1,211 @@ +% Copyright ©2010. The Regents of the University of California (Regents). +% All Rights Reserved. Contact The Office of Technology Licensing, +% UC Berkeley, 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, +% (510) 643-7201, for commercial licensing opportunities. + +% Authors: Arvind Ganesh, Allen Y. Yang, Zihan Zhou. +% Contact: Allen Y. Yang, Department of EECS, University of California, +% Berkeley. <yang@eecs.berkeley.edu> + +% IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +% SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +% ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +% REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +% REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED +% TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +% PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, +% PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO +% PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + +%% This function is modified from Matlab code proximal_gradient_bp + +function [x_hat,nIter] = SolveFISTA(A,b, varargin) + +% b - m x 1 vector of observations/data (required input) +% A - m x n measurement matrix (required input) +% +% tol - tolerance for stopping criterion. +% - DEFAULT 1e-7 if omitted or -1. +% maxIter - maxilambdam number of iterations +% - DEFAULT 10000, if omitted or -1. +% lineSearchFlag - 1 if line search is to be done every iteration +% - DEFAULT 0, if omitted or -1. +% continuationFlag - 1 if a continuation is to be done on the parameter lambda +% - DEFAULT 1, if omitted or -1. +% eta - line search parameter, should be in (0,1) +% - ignored if lineSearchFlag is 0. +% - DEFAULT 0.9, if omitted or -1. +% lambda - relaxation parameter +% - ignored if continuationFlag is 1. +% - DEFAULT 1e-3, if omitted or -1. +% outputFileName - Details of each iteration are dumped here, if provided. +% +% x_hat - estimate of coeeficient vector +% numIter - number of iterations until convergence +% +% +% References +% "Robust PCA: Exact Recovery of Corrupted Low-Rank Matrices via Convex Optimization", J. Wright et al., preprint 2009. +% "An Accelerated Proximal Gradient Algorithm for Nuclear Norm Regularized Least Squares problems", K.-C. Toh and S. Yun, preprint 2009. +% +% Arvind Ganesh, Summer 2009. Questions? abalasu2@illinois.edu + +DEBUG = 0 ; + +STOPPING_GROUND_TRUTH = -1; +STOPPING_DUALITY_GAP = 1; +STOPPING_SPARSE_SUPPORT = 2; +STOPPING_OBJECTIVE_VALUE = 3; +STOPPING_SUBGRADIENT = 4; +STOPPING_DEFAULT = STOPPING_SUBGRADIENT; + +stoppingCriterion = STOPPING_DEFAULT; +maxIter = 1000 ; +tolerance = 1e-3; +[m,n] = size(A) ; +x0 = zeros(n,1) ; +xG = []; + +%% Initializing optimization variables +t_k = 1 ; +t_km1 = 1 ; +L0 = 1 ; +G = A'*A ; +nIter = 0 ; +c = A'*b ; +lambda0 = 0.99*L0*norm(c,inf) ; +eta = 0.6 ; +lambda_bar = 1e-4*lambda0 ; +xk = zeros(n,1) ; +lambda = lambda0 ; +L = L0 ; +beta = 1.5 ; + +% Parse the optional inputs. +if (mod(length(varargin), 2) ~= 0 ), + error(['Extra Parameters passed to the function ''' mfilename ''' lambdast be passed in pairs.']); +end +parameterCount = length(varargin)/2; + +for parameterIndex = 1:parameterCount, + parameterName = varargin{parameterIndex*2 - 1}; + parameterValue = varargin{parameterIndex*2}; + switch lower(parameterName) + case 'stoppingcriterion' + stoppingCriterion = parameterValue; + case 'groundtruth' + xG = parameterValue; + case 'tolerance' + tolerance = parameterValue; + case 'linesearchflag' + lineSearchFlag = parameterValue; + case 'lambda' + lambda_bar = parameterValue; + case 'maxiteration' + maxIter = parameterValue; + case 'isnonnegative' + isNonnegative = parameterValue; + case 'continuationflag' + continuationFlag = parameterValue; + case 'initialization' + xk = parameterValue; + if ~all(size(xk)==[n,1]) + error('The dimension of the initial xk does not match.'); + end + case 'eta' + eta = parameterValue; + if ( eta <= 0 || eta >= 1 ) + disp('Line search parameter out of bounds, switching to default 0.9') ; + eta = 0.9 ; + end + otherwise + error(['The parameter ''' parameterName ''' is not recognized by the function ''' mfilename '''.']); + end +end +clear varargin + +if stoppingCriterion==STOPPING_GROUND_TRUTH && isempty(xG) + error('The stopping criterion must provide the ground truth value of x.'); +end + +keep_going = 1 ; +nz_x = (abs(xk)> eps*10); +f = 0.5*norm(b-A*xk)^2 + lambda_bar * norm(xk,1); +xkm1 = xk; +while keep_going && (nIter < maxIter) + nIter = nIter + 1 ; + + yk = xk + ((t_km1-1)/t_k)*(xk-xkm1) ; + + stop_backtrack = 0 ; + + temp = G*yk - c ; % gradient of f at yk + + while ~stop_backtrack + + gk = yk - (1/L)*temp ; + + xkp1 = soft(gk,lambda/L) ; + + temp1 = 0.5*norm(b-A*xkp1)^2 ; + temp2 = 0.5*norm(b-A*yk)^2 + (xkp1-yk)'*temp + (L/2)*norm(xkp1-yk)^2 ; + + if temp1 <= temp2 + stop_backtrack = 1 ; + else + L = L*beta ; + end + + end + + switch stoppingCriterion + case STOPPING_GROUND_TRUTH + keep_going = norm(xG-xkp1)>tolerance; + case STOPPING_SUBGRADIENT + sk = L*(yk-xkp1) + G*(xkp1-yk) ; + keep_going = norm(sk) > tolerance*L*max(1,norm(xkp1)); + case STOPPING_SPARSE_SUPPORT + % compute the stopping criterion based on the change + % of the number of non-zero components of the estimate + nz_x_prev = nz_x; + nz_x = (abs(xkp1)>eps*10); + num_nz_x = sum(nz_x(:)); + num_changes_active = (sum(nz_x(:)~=nz_x_prev(:))); + if num_nz_x >= 1 + criterionActiveSet = num_changes_active / num_nz_x; + keep_going = (criterionActiveSet > tolerance); + end + case STOPPING_OBJECTIVE_VALUE + % compute the stopping criterion based on the relative + % variation of the objective function. + prev_f = f; + f = 0.5*norm(b-A*xkp1)^2 + lambda_bar * norm(xk,1); + criterionObjective = abs(f-prev_f)/(prev_f); + keep_going = (criterionObjective > tolerance); + case STOPPING_DUALITY_GAP + error('Duality gap is not a valid stopping criterion for PGBP.'); + otherwise + error('Undefined stopping criterion.'); + end + + lambda = max(eta*lambda,lambda_bar) ; + + + t_kp1 = 0.5*(1+sqrt(1+4*t_k*t_k)) ; + + t_km1 = t_k ; + t_k = t_kp1 ; + xkm1 = xk ; + xk = xkp1 ; +end + +x_hat = xk ; + +function y = soft(x,T) +if sum(abs(T(:)))==0 + y = x; +else + y = max(abs(x) - T, 0); + y = sign(x).*y; +end