Mercurial > hg > smallbox
view DL/RLS-DLA/SMALL_rlsdla1.m @ 76:d052ec5b742f
update tags
author | convert-repo |
---|---|
date | Wed, 23 Mar 2011 17:08:55 +0000 |
parents | 6416fc12f2b8 |
children |
line wrap: on
line source
function Dictionary = SMALL_rlsdla1(X, params) global CODE_SPARSITY CODE_ERROR codemode global MEM_LOW MEM_NORMAL MEM_HIGH memusage global ompfunc ompparams exactsvd CODE_SPARSITY = 1; CODE_ERROR = 2; MEM_LOW = 1; MEM_NORMAL = 2; MEM_HIGH = 3; % p = randperm(size(X,2)); % coding mode % X_norm=sqrt(sum(X.^2, 1)); % X_norm_1=sum(abs(X)); % X_norm_inf=max(abs(X)); [X_norm_sort, p]=sort(X_norm);%, 'descend'); % [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); % if (isfield(params,'codemode')) % switch lower(params.codemode) % case 'sparsity' % codemode = CODE_SPARSITY; % thresh = params.Tdata; % case 'error' % codemode = CODE_ERROR; % thresh = params.Edata; % otherwise % error('Invalid coding mode specified'); % end % elseif (isfield(params,'Tdata')) % codemode = CODE_SPARSITY; % thresh = params.Tdata; % elseif (isfield(params,'Edata')) % codemode = CODE_ERROR; % thresh = params.Edata; % % else % error('Data sparse-coding target not specified'); % end thresh = params.Edata; % max number of atoms % % if (codemode==CODE_ERROR && isfield(params,'maxatoms')) % ompparams{end+1} = 'maxatoms'; % ompparams{end+1} = params.maxatoms; % end % memory usage % if (isfield(params,'memusage')) switch lower(params.memusage) case 'low' memusage = MEM_LOW; case 'normal' memusage = MEM_NORMAL; case 'high' memusage = MEM_HIGH; otherwise error('Invalid memory usage mode'); end else memusage = MEM_NORMAL; end % iteration count % if (isfield(params,'iternum')) iternum = params.iternum; else iternum = 10; end % omp function % if (codemode == CODE_SPARSITY) ompfunc = @omp; else ompfunc = @omp2; end % % status messages % % % printiter = 0; % printreplaced = 0; % printerr = 0; % printgerr = 0; % % verbose = 't'; % msgdelta = -1; % % % for i = 1:length(verbose) % switch lower(verbose(i)) % case 'i' % printiter = 1; % case 'r' % printiter = 1; % printreplaced = 1; % case 't' % printiter = 1; % printerr = 1; % if (isfield(params,'testdata')) % printgerr = 1; % end % end % end % % if (msgdelta<=0 || isempty(verbose)) % msgdelta = -1; % end % % ompparams{end+1} = 'messages'; % ompparams{end+1} = msgdelta; % % % % % compute error flag % % % comperr = (nargout>=3 || printerr); % % % % validation flag % % % testgen = 0; % if (isfield(params,'testdata')) % testdata = params.testdata; % if (nargout>=4 || printgerr) % testgen = 1; % end % end % % % data norms % % % XtX = []; XtXg = []; % if (codemode==CODE_ERROR && memusage==MEM_HIGH) % XtX = colnorms_squared(data); % if (testgen) % XtXg = colnorms_squared(testdata); % end % end % mutual incoherence limit % if (isfield(params,'muthresh')) muthresh = params.muthresh; else muthresh = 0.99; end if (muthresh < 0) error('invalid muthresh value, must be non-negative'); end % determine dictionary size % if (isfield(params,'initdict')) if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) dictsize = length(params.initdict); else dictsize = size(params.initdict,2); end end if (isfield(params,'dictsize')) % this superceedes the size determined by initdict dictsize = params.dictsize; end if (size(X,2) < dictsize) error('Number of training signals is smaller than number of atoms to train'); end % initialize the dictionary % if (isfield(params,'initdict')) if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) D1 = X(:,params.initdict(1:dictsize)); D2 = X(:,params.initdict(1:dictsize)); else if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) error('Invalid initial dictionary'); end D1 = params.initdict(:,1:dictsize); D2 = params.initdict(:,1:dictsize); end else data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen perm = randperm(length(data_ids)); D = X(:,data_ids(perm(1:dictsize))); end % normalize the dictionary % % D = normcols(D); % DtD=D'*D; err = zeros(1,iternum); gerr = zeros(1,iternum); if (codemode == CODE_SPARSITY) errstr = 'RMSE'; else errstr = 'mean atomnum'; end X(:,p(X_norm_sort<thresh))=0; % if (iternum==4) % X_im=col2imstep(X, [256 256], [8 8]); % else % X_im=col2imstep(X, [512 512], [8 8]); % end % figure(10); imshow(X_im); p1=p(X_norm_sort>thresh); %p1=p1(p2(1:40000)); %end-min(40000, end)+1:end));%1:min(40000, end))); %p1 = randperm(size(data,2));%size(data,2) %data=data(:,p1); C1=(100000*thresh)*eye(dictsize); C2=(100000*thresh)*eye(dictsize); % figure(11); w=zeros(dictsize,1); replaced=zeros(dictsize,1); u=zeros(dictsize,1); dictimg = showdict(D1,[8 8],round(sqrt(size(D1,2))),round(sqrt(size(D1,2))),'lines','highcontrast'); % h=imshow(imresize(dictimg,2,'nearest')); lambda=0.9998 for j=1:3 if size(p1,2)>20000 p2 = randperm(floor(size(p1,2)/2)); p2=sort(p2(1:20000)); data1=X(:,p1(p2)); data2=X(:,p1(floor(size(p1,2)/2)+p2)); elseif size(p1,2)>0 data=X(:,p1); else break; end % figure(1); % plot(sqrt(sum(data.^2, 1))); % a=size(data,2)/4; % lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 C1(1,1)=0; C2(1,1)=0; for i = 1:size(data1,2) % if norm(data(:,i))>thresh % par.multA= @(x,par) multMatr(D,x); % user function y=Ax % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x % par.y=data(:,i); % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); %w = SMALL_chol(D,data(:,i), 256,32, thresh);% %w = sparsecode(data(:,i), D, [], [], thresh); w1 = omp2mex(D1,data1(:,i),[],[],[],thresh,0,-1,-1,0); w2 = omp2mex(D2,data2(:,i),[],[],[],thresh,0,-1,-1,0); %w(find(w<1))=0; %^2; % lambda(i)=1-0.001/(1+i/a); % if i<a % lambda(i)=1-0.001*(1-(i/a)); % else % lambda(i)=1; % end % param.lambda=thresh; % param.mode=2; % param.L=32; % w=mexLasso(data(:,i), D, param); spind1=find(w1); spind2=find(w2); %replaced(spind)=replaced(spind)+1; %-0.001*(1/2)^(i/a); % w_sp(i)=nnz(w); residual1 = data1(:,i) - D1 * w1; residual2 = data2(:,i) - D2 * w2; %if ~isempty(spind) %i C1 = C1 *(1/ lambda); C2 = C2 *(1/ lambda); u1 = C1(:,spind1) * w1(spind1); u2 = C2(:,spind2) * w2(spind2); %spindu=find(u); % v = D' * residual; alfa1 = 1/(1 + w1' * u1); alfa2 = 1/(1 + w2' * u2); D1 = D1 + (alfa1 * residual1) * u1'; D2 = D2 + (alfa2 * residual2) * u2'; %uut=; C1 = C1 - (alfa1 * u1)* u1'; C2 = C2 - (alfa2 * u2)* u2'; % lambda=(19*lambda+1)/20; % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; % modi=5000; % if (mod(i,modi)==0) % Ximd=zeros(size(X)); % Ximd(:,p((i-modi+1:i)))=data(:,i-modi+1:i); % % if (iternum==4) % X_ima=col2imstep(Ximd, [256 256], [8 8]); % else % X_ima=col2imstep(Ximd, [512 512], [8 8]); % end % dictimg1=dictimg; % dictimg = showdict(D,[8 8],... % round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); % dictimg1=(dictimg-dictimg1)*255; % % figure(2); % subplot(2,2,1); imshow(X_ima); % subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); % subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); % subplot(2,2,2);imshow(C*(255/max(max(C)))); % pause(0.02); % end % end end %p1=p1(setxor(p2,1:end)); %[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); %replaced=zeros(dictsize,1); % W=sparsecode(data, D, [], [], thresh); % data=D*W; lambda=lambda+0.0001 end %Gamma=mexLasso(data, D, param); %err=compute_err(D,Gamma, data); %[y,i]=max(err); %D(:,1)=data(:,i)/norm(data(:,i)); % D=normcols(D); % D_norm=sqrt(sum(D.^2, 1)); % D_norm_1=sum(abs(D)); % X_norm_1=sum(abs(X)); % X_norm_inf=max(abs(X)); % [D_norm_sort, p]=sort(D_norm_1, 'descend'); Dictionary =[D1 D2]; % figure(3); % plot(lambda); % mean(lambda); % figure(4+j);plot(w_sp); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % sparsecode % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function Gamma = sparsecode(data,D,XtX,G,thresh) global CODE_SPARSITY codemode global MEM_HIGH memusage global ompfunc ompparams if (memusage < MEM_HIGH) Gamma = ompfunc(D,data,G,thresh,ompparams{:}); else % memusage is high if (codemode == CODE_SPARSITY) Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); else Gamma = ompfunc(D, data, G, thresh,ompparams{:}); end end end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % compute_err % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function err = compute_err(D,Gamma,data) global CODE_SPARSITY codemode if (codemode == CODE_SPARSITY) err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); else err = nnz(Gamma)/size(data,2); end end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % cleardict % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) use_thresh = 4; % at least this number of samples must use the atom to be kept dictsize = size(D,2); % compute error in blocks to conserve memory % err = zeros(1,size(X,2)); % blocks = [1:3000:size(X,2) size(X,2)+1]; % for i = 1:length(blocks)-1 % err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); % end cleared_atoms = 0; usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); for j = 1:dictsize % compute G(:,j) Gj = D'*D(:,j); Gj(j) = 0; % replace atom if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) % [y,i] = max(err(unused_sigs)); D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); unused_sigs = unused_sigs([1:end-1]); cleared_atoms = cleared_atoms+1; end end end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % misc functions % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function err2 = reperror2(X,D,Gamma) % compute in blocks to conserve memory err2 = zeros(1,size(X,2)); blocksize = 2000; for i = 1:blocksize:size(X,2) blockids = i : min(i+blocksize-1,size(X,2)); err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); end end function Y = colnorms_squared(X) % compute in blocks to conserve memory Y = zeros(1,size(X,2)); blocksize = 2000; for i = 1:blocksize:size(X,2) blockids = i : min(i+blocksize-1,size(X,2)); Y(blockids) = sum(X(:,blockids).^2); end end