Mercurial > hg > smallbox
diff DL/RLS-DLA/SMALL_rlsdla.m @ 40:6416fc12f2b8
(none)
author | idamnjanovic |
---|---|
date | Mon, 14 Mar 2011 15:35:24 +0000 |
parents | |
children | 55faa9b5d1ac |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/RLS-DLA/SMALL_rlsdla.m Mon Mar 14 15:35:24 2011 +0000 @@ -0,0 +1,148 @@ +function Dictionary = SMALL_rlsdla(X, params) + + + + + + +CODE_SPARSITY = 1; +CODE_ERROR = 2; + + +% Determine which method will be used for sparse representation step - +% Sparsity or Error mode + +if (isfield(params,'codemode')) + switch lower(params.codemode) + case 'sparsity' + codemode = CODE_SPARSITY; + thresh = params.Tdata; + case 'error' + codemode = CODE_ERROR; + thresh = params.Edata; + + otherwise + error('Invalid coding mode specified'); + end +elseif (isfield(params,'Tdata')) + codemode = CODE_SPARSITY; + thresh = params.Tdata; +elseif (isfield(params,'Edata')) + codemode = CODE_ERROR; + thresh = params.Edata; + +else + error('Data sparse-coding target not specified'); +end + + +% max number of atoms % + +if (codemode==CODE_ERROR && isfield(params,'maxatoms')) + maxatoms = params.maxatoms; +else + maxatoms = -1; +end + + +% Forgetting factor + +if (isfield(params,'forgettingMode')) + switch lower(params.forgettingMode) + case 'fix' + if (isfield(params,'forgettingFactor')) + lambda=params.forgettingFactor; + else + lambda=1; + end + otherwise + error('This mode is still not implemented'); + end +elseif (isfield(params,'forgettingFactor')) + lambda=params.forgettingFactor; +else + lambda=1; +end + +% determine dictionary size % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + dictsize = length(params.initdict); + else + dictsize = size(params.initdict,2); + end +end +if (isfield(params,'dictsize')) % this superceedes the size determined by initdict + dictsize = params.dictsize; +end + +if (size(X,2) < dictsize) + error('Number of training signals is smaller than number of atoms to train'); +end + + +% initialize the dictionary % + +if (isfield(params,'initdict')) + if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) + D = X(:,params.initdict(1:dictsize)); + else + if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) + error('Invalid initial dictionary'); + end + D = params.initdict(:,1:dictsize); + end +else + data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen + perm = randperm(length(data_ids)); + D = X(:,data_ids(perm(1:dictsize))); +end + + +% normalize the dictionary % + +D = normcols(D); + +% Training data + +data=X; + +% + +C=(100000*thresh)*eye(dictsize); +w=zeros(dictsize,1); +u=zeros(dictsize,1); + + +for i = 1:size(data,2) + + if (codemode == CODE_SPARSITY) + w = ompmex(D,data(:,i),[],[],thresh,1,-1,0); + else + w = omp2mex(D,data(:,i),[],[],[],thresh,0,-1,maxatoms,0); + end + + spind=find(w); + + residual = data(:,i) - D * w; + + if (lambda~=1) + C = C *(1/ lambda); + end + + u = C(:,spind) * w(spind); + + + alfa = 1/(1 + w' * u); + + D = D + (alfa * residual) * u'; + + + C = C - (alfa * u)* u'; + +end + +Dictionary = D; + +end