idamnjanovic@40: function Dictionary = SMALL_rlsdla(X, params) idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: global CODE_SPARSITY CODE_ERROR codemode idamnjanovic@40: global MEM_LOW MEM_NORMAL MEM_HIGH memusage idamnjanovic@40: global ompfunc ompparams exactsvd idamnjanovic@40: idamnjanovic@40: CODE_SPARSITY = 1; idamnjanovic@40: CODE_ERROR = 2; idamnjanovic@40: idamnjanovic@40: MEM_LOW = 1; idamnjanovic@40: MEM_NORMAL = 2; idamnjanovic@40: MEM_HIGH = 3; idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % p = randperm(size(X,2)); idamnjanovic@40: idamnjanovic@40: % coding mode % idamnjanovic@40: %X_norm=sqrt(sum(X.^2, 1)); idamnjanovic@40: % X_norm_1=sum(abs(X)); idamnjanovic@40: % X_norm_inf=max(abs(X)); idamnjanovic@40: %[X_norm_sort, p]=sort(X_norm);%, 'descend'); idamnjanovic@40: % [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); idamnjanovic@40: idamnjanovic@40: % if (isfield(params,'codemode')) idamnjanovic@40: % switch lower(params.codemode) idamnjanovic@40: % case 'sparsity' idamnjanovic@40: % codemode = CODE_SPARSITY; idamnjanovic@40: % thresh = params.Tdata; idamnjanovic@40: % case 'error' idamnjanovic@40: % codemode = CODE_ERROR; idamnjanovic@40: % thresh = params.Edata; idamnjanovic@40: % otherwise idamnjanovic@40: % error('Invalid coding mode specified'); idamnjanovic@40: % end idamnjanovic@40: % elseif (isfield(params,'Tdata')) idamnjanovic@40: % codemode = CODE_SPARSITY; idamnjanovic@40: % thresh = params.Tdata; idamnjanovic@40: % elseif (isfield(params,'Edata')) idamnjanovic@40: % codemode = CODE_ERROR; idamnjanovic@40: % thresh = params.Edata; idamnjanovic@40: % idamnjanovic@40: % else idamnjanovic@40: % error('Data sparse-coding target not specified'); idamnjanovic@40: % end idamnjanovic@40: idamnjanovic@40: thresh = params.Edata; idamnjanovic@40: % max number of atoms % idamnjanovic@40: idamnjanovic@40: % if (codemode==CODE_ERROR && isfield(params,'maxatoms')) idamnjanovic@40: % ompparams{end+1} = 'maxatoms'; idamnjanovic@40: % ompparams{end+1} = params.maxatoms; idamnjanovic@40: % end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % memory usage % idamnjanovic@40: idamnjanovic@40: if (isfield(params,'memusage')) idamnjanovic@40: switch lower(params.memusage) idamnjanovic@40: case 'low' idamnjanovic@40: memusage = MEM_LOW; idamnjanovic@40: case 'normal' idamnjanovic@40: memusage = MEM_NORMAL; idamnjanovic@40: case 'high' idamnjanovic@40: memusage = MEM_HIGH; idamnjanovic@40: otherwise idamnjanovic@40: error('Invalid memory usage mode'); idamnjanovic@40: end idamnjanovic@40: else idamnjanovic@40: memusage = MEM_NORMAL; idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % iteration count % idamnjanovic@40: idamnjanovic@40: if (isfield(params,'iternum')) idamnjanovic@40: iternum = params.iternum; idamnjanovic@40: else idamnjanovic@40: iternum = 10; idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % omp function % idamnjanovic@40: idamnjanovic@40: if (codemode == CODE_SPARSITY) idamnjanovic@40: ompfunc = @omp; idamnjanovic@40: else idamnjanovic@40: ompfunc = @omp2; idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % % status messages % idamnjanovic@40: % idamnjanovic@40: % printiter = 0; idamnjanovic@40: % printreplaced = 0; idamnjanovic@40: % printerr = 0; idamnjanovic@40: % printgerr = 0; idamnjanovic@40: % idamnjanovic@40: % verbose = 't'; idamnjanovic@40: % msgdelta = -1; idamnjanovic@40: % idamnjanovic@40: idamnjanovic@40: % idamnjanovic@40: % for i = 1:length(verbose) idamnjanovic@40: % switch lower(verbose(i)) idamnjanovic@40: % case 'i' idamnjanovic@40: % printiter = 1; idamnjanovic@40: % case 'r' idamnjanovic@40: % printiter = 1; idamnjanovic@40: % printreplaced = 1; idamnjanovic@40: % case 't' idamnjanovic@40: % printiter = 1; idamnjanovic@40: % printerr = 1; idamnjanovic@40: % if (isfield(params,'testdata')) idamnjanovic@40: % printgerr = 1; idamnjanovic@40: % end idamnjanovic@40: % end idamnjanovic@40: % end idamnjanovic@40: % idamnjanovic@40: % if (msgdelta<=0 || isempty(verbose)) idamnjanovic@40: % msgdelta = -1; idamnjanovic@40: % end idamnjanovic@40: % idamnjanovic@40: % ompparams{end+1} = 'messages'; idamnjanovic@40: % ompparams{end+1} = msgdelta; idamnjanovic@40: % idamnjanovic@40: % idamnjanovic@40: % idamnjanovic@40: % % compute error flag % idamnjanovic@40: % idamnjanovic@40: % comperr = (nargout>=3 || printerr); idamnjanovic@40: % idamnjanovic@40: % idamnjanovic@40: % % validation flag % idamnjanovic@40: % idamnjanovic@40: % testgen = 0; idamnjanovic@40: % if (isfield(params,'testdata')) idamnjanovic@40: % testdata = params.testdata; idamnjanovic@40: % if (nargout>=4 || printgerr) idamnjanovic@40: % testgen = 1; idamnjanovic@40: % end idamnjanovic@40: % end idamnjanovic@40: idamnjanovic@40: % idamnjanovic@40: % % data norms % idamnjanovic@40: % idamnjanovic@40: % XtX = []; XtXg = []; idamnjanovic@40: % if (codemode==CODE_ERROR && memusage==MEM_HIGH) idamnjanovic@40: % XtX = colnorms_squared(data); idamnjanovic@40: % if (testgen) idamnjanovic@40: % XtXg = colnorms_squared(testdata); idamnjanovic@40: % end idamnjanovic@40: % end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % mutual incoherence limit % idamnjanovic@40: idamnjanovic@40: if (isfield(params,'muthresh')) idamnjanovic@40: muthresh = params.muthresh; idamnjanovic@40: else idamnjanovic@40: muthresh = 0.99; idamnjanovic@40: end idamnjanovic@40: if (muthresh < 0) idamnjanovic@40: error('invalid muthresh value, must be non-negative'); idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % determine dictionary size % idamnjanovic@40: idamnjanovic@40: if (isfield(params,'initdict')) idamnjanovic@40: if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) idamnjanovic@40: dictsize = length(params.initdict); idamnjanovic@40: else idamnjanovic@40: dictsize = size(params.initdict,2); idamnjanovic@40: end idamnjanovic@40: end idamnjanovic@40: if (isfield(params,'dictsize')) % this superceedes the size determined by initdict idamnjanovic@40: dictsize = params.dictsize; idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: if (size(X,2) < dictsize) idamnjanovic@40: error('Number of training signals is smaller than number of atoms to train'); idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % initialize the dictionary % idamnjanovic@40: idamnjanovic@40: if (isfield(params,'initdict')) idamnjanovic@40: if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) idamnjanovic@40: D = X(:,params.initdict(1:dictsize)); idamnjanovic@40: else idamnjanovic@40: if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2) 1e-6); % ensure no zero data elements are chosen idamnjanovic@40: perm = randperm(length(data_ids)); idamnjanovic@40: D = X(:,data_ids(perm(1:dictsize))); idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % normalize the dictionary % idamnjanovic@40: idamnjanovic@40: % D = normcols(D); idamnjanovic@40: % DtD=D'*D; idamnjanovic@40: idamnjanovic@40: err = zeros(1,iternum); idamnjanovic@40: gerr = zeros(1,iternum); idamnjanovic@40: idamnjanovic@40: if (codemode == CODE_SPARSITY) idamnjanovic@40: errstr = 'RMSE'; idamnjanovic@40: else idamnjanovic@40: errstr = 'mean atomnum'; idamnjanovic@40: end idamnjanovic@40: %X(:,p(X_norm_sortthresh); idamnjanovic@40: %p1=p(X_norm_sort>thresh); idamnjanovic@40: % X(:,setxor(p1,1:end))=0; idamnjanovic@40: % X_im=col2imstep(X, [256 256], [8 8]); idamnjanovic@40: % figure(10); imshow(X_im); idamnjanovic@40: % if iternum==2 idamnjanovic@40: % D(:,1)=D(:,2); idamnjanovic@40: % end idamnjanovic@40: %p1=p1(p2(1:40000)); idamnjanovic@40: %end-min(40000, end)+1:end));%1:min(40000, end))); idamnjanovic@40: %p1 = randperm(size(data,2));%size(data,2) idamnjanovic@40: %data=data(:,p1); idamnjanovic@40: idamnjanovic@40: C=(100000*thresh)*eye(dictsize); idamnjanovic@40: % figure(11); idamnjanovic@40: w=zeros(dictsize,1); idamnjanovic@40: replaced=zeros(dictsize,1); idamnjanovic@40: u=zeros(dictsize,1); idamnjanovic@40: % dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); idamnjanovic@40: % figure(11);imshow(imresize(dictimg,2,'nearest')); idamnjanovic@40: % pause(1); idamnjanovic@40: lambda=0.9997%0.99986;%3+0.0001*params.linc; idamnjanovic@40: for j=1:1 idamnjanovic@40: %data=X; idamnjanovic@40: if size(X,2)>40000 idamnjanovic@40: p2 = randperm(size(X,2)); idamnjanovic@40: idamnjanovic@40: p2=sort(p2(1:40000));%min(floor(size(p1,2)/2),40000))); idamnjanovic@40: size(p2,2) idamnjanovic@40: data=X(:,p2); idamnjanovic@40: elseif size(X,2)>0 idamnjanovic@40: %p2 = randperm(size(p1,2)); idamnjanovic@40: size(X,2) idamnjanovic@40: data=X; idamnjanovic@40: else idamnjanovic@40: break; idamnjanovic@40: end idamnjanovic@40: % figure(1); idamnjanovic@40: % plot(sqrt(sum(data.^2, 1))); idamnjanovic@40: % a=size(data,2)/4; idamnjanovic@40: % lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 idamnjanovic@40: %C(1,1)=0; idamnjanovic@40: modi=1000; idamnjanovic@40: for i = 1:size(data,2) idamnjanovic@40: % if norm(data(:,i))>thresh idamnjanovic@40: % par.multA= @(x,par) multMatr(D,x); % user function y=Ax idamnjanovic@40: % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x idamnjanovic@40: % par.y=data(:,i); idamnjanovic@40: % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); idamnjanovic@40: % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); idamnjanovic@40: %w = SMALL_chol(D,data(:,i), 256,32, thresh);% idamnjanovic@40: %w = sparsecode(data(:,i), D, [], [], thresh); idamnjanovic@40: w = omp2mex(D,data(:,i),[],[],[],thresh,0,-1,-1,0); idamnjanovic@40: idamnjanovic@40: %w(find(w<1))=0; idamnjanovic@40: %^2; idamnjanovic@40: % lambda(i)=1-0.001/(1+i/a); idamnjanovic@40: % if i=35000) idamnjanovic@40: % modi=100; idamnjanovic@40: % pause idamnjanovic@40: % end; idamnjanovic@40: % end idamnjanovic@40: % end idamnjanovic@40: end idamnjanovic@40: %p1=p1(setxor(p2,1:end)); idamnjanovic@40: %[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); idamnjanovic@40: %replaced=zeros(dictsize,1); idamnjanovic@40: % W=sparsecode(data, D, [], [], thresh); idamnjanovic@40: % data=D*W; idamnjanovic@40: %lambda=lambda+0.0002 idamnjanovic@40: end idamnjanovic@40: %Gamma=mexLasso(data, D, param); idamnjanovic@40: %err=compute_err(D,Gamma, data); idamnjanovic@40: %[y,i]=max(err); idamnjanovic@40: %D(:,1)=data(:,i)/norm(data(:,i)); idamnjanovic@40: idamnjanovic@40: Dictionary = D;%D(:,p); idamnjanovic@40: % figure(3); idamnjanovic@40: % plot(lambda); idamnjanovic@40: % mean(lambda); idamnjanovic@40: % figure(4+j);plot(w_sp); idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% idamnjanovic@40: % sparsecode % idamnjanovic@40: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% idamnjanovic@40: idamnjanovic@40: function Gamma = sparsecode(data,D,XtX,G,thresh) idamnjanovic@40: idamnjanovic@40: global CODE_SPARSITY codemode idamnjanovic@40: global MEM_HIGH memusage idamnjanovic@40: global ompfunc ompparams idamnjanovic@40: idamnjanovic@40: if (memusage < MEM_HIGH) idamnjanovic@40: Gamma = ompfunc(D,data,G,thresh,ompparams{:}); idamnjanovic@40: idamnjanovic@40: else % memusage is high idamnjanovic@40: idamnjanovic@40: if (codemode == CODE_SPARSITY) idamnjanovic@40: Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); idamnjanovic@40: idamnjanovic@40: else idamnjanovic@40: Gamma = ompfunc(D, data, G, thresh,ompparams{:}); idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% idamnjanovic@40: % compute_err % idamnjanovic@40: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: function err = compute_err(D,Gamma,data) idamnjanovic@40: idamnjanovic@40: global CODE_SPARSITY codemode idamnjanovic@40: idamnjanovic@40: if (codemode == CODE_SPARSITY) idamnjanovic@40: err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); idamnjanovic@40: else idamnjanovic@40: err = nnz(Gamma)/size(data,2); idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% idamnjanovic@40: % cleardict % idamnjanovic@40: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) idamnjanovic@40: idamnjanovic@40: use_thresh = 4; % at least this number of samples must use the atom to be kept idamnjanovic@40: idamnjanovic@40: dictsize = size(D,2); idamnjanovic@40: idamnjanovic@40: % compute error in blocks to conserve memory idamnjanovic@40: % err = zeros(1,size(X,2)); idamnjanovic@40: % blocks = [1:3000:size(X,2) size(X,2)+1]; idamnjanovic@40: % for i = 1:length(blocks)-1 idamnjanovic@40: % err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); idamnjanovic@40: % end idamnjanovic@40: idamnjanovic@40: cleared_atoms = 0; idamnjanovic@40: usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); idamnjanovic@40: idamnjanovic@40: for j = 1:dictsize idamnjanovic@40: idamnjanovic@40: % compute G(:,j) idamnjanovic@40: Gj = D'*D(:,j); idamnjanovic@40: Gj(j) = 0; idamnjanovic@40: idamnjanovic@40: % replace atom idamnjanovic@40: if ( (max(Gj.^2)>muthresh^2 || usecount(j)